1use super::{compiled::CompiledGraph, topology::TopologyError, topology::TopologyValidator};
7use crate::{
8 State,
9 edge::{CompiledEdge, END, Edge, START, TriggerSource},
10 node::IntoNode,
11 state::{FromState, IntoState},
12};
13use indexmap::IndexMap;
14use std::collections::HashMap;
15use std::sync::Arc;
16
17#[derive(Clone, Debug, Default)]
35pub struct CompileConfig {
36 pub interrupt_before: Vec<String>,
42
43 pub interrupt_after: Vec<String>,
49}
50
51#[derive(Clone, Debug, Default)]
56pub struct NodeMetadata {
57 pub defer: bool,
59
60 pub metadata: Option<HashMap<String, serde_json::Value>>,
62
63 pub destinations: Option<Vec<String>>,
65
66 pub retry_policies: Vec<RetryPolicy>,
68
69 pub error_handler: Option<String>,
77
78 pub timeout_policies: Vec<crate::TimeoutPolicy>,
82
83 pub circuit_breaker: Option<CircuitBreakerConfig>,
89
90 pub fallback_node: Option<String>,
102}
103
104#[derive(Clone)]
109pub struct RetryPolicy {
110 pub max_attempts: u32,
112
113 pub initial_interval: std::time::Duration,
115
116 pub backoff_factor: f64,
118
119 pub max_interval: std::time::Duration,
121
122 pub jitter: bool,
124
125 #[allow(
127 clippy::type_complexity,
128 reason = "trait object requires full signature"
129 )]
130 pub retry_on: Option<Arc<dyn Fn(&crate::JunctureError) -> bool + Send + Sync>>,
131}
132
133impl std::fmt::Debug for RetryPolicy {
134 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
135 f.debug_struct("RetryPolicy")
136 .field("max_attempts", &self.max_attempts)
137 .field("initial_interval", &self.initial_interval)
138 .field("backoff_factor", &self.backoff_factor)
139 .field("max_interval", &self.max_interval)
140 .field("jitter", &self.jitter)
141 .field("retry_on", &self.retry_on.as_ref().map(|_| "<fn>"))
142 .finish()
143 }
144}
145
146impl Default for RetryPolicy {
147 fn default() -> Self {
148 Self {
149 max_attempts: 3,
150 initial_interval: std::time::Duration::from_millis(500),
151 backoff_factor: 2.0,
152 max_interval: std::time::Duration::from_secs(10),
153 jitter: true,
154 retry_on: None,
155 }
156 }
157}
158
159#[derive(Clone)]
183pub struct CircuitBreakerConfig {
184 pub failure_threshold: usize,
186
187 pub cooldown_duration: std::time::Duration,
189
190 pub half_open_max_attempts: usize,
192}
193
194impl std::fmt::Debug for CircuitBreakerConfig {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct("CircuitBreakerConfig")
197 .field("failure_threshold", &self.failure_threshold)
198 .field("cooldown_duration", &self.cooldown_duration)
199 .field("half_open_max_attempts", &self.half_open_max_attempts)
200 .finish()
201 }
202}
203
204impl CircuitBreakerConfig {
205 #[must_use]
212 pub const fn new(failure_threshold: usize, cooldown_duration: std::time::Duration) -> Self {
213 Self {
214 failure_threshold,
215 cooldown_duration,
216 half_open_max_attempts: 1,
217 }
218 }
219
220 #[must_use]
225 pub const fn with_half_open_max_attempts(mut self, max_attempts: usize) -> Self {
226 self.half_open_max_attempts = if max_attempts < 1 { 1 } else { max_attempts };
227 self
228 }
229}
230
231#[derive(Clone, Debug)]
236pub struct CircuitBreakerState {
237 state: CircuitState,
239
240 consecutive_failures: usize,
242
243 opened_at: Option<crate::time::Instant>,
245
246 half_open_attempts: usize,
248}
249
250#[derive(Clone, Debug, PartialEq, Eq)]
252pub enum CircuitState {
253 Closed,
255 Open,
257 HalfOpen,
259}
260
261impl CircuitBreakerState {
262 #[must_use]
264 pub const fn new() -> Self {
265 Self {
266 state: CircuitState::Closed,
267 consecutive_failures: 0,
268 opened_at: None,
269 half_open_attempts: 0,
270 }
271 }
272
273 pub fn should_allow(&mut self, config: &CircuitBreakerConfig) -> bool {
282 match self.state {
283 CircuitState::Closed => true,
284 CircuitState::Open => {
285 if let Some(opened) = self.opened_at {
287 if opened.elapsed() >= config.cooldown_duration {
288 self.state = CircuitState::HalfOpen;
289 self.half_open_attempts = 1;
293 true
294 } else {
295 false
296 }
297 } else {
298 self.state = CircuitState::Closed;
300 true
301 }
302 }
303 CircuitState::HalfOpen => {
304 self.half_open_attempts < config.half_open_max_attempts
307 }
308 }
309 }
310
311 pub fn mark_half_open_attempt(&mut self) {
317 if self.state == CircuitState::HalfOpen {
318 self.half_open_attempts += 1;
319 }
320 }
321
322 pub fn record_success(&mut self) {
328 if self.state == CircuitState::Open {
329 return;
330 }
331 self.consecutive_failures = 0;
332 self.half_open_attempts = 0;
333 self.state = CircuitState::Closed;
334 self.opened_at = None;
335 }
336
337 pub fn record_failure(&mut self, config: &CircuitBreakerConfig) {
339 self.consecutive_failures += 1;
340
341 match self.state {
342 CircuitState::Closed => {
343 if self.consecutive_failures >= config.failure_threshold {
344 self.state = CircuitState::Open;
345 self.opened_at = Some(crate::time::Instant::now());
346 }
347 }
348 CircuitState::HalfOpen => {
349 self.state = CircuitState::Open;
351 self.opened_at = Some(crate::time::Instant::now());
352 self.half_open_attempts = 0;
353 }
354 CircuitState::Open => {
355 }
357 }
358 }
359
360 #[must_use]
362 pub const fn state(&self) -> &CircuitState {
363 &self.state
364 }
365
366 #[must_use]
368 pub const fn consecutive_failures(&self) -> usize {
369 self.consecutive_failures
370 }
371}
372
373impl Default for CircuitBreakerState {
374 fn default() -> Self {
375 Self::new()
376 }
377}
378
379pub struct NodeError<S: State> {
384 pub node: String,
386
387 pub error: crate::JunctureError,
389
390 pub state: S,
392
393 pub attempt: u32,
395}
396
397impl<S: State> std::fmt::Debug for NodeError<S> {
398 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
399 f.debug_struct("NodeError")
400 .field("node", &self.node)
401 .field("error", &self.error)
402 .field("state", &"<state>")
403 .field("attempt", &self.attempt)
404 .finish()
405 }
406}
407
408pub struct ErrorHandlerNode<S: State> {
417 inner: Arc<dyn crate::Node<S>>,
419
420 #[allow(
425 clippy::type_complexity,
426 reason = "trait object requires full signature"
427 )]
428 handler: Arc<dyn Fn(NodeError<S>) -> crate::Command<S> + Send + Sync>,
429
430 name: String,
432}
433
434impl<S: State> std::fmt::Debug for ErrorHandlerNode<S> {
435 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
436 f.debug_struct("ErrorHandlerNode")
437 .field("name", &self.name)
438 .field("inner", &"<node>")
439 .field("handler", &"<fn>")
440 .finish()
441 }
442}
443
444impl<S: State> ErrorHandlerNode<S> {
445 #[allow(
453 clippy::type_complexity,
454 reason = "trait object requires full signature"
455 )]
456 pub fn new(
457 inner: Arc<dyn crate::Node<S>>,
458 handler: Arc<dyn Fn(NodeError<S>) -> crate::Command<S> + Send + Sync>,
459 ) -> Self {
460 let name = inner.name().to_string();
461 Self {
462 inner,
463 handler,
464 name,
465 }
466 }
467}
468
469impl<S: State + Clone> crate::Node<S> for ErrorHandlerNode<S> {
470 fn call(
471 &self,
472 state: &S,
473 config: &crate::RunnableConfig,
474 ) -> std::pin::Pin<
475 Box<
476 dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
477 + Send
478 + '_,
479 >,
480 > {
481 let state_backup = state.clone();
484 let result = self.inner.call(state, config);
485 let handler = Arc::clone(&self.handler);
486 let node_name = self.name.clone();
487 Box::pin(async move {
488 match result.await {
489 Ok(command) => Ok(command),
490 Err(error) => {
491 let node_error = NodeError {
493 node: node_name,
494 error,
495 state: state_backup,
496 attempt: 1, };
498 Ok(handler(node_error))
499 }
500 }
501 })
502 }
503
504 fn name(&self) -> &str {
505 &self.name
506 }
507}
508
509pub struct RetryingNode<S: State> {
514 inner: Arc<dyn crate::Node<S>>,
516
517 policy: RetryPolicy,
519
520 name: String,
522}
523
524impl<S: State> std::fmt::Debug for RetryingNode<S> {
525 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
526 f.debug_struct("RetryingNode")
527 .field("name", &self.name)
528 .field("inner", &"<node>")
529 .field("policy", &self.policy)
530 .finish()
531 }
532}
533
534impl<S: State> RetryingNode<S> {
535 #[must_use]
542 pub fn new(inner: Arc<dyn crate::Node<S>>, policy: RetryPolicy) -> Self {
543 let name = inner.name().to_string();
544 Self {
545 inner,
546 policy,
547 name,
548 }
549 }
550}
551
552impl<S: State + Clone> crate::Node<S> for RetryingNode<S> {
553 fn call(
554 &self,
555 state: &S,
556 config: &crate::RunnableConfig,
557 ) -> std::pin::Pin<
558 Box<
559 dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
560 + Send
561 + '_,
562 >,
563 > {
564 let policy = self.policy.clone();
565 let inner = Arc::clone(&self.inner);
566 let config = config.clone();
567 let node_name = self.name.clone();
568 let state_owned = state.clone();
569
570 Box::pin(async move {
571 execute_with_retry(
572 &node_name,
573 &policy,
574 |s, cfg| inner.call(s, cfg),
575 &state_owned,
576 &config,
577 )
578 .await
579 })
580 }
581
582 fn name(&self) -> &str {
583 &self.name
584 }
585}
586
587pub async fn execute_with_retry<S, F, Fut>(
622 node_name: &str,
623 policy: &RetryPolicy,
624 operation: F,
625 state: &S,
626 config: &crate::RunnableConfig,
627) -> Result<crate::Command<S>, crate::JunctureError>
628where
629 S: State,
630 F: Fn(&S, &crate::RunnableConfig) -> Fut,
631 Fut: std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>,
632{
633 let mut last_error: Option<crate::JunctureError> = None;
634 let mut delay = policy.initial_interval;
635
636 for attempt in 0..policy.max_attempts {
637 match operation(state, config).await {
638 Ok(command) => {
639 if attempt > 0 {
640 tracing::debug!(
641 node_name = node_name,
642 attempt = attempt + 1,
643 "node succeeded after retry"
644 );
645 }
646 return Ok(command);
647 }
648 Err(error) => {
649 let should_retry = policy.should_retry(&error);
650
651 if !should_retry || attempt + 1 >= policy.max_attempts {
652 return Err(error);
653 }
654
655 tracing::warn!(
656 node_name = node_name,
657 attempt = attempt + 1,
658 max_attempts = policy.max_attempts,
659 error = %error,
660 "node failed, will retry"
661 );
662
663 last_error = Some(error);
664
665 let actual_delay = compute_delay(delay, policy.jitter, policy.max_interval);
666 tokio::time::sleep(actual_delay).await;
667
668 delay = cap_delay(delay.mul_f64(policy.backoff_factor), policy.max_interval);
669 }
670 }
671 }
672
673 Err(last_error.unwrap_or_else(|| {
674 crate::JunctureError::execution(format!(
675 "node '{node_name}': retry policy exhausted with no error recorded"
676 ))
677 }))
678}
679
680fn compute_delay(
686 base: std::time::Duration,
687 jitter: bool,
688 max_interval: std::time::Duration,
689) -> std::time::Duration {
690 let capped = cap_delay(base, max_interval);
691
692 if !jitter {
693 return capped;
694 }
695
696 let jitter_fraction: f64 = rand::random_range(0.75..=1.25);
698 let jittered = capped.mul_f64(jitter_fraction);
699 cap_delay(jittered, max_interval)
700}
701
702fn cap_delay(delay: std::time::Duration, max: std::time::Duration) -> std::time::Duration {
704 delay.min(max)
705}
706
707impl RetryPolicy {
708 fn should_retry(&self, error: &crate::JunctureError) -> bool {
714 self.retry_on.as_ref().map_or_else(
715 || !error.is_cancelled() && !error.is_interrupt(),
716 |predicate| predicate(error),
717 )
718 }
719}
720
721pub struct TimeoutNode<S: State> {
727 inner: Arc<dyn crate::Node<S>>,
729
730 policy: crate::TimeoutPolicy,
732
733 name: String,
735}
736
737impl<S: State> std::fmt::Debug for TimeoutNode<S> {
738 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
739 f.debug_struct("TimeoutNode")
740 .field("name", &self.name)
741 .field("inner", &"<node>")
742 .field("policy", &self.policy)
743 .finish()
744 }
745}
746
747impl<S: State> TimeoutNode<S> {
748 #[must_use]
755 pub fn new(inner: Arc<dyn crate::Node<S>>, policy: crate::TimeoutPolicy) -> Self {
756 let name = inner.name().to_string();
757 Self {
758 inner,
759 policy,
760 name,
761 }
762 }
763}
764
765impl<S: State + Clone> crate::Node<S> for TimeoutNode<S> {
766 fn call(
767 &self,
768 state: &S,
769 config: &crate::RunnableConfig,
770 ) -> std::pin::Pin<
771 Box<
772 dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
773 + Send
774 + '_,
775 >,
776 > {
777 let inner = Arc::clone(&self.inner);
778 let config = config.clone();
779 let node_name = self.name.clone();
780 let run_timeout = self.policy.run_timeout;
781
782 let state_cloned = state.clone();
783 Box::pin(async move {
784 execute_with_timeout(
785 &node_name,
786 run_timeout,
787 |s, cfg| inner.call(s, cfg),
788 &state_cloned,
789 &config,
790 )
791 .await
792 })
793 }
794
795 fn name(&self) -> &str {
796 &self.name
797 }
798}
799
800pub async fn execute_with_timeout<S, F, Fut>(
834 node_name: &str,
835 run_timeout: std::time::Duration,
836 operation: F,
837 state: &S,
838 config: &crate::RunnableConfig,
839) -> Result<crate::Command<S>, crate::JunctureError>
840where
841 S: State,
842 F: FnOnce(&S, &crate::RunnableConfig) -> Fut,
843 Fut: std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>,
844{
845 let result = tokio::time::timeout(run_timeout, operation(state, config)).await;
846
847 match result {
848 Ok(Ok(command)) => Ok(command),
849 Ok(Err(error)) => Err(error),
850 Err(_) => Err(crate::JunctureError::node_timeout(
851 crate::error::NodeTimeoutError::RunTimeout {
852 node: node_name.to_string(),
853 timeout: u64::try_from(run_timeout.as_millis()).unwrap_or(u64::MAX),
854 },
855 )),
856 }
857}
858
859pub struct StateGraph<S: State, I: IntoState<S> = S, O: FromState<S> = S> {
883 nodes: IndexMap<String, Arc<dyn crate::Node<S>>>,
885
886 edges: Vec<Edge<S>>,
888
889 entry_point: Option<String>,
891
892 finish_points: Vec<String>,
894
895 builder_metadata: IndexMap<String, NodeMetadata>,
897
898 subgraphs: Vec<crate::subgraph::SubgraphMount<S>>,
900
901 _input: std::marker::PhantomData<I>,
903 _output: std::marker::PhantomData<O>,
905}
906
907impl<S: State, I: IntoState<S>, O: FromState<S>> std::fmt::Debug for StateGraph<S, I, O> {
908 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
909 f.debug_struct("StateGraph")
910 .field("nodes", &format_args!("{} nodes", self.nodes.len()))
911 .field("edges", &format_args!("{} edges", self.edges.len()))
912 .field("entry_point", &self.entry_point)
913 .field("finish_points", &self.finish_points)
914 .field("builder_metadata", &self.builder_metadata)
915 .field(
916 "subgraphs",
917 &format_args!("{} subgraphs", self.subgraphs.len()),
918 )
919 .finish()
920 }
921}
922
923impl<S: State, I: IntoState<S>, O: FromState<S>> StateGraph<S, I, O> {
924 #[must_use]
926 pub fn new() -> Self {
927 Self {
928 nodes: IndexMap::new(),
929 edges: Vec::new(),
930 entry_point: None,
931 finish_points: Vec::new(),
932 builder_metadata: IndexMap::new(),
933 subgraphs: Vec::new(),
934 _input: std::marker::PhantomData,
935 _output: std::marker::PhantomData,
936 }
937 }
938
939 #[expect(
951 clippy::too_many_arguments,
952 reason = "add_node requires name, node, defer, metadata, destinations, retry_policies, and timeout_policies. All are necessary for the builder pattern."
953 )]
954 pub fn add_node(
955 &mut self,
956 name: impl Into<String>,
957 node: impl IntoNode<S>,
958 defer: bool,
959 metadata: Option<HashMap<String, serde_json::Value>>,
960 destinations: Option<Vec<String>>,
961 retry_policies: Vec<RetryPolicy>,
962 timeout_policies: Vec<crate::TimeoutPolicy>,
963 ) -> Result<&mut Self, TopologyError> {
964 let name = name.into();
965 if self.nodes.contains_key(&name) {
966 return Err(TopologyError::DuplicateNode { name });
967 }
968
969 let node_arc = node.into_node(&name);
970 self.nodes.insert(name.clone(), node_arc);
971
972 self.builder_metadata.insert(
973 name,
974 NodeMetadata {
975 defer,
976 metadata,
977 destinations,
978 retry_policies,
979 error_handler: None,
980 timeout_policies,
981 circuit_breaker: None,
982 fallback_node: None,
983 },
984 );
985
986 Ok(self)
987 }
988
989 pub fn add_node_simple(
1004 &mut self,
1005 name: impl Into<String>,
1006 node: impl IntoNode<S>,
1007 ) -> Result<&mut Self, TopologyError> {
1008 self.add_node(name, node, false, None, None, Vec::new(), Vec::new())
1009 }
1010
1011 #[allow(
1031 clippy::type_complexity,
1032 reason = "trait object requires full signature"
1033 )]
1034 pub fn add_node_with_error_handler(
1035 &mut self,
1036 name: impl Into<String>,
1037 node: impl IntoNode<S>,
1038 handler: Arc<dyn Fn(super::builder::NodeError<S>) -> crate::Command<S> + Send + Sync>,
1039 ) -> Result<&mut Self, TopologyError>
1040 where
1041 S: Clone,
1042 {
1043 let name_str = name.into();
1044 let inner = node.into_node(&name_str);
1045 let wrapped: Arc<dyn crate::Node<S>> = Arc::new(ErrorHandlerNode::new(inner, handler));
1046
1047 if self.nodes.contains_key(&name_str) {
1048 return Err(TopologyError::DuplicateNode { name: name_str });
1049 }
1050
1051 self.nodes.insert(name_str.clone(), wrapped);
1052 self.builder_metadata
1053 .insert(name_str, NodeMetadata::default());
1054
1055 Ok(self)
1056 }
1057
1058 pub fn add_node_with_retry(
1075 &mut self,
1076 name: impl Into<String>,
1077 node: impl IntoNode<S>,
1078 policy: RetryPolicy,
1079 ) -> Result<&mut Self, TopologyError>
1080 where
1081 S: Clone,
1082 {
1083 let name_str = name.into();
1084 let inner = node.into_node(&name_str);
1085 let wrapped: Arc<dyn crate::Node<S>> = Arc::new(RetryingNode::new(inner, policy));
1086
1087 if self.nodes.contains_key(&name_str) {
1088 return Err(TopologyError::DuplicateNode { name: name_str });
1089 }
1090
1091 self.nodes.insert(name_str.clone(), wrapped);
1092 self.builder_metadata
1093 .insert(name_str, NodeMetadata::default());
1094
1095 Ok(self)
1096 }
1097
1098 pub fn add_node_with_circuit_breaker(
1116 &mut self,
1117 name: impl Into<String>,
1118 node: impl IntoNode<S>,
1119 config: CircuitBreakerConfig,
1120 ) -> Result<&mut Self, TopologyError> {
1121 let name_str = name.into();
1122 let node_arc = node.into_node(&name_str);
1123
1124 if self.nodes.contains_key(&name_str) {
1125 return Err(TopologyError::DuplicateNode { name: name_str });
1126 }
1127
1128 self.nodes.insert(name_str.clone(), node_arc);
1129 self.builder_metadata.insert(
1130 name_str,
1131 NodeMetadata {
1132 circuit_breaker: Some(config),
1133 ..NodeMetadata::default()
1134 },
1135 );
1136
1137 Ok(self)
1138 }
1139
1140 pub fn add_node_with_retry_and_circuit_breaker(
1159 &mut self,
1160 name: impl Into<String>,
1161 node: impl IntoNode<S>,
1162 retry_policy: RetryPolicy,
1163 circuit_breaker_config: CircuitBreakerConfig,
1164 ) -> Result<&mut Self, TopologyError>
1165 where
1166 S: Clone,
1167 {
1168 let name_str = name.into();
1169 let inner = node.into_node(&name_str);
1170 let wrapped: Arc<dyn crate::Node<S>> = Arc::new(RetryingNode::new(inner, retry_policy));
1171
1172 if self.nodes.contains_key(&name_str) {
1173 return Err(TopologyError::DuplicateNode { name: name_str });
1174 }
1175
1176 self.nodes.insert(name_str.clone(), wrapped);
1177 self.builder_metadata.insert(
1178 name_str,
1179 NodeMetadata {
1180 circuit_breaker: Some(circuit_breaker_config),
1181 ..NodeMetadata::default()
1182 },
1183 );
1184
1185 Ok(self)
1186 }
1187
1188 pub fn add_node_with_fallback(
1206 &mut self,
1207 name: impl Into<String>,
1208 node: impl IntoNode<S>,
1209 fallback: impl Into<String>,
1210 ) -> Result<&mut Self, TopologyError> {
1211 let name_str = name.into();
1212 let node_arc = node.into_node(&name_str);
1213
1214 if self.nodes.contains_key(&name_str) {
1215 return Err(TopologyError::DuplicateNode { name: name_str });
1216 }
1217
1218 self.nodes.insert(name_str.clone(), node_arc);
1219 self.builder_metadata.insert(
1220 name_str,
1221 NodeMetadata {
1222 fallback_node: Some(fallback.into()),
1223 ..NodeMetadata::default()
1224 },
1225 );
1226
1227 Ok(self)
1228 }
1229
1230 pub fn add_subgraph(
1245 &mut self,
1246 mount: crate::subgraph::SubgraphMount<S>,
1247 ) -> Result<&mut Self, TopologyError> {
1248 if self.nodes.contains_key(&mount.name) {
1249 return Err(TopologyError::DuplicateNode {
1250 name: mount.name.clone(),
1251 });
1252 }
1253
1254 let name = mount.name.clone();
1255 let node = Arc::clone(&mount.node);
1256 self.nodes.insert(name.clone(), node);
1257 self.builder_metadata.insert(name, NodeMetadata::default());
1258 self.subgraphs.push(mount);
1259
1260 Ok(self)
1261 }
1262
1263 #[allow(
1285 dead_code,
1286 reason = "fully implemented public API awaiting external consumers"
1287 )]
1288 pub fn add_subgraph_node<Sub>(
1289 &mut self,
1290 name: &str,
1291 subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
1292 ) -> Result<&mut Self, TopologyError>
1293 where
1294 Sub: crate::subgraph::StateSubset<S>
1295 + State
1296 + Clone
1297 + serde::Serialize
1298 + for<'de> serde::Deserialize<'de>,
1299 Sub::Update: serde::Serialize,
1300 S: Clone,
1301 {
1302 let input_map = Arc::new(move |parent: &S| Sub::extract(parent));
1306 let output_map = Arc::new(|_sub_output: &Sub| Sub::map_update(Default::default()));
1307
1308 let node: Arc<dyn crate::Node<S>> = Arc::new(crate::subgraph::SubgraphNode::new(
1310 subgraph,
1311 name.to_string(),
1312 input_map,
1313 output_map,
1314 crate::subgraph::SubgraphConfig::default(),
1315 ));
1316
1317 if self.nodes.contains_key(name) {
1318 return Err(TopologyError::DuplicateNode {
1319 name: name.to_string(),
1320 });
1321 }
1322
1323 self.nodes.insert(name.to_string(), node);
1324 self.builder_metadata
1325 .insert(name.to_string(), NodeMetadata::default());
1326
1327 Ok(self)
1328 }
1329
1330 #[allow(
1355 clippy::type_complexity,
1356 reason = "requires type erasure for trait object storage"
1357 )]
1358 pub fn add_subgraph_with_config<Sub>(
1359 &mut self,
1360 name: &str,
1361 subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
1362 input_map: impl Fn(&S) -> Sub + Send + Sync + 'static,
1363 output_map: impl Fn(&Sub) -> S::Update + Send + Sync + 'static,
1364 config: crate::subgraph::SubgraphConfig,
1365 ) -> Result<&mut Self, TopologyError>
1366 where
1367 Sub: State + serde::Serialize + for<'de> serde::Deserialize<'de>,
1368 Sub::Update: serde::Serialize,
1369 S: Clone,
1370 {
1371 let input_map_arc = Arc::new(input_map);
1372 let output_map_arc: Arc<dyn Fn(&Sub) -> S::Update + Send + Sync> = Arc::new(output_map);
1373
1374 let node: Arc<dyn crate::Node<S>> = Arc::new(crate::subgraph::SubgraphNode::new(
1376 subgraph,
1377 name.to_string(),
1378 input_map_arc,
1379 output_map_arc,
1380 config,
1381 ));
1382
1383 if self.nodes.contains_key(name) {
1384 return Err(TopologyError::DuplicateNode {
1385 name: name.to_string(),
1386 });
1387 }
1388
1389 self.nodes.insert(name.to_string(), node);
1390 self.builder_metadata
1391 .insert(name.to_string(), NodeMetadata::default());
1392
1393 Ok(self)
1394 }
1395
1396 #[allow(
1416 clippy::type_complexity,
1417 reason = "requires type erasure for trait object storage"
1418 )]
1419 pub fn add_subgraph_explicit<Sub>(
1420 &mut self,
1421 name: &str,
1422 subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
1423 input_map: impl Fn(&S) -> Sub + Send + Sync + 'static,
1424 output_map: impl Fn(&Sub) -> S::Update + Send + Sync + 'static,
1425 ) -> Result<&mut Self, TopologyError>
1426 where
1427 Sub: State + serde::Serialize + for<'de> serde::Deserialize<'de>,
1428 Sub::Update: serde::Serialize,
1429 S: Clone,
1430 {
1431 self.add_subgraph_with_config(
1432 name,
1433 subgraph,
1434 input_map,
1435 output_map,
1436 crate::subgraph::SubgraphConfig::default(),
1437 )
1438 }
1439
1440 pub fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) {
1452 self.edges.push(Edge::Fixed {
1453 from: from.into(),
1454 to: to.into(),
1455 });
1456 }
1457
1458 pub fn add_conditional_edges(
1483 &mut self,
1484 from: impl Into<String>,
1485 router: Arc<dyn crate::edge::Router<S>>,
1486 path_map: crate::edge::PathMap,
1487 ) {
1488 self.edges.push(Edge::Conditional {
1489 from: from.into(),
1490 router,
1491 path_map,
1492 });
1493 }
1494
1495 pub fn set_entry_point(&mut self, node: impl Into<String>) {
1505 let node = node.into();
1506 self.entry_point = Some(node.clone());
1507 self.edges.push(Edge::Fixed {
1508 from: START.to_string(),
1509 to: node,
1510 });
1511 }
1512
1513 pub fn set_finish_point(&mut self, node: impl Into<String>) {
1523 let node = node.into();
1524 self.finish_points.push(node.clone());
1525 self.edges.push(Edge::Fixed {
1526 from: node,
1527 to: END.to_string(),
1528 });
1529 }
1530
1531 pub fn add_sequence(&mut self, nodes: &[impl AsRef<str>]) -> Result<&mut Self, TopologyError> {
1548 if nodes.is_empty() {
1549 return Ok(self);
1550 }
1551
1552 let node_names: Vec<&str> = nodes.iter().map(std::convert::AsRef::as_ref).collect();
1553
1554 for name in &node_names {
1556 if !self.nodes.contains_key(*name) {
1557 return Err(TopologyError::NodeNotFound {
1558 name: (*name).to_string(),
1559 });
1560 }
1561 }
1562
1563 if self.entry_point.is_none() {
1565 self.set_entry_point(node_names[0]);
1566 }
1567
1568 for window in node_names.windows(2) {
1570 self.add_edge(window[0], window[1]);
1571 }
1572
1573 Ok(self)
1574 }
1575
1576 pub fn validate_keys(&self) -> Result<(), TopologyError> {
1593 for name in self.nodes.keys() {
1595 if name.is_empty() {
1596 return Err(TopologyError::InvalidNodeName {
1597 name: name.clone(),
1598 reason: "node name cannot be empty".to_string(),
1599 });
1600 }
1601
1602 if name.contains(':') || name.contains('/') || name.contains('\\') {
1604 return Err(TopologyError::InvalidNodeName {
1605 name: name.clone(),
1606 reason: "node name cannot contain ':', '/', or '\\'".to_string(),
1607 });
1608 }
1609 }
1610
1611 if let Some(ref entry) = self.entry_point
1613 && !self.nodes.contains_key(entry)
1614 {
1615 return Err(TopologyError::NodeNotFound {
1616 name: entry.clone(),
1617 });
1618 }
1619
1620 for finish in &self.finish_points {
1622 if !self.nodes.contains_key(finish) {
1623 return Err(TopologyError::NodeNotFound {
1624 name: finish.clone(),
1625 });
1626 }
1627 }
1628
1629 let field_count = S::field_count();
1631 let field_names = S::field_names();
1632
1633 for &idx in S::replace_field_indices() {
1634 if idx >= field_count {
1635 return Err(TopologyError::InvalidFieldReference {
1636 index: idx,
1637 field_count,
1638 field_names,
1639 context: "replace_field_indices".to_string(),
1640 });
1641 }
1642 }
1643
1644 for &idx in S::replace_after_finish_field_indices() {
1645 if idx >= field_count {
1646 return Err(TopologyError::InvalidFieldReference {
1647 index: idx,
1648 field_count,
1649 field_names,
1650 context: "replace_after_finish_field_indices".to_string(),
1651 });
1652 }
1653 }
1654
1655 Ok(())
1656 }
1657
1658 pub fn compile(&self) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1673 self.compile_inner(CompileConfig::default(), None)
1674 }
1675
1676 pub fn compile_with_config(
1698 &self,
1699 config: CompileConfig,
1700 ) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1701 self.compile_inner(config, None)
1702 }
1703
1704 pub fn compile_ephemeral(&self) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1713 self.compile_inner(CompileConfig::default(), None)
1714 }
1715
1716 pub fn compile_with_checkpointer(
1725 &self,
1726 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
1727 ) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1728 self.compile_inner(CompileConfig::default(), checkpointer)
1729 }
1730
1731 fn compile_inner(
1736 &self,
1737 config: CompileConfig,
1738 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
1739 ) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1740 TopologyValidator::validate(
1742 &self.nodes,
1743 &self.edges,
1744 self.entry_point.as_deref(),
1745 &self.builder_metadata,
1746 )?;
1747 self.validate_keys()?;
1748
1749 let trigger_table = self.build_trigger_table();
1751
1752 let subgraph_info: Vec<super::compiled::SubgraphInfo> = self
1754 .subgraphs
1755 .iter()
1756 .map(|mount| super::compiled::SubgraphInfo {
1757 name: mount.name.clone(),
1758 persistence: mount.config.persistence,
1759 })
1760 .collect();
1761
1762 Ok(CompiledGraph::new(
1764 self.nodes.clone(),
1765 trigger_table,
1766 self.builder_metadata.clone(),
1767 config.interrupt_before,
1768 config.interrupt_after,
1769 checkpointer,
1770 subgraph_info,
1771 ))
1772 }
1773
1774 fn build_trigger_table(&self) -> crate::edge::TriggerTable<S> {
1776 let mut trigger_table = crate::edge::TriggerTable::new();
1777
1778 for edge in &self.edges {
1779 match edge {
1780 Edge::Fixed { from, to } => {
1781 if from == START {
1782 trigger_table
1784 .add_incoming(to.clone(), TriggerSource::Edge { from: from.clone() });
1785 } else if to == END {
1786 } else {
1788 trigger_table
1790 .add_outgoing(from.clone(), CompiledEdge::Fixed { target: to.clone() });
1791 trigger_table
1792 .add_incoming(to.clone(), TriggerSource::Edge { from: from.clone() });
1793 }
1794 }
1795 Edge::Conditional {
1796 from,
1797 path_map,
1798 router,
1799 } => {
1800 let router = Arc::clone(router);
1801 let path_map = path_map.clone();
1802
1803 if from == START {
1804 for target in path_map.iter().map(|(_, v)| v) {
1806 trigger_table.add_incoming(
1807 target.clone(),
1808 TriggerSource::Edge { from: from.clone() },
1809 );
1810 }
1811 } else {
1812 trigger_table.add_outgoing(
1814 from.clone(),
1815 CompiledEdge::Conditional {
1816 router,
1817 path_map: path_map.clone(),
1818 },
1819 );
1820
1821 for target in path_map.iter().map(|(_, v)| v) {
1822 trigger_table.add_incoming(
1823 target.clone(),
1824 TriggerSource::Edge { from: from.clone() },
1825 );
1826 }
1827 }
1828 }
1829 }
1830 }
1831
1832 trigger_table
1833 }
1834}
1835
1836impl<S: State, I: IntoState<S>, O: FromState<S>> Default for StateGraph<S, I, O> {
1837 fn default() -> Self {
1838 Self::new()
1839 }
1840}
1841
1842#[cfg(test)]
1843mod tests {
1844 use super::*;
1845 use crate::Node;
1846 use crate::error::JunctureError;
1847 use crate::node::NodeFnUpdate;
1848 use std::pin::Pin;
1849
1850 type BoxResult<T> = Pin<Box<dyn Future<Output = Result<T, JunctureError>> + Send>>;
1852
1853 #[test]
1854 fn test_state_graph_new() {
1855 let graph: StateGraph<StateDummy> = StateGraph::new();
1856 assert!(graph.nodes.is_empty());
1857 assert!(graph.edges.is_empty());
1858 assert!(graph.entry_point.is_none());
1859 assert!(graph.subgraphs.is_empty());
1860 }
1861
1862 #[test]
1863 fn test_add_node_simple() {
1864 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1865 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1866 Box::pin(async move { Ok(StateDummyUpdate) })
1867 });
1868
1869 graph.add_node_simple("test", node).unwrap();
1870 assert!(graph.nodes.contains_key("test"));
1871 }
1872
1873 #[test]
1874 fn test_add_node_duplicate() {
1875 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1876
1877 graph
1878 .add_node_simple(
1879 "test",
1880 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1881 Box::pin(async move { Ok(StateDummyUpdate) })
1882 }),
1883 )
1884 .unwrap();
1885 let result = graph.add_node_simple(
1886 "test",
1887 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1888 Box::pin(async move { Ok(StateDummyUpdate) })
1889 }),
1890 );
1891 assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
1892 }
1893
1894 #[test]
1895 fn test_set_entry_point() {
1896 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1897 graph.set_entry_point("start");
1898 assert_eq!(graph.entry_point, Some("start".to_string()));
1899 assert_eq!(graph.edges.len(), 1);
1900 }
1901
1902 #[test]
1903 fn test_set_finish_point() {
1904 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1905 graph.set_finish_point("end");
1906 assert_eq!(graph.finish_points, vec!["end"]);
1907 assert_eq!(graph.edges.len(), 1);
1908 }
1909
1910 #[test]
1911 fn test_add_sequence() {
1912 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1913
1914 graph
1916 .add_node_simple(
1917 "a",
1918 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1919 Box::pin(async move { Ok(StateDummyUpdate) })
1920 }),
1921 )
1922 .unwrap();
1923 graph
1924 .add_node_simple(
1925 "b",
1926 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1927 Box::pin(async move { Ok(StateDummyUpdate) })
1928 }),
1929 )
1930 .unwrap();
1931 graph
1932 .add_node_simple(
1933 "c",
1934 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1935 Box::pin(async move { Ok(StateDummyUpdate) })
1936 }),
1937 )
1938 .unwrap();
1939
1940 graph.add_sequence(&["a", "b", "c"]).unwrap();
1942
1943 assert_eq!(graph.entry_point, Some("a".to_string()));
1944 assert_eq!(graph.edges.len(), 3); }
1946
1947 #[test]
1948 fn test_add_sequence_missing_node() {
1949 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1950 let result = graph.add_sequence(&["missing"]);
1951 assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
1952 }
1953
1954 #[test]
1955 fn test_compile_ephemeral() {
1956 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1957 graph
1958 .add_node_simple(
1959 "a",
1960 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1961 Box::pin(async move { Ok(StateDummyUpdate) })
1962 }),
1963 )
1964 .unwrap();
1965 graph.set_entry_point("a");
1966 graph.set_finish_point("a");
1967
1968 let compiled = graph.compile_ephemeral().unwrap();
1969 assert_eq!(compiled.nodes().len(), 1);
1970 }
1971
1972 #[test]
1973 fn test_add_subgraph() {
1974 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1975
1976 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1977 Box::pin(async move { Ok(StateDummyUpdate) })
1978 })
1979 .into_node("sub");
1980 let mount = crate::subgraph::SubgraphMount::new(
1981 "my_subgraph",
1982 crate::subgraph::SubgraphConfig::default(),
1983 node,
1984 );
1985
1986 graph.add_subgraph(mount).unwrap();
1987 assert!(graph.nodes.contains_key("my_subgraph"));
1988 assert_eq!(graph.subgraphs.len(), 1);
1989 }
1990
1991 #[test]
1992 fn test_compile_wires_subgraph_info() {
1993 use crate::subgraph::{SubgraphConfig, SubgraphMount, SubgraphPersistence};
1994
1995 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1996
1997 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1998 Box::pin(async move { Ok(StateDummyUpdate) })
1999 })
2000 .into_node("sub");
2001 let mount = SubgraphMount::new(
2002 "my_subgraph",
2003 SubgraphConfig {
2004 persistence: SubgraphPersistence::PerThread,
2005 },
2006 node,
2007 );
2008
2009 graph.add_subgraph(mount).unwrap();
2010 graph.set_entry_point("my_subgraph");
2011 graph.set_finish_point("my_subgraph");
2012
2013 let compiled = graph.compile().unwrap();
2014 let subgraphs = compiled.get_subgraphs();
2015 assert_eq!(subgraphs.len(), 1);
2016 assert_eq!(subgraphs[0].name, "my_subgraph");
2017 assert_eq!(subgraphs[0].persistence, SubgraphPersistence::PerThread);
2018 }
2019
2020 #[test]
2021 fn test_add_subgraph_duplicate() {
2022 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2023
2024 graph
2025 .add_node_simple(
2026 "my_subgraph",
2027 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2028 Box::pin(async move { Ok(StateDummyUpdate) })
2029 }),
2030 )
2031 .unwrap();
2032
2033 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2034 Box::pin(async move { Ok(StateDummyUpdate) })
2035 })
2036 .into_node("sub");
2037 let mount = crate::subgraph::SubgraphMount::new(
2038 "my_subgraph",
2039 crate::subgraph::SubgraphConfig::default(),
2040 node,
2041 );
2042
2043 let result = graph.add_subgraph(mount);
2044 assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
2045 }
2046
2047 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
2049 struct ChildState {
2050 value: i32,
2051 }
2052
2053 impl crate::State for ChildState {
2054 type Update = ChildStateUpdate;
2055 type FieldVersions = crate::state::FieldVersions;
2056
2057 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
2058 if let Some(v) = update.value {
2059 self.value = v;
2060 }
2061 crate::FieldsChanged(0)
2062 }
2063
2064 fn reset_ephemeral(&mut self) {}
2065 }
2066
2067 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
2068 struct ChildStateUpdate {
2069 value: Option<i32>,
2070 }
2071
2072 #[test]
2073 fn test_add_subgraph_with_config_registers_node() {
2074 let mut child_graph: StateGraph<ChildState> = StateGraph::new();
2075 child_graph
2076 .add_node_simple(
2077 "child_node",
2078 crate::node::NodeFnUpdate(|_s: &ChildState| -> BoxResult<_> {
2079 Box::pin(async move { Ok(ChildStateUpdate { value: Some(42) }) })
2080 }),
2081 )
2082 .unwrap();
2083 child_graph.set_entry_point("child_node");
2084 child_graph.set_finish_point("child_node");
2085
2086 let compiled_child = Arc::new(child_graph.compile().unwrap());
2087
2088 let mut parent_graph: StateGraph<StateDummy> = StateGraph::new();
2089 parent_graph
2090 .add_subgraph_with_config(
2091 "explicit_subgraph",
2092 compiled_child,
2093 |_parent: &StateDummy| ChildState { value: 0 },
2094 |_child: &ChildState| StateDummyUpdate,
2095 crate::subgraph::SubgraphConfig::default(),
2096 )
2097 .unwrap();
2098
2099 assert!(parent_graph.nodes.contains_key("explicit_subgraph"));
2100 }
2101
2102 #[test]
2103 fn test_add_subgraph_with_config_duplicate_node() {
2104 let mut child_graph: StateGraph<ChildState> = StateGraph::new();
2105 child_graph
2106 .add_node_simple(
2107 "child_node",
2108 crate::node::NodeFnUpdate(|_s: &ChildState| -> BoxResult<_> {
2109 Box::pin(async move { Ok(ChildStateUpdate { value: Some(42) }) })
2110 }),
2111 )
2112 .unwrap();
2113 child_graph.set_entry_point("child_node");
2114 child_graph.set_finish_point("child_node");
2115
2116 let compiled_child = Arc::new(child_graph.compile().unwrap());
2117
2118 let mut parent_graph: StateGraph<StateDummy> = StateGraph::new();
2119 parent_graph
2120 .add_node_simple(
2121 "explicit_subgraph",
2122 crate::node::NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2123 Box::pin(async move { Ok(StateDummyUpdate) })
2124 }),
2125 )
2126 .unwrap();
2127
2128 let result = parent_graph.add_subgraph_with_config(
2129 "explicit_subgraph",
2130 compiled_child,
2131 |_parent: &StateDummy| ChildState { value: 0 },
2132 |_child: &ChildState| StateDummyUpdate,
2133 crate::subgraph::SubgraphConfig::default(),
2134 );
2135
2136 assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
2137 }
2138
2139 #[test]
2140 fn test_add_node_with_retry() {
2141 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2142
2143 let policy = RetryPolicy {
2144 max_attempts: 3,
2145 initial_interval: std::time::Duration::from_millis(100),
2146 backoff_factor: 2.0,
2147 max_interval: std::time::Duration::from_secs(10),
2148 jitter: true,
2149 retry_on: None,
2150 };
2151
2152 graph
2153 .add_node_with_retry(
2154 "retry_node",
2155 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2156 Box::pin(async move { Ok(StateDummyUpdate) })
2157 }),
2158 policy,
2159 )
2160 .unwrap();
2161
2162 assert!(graph.nodes.contains_key("retry_node"));
2163 }
2164
2165 #[test]
2166 fn test_add_node_with_error_handler() {
2167 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2168
2169 let handler = Arc::new(|_err: NodeError<StateDummy>| crate::Command::end());
2170
2171 graph
2172 .add_node_with_error_handler(
2173 "error_handler_node",
2174 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2175 Box::pin(async move { Ok(StateDummyUpdate) })
2176 }),
2177 handler,
2178 )
2179 .unwrap();
2180
2181 assert!(graph.nodes.contains_key("error_handler_node"));
2182 }
2183
2184 #[test]
2185 fn test_default_implementation() {
2186 let graph: StateGraph<StateDummy> = StateGraph::default();
2187 assert!(graph.nodes.is_empty());
2188 assert!(graph.subgraphs.is_empty());
2189 }
2190
2191 #[test]
2192 fn test_validate_keys_empty_graph() {
2193 let graph: StateGraph<StateDummy> = StateGraph::new();
2194 graph.validate_keys().unwrap();
2195 }
2196
2197 #[test]
2198 fn test_validate_keys_valid_nodes() {
2199 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2200 graph
2201 .add_node_simple(
2202 "node_a",
2203 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2204 Box::pin(async move { Ok(StateDummyUpdate) })
2205 }),
2206 )
2207 .unwrap();
2208 graph
2209 .add_node_simple(
2210 "node_b",
2211 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2212 Box::pin(async move { Ok(StateDummyUpdate) })
2213 }),
2214 )
2215 .unwrap();
2216
2217 graph.validate_keys().unwrap();
2218 }
2219
2220 #[test]
2221 fn test_validate_keys_empty_node_name() {
2222 let graph: StateGraph<StateDummy> = StateGraph::new();
2223 let result = graph.validate_keys();
2226 result.unwrap();
2228 }
2229
2230 #[test]
2231 fn test_validate_keys_reserved_characters() {
2232 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2233
2234 graph
2236 .add_node_simple(
2237 "node:test",
2238 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2239 Box::pin(async move { Ok(StateDummyUpdate) })
2240 }),
2241 )
2242 .unwrap();
2243
2244 let result = graph.validate_keys();
2245 assert!(matches!(result, Err(TopologyError::InvalidNodeName { .. })));
2247 }
2248
2249 #[test]
2250 fn test_validate_keys_entry_point_not_found() {
2251 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2252 graph.set_entry_point("nonexistent");
2253
2254 let result = graph.validate_keys();
2255 assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
2256 }
2257
2258 #[test]
2259 fn test_validate_keys_finish_point_not_found() {
2260 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2261 graph
2262 .add_node_simple(
2263 "node_a",
2264 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2265 Box::pin(async move { Ok(StateDummyUpdate) })
2266 }),
2267 )
2268 .unwrap();
2269 graph.set_finish_point("nonexistent");
2270
2271 let result = graph.validate_keys();
2272 assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
2273 }
2274
2275 #[test]
2276 fn test_validate_keys_with_valid_entry_and_finish() {
2277 let mut graph: StateGraph<StateDummy> = StateGraph::new();
2278 graph
2279 .add_node_simple(
2280 "start",
2281 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2282 Box::pin(async move { Ok(StateDummyUpdate) })
2283 }),
2284 )
2285 .unwrap();
2286 graph
2287 .add_node_simple(
2288 "end",
2289 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
2290 Box::pin(async move { Ok(StateDummyUpdate) })
2291 }),
2292 )
2293 .unwrap();
2294 graph.set_entry_point("start");
2295 graph.set_finish_point("end");
2296
2297 graph.validate_keys().unwrap();
2298 }
2299
2300 #[test]
2301 fn test_validate_keys_catches_invalid_replace_field_index() {
2302 let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
2303 graph
2304 .add_node_simple(
2305 "node_a",
2306 NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
2307 Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
2308 }),
2309 )
2310 .unwrap();
2311 graph.set_entry_point("node_a");
2312 graph.set_finish_point("node_a");
2313
2314 let result = graph.validate_keys();
2315 assert!(matches!(
2316 result,
2317 Err(TopologyError::InvalidFieldReference { .. })
2318 ));
2319 if let Err(TopologyError::InvalidFieldReference {
2320 index,
2321 field_count,
2322 context,
2323 ..
2324 }) = result
2325 {
2326 assert_eq!(index, 5);
2327 assert_eq!(field_count, 2);
2328 assert_eq!(context, "replace_field_indices");
2329 }
2330 }
2331
2332 #[test]
2333 fn test_validate_keys_catches_invalid_replace_after_finish_field_index() {
2334 let mut graph: StateGraph<StateWithBadAfterFinishIndex> = StateGraph::new();
2335 graph
2336 .add_node_simple(
2337 "node_a",
2338 NodeFnUpdate(|_s: &StateWithBadAfterFinishIndex| -> BoxResult<_> {
2339 Box::pin(async move { Ok(StateWithBadAfterFinishIndexUpdate::default()) })
2340 }),
2341 )
2342 .unwrap();
2343 graph.set_entry_point("node_a");
2344 graph.set_finish_point("node_a");
2345
2346 let result = graph.validate_keys();
2347 assert!(matches!(
2348 result,
2349 Err(TopologyError::InvalidFieldReference { .. })
2350 ));
2351 if let Err(TopologyError::InvalidFieldReference {
2352 index,
2353 field_count,
2354 context,
2355 ..
2356 }) = result
2357 {
2358 assert_eq!(index, 99);
2359 assert_eq!(field_count, 2);
2360 assert_eq!(context, "replace_after_finish_field_indices");
2361 }
2362 }
2363
2364 #[derive(Clone, Debug, Default)]
2367 struct StateWithBadReplaceIndex {
2368 a: i32,
2369 b: i32,
2370 }
2371
2372 #[derive(Clone, Debug, Default)]
2373 struct StateWithBadReplaceIndexUpdate {
2374 a: Option<i32>,
2375 b: Option<i32>,
2376 }
2377
2378 impl crate::State for StateWithBadReplaceIndex {
2379 type Update = StateWithBadReplaceIndexUpdate;
2380 type FieldVersions = crate::state::FieldVersions;
2381
2382 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
2383 let mut changed = crate::FieldsChanged::default();
2384 if let Some(v) = update.a {
2385 self.a = v;
2386 changed.set_field(0);
2387 }
2388 if let Some(v) = update.b {
2389 self.b = v;
2390 changed.set_field(1);
2391 }
2392 changed
2393 }
2394
2395 fn reset_ephemeral(&mut self) {}
2396
2397 fn field_count() -> usize {
2398 2
2399 }
2400
2401 fn field_names() -> &'static [&'static str] {
2402 &["a", "b"]
2403 }
2404
2405 fn replace_field_indices() -> &'static [usize] {
2406 &[5] }
2408 }
2409
2410 #[derive(Clone, Debug, Default)]
2413 struct StateWithBadAfterFinishIndex {
2414 x: String,
2415 y: String,
2416 }
2417
2418 #[derive(Clone, Debug, Default)]
2419 struct StateWithBadAfterFinishIndexUpdate {
2420 x: Option<String>,
2421 y: Option<String>,
2422 }
2423
2424 impl crate::State for StateWithBadAfterFinishIndex {
2425 type Update = StateWithBadAfterFinishIndexUpdate;
2426 type FieldVersions = crate::state::FieldVersions;
2427
2428 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
2429 let mut changed = crate::FieldsChanged::default();
2430 if let Some(v) = update.x {
2431 self.x = v;
2432 changed.set_field(0);
2433 }
2434 if let Some(v) = update.y {
2435 self.y = v;
2436 changed.set_field(1);
2437 }
2438 changed
2439 }
2440
2441 fn reset_ephemeral(&mut self) {}
2442
2443 fn field_count() -> usize {
2444 2
2445 }
2446
2447 fn field_names() -> &'static [&'static str] {
2448 &["x", "y"]
2449 }
2450
2451 fn replace_after_finish_field_indices() -> &'static [usize] {
2452 &[99] }
2454 }
2455
2456 #[test]
2457 fn test_compile_calls_validate_keys_and_catches_invalid_replace_field_index() {
2458 let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
2459 graph
2460 .add_node_simple(
2461 "node_a",
2462 NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
2463 Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
2464 }),
2465 )
2466 .unwrap();
2467 graph.set_entry_point("node_a");
2468 graph.set_finish_point("node_a");
2469
2470 let result = graph.compile();
2472 assert!(matches!(
2473 result,
2474 Err(TopologyError::InvalidFieldReference { .. })
2475 ));
2476 if let Err(TopologyError::InvalidFieldReference {
2477 index,
2478 field_count,
2479 context,
2480 ..
2481 }) = result
2482 {
2483 assert_eq!(index, 5);
2484 assert_eq!(field_count, 2);
2485 assert_eq!(context, "replace_field_indices");
2486 }
2487 }
2488
2489 #[test]
2490 fn test_compile_calls_validate_keys_and_catches_invalid_replace_after_finish_field_index() {
2491 let mut graph: StateGraph<StateWithBadAfterFinishIndex> = StateGraph::new();
2492 graph
2493 .add_node_simple(
2494 "node_a",
2495 NodeFnUpdate(|_s: &StateWithBadAfterFinishIndex| -> BoxResult<_> {
2496 Box::pin(async move { Ok(StateWithBadAfterFinishIndexUpdate::default()) })
2497 }),
2498 )
2499 .unwrap();
2500 graph.set_entry_point("node_a");
2501 graph.set_finish_point("node_a");
2502
2503 let result = graph.compile();
2505 assert!(matches!(
2506 result,
2507 Err(TopologyError::InvalidFieldReference { .. })
2508 ));
2509 if let Err(TopologyError::InvalidFieldReference {
2510 index,
2511 field_count,
2512 context,
2513 ..
2514 }) = result
2515 {
2516 assert_eq!(index, 99);
2517 assert_eq!(field_count, 2);
2518 assert_eq!(context, "replace_after_finish_field_indices");
2519 }
2520 }
2521
2522 #[test]
2523 fn test_validate_keys_validates_reducer_indices_during_compile() {
2524 let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
2530 graph
2531 .add_node_simple(
2532 "process",
2533 NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
2534 Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
2535 }),
2536 )
2537 .unwrap();
2538 graph.set_entry_point("process");
2539 graph.set_finish_point("process");
2540
2541 let validate_result = graph.validate_keys();
2543 assert!(
2544 validate_result.is_err(),
2545 "validate_keys should detect invalid field index"
2546 );
2547
2548 let compile_result = graph.compile();
2550 assert!(
2551 compile_result.is_err(),
2552 "compile should detect invalid field index"
2553 );
2554
2555 match (validate_result, compile_result) {
2557 (
2558 Err(TopologyError::InvalidFieldReference { index: v_idx, .. }),
2559 Err(TopologyError::InvalidFieldReference { index: c_idx, .. }),
2560 ) => {
2561 assert_eq!(
2562 v_idx, c_idx,
2563 "Both methods should report the same invalid index"
2564 );
2565 }
2566 _ => panic!("Both methods should return InvalidFieldReference error"),
2567 }
2568 }
2569
2570 #[derive(Clone, Debug, Default)]
2571 struct StateDummy;
2572
2573 impl crate::State for StateDummy {
2574 type Update = StateDummyUpdate;
2575 type FieldVersions = crate::state::FieldVersions;
2576
2577 fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
2578 crate::FieldsChanged(0)
2579 }
2580
2581 fn reset_ephemeral(&mut self) {}
2582 }
2583
2584 #[derive(Clone, Debug, Default)]
2585 struct StateDummyUpdate;
2586
2587 #[tokio::test]
2590 async fn test_execute_with_retry_succeeds_first_attempt() {
2591 let policy = RetryPolicy {
2592 max_attempts: 3,
2593 initial_interval: std::time::Duration::from_millis(1),
2594 backoff_factor: 2.0,
2595 max_interval: std::time::Duration::from_secs(1),
2596 jitter: false,
2597 retry_on: None,
2598 };
2599 let config = crate::RunnableConfig::new();
2600
2601 let result = execute_with_retry(
2602 "test_node",
2603 &policy,
2604 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2605 Box::pin(async { Ok(crate::Command::end()) })
2606 },
2607 &StateDummy,
2608 &config,
2609 )
2610 .await;
2611
2612 result.unwrap();
2613 }
2614
2615 #[tokio::test]
2616 async fn test_execute_with_retry_succeeds_after_retries() {
2617 let policy = RetryPolicy {
2618 max_attempts: 3,
2619 initial_interval: std::time::Duration::from_millis(1),
2620 backoff_factor: 2.0,
2621 max_interval: std::time::Duration::from_secs(1),
2622 jitter: false,
2623 retry_on: None,
2624 };
2625 let config = crate::RunnableConfig::new();
2626
2627 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2628 let attempt_clone = Arc::clone(&attempt_count);
2629
2630 let result = execute_with_retry(
2631 "test_node",
2632 &policy,
2633 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2634 let counter = Arc::clone(&attempt_clone);
2635 Box::pin(async move {
2636 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2637 if n < 2 {
2638 Err(crate::JunctureError::execution("transient failure"))
2639 } else {
2640 Ok(crate::Command::end())
2641 }
2642 })
2643 },
2644 &StateDummy,
2645 &config,
2646 )
2647 .await;
2648
2649 result.unwrap();
2650 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
2651 }
2652
2653 #[tokio::test]
2654 async fn test_execute_with_retry_exhausts_attempts() {
2655 let policy = RetryPolicy {
2656 max_attempts: 3,
2657 initial_interval: std::time::Duration::from_millis(1),
2658 backoff_factor: 2.0,
2659 max_interval: std::time::Duration::from_secs(1),
2660 jitter: false,
2661 retry_on: None,
2662 };
2663 let config = crate::RunnableConfig::new();
2664
2665 let result = execute_with_retry(
2666 "test_node",
2667 &policy,
2668 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2669 Box::pin(async { Err(crate::JunctureError::execution("always fails")) })
2670 },
2671 &StateDummy,
2672 &config,
2673 )
2674 .await;
2675
2676 assert!(result.is_err());
2677 assert!(result.unwrap_err().is_execution());
2678 }
2679
2680 #[tokio::test]
2681 async fn test_execute_with_retry_does_not_retry_cancelled() {
2682 let policy = RetryPolicy {
2683 max_attempts: 3,
2684 initial_interval: std::time::Duration::from_millis(1),
2685 backoff_factor: 2.0,
2686 max_interval: std::time::Duration::from_secs(1),
2687 jitter: false,
2688 retry_on: None,
2689 };
2690 let config = crate::RunnableConfig::new();
2691
2692 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2693 let attempt_clone = Arc::clone(&attempt_count);
2694
2695 let result = execute_with_retry(
2696 "test_node",
2697 &policy,
2698 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2699 let counter = Arc::clone(&attempt_clone);
2700 Box::pin(async move {
2701 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2702 Err(crate::JunctureError::cancelled())
2703 })
2704 },
2705 &StateDummy,
2706 &config,
2707 )
2708 .await;
2709
2710 assert!(result.is_err());
2711 assert!(result.unwrap_err().is_cancelled());
2712 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2714 }
2715
2716 #[tokio::test]
2717 async fn test_execute_with_retry_does_not_retry_interrupt() {
2718 let policy = RetryPolicy {
2719 max_attempts: 3,
2720 initial_interval: std::time::Duration::from_millis(1),
2721 backoff_factor: 2.0,
2722 max_interval: std::time::Duration::from_secs(1),
2723 jitter: false,
2724 retry_on: None,
2725 };
2726 let config = crate::RunnableConfig::new();
2727
2728 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2729 let attempt_clone = Arc::clone(&attempt_count);
2730
2731 let result = execute_with_retry(
2732 "test_node",
2733 &policy,
2734 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2735 let counter = Arc::clone(&attempt_clone);
2736 Box::pin(async move {
2737 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2738 Err(crate::JunctureError::interrupt("user input needed"))
2739 })
2740 },
2741 &StateDummy,
2742 &config,
2743 )
2744 .await;
2745
2746 assert!(result.is_err());
2747 assert!(result.unwrap_err().is_interrupt());
2748 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2749 }
2750
2751 #[tokio::test]
2752 async fn test_execute_with_retry_custom_retry_on_predicate() {
2753 let policy = RetryPolicy {
2755 max_attempts: 3,
2756 initial_interval: std::time::Duration::from_millis(1),
2757 backoff_factor: 2.0,
2758 max_interval: std::time::Duration::from_secs(1),
2759 jitter: false,
2760 retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
2761 };
2762 let config = crate::RunnableConfig::new();
2763
2764 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2765 let attempt_clone = Arc::clone(&attempt_count);
2766
2767 let result = execute_with_retry(
2768 "test_node",
2769 &policy,
2770 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2771 let counter = Arc::clone(&attempt_clone);
2772 Box::pin(async move {
2773 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2774 Err(crate::JunctureError::execution("not a timeout"))
2776 })
2777 },
2778 &StateDummy,
2779 &config,
2780 )
2781 .await;
2782
2783 assert!(result.is_err());
2784 assert!(result.unwrap_err().is_execution());
2785 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2787 }
2788
2789 #[tokio::test]
2790 async fn test_execute_with_retry_custom_predicate_allows_retry() {
2791 let policy = RetryPolicy {
2793 max_attempts: 3,
2794 initial_interval: std::time::Duration::from_millis(1),
2795 backoff_factor: 2.0,
2796 max_interval: std::time::Duration::from_secs(1),
2797 jitter: false,
2798 retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
2799 };
2800 let config = crate::RunnableConfig::new();
2801
2802 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2803 let attempt_clone = Arc::clone(&attempt_count);
2804
2805 let result = execute_with_retry(
2806 "test_node",
2807 &policy,
2808 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2809 let counter = Arc::clone(&attempt_clone);
2810 Box::pin(async move {
2811 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2812 if n < 2 {
2813 Err(crate::JunctureError::timeout("timed out"))
2814 } else {
2815 Ok(crate::Command::end())
2816 }
2817 })
2818 },
2819 &StateDummy,
2820 &config,
2821 )
2822 .await;
2823
2824 result.unwrap();
2825 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
2826 }
2827
2828 #[test]
2829 fn test_compute_delay_no_jitter() {
2830 let base = std::time::Duration::from_millis(100);
2831 let max = std::time::Duration::from_secs(10);
2832 let result = compute_delay(base, false, max);
2833 assert_eq!(result, std::time::Duration::from_millis(100));
2834 }
2835
2836 #[test]
2837 fn test_compute_delay_caps_at_max() {
2838 let base = std::time::Duration::from_secs(20);
2839 let max = std::time::Duration::from_secs(10);
2840 let result = compute_delay(base, false, max);
2841 assert_eq!(result, std::time::Duration::from_secs(10));
2842 }
2843
2844 #[test]
2845 fn test_compute_delay_with_jitter_stays_within_range() {
2846 let base = std::time::Duration::from_millis(100);
2847 let max = std::time::Duration::from_secs(10);
2848 for _ in 0..100 {
2850 let result = compute_delay(base, true, max);
2851 let millis = result.as_secs_f64() * 1000.0;
2852 assert!(
2854 (75.0..=125.0).contains(&millis),
2855 "jittered delay {millis}ms outside expected range [75, 125]"
2856 );
2857 }
2858 }
2859
2860 #[test]
2861 fn test_compute_delay_jitter_capped_by_max() {
2862 let base = std::time::Duration::from_millis(100);
2863 let max = std::time::Duration::from_millis(50);
2865 for _ in 0..100 {
2866 let result = compute_delay(base, true, max);
2867 assert!(
2868 result <= max,
2869 "jittered delay {result:?} exceeded max {max:?}",
2870 );
2871 }
2872 }
2873
2874 #[test]
2875 fn test_cap_delay_returns_min() {
2876 let delay = std::time::Duration::from_secs(5);
2877 let max = std::time::Duration::from_secs(10);
2878 assert_eq!(cap_delay(delay, max), delay);
2879
2880 let delay_large = std::time::Duration::from_secs(15);
2881 assert_eq!(cap_delay(delay_large, max), max);
2882 }
2883
2884 #[test]
2885 fn test_retry_policy_should_retry_default_allows_execution_errors() {
2886 let policy = RetryPolicy::default();
2887 let error = crate::JunctureError::execution("something went wrong");
2888 assert!(policy.should_retry(&error));
2889 }
2890
2891 #[test]
2892 fn test_retry_policy_should_retry_default_blocks_cancelled() {
2893 let policy = RetryPolicy::default();
2894 let error = crate::JunctureError::cancelled();
2895 assert!(!policy.should_retry(&error));
2896 }
2897
2898 #[test]
2899 fn test_retry_policy_should_retry_default_blocks_interrupt() {
2900 let policy = RetryPolicy::default();
2901 let error = crate::JunctureError::interrupt("waiting for user");
2902 assert!(!policy.should_retry(&error));
2903 }
2904
2905 #[test]
2906 fn test_retry_policy_should_retry_custom_predicate() {
2907 let policy = RetryPolicy {
2908 max_attempts: 3,
2909 initial_interval: std::time::Duration::from_millis(100),
2910 backoff_factor: 2.0,
2911 max_interval: std::time::Duration::from_secs(10),
2912 jitter: false,
2913 retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
2914 };
2915
2916 assert!(policy.should_retry(&crate::JunctureError::timeout("slow")));
2917 assert!(!policy.should_retry(&crate::JunctureError::execution("not timeout")));
2918 }
2919
2920 #[tokio::test]
2921 async fn test_retrying_node_delegates_to_execute_with_retry() {
2922 use crate::node::NodeFnCommand;
2923
2924 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2925 let count_clone = Arc::clone(&call_count);
2926
2927 let inner: Arc<dyn crate::Node<StateDummy>> =
2928 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
2929 let counter = Arc::clone(&count_clone);
2930 Box::pin(async move {
2931 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2932 if n == 0 {
2933 Err(crate::JunctureError::execution("first try fails"))
2934 } else {
2935 Ok(crate::Command::end())
2936 }
2937 })
2938 })
2939 .into_node("inner");
2940
2941 let policy = RetryPolicy {
2942 max_attempts: 3,
2943 initial_interval: std::time::Duration::from_millis(1),
2944 backoff_factor: 2.0,
2945 max_interval: std::time::Duration::from_secs(1),
2946 jitter: false,
2947 retry_on: None,
2948 };
2949
2950 let retrying_node = RetryingNode::new(inner, policy);
2951 let config = crate::RunnableConfig::new();
2952
2953 let result = retrying_node.call(&StateDummy, &config).await;
2954 result.unwrap();
2955 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 2);
2956 }
2957
2958 #[tokio::test]
2959 async fn test_retrying_node_respects_max_attempts() {
2960 use crate::node::NodeFnCommand;
2961
2962 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2963 let count_clone = Arc::clone(&call_count);
2964
2965 let inner: Arc<dyn crate::Node<StateDummy>> =
2966 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
2967 let counter = Arc::clone(&count_clone);
2968 Box::pin(async move {
2969 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2970 Err(crate::JunctureError::execution("always fails"))
2971 })
2972 })
2973 .into_node("inner");
2974
2975 let policy = RetryPolicy {
2976 max_attempts: 5,
2977 initial_interval: std::time::Duration::from_millis(1),
2978 backoff_factor: 2.0,
2979 max_interval: std::time::Duration::from_secs(1),
2980 jitter: false,
2981 retry_on: None,
2982 };
2983
2984 let retrying_node = RetryingNode::new(inner, policy);
2985 let config = crate::RunnableConfig::new();
2986
2987 let result = retrying_node.call(&StateDummy, &config).await;
2988 let err = result.unwrap_err();
2989 assert!(err.is_execution());
2990 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 5);
2991 }
2992
2993 #[tokio::test]
2994 async fn test_retrying_node_with_jitter_enabled() {
2995 use crate::node::NodeFnCommand;
2996
2997 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2998 let count_clone = Arc::clone(&call_count);
2999
3000 let inner: Arc<dyn crate::Node<StateDummy>> =
3001 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
3002 let counter = Arc::clone(&count_clone);
3003 Box::pin(async move {
3004 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
3005 if n < 2 {
3006 Err(crate::JunctureError::execution("retry me"))
3007 } else {
3008 Ok(crate::Command::end())
3009 }
3010 })
3011 })
3012 .into_node("inner");
3013
3014 let policy = RetryPolicy {
3015 max_attempts: 3,
3016 initial_interval: std::time::Duration::from_millis(1),
3017 backoff_factor: 2.0,
3018 max_interval: std::time::Duration::from_secs(1),
3019 jitter: true,
3020 retry_on: None,
3021 };
3022
3023 let retrying_node = RetryingNode::new(inner, policy);
3024 let config = crate::RunnableConfig::new();
3025
3026 let result = retrying_node.call(&StateDummy, &config).await;
3027 result.unwrap();
3028 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 3);
3029 }
3030
3031 #[tokio::test]
3032 async fn test_execute_with_retry_max_interval_capping() {
3033 let policy = RetryPolicy {
3035 max_attempts: 3,
3036 initial_interval: std::time::Duration::from_millis(50),
3037 backoff_factor: 100.0,
3038 max_interval: std::time::Duration::from_millis(80),
3039 jitter: false,
3040 retry_on: None,
3041 };
3042 let config = crate::RunnableConfig::new();
3043
3044 let start = crate::time::Instant::now();
3045 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
3046 let attempt_clone = Arc::clone(&attempt_count);
3047
3048 let result = execute_with_retry(
3049 "test_node",
3050 &policy,
3051 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
3052 let counter = Arc::clone(&attempt_clone);
3053 Box::pin(async move {
3054 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
3055 if n < 2 {
3056 Err(crate::JunctureError::execution("fail"))
3057 } else {
3058 Ok(crate::Command::end())
3059 }
3060 })
3061 },
3062 &StateDummy,
3063 &config,
3064 )
3065 .await;
3066
3067 let elapsed = start.elapsed();
3068 result.unwrap();
3069 assert!(
3072 elapsed < std::time::Duration::from_secs(2),
3073 "max_interval capping should prevent very long waits, elapsed: {elapsed:?}"
3074 );
3075 }
3076
3077 #[tokio::test]
3080 async fn test_execute_with_timeout_succeeds_within_limit() {
3081 let config = crate::RunnableConfig::new();
3082
3083 let result = execute_with_timeout(
3084 "test_node",
3085 std::time::Duration::from_secs(10),
3086 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
3087 Box::pin(async { Ok(crate::Command::end()) })
3088 },
3089 &StateDummy,
3090 &config,
3091 )
3092 .await;
3093
3094 result.unwrap();
3095 }
3096
3097 #[tokio::test]
3098 async fn test_execute_with_timeout_fires_on_slow_node() {
3099 let config = crate::RunnableConfig::new();
3100
3101 let result = execute_with_timeout(
3102 "slow_node",
3103 std::time::Duration::from_millis(10),
3104 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
3105 Box::pin(async {
3106 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
3107 Ok(crate::Command::end())
3108 })
3109 },
3110 &StateDummy,
3111 &config,
3112 )
3113 .await;
3114
3115 let err = result.unwrap_err();
3116 assert!(err.is_node_timeout());
3117 }
3118
3119 #[tokio::test]
3120 async fn test_execute_with_timeout_passes_through_inner_error() {
3121 let config = crate::RunnableConfig::new();
3122
3123 let result = execute_with_timeout(
3124 "failing_node",
3125 std::time::Duration::from_secs(10),
3126 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
3127 Box::pin(async { Err(crate::JunctureError::execution("inner failure")) })
3128 },
3129 &StateDummy,
3130 &config,
3131 )
3132 .await;
3133
3134 let err = result.unwrap_err();
3135 assert!(err.is_execution());
3136 assert!(!err.is_node_timeout());
3137 }
3138
3139 #[tokio::test]
3140 async fn test_timeout_node_wrapper_integration() {
3141 use crate::node::NodeFnCommand;
3142
3143 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
3144 let count_clone = Arc::clone(&call_count);
3145
3146 let inner: Arc<dyn crate::Node<StateDummy>> =
3147 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
3148 let counter = Arc::clone(&count_clone);
3149 Box::pin(async move {
3150 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
3151 Ok(crate::Command::end())
3152 })
3153 })
3154 .into_node("inner");
3155
3156 let policy =
3157 crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_secs(10));
3158
3159 let timeout_node = TimeoutNode::new(inner, policy);
3160 let config = crate::RunnableConfig::new();
3161
3162 let result = timeout_node.call(&StateDummy, &config).await;
3163 result.unwrap();
3164 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
3165 }
3166
3167 #[tokio::test]
3168 async fn test_timeout_node_fires_on_exceeded_duration() {
3169 use crate::node::NodeFnCommand;
3170
3171 let inner: Arc<dyn crate::Node<StateDummy>> =
3172 NodeFnCommand(|_s: &StateDummy| -> BoxResult<_> {
3173 Box::pin(async {
3174 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
3175 Ok(crate::Command::end())
3176 })
3177 })
3178 .into_node("inner");
3179
3180 let policy =
3181 crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_millis(10));
3182
3183 let timeout_node = TimeoutNode::new(inner, policy);
3184 let config = crate::RunnableConfig::new();
3185
3186 let result = timeout_node.call(&StateDummy, &config).await;
3187 let err = result.unwrap_err();
3188 assert!(err.is_node_timeout());
3189 }
3190
3191 #[tokio::test]
3192 async fn test_timeout_node_passes_through_inner_error() {
3193 use crate::node::NodeFnCommand;
3194
3195 let inner: Arc<dyn crate::Node<StateDummy>> =
3196 NodeFnCommand(|_s: &StateDummy| -> BoxResult<_> {
3197 Box::pin(async { Err(crate::JunctureError::execution("node failure")) })
3198 })
3199 .into_node("inner");
3200
3201 let policy =
3202 crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_secs(10));
3203
3204 let timeout_node = TimeoutNode::new(inner, policy);
3205 let config = crate::RunnableConfig::new();
3206
3207 let result = timeout_node.call(&StateDummy, &config).await;
3208 let err = result.unwrap_err();
3209 assert!(err.is_execution());
3210 assert!(!err.is_node_timeout());
3211 }
3212
3213 #[test]
3216 fn circuit_breaker_config_new() {
3217 let config = CircuitBreakerConfig::new(5, std::time::Duration::from_secs(30));
3218 assert_eq!(config.failure_threshold, 5);
3219 assert_eq!(config.cooldown_duration, std::time::Duration::from_secs(30));
3220 assert_eq!(config.half_open_max_attempts, 1);
3221 }
3222
3223 #[test]
3224 fn circuit_breaker_config_with_half_open_max_attempts() {
3225 let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10))
3226 .with_half_open_max_attempts(3);
3227 assert_eq!(config.half_open_max_attempts, 3);
3228 }
3229
3230 #[test]
3231 fn circuit_breaker_config_debug() {
3232 let config = CircuitBreakerConfig::new(5, std::time::Duration::from_secs(30));
3233 let debug = format!("{config:?}");
3234 assert!(debug.contains("CircuitBreakerConfig"));
3235 assert!(debug.contains('5'));
3236 }
3237
3238 #[test]
3241 fn circuit_breaker_state_new_is_closed() {
3242 let state = CircuitBreakerState::new();
3243 assert_eq!(*state.state(), CircuitState::Closed);
3244 assert_eq!(state.consecutive_failures(), 0);
3245 }
3246
3247 #[test]
3248 fn circuit_breaker_state_default_is_closed() {
3249 let state = CircuitBreakerState::default();
3250 assert_eq!(*state.state(), CircuitState::Closed);
3251 }
3252
3253 #[test]
3254 fn circuit_breaker_closed_allows_execution() {
3255 let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10));
3256 let mut state = CircuitBreakerState::new();
3257 assert!(state.should_allow(&config));
3258 }
3259
3260 #[test]
3261 fn circuit_breaker_opens_after_threshold() {
3262 let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10));
3263 let mut state = CircuitBreakerState::new();
3264
3265 state.record_failure(&config);
3266 assert_eq!(*state.state(), CircuitState::Closed);
3267 assert!(state.should_allow(&config));
3268
3269 state.record_failure(&config);
3270 assert_eq!(*state.state(), CircuitState::Closed);
3271 assert!(state.should_allow(&config));
3272
3273 state.record_failure(&config);
3274 assert_eq!(*state.state(), CircuitState::Open);
3275 assert!(!state.should_allow(&config));
3276 }
3277
3278 #[test]
3279 fn circuit_breaker_resets_on_success() {
3280 let config = CircuitBreakerConfig::new(3, std::time::Duration::from_secs(10));
3281 let mut state = CircuitBreakerState::new();
3282
3283 state.record_failure(&config);
3284 state.record_failure(&config);
3285 assert_eq!(state.consecutive_failures(), 2);
3286
3287 state.record_success();
3288 assert_eq!(*state.state(), CircuitState::Closed);
3289 assert_eq!(state.consecutive_failures(), 0);
3290 }
3291
3292 #[test]
3293 fn circuit_breaker_half_open_after_cooldown() {
3294 let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0));
3295 let mut state = CircuitBreakerState::new();
3296
3297 state.record_failure(&config);
3298 assert_eq!(*state.state(), CircuitState::Open);
3299
3300 assert!(state.should_allow(&config));
3302 assert_eq!(*state.state(), CircuitState::HalfOpen);
3303 }
3304
3305 #[test]
3306 fn circuit_breaker_half_open_success_closes() {
3307 let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0));
3308 let mut state = CircuitBreakerState::new();
3309
3310 state.record_failure(&config);
3311 state.should_allow(&config); state.record_success();
3313 assert_eq!(*state.state(), CircuitState::Closed);
3314 }
3315
3316 #[test]
3317 fn circuit_breaker_half_open_failure_reopens() {
3318 let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0));
3319 let mut state = CircuitBreakerState::new();
3320
3321 state.record_failure(&config);
3322 state.should_allow(&config); state.record_failure(&config);
3324 assert_eq!(*state.state(), CircuitState::Open);
3325 }
3326
3327 #[test]
3328 fn circuit_breaker_half_open_limits_attempts() {
3329 let config = CircuitBreakerConfig::new(1, std::time::Duration::from_millis(0))
3330 .with_half_open_max_attempts(2);
3331 let mut state = CircuitBreakerState::new();
3332
3333 state.record_failure(&config);
3334
3335 assert!(state.should_allow(&config));
3337 assert_eq!(*state.state(), CircuitState::HalfOpen);
3338
3339 assert!(state.should_allow(&config));
3341 state.mark_half_open_attempt();
3342
3343 assert!(!state.should_allow(&config));
3345 }
3346
3347 #[test]
3348 fn circuit_breaker_open_blocks_until_cooldown() {
3349 let config = CircuitBreakerConfig::new(1, std::time::Duration::from_secs(60));
3350 let mut state = CircuitBreakerState::new();
3351
3352 state.record_failure(&config);
3353 assert_eq!(*state.state(), CircuitState::Open);
3354 assert!(!state.should_allow(&config));
3355 }
3356
3357 #[test]
3360 fn node_metadata_default_has_no_circuit_breaker() {
3361 let meta = NodeMetadata::default();
3362 assert!(meta.circuit_breaker.is_none());
3363 }
3364
3365 #[test]
3366 fn node_metadata_with_circuit_breaker() {
3367 let meta = NodeMetadata {
3368 circuit_breaker: Some(CircuitBreakerConfig::new(
3369 5,
3370 std::time::Duration::from_secs(30),
3371 )),
3372 ..NodeMetadata::default()
3373 };
3374 assert!(meta.circuit_breaker.is_some());
3375 assert_eq!(meta.circuit_breaker.as_ref().unwrap().failure_threshold, 5);
3376 }
3377}
3378
3379