1use super::{compiled::CompiledGraph, topology::TopologyError, topology::TopologyValidator};
7use crate::{
8 State,
9 edge::{CompiledEdge, END, Edge, START, TriggerSource},
10 node::IntoNode,
11 state::{FromState, IntoState},
12};
13use indexmap::IndexMap;
14use std::collections::HashMap;
15use std::sync::Arc;
16
17#[derive(Clone, Debug, Default)]
35pub struct CompileConfig {
36 pub interrupt_before: Vec<String>,
42
43 pub interrupt_after: Vec<String>,
49}
50
51#[derive(Clone, Debug, Default)]
56pub struct NodeMetadata {
57 pub defer: bool,
59
60 pub metadata: Option<HashMap<String, serde_json::Value>>,
62
63 pub destinations: Option<Vec<String>>,
65
66 pub retry_policies: Vec<RetryPolicy>,
68
69 pub error_handler: Option<String>,
77
78 pub timeout_policies: Vec<crate::TimeoutPolicy>,
82}
83
84#[derive(Clone)]
89pub struct RetryPolicy {
90 pub max_attempts: u32,
92
93 pub initial_interval: std::time::Duration,
95
96 pub backoff_factor: f64,
98
99 pub max_interval: std::time::Duration,
101
102 pub jitter: bool,
104
105 #[allow(
107 clippy::type_complexity,
108 reason = "trait object requires full signature"
109 )]
110 pub retry_on: Option<Arc<dyn Fn(&crate::JunctureError) -> bool + Send + Sync>>,
111}
112
113impl std::fmt::Debug for RetryPolicy {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("RetryPolicy")
116 .field("max_attempts", &self.max_attempts)
117 .field("initial_interval", &self.initial_interval)
118 .field("backoff_factor", &self.backoff_factor)
119 .field("max_interval", &self.max_interval)
120 .field("jitter", &self.jitter)
121 .field("retry_on", &self.retry_on.as_ref().map(|_| "<fn>"))
122 .finish()
123 }
124}
125
126impl Default for RetryPolicy {
127 fn default() -> Self {
128 Self {
129 max_attempts: 3,
130 initial_interval: std::time::Duration::from_millis(500),
131 backoff_factor: 2.0,
132 max_interval: std::time::Duration::from_secs(10),
133 jitter: true,
134 retry_on: None,
135 }
136 }
137}
138
139pub struct NodeError<S: State> {
144 pub node: String,
146
147 pub error: crate::JunctureError,
149
150 pub state: S,
152
153 pub attempt: u32,
155}
156
157impl<S: State> std::fmt::Debug for NodeError<S> {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("NodeError")
160 .field("node", &self.node)
161 .field("error", &self.error)
162 .field("state", &"<state>")
163 .field("attempt", &self.attempt)
164 .finish()
165 }
166}
167
168pub struct ErrorHandlerNode<S: State> {
177 inner: Arc<dyn crate::Node<S>>,
179
180 #[allow(
185 clippy::type_complexity,
186 reason = "trait object requires full signature"
187 )]
188 handler: Arc<dyn Fn(NodeError<S>) -> crate::Command<S> + Send + Sync>,
189
190 name: String,
192}
193
194impl<S: State> std::fmt::Debug for ErrorHandlerNode<S> {
195 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
196 f.debug_struct("ErrorHandlerNode")
197 .field("name", &self.name)
198 .field("inner", &"<node>")
199 .field("handler", &"<fn>")
200 .finish()
201 }
202}
203
204impl<S: State> ErrorHandlerNode<S> {
205 #[allow(
213 clippy::type_complexity,
214 reason = "trait object requires full signature"
215 )]
216 pub fn new(
217 inner: Arc<dyn crate::Node<S>>,
218 handler: Arc<dyn Fn(NodeError<S>) -> crate::Command<S> + Send + Sync>,
219 ) -> Self {
220 let name = inner.name().to_string();
221 Self {
222 inner,
223 handler,
224 name,
225 }
226 }
227}
228
229impl<S: State + Clone> crate::Node<S> for ErrorHandlerNode<S> {
230 fn call(
231 &self,
232 state: &S,
233 config: &crate::RunnableConfig,
234 ) -> std::pin::Pin<
235 Box<
236 dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
237 + Send
238 + '_,
239 >,
240 > {
241 let state_backup = state.clone();
244 let result = self.inner.call(state, config);
245 let handler = Arc::clone(&self.handler);
246 let node_name = self.name.clone();
247 Box::pin(async move {
248 match result.await {
249 Ok(command) => Ok(command),
250 Err(error) => {
251 let node_error = NodeError {
253 node: node_name,
254 error,
255 state: state_backup,
256 attempt: 1, };
258 Ok(handler(node_error))
259 }
260 }
261 })
262 }
263
264 fn name(&self) -> &str {
265 &self.name
266 }
267}
268
269pub struct RetryingNode<S: State> {
274 inner: Arc<dyn crate::Node<S>>,
276
277 policy: RetryPolicy,
279
280 name: String,
282}
283
284impl<S: State> std::fmt::Debug for RetryingNode<S> {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 f.debug_struct("RetryingNode")
287 .field("name", &self.name)
288 .field("inner", &"<node>")
289 .field("policy", &self.policy)
290 .finish()
291 }
292}
293
294impl<S: State> RetryingNode<S> {
295 #[must_use]
302 pub fn new(inner: Arc<dyn crate::Node<S>>, policy: RetryPolicy) -> Self {
303 let name = inner.name().to_string();
304 Self {
305 inner,
306 policy,
307 name,
308 }
309 }
310}
311
312impl<S: State + Clone> crate::Node<S> for RetryingNode<S> {
313 fn call(
314 &self,
315 state: &S,
316 config: &crate::RunnableConfig,
317 ) -> std::pin::Pin<
318 Box<
319 dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
320 + Send
321 + '_,
322 >,
323 > {
324 let policy = self.policy.clone();
325 let inner = Arc::clone(&self.inner);
326 let config = config.clone();
327 let node_name = self.name.clone();
328 let state_owned = state.clone();
329
330 Box::pin(async move {
331 execute_with_retry(
332 &node_name,
333 &policy,
334 |s, cfg| inner.call(s, cfg),
335 &state_owned,
336 &config,
337 )
338 .await
339 })
340 }
341
342 fn name(&self) -> &str {
343 &self.name
344 }
345}
346
347pub async fn execute_with_retry<S, F, Fut>(
382 node_name: &str,
383 policy: &RetryPolicy,
384 operation: F,
385 state: &S,
386 config: &crate::RunnableConfig,
387) -> Result<crate::Command<S>, crate::JunctureError>
388where
389 S: State,
390 F: Fn(&S, &crate::RunnableConfig) -> Fut,
391 Fut: std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>,
392{
393 let mut last_error: Option<crate::JunctureError> = None;
394 let mut delay = policy.initial_interval;
395
396 for attempt in 0..policy.max_attempts {
397 match operation(state, config).await {
398 Ok(command) => {
399 if attempt > 0 {
400 tracing::debug!(
401 node_name = node_name,
402 attempt = attempt + 1,
403 "node succeeded after retry"
404 );
405 }
406 return Ok(command);
407 }
408 Err(error) => {
409 let should_retry = policy.should_retry(&error);
410
411 if !should_retry || attempt + 1 >= policy.max_attempts {
412 return Err(error);
413 }
414
415 tracing::warn!(
416 node_name = node_name,
417 attempt = attempt + 1,
418 max_attempts = policy.max_attempts,
419 error = %error,
420 "node failed, will retry"
421 );
422
423 last_error = Some(error);
424
425 let actual_delay = compute_delay(delay, policy.jitter, policy.max_interval);
426 tokio::time::sleep(actual_delay).await;
427
428 delay = cap_delay(delay.mul_f64(policy.backoff_factor), policy.max_interval);
429 }
430 }
431 }
432
433 Err(last_error.unwrap_or_else(|| {
434 crate::JunctureError::execution(format!(
435 "node '{node_name}': retry policy exhausted with no error recorded"
436 ))
437 }))
438}
439
440fn compute_delay(
446 base: std::time::Duration,
447 jitter: bool,
448 max_interval: std::time::Duration,
449) -> std::time::Duration {
450 let capped = cap_delay(base, max_interval);
451
452 if !jitter {
453 return capped;
454 }
455
456 let jitter_fraction: f64 = rand::random_range(0.75..=1.25);
458 let jittered = capped.mul_f64(jitter_fraction);
459 cap_delay(jittered, max_interval)
460}
461
462fn cap_delay(delay: std::time::Duration, max: std::time::Duration) -> std::time::Duration {
464 delay.min(max)
465}
466
467impl RetryPolicy {
468 fn should_retry(&self, error: &crate::JunctureError) -> bool {
474 self.retry_on.as_ref().map_or_else(
475 || !error.is_cancelled() && !error.is_interrupt(),
476 |predicate| predicate(error),
477 )
478 }
479}
480
481pub struct TimeoutNode<S: State> {
487 inner: Arc<dyn crate::Node<S>>,
489
490 policy: crate::TimeoutPolicy,
492
493 name: String,
495}
496
497impl<S: State> std::fmt::Debug for TimeoutNode<S> {
498 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499 f.debug_struct("TimeoutNode")
500 .field("name", &self.name)
501 .field("inner", &"<node>")
502 .field("policy", &self.policy)
503 .finish()
504 }
505}
506
507impl<S: State> TimeoutNode<S> {
508 #[must_use]
515 pub fn new(inner: Arc<dyn crate::Node<S>>, policy: crate::TimeoutPolicy) -> Self {
516 let name = inner.name().to_string();
517 Self {
518 inner,
519 policy,
520 name,
521 }
522 }
523}
524
525impl<S: State + Clone> crate::Node<S> for TimeoutNode<S> {
526 fn call(
527 &self,
528 state: &S,
529 config: &crate::RunnableConfig,
530 ) -> std::pin::Pin<
531 Box<
532 dyn std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>
533 + Send
534 + '_,
535 >,
536 > {
537 let inner = Arc::clone(&self.inner);
538 let config = config.clone();
539 let node_name = self.name.clone();
540 let run_timeout = self.policy.run_timeout;
541
542 let state_cloned = state.clone();
543 Box::pin(async move {
544 execute_with_timeout(
545 &node_name,
546 run_timeout,
547 |s, cfg| inner.call(s, cfg),
548 &state_cloned,
549 &config,
550 )
551 .await
552 })
553 }
554
555 fn name(&self) -> &str {
556 &self.name
557 }
558}
559
560pub async fn execute_with_timeout<S, F, Fut>(
594 node_name: &str,
595 run_timeout: std::time::Duration,
596 operation: F,
597 state: &S,
598 config: &crate::RunnableConfig,
599) -> Result<crate::Command<S>, crate::JunctureError>
600where
601 S: State,
602 F: FnOnce(&S, &crate::RunnableConfig) -> Fut,
603 Fut: std::future::Future<Output = Result<crate::Command<S>, crate::JunctureError>>,
604{
605 let result = tokio::time::timeout(run_timeout, operation(state, config)).await;
606
607 match result {
608 Ok(Ok(command)) => Ok(command),
609 Ok(Err(error)) => Err(error),
610 Err(_) => Err(crate::JunctureError::node_timeout(
611 crate::error::NodeTimeoutError::RunTimeout {
612 node: node_name.to_string(),
613 timeout: u64::try_from(run_timeout.as_millis()).unwrap_or(u64::MAX),
614 },
615 )),
616 }
617}
618
619pub struct StateGraph<S: State, I: IntoState<S> = S, O: FromState<S> = S> {
643 nodes: IndexMap<String, Arc<dyn crate::Node<S>>>,
645
646 edges: Vec<Edge<S>>,
648
649 entry_point: Option<String>,
651
652 finish_points: Vec<String>,
654
655 builder_metadata: IndexMap<String, NodeMetadata>,
657
658 subgraphs: Vec<crate::subgraph::SubgraphMount<S>>,
660
661 _input: std::marker::PhantomData<I>,
663 _output: std::marker::PhantomData<O>,
665}
666
667impl<S: State, I: IntoState<S>, O: FromState<S>> std::fmt::Debug for StateGraph<S, I, O> {
668 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
669 f.debug_struct("StateGraph")
670 .field("nodes", &format_args!("{} nodes", self.nodes.len()))
671 .field("edges", &format_args!("{} edges", self.edges.len()))
672 .field("entry_point", &self.entry_point)
673 .field("finish_points", &self.finish_points)
674 .field("builder_metadata", &self.builder_metadata)
675 .field(
676 "subgraphs",
677 &format_args!("{} subgraphs", self.subgraphs.len()),
678 )
679 .finish()
680 }
681}
682
683impl<S: State, I: IntoState<S>, O: FromState<S>> StateGraph<S, I, O> {
684 #[must_use]
686 pub fn new() -> Self {
687 Self {
688 nodes: IndexMap::new(),
689 edges: Vec::new(),
690 entry_point: None,
691 finish_points: Vec::new(),
692 builder_metadata: IndexMap::new(),
693 subgraphs: Vec::new(),
694 _input: std::marker::PhantomData,
695 _output: std::marker::PhantomData,
696 }
697 }
698
699 #[expect(
711 clippy::too_many_arguments,
712 reason = "add_node requires name, node, defer, metadata, destinations, retry_policies, and timeout_policies. All are necessary for the builder pattern."
713 )]
714 pub fn add_node(
715 &mut self,
716 name: impl Into<String>,
717 node: impl IntoNode<S>,
718 defer: bool,
719 metadata: Option<HashMap<String, serde_json::Value>>,
720 destinations: Option<Vec<String>>,
721 retry_policies: Vec<RetryPolicy>,
722 timeout_policies: Vec<crate::TimeoutPolicy>,
723 ) -> Result<&mut Self, TopologyError> {
724 let name = name.into();
725 if self.nodes.contains_key(&name) {
726 return Err(TopologyError::DuplicateNode { name });
727 }
728
729 let node_arc = node.into_node(&name);
730 self.nodes.insert(name.clone(), node_arc);
731
732 self.builder_metadata.insert(
733 name,
734 NodeMetadata {
735 defer,
736 metadata,
737 destinations,
738 retry_policies,
739 error_handler: None,
740 timeout_policies,
741 },
742 );
743
744 Ok(self)
745 }
746
747 pub fn add_node_simple(
762 &mut self,
763 name: impl Into<String>,
764 node: impl IntoNode<S>,
765 ) -> Result<&mut Self, TopologyError> {
766 self.add_node(name, node, false, None, None, Vec::new(), Vec::new())
767 }
768
769 #[allow(
789 clippy::type_complexity,
790 reason = "trait object requires full signature"
791 )]
792 pub fn add_node_with_error_handler(
793 &mut self,
794 name: impl Into<String>,
795 node: impl IntoNode<S>,
796 handler: Arc<dyn Fn(super::builder::NodeError<S>) -> crate::Command<S> + Send + Sync>,
797 ) -> Result<&mut Self, TopologyError>
798 where
799 S: Clone,
800 {
801 let name_str = name.into();
802 let inner = node.into_node(&name_str);
803 let wrapped: Arc<dyn crate::Node<S>> = Arc::new(ErrorHandlerNode::new(inner, handler));
804
805 if self.nodes.contains_key(&name_str) {
806 return Err(TopologyError::DuplicateNode { name: name_str });
807 }
808
809 self.nodes.insert(name_str.clone(), wrapped);
810 self.builder_metadata
811 .insert(name_str, NodeMetadata::default());
812
813 Ok(self)
814 }
815
816 pub fn add_node_with_retry(
833 &mut self,
834 name: impl Into<String>,
835 node: impl IntoNode<S>,
836 policy: RetryPolicy,
837 ) -> Result<&mut Self, TopologyError>
838 where
839 S: Clone,
840 {
841 let name_str = name.into();
842 let inner = node.into_node(&name_str);
843 let wrapped: Arc<dyn crate::Node<S>> = Arc::new(RetryingNode::new(inner, policy));
844
845 if self.nodes.contains_key(&name_str) {
846 return Err(TopologyError::DuplicateNode { name: name_str });
847 }
848
849 self.nodes.insert(name_str.clone(), wrapped);
850 self.builder_metadata
851 .insert(name_str, NodeMetadata::default());
852
853 Ok(self)
854 }
855
856 pub fn add_subgraph(
871 &mut self,
872 mount: crate::subgraph::SubgraphMount<S>,
873 ) -> Result<&mut Self, TopologyError> {
874 if self.nodes.contains_key(&mount.name) {
875 return Err(TopologyError::DuplicateNode {
876 name: mount.name.clone(),
877 });
878 }
879
880 let name = mount.name.clone();
881 let node = Arc::clone(&mount.node);
882 self.nodes.insert(name.clone(), node);
883 self.builder_metadata.insert(name, NodeMetadata::default());
884 self.subgraphs.push(mount);
885
886 Ok(self)
887 }
888
889 #[allow(
911 dead_code,
912 reason = "fully implemented public API awaiting external consumers"
913 )]
914 pub fn add_subgraph_node<Sub>(
915 &mut self,
916 name: &str,
917 subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
918 ) -> Result<&mut Self, TopologyError>
919 where
920 Sub: crate::subgraph::StateSubset<S>
921 + State
922 + Clone
923 + serde::Serialize
924 + for<'de> serde::Deserialize<'de>,
925 Sub::Update: serde::Serialize,
926 S: Clone,
927 {
928 let input_map = Arc::new(move |parent: &S| Sub::extract(parent));
932 let output_map = Arc::new(|_sub_output: &Sub| Sub::map_update(Default::default()));
933
934 let node: Arc<dyn crate::Node<S>> = Arc::new(crate::subgraph::SubgraphNode::new(
936 subgraph,
937 name.to_string(),
938 input_map,
939 output_map,
940 crate::subgraph::SubgraphConfig::default(),
941 ));
942
943 if self.nodes.contains_key(name) {
944 return Err(TopologyError::DuplicateNode {
945 name: name.to_string(),
946 });
947 }
948
949 self.nodes.insert(name.to_string(), node);
950 self.builder_metadata
951 .insert(name.to_string(), NodeMetadata::default());
952
953 Ok(self)
954 }
955
956 #[allow(
981 clippy::type_complexity,
982 reason = "requires type erasure for trait object storage"
983 )]
984 pub fn add_subgraph_with_config<Sub>(
985 &mut self,
986 name: &str,
987 subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
988 input_map: impl Fn(&S) -> Sub + Send + Sync + 'static,
989 output_map: impl Fn(&Sub) -> S::Update + Send + Sync + 'static,
990 config: crate::subgraph::SubgraphConfig,
991 ) -> Result<&mut Self, TopologyError>
992 where
993 Sub: State + serde::Serialize + for<'de> serde::Deserialize<'de>,
994 Sub::Update: serde::Serialize,
995 S: Clone,
996 {
997 let input_map_arc = Arc::new(input_map);
998 let output_map_arc: Arc<dyn Fn(&Sub) -> S::Update + Send + Sync> = Arc::new(output_map);
999
1000 let node: Arc<dyn crate::Node<S>> = Arc::new(crate::subgraph::SubgraphNode::new(
1002 subgraph,
1003 name.to_string(),
1004 input_map_arc,
1005 output_map_arc,
1006 config,
1007 ));
1008
1009 if self.nodes.contains_key(name) {
1010 return Err(TopologyError::DuplicateNode {
1011 name: name.to_string(),
1012 });
1013 }
1014
1015 self.nodes.insert(name.to_string(), node);
1016 self.builder_metadata
1017 .insert(name.to_string(), NodeMetadata::default());
1018
1019 Ok(self)
1020 }
1021
1022 #[allow(
1042 clippy::type_complexity,
1043 reason = "requires type erasure for trait object storage"
1044 )]
1045 pub fn add_subgraph_explicit<Sub>(
1046 &mut self,
1047 name: &str,
1048 subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
1049 input_map: impl Fn(&S) -> Sub + Send + Sync + 'static,
1050 output_map: impl Fn(&Sub) -> S::Update + Send + Sync + 'static,
1051 ) -> Result<&mut Self, TopologyError>
1052 where
1053 Sub: State + serde::Serialize + for<'de> serde::Deserialize<'de>,
1054 Sub::Update: serde::Serialize,
1055 S: Clone,
1056 {
1057 self.add_subgraph_with_config(
1058 name,
1059 subgraph,
1060 input_map,
1061 output_map,
1062 crate::subgraph::SubgraphConfig::default(),
1063 )
1064 }
1065
1066 pub fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) {
1078 self.edges.push(Edge::Fixed {
1079 from: from.into(),
1080 to: to.into(),
1081 });
1082 }
1083
1084 pub fn add_conditional_edges(
1109 &mut self,
1110 from: impl Into<String>,
1111 router: Arc<dyn crate::edge::Router<S>>,
1112 path_map: crate::edge::PathMap,
1113 ) {
1114 self.edges.push(Edge::Conditional {
1115 from: from.into(),
1116 router,
1117 path_map,
1118 });
1119 }
1120
1121 pub fn set_entry_point(&mut self, node: impl Into<String>) {
1131 let node = node.into();
1132 self.entry_point = Some(node.clone());
1133 self.edges.push(Edge::Fixed {
1134 from: START.to_string(),
1135 to: node,
1136 });
1137 }
1138
1139 pub fn set_finish_point(&mut self, node: impl Into<String>) {
1149 let node = node.into();
1150 self.finish_points.push(node.clone());
1151 self.edges.push(Edge::Fixed {
1152 from: node,
1153 to: END.to_string(),
1154 });
1155 }
1156
1157 pub fn add_sequence(&mut self, nodes: &[impl AsRef<str>]) -> Result<&mut Self, TopologyError> {
1174 if nodes.is_empty() {
1175 return Ok(self);
1176 }
1177
1178 let node_names: Vec<&str> = nodes.iter().map(std::convert::AsRef::as_ref).collect();
1179
1180 for name in &node_names {
1182 if !self.nodes.contains_key(*name) {
1183 return Err(TopologyError::NodeNotFound {
1184 name: (*name).to_string(),
1185 });
1186 }
1187 }
1188
1189 if self.entry_point.is_none() {
1191 self.set_entry_point(node_names[0]);
1192 }
1193
1194 for window in node_names.windows(2) {
1196 self.add_edge(window[0], window[1]);
1197 }
1198
1199 Ok(self)
1200 }
1201
1202 pub fn validate_keys(&self) -> Result<(), TopologyError> {
1219 for name in self.nodes.keys() {
1221 if name.is_empty() {
1222 return Err(TopologyError::InvalidNodeName {
1223 name: name.clone(),
1224 reason: "node name cannot be empty".to_string(),
1225 });
1226 }
1227
1228 if name.contains(':') || name.contains('/') || name.contains('\\') {
1230 return Err(TopologyError::InvalidNodeName {
1231 name: name.clone(),
1232 reason: "node name cannot contain ':', '/', or '\\'".to_string(),
1233 });
1234 }
1235 }
1236
1237 if let Some(ref entry) = self.entry_point
1239 && !self.nodes.contains_key(entry)
1240 {
1241 return Err(TopologyError::NodeNotFound {
1242 name: entry.clone(),
1243 });
1244 }
1245
1246 for finish in &self.finish_points {
1248 if !self.nodes.contains_key(finish) {
1249 return Err(TopologyError::NodeNotFound {
1250 name: finish.clone(),
1251 });
1252 }
1253 }
1254
1255 let field_count = S::field_count();
1257 let field_names = S::field_names();
1258
1259 for &idx in S::replace_field_indices() {
1260 if idx >= field_count {
1261 return Err(TopologyError::InvalidFieldReference {
1262 index: idx,
1263 field_count,
1264 field_names,
1265 context: "replace_field_indices".to_string(),
1266 });
1267 }
1268 }
1269
1270 for &idx in S::replace_after_finish_field_indices() {
1271 if idx >= field_count {
1272 return Err(TopologyError::InvalidFieldReference {
1273 index: idx,
1274 field_count,
1275 field_names,
1276 context: "replace_after_finish_field_indices".to_string(),
1277 });
1278 }
1279 }
1280
1281 Ok(())
1282 }
1283
1284 pub fn compile(&self) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1299 self.compile_inner(CompileConfig::default(), None)
1300 }
1301
1302 pub fn compile_with_config(
1324 &self,
1325 config: CompileConfig,
1326 ) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1327 self.compile_inner(config, None)
1328 }
1329
1330 pub fn compile_ephemeral(&self) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1339 self.compile_inner(CompileConfig::default(), None)
1340 }
1341
1342 pub fn compile_with_checkpointer(
1351 &self,
1352 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
1353 ) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1354 self.compile_inner(CompileConfig::default(), checkpointer)
1355 }
1356
1357 fn compile_inner(
1362 &self,
1363 config: CompileConfig,
1364 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
1365 ) -> Result<CompiledGraph<S, I, O>, TopologyError> {
1366 TopologyValidator::validate(&self.nodes, &self.edges, self.entry_point.as_deref())?;
1368 self.validate_keys()?;
1369
1370 let trigger_table = self.build_trigger_table();
1372
1373 let subgraph_info: Vec<super::compiled::SubgraphInfo> = self
1375 .subgraphs
1376 .iter()
1377 .map(|mount| super::compiled::SubgraphInfo {
1378 name: mount.name.clone(),
1379 persistence: mount.config.persistence,
1380 })
1381 .collect();
1382
1383 Ok(CompiledGraph::new(
1385 self.nodes.clone(),
1386 trigger_table,
1387 self.builder_metadata.clone(),
1388 config.interrupt_before,
1389 config.interrupt_after,
1390 checkpointer,
1391 subgraph_info,
1392 ))
1393 }
1394
1395 fn build_trigger_table(&self) -> crate::edge::TriggerTable<S> {
1397 let mut trigger_table = crate::edge::TriggerTable::new();
1398
1399 for edge in &self.edges {
1400 match edge {
1401 Edge::Fixed { from, to } => {
1402 if from == START {
1403 trigger_table
1405 .add_incoming(to.clone(), TriggerSource::Edge { from: from.clone() });
1406 } else if to == END {
1407 } else {
1409 trigger_table
1411 .add_outgoing(from.clone(), CompiledEdge::Fixed { target: to.clone() });
1412 trigger_table
1413 .add_incoming(to.clone(), TriggerSource::Edge { from: from.clone() });
1414 }
1415 }
1416 Edge::Conditional {
1417 from,
1418 path_map,
1419 router,
1420 } => {
1421 let router = Arc::clone(router);
1422 let path_map = path_map.clone();
1423
1424 if from == START {
1425 for target in path_map.iter().map(|(_, v)| v) {
1427 trigger_table.add_incoming(
1428 target.clone(),
1429 TriggerSource::Edge { from: from.clone() },
1430 );
1431 }
1432 } else {
1433 trigger_table.add_outgoing(
1435 from.clone(),
1436 CompiledEdge::Conditional {
1437 router,
1438 path_map: path_map.clone(),
1439 },
1440 );
1441
1442 for target in path_map.iter().map(|(_, v)| v) {
1443 trigger_table.add_incoming(
1444 target.clone(),
1445 TriggerSource::Edge { from: from.clone() },
1446 );
1447 }
1448 }
1449 }
1450 }
1451 }
1452
1453 trigger_table
1454 }
1455}
1456
1457impl<S: State, I: IntoState<S>, O: FromState<S>> Default for StateGraph<S, I, O> {
1458 fn default() -> Self {
1459 Self::new()
1460 }
1461}
1462
1463#[cfg(test)]
1464mod tests {
1465 use super::*;
1466 use crate::Node;
1467 use crate::error::JunctureError;
1468 use crate::node::NodeFnUpdate;
1469 use std::pin::Pin;
1470
1471 type BoxResult<T> = Pin<Box<dyn Future<Output = Result<T, JunctureError>> + Send>>;
1473
1474 #[test]
1475 fn test_state_graph_new() {
1476 let graph: StateGraph<StateDummy> = StateGraph::new();
1477 assert!(graph.nodes.is_empty());
1478 assert!(graph.edges.is_empty());
1479 assert!(graph.entry_point.is_none());
1480 assert!(graph.subgraphs.is_empty());
1481 }
1482
1483 #[test]
1484 fn test_add_node_simple() {
1485 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1486 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1487 Box::pin(async move { Ok(StateDummyUpdate) })
1488 });
1489
1490 graph.add_node_simple("test", node).unwrap();
1491 assert!(graph.nodes.contains_key("test"));
1492 }
1493
1494 #[test]
1495 fn test_add_node_duplicate() {
1496 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1497
1498 graph
1499 .add_node_simple(
1500 "test",
1501 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1502 Box::pin(async move { Ok(StateDummyUpdate) })
1503 }),
1504 )
1505 .unwrap();
1506 let result = graph.add_node_simple(
1507 "test",
1508 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1509 Box::pin(async move { Ok(StateDummyUpdate) })
1510 }),
1511 );
1512 assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
1513 }
1514
1515 #[test]
1516 fn test_set_entry_point() {
1517 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1518 graph.set_entry_point("start");
1519 assert_eq!(graph.entry_point, Some("start".to_string()));
1520 assert_eq!(graph.edges.len(), 1);
1521 }
1522
1523 #[test]
1524 fn test_set_finish_point() {
1525 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1526 graph.set_finish_point("end");
1527 assert_eq!(graph.finish_points, vec!["end"]);
1528 assert_eq!(graph.edges.len(), 1);
1529 }
1530
1531 #[test]
1532 fn test_add_sequence() {
1533 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1534
1535 graph
1537 .add_node_simple(
1538 "a",
1539 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1540 Box::pin(async move { Ok(StateDummyUpdate) })
1541 }),
1542 )
1543 .unwrap();
1544 graph
1545 .add_node_simple(
1546 "b",
1547 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1548 Box::pin(async move { Ok(StateDummyUpdate) })
1549 }),
1550 )
1551 .unwrap();
1552 graph
1553 .add_node_simple(
1554 "c",
1555 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1556 Box::pin(async move { Ok(StateDummyUpdate) })
1557 }),
1558 )
1559 .unwrap();
1560
1561 graph.add_sequence(&["a", "b", "c"]).unwrap();
1563
1564 assert_eq!(graph.entry_point, Some("a".to_string()));
1565 assert_eq!(graph.edges.len(), 3); }
1567
1568 #[test]
1569 fn test_add_sequence_missing_node() {
1570 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1571 let result = graph.add_sequence(&["missing"]);
1572 assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
1573 }
1574
1575 #[test]
1576 fn test_compile_ephemeral() {
1577 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1578 graph
1579 .add_node_simple(
1580 "a",
1581 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1582 Box::pin(async move { Ok(StateDummyUpdate) })
1583 }),
1584 )
1585 .unwrap();
1586 graph.set_entry_point("a");
1587 graph.set_finish_point("a");
1588
1589 let compiled = graph.compile_ephemeral().unwrap();
1590 assert_eq!(compiled.nodes().len(), 1);
1591 }
1592
1593 #[test]
1594 fn test_add_subgraph() {
1595 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1596
1597 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1598 Box::pin(async move { Ok(StateDummyUpdate) })
1599 })
1600 .into_node("sub");
1601 let mount = crate::subgraph::SubgraphMount::new(
1602 "my_subgraph",
1603 crate::subgraph::SubgraphConfig::default(),
1604 node,
1605 );
1606
1607 graph.add_subgraph(mount).unwrap();
1608 assert!(graph.nodes.contains_key("my_subgraph"));
1609 assert_eq!(graph.subgraphs.len(), 1);
1610 }
1611
1612 #[test]
1613 fn test_compile_wires_subgraph_info() {
1614 use crate::subgraph::{SubgraphConfig, SubgraphMount, SubgraphPersistence};
1615
1616 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1617
1618 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1619 Box::pin(async move { Ok(StateDummyUpdate) })
1620 })
1621 .into_node("sub");
1622 let mount = SubgraphMount::new(
1623 "my_subgraph",
1624 SubgraphConfig {
1625 persistence: SubgraphPersistence::PerThread,
1626 },
1627 node,
1628 );
1629
1630 graph.add_subgraph(mount).unwrap();
1631 graph.set_entry_point("my_subgraph");
1632 graph.set_finish_point("my_subgraph");
1633
1634 let compiled = graph.compile().unwrap();
1635 let subgraphs = compiled.get_subgraphs();
1636 assert_eq!(subgraphs.len(), 1);
1637 assert_eq!(subgraphs[0].name, "my_subgraph");
1638 assert_eq!(subgraphs[0].persistence, SubgraphPersistence::PerThread);
1639 }
1640
1641 #[test]
1642 fn test_add_subgraph_duplicate() {
1643 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1644
1645 graph
1646 .add_node_simple(
1647 "my_subgraph",
1648 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1649 Box::pin(async move { Ok(StateDummyUpdate) })
1650 }),
1651 )
1652 .unwrap();
1653
1654 let node = NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1655 Box::pin(async move { Ok(StateDummyUpdate) })
1656 })
1657 .into_node("sub");
1658 let mount = crate::subgraph::SubgraphMount::new(
1659 "my_subgraph",
1660 crate::subgraph::SubgraphConfig::default(),
1661 node,
1662 );
1663
1664 let result = graph.add_subgraph(mount);
1665 assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
1666 }
1667
1668 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
1670 struct ChildState {
1671 value: i32,
1672 }
1673
1674 impl crate::State for ChildState {
1675 type Update = ChildStateUpdate;
1676 type FieldVersions = crate::state::FieldVersions;
1677
1678 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
1679 if let Some(v) = update.value {
1680 self.value = v;
1681 }
1682 crate::FieldsChanged(0)
1683 }
1684
1685 fn reset_ephemeral(&mut self) {}
1686 }
1687
1688 #[derive(Clone, Debug, Default, serde::Serialize, serde::Deserialize)]
1689 struct ChildStateUpdate {
1690 value: Option<i32>,
1691 }
1692
1693 #[test]
1694 fn test_add_subgraph_with_config_registers_node() {
1695 let mut child_graph: StateGraph<ChildState> = StateGraph::new();
1696 child_graph
1697 .add_node_simple(
1698 "child_node",
1699 crate::node::NodeFnUpdate(|_s: &ChildState| -> BoxResult<_> {
1700 Box::pin(async move { Ok(ChildStateUpdate { value: Some(42) }) })
1701 }),
1702 )
1703 .unwrap();
1704 child_graph.set_entry_point("child_node");
1705 child_graph.set_finish_point("child_node");
1706
1707 let compiled_child = Arc::new(child_graph.compile().unwrap());
1708
1709 let mut parent_graph: StateGraph<StateDummy> = StateGraph::new();
1710 parent_graph
1711 .add_subgraph_with_config(
1712 "explicit_subgraph",
1713 compiled_child,
1714 |_parent: &StateDummy| ChildState { value: 0 },
1715 |_child: &ChildState| StateDummyUpdate,
1716 crate::subgraph::SubgraphConfig::default(),
1717 )
1718 .unwrap();
1719
1720 assert!(parent_graph.nodes.contains_key("explicit_subgraph"));
1721 }
1722
1723 #[test]
1724 fn test_add_subgraph_with_config_duplicate_node() {
1725 let mut child_graph: StateGraph<ChildState> = StateGraph::new();
1726 child_graph
1727 .add_node_simple(
1728 "child_node",
1729 crate::node::NodeFnUpdate(|_s: &ChildState| -> BoxResult<_> {
1730 Box::pin(async move { Ok(ChildStateUpdate { value: Some(42) }) })
1731 }),
1732 )
1733 .unwrap();
1734 child_graph.set_entry_point("child_node");
1735 child_graph.set_finish_point("child_node");
1736
1737 let compiled_child = Arc::new(child_graph.compile().unwrap());
1738
1739 let mut parent_graph: StateGraph<StateDummy> = StateGraph::new();
1740 parent_graph
1741 .add_node_simple(
1742 "explicit_subgraph",
1743 crate::node::NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1744 Box::pin(async move { Ok(StateDummyUpdate) })
1745 }),
1746 )
1747 .unwrap();
1748
1749 let result = parent_graph.add_subgraph_with_config(
1750 "explicit_subgraph",
1751 compiled_child,
1752 |_parent: &StateDummy| ChildState { value: 0 },
1753 |_child: &ChildState| StateDummyUpdate,
1754 crate::subgraph::SubgraphConfig::default(),
1755 );
1756
1757 assert!(matches!(result, Err(TopologyError::DuplicateNode { .. })));
1758 }
1759
1760 #[test]
1761 fn test_add_node_with_retry() {
1762 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1763
1764 let policy = RetryPolicy {
1765 max_attempts: 3,
1766 initial_interval: std::time::Duration::from_millis(100),
1767 backoff_factor: 2.0,
1768 max_interval: std::time::Duration::from_secs(10),
1769 jitter: true,
1770 retry_on: None,
1771 };
1772
1773 graph
1774 .add_node_with_retry(
1775 "retry_node",
1776 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1777 Box::pin(async move { Ok(StateDummyUpdate) })
1778 }),
1779 policy,
1780 )
1781 .unwrap();
1782
1783 assert!(graph.nodes.contains_key("retry_node"));
1784 }
1785
1786 #[test]
1787 fn test_add_node_with_error_handler() {
1788 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1789
1790 let handler = Arc::new(|_err: NodeError<StateDummy>| crate::Command::end());
1791
1792 graph
1793 .add_node_with_error_handler(
1794 "error_handler_node",
1795 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1796 Box::pin(async move { Ok(StateDummyUpdate) })
1797 }),
1798 handler,
1799 )
1800 .unwrap();
1801
1802 assert!(graph.nodes.contains_key("error_handler_node"));
1803 }
1804
1805 #[test]
1806 fn test_default_implementation() {
1807 let graph: StateGraph<StateDummy> = StateGraph::default();
1808 assert!(graph.nodes.is_empty());
1809 assert!(graph.subgraphs.is_empty());
1810 }
1811
1812 #[test]
1813 fn test_validate_keys_empty_graph() {
1814 let graph: StateGraph<StateDummy> = StateGraph::new();
1815 graph.validate_keys().unwrap();
1816 }
1817
1818 #[test]
1819 fn test_validate_keys_valid_nodes() {
1820 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1821 graph
1822 .add_node_simple(
1823 "node_a",
1824 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1825 Box::pin(async move { Ok(StateDummyUpdate) })
1826 }),
1827 )
1828 .unwrap();
1829 graph
1830 .add_node_simple(
1831 "node_b",
1832 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1833 Box::pin(async move { Ok(StateDummyUpdate) })
1834 }),
1835 )
1836 .unwrap();
1837
1838 graph.validate_keys().unwrap();
1839 }
1840
1841 #[test]
1842 fn test_validate_keys_empty_node_name() {
1843 let graph: StateGraph<StateDummy> = StateGraph::new();
1844 let result = graph.validate_keys();
1847 result.unwrap();
1849 }
1850
1851 #[test]
1852 fn test_validate_keys_reserved_characters() {
1853 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1854
1855 graph
1857 .add_node_simple(
1858 "node:test",
1859 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1860 Box::pin(async move { Ok(StateDummyUpdate) })
1861 }),
1862 )
1863 .unwrap();
1864
1865 let result = graph.validate_keys();
1866 assert!(matches!(result, Err(TopologyError::InvalidNodeName { .. })));
1868 }
1869
1870 #[test]
1871 fn test_validate_keys_entry_point_not_found() {
1872 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1873 graph.set_entry_point("nonexistent");
1874
1875 let result = graph.validate_keys();
1876 assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
1877 }
1878
1879 #[test]
1880 fn test_validate_keys_finish_point_not_found() {
1881 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1882 graph
1883 .add_node_simple(
1884 "node_a",
1885 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1886 Box::pin(async move { Ok(StateDummyUpdate) })
1887 }),
1888 )
1889 .unwrap();
1890 graph.set_finish_point("nonexistent");
1891
1892 let result = graph.validate_keys();
1893 assert!(matches!(result, Err(TopologyError::NodeNotFound { .. })));
1894 }
1895
1896 #[test]
1897 fn test_validate_keys_with_valid_entry_and_finish() {
1898 let mut graph: StateGraph<StateDummy> = StateGraph::new();
1899 graph
1900 .add_node_simple(
1901 "start",
1902 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1903 Box::pin(async move { Ok(StateDummyUpdate) })
1904 }),
1905 )
1906 .unwrap();
1907 graph
1908 .add_node_simple(
1909 "end",
1910 NodeFnUpdate(|_s: &StateDummy| -> BoxResult<_> {
1911 Box::pin(async move { Ok(StateDummyUpdate) })
1912 }),
1913 )
1914 .unwrap();
1915 graph.set_entry_point("start");
1916 graph.set_finish_point("end");
1917
1918 graph.validate_keys().unwrap();
1919 }
1920
1921 #[test]
1922 fn test_validate_keys_catches_invalid_replace_field_index() {
1923 let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
1924 graph
1925 .add_node_simple(
1926 "node_a",
1927 NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
1928 Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
1929 }),
1930 )
1931 .unwrap();
1932 graph.set_entry_point("node_a");
1933 graph.set_finish_point("node_a");
1934
1935 let result = graph.validate_keys();
1936 assert!(matches!(
1937 result,
1938 Err(TopologyError::InvalidFieldReference { .. })
1939 ));
1940 if let Err(TopologyError::InvalidFieldReference {
1941 index,
1942 field_count,
1943 context,
1944 ..
1945 }) = result
1946 {
1947 assert_eq!(index, 5);
1948 assert_eq!(field_count, 2);
1949 assert_eq!(context, "replace_field_indices");
1950 }
1951 }
1952
1953 #[test]
1954 fn test_validate_keys_catches_invalid_replace_after_finish_field_index() {
1955 let mut graph: StateGraph<StateWithBadAfterFinishIndex> = StateGraph::new();
1956 graph
1957 .add_node_simple(
1958 "node_a",
1959 NodeFnUpdate(|_s: &StateWithBadAfterFinishIndex| -> BoxResult<_> {
1960 Box::pin(async move { Ok(StateWithBadAfterFinishIndexUpdate::default()) })
1961 }),
1962 )
1963 .unwrap();
1964 graph.set_entry_point("node_a");
1965 graph.set_finish_point("node_a");
1966
1967 let result = graph.validate_keys();
1968 assert!(matches!(
1969 result,
1970 Err(TopologyError::InvalidFieldReference { .. })
1971 ));
1972 if let Err(TopologyError::InvalidFieldReference {
1973 index,
1974 field_count,
1975 context,
1976 ..
1977 }) = result
1978 {
1979 assert_eq!(index, 99);
1980 assert_eq!(field_count, 2);
1981 assert_eq!(context, "replace_after_finish_field_indices");
1982 }
1983 }
1984
1985 #[derive(Clone, Debug, Default)]
1988 struct StateWithBadReplaceIndex {
1989 a: i32,
1990 b: i32,
1991 }
1992
1993 #[derive(Clone, Debug, Default)]
1994 struct StateWithBadReplaceIndexUpdate {
1995 a: Option<i32>,
1996 b: Option<i32>,
1997 }
1998
1999 impl crate::State for StateWithBadReplaceIndex {
2000 type Update = StateWithBadReplaceIndexUpdate;
2001 type FieldVersions = crate::state::FieldVersions;
2002
2003 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
2004 let mut changed = crate::FieldsChanged::default();
2005 if let Some(v) = update.a {
2006 self.a = v;
2007 changed.set_field(0);
2008 }
2009 if let Some(v) = update.b {
2010 self.b = v;
2011 changed.set_field(1);
2012 }
2013 changed
2014 }
2015
2016 fn reset_ephemeral(&mut self) {}
2017
2018 fn field_count() -> usize {
2019 2
2020 }
2021
2022 fn field_names() -> &'static [&'static str] {
2023 &["a", "b"]
2024 }
2025
2026 fn replace_field_indices() -> &'static [usize] {
2027 &[5] }
2029 }
2030
2031 #[derive(Clone, Debug, Default)]
2034 struct StateWithBadAfterFinishIndex {
2035 x: String,
2036 y: String,
2037 }
2038
2039 #[derive(Clone, Debug, Default)]
2040 struct StateWithBadAfterFinishIndexUpdate {
2041 x: Option<String>,
2042 y: Option<String>,
2043 }
2044
2045 impl crate::State for StateWithBadAfterFinishIndex {
2046 type Update = StateWithBadAfterFinishIndexUpdate;
2047 type FieldVersions = crate::state::FieldVersions;
2048
2049 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
2050 let mut changed = crate::FieldsChanged::default();
2051 if let Some(v) = update.x {
2052 self.x = v;
2053 changed.set_field(0);
2054 }
2055 if let Some(v) = update.y {
2056 self.y = v;
2057 changed.set_field(1);
2058 }
2059 changed
2060 }
2061
2062 fn reset_ephemeral(&mut self) {}
2063
2064 fn field_count() -> usize {
2065 2
2066 }
2067
2068 fn field_names() -> &'static [&'static str] {
2069 &["x", "y"]
2070 }
2071
2072 fn replace_after_finish_field_indices() -> &'static [usize] {
2073 &[99] }
2075 }
2076
2077 #[test]
2078 fn test_compile_calls_validate_keys_and_catches_invalid_replace_field_index() {
2079 let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
2080 graph
2081 .add_node_simple(
2082 "node_a",
2083 NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
2084 Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
2085 }),
2086 )
2087 .unwrap();
2088 graph.set_entry_point("node_a");
2089 graph.set_finish_point("node_a");
2090
2091 let result = graph.compile();
2093 assert!(matches!(
2094 result,
2095 Err(TopologyError::InvalidFieldReference { .. })
2096 ));
2097 if let Err(TopologyError::InvalidFieldReference {
2098 index,
2099 field_count,
2100 context,
2101 ..
2102 }) = result
2103 {
2104 assert_eq!(index, 5);
2105 assert_eq!(field_count, 2);
2106 assert_eq!(context, "replace_field_indices");
2107 }
2108 }
2109
2110 #[test]
2111 fn test_compile_calls_validate_keys_and_catches_invalid_replace_after_finish_field_index() {
2112 let mut graph: StateGraph<StateWithBadAfterFinishIndex> = StateGraph::new();
2113 graph
2114 .add_node_simple(
2115 "node_a",
2116 NodeFnUpdate(|_s: &StateWithBadAfterFinishIndex| -> BoxResult<_> {
2117 Box::pin(async move { Ok(StateWithBadAfterFinishIndexUpdate::default()) })
2118 }),
2119 )
2120 .unwrap();
2121 graph.set_entry_point("node_a");
2122 graph.set_finish_point("node_a");
2123
2124 let result = graph.compile();
2126 assert!(matches!(
2127 result,
2128 Err(TopologyError::InvalidFieldReference { .. })
2129 ));
2130 if let Err(TopologyError::InvalidFieldReference {
2131 index,
2132 field_count,
2133 context,
2134 ..
2135 }) = result
2136 {
2137 assert_eq!(index, 99);
2138 assert_eq!(field_count, 2);
2139 assert_eq!(context, "replace_after_finish_field_indices");
2140 }
2141 }
2142
2143 #[test]
2144 fn test_validate_keys_validates_reducer_indices_during_compile() {
2145 let mut graph: StateGraph<StateWithBadReplaceIndex> = StateGraph::new();
2151 graph
2152 .add_node_simple(
2153 "process",
2154 NodeFnUpdate(|_s: &StateWithBadReplaceIndex| -> BoxResult<_> {
2155 Box::pin(async move { Ok(StateWithBadReplaceIndexUpdate::default()) })
2156 }),
2157 )
2158 .unwrap();
2159 graph.set_entry_point("process");
2160 graph.set_finish_point("process");
2161
2162 let validate_result = graph.validate_keys();
2164 assert!(
2165 validate_result.is_err(),
2166 "validate_keys should detect invalid field index"
2167 );
2168
2169 let compile_result = graph.compile();
2171 assert!(
2172 compile_result.is_err(),
2173 "compile should detect invalid field index"
2174 );
2175
2176 match (validate_result, compile_result) {
2178 (
2179 Err(TopologyError::InvalidFieldReference { index: v_idx, .. }),
2180 Err(TopologyError::InvalidFieldReference { index: c_idx, .. }),
2181 ) => {
2182 assert_eq!(
2183 v_idx, c_idx,
2184 "Both methods should report the same invalid index"
2185 );
2186 }
2187 _ => panic!("Both methods should return InvalidFieldReference error"),
2188 }
2189 }
2190
2191 #[derive(Clone, Debug, Default)]
2192 struct StateDummy;
2193
2194 impl crate::State for StateDummy {
2195 type Update = StateDummyUpdate;
2196 type FieldVersions = crate::state::FieldVersions;
2197
2198 fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
2199 crate::FieldsChanged(0)
2200 }
2201
2202 fn reset_ephemeral(&mut self) {}
2203 }
2204
2205 #[derive(Clone, Debug, Default)]
2206 struct StateDummyUpdate;
2207
2208 #[tokio::test]
2211 async fn test_execute_with_retry_succeeds_first_attempt() {
2212 let policy = RetryPolicy {
2213 max_attempts: 3,
2214 initial_interval: std::time::Duration::from_millis(1),
2215 backoff_factor: 2.0,
2216 max_interval: std::time::Duration::from_secs(1),
2217 jitter: false,
2218 retry_on: None,
2219 };
2220 let config = crate::RunnableConfig::new();
2221
2222 let result = execute_with_retry(
2223 "test_node",
2224 &policy,
2225 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2226 Box::pin(async { Ok(crate::Command::end()) })
2227 },
2228 &StateDummy,
2229 &config,
2230 )
2231 .await;
2232
2233 result.unwrap();
2234 }
2235
2236 #[tokio::test]
2237 async fn test_execute_with_retry_succeeds_after_retries() {
2238 let policy = RetryPolicy {
2239 max_attempts: 3,
2240 initial_interval: std::time::Duration::from_millis(1),
2241 backoff_factor: 2.0,
2242 max_interval: std::time::Duration::from_secs(1),
2243 jitter: false,
2244 retry_on: None,
2245 };
2246 let config = crate::RunnableConfig::new();
2247
2248 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2249 let attempt_clone = Arc::clone(&attempt_count);
2250
2251 let result = execute_with_retry(
2252 "test_node",
2253 &policy,
2254 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2255 let counter = Arc::clone(&attempt_clone);
2256 Box::pin(async move {
2257 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2258 if n < 2 {
2259 Err(crate::JunctureError::execution("transient failure"))
2260 } else {
2261 Ok(crate::Command::end())
2262 }
2263 })
2264 },
2265 &StateDummy,
2266 &config,
2267 )
2268 .await;
2269
2270 result.unwrap();
2271 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
2272 }
2273
2274 #[tokio::test]
2275 async fn test_execute_with_retry_exhausts_attempts() {
2276 let policy = RetryPolicy {
2277 max_attempts: 3,
2278 initial_interval: std::time::Duration::from_millis(1),
2279 backoff_factor: 2.0,
2280 max_interval: std::time::Duration::from_secs(1),
2281 jitter: false,
2282 retry_on: None,
2283 };
2284 let config = crate::RunnableConfig::new();
2285
2286 let result = execute_with_retry(
2287 "test_node",
2288 &policy,
2289 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2290 Box::pin(async { Err(crate::JunctureError::execution("always fails")) })
2291 },
2292 &StateDummy,
2293 &config,
2294 )
2295 .await;
2296
2297 assert!(result.is_err());
2298 assert!(result.unwrap_err().is_execution());
2299 }
2300
2301 #[tokio::test]
2302 async fn test_execute_with_retry_does_not_retry_cancelled() {
2303 let policy = RetryPolicy {
2304 max_attempts: 3,
2305 initial_interval: std::time::Duration::from_millis(1),
2306 backoff_factor: 2.0,
2307 max_interval: std::time::Duration::from_secs(1),
2308 jitter: false,
2309 retry_on: None,
2310 };
2311 let config = crate::RunnableConfig::new();
2312
2313 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2314 let attempt_clone = Arc::clone(&attempt_count);
2315
2316 let result = execute_with_retry(
2317 "test_node",
2318 &policy,
2319 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2320 let counter = Arc::clone(&attempt_clone);
2321 Box::pin(async move {
2322 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2323 Err(crate::JunctureError::cancelled())
2324 })
2325 },
2326 &StateDummy,
2327 &config,
2328 )
2329 .await;
2330
2331 assert!(result.is_err());
2332 assert!(result.unwrap_err().is_cancelled());
2333 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2335 }
2336
2337 #[tokio::test]
2338 async fn test_execute_with_retry_does_not_retry_interrupt() {
2339 let policy = RetryPolicy {
2340 max_attempts: 3,
2341 initial_interval: std::time::Duration::from_millis(1),
2342 backoff_factor: 2.0,
2343 max_interval: std::time::Duration::from_secs(1),
2344 jitter: false,
2345 retry_on: None,
2346 };
2347 let config = crate::RunnableConfig::new();
2348
2349 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2350 let attempt_clone = Arc::clone(&attempt_count);
2351
2352 let result = execute_with_retry(
2353 "test_node",
2354 &policy,
2355 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2356 let counter = Arc::clone(&attempt_clone);
2357 Box::pin(async move {
2358 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2359 Err(crate::JunctureError::interrupt("user input needed"))
2360 })
2361 },
2362 &StateDummy,
2363 &config,
2364 )
2365 .await;
2366
2367 assert!(result.is_err());
2368 assert!(result.unwrap_err().is_interrupt());
2369 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2370 }
2371
2372 #[tokio::test]
2373 async fn test_execute_with_retry_custom_retry_on_predicate() {
2374 let policy = RetryPolicy {
2376 max_attempts: 3,
2377 initial_interval: std::time::Duration::from_millis(1),
2378 backoff_factor: 2.0,
2379 max_interval: std::time::Duration::from_secs(1),
2380 jitter: false,
2381 retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
2382 };
2383 let config = crate::RunnableConfig::new();
2384
2385 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2386 let attempt_clone = Arc::clone(&attempt_count);
2387
2388 let result = execute_with_retry(
2389 "test_node",
2390 &policy,
2391 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2392 let counter = Arc::clone(&attempt_clone);
2393 Box::pin(async move {
2394 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2395 Err(crate::JunctureError::execution("not a timeout"))
2397 })
2398 },
2399 &StateDummy,
2400 &config,
2401 )
2402 .await;
2403
2404 assert!(result.is_err());
2405 assert!(result.unwrap_err().is_execution());
2406 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2408 }
2409
2410 #[tokio::test]
2411 async fn test_execute_with_retry_custom_predicate_allows_retry() {
2412 let policy = RetryPolicy {
2414 max_attempts: 3,
2415 initial_interval: std::time::Duration::from_millis(1),
2416 backoff_factor: 2.0,
2417 max_interval: std::time::Duration::from_secs(1),
2418 jitter: false,
2419 retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
2420 };
2421 let config = crate::RunnableConfig::new();
2422
2423 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2424 let attempt_clone = Arc::clone(&attempt_count);
2425
2426 let result = execute_with_retry(
2427 "test_node",
2428 &policy,
2429 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2430 let counter = Arc::clone(&attempt_clone);
2431 Box::pin(async move {
2432 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2433 if n < 2 {
2434 Err(crate::JunctureError::timeout("timed out"))
2435 } else {
2436 Ok(crate::Command::end())
2437 }
2438 })
2439 },
2440 &StateDummy,
2441 &config,
2442 )
2443 .await;
2444
2445 result.unwrap();
2446 assert_eq!(attempt_count.load(std::sync::atomic::Ordering::Relaxed), 3);
2447 }
2448
2449 #[test]
2450 fn test_compute_delay_no_jitter() {
2451 let base = std::time::Duration::from_millis(100);
2452 let max = std::time::Duration::from_secs(10);
2453 let result = compute_delay(base, false, max);
2454 assert_eq!(result, std::time::Duration::from_millis(100));
2455 }
2456
2457 #[test]
2458 fn test_compute_delay_caps_at_max() {
2459 let base = std::time::Duration::from_secs(20);
2460 let max = std::time::Duration::from_secs(10);
2461 let result = compute_delay(base, false, max);
2462 assert_eq!(result, std::time::Duration::from_secs(10));
2463 }
2464
2465 #[test]
2466 fn test_compute_delay_with_jitter_stays_within_range() {
2467 let base = std::time::Duration::from_millis(100);
2468 let max = std::time::Duration::from_secs(10);
2469 for _ in 0..100 {
2471 let result = compute_delay(base, true, max);
2472 let millis = result.as_secs_f64() * 1000.0;
2473 assert!(
2475 (75.0..=125.0).contains(&millis),
2476 "jittered delay {millis}ms outside expected range [75, 125]"
2477 );
2478 }
2479 }
2480
2481 #[test]
2482 fn test_compute_delay_jitter_capped_by_max() {
2483 let base = std::time::Duration::from_millis(100);
2484 let max = std::time::Duration::from_millis(50);
2486 for _ in 0..100 {
2487 let result = compute_delay(base, true, max);
2488 assert!(
2489 result <= max,
2490 "jittered delay {result:?} exceeded max {max:?}",
2491 );
2492 }
2493 }
2494
2495 #[test]
2496 fn test_cap_delay_returns_min() {
2497 let delay = std::time::Duration::from_secs(5);
2498 let max = std::time::Duration::from_secs(10);
2499 assert_eq!(cap_delay(delay, max), delay);
2500
2501 let delay_large = std::time::Duration::from_secs(15);
2502 assert_eq!(cap_delay(delay_large, max), max);
2503 }
2504
2505 #[test]
2506 fn test_retry_policy_should_retry_default_allows_execution_errors() {
2507 let policy = RetryPolicy::default();
2508 let error = crate::JunctureError::execution("something went wrong");
2509 assert!(policy.should_retry(&error));
2510 }
2511
2512 #[test]
2513 fn test_retry_policy_should_retry_default_blocks_cancelled() {
2514 let policy = RetryPolicy::default();
2515 let error = crate::JunctureError::cancelled();
2516 assert!(!policy.should_retry(&error));
2517 }
2518
2519 #[test]
2520 fn test_retry_policy_should_retry_default_blocks_interrupt() {
2521 let policy = RetryPolicy::default();
2522 let error = crate::JunctureError::interrupt("waiting for user");
2523 assert!(!policy.should_retry(&error));
2524 }
2525
2526 #[test]
2527 fn test_retry_policy_should_retry_custom_predicate() {
2528 let policy = RetryPolicy {
2529 max_attempts: 3,
2530 initial_interval: std::time::Duration::from_millis(100),
2531 backoff_factor: 2.0,
2532 max_interval: std::time::Duration::from_secs(10),
2533 jitter: false,
2534 retry_on: Some(Arc::new(|e: &crate::JunctureError| e.is_timeout())),
2535 };
2536
2537 assert!(policy.should_retry(&crate::JunctureError::timeout("slow")));
2538 assert!(!policy.should_retry(&crate::JunctureError::execution("not timeout")));
2539 }
2540
2541 #[tokio::test]
2542 async fn test_retrying_node_delegates_to_execute_with_retry() {
2543 use crate::node::NodeFnCommand;
2544
2545 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2546 let count_clone = Arc::clone(&call_count);
2547
2548 let inner: Arc<dyn crate::Node<StateDummy>> =
2549 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
2550 let counter = Arc::clone(&count_clone);
2551 Box::pin(async move {
2552 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2553 if n == 0 {
2554 Err(crate::JunctureError::execution("first try fails"))
2555 } else {
2556 Ok(crate::Command::end())
2557 }
2558 })
2559 })
2560 .into_node("inner");
2561
2562 let policy = RetryPolicy {
2563 max_attempts: 3,
2564 initial_interval: std::time::Duration::from_millis(1),
2565 backoff_factor: 2.0,
2566 max_interval: std::time::Duration::from_secs(1),
2567 jitter: false,
2568 retry_on: None,
2569 };
2570
2571 let retrying_node = RetryingNode::new(inner, policy);
2572 let config = crate::RunnableConfig::new();
2573
2574 let result = retrying_node.call(&StateDummy, &config).await;
2575 result.unwrap();
2576 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 2);
2577 }
2578
2579 #[tokio::test]
2580 async fn test_retrying_node_respects_max_attempts() {
2581 use crate::node::NodeFnCommand;
2582
2583 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2584 let count_clone = Arc::clone(&call_count);
2585
2586 let inner: Arc<dyn crate::Node<StateDummy>> =
2587 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
2588 let counter = Arc::clone(&count_clone);
2589 Box::pin(async move {
2590 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2591 Err(crate::JunctureError::execution("always fails"))
2592 })
2593 })
2594 .into_node("inner");
2595
2596 let policy = RetryPolicy {
2597 max_attempts: 5,
2598 initial_interval: std::time::Duration::from_millis(1),
2599 backoff_factor: 2.0,
2600 max_interval: std::time::Duration::from_secs(1),
2601 jitter: false,
2602 retry_on: None,
2603 };
2604
2605 let retrying_node = RetryingNode::new(inner, policy);
2606 let config = crate::RunnableConfig::new();
2607
2608 let result = retrying_node.call(&StateDummy, &config).await;
2609 let err = result.unwrap_err();
2610 assert!(err.is_execution());
2611 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 5);
2612 }
2613
2614 #[tokio::test]
2615 async fn test_retrying_node_with_jitter_enabled() {
2616 use crate::node::NodeFnCommand;
2617
2618 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2619 let count_clone = Arc::clone(&call_count);
2620
2621 let inner: Arc<dyn crate::Node<StateDummy>> =
2622 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
2623 let counter = Arc::clone(&count_clone);
2624 Box::pin(async move {
2625 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2626 if n < 2 {
2627 Err(crate::JunctureError::execution("retry me"))
2628 } else {
2629 Ok(crate::Command::end())
2630 }
2631 })
2632 })
2633 .into_node("inner");
2634
2635 let policy = RetryPolicy {
2636 max_attempts: 3,
2637 initial_interval: std::time::Duration::from_millis(1),
2638 backoff_factor: 2.0,
2639 max_interval: std::time::Duration::from_secs(1),
2640 jitter: true,
2641 retry_on: None,
2642 };
2643
2644 let retrying_node = RetryingNode::new(inner, policy);
2645 let config = crate::RunnableConfig::new();
2646
2647 let result = retrying_node.call(&StateDummy, &config).await;
2648 result.unwrap();
2649 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 3);
2650 }
2651
2652 #[tokio::test]
2653 async fn test_execute_with_retry_max_interval_capping() {
2654 let policy = RetryPolicy {
2656 max_attempts: 3,
2657 initial_interval: std::time::Duration::from_millis(50),
2658 backoff_factor: 100.0,
2659 max_interval: std::time::Duration::from_millis(80),
2660 jitter: false,
2661 retry_on: None,
2662 };
2663 let config = crate::RunnableConfig::new();
2664
2665 let start = crate::time::Instant::now();
2666 let attempt_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2667 let attempt_clone = Arc::clone(&attempt_count);
2668
2669 let result = execute_with_retry(
2670 "test_node",
2671 &policy,
2672 move |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2673 let counter = Arc::clone(&attempt_clone);
2674 Box::pin(async move {
2675 let n = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2676 if n < 2 {
2677 Err(crate::JunctureError::execution("fail"))
2678 } else {
2679 Ok(crate::Command::end())
2680 }
2681 })
2682 },
2683 &StateDummy,
2684 &config,
2685 )
2686 .await;
2687
2688 let elapsed = start.elapsed();
2689 result.unwrap();
2690 assert!(
2693 elapsed < std::time::Duration::from_secs(2),
2694 "max_interval capping should prevent very long waits, elapsed: {elapsed:?}"
2695 );
2696 }
2697
2698 #[tokio::test]
2701 async fn test_execute_with_timeout_succeeds_within_limit() {
2702 let config = crate::RunnableConfig::new();
2703
2704 let result = execute_with_timeout(
2705 "test_node",
2706 std::time::Duration::from_secs(10),
2707 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2708 Box::pin(async { Ok(crate::Command::end()) })
2709 },
2710 &StateDummy,
2711 &config,
2712 )
2713 .await;
2714
2715 result.unwrap();
2716 }
2717
2718 #[tokio::test]
2719 async fn test_execute_with_timeout_fires_on_slow_node() {
2720 let config = crate::RunnableConfig::new();
2721
2722 let result = execute_with_timeout(
2723 "slow_node",
2724 std::time::Duration::from_millis(10),
2725 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2726 Box::pin(async {
2727 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
2728 Ok(crate::Command::end())
2729 })
2730 },
2731 &StateDummy,
2732 &config,
2733 )
2734 .await;
2735
2736 let err = result.unwrap_err();
2737 assert!(err.is_node_timeout());
2738 }
2739
2740 #[tokio::test]
2741 async fn test_execute_with_timeout_passes_through_inner_error() {
2742 let config = crate::RunnableConfig::new();
2743
2744 let result = execute_with_timeout(
2745 "failing_node",
2746 std::time::Duration::from_secs(10),
2747 |_s: &StateDummy, _cfg: &crate::RunnableConfig| -> BoxResult<_> {
2748 Box::pin(async { Err(crate::JunctureError::execution("inner failure")) })
2749 },
2750 &StateDummy,
2751 &config,
2752 )
2753 .await;
2754
2755 let err = result.unwrap_err();
2756 assert!(err.is_execution());
2757 assert!(!err.is_node_timeout());
2758 }
2759
2760 #[tokio::test]
2761 async fn test_timeout_node_wrapper_integration() {
2762 use crate::node::NodeFnCommand;
2763
2764 let call_count = Arc::new(std::sync::atomic::AtomicU32::new(0));
2765 let count_clone = Arc::clone(&call_count);
2766
2767 let inner: Arc<dyn crate::Node<StateDummy>> =
2768 NodeFnCommand(move |_s: &StateDummy| -> BoxResult<_> {
2769 let counter = Arc::clone(&count_clone);
2770 Box::pin(async move {
2771 counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
2772 Ok(crate::Command::end())
2773 })
2774 })
2775 .into_node("inner");
2776
2777 let policy =
2778 crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_secs(10));
2779
2780 let timeout_node = TimeoutNode::new(inner, policy);
2781 let config = crate::RunnableConfig::new();
2782
2783 let result = timeout_node.call(&StateDummy, &config).await;
2784 result.unwrap();
2785 assert_eq!(call_count.load(std::sync::atomic::Ordering::Relaxed), 1);
2786 }
2787
2788 #[tokio::test]
2789 async fn test_timeout_node_fires_on_exceeded_duration() {
2790 use crate::node::NodeFnCommand;
2791
2792 let inner: Arc<dyn crate::Node<StateDummy>> =
2793 NodeFnCommand(|_s: &StateDummy| -> BoxResult<_> {
2794 Box::pin(async {
2795 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
2796 Ok(crate::Command::end())
2797 })
2798 })
2799 .into_node("inner");
2800
2801 let policy =
2802 crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_millis(10));
2803
2804 let timeout_node = TimeoutNode::new(inner, policy);
2805 let config = crate::RunnableConfig::new();
2806
2807 let result = timeout_node.call(&StateDummy, &config).await;
2808 let err = result.unwrap_err();
2809 assert!(err.is_node_timeout());
2810 }
2811
2812 #[tokio::test]
2813 async fn test_timeout_node_passes_through_inner_error() {
2814 use crate::node::NodeFnCommand;
2815
2816 let inner: Arc<dyn crate::Node<StateDummy>> =
2817 NodeFnCommand(|_s: &StateDummy| -> BoxResult<_> {
2818 Box::pin(async { Err(crate::JunctureError::execution("node failure")) })
2819 })
2820 .into_node("inner");
2821
2822 let policy =
2823 crate::TimeoutPolicy::new().with_run_timeout(std::time::Duration::from_secs(10));
2824
2825 let timeout_node = TimeoutNode::new(inner, policy);
2826 let config = crate::RunnableConfig::new();
2827
2828 let result = timeout_node.call(&StateDummy, &config).await;
2829 let err = result.unwrap_err();
2830 assert!(err.is_execution());
2831 assert!(!err.is_node_timeout());
2832 }
2833}
2834
2835