use super::{compiled::CompiledGraph, topology::TopologyError, topology::TopologyValidator};
use crate::{
State,
edge::{CompiledEdge, END, Edge, START, TriggerSource},
node::IntoNode,
state::{FromState, IntoState},
};
use indexmap::IndexMap;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone, Debug, Default)]
pub struct CompileConfig {
pub interrupt_before: Vec<String>,
pub interrupt_after: Vec<String>,
}
#[derive(Clone, Debug, Default)]
pub struct NodeMetadata {
pub defer: bool,
pub metadata: Option<HashMap<String, serde_json::Value>>,
pub destinations: Option<Vec<String>>,
pub retry_policies: Vec<RetryPolicy>,
pub error_handler: Option<String>,
pub timeout_policies: Vec<crate::TimeoutPolicy>,
pub circuit_breaker: Option<CircuitBreakerConfig>,
pub fallback_node: Option<String>,
}
#[derive(Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_interval: std::time::Duration,
pub backoff_factor: f64,
pub max_interval: std::time::Duration,
pub jitter: bool,
#[allow(
clippy::type_complexity,
reason = "trait object requires full signature"
)]
pub retry_on: Option<Arc<dyn Fn(&crate::JunctureError) -> bool + Send + Sync>>,
}
impl std::fmt::Debug for RetryPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryPolicy")
.field("max_attempts", &self.max_attempts)
.field("initial_interval", &self.initial_interval)
.field("backoff_factor", &self.backoff_factor)
.field("max_interval", &self.max_interval)
.field("jitter", &self.jitter)
.field("retry_on", &self.retry_on.as_ref().map(|_| "<fn>"))
.finish()
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(500),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(10),
jitter: true,
retry_on: None,
}
}
}
#[derive(Clone)]
pub struct CircuitBreakerConfig {
pub failure_threshold: usize,
pub cooldown_duration: std::time::Duration,
pub half_open_max_attempts: usize,
}
impl std::fmt::Debug for CircuitBreakerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CircuitBreakerConfig")
.field("failure_threshold", &self.failure_threshold)
.field("cooldown_duration", &self.cooldown_duration)
.field("half_open_max_attempts", &self.half_open_max_attempts)
.finish()
}
}
impl CircuitBreakerConfig {
#[must_use]
pub const fn new(failure_threshold: usize, cooldown_duration: std::time::Duration) -> Self {
Self {
failure_threshold,
cooldown_duration,
half_open_max_attempts: 1,
}
}
#[must_use]
pub const fn with_half_open_max_attempts(mut self, max_attempts: usize) -> Self {
self.half_open_max_attempts = if max_attempts < 1 { 1 } else { max_attempts };
self
}
}
#[derive(Clone, Debug)]
pub struct CircuitBreakerState {
state: CircuitState,
consecutive_failures: usize,
opened_at: Option<crate::time::Instant>,
half_open_attempts: usize,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum CircuitState {
Closed,
Open,
HalfOpen,
}
impl CircuitBreakerState {
#[must_use]
pub const fn new() -> Self {
Self {
state: CircuitState::Closed,
consecutive_failures: 0,
opened_at: None,
half_open_attempts: 0,
}
}
pub fn should_allow(&mut self, config: &CircuitBreakerConfig) -> bool {
match self.state {
CircuitState::Closed => true,
CircuitState::Open => {
if let Some(opened) = self.opened_at {
if opened.elapsed() >= config.cooldown_duration {
self.state = CircuitState::HalfOpen;
self.half_open_attempts = 1;
true
} else {
false
}
} else {
self.state = CircuitState::Closed;
true
}
}
CircuitState::HalfOpen => {
self.half_open_attempts < config.half_open_max_attempts
}
}
}
pub fn mark_half_open_attempt(&mut self) {
if self.state == CircuitState::HalfOpen {
self.half_open_attempts += 1;
}
}
pub fn record_success(&mut self) {
if self.state == CircuitState::Open {
return;
}
self.consecutive_failures = 0;
self.half_open_attempts = 0;
self.state = CircuitState::Closed;
self.opened_at = None;
}
pub fn record_failure(&mut self, config: &CircuitBreakerConfig) {
self.consecutive_failures += 1;
match self.state {
CircuitState::Closed => {
if self.consecutive_failures >= config.failure_threshold {
self.state = CircuitState::Open;
self.opened_at = Some(crate::time::Instant::now());
}
}
CircuitState::HalfOpen => {
self.state = CircuitState::Open;
self.opened_at = Some(crate::time::Instant::now());
self.half_open_attempts = 0;
}
CircuitState::Open => {
}
}
}
#[must_use]
pub const fn state(&self) -> &CircuitState {
&self.state
}
#[must_use]
pub const fn consecutive_failures(&self) -> usize {
self.consecutive_failures
}
}
impl Default for CircuitBreakerState {
fn default() -> Self {
Self::new()
}
}
pub struct NodeError<S: State> {
pub node: String,
pub error: crate::JunctureError,
pub state: S,
pub attempt: u32,
}
impl<S: State> std::fmt::Debug for NodeError<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NodeError")
.field("node", &self.node)
.field("error", &self.error)
.field("state", &"<state>")
.field("attempt", &self.attempt)
.finish()
}
}
pub struct ErrorHandlerNode<S: State> {
inner: Arc<dyn crate::Node<S>>,
#[allow(
clippy::type_complexity,
reason = "trait object requires full signature"
)]
handler: Arc<dyn Fn(NodeError<S>) -> crate::Command<S> + Send + Sync>,
name: String,
}
impl<S: State> std::fmt::Debug for ErrorHandlerNode<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ErrorHandlerNode")
.field("name", &self.name)
.field("inner", &"<node>")
.field("handler", &"<fn>")
.finish()
}
}
impl<S: State> ErrorHandlerNode<S> {
#[allow(
clippy::type_complexity,
reason = "trait object requires full signature"
)]
pub fn new(
inner: Arc<dyn crate::Node<S>>,
handler: Arc<dyn Fn(NodeError<S>) -> crate::Command<S> + Send + Sync>,
) -> Self {
let name = inner.name().to_string();
Self {
inner,
handler,
name,
}
}
}
impl<S: State + Clone> crate::Node<S> for ErrorHandlerNode<S> {
fn call(
&self,
state: &S,
config: &crate::RunnableConfig,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
+ Send
+ '_,
>,
> {
let state_backup = state.clone();
let result = self.inner.call(state, config);
let handler = Arc::clone(&self.handler);
let node_name = self.name.clone();
Box::pin(async move {
match result.await {
Ok(command) => Ok(command),
Err(error) => {
let node_error = NodeError {
node: node_name,
error,
state: state_backup,
attempt: 1, };
Ok(handler(node_error))
}
}
})
}
fn name(&self) -> &str {
&self.name
}
}
pub struct RetryingNode<S: State> {
inner: Arc<dyn crate::Node<S>>,
policy: RetryPolicy,
name: String,
}
impl<S: State> std::fmt::Debug for RetryingNode<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryingNode")
.field("name", &self.name)
.field("inner", &"<node>")
.field("policy", &self.policy)
.finish()
}
}
impl<S: State> RetryingNode<S> {
#[must_use]
pub fn new(inner: Arc<dyn crate::Node<S>>, policy: RetryPolicy) -> Self {
let name = inner.name().to_string();
Self {
inner,
policy,
name,
}
}
}
impl<S: State + Clone> crate::Node<S> for RetryingNode<S> {
fn call(
&self,
state: &S,
config: &crate::RunnableConfig,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
+ Send
+ '_,
>,
> {
let policy = self.policy.clone();
let inner = Arc::clone(&self.inner);
let config = config.clone();
let node_name = self.name.clone();
let state_owned = state.clone();
Box::pin(async move {
execute_with_retry(
&node_name,
&policy,
|s, cfg| inner.call(s, cfg),
&state_owned,
&config,
)
.await
})
}
fn name(&self) -> &str {
&self.name
}
}
pub async fn execute_with_retry<S, F, Fut>(
node_name: &str,
policy: &RetryPolicy,
operation: F,
state: &S,
config: &crate::RunnableConfig,
) -> Result<crate::Command<S>, crate::JunctureError>
where
S: State,
F: Fn(&S, &crate::RunnableConfig) -> Fut,
Fut: std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>,
{
let mut last_error: Option<crate::JunctureError> = None;
let mut delay = policy.initial_interval;
for attempt in 0..policy.max_attempts {
match operation(state, config).await {
Ok(command) => {
if attempt > 0 {
tracing::debug!(
node_name = node_name,
attempt = attempt + 1,
"node succeeded after retry"
);
}
return Ok(command);
}
Err(error) => {
let should_retry = policy.should_retry(&error);
if !should_retry || attempt + 1 >= policy.max_attempts {
return Err(error);
}
tracing::warn!(
node_name = node_name,
attempt = attempt + 1,
max_attempts = policy.max_attempts,
error = %error,
"node failed, will retry"
);
last_error = Some(error);
let actual_delay = compute_delay(delay, policy.jitter, policy.max_interval);
tokio::time::sleep(actual_delay).await;
delay = cap_delay(delay.mul_f64(policy.backoff_factor), policy.max_interval);
}
}
}
Err(last_error.unwrap_or_else(|| {
crate::JunctureError::execution(format!(
"node '{node_name}': retry policy exhausted with no error recorded"
))
}))
}
fn compute_delay(
base: std::time::Duration,
jitter: bool,
max_interval: std::time::Duration,
) -> std::time::Duration {
let capped = cap_delay(base, max_interval);
if !jitter {
return capped;
}
let jitter_fraction: f64 = rand::random_range(0.75..=1.25);
let jittered = capped.mul_f64(jitter_fraction);
cap_delay(jittered, max_interval)
}
fn cap_delay(delay: std::time::Duration, max: std::time::Duration) -> std::time::Duration {
delay.min(max)
}
impl RetryPolicy {
fn should_retry(&self, error: &crate::JunctureError) -> bool {
self.retry_on.as_ref().map_or_else(
|| !error.is_cancelled() && !error.is_interrupt(),
|predicate| predicate(error),
)
}
}
pub struct TimeoutNode<S: State> {
inner: Arc<dyn crate::Node<S>>,
policy: crate::TimeoutPolicy,
name: String,
}
impl<S: State> std::fmt::Debug for TimeoutNode<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimeoutNode")
.field("name", &self.name)
.field("inner", &"<node>")
.field("policy", &self.policy)
.finish()
}
}
impl<S: State> TimeoutNode<S> {
#[must_use]
pub fn new(inner: Arc<dyn crate::Node<S>>, policy: crate::TimeoutPolicy) -> Self {
let name = inner.name().to_string();
Self {
inner,
policy,
name,
}
}
}
impl<S: State + Clone> crate::Node<S> for TimeoutNode<S> {
fn call(
&self,
state: &S,
config: &crate::RunnableConfig,
) -> std::pin::Pin<
Box<
dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
+ Send
+ '_,
>,
> {
let inner = Arc::clone(&self.inner);
let config = config.clone();
let node_name = self.name.clone();
let run_timeout = self.policy.run_timeout;
let state_cloned = state.clone();
Box::pin(async move {
execute_with_timeout(
&node_name,
run_timeout,
|s, cfg| inner.call(s, cfg),
&state_cloned,
&config,
)
.await
})
}
fn name(&self) -> &str {
&self.name
}
}
pub async fn execute_with_timeout<S, F, Fut>(
node_name: &str,
run_timeout: std::time::Duration,
operation: F,
state: &S,
config: &crate::RunnableConfig,
) -> Result<crate::Command<S>, crate::JunctureError>
where
S: State,
F: FnOnce(&S, &crate::RunnableConfig) -> Fut,
Fut: std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>,
{
let result = tokio::time::timeout(run_timeout, operation(state, config)).await;
match result {
Ok(Ok(command)) => Ok(command),
Ok(Err(error)) => Err(error),
Err(_) => Err(crate::JunctureError::node_timeout(
crate::error::NodeTimeoutError::RunTimeout {
node: node_name.to_string(),
timeout: u64::try_from(run_timeout.as_millis()).unwrap_or(u64::MAX),
},
)),
}
}
pub struct StateGraph<S: State, I: IntoState<S> = S, O: FromState<S> = S> {
nodes: IndexMap<String, Arc<dyn crate::Node<S>>>,
edges: Vec<Edge<S>>,
entry_point: Option<String>,
finish_points: Vec<String>,
builder_metadata: IndexMap<String, NodeMetadata>,
subgraphs: Vec<crate::subgraph::SubgraphMount<S>>,
_input: std::marker::PhantomData<I>,
_output: std::marker::PhantomData<O>,
}
impl<S: State, I: IntoState<S>, O: FromState<S>> std::fmt::Debug for StateGraph<S, I, O> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StateGraph")
.field("nodes", &format_args!("{} nodes", self.nodes.len()))
.field("edges", &format_args!("{} edges", self.edges.len()))
.field("entry_point", &self.entry_point)
.field("finish_points", &self.finish_points)
.field("builder_metadata", &self.builder_metadata)
.field(
"subgraphs",
&format_args!("{} subgraphs", self.subgraphs.len()),
)
.finish()
}
}
impl<S: State, I: IntoState<S>, O: FromState<S>> StateGraph<S, I, O> {
#[must_use]
pub fn new() -> Self {
Self {
nodes: IndexMap::new(),
edges: Vec::new(),
entry_point: None,
finish_points: Vec::new(),
builder_metadata: IndexMap::new(),
subgraphs: Vec::new(),
_input: std::marker::PhantomData,
_output: std::marker::PhantomData,
}
}
#[expect(
clippy::too_many_arguments,
reason = "add_node requires name, node, defer, metadata, destinations, retry_policies, and timeout_policies. All are necessary for the builder pattern."
)]
pub fn add_node(
&mut self,
name: impl Into<String>,
node: impl IntoNode<S>,
defer: bool,
metadata: Option<HashMap<String, serde_json::Value>>,
destinations: Option<Vec<String>>,
retry_policies: Vec<RetryPolicy>,
timeout_policies: Vec<crate::TimeoutPolicy>,
) -> Result<&mut Self, TopologyError> {
let name = name.into();
if self.nodes.contains_key(&name) {
return Err(TopologyError::DuplicateNode { name });
}
let node_arc = node.into_node(&name);
self.nodes.insert(name.clone(), node_arc);
self.builder_metadata.insert(
name,
NodeMetadata {
defer,
metadata,
destinations,
retry_policies,
error_handler: None,
timeout_policies,
circuit_breaker: None,
fallback_node: None,
},
);
Ok(self)
}
pub fn add_node_simple(
&mut self,
name: impl Into<String>,
node: impl IntoNode<S>,
) -> Result<&mut Self, TopologyError> {
self.add_node(name, node, false, None, None, Vec::new(), Vec::new())
}
#[allow(
clippy::type_complexity,
reason = "trait object requires full signature"
)]
pub fn add_node_with_error_handler(
&mut self,
name: impl Into<String>,
node: impl IntoNode<S>,
handler: Arc<dyn Fn(super::builder::NodeError<S>) -> crate::Command<S> + Send + Sync>,
) -> Result<&mut Self, TopologyError>
where
S: Clone,
{
let name_str = name.into();
let inner = node.into_node(&name_str);
let wrapped: Arc<dyn crate::Node<S>> = Arc::new(ErrorHandlerNode::new(inner, handler));
if self.nodes.contains_key(&name_str) {
return Err(TopologyError::DuplicateNode { name: name_str });
}
self.nodes.insert(name_str.clone(), wrapped);
self.builder_metadata
.insert(name_str, NodeMetadata::default());
Ok(self)
}
pub fn add_node_with_retry(
&mut self,
name: impl Into<String>,
node: impl IntoNode<S>,
policy: RetryPolicy,
) -> Result<&mut Self, TopologyError>
where
S: Clone,
{
let name_str = name.into();
let inner = node.into_node(&name_str);
let wrapped: Arc<dyn crate::Node<S>> = Arc::new(RetryingNode::new(inner, policy));
if self.nodes.contains_key(&name_str) {
return Err(TopologyError::DuplicateNode { name: name_str });
}
self.nodes.insert(name_str.clone(), wrapped);
self.builder_metadata
.insert(name_str, NodeMetadata::default());
Ok(self)
}
pub fn add_node_with_circuit_breaker(
&mut self,
name: impl Into<String>,
node: impl IntoNode<S>,
config: CircuitBreakerConfig,
) -> Result<&mut Self, TopologyError> {
let name_str = name.into();
let node_arc = node.into_node(&name_str);
if self.nodes.contains_key(&name_str) {
return Err(TopologyError::DuplicateNode { name: name_str });
}
self.nodes.insert(name_str.clone(), node_arc);
self.builder_metadata.insert(
name_str,
NodeMetadata {
circuit_breaker: Some(config),
..NodeMetadata::default()
},
);
Ok(self)
}
pub fn add_node_with_retry_and_circuit_breaker(
&mut self,
name: impl Into<String>,
node: impl IntoNode<S>,
retry_policy: RetryPolicy,
circuit_breaker_config: CircuitBreakerConfig,
) -> Result<&mut Self, TopologyError>
where
S: Clone,
{
let name_str = name.into();
let inner = node.into_node(&name_str);
let wrapped: Arc<dyn crate::Node<S>> = Arc::new(RetryingNode::new(inner, retry_policy));
if self.nodes.contains_key(&name_str) {
return Err(TopologyError::DuplicateNode { name: name_str });
}
self.nodes.insert(name_str.clone(), wrapped);
self.builder_metadata.insert(
name_str,
NodeMetadata {
circuit_breaker: Some(circuit_breaker_config),
..NodeMetadata::default()
},
);
Ok(self)
}
pub fn add_node_with_fallback(
&mut self,
name: impl Into<String>,
node: impl IntoNode<S>,
fallback: impl Into<String>,
) -> Result<&mut Self, TopologyError> {
let name_str = name.into();
let node_arc = node.into_node(&name_str);
if self.nodes.contains_key(&name_str) {
return Err(TopologyError::DuplicateNode { name: name_str });
}
self.nodes.insert(name_str.clone(), node_arc);
self.builder_metadata.insert(
name_str,
NodeMetadata {
fallback_node: Some(fallback.into()),
..NodeMetadata::default()
},
);
Ok(self)
}
pub fn add_subgraph(
&mut self,
mount: crate::subgraph::SubgraphMount<S>,
) -> Result<&mut Self, TopologyError> {
if self.nodes.contains_key(&mount.name) {
return Err(TopologyError::DuplicateNode {
name: mount.name.clone(),
});
}
let name = mount.name.clone();
let node = Arc::clone(&mount.node);
self.nodes.insert(name.clone(), node);
self.builder_metadata.insert(name, NodeMetadata::default());
self.subgraphs.push(mount);
Ok(self)
}
#[allow(
dead_code,
reason = "fully implemented public API awaiting external consumers"
)]
pub fn add_subgraph_node<Sub>(
&mut self,
name: &str,
subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
) -> Result<&mut Self, TopologyError>
where
Sub: crate::subgraph::StateSubset<S>
+ State
+ Clone
+ serde::Serialize
+ for<'de> serde::Deserialize<'de>,
Sub::Update: serde::Serialize,
S: Clone,
{
let input_map = Arc::new(move |parent: &S| Sub::extract(parent));
let output_map = Arc::new(|_sub_output: &Sub| Sub::map_update(Default::default()));
let node: Arc<dyn crate::Node<S>> = Arc::new(crate::subgraph::SubgraphNode::new(
subgraph,
name.to_string(),
input_map,
output_map,
crate::subgraph::SubgraphConfig::default(),
));
if self.nodes.contains_key(name) {
return Err(TopologyError::DuplicateNode {
name: name.to_string(),
});
}
self.nodes.insert(name.to_string(), node);
self.builder_metadata
.insert(name.to_string(), NodeMetadata::default());
Ok(self)
}
#[allow(
clippy::type_complexity,
reason = "requires type erasure for trait object storage"
)]
pub fn add_subgraph_with_config<Sub>(
&mut self,
name: &str,
subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
input_map: impl Fn(&S) -> Sub + Send + Sync + 'static,
output_map: impl Fn(&Sub) -> S::Update + Send + Sync + 'static,
config: crate::subgraph::SubgraphConfig,
) -> Result<&mut Self, TopologyError>
where
Sub: State + serde::Serialize + for<'de> serde::Deserialize<'de>,
Sub::Update: serde::Serialize,
S: Clone,
{
let input_map_arc = Arc::new(input_map);
let output_map_arc: Arc<dyn Fn(&Sub) -> S::Update + Send + Sync> = Arc::new(output_map);
let node: Arc<dyn crate::Node<S>> = Arc::new(crate::subgraph::SubgraphNode::new(
subgraph,
name.to_string(),
input_map_arc,
output_map_arc,
config,
));
if self.nodes.contains_key(name) {
return Err(TopologyError::DuplicateNode {
name: name.to_string(),
});
}
self.nodes.insert(name.to_string(), node);
self.builder_metadata
.insert(name.to_string(), NodeMetadata::default());
Ok(self)
}
#[allow(
clippy::type_complexity,
reason = "requires type erasure for trait object storage"
)]
pub fn add_subgraph_explicit<Sub>(
&mut self,
name: &str,
subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
input_map: impl Fn(&S) -> Sub + Send + Sync + 'static,
output_map: impl Fn(&Sub) -> S::Update + Send + Sync + 'static,
) -> Result<&mut Self, TopologyError>
where
Sub: State + serde::Serialize + for<'de> serde::Deserialize<'de>,
Sub::Update: serde::Serialize,
S: Clone,
{
self.add_subgraph_with_config(
name,
subgraph,
input_map,
output_map,
crate::subgraph::SubgraphConfig::default(),
)
}
pub fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) {
self.edges.push(Edge::Fixed {
from: from.into(),
to: to.into(),
});
}
pub fn add_conditional_edges(
&mut self,
from: impl Into<String>,
router: Arc<dyn crate::edge::Router<S>>,
path_map: crate::edge::PathMap,
) {
self.edges.push(Edge::Conditional {
from: from.into(),
router,
path_map,
});
}
pub fn set_entry_point(&mut self, node: impl Into<String>) {
let node = node.into();
self.entry_point = Some(node.clone());
self.edges.push(Edge::Fixed {
from: START.to_string(),
to: node,
});
}
pub fn set_finish_point(&mut self, node: impl Into<String>) {
let node = node.into();
self.finish_points.push(node.clone());
self.edges.push(Edge::Fixed {
from: node,
to: END.to_string(),
});
}
pub fn add_sequence(&mut self, nodes: &[impl AsRef<str>]) -> Result<&mut Self, TopologyError> {
if nodes.is_empty() {
return Ok(self);
}
let node_names: Vec<&str> = nodes.iter().map(std::convert::AsRef::as_ref).collect();
for name in &node_names {
if !self.nodes.contains_key(*name) {
return Err(TopologyError::NodeNotFound {
name: (*name).to_string(),
});
}
}
if self.entry_point.is_none() {
self.set_entry_point(node_names[0]);
}
for window in node_names.windows(2) {
self.add_edge(window[0], window[1]);
}
Ok(self)
}
pub fn validate_keys(&self) -> Result<(), TopologyError> {
for name in self.nodes.keys() {
if name.is_empty() {
return Err(TopologyError::InvalidNodeName {
name: name.clone(),
reason: "node name cannot be empty".to_string(),
});
}
if name.contains(':') || name.contains('/') || name.contains('\\') {
return Err(TopologyError::InvalidNodeName {
name: name.clone(),
reason: "node name cannot contain ':', '/', or '\\'".to_string(),
});
}
}
if let Some(ref entry) = self.entry_point
&& !self.nodes.contains_key(entry)
{
return Err(TopologyError::NodeNotFound {
name: entry.clone(),
});
}
for finish in &self.finish_points {
if !self.nodes.contains_key(finish) {
return Err(TopologyError::NodeNotFound {
name: finish.clone(),
});
}
}
let field_count = S::field_count();
let field_names = S::field_names();
for &idx in S::replace_field_indices() {
if idx >= field_count {
return Err(TopologyError::InvalidFieldReference {
index: idx,
field_count,
field_names,
context: "replace_field_indices".to_string(),
});
}
}
for &idx in S::replace_after_finish_field_indices() {
if idx >= field_count {
return Err(TopologyError::InvalidFieldReference {
index: idx,
field_count,
field_names,
context: "replace_after_finish_field_indices".to_string(),
});
}
}
Ok(())
}
pub fn compile(&self) -> Result<CompiledGraph<S, I, O>, TopologyError> {
self.compile_inner(CompileConfig::default(), None)
}
pub fn compile_with_config(
&self,
config: CompileConfig,
) -> Result<CompiledGraph<S, I, O>, TopologyError> {
self.compile_inner(config, None)
}
pub fn compile_ephemeral(&self) -> Result<CompiledGraph<S, I, O>, TopologyError> {
self.compile_inner(CompileConfig::default(), None)
}
pub fn compile_with_checkpointer(
&self,
checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
) -> Result<CompiledGraph<S, I, O>, TopologyError> {
self.compile_inner(CompileConfig::default(), checkpointer)
}
fn compile_inner(
&self,
config: CompileConfig,
checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
) -> Result<CompiledGraph<S, I, O>, TopologyError> {
TopologyValidator::validate(
&self.nodes,
&self.edges,
self.entry_point.as_deref(),
&self.builder_metadata,
)?;
self.validate_keys()?;
let trigger_table = self.build_trigger_table();
let subgraph_info: Vec<super::compiled::SubgraphInfo> = self
.subgraphs
.iter()
.map(|mount| super::compiled::SubgraphInfo {
name: mount.name.clone(),
persistence: mount.config.persistence,
})
.collect();
Ok(CompiledGraph::new(
self.nodes.clone(),
trigger_table,
self.builder_metadata.clone(),
config.interrupt_before,
config.interrupt_after,
checkpointer,
subgraph_info,
))
}
fn build_trigger_table(&self) -> crate::edge::TriggerTable<S> {
let mut trigger_table = crate::edge::TriggerTable::new();
for edge in &self.edges {
match edge {
Edge::Fixed { from, to } => {
if from == START {
trigger_table
.add_incoming(to.clone(), TriggerSource::Edge { from: from.clone() });
} else if to == END {
} else {
trigger_table
.add_outgoing(from.clone(), CompiledEdge::Fixed { target: to.clone() });
trigger_table
.add_incoming(to.clone(), TriggerSource::Edge { from: from.clone() });
}
}
Edge::Conditional {
from,
path_map,
router,
} => {
let router = Arc::clone(router);
let path_map = path_map.clone();
if from == START {
for target in path_map.iter().map(|(_, v)| v) {
trigger_table.add_incoming(
target.clone(),
TriggerSource::Edge { from: from.clone() },
);
}
} else {
trigger_table.add_outgoing(
from.clone(),
CompiledEdge::Conditional {
router,
path_map: path_map.clone(),
},
);
for target in path_map.iter().map(|(_, v)| v) {
trigger_table.add_incoming(
target.clone(),
TriggerSource::Edge { from: from.clone() },
);
}
}
}
}
}
trigger_table
}
}
impl<S: State, I: IntoState<S>, O: FromState<S>> Default for StateGraph<S, I, O> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Node;
use crate::error::JunctureError;
use crate::node::NodeFnUpdate;
use std::pin::Pin;
type BoxResult<T> = Pin<Box<dyn Future<Output = Result<T, JunctureError>> + Send>>;
#[test]
fn test_state_graph_new() {
let graph: StateGraph<StateDummy> = StateGraph::new();
assert!(graph.nodes.is_empty());
assert!(graph.edges.is_empty());
assert!(graph.entry_point.is_none());
assert!(graph.subgraphs.is_empty());
}
#[test]
fn test_add_node_simple() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
});
graph.add_node_simple("test", node).unwrap();
assert!(graph.nodes.contains_key("test"));
}
#[test]
fn test_add_node_duplicate() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"test",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
let result = graph.add_node_simple(
"test",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
);
assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
}
#[test]
fn test_set_entry_point() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph.set_entry_point("start");
assert_eq!(graph.entry_point, Some("start".to_string()));
assert_eq!(graph.edges.len(), 1);
}
#[test]
fn test_set_finish_point() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph.set_finish_point("end");
assert_eq!(graph.finish_points, vec!["end"]);
assert_eq!(graph.edges.len(), 1);
}
#[test]
fn test_add_sequence() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"a",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph
.add_node_simple(
"b",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph
.add_node_simple(
"c",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph.add_sequence(&["a", "b", "c"]).unwrap();
assert_eq!(graph.entry_point, Some("a".to_string()));
assert_eq!(graph.edges.len(), 3); }
#[test]
fn test_add_sequence_missing_node() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
let result = graph.add_sequence(&["missing"]);
assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
}
#[test]
fn test_compile_ephemeral() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"a",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph.set_entry_point("a");
graph.set_finish_point("a");
let compiled = graph.compile_ephemeral().unwrap();
assert_eq!(compiled.nodes().len(), 1);
}
#[test]
fn test_add_subgraph() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
})
.into_node("sub");
let mount = crate::subgraph::SubgraphMount::new(
"my_subgraph",
crate::subgraph::SubgraphConfig::default(),
node,
);
graph.add_subgraph(mount).unwrap();
assert!(graph.nodes.contains_key("my_subgraph"));
assert_eq!(graph.subgraphs.len(), 1);
}
#[test]
fn test_compile_wires_subgraph_info() {
use crate::subgraph::{SubgraphConfig, SubgraphMount, SubgraphPersistence};
let mut graph: StateGraph<StateDummy> = StateGraph::new();
let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
})
.into_node("sub");
let mount = SubgraphMount::new(
"my_subgraph",
SubgraphConfig {
persistence: SubgraphPersistence::PerThread,
},
node,
);
graph.add_subgraph(mount).unwrap();
graph.set_entry_point("my_subgraph");
graph.set_finish_point("my_subgraph");
let compiled = graph.compile().unwrap();
let subgraphs = compiled.get_subgraphs();
assert_eq!(subgraphs.len(), 1);
assert_eq!(subgraphs[0].name, "my_subgraph");
assert_eq!(subgraphs[0].persistence, SubgraphPersistence::PerThread);
}
#[test]
fn test_add_subgraph_duplicate() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"my_subgraph",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
})
.into_node("sub");
let mount = crate::subgraph::SubgraphMount::new(
"my_subgraph",
crate::subgraph::SubgraphConfig::default(),
node,
);
let result = graph.add_subgraph(mount);
assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
}
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
struct ChildState {
value: i32,
}
impl crate::State for ChildState {
type Update = ChildStateUpdate;
type FieldVersions = crate::state::FieldVersions;
fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
if let Some(v) = update.value {
self.value = v;
}
crate::FieldsChanged(0)
}
fn reset_ephemeral(&mut self) {}
}
#[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
struct ChildStateUpdate {
value: Option<i32>,
}
#[test]
fn test_add_subgraph_with_config_registers_node() {
let mut child_graph: StateGraph<ChildState> = StateGraph::new();
child_graph
.add_node_simple(
"child_node",
crate::node::NodeFnUpdate(|_s: &ChildState| -> BoxResult<_> {
Box::pin(async move { Ok(ChildStateUpdate { value: Some(42) }) })
}),
)
.unwrap();
child_graph.set_entry_point("child_node");
child_graph.set_finish_point("child_node");
let compiled_child = Arc::new(child_graph.compile().unwrap());
let mut parent_graph: StateGraph<StateDummy> = StateGraph::new();
parent_graph
.add_subgraph_with_config(
"explicit_subgraph",
compiled_child,
|_parent: &StateDummy| ChildState { value: 0 },
|_child: &ChildState| StateDummyUpdate,
crate::subgraph::SubgraphConfig::default(),
)
.unwrap();
assert!(parent_graph.nodes.contains_key("explicit_subgraph"));
}
#[test]
fn test_add_subgraph_with_config_duplicate_node() {
let mut child_graph: StateGraph<ChildState> = StateGraph::new();
child_graph
.add_node_simple(
"child_node",
crate::node::NodeFnUpdate(|_s: &ChildState| -> BoxResult<_> {
Box::pin(async move { Ok(ChildStateUpdate { value: Some(42) }) })
}),
)
.unwrap();
child_graph.set_entry_point("child_node");
child_graph.set_finish_point("child_node");
let compiled_child = Arc::new(child_graph.compile().unwrap());
let mut parent_graph: StateGraph<StateDummy> = StateGraph::new();
parent_graph
.add_node_simple(
"explicit_subgraph",
crate::node::NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
let result = parent_graph.add_subgraph_with_config(
"explicit_subgraph",
compiled_child,
|_parent: &StateDummy| ChildState { value: 0 },
|_child: &ChildState| StateDummyUpdate,
crate::subgraph::SubgraphConfig::default(),
);
assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
}
#[test]
fn test_add_node_with_retry() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(100),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(10),
jitter: true,
retry_on: None,
};
graph
.add_node_with_retry(
"retry_node",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
policy,
)
.unwrap();
assert!(graph.nodes.contains_key("retry_node"));
}
#[test]
fn test_add_node_with_error_handler() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
let handler = Arc::new(|_err: NodeError<StateDummy>| crate::Command::end());
graph
.add_node_with_error_handler(
"error_handler_node",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
handler,
)
.unwrap();
assert!(graph.nodes.contains_key("error_handler_node"));
}
#[test]
fn test_default_implementation() {
let graph: StateGraph<StateDummy> = StateGraph::default();
assert!(graph.nodes.is_empty());
assert!(graph.subgraphs.is_empty());
}
#[test]
fn test_validate_keys_empty_graph() {
let graph: StateGraph<StateDummy> = StateGraph::new();
graph.validate_keys().unwrap();
}
#[test]
fn test_validate_keys_valid_nodes() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"node_a",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph
.add_node_simple(
"node_b",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph.validate_keys().unwrap();
}
#[test]
fn test_validate_keys_empty_node_name() {
let graph: StateGraph<StateDummy> = StateGraph::new();
let result = graph.validate_keys();
result.unwrap();
}
#[test]
fn test_validate_keys_reserved_characters() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"node:test",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
let result = graph.validate_keys();
assert!(matches!(result, Err(TopologyError::InvalidNodeName { .. })));
}
#[test]
fn test_validate_keys_entry_point_not_found() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph.set_entry_point("nonexistent");
let result = graph.validate_keys();
assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
}
#[test]
fn test_validate_keys_finish_point_not_found() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"node_a",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph.set_finish_point("nonexistent");
let result = graph.validate_keys();
assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
}
#[test]
fn test_validate_keys_with_valid_entry_and_finish() {
let mut graph: StateGraph<StateDummy> = StateGraph::new();
graph
.add_node_simple(
"start",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph
.add_node_simple(
"end",
NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async move { Ok(StateDummyUpdate) })
}),
)
.unwrap();
graph.set_entry_point("start");
graph.set_finish_point("end");
graph.validate_keys().unwrap();
}
#[test]
fn test_validate_keys_catches_invalid_replace_field_index() {
let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
graph
.add_node_simple(
"node_a",
NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
}),
)
.unwrap();
graph.set_entry_point("node_a");
graph.set_finish_point("node_a");
let result = graph.validate_keys();
assert!(matches!(
result,
Err(TopologyError::InvalidFieldReference { .. })
));
if let Err(TopologyError::InvalidFieldReference {
index,
field_count,
context,
..
}) = result
{
assert_eq!(index, 5);
assert_eq!(field_count, 2);
assert_eq!(context, "replace_field_indices");
}
}
#[test]
fn test_validate_keys_catches_invalid_replace_after_finish_field_index() {
let mut graph: StateGraph<StateWithBadAfterFinishIndex> = StateGraph::new();
graph
.add_node_simple(
"node_a",
NodeFnUpdate(|_s: &StateWithBadAfterFinishIndex| -> BoxResult<_> {
Box::pin(async move { Ok(StateWithBadAfterFinishIndexUpdate::default()) })
}),
)
.unwrap();
graph.set_entry_point("node_a");
graph.set_finish_point("node_a");
let result = graph.validate_keys();
assert!(matches!(
result,
Err(TopologyError::InvalidFieldReference { .. })
));
if let Err(TopologyError::InvalidFieldReference {
index,
field_count,
context,
..
}) = result
{
assert_eq!(index, 99);
assert_eq!(field_count, 2);
assert_eq!(context, "replace_after_finish_field_indices");
}
}
#[derive(Clone, Debug, Default)]
struct StateWithBadReplaceIndex {
a: i32,
b: i32,
}
#[derive(Clone, Debug, Default)]
struct StateWithBadReplaceIndexUpdate {
a: Option<i32>,
b: Option<i32>,
}
impl crate::State for StateWithBadReplaceIndex {
type Update = StateWithBadReplaceIndexUpdate;
type FieldVersions = crate::state::FieldVersions;
fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
let mut changed = crate::FieldsChanged::default();
if let Some(v) = update.a {
self.a = v;
changed.set_field(0);
}
if let Some(v) = update.b {
self.b = v;
changed.set_field(1);
}
changed
}
fn reset_ephemeral(&mut self) {}
fn field_count() -> usize {
2
}
fn field_names() -> &'static [&'static str] {
&["a", "b"]
}
fn replace_field_indices() -> &'static [usize] {
&[5] }
}
#[derive(Clone, Debug, Default)]
struct StateWithBadAfterFinishIndex {
x: String,
y: String,
}
#[derive(Clone, Debug, Default)]
struct StateWithBadAfterFinishIndexUpdate {
x: Option<String>,
y: Option<String>,
}
impl crate::State for StateWithBadAfterFinishIndex {
type Update = StateWithBadAfterFinishIndexUpdate;
type FieldVersions = crate::state::FieldVersions;
fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
let mut changed = crate::FieldsChanged::default();
if let Some(v) = update.x {
self.x = v;
changed.set_field(0);
}
if let Some(v) = update.y {
self.y = v;
changed.set_field(1);
}
changed
}
fn reset_ephemeral(&mut self) {}
fn field_count() -> usize {
2
}
fn field_names() -> &'static [&'static str] {
&["x", "y"]
}
fn replace_after_finish_field_indices() -> &'static [usize] {
&[99] }
}
#[test]
fn test_compile_calls_validate_keys_and_catches_invalid_replace_field_index() {
let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
graph
.add_node_simple(
"node_a",
NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
}),
)
.unwrap();
graph.set_entry_point("node_a");
graph.set_finish_point("node_a");
let result = graph.compile();
assert!(matches!(
result,
Err(TopologyError::InvalidFieldReference { .. })
));
if let Err(TopologyError::InvalidFieldReference {
index,
field_count,
context,
..
}) = result
{
assert_eq!(index, 5);
assert_eq!(field_count, 2);
assert_eq!(context, "replace_field_indices");
}
}
#[test]
fn test_compile_calls_validate_keys_and_catches_invalid_replace_after_finish_field_index() {
let mut graph: StateGraph<StateWithBadAfterFinishIndex> = StateGraph::new();
graph
.add_node_simple(
"node_a",
NodeFnUpdate(|_s: &StateWithBadAfterFinishIndex| -> BoxResult<_> {
Box::pin(async move { Ok(StateWithBadAfterFinishIndexUpdate::default()) })
}),
)
.unwrap();
graph.set_entry_point("node_a");
graph.set_finish_point("node_a");
let result = graph.compile();
assert!(matches!(
result,
Err(TopologyError::InvalidFieldReference { .. })
));
if let Err(TopologyError::InvalidFieldReference {
index,
field_count,
context,
..
}) = result
{
assert_eq!(index, 99);
assert_eq!(field_count, 2);
assert_eq!(context, "replace_after_finish_field_indices");
}
}
#[test]
fn test_validate_keys_validates_reducer_indices_during_compile() {
let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
graph
.add_node_simple(
"process",
NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
}),
)
.unwrap();
graph.set_entry_point("process");
graph.set_finish_point("process");
let validate_result = graph.validate_keys();
assert!(
validate_result.is_err(),
"validate_keys should detect invalid field index"
);
let compile_result = graph.compile();
assert!(
compile_result.is_err(),
"compile should detect invalid field index"
);
match (validate_result, compile_result) {
(
Err(TopologyError::InvalidFieldReference { index: v_idx, .. }),
Err(TopologyError::InvalidFieldReference { index: c_idx, .. }),
) => {
assert_eq!(
v_idx, c_idx,
"Both methods should report the same invalid index"
);
}
_ => panic!("Both methods should return InvalidFieldReference error"),
}
}
#[derive(Clone, Debug, Default)]
struct StateDummy;
impl crate::State for StateDummy {
type Update = StateDummyUpdate;
type FieldVersions = crate::state::FieldVersions;
fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
crate::FieldsChanged(0)
}
fn reset_ephemeral(&mut self) {}
}
#[derive(Clone, Debug, Default)]
struct StateDummyUpdate;
#[tokio::test]
async fn test_execute_with_retry_succeeds_first_attempt() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
};
let config = crate::RunnableConfig::new();
let result = execute_with_retry(
"test_node",
&policy,
|_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
Box::pin(async { Ok(crate::Command::end()) })
},
&StateDummy,
&config,
)
.await;
result.unwrap();
}
#[tokio::test]
async fn test_execute_with_retry_succeeds_after_retries() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
};
let config = crate::RunnableConfig::new();
let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let result = execute_with_retry(
"test_node",
&policy,
move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
let counter = Arc::clone(&attempt_clone);
Box::pin(async move {
let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 2 {
Err(crate::JunctureError::execution("transient failure"))
} else {
Ok(crate::Command::end())
}
})
},
&StateDummy,
&config,
)
.await;
result.unwrap();
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
}
#[tokio::test]
async fn test_execute_with_retry_exhausts_attempts() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
};
let config = crate::RunnableConfig::new();
let result = execute_with_retry(
"test_node",
&policy,
|_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
Box::pin(async { Err(crate::JunctureError::execution("always fails")) })
},
&StateDummy,
&config,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_execution());
}
#[tokio::test]
async fn test_execute_with_retry_does_not_retry_cancelled() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
};
let config = crate::RunnableConfig::new();
let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let result = execute_with_retry(
"test_node",
&policy,
move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
let counter = Arc::clone(&attempt_clone);
Box::pin(async move {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Err(crate::JunctureError::cancelled())
})
},
&StateDummy,
&config,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_cancelled());
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_execute_with_retry_does_not_retry_interrupt() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
};
let config = crate::RunnableConfig::new();
let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let result = execute_with_retry(
"test_node",
&policy,
move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
let counter = Arc::clone(&attempt_clone);
Box::pin(async move {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Err(crate::JunctureError::interrupt("user input needed"))
})
},
&StateDummy,
&config,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_interrupt());
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_execute_with_retry_custom_retry_on_predicate() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
};
let config = crate::RunnableConfig::new();
let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let result = execute_with_retry(
"test_node",
&policy,
move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
let counter = Arc::clone(&attempt_clone);
Box::pin(async move {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Err(crate::JunctureError::execution("not a timeout"))
})
},
&StateDummy,
&config,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().is_execution());
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_execute_with_retry_custom_predicate_allows_retry() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
};
let config = crate::RunnableConfig::new();
let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let result = execute_with_retry(
"test_node",
&policy,
move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
let counter = Arc::clone(&attempt_clone);
Box::pin(async move {
let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 2 {
Err(crate::JunctureError::timeout("timed out"))
} else {
Ok(crate::Command::end())
}
})
},
&StateDummy,
&config,
)
.await;
result.unwrap();
assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
}
#[test]
fn test_compute_delay_no_jitter() {
let base = std::time::Duration::from_millis(100);
let max = std::time::Duration::from_secs(10);
let result = compute_delay(base, false, max);
assert_eq!(result, std::time::Duration::from_millis(100));
}
#[test]
fn test_compute_delay_caps_at_max() {
let base = std::time::Duration::from_secs(20);
let max = std::time::Duration::from_secs(10);
let result = compute_delay(base, false, max);
assert_eq!(result, std::time::Duration::from_secs(10));
}
#[test]
fn test_compute_delay_with_jitter_stays_within_range() {
let base = std::time::Duration::from_millis(100);
let max = std::time::Duration::from_secs(10);
for _ in 0..100 {
let result = compute_delay(base, true, max);
let millis = result.as_secs_f64() * 1000.0;
assert!(
(75.0..=125.0).contains(&millis),
"jittered delay {millis}ms outside expected range [75, 125]"
);
}
}
#[test]
fn test_compute_delay_jitter_capped_by_max() {
let base = std::time::Duration::from_millis(100);
let max = std::time::Duration::from_millis(50);
for _ in 0..100 {
let result = compute_delay(base, true, max);
assert!(
result <= max,
"jittered delay {result:?} exceeded max {max:?}",
);
}
}
#[test]
fn test_cap_delay_returns_min() {
let delay = std::time::Duration::from_secs(5);
let max = std::time::Duration::from_secs(10);
assert_eq!(cap_delay(delay, max), delay);
let delay_large = std::time::Duration::from_secs(15);
assert_eq!(cap_delay(delay_large, max), max);
}
#[test]
fn test_retry_policy_should_retry_default_allows_execution_errors() {
let policy = RetryPolicy::default();
let error = crate::JunctureError::execution("something went wrong");
assert!(policy.should_retry(&error));
}
#[test]
fn test_retry_policy_should_retry_default_blocks_cancelled() {
let policy = RetryPolicy::default();
let error = crate::JunctureError::cancelled();
assert!(!policy.should_retry(&error));
}
#[test]
fn test_retry_policy_should_retry_default_blocks_interrupt() {
let policy = RetryPolicy::default();
let error = crate::JunctureError::interrupt("waiting for user");
assert!(!policy.should_retry(&error));
}
#[test]
fn test_retry_policy_should_retry_custom_predicate() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(100),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(10),
jitter: false,
retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
};
assert!(policy.should_retry(&crate::JunctureError::timeout("slow")));
assert!(!policy.should_retry(&crate::JunctureError::execution("not timeout")));
}
#[tokio::test]
async fn test_retrying_node_delegates_to_execute_with_retry() {
use crate::node::NodeFnCommand;
let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = Arc::clone(&call_count);
let inner: Arc<dyn crate::Node<StateDummy>> =
NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
let counter = Arc::clone(&count_clone);
Box::pin(async move {
let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n == 0 {
Err(crate::JunctureError::execution("first try fails"))
} else {
Ok(crate::Command::end())
}
})
})
.into_node("inner");
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
};
let retrying_node = RetryingNode::new(inner, policy);
let config = crate::RunnableConfig::new();
let result = retrying_node.call(&StateDummy, &config).await;
result.unwrap();
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 2);
}
#[tokio::test]
async fn test_retrying_node_respects_max_attempts() {
use crate::node::NodeFnCommand;
let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = Arc::clone(&call_count);
let inner: Arc<dyn crate::Node<StateDummy>> =
NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
let counter = Arc::clone(&count_clone);
Box::pin(async move {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Err(crate::JunctureError::execution("always fails"))
})
})
.into_node("inner");
let policy = RetryPolicy {
max_attempts: 5,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: false,
retry_on: None,
};
let retrying_node = RetryingNode::new(inner, policy);
let config = crate::RunnableConfig::new();
let result = retrying_node.call(&StateDummy, &config).await;
let err = result.unwrap_err();
assert!(err.is_execution());
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 5);
}
#[tokio::test]
async fn test_retrying_node_with_jitter_enabled() {
use crate::node::NodeFnCommand;
let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = Arc::clone(&call_count);
let inner: Arc<dyn crate::Node<StateDummy>> =
NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
let counter = Arc::clone(&count_clone);
Box::pin(async move {
let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 2 {
Err(crate::JunctureError::execution("retry me"))
} else {
Ok(crate::Command::end())
}
})
})
.into_node("inner");
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(1),
backoff_factor: 2.0,
max_interval: std::time::Duration::from_secs(1),
jitter: true,
retry_on: None,
};
let retrying_node = RetryingNode::new(inner, policy);
let config = crate::RunnableConfig::new();
let result = retrying_node.call(&StateDummy, &config).await;
result.unwrap();
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 3);
}
#[tokio::test]
async fn test_execute_with_retry_max_interval_capping() {
let policy = RetryPolicy {
max_attempts: 3,
initial_interval: std::time::Duration::from_millis(50),
backoff_factor: 100.0,
max_interval: std::time::Duration::from_millis(80),
jitter: false,
retry_on: None,
};
let config = crate::RunnableConfig::new();
let start = crate::time::Instant::now();
let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let attempt_clone = Arc::clone(&attempt_count);
let result = execute_with_retry(
"test_node",
&policy,
move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
let counter = Arc::clone(&attempt_clone);
Box::pin(async move {
let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if n < 2 {
Err(crate::JunctureError::execution("fail"))
} else {
Ok(crate::Command::end())
}
})
},
&StateDummy,
&config,
)
.await;
let elapsed = start.elapsed();
result.unwrap();
assert!(
elapsed < std::time::Duration::from_secs(2),
"max_interval capping should prevent very long waits, elapsed: {elapsed:?}"
);
}
#[tokio::test]
async fn test_execute_with_timeout_succeeds_within_limit() {
let config = crate::RunnableConfig::new();
let result = execute_with_timeout(
"test_node",
std::time::Duration::from_secs(10),
|_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
Box::pin(async { Ok(crate::Command::end()) })
},
&StateDummy,
&config,
)
.await;
result.unwrap();
}
#[tokio::test]
async fn test_execute_with_timeout_fires_on_slow_node() {
let config = crate::RunnableConfig::new();
let result = execute_with_timeout(
"slow_node",
std::time::Duration::from_millis(10),
|_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
Box::pin(async {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
Ok(crate::Command::end())
})
},
&StateDummy,
&config,
)
.await;
let err = result.unwrap_err();
assert!(err.is_node_timeout());
}
#[tokio::test]
async fn test_execute_with_timeout_passes_through_inner_error() {
let config = crate::RunnableConfig::new();
let result = execute_with_timeout(
"failing_node",
std::time::Duration::from_secs(10),
|_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
Box::pin(async { Err(crate::JunctureError::execution("inner failure")) })
},
&StateDummy,
&config,
)
.await;
let err = result.unwrap_err();
assert!(err.is_execution());
assert!(!err.is_node_timeout());
}
#[tokio::test]
async fn test_timeout_node_wrapper_integration() {
use crate::node::NodeFnCommand;
let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
let count_clone = Arc::clone(&call_count);
let inner: Arc<dyn crate::Node<StateDummy>> =
NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
let counter = Arc::clone(&count_clone);
Box::pin(async move {
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(crate::Command::end())
})
})
.into_node("inner");
let policy =
crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_secs(10));
let timeout_node = TimeoutNode::new(inner, policy);
let config = crate::RunnableConfig::new();
let result = timeout_node.call(&StateDummy, &config).await;
result.unwrap();
assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_timeout_node_fires_on_exceeded_duration() {
use crate::node::NodeFnCommand;
let inner: Arc<dyn crate::Node<StateDummy>> =
NodeFnCommand(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async {
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
Ok(crate::Command::end())
})
})
.into_node("inner");
let policy =
crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_millis(10));
let timeout_node = TimeoutNode::new(inner, policy);
let config = crate::RunnableConfig::new();
let result = timeout_node.call(&StateDummy, &config).await;
let err = result.unwrap_err();
assert!(err.is_node_timeout());
}
#[tokio::test]
async fn test_timeout_node_passes_through_inner_error() {
use crate::node::NodeFnCommand;
let inner: Arc<dyn crate::Node<StateDummy>> =
NodeFnCommand(|_s: &StateDummy| -> BoxResult<_> {
Box::pin(async { Err(crate::JunctureError::execution("node failure")) })
})
.into_node("inner");
let policy =
crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_secs(10));
let timeout_node = TimeoutNode::new(inner, policy);
let config = crate::RunnableConfig::new();
let result = timeout_node.call(&StateDummy, &config).await;
let err = result.unwrap_err();
assert!(err.is_execution());
assert!(!err.is_node_timeout());
}
#[test]
fn circuit_breaker_config_new() {
let config = CircuitBreakerConfig::new(5, std::time::Duration::from_secs(30));
assert_eq!(config.failure_threshold, 5);
assert_eq!(config.cooldown_duration, std::time::Duration::from_secs(30));
assert_eq!(config.half_open_max_attempts, 1);
}
#[test]
fn circuit_breaker_config_with_half_open_max_attempts() {
let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10))
.with_half_open_max_attempts(3);
assert_eq!(config.half_open_max_attempts, 3);
}
#[test]
fn circuit_breaker_config_debug() {
let config = CircuitBreakerConfig::new(5, std::time::Duration::from_secs(30));
let debug = format!("{config:?}");
assert!(debug.contains("CircuitBreakerConfig"));
assert!(debug.contains('5'));
}
#[test]
fn circuit_breaker_state_new_is_closed() {
let state = CircuitBreakerState::new();
assert_eq!(*state.state(), CircuitState::Closed);
assert_eq!(state.consecutive_failures(), 0);
}
#[test]
fn circuit_breaker_state_default_is_closed() {
let state = CircuitBreakerState::default();
assert_eq!(*state.state(), CircuitState::Closed);
}
#[test]
fn circuit_breaker_closed_allows_execution() {
let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10));
let mut state = CircuitBreakerState::new();
assert!(state.should_allow(&config));
}
#[test]
fn circuit_breaker_opens_after_threshold() {
let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10));
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
assert_eq!(*state.state(), CircuitState::Closed);
assert!(state.should_allow(&config));
state.record_failure(&config);
assert_eq!(*state.state(), CircuitState::Closed);
assert!(state.should_allow(&config));
state.record_failure(&config);
assert_eq!(*state.state(), CircuitState::Open);
assert!(!state.should_allow(&config));
}
#[test]
fn circuit_breaker_resets_on_success() {
let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10));
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
state.record_failure(&config);
assert_eq!(state.consecutive_failures(), 2);
state.record_success();
assert_eq!(*state.state(), CircuitState::Closed);
assert_eq!(state.consecutive_failures(), 0);
}
#[test]
fn circuit_breaker_half_open_after_cooldown() {
let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0));
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
assert_eq!(*state.state(), CircuitState::Open);
assert!(state.should_allow(&config));
assert_eq!(*state.state(), CircuitState::HalfOpen);
}
#[test]
fn circuit_breaker_half_open_success_closes() {
let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0));
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
state.should_allow(&config); state.record_success();
assert_eq!(*state.state(), CircuitState::Closed);
}
#[test]
fn circuit_breaker_half_open_failure_reopens() {
let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0));
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
state.should_allow(&config); state.record_failure(&config);
assert_eq!(*state.state(), CircuitState::Open);
}
#[test]
fn circuit_breaker_half_open_limits_attempts() {
let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0))
.with_half_open_max_attempts(2);
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
assert!(state.should_allow(&config));
assert_eq!(*state.state(), CircuitState::HalfOpen);
assert!(state.should_allow(&config));
state.mark_half_open_attempt();
assert!(!state.should_allow(&config));
}
#[test]
fn circuit_breaker_open_blocks_until_cooldown() {
let config = CircuitBreakerConfig::new(1, std::time::Duration::from_secs(60));
let mut state = CircuitBreakerState::new();
state.record_failure(&config);
assert_eq!(*state.state(), CircuitState::Open);
assert!(!state.should_allow(&config));
}
#[test]
fn node_metadata_default_has_no_circuit_breaker() {
let meta = NodeMetadata::default();
assert!(meta.circuit_breaker.is_none());
}
#[test]
fn node_metadata_with_circuit_breaker() {
let meta = NodeMetadata {
circuit_breaker: Some(CircuitBreakerConfig::new(
5,
std::time::Duration::from_secs(30),
)),
..NodeMetadata::default()
};
assert!(meta.circuit_breaker.is_some());
assert_eq!(meta.circuit_breaker.as_ref().unwrap().failure_threshold, 5);
}
}