1use super::builder::NodeMetadata;
8use crate::info_span;
9#[cfg(target_family = "wasm")]
10use crate::tracing_wasm::WasmInstrument;
11use crate::{
12 JunctureError, State,
13 checkpoint::{
14 Checkpoint, CheckpointFilter, CheckpointMetadata, CheckpointSource, StateSnapshot,
15 },
16 config::RunnableConfig,
17 edge::TriggerTable,
18 interrupt::ResumeValue,
19 pregel::{BudgetTracker, PregelLoop},
20 state::{FromState, IntoState},
21 stream::{EventEmitter, StreamEvent, StreamMode},
22};
23use futures::Stream;
24use indexmap::IndexMap;
25use std::{pin::Pin, sync::Arc};
26use tokio::sync::mpsc;
27#[cfg(not(target_family = "wasm"))]
28use tracing::Instrument;
29
30const CHANNEL_CAPACITY_MESSAGES: usize = 256;
36
37const CHANNEL_CAPACITY_DEFAULT: usize = 32;
43
44fn stream_capacity(mode: &StreamMode) -> usize {
50 match mode {
51 StreamMode::Messages => CHANNEL_CAPACITY_MESSAGES,
52 StreamMode::Multi(modes) if modes.iter().any(|m| matches!(m, StreamMode::Messages)) => {
53 CHANNEL_CAPACITY_MESSAGES
54 }
55 _ => CHANNEL_CAPACITY_DEFAULT,
56 }
57}
58
59pub struct StreamHandle<S: State> {
81 run_id: String,
83 pub stream: Pin<Box<dyn Stream<Item = Result<StreamEvent<S>, JunctureError>> + Send>>,
85}
86
87impl<S: State> std::fmt::Debug for StreamHandle<S> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("StreamHandle")
90 .field("run_id", &self.run_id)
91 .field("stream", &"<stream>")
92 .finish()
93 }
94}
95
96impl<S: State> StreamHandle<S> {
97 #[must_use]
99 pub fn run_id(&self) -> &str {
100 &self.run_id
101 }
102
103 #[must_use]
105 #[allow(
106 clippy::type_complexity,
107 reason = "return type mirrors StreamHandle fields"
108 )]
109 pub fn into_parts(
110 self,
111 ) -> (
112 String,
113 Pin<Box<dyn Stream<Item = Result<StreamEvent<S>, JunctureError>> + Send>>,
114 ) {
115 (self.run_id, self.stream)
116 }
117}
118
119#[derive(Clone)]
137pub struct CompiledGraph<S: State, I: IntoState<S> = S, O: FromState<S> = S> {
138 inner: Arc<CompiledGraphInner<S>>,
139 _input: std::marker::PhantomData<I>,
140 _output: std::marker::PhantomData<O>,
141}
142
143impl<S: State, I: IntoState<S>, O: FromState<S>> std::fmt::Debug for CompiledGraph<S, I, O> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("CompiledGraph")
146 .field("node_count", &self.inner.nodes.len())
147 .field("has_checkpointer", &self.inner.checkpointer.is_some())
148 .finish()
149 }
150}
151
152impl<S: State, I: IntoState<S>, O: FromState<S>> CompiledGraph<S, I, O> {
153 #[must_use]
155 pub(crate) fn new(
156 nodes: IndexMap<String, Arc<dyn crate::Node<S>>>,
157 trigger_table: TriggerTable<S>,
158 builder_metadata: IndexMap<String, NodeMetadata>,
159 interrupt_before: Vec<String>,
160 interrupt_after: Vec<String>,
161 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
162 subgraphs: Vec<SubgraphInfo>,
163 ) -> Self {
164 Self {
165 inner: Arc::new(CompiledGraphInner {
166 nodes,
167 trigger_table,
168 builder_metadata,
169 checkpointer,
170 interrupt_before,
171 interrupt_after,
172 subgraphs,
173 active_invocations: std::sync::atomic::AtomicU64::new(0),
174 }),
175 _input: std::marker::PhantomData,
176 _output: std::marker::PhantomData,
177 }
178 }
179
180 fn build_error_handler_map(&self) -> std::collections::HashMap<String, String> {
185 self.inner
186 .builder_metadata
187 .iter()
188 .filter_map(|(node_name, meta)| {
189 meta.error_handler
190 .as_ref()
191 .map(|handler| (node_name.clone(), handler.clone()))
192 })
193 .collect()
194 }
195
196 fn build_retry_policy_map(
205 &self,
206 ) -> std::collections::HashMap<String, super::builder::RetryPolicy> {
207 self.inner
208 .builder_metadata
209 .iter()
210 .filter_map(|(node_name, meta)| {
211 meta.retry_policies
212 .first()
213 .map(|policy| (node_name.clone(), policy.clone()))
214 })
215 .collect()
216 }
217
218 fn build_timeout_policy_map(&self) -> std::collections::HashMap<String, crate::TimeoutPolicy> {
226 self.inner
227 .builder_metadata
228 .iter()
229 .filter_map(|(node_name, meta)| {
230 meta.timeout_policies
231 .first()
232 .cloned()
233 .map(|policy| (node_name.clone(), policy))
234 })
235 .collect()
236 }
237
238 fn effective_config(&self, config: &RunnableConfig) -> RunnableConfig {
244 let mut effective = config.clone();
245 if effective.interrupt_before.is_none() && !self.inner.interrupt_before.is_empty() {
246 effective.interrupt_before = Some(self.inner.interrupt_before.clone());
247 }
248 if effective.interrupt_after.is_none() && !self.inner.interrupt_after.is_empty() {
249 effective.interrupt_after = Some(self.inner.interrupt_after.clone());
250 }
251 effective
252 }
253
254 fn deserialize_with_migration(
259 checkpoint: &crate::checkpoint::Checkpoint,
260 ) -> Result<S, JunctureError>
261 where
262 S: serde::de::DeserializeOwned,
263 {
264 let mut channel_values = checkpoint.channel_values.clone();
265 let checkpoint_version = checkpoint.schema_version;
266 let current_version = S::schema_version();
267 if checkpoint_version != current_version {
268 channel_values = S::migrate(checkpoint_version, channel_values);
269 }
270 serde_json::from_value(channel_values)
271 .map_err(|e| JunctureError::checkpoint(format!("failed to deserialize state: {e}")))
272 }
273
274 pub fn invoke(
289 &self,
290 input: I,
291 config: &RunnableConfig,
292 ) -> Result<GraphOutput<S, O>, JunctureError>
293 where
294 S: serde::de::DeserializeOwned + serde::Serialize,
295 S::Update: serde::Serialize,
296 O: FromState<S>,
297 {
298 let effective = self.effective_config(config);
299
300 let runtime = {
303 #[cfg(feature = "multi-thread")]
304 {
305 tokio::runtime::Runtime::new().map_err(|e| {
306 JunctureError::execution(format!("Failed to create runtime: {e}"))
307 })?
308 }
309 #[cfg(not(feature = "multi-thread"))]
310 {
311 tokio::runtime::Builder::new_current_thread()
312 .enable_all()
313 .build()
314 .map_err(|e| {
315 JunctureError::execution(format!("Failed to create runtime: {e}"))
316 })?
317 }
318 };
319
320 runtime.block_on(self.invoke_async_inner(input, &effective))
321 }
322
323 pub async fn invoke_async(
338 &self,
339 input: I,
340 config: &RunnableConfig,
341 ) -> Result<GraphOutput<S, O>, JunctureError>
342 where
343 S: serde::de::DeserializeOwned + serde::Serialize,
344 S::Update: serde::Serialize,
345 O: FromState<S>,
346 {
347 let effective = self.effective_config(config);
348 self.invoke_async_inner(input, &effective).await
349 }
350
351 async fn invoke_async_inner(
353 &self,
354 input: I,
355 config: &RunnableConfig,
356 ) -> Result<GraphOutput<S, O>, JunctureError>
357 where
358 S: serde::de::DeserializeOwned + serde::Serialize,
359 S::Update: serde::Serialize,
360 O: FromState<S>,
361 {
362 let num_fields = 64;
364
365 let error_handler_map = self.build_error_handler_map();
367
368 let retry_policy_map = self.build_retry_policy_map();
370
371 let timeout_policy_map = self.build_timeout_policy_map();
373
374 let state_input = input.into_state();
376
377 let mut pregel = PregelLoop::with_error_handlers(
379 state_input,
380 self.inner.nodes.clone(),
381 self.inner.trigger_table.clone(),
382 config.clone(),
383 num_fields,
384 error_handler_map,
385 )?;
386
387 pregel.set_retry_policies(retry_policy_map);
388 pregel.set_timeout_policies(timeout_policy_map);
389
390 if let Some(budget_config) = &pregel.runnable_config.budget {
392 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
393 pregel.set_budget_tracker(
394 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
395 );
396 }
397
398 let graph_name = config
401 .graph_name
402 .clone()
403 .unwrap_or_else(|| "unnamed".to_string());
404 let run_id = pregel.run_id().to_string();
405 let recursion_limit = pregel.runnable_config.recursion_limit;
406
407 async move {
408 let graph_start = crate::time::Instant::now();
409
410 if let Some(ref collector) = config.metrics_collector {
412 collector.inc_counter("juncture.graph.invocations", 1);
413
414 let active = self
415 .inner
416 .active_invocations
417 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
418 + 1;
419 collector.set_gauge("juncture.graph.active_invocations", active);
420 }
421
422 let execution_result = async {
424 while pregel.tick()? {
425 let result = pregel.execute_superstep().await?;
426 pregel.after_tick(result).await?;
427 }
428 Ok::<(), JunctureError>(())
429 }
430 .await;
431
432 if let Some(ref collector) = config.metrics_collector {
434 let active = self
435 .inner
436 .active_invocations
437 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
438 - 1;
439 collector.set_gauge("juncture.graph.active_invocations", active);
440 }
441
442 let execution_result = match execution_result {
444 Ok(()) => Ok(()),
445 Err(e) => {
446 if let Some(ref collector) = config.metrics_collector {
448 collector.inc_counter("juncture.graph.errors", 1);
449 }
450 Err(e)
451 }
452 };
453
454 let steps = pregel.step();
456 let run_id = pregel.run_id().to_string();
457
458 let final_state = pregel.into_state();
460 let output = O::from_state(&final_state);
461
462 if let Some(ref collector) = config.metrics_collector {
464 #[allow(
465 clippy::cast_precision_loss,
466 reason = "Milliseconds as f64 is sufficient for histogram metrics; sub-millisecond precision is not required for graph duration tracking"
467 )]
468 collector.record_histogram(
469 "juncture.graph.duration_ms",
470 graph_start.elapsed().as_millis() as f64,
471 );
472 }
473
474 execution_result?;
475
476 Ok(GraphOutput {
477 value: final_state,
478 output,
479 interrupts: Vec::new(),
480 metadata: GraphOutputMetadata {
481 steps,
482 run_id,
483 checkpoint_id: config.checkpoint_id.clone(),
484 budget_usage: None,
485 },
486 })
487 }
488 .instrument(info_span!(
489 "juncture.graph.invoke",
490 "juncture.graph.name" = graph_name,
491 "juncture.run.id" = %run_id,
492 "juncture.recursion.limit" = recursion_limit,
493 ))
494 .await
495 }
496
497 pub async fn stream(
545 &self,
546 input: I,
547 config: &RunnableConfig,
548 mode: StreamMode,
549 ) -> Result<StreamHandle<S>, JunctureError>
550 where
551 S: Clone + Send + serde::de::DeserializeOwned + serde::Serialize + 'static,
552 S::Update: serde::Serialize,
553 {
554 self.stream_with_config(input, config, crate::stream::StreamConfig::new(mode))
555 .await
556 }
557
558 #[allow(
615 clippy::too_many_lines,
616 reason = "stream orchestration: channel setup, PregelLoop wiring, output_keys filtering, and event forwarding are inseparable"
617 )]
618 #[expect(
619 clippy::unused_async,
620 reason = "function signature follows async convention for consistency with invoke_async"
621 )]
622 pub async fn stream_with_config(
623 &self,
624 input: I,
625 config: &RunnableConfig,
626 stream_config: crate::stream::StreamConfig,
627 ) -> Result<StreamHandle<S>, JunctureError>
628 where
629 S: Clone + Send + serde::de::DeserializeOwned + serde::Serialize + 'static,
630 S::Update: serde::Serialize,
631 {
632 use futures::stream;
633
634 let effective = self.effective_config(config);
635 let num_fields = 64;
636 let mode = stream_config.mode.clone();
637 let output_keys = stream_config.output_keys;
638 let include_subgraphs = stream_config.include_subgraphs;
639 let subgraph_filter = stream_config.subgraph_filter;
640 let resumption = stream_config.resumption;
641
642 let capacity = stream_capacity(&mode);
645 let (tx, rx) = mpsc::channel(capacity);
646
647 let error_handler_map = self.build_error_handler_map();
649
650 let retry_policy_map = self.build_retry_policy_map();
652
653 let timeout_policy_map = self.build_timeout_policy_map();
655
656 let graph_name = effective
658 .graph_name
659 .clone()
660 .unwrap_or_else(|| "unnamed".to_string());
661
662 let state_input = input.into_state();
664 let mut pregel = PregelLoop::with_error_handlers(
665 state_input,
666 self.inner.nodes.clone(),
667 self.inner.trigger_table.clone(),
668 effective,
669 num_fields,
670 error_handler_map,
671 )?;
672
673 pregel.set_retry_policies(retry_policy_map);
674 pregel.set_timeout_policies(timeout_policy_map);
675
676 if let Some(budget_config) = &pregel.runnable_config.budget {
678 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
679 pregel.set_budget_tracker(
680 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
681 );
682 }
683
684 let run_id = pregel.run_id().to_string();
686 let recursion_limit = pregel.runnable_config.recursion_limit;
687
688 let (pregel_tx, mut pregel_rx) = mpsc::unbounded_channel();
693 pregel.set_stream_sender(pregel_tx);
694
695 tokio::spawn(
697 async move {
698 let tx_forward = tx.clone();
701 let mode_forward = mode.clone();
702 let output_keys_forward = output_keys.clone();
703 let resumption_forward = resumption.clone();
704 tokio::spawn(async move {
705 let (temp_tx, _temp_rx) = mpsc::channel(1);
707 let emitter = EventEmitter::new(temp_tx, mode_forward);
708
709 while let Some(event) = pregel_rx.recv().await {
710 if !emitter.should_emit(&event) {
711 continue;
712 }
713
714 let ns = event.namespace();
717 if !ns.is_empty() {
718 if !include_subgraphs {
719 continue;
720 }
721 if let Some(ref filter) = subgraph_filter
722 && let Some(first) = ns.first()
723 && !filter.contains(first)
724 {
725 continue;
726 }
727 }
728
729 if let Some(ref r) = resumption_forward {
733 let step = match &event {
734 StreamEvent::Values { step, .. }
735 | StreamEvent::FilteredValues { step, .. }
736 | StreamEvent::Updates { step, .. }
737 | StreamEvent::FilteredUpdates { step, .. } => Some(*step),
738 _ => None,
739 };
740 if let Some(s) = step
741 && r.should_skip(s)
742 {
743 continue;
744 }
745 }
746
747 let filtered = output_keys_forward.as_ref().and_then(|keys| match &event {
749 StreamEvent::Updates { node, update, step } => {
750 serde_json::to_value(update).ok().map(|json| {
751 StreamEvent::FilteredUpdates {
752 node: node.clone(),
753 data: crate::stream::filter_json_by_keys(json, keys),
754 step: *step,
755 }
756 })
757 }
758 _ => None,
759 });
760
761 if let Some(filtered_event) = filtered {
762 let _ = tx_forward.send(Ok(filtered_event)).await;
763 } else {
764 let _ = tx_forward.send(Ok(event)).await;
765 }
766 }
767 });
768
769 while matches!(pregel.tick(), Ok(true)) {
771 let step = pregel.step();
772
773 if matches!(mode, StreamMode::Values) {
776 let skip = resumption.as_ref().is_some_and(|r| r.should_skip(step));
777
778 if !skip {
779 let event = output_keys.as_ref().map_or_else(
780 || StreamEvent::Values {
781 state: pregel.snapshot_state(),
782 step,
783 },
784 |keys| {
785 let json = serde_json::to_value(pregel.snapshot_state())
786 .unwrap_or(serde_json::Value::Null);
787 StreamEvent::FilteredValues {
788 data: crate::stream::filter_json_by_keys(json, keys),
789 step,
790 }
791 },
792 );
793 let _ = tx.send(Ok(event)).await;
794 }
795 }
796
797 match pregel.execute_superstep().await {
799 Ok(result) => {
800 if let Err(e) = pregel.after_tick(result).await {
802 let _ = tx
804 .send(Ok(StreamEvent::End {
805 output: pregel.snapshot_state(),
806 }))
807 .await;
808 let _ = tx.send(Err(e)).await;
810 return;
811 }
812 }
813 Err(e) => {
814 let _ = tx
816 .send(Ok(StreamEvent::End {
817 output: pregel.snapshot_state(),
818 }))
819 .await;
820 let _ = tx.send(Err(e)).await;
822 return;
823 }
824 }
825 }
826
827 let final_state = pregel.into_state();
829 let _ = tx
830 .send(Ok(StreamEvent::End {
831 output: final_state,
832 }))
833 .await;
834 }
835 .instrument(info_span!(
836 "juncture.graph.invoke",
837 "juncture.graph.name" = graph_name,
838 "juncture.run.id" = %run_id,
839 "juncture.recursion.limit" = recursion_limit,
840 )),
841 );
842
843 Ok(StreamHandle {
845 run_id,
846 stream: Box::pin(stream::unfold(rx, |mut rx| async move {
847 rx.recv().await.map(|item| (item, rx))
848 })),
849 })
850 }
851
852 pub async fn execute_with_emitter(
894 &self,
895 input: S,
896 config: &RunnableConfig,
897 emitter: EventEmitter<S>,
898 ) -> Result<S, JunctureError>
899 where
900 S: Clone + Send + serde::de::DeserializeOwned + serde::Serialize + 'static,
901 S::Update: serde::Serialize,
902 {
903 let num_fields = 64;
904
905 let mut exec_config = self.effective_config(config);
907 if exec_config.run_id.is_none() {
909 exec_config.run_id = Some(uuid::Uuid::new_v4().to_string());
910 }
911
912 let graph_name = exec_config
914 .graph_name
915 .clone()
916 .unwrap_or_else(|| "unnamed".to_string());
917
918 let error_handler_map = self.build_error_handler_map();
919 let retry_policy_map = self.build_retry_policy_map();
920 let timeout_policy_map = self.build_timeout_policy_map();
921
922 let mut pregel = PregelLoop::with_error_handlers(
923 input,
924 self.inner.nodes.clone(),
925 self.inner.trigger_table.clone(),
926 exec_config,
927 num_fields,
928 error_handler_map,
929 )?;
930
931 pregel.set_retry_policies(retry_policy_map);
932 pregel.set_timeout_policies(timeout_policy_map);
933
934 if let Some(budget_config) = &pregel.runnable_config.budget {
936 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
937 pregel.set_budget_tracker(
938 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
939 );
940 }
941
942 if let Some(cp) = self.inner.checkpointer.clone() {
943 pregel.set_checkpointer(cp);
944 }
945
946 let mode = emitter.mode().clone();
947 let run_id = pregel.run_id().to_string();
948 let recursion_limit = pregel.runnable_config.recursion_limit;
949
950 let (pregel_tx, mut pregel_rx) = mpsc::unbounded_channel();
952 pregel.set_stream_sender(pregel_tx);
953
954 let emitter_clone = emitter.clone();
956 tokio::spawn(async move {
957 while let Some(event) = pregel_rx.recv().await {
958 if emitter_clone.should_emit(&event) {
959 emitter_clone.emit(event).await;
960 }
961 }
962 });
963
964 async move {
965 while pregel.tick()? {
967 let step = pregel.step();
968
969 if matches!(mode, StreamMode::Values) {
970 let event = StreamEvent::Values {
971 state: pregel.snapshot_state(),
972 step,
973 };
974 emitter.emit(event).await;
975 }
976
977 let result = pregel.execute_superstep().await?;
978 pregel.after_tick(result).await?;
979 }
980
981 let final_state = pregel.into_state();
983 emitter
984 .emit(StreamEvent::End {
985 output: final_state.clone(),
986 })
987 .await;
988
989 Ok(final_state)
990 }
991 .instrument(info_span!(
992 "juncture.graph.invoke",
993 "juncture.graph.name" = graph_name,
994 "juncture.run.id" = %run_id,
995 "juncture.recursion.limit" = recursion_limit,
996 ))
997 .await
998 }
999
1000 pub async fn resume(
1045 &self,
1046 config: &RunnableConfig,
1047 resume_value: ResumeValue,
1048 ) -> Result<GraphOutput<S, O>, JunctureError>
1049 where
1050 S: for<'de> serde::Deserialize<'de> + serde::Serialize,
1051 S::Update: serde::Serialize,
1052 O: FromState<S>,
1053 {
1054 let checkpointer =
1055 self.inner.checkpointer.as_ref().ok_or_else(|| {
1056 JunctureError::checkpoint("no checkpointer configured for resume")
1057 })?;
1058
1059 let tuple = checkpointer
1061 .get_tuple(config)
1062 .await
1063 .map_err(|e| JunctureError::checkpoint(format!("failed to load checkpoint: {e}")))?
1064 .ok_or_else(|| {
1065 JunctureError::checkpoint(format!(
1066 "checkpoint not found: thread_id={:?}, checkpoint_id={:?}",
1067 config.thread_id, config.checkpoint_id
1068 ))
1069 })?;
1070
1071 if !matches!(tuple.metadata.source, CheckpointSource::Interrupt { .. }) {
1074 return Err(JunctureError::checkpoint(format!(
1075 "resume() requires checkpoint from Interrupt source, got {:?}",
1076 tuple.metadata.source
1077 )));
1078 }
1079
1080 let state = Self::deserialize_with_migration(&tuple.checkpoint)?;
1082
1083 let mut resume_config = self.effective_config(config);
1085 resume_config.resume_value = Some(resume_value);
1086 if resume_config.run_id.is_none() {
1087 resume_config.run_id = Some(uuid::Uuid::new_v4().to_string());
1088 }
1089
1090 let graph_name = resume_config
1092 .graph_name
1093 .clone()
1094 .unwrap_or_else(|| "unnamed".to_string());
1095
1096 let num_fields = 64; let error_handler_map = self.build_error_handler_map();
1099 let retry_policy_map = self.build_retry_policy_map();
1100 let timeout_policy_map = self.build_timeout_policy_map();
1101 let mut pregel = crate::pregel::PregelLoop::with_error_handlers(
1102 state,
1103 self.inner.nodes.clone(),
1104 self.inner.trigger_table.clone(),
1105 resume_config,
1106 num_fields,
1107 error_handler_map,
1108 )?;
1109
1110 pregel.set_retry_policies(retry_policy_map);
1111 pregel.set_timeout_policies(timeout_policy_map);
1112
1113 if let Some(budget_config) = &pregel.runnable_config.budget {
1115 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
1116 pregel.set_budget_tracker(
1117 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
1118 );
1119 }
1120
1121 if let Some(cp) = self.inner.checkpointer.clone() {
1123 pregel.set_checkpointer(cp);
1124 }
1125
1126 let run_id = pregel.run_id().to_string();
1127 let recursion_limit = pregel.runnable_config.recursion_limit;
1128
1129 async move {
1130 while pregel.tick()? {
1132 let result = pregel.execute_superstep().await?;
1133 pregel.after_tick(result).await?;
1134 }
1135
1136 let steps = pregel.step();
1138 let run_id = pregel.run_id().to_string();
1139
1140 let final_state = pregel.into_state();
1142 let output = O::from_state(&final_state);
1143
1144 Ok(GraphOutput {
1145 value: final_state,
1146 output,
1147 interrupts: Vec::new(),
1148 metadata: GraphOutputMetadata {
1149 steps,
1150 run_id,
1151 checkpoint_id: config.checkpoint_id.clone(),
1152 budget_usage: None,
1153 },
1154 })
1155 }
1156 .instrument(info_span!(
1157 "juncture.graph.invoke",
1158 "juncture.graph.name" = graph_name,
1159 "juncture.run.id" = %run_id,
1160 "juncture.recursion.limit" = recursion_limit,
1161 ))
1162 .await
1163 }
1164
1165 pub async fn resume_single(
1190 &self,
1191 config: &RunnableConfig,
1192 value: serde_json::Value,
1193 ) -> Result<GraphOutput<S, O>, JunctureError>
1194 where
1195 S: for<'de> serde::Deserialize<'de> + serde::Serialize,
1196 S::Update: serde::Serialize,
1197 O: FromState<S>,
1198 {
1199 self.resume(config, ResumeValue::Single(value)).await
1200 }
1201
1202 pub async fn resume_stream(
1259 &self,
1260 config: &RunnableConfig,
1261 resume_value: ResumeValue,
1262 mode: StreamMode,
1263 ) -> Result<StreamHandle<S>, JunctureError>
1264 where
1265 S: Clone + Send + for<'de> serde::Deserialize<'de> + serde::Serialize + 'static,
1266 S::Update: serde::Serialize,
1267 {
1268 use futures::stream;
1269
1270 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1271 JunctureError::checkpoint("no checkpointer configured for resume_stream")
1272 })?;
1273
1274 let tuple = checkpointer
1276 .get_tuple(config)
1277 .await
1278 .map_err(|e| JunctureError::checkpoint(format!("failed to load checkpoint: {e}")))?
1279 .ok_or_else(|| {
1280 JunctureError::checkpoint(format!(
1281 "checkpoint not found: thread_id={:?}, checkpoint_id={:?}",
1282 config.thread_id, config.checkpoint_id
1283 ))
1284 })?;
1285
1286 if !matches!(tuple.metadata.source, CheckpointSource::Interrupt { .. }) {
1288 return Err(JunctureError::checkpoint(format!(
1289 "resume_stream() requires checkpoint from Interrupt source, got {:?}",
1290 tuple.metadata.source
1291 )));
1292 }
1293
1294 let state = Self::deserialize_with_migration(&tuple.checkpoint)?;
1296
1297 let mut resume_config = self.effective_config(config);
1299 resume_config.resume_value = Some(resume_value);
1300 if resume_config.run_id.is_none() {
1301 resume_config.run_id = Some(uuid::Uuid::new_v4().to_string());
1302 }
1303
1304 let num_fields = 64;
1306 let error_handler_map = self.build_error_handler_map();
1307 let retry_policy_map = self.build_retry_policy_map();
1308 let timeout_policy_map = self.build_timeout_policy_map();
1309 let mut pregel = PregelLoop::with_error_handlers(
1310 state,
1311 self.inner.nodes.clone(),
1312 self.inner.trigger_table.clone(),
1313 resume_config,
1314 num_fields,
1315 error_handler_map,
1316 )?;
1317
1318 pregel.set_retry_policies(retry_policy_map);
1319 pregel.set_timeout_policies(timeout_policy_map);
1320
1321 if let Some(budget_config) = &pregel.runnable_config.budget {
1323 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
1324 pregel.set_budget_tracker(
1325 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
1326 );
1327 }
1328
1329 if let Some(cp) = self.inner.checkpointer.clone() {
1331 pregel.set_checkpointer(cp);
1332 }
1333
1334 let (_handle, rx, run_id) = Self::spawn_streaming_loop(pregel, mode);
1335
1336 Ok(StreamHandle {
1338 run_id,
1339 stream: Box::pin(stream::unfold(rx, |mut receiver| async move {
1340 receiver.recv().await.map(|item| (item, receiver))
1341 })),
1342 })
1343 }
1344
1345 #[allow(
1350 clippy::type_complexity,
1351 reason = "return type is a tuple of channel handle, receiver, and run_id which is clear in context"
1352 )]
1353 fn spawn_streaming_loop(
1354 mut pregel: PregelLoop<S>,
1355 mode: StreamMode,
1356 ) -> (
1357 tokio::task::JoinHandle<()>,
1358 mpsc::Receiver<Result<StreamEvent<S>, JunctureError>>,
1359 String,
1360 )
1361 where
1362 S: Clone + Send + for<'de> serde::Deserialize<'de> + serde::Serialize + 'static,
1363 S::Update: serde::Serialize,
1364 {
1365 let capacity = stream_capacity(&mode);
1368 let (tx, rx) = mpsc::channel(capacity);
1369
1370 let run_id = pregel.run_id().to_string();
1372 let graph_name = pregel
1373 .runnable_config
1374 .graph_name
1375 .clone()
1376 .unwrap_or_else(|| "unnamed".to_string());
1377 let recursion_limit = pregel.runnable_config.recursion_limit;
1378
1379 let (pregel_tx, mut pregel_rx) = mpsc::unbounded_channel();
1384 pregel.set_stream_sender(pregel_tx);
1385
1386 let handle = tokio::spawn(
1387 async move {
1388 let tx_forward = tx.clone();
1390 let mode_forward = mode.clone();
1391 tokio::spawn(async move {
1392 let (temp_tx, _temp_rx) = mpsc::channel(1);
1394 let emitter = EventEmitter::new(temp_tx, mode_forward);
1395
1396 while let Some(event) = pregel_rx.recv().await {
1397 if emitter.should_emit(&event) {
1398 let _ = tx_forward.send(Ok(event)).await;
1399 }
1400 }
1401 });
1402
1403 while matches!(pregel.tick(), Ok(true)) {
1405 let step = pregel.step();
1406
1407 if matches!(mode, StreamMode::Values) {
1409 let event = StreamEvent::Values {
1410 state: pregel.snapshot_state(),
1411 step,
1412 };
1413 let _ = tx.send(Ok(event)).await;
1414 }
1415
1416 match pregel.execute_superstep().await {
1418 Ok(result) => {
1419 if let Err(e) = pregel.after_tick(result).await {
1420 let _ = tx
1421 .send(Ok(StreamEvent::End {
1422 output: pregel.snapshot_state(),
1423 }))
1424 .await;
1425 let _ = tx.send(Err(e)).await;
1426 return;
1427 }
1428 }
1429 Err(e) => {
1430 let _ = tx
1431 .send(Ok(StreamEvent::End {
1432 output: pregel.snapshot_state(),
1433 }))
1434 .await;
1435 let _ = tx.send(Err(e)).await;
1436 return;
1437 }
1438 }
1439 }
1440
1441 let final_state = pregel.into_state();
1443 let _ = tx
1444 .send(Ok(StreamEvent::End {
1445 output: final_state,
1446 }))
1447 .await;
1448 }
1449 .instrument(info_span!(
1450 "juncture.graph.invoke",
1451 "juncture.graph.name" = graph_name,
1452 "juncture.run.id" = %run_id,
1453 "juncture.recursion.limit" = recursion_limit,
1454 )),
1455 );
1456
1457 (handle, rx, run_id)
1458 }
1459
1460 pub async fn get_state(
1469 &self,
1470 config: &RunnableConfig,
1471 ) -> Result<Option<StateSnapshot<S>>, JunctureError>
1472 where
1473 S: serde::de::DeserializeOwned,
1474 {
1475 let checkpointer =
1476 self.inner.checkpointer.as_ref().ok_or_else(|| {
1477 JunctureError::checkpoint("no checkpointer configured for get_state")
1478 })?;
1479
1480 let tuple = checkpointer
1481 .get_tuple(config)
1482 .await
1483 .map_err(|e| JunctureError::checkpoint(e.to_string()))?;
1484
1485 let Some(tuple) = tuple else {
1486 return Ok(None);
1487 };
1488
1489 let values = Self::deserialize_with_migration(&tuple.checkpoint)?;
1491
1492 let next: Vec<String> = tuple
1494 .checkpoint
1495 .pending_tasks
1496 .iter()
1497 .map(|t| t.node.clone())
1498 .collect();
1499
1500 let snapshot = StateSnapshot {
1501 values,
1502 next,
1503 config: tuple.config,
1504 metadata: tuple.metadata,
1505 created_at: tuple.checkpoint.created_at,
1506 parent_config: tuple.parent_config,
1507 tasks: vec![],
1508 };
1509
1510 Ok(Some(snapshot))
1511 }
1512
1513 #[expect(
1528 clippy::unused_async,
1529 reason = "async API consistency for checkpoint operations"
1530 )]
1531 pub async fn get_state_history(
1532 &self,
1533 _config: &RunnableConfig,
1534 filter: Option<CheckpointFilter>,
1535 ) -> Result<Vec<StateSnapshot<S>>, JunctureError> {
1536 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1537 JunctureError::checkpoint("no checkpointer configured for get_state_history")
1538 })?;
1539
1540 let _ = (checkpointer, filter);
1541
1542 Err(JunctureError::checkpoint(
1545 "get_state_history not yet implemented: requires checkpoint state recovery",
1546 ))
1547 }
1548
1549 pub async fn update_state(
1572 &self,
1573 config: &RunnableConfig,
1574 update: StateUpdate<S>,
1575 ) -> Result<RunnableConfig, JunctureError>
1576 where
1577 S: serde::de::DeserializeOwned + serde::Serialize,
1578 {
1579 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1580 JunctureError::checkpoint("no checkpointer configured for update_state")
1581 })?;
1582
1583 let tuple = checkpointer
1585 .get_tuple(config)
1586 .await
1587 .map_err(|e| JunctureError::checkpoint(e.to_string()))?;
1588
1589 let Some(tuple) = tuple else {
1590 return Err(JunctureError::checkpoint(
1591 "no checkpoint found for update_state",
1592 ));
1593 };
1594
1595 let mut state = Self::deserialize_with_migration(&tuple.checkpoint)?;
1597
1598 state.apply(update.update);
1600
1601 let updated_values = serde_json::to_value(&state).map_err(|e| {
1603 JunctureError::checkpoint(format!("failed to serialize updated state: {e}"))
1604 })?;
1605
1606 let mut writes = tuple.metadata.writes;
1608 if let Some(as_node) = update.as_node {
1609 writes.insert(as_node, serde_json::Value::Null);
1610 }
1611
1612 let updated_checkpoint = Checkpoint {
1614 channel_values: updated_values,
1615 ..tuple.checkpoint
1616 };
1617
1618 let metadata = CheckpointMetadata {
1620 source: CheckpointSource::Update,
1621 step: tuple.metadata.step + 1,
1622 writes,
1623 ..tuple.metadata
1624 };
1625
1626 checkpointer
1628 .put(config, updated_checkpoint, metadata)
1629 .await
1630 .map_err(|e| JunctureError::checkpoint(e.to_string()))
1631 }
1632
1633 #[expect(
1648 clippy::unused_async,
1649 reason = "async API consistency for checkpoint operations"
1650 )]
1651 pub async fn bulk_update_state(
1652 &self,
1653 _config: &RunnableConfig,
1654 updates: Vec<StateUpdate<S>>,
1655 ) -> Result<Vec<RunnableConfig>, JunctureError> {
1656 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1657 JunctureError::checkpoint("no checkpointer configured for bulk_update_state")
1658 })?;
1659
1660 let _ = (checkpointer, updates);
1661
1662 Err(JunctureError::checkpoint(
1665 "bulk_update_state not yet implemented: requires checkpoint state recovery",
1666 ))
1667 }
1668
1669 #[must_use]
1680 pub fn get_graph(&self, xray: Option<usize>) -> DrawableGraph {
1681 let _ = xray;
1682
1683 self.to_drawable()
1686 }
1687
1688 #[must_use]
1693 pub fn get_subgraphs(&self) -> Vec<SubgraphInfo> {
1694 self.inner.subgraphs.clone()
1695 }
1696
1697 #[must_use]
1699 pub fn nodes(&self) -> &IndexMap<String, Arc<dyn crate::Node<S>>> {
1700 &self.inner.nodes
1701 }
1702
1703 #[must_use]
1705 pub fn trigger_table(&self) -> &TriggerTable<S> {
1706 &self.inner.trigger_table
1707 }
1708
1709 #[must_use]
1711 pub fn checkpointer(&self) -> Option<&Arc<dyn crate::checkpoint::CheckpointSaver>> {
1712 self.inner.checkpointer.as_ref()
1713 }
1714
1715 #[must_use]
1717 pub fn builder_metadata(&self) -> &IndexMap<String, NodeMetadata> {
1718 &self.inner.builder_metadata
1719 }
1720
1721 #[must_use]
1732 pub fn to_mermaid(&self) -> String {
1733 let mut lines = vec!["graph TD".to_string()];
1734
1735 for node_name in self.inner.nodes.keys() {
1737 lines.push(format!(" {node_name}[{node_name}]"));
1738 }
1739
1740 for (from, edges) in &self.inner.trigger_table.outgoing {
1742 for edge in edges {
1743 match edge {
1744 crate::edge::CompiledEdge::Fixed { target } => {
1745 lines.push(format!(" {from} --> {target}"));
1746 }
1747 crate::edge::CompiledEdge::Conditional { path_map, .. } => {
1748 for (branch, target) in path_map.iter() {
1749 lines.push(format!(" {from} -->|{branch}| {target}"));
1750 }
1751 }
1752 }
1753 }
1754 }
1755
1756 if let Some(entry) = self.find_entry_point() {
1758 lines.push(format!(" START((start)) --> {entry}"));
1759 }
1760
1761 lines.join("\n")
1762 }
1763
1764 #[must_use]
1775 pub fn to_dot(&self) -> String {
1776 let mut lines = vec!["digraph juncture_graph {".to_string()];
1777 lines.push(" rankdir=LR;".to_string());
1778 lines.push(" node [shape=box];".to_string());
1779 lines.push(" START [shape=circle];".to_string());
1780 lines.push(" END [shape=doublecircle];".to_string());
1781 lines.push(String::new());
1782
1783 for node_name in self.inner.nodes.keys() {
1785 lines.push(format!(" {node_name};"));
1786 }
1787
1788 lines.push(String::new());
1789
1790 for (from, edges) in &self.inner.trigger_table.outgoing {
1792 for edge in edges {
1793 match edge {
1794 crate::edge::CompiledEdge::Fixed { target } => {
1795 lines.push(format!(" {from} -> {target};"));
1796 }
1797 crate::edge::CompiledEdge::Conditional { path_map, .. } => {
1798 for (branch, target) in path_map.iter() {
1799 lines.push(format!(" {from} -> {target} [label=\"{branch}\"];"));
1800 }
1801 }
1802 }
1803 }
1804 }
1805
1806 if let Some(entry) = self.find_entry_point() {
1808 lines.push(format!(" START -> {entry};"));
1809 }
1810
1811 lines.push("}".to_string());
1812 lines.join("\n")
1813 }
1814
1815 #[must_use]
1827 pub fn to_json(&self) -> serde_json::Value {
1828 let drawable = self.to_drawable();
1829
1830 serde_json::json!({
1831 "nodes": drawable.nodes.into_iter().map(|n| {
1832 serde_json::json!({
1833 "name": n.name,
1834 "metadata": n.metadata,
1835 })
1836 }).collect::<Vec<_>>(),
1837 "edges": drawable.edges.into_iter().map(|e| {
1838 let mut edge = serde_json::json!({
1839 "from": e.from,
1840 "to": e.to,
1841 "conditional": e.conditional,
1842 });
1843 if let Some(label) = e.label {
1844 edge["label"] = serde_json::Value::String(label);
1845 }
1846 edge
1847 }).collect::<Vec<_>>(),
1848 })
1849 }
1850
1851 fn to_drawable(&self) -> DrawableGraph {
1853 let mut nodes = Vec::new();
1854 let mut edges = Vec::new();
1855
1856 for node_name in self.inner.nodes.keys() {
1858 let metadata = self
1859 .inner
1860 .builder_metadata
1861 .get(node_name)
1862 .and_then(|m| m.metadata.clone())
1863 .unwrap_or_default();
1864
1865 nodes.push(DrawableNode {
1866 name: node_name.clone(),
1867 metadata,
1868 });
1869 }
1870
1871 for (from, edge_list) in &self.inner.trigger_table.outgoing {
1873 for edge in edge_list {
1874 match edge {
1875 crate::edge::CompiledEdge::Fixed { target } => {
1876 edges.push(DrawableEdge {
1877 from: from.clone(),
1878 to: target.clone(),
1879 conditional: false,
1880 label: None,
1881 });
1882 }
1883 crate::edge::CompiledEdge::Conditional { path_map, .. } => {
1884 for (branch, target) in path_map.iter() {
1885 edges.push(DrawableEdge {
1886 from: from.clone(),
1887 to: target.clone(),
1888 conditional: true,
1889 label: Some(branch.clone()),
1890 });
1891 }
1892 }
1893 }
1894 }
1895 }
1896
1897 DrawableGraph { nodes, edges }
1898 }
1899
1900 fn find_entry_point(&self) -> Option<String> {
1902 for (target, sources) in &self.inner.trigger_table.incoming {
1903 for source in sources {
1904 if matches!(source, crate::edge::TriggerSource::Edge { from } if from == "START") {
1905 return Some(target.clone());
1906 }
1907 }
1908 }
1909 None
1910 }
1911}
1912
1913#[allow(dead_code, reason = "fields used through Arc, not directly")]
1915struct CompiledGraphInner<S: State> {
1916 nodes: IndexMap<String, Arc<dyn crate::Node<S>>>,
1918
1919 trigger_table: TriggerTable<S>,
1921
1922 builder_metadata: IndexMap<String, NodeMetadata>,
1924
1925 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
1927
1928 interrupt_before: Vec<String>,
1930
1931 interrupt_after: Vec<String>,
1933
1934 subgraphs: Vec<SubgraphInfo>,
1936
1937 active_invocations: std::sync::atomic::AtomicU64,
1943}
1944
1945#[derive(Debug)]
1950pub struct GraphOutput<S: State, O: FromState<S> = S> {
1951 pub value: S,
1953
1954 pub output: O,
1956
1957 pub interrupts: Vec<InterruptInfo>,
1959
1960 pub metadata: GraphOutputMetadata,
1962}
1963
1964#[derive(Clone, Debug)]
1968pub struct InterruptInfo {
1969 pub node: String,
1971
1972 pub value: serde_json::Value,
1974
1975 pub id: Option<String>,
1977}
1978
1979#[derive(Clone, Debug)]
1983pub struct GraphOutputMetadata {
1984 pub steps: usize,
1986
1987 pub run_id: String,
1989
1990 pub checkpoint_id: Option<String>,
1992
1993 pub budget_usage: Option<crate::pregel::BudgetUsage>,
1995}
1996
1997#[derive(Clone, Debug)]
2002pub struct StateUpdate<S: State> {
2003 pub update: S::Update,
2005
2006 pub label: Option<String>,
2008
2009 pub as_node: Option<String>,
2011}
2012
2013#[derive(Clone, Debug)]
2018pub struct SubgraphInfo {
2019 pub name: String,
2021
2022 pub persistence: crate::subgraph::SubgraphPersistence,
2024}
2025
2026#[derive(Clone, Debug, Default)]
2030pub struct StateFilter {
2031 pub after_step: Option<usize>,
2033
2034 pub before_step: Option<usize>,
2036
2037 pub limit: Option<usize>,
2039}
2040
2041#[derive(Clone, Debug)]
2045pub struct DrawableGraph {
2046 pub nodes: Vec<DrawableNode>,
2048
2049 pub edges: Vec<DrawableEdge>,
2051}
2052
2053#[derive(Clone, Debug)]
2057pub struct DrawableNode {
2058 pub name: String,
2060
2061 pub metadata: std::collections::HashMap<String, serde_json::Value>,
2063}
2064
2065#[derive(Clone, Debug)]
2069pub struct DrawableEdge {
2070 pub from: String,
2072
2073 pub to: String,
2075
2076 pub conditional: bool,
2078
2079 pub label: Option<String>,
2081}
2082
2083#[cfg(test)]
2084mod tests {
2085 use super::*;
2086 use crate::{node::IntoNode, node::NodeFnUpdate};
2087
2088 #[test]
2089 fn test_compiled_graph_creation() {
2090 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2091 nodes.insert("test".to_string(), mock_node("test"));
2092
2093 let trigger_table = TriggerTable::new();
2094 let builder_metadata = IndexMap::new();
2095
2096 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2097 nodes,
2098 trigger_table,
2099 builder_metadata,
2100 vec![],
2101 vec![],
2102 None,
2103 vec![],
2104 );
2105 assert_eq!(compiled.nodes().len(), 1);
2106 }
2107
2108 #[test]
2109 fn test_to_mermaid() {
2110 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2111 nodes.insert("a".to_string(), mock_node("a"));
2112 nodes.insert("b".to_string(), mock_node("b"));
2113
2114 let mut trigger_table = TriggerTable::new();
2115 trigger_table.add_outgoing(
2116 "a".to_string(),
2117 crate::edge::CompiledEdge::Fixed {
2118 target: "b".to_string(),
2119 },
2120 );
2121
2122 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2123 nodes,
2124 trigger_table,
2125 IndexMap::new(),
2126 vec![],
2127 vec![],
2128 None,
2129 vec![],
2130 );
2131 let mermaid = compiled.to_mermaid();
2132
2133 assert!(mermaid.contains("graph TD"));
2134 assert!(mermaid.contains("a --> b"));
2135 }
2136
2137 #[test]
2138 fn test_to_dot() {
2139 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2140 nodes.insert("a".to_string(), mock_node("a"));
2141 nodes.insert("b".to_string(), mock_node("b"));
2142
2143 let mut trigger_table = TriggerTable::new();
2144 trigger_table.add_outgoing(
2145 "a".to_string(),
2146 crate::edge::CompiledEdge::Fixed {
2147 target: "b".to_string(),
2148 },
2149 );
2150
2151 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2152 nodes,
2153 trigger_table,
2154 IndexMap::new(),
2155 vec![],
2156 vec![],
2157 None,
2158 vec![],
2159 );
2160 let dot = compiled.to_dot();
2161
2162 assert!(dot.contains("digraph juncture_graph"));
2163 assert!(dot.contains("a -> b"));
2164 }
2165
2166 #[test]
2167 fn test_to_json() {
2168 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2169 nodes.insert("a".to_string(), mock_node("a"));
2170 nodes.insert("b".to_string(), mock_node("b"));
2171
2172 let mut trigger_table = TriggerTable::new();
2173 trigger_table.add_outgoing(
2174 "a".to_string(),
2175 crate::edge::CompiledEdge::Fixed {
2176 target: "b".to_string(),
2177 },
2178 );
2179
2180 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2181 nodes,
2182 trigger_table,
2183 IndexMap::new(),
2184 vec![],
2185 vec![],
2186 None,
2187 vec![],
2188 );
2189 let json = compiled.to_json();
2190
2191 assert!(json.is_object());
2192 assert!(json.get("nodes").is_some());
2193 assert!(json.get("edges").is_some());
2194 }
2195
2196 #[test]
2197 fn test_get_graph() {
2198 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2199 nodes.insert("a".to_string(), mock_node("a"));
2200
2201 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2202 nodes,
2203 TriggerTable::new(),
2204 IndexMap::new(),
2205 vec![],
2206 vec![],
2207 None,
2208 vec![],
2209 );
2210 let drawable = compiled.get_graph(None);
2211 assert_eq!(drawable.nodes.len(), 1);
2212
2213 let drawable_xray = compiled.get_graph(Some(2));
2214 assert_eq!(drawable_xray.nodes.len(), 1);
2215 }
2216
2217 #[test]
2218 fn test_get_subgraphs_empty() {
2219 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2220 nodes.insert("a".to_string(), mock_node("a"));
2221
2222 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2223 nodes,
2224 TriggerTable::new(),
2225 IndexMap::new(),
2226 vec![],
2227 vec![],
2228 None,
2229 vec![],
2230 );
2231 let subgraphs = compiled.get_subgraphs();
2232 assert!(subgraphs.is_empty());
2233 }
2234
2235 #[test]
2236 fn test_get_subgraphs_with_mounted_subgraphs() {
2237 use crate::subgraph::{SubgraphConfig, SubgraphMount, SubgraphPersistence};
2238
2239 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2240 nodes.insert("a".to_string(), mock_node("a"));
2241
2242 let sub_node = mock_node("sub_node");
2243 let mount_inherit = SubgraphMount::new(
2244 "child_graph",
2245 SubgraphConfig {
2246 persistence: SubgraphPersistence::Inherit,
2247 },
2248 Arc::clone(&sub_node),
2249 );
2250 let mount_per_thread = SubgraphMount::new(
2251 "worker_graph",
2252 SubgraphConfig {
2253 persistence: SubgraphPersistence::PerThread,
2254 },
2255 sub_node,
2256 );
2257
2258 let subgraphs = vec![
2259 super::SubgraphInfo {
2260 name: mount_inherit.name.clone(),
2261 persistence: mount_inherit.config.persistence,
2262 },
2263 super::SubgraphInfo {
2264 name: mount_per_thread.name.clone(),
2265 persistence: mount_per_thread.config.persistence,
2266 },
2267 ];
2268
2269 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2270 nodes,
2271 TriggerTable::new(),
2272 IndexMap::new(),
2273 vec![],
2274 vec![],
2275 None,
2276 subgraphs,
2277 );
2278
2279 let result = compiled.get_subgraphs();
2280 assert_eq!(result.len(), 2);
2281 assert_eq!(result[0].name, "child_graph");
2282 assert_eq!(result[0].persistence, SubgraphPersistence::Inherit);
2283 assert_eq!(result[1].name, "worker_graph");
2284 assert_eq!(result[1].persistence, SubgraphPersistence::PerThread);
2285 }
2286
2287 #[tokio::test]
2288 async fn test_resume_no_checkpointer() {
2289 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2290 nodes.insert("a".to_string(), mock_node("a"));
2291
2292 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2293 nodes,
2294 TriggerTable::new(),
2295 IndexMap::new(),
2296 vec![],
2297 vec![],
2298 None,
2299 vec![],
2300 );
2301 let config = RunnableConfig::new();
2302
2303 let result = compiled
2304 .resume(&config, ResumeValue::Single(serde_json::Value::Null))
2305 .await;
2306 assert!(result.is_err());
2307 assert!(result.unwrap_err().is_checkpoint());
2308 }
2309
2310 #[tokio::test]
2311 #[expect(
2312 clippy::too_many_lines,
2313 reason = "comprehensive test with multiple mock scenarios"
2314 )]
2315 async fn test_resume_validates_interrupt_source() {
2316 use crate::checkpoint::{
2317 Checkpoint, CheckpointMetadata, CheckpointSource, CheckpointTuple,
2318 };
2319 use std::collections::HashMap;
2320
2321 struct MockCheckpointer {
2323 checkpoint_source: CheckpointSource,
2324 }
2325
2326 #[async_trait::async_trait]
2327 impl crate::checkpoint::CheckpointSaver for MockCheckpointer {
2328 async fn get_tuple(
2329 &self,
2330 _config: &crate::config::RunnableConfig,
2331 ) -> Result<Option<CheckpointTuple>, crate::checkpoint::CheckpointError> {
2332 Ok(Some(CheckpointTuple {
2333 config: crate::config::RunnableConfig::new(),
2334 checkpoint: Checkpoint {
2335 id: "test_id".to_string(),
2336 channel_values: serde_json::json!({}),
2337 channel_versions: HashMap::new(),
2338 versions_seen: HashMap::new(),
2339 pending_tasks: Vec::new(),
2340 pending_sends: Vec::new(),
2341 pending_interrupts: Vec::new(),
2342 schema_version: 1,
2343 created_at: "2024-01-01T00:00:00Z".to_string(),
2344 v: 1,
2345 new_versions: HashMap::new(),
2346 counters_since_delta_snapshot: HashMap::new(),
2347 },
2348 metadata: CheckpointMetadata {
2349 source: self.checkpoint_source.clone(),
2350 step: 1,
2351 writes: HashMap::new(),
2352 parents: HashMap::new(),
2353 run_id: "test_run".to_string(),
2354 },
2355 pending_writes: Vec::new(),
2356 parent_config: None,
2357 }))
2358 }
2359
2360 async fn list(
2361 &self,
2362 _config: &crate::config::RunnableConfig,
2363 _filter: Option<crate::checkpoint::CheckpointFilter>,
2364 ) -> Result<Vec<CheckpointTuple>, crate::checkpoint::CheckpointError> {
2365 Ok(Vec::new())
2366 }
2367
2368 async fn put(
2369 &self,
2370 _config: &crate::config::RunnableConfig,
2371 _checkpoint: Checkpoint,
2372 _metadata: CheckpointMetadata,
2373 ) -> Result<crate::config::RunnableConfig, crate::checkpoint::CheckpointError>
2374 {
2375 Ok(crate::config::RunnableConfig::new())
2376 }
2377
2378 async fn put_writes(
2379 &self,
2380 _config: &crate::config::RunnableConfig,
2381 _writes: Vec<crate::checkpoint::PendingWrite>,
2382 _task_id: &str,
2383 ) -> Result<(), crate::checkpoint::CheckpointError> {
2384 Ok(())
2385 }
2386 }
2387
2388 let nodes = {
2390 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2391 nodes.insert("a".to_string(), mock_node("a"));
2392 nodes
2393 };
2394
2395 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2396 nodes.clone(),
2397 TriggerTable::new(),
2398 IndexMap::new(),
2399 vec![],
2400 vec![],
2401 Some(Arc::new(MockCheckpointer {
2402 checkpoint_source: CheckpointSource::Input,
2403 })),
2404 vec![],
2405 );
2406
2407 let config = RunnableConfig::new();
2408 let result = compiled
2409 .resume(&config, ResumeValue::Single(serde_json::json!("test")))
2410 .await;
2411
2412 assert!(result.is_err());
2413 let err = result.unwrap_err();
2414 assert!(err.is_checkpoint());
2415 assert!(
2416 err.to_string()
2417 .contains("resume() requires checkpoint from Interrupt source")
2418 );
2419 assert!(err.to_string().contains("Input"));
2420
2421 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2423 nodes.clone(),
2424 TriggerTable::new(),
2425 IndexMap::new(),
2426 vec![],
2427 vec![],
2428 Some(Arc::new(MockCheckpointer {
2429 checkpoint_source: CheckpointSource::Loop,
2430 })),
2431 vec![],
2432 );
2433
2434 let result = compiled
2435 .resume(&config, ResumeValue::Single(serde_json::json!("test")))
2436 .await;
2437
2438 assert!(result.is_err());
2439 let err = result.unwrap_err();
2440 assert!(err.is_checkpoint());
2441 assert!(
2442 err.to_string()
2443 .contains("resume() requires checkpoint from Interrupt source")
2444 );
2445 assert!(err.to_string().contains("Loop"));
2446
2447 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2449 nodes,
2450 TriggerTable::new(),
2451 IndexMap::new(),
2452 vec![],
2453 vec![],
2454 Some(Arc::new(MockCheckpointer {
2455 checkpoint_source: CheckpointSource::Interrupt {
2456 node: "test_node".to_string(),
2457 },
2458 })),
2459 vec![],
2460 );
2461
2462 let result = compiled
2463 .resume(&config, ResumeValue::Single(serde_json::json!("test")))
2464 .await;
2465
2466 if let Err(err) = result {
2469 assert!(
2470 !err.to_string()
2471 .contains("resume() requires checkpoint from Interrupt source")
2472 );
2473 }
2474 }
2475
2476 #[tokio::test]
2477 async fn test_resume_single_no_checkpointer() {
2478 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2479 nodes.insert("a".to_string(), mock_node("a"));
2480
2481 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2482 nodes,
2483 TriggerTable::new(),
2484 IndexMap::new(),
2485 vec![],
2486 vec![],
2487 None,
2488 vec![],
2489 );
2490 let config = RunnableConfig::new();
2491
2492 let result = compiled
2493 .resume_single(&config, serde_json::Value::Null)
2494 .await;
2495 assert!(result.is_err());
2496 assert!(result.unwrap_err().is_checkpoint());
2497 }
2498
2499 #[tokio::test]
2500 async fn test_resume_stream_no_checkpointer() {
2501 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2502 nodes.insert("a".to_string(), mock_node("a"));
2503
2504 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2505 nodes,
2506 TriggerTable::new(),
2507 IndexMap::new(),
2508 vec![],
2509 vec![],
2510 None,
2511 vec![],
2512 );
2513 let config = RunnableConfig::new();
2514
2515 let result = compiled
2516 .resume_stream(
2517 &config,
2518 ResumeValue::Single(serde_json::Value::Null),
2519 StreamMode::Values,
2520 )
2521 .await;
2522 let Err(err) = result else {
2523 panic!("expected checkpoint error, got stream");
2524 };
2525 assert!(err.is_checkpoint());
2526 }
2527
2528 #[tokio::test]
2529 #[expect(
2530 clippy::too_many_lines,
2531 reason = "mock checkpointer boilerplate inflates line count; extraction would hurt readability"
2532 )]
2533 async fn test_resume_stream_validates_interrupt_source() {
2534 use crate::checkpoint::{
2535 Checkpoint, CheckpointError, CheckpointMetadata, CheckpointSource, CheckpointTuple,
2536 };
2537 use std::collections::HashMap;
2538
2539 struct MockCheckpointer {
2540 checkpoint_source: CheckpointSource,
2541 }
2542
2543 #[async_trait::async_trait]
2544 impl crate::checkpoint::CheckpointSaver for MockCheckpointer {
2545 async fn get_tuple(
2546 &self,
2547 _config: &crate::config::RunnableConfig,
2548 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
2549 Ok(Some(CheckpointTuple {
2550 config: crate::config::RunnableConfig::new(),
2551 checkpoint: Checkpoint {
2552 id: "test_id".to_string(),
2553 channel_values: serde_json::json!({}),
2554 channel_versions: HashMap::new(),
2555 versions_seen: HashMap::new(),
2556 pending_tasks: Vec::new(),
2557 pending_sends: Vec::new(),
2558 pending_interrupts: Vec::new(),
2559 schema_version: 1,
2560 created_at: "2024-01-01T00:00:00Z".to_string(),
2561 v: 1,
2562 new_versions: HashMap::new(),
2563 counters_since_delta_snapshot: HashMap::new(),
2564 },
2565 metadata: CheckpointMetadata {
2566 source: self.checkpoint_source.clone(),
2567 step: 1,
2568 writes: HashMap::new(),
2569 parents: HashMap::new(),
2570 run_id: "test_run".to_string(),
2571 },
2572 pending_writes: Vec::new(),
2573 parent_config: None,
2574 }))
2575 }
2576
2577 async fn list(
2578 &self,
2579 _config: &crate::config::RunnableConfig,
2580 _filter: Option<crate::checkpoint::CheckpointFilter>,
2581 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
2582 Ok(Vec::new())
2583 }
2584
2585 async fn put(
2586 &self,
2587 _config: &crate::config::RunnableConfig,
2588 _checkpoint: Checkpoint,
2589 _metadata: CheckpointMetadata,
2590 ) -> Result<crate::config::RunnableConfig, CheckpointError> {
2591 Ok(crate::config::RunnableConfig::new())
2592 }
2593
2594 async fn put_writes(
2595 &self,
2596 _config: &crate::config::RunnableConfig,
2597 _writes: Vec<crate::checkpoint::PendingWrite>,
2598 _task_id: &str,
2599 ) -> Result<(), CheckpointError> {
2600 Ok(())
2601 }
2602 }
2603
2604 let nodes = {
2605 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2606 nodes.insert("a".to_string(), mock_node("a"));
2607 nodes
2608 };
2609
2610 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2612 nodes.clone(),
2613 TriggerTable::new(),
2614 IndexMap::new(),
2615 vec![],
2616 vec![],
2617 Some(Arc::new(MockCheckpointer {
2618 checkpoint_source: CheckpointSource::Input,
2619 })),
2620 vec![],
2621 );
2622
2623 let config = RunnableConfig::new();
2624 let result = compiled
2625 .resume_stream(
2626 &config,
2627 ResumeValue::Single(serde_json::json!("test")),
2628 StreamMode::Values,
2629 )
2630 .await;
2631
2632 assert!(result.is_err());
2633 let Err(err) = result else {
2634 panic!("expected checkpoint error, got stream");
2635 };
2636 assert!(err.is_checkpoint());
2637 assert!(
2638 err.to_string()
2639 .contains("resume_stream() requires checkpoint from Interrupt source"),
2640 "Expected interrupt source validation error, got: {err}"
2641 );
2642
2643 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2645 nodes,
2646 TriggerTable::new(),
2647 IndexMap::new(),
2648 vec![],
2649 vec![],
2650 Some(Arc::new(MockCheckpointer {
2651 checkpoint_source: CheckpointSource::Interrupt {
2652 node: "test_node".to_string(),
2653 },
2654 })),
2655 vec![],
2656 );
2657
2658 let result = compiled
2659 .resume_stream(
2660 &config,
2661 ResumeValue::Single(serde_json::json!("test")),
2662 StreamMode::Values,
2663 )
2664 .await;
2665
2666 if let Err(err) = result {
2669 assert!(
2670 !err.to_string()
2671 .contains("resume_stream() requires checkpoint from Interrupt source"),
2672 "Interrupt source should pass validation: {err}"
2673 );
2674 }
2675 }
2676
2677 #[tokio::test]
2678 async fn test_get_state_no_checkpointer() {
2679 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2680 nodes.insert("a".to_string(), mock_node("a"));
2681
2682 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2683 nodes,
2684 TriggerTable::new(),
2685 IndexMap::new(),
2686 vec![],
2687 vec![],
2688 None,
2689 vec![],
2690 );
2691 let config = RunnableConfig::new();
2692
2693 let result = compiled.get_state(&config).await;
2694 assert!(result.is_err());
2695 assert!(result.unwrap_err().is_checkpoint());
2696 }
2697
2698 #[tokio::test]
2699 async fn test_get_state_history_no_checkpointer() {
2700 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2701 nodes.insert("a".to_string(), mock_node("a"));
2702
2703 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2704 nodes,
2705 TriggerTable::new(),
2706 IndexMap::new(),
2707 vec![],
2708 vec![],
2709 None,
2710 vec![],
2711 );
2712 let config = RunnableConfig::new();
2713
2714 let result = compiled.get_state_history(&config, None).await;
2715 assert!(result.is_err());
2716 assert!(result.unwrap_err().is_checkpoint());
2717 }
2718
2719 #[tokio::test]
2720 async fn test_update_state_no_checkpointer() {
2721 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2722 nodes.insert("a".to_string(), mock_node("a"));
2723
2724 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2725 nodes,
2726 TriggerTable::new(),
2727 IndexMap::new(),
2728 vec![],
2729 vec![],
2730 None,
2731 vec![],
2732 );
2733 let config = RunnableConfig::new();
2734
2735 let update = StateUpdate {
2736 update: StateDummyUpdate,
2737 label: None,
2738 as_node: None,
2739 };
2740
2741 let result = compiled.update_state(&config, update).await;
2742 assert!(result.is_err());
2743 assert!(result.unwrap_err().is_checkpoint());
2744 }
2745
2746 #[tokio::test]
2747 async fn test_update_state_no_checkpoint_found() {
2748 use crate::checkpoint::{Checkpoint, CheckpointError, CheckpointMetadata, CheckpointTuple};
2749
2750 struct NoCheckpointCheckpointer;
2751
2752 #[async_trait::async_trait]
2753 impl crate::checkpoint::CheckpointSaver for NoCheckpointCheckpointer {
2754 async fn get_tuple(
2755 &self,
2756 _config: &crate::config::RunnableConfig,
2757 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
2758 Ok(None)
2759 }
2760
2761 async fn list(
2762 &self,
2763 _config: &crate::config::RunnableConfig,
2764 _filter: Option<crate::checkpoint::CheckpointFilter>,
2765 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
2766 Ok(Vec::new())
2767 }
2768
2769 async fn put(
2770 &self,
2771 _config: &crate::config::RunnableConfig,
2772 _checkpoint: Checkpoint,
2773 _metadata: CheckpointMetadata,
2774 ) -> Result<crate::config::RunnableConfig, CheckpointError> {
2775 Ok(crate::config::RunnableConfig::new())
2776 }
2777
2778 async fn put_writes(
2779 &self,
2780 _config: &crate::config::RunnableConfig,
2781 _writes: Vec<crate::checkpoint::PendingWrite>,
2782 _task_id: &str,
2783 ) -> Result<(), CheckpointError> {
2784 Ok(())
2785 }
2786 }
2787
2788 let nodes = {
2789 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2790 nodes.insert("a".to_string(), mock_node("a"));
2791 nodes
2792 };
2793
2794 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2795 nodes,
2796 TriggerTable::new(),
2797 IndexMap::new(),
2798 vec![],
2799 vec![],
2800 Some(Arc::new(NoCheckpointCheckpointer)),
2801 vec![],
2802 );
2803
2804 let config = RunnableConfig::new();
2805 let update = StateUpdate {
2806 update: StateDummyUpdate,
2807 label: None,
2808 as_node: None,
2809 };
2810
2811 let result = compiled.update_state(&config, update).await;
2812 assert!(result.is_err());
2813 let err = result.unwrap_err();
2814 assert!(err.is_checkpoint());
2815 assert!(
2816 err.to_string().contains("no checkpoint found"),
2817 "Expected 'no checkpoint found' error, got: {err}"
2818 );
2819 }
2820
2821 #[tokio::test]
2822 #[expect(
2823 clippy::too_many_lines,
2824 reason = "mock checkpointer boilerplate inflates line count; extraction would hurt readability"
2825 )]
2826 async fn test_update_state_success() {
2827 use crate::checkpoint::{
2828 Checkpoint, CheckpointError, CheckpointMetadata, CheckpointSource, CheckpointTuple,
2829 };
2830 use std::collections::HashMap;
2831 use std::sync::{Arc, Mutex};
2832
2833 #[derive(Clone)]
2834 enum ObservedCall {
2835 Put { source: CheckpointSource, step: i64 },
2836 }
2837
2838 struct MockCheckpointer {
2839 observed: Arc<Mutex<Vec<ObservedCall>>>,
2840 }
2841
2842 #[async_trait::async_trait]
2843 impl crate::checkpoint::CheckpointSaver for MockCheckpointer {
2844 async fn get_tuple(
2845 &self,
2846 _config: &crate::config::RunnableConfig,
2847 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
2848 Ok(Some(CheckpointTuple {
2849 config: crate::config::RunnableConfig::new(),
2850 checkpoint: Checkpoint {
2851 id: "cp_123".to_string(),
2852 channel_values: serde_json::Value::Null,
2853 channel_versions: HashMap::new(),
2854 versions_seen: HashMap::new(),
2855 pending_tasks: Vec::new(),
2856 pending_sends: Vec::new(),
2857 pending_interrupts: Vec::new(),
2858 schema_version: 1,
2859 created_at: "2024-01-01T00:00:00Z".to_string(),
2860 v: 1,
2861 new_versions: HashMap::new(),
2862 counters_since_delta_snapshot: HashMap::new(),
2863 },
2864 metadata: CheckpointMetadata {
2865 source: CheckpointSource::Loop,
2866 step: 5,
2867 writes: HashMap::new(),
2868 parents: HashMap::new(),
2869 run_id: "run_abc".to_string(),
2870 },
2871 pending_writes: Vec::new(),
2872 parent_config: None,
2873 }))
2874 }
2875
2876 async fn list(
2877 &self,
2878 _config: &crate::config::RunnableConfig,
2879 _filter: Option<crate::checkpoint::CheckpointFilter>,
2880 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
2881 Ok(Vec::new())
2882 }
2883
2884 async fn put(
2885 &self,
2886 _config: &crate::config::RunnableConfig,
2887 _checkpoint: Checkpoint,
2888 metadata: CheckpointMetadata,
2889 ) -> Result<crate::config::RunnableConfig, CheckpointError> {
2890 self.observed
2891 .lock()
2892 .unwrap_or_else(std::sync::PoisonError::into_inner)
2893 .push(ObservedCall::Put {
2894 source: metadata.source,
2895 step: metadata.step,
2896 });
2897 Ok(crate::config::RunnableConfig::new())
2898 }
2899
2900 async fn put_writes(
2901 &self,
2902 _config: &crate::config::RunnableConfig,
2903 _writes: Vec<crate::checkpoint::PendingWrite>,
2904 _task_id: &str,
2905 ) -> Result<(), CheckpointError> {
2906 Ok(())
2907 }
2908 }
2909
2910 let observed = Arc::new(Mutex::new(Vec::new()));
2911 let nodes = {
2912 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2913 nodes.insert("a".to_string(), mock_node("a"));
2914 nodes
2915 };
2916
2917 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2918 nodes,
2919 TriggerTable::new(),
2920 IndexMap::new(),
2921 vec![],
2922 vec![],
2923 Some(Arc::new(MockCheckpointer {
2924 observed: Arc::clone(&observed),
2925 })),
2926 vec![],
2927 );
2928
2929 let config = RunnableConfig::new();
2930 let update = StateUpdate {
2931 update: StateDummyUpdate,
2932 label: Some("manual fix".to_string()),
2933 as_node: Some("admin".to_string()),
2934 };
2935
2936 let result = compiled.update_state(&config, update).await;
2937 assert!(result.is_ok(), "update_state should succeed");
2938
2939 let calls = observed
2941 .lock()
2942 .unwrap_or_else(std::sync::PoisonError::into_inner);
2943 assert_eq!(calls.len(), 1, "Expected exactly one put call");
2944 match &calls[0] {
2945 ObservedCall::Put { source, step } => {
2946 assert!(
2947 matches!(source, CheckpointSource::Update),
2948 "Expected Update source, got {source:?}"
2949 );
2950 assert_eq!(*step, 6, "Expected step to be incremented from 5 to 6");
2951 }
2952 }
2953 }
2954
2955 #[tokio::test]
2956 async fn test_bulk_update_state_no_checkpointer() {
2957 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2958 nodes.insert("a".to_string(), mock_node("a"));
2959
2960 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2961 nodes,
2962 TriggerTable::new(),
2963 IndexMap::new(),
2964 vec![],
2965 vec![],
2966 None,
2967 vec![],
2968 );
2969 let config = RunnableConfig::new();
2970
2971 let updates = vec![StateUpdate {
2972 update: StateDummyUpdate,
2973 label: None,
2974 as_node: None,
2975 }];
2976
2977 let result = compiled.bulk_update_state(&config, updates).await;
2978 assert!(result.is_err());
2979 assert!(result.unwrap_err().is_checkpoint());
2980 }
2981
2982 #[test]
2983 fn test_state_update_creation() {
2984 let update: StateUpdate<StateDummy> = StateUpdate {
2985 update: StateDummyUpdate,
2986 label: Some("test update".to_string()),
2987 as_node: Some("my_node".to_string()),
2988 };
2989
2990 assert!(update.label.is_some());
2991 assert!(update.as_node.is_some());
2992 }
2993
2994 #[test]
2995 fn test_subgraph_info_creation() {
2996 let info = SubgraphInfo {
2997 name: "my_subgraph".to_string(),
2998 persistence: crate::subgraph::SubgraphPersistence::Inherit,
2999 };
3000
3001 assert_eq!(info.name, "my_subgraph");
3002 }
3003
3004 #[test]
3005 fn test_state_filter_default() {
3006 let filter = StateFilter::default();
3007 assert!(filter.after_step.is_none());
3008 assert!(filter.before_step.is_none());
3009 assert!(filter.limit.is_none());
3010 }
3011
3012 #[test]
3013 fn test_state_filter_with_values() {
3014 let filter = StateFilter {
3015 after_step: Some(5),
3016 before_step: Some(10),
3017 limit: Some(20),
3018 };
3019
3020 assert_eq!(filter.after_step, Some(5));
3021 assert_eq!(filter.before_step, Some(10));
3022 assert_eq!(filter.limit, Some(20));
3023 }
3024
3025 #[tokio::test]
3026 async fn test_stream_values_mode() {
3027 use futures::StreamExt;
3028
3029 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3031 nodes.insert("node_a".to_string(), mock_node("node_a"));
3032
3033 let mut trigger_table = TriggerTable::new();
3034 trigger_table.add_incoming(
3036 "node_a".to_string(),
3037 crate::edge::TriggerSource::Edge {
3038 from: crate::edge::START.to_string(),
3039 },
3040 );
3041
3042 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3043 nodes,
3044 trigger_table,
3045 IndexMap::new(),
3046 vec![],
3047 vec![],
3048 None,
3049 vec![],
3050 );
3051 let config = RunnableConfig::new();
3052
3053 let handle = compiled
3054 .stream(StateDummy, &config, StreamMode::Values)
3055 .await
3056 .expect("stream should succeed");
3057
3058 let mut events = Vec::new();
3060 let mut stream = handle.stream;
3061 while let Some(result) = stream.next().await {
3062 events.push(result.expect("stream event should be Ok"));
3063 }
3064
3065 let has_values = events
3067 .iter()
3068 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. }));
3069 let has_end = events
3070 .iter()
3071 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3072
3073 assert!(has_values, "Expected Values events in Values mode");
3074 assert!(has_end, "Expected End event");
3075 }
3076
3077 #[tokio::test]
3078 async fn test_stream_updates_mode() {
3079 use futures::StreamExt;
3080
3081 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3083 nodes.insert("node_a".to_string(), mock_node("node_a"));
3084
3085 let mut trigger_table = TriggerTable::new();
3086 trigger_table.add_incoming(
3087 "node_a".to_string(),
3088 crate::edge::TriggerSource::Edge {
3089 from: crate::edge::START.to_string(),
3090 },
3091 );
3092
3093 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3094 nodes,
3095 trigger_table,
3096 IndexMap::new(),
3097 vec![],
3098 vec![],
3099 None,
3100 vec![],
3101 );
3102 let config = RunnableConfig::new();
3103
3104 let handle = compiled
3105 .stream(StateDummy, &config, StreamMode::Updates)
3106 .await
3107 .expect("stream should succeed");
3108
3109 let mut events = Vec::new();
3111 let mut stream = handle.stream;
3112 while let Some(result) = stream.next().await {
3113 events.push(result.expect("stream event should be Ok"));
3114 }
3115
3116 let has_updates = events
3118 .iter()
3119 .any(|e| matches!(e, crate::stream::StreamEvent::Updates { .. }));
3120 let has_end = events
3121 .iter()
3122 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3123
3124 assert!(has_updates, "Expected Updates events in Updates mode");
3125 assert!(has_end, "Expected End event");
3126 }
3127
3128 #[tokio::test]
3129 async fn test_stream_debug_mode() {
3130 use futures::StreamExt;
3131
3132 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3134 nodes.insert("node_a".to_string(), mock_node("node_a"));
3135
3136 let mut trigger_table = TriggerTable::new();
3137 trigger_table.add_incoming(
3138 "node_a".to_string(),
3139 crate::edge::TriggerSource::Edge {
3140 from: crate::edge::START.to_string(),
3141 },
3142 );
3143
3144 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3145 nodes,
3146 trigger_table,
3147 IndexMap::new(),
3148 vec![],
3149 vec![],
3150 None,
3151 vec![],
3152 );
3153 let config = RunnableConfig::new();
3154
3155 let handle = compiled
3156 .stream(StateDummy, &config, StreamMode::Debug)
3157 .await
3158 .expect("stream should succeed");
3159
3160 let mut events = Vec::new();
3162 let mut stream = handle.stream;
3163 while let Some(result) = stream.next().await {
3164 events.push(result.expect("stream event should be Ok"));
3165 }
3166
3167 let has_debug = events
3169 .iter()
3170 .any(|e| matches!(e, crate::stream::StreamEvent::Debug(_)));
3171 let has_end = events
3172 .iter()
3173 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3174
3175 assert!(has_debug, "Expected Debug events in Debug mode");
3176 assert!(has_end, "Expected End event");
3177 }
3178
3179 #[tokio::test]
3180 async fn test_stream_end_event() {
3181 use futures::StreamExt;
3182
3183 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3185 nodes.insert("node_a".to_string(), mock_node("node_a"));
3186
3187 let mut trigger_table = TriggerTable::new();
3188 trigger_table.add_incoming(
3189 "node_a".to_string(),
3190 crate::edge::TriggerSource::Edge {
3191 from: crate::edge::START.to_string(),
3192 },
3193 );
3194
3195 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3196 nodes,
3197 trigger_table,
3198 IndexMap::new(),
3199 vec![],
3200 vec![],
3201 None,
3202 vec![],
3203 );
3204 let config = RunnableConfig::new();
3205
3206 let handle = compiled
3207 .stream(StateDummy, &config, StreamMode::Values)
3208 .await
3209 .expect("stream should succeed");
3210
3211 let mut events = Vec::new();
3213 let mut stream = handle.stream;
3214 while let Some(result) = stream.next().await {
3215 events.push(result.expect("stream event should be Ok"));
3216 }
3217
3218 assert!(!events.is_empty(), "Stream should emit events");
3220
3221 let end_events: Vec<_> = events
3222 .iter()
3223 .filter_map(|e| {
3224 if let crate::stream::StreamEvent::End { output } = e {
3225 Some(output.clone())
3226 } else {
3227 None
3228 }
3229 })
3230 .collect();
3231
3232 assert!(!end_events.is_empty(), "Expected at least one End event");
3233
3234 for state in end_events {
3236 let _cloned_state = state.clone();
3237 }
3238 }
3239
3240 fn mock_node(name: &str) -> Arc<dyn crate::Node<StateDummy>> {
3241 NodeFnUpdate(|_s: &StateDummy| async move { Ok(StateDummyUpdate) }).into_node(name)
3242 }
3243
3244 #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
3245 #[serde(crate = "serde")]
3246 struct StateDummy;
3247
3248 impl crate::State for StateDummy {
3249 type Update = StateDummyUpdate;
3250 type FieldVersions = crate::state::FieldVersions;
3251
3252 fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
3253 crate::FieldsChanged(0)
3254 }
3255
3256 fn reset_ephemeral(&mut self) {}
3257 }
3258
3259 #[derive(Clone, Debug, Default, serde::Serialize)]
3260 struct StateDummyUpdate;
3261
3262 #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize, PartialEq)]
3267 #[serde(crate = "serde")]
3268 struct StateV2 {
3269 value: i32,
3270 label: String,
3271 }
3272
3273 impl crate::State for StateV2 {
3274 type Update = StateV2Update;
3275 type FieldVersions = crate::state::FieldVersions;
3276
3277 fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
3278 crate::FieldsChanged(0)
3279 }
3280
3281 fn reset_ephemeral(&mut self) {}
3282
3283 fn schema_version() -> u32 {
3284 2
3285 }
3286
3287 fn migrate(from_version: u32, value: serde_json::Value) -> serde_json::Value {
3288 let mut map = match value {
3289 serde_json::Value::Object(m) => m,
3290 other => return other,
3291 };
3292 if from_version < 2 {
3293 map.insert(
3294 "label".to_string(),
3295 serde_json::Value::String("migrated".to_string()),
3296 );
3297 }
3298 serde_json::Value::Object(map)
3299 }
3300 }
3301
3302 #[derive(Clone, Debug, Default, serde::Serialize)]
3303 struct StateV2Update;
3304
3305 #[test]
3306 fn test_deserialize_with_migration_applies_migration_when_versions_differ() {
3307 use std::collections::HashMap;
3308
3309 let checkpoint = crate::checkpoint::Checkpoint {
3311 id: "test_id".to_string(),
3312 channel_values: serde_json::json!({"value": 42}),
3313 channel_versions: HashMap::new(),
3314 versions_seen: HashMap::new(),
3315 pending_tasks: Vec::new(),
3316 pending_sends: Vec::new(),
3317 pending_interrupts: Vec::new(),
3318 schema_version: 1, created_at: "2024-01-01T00:00:00Z".to_string(),
3320 v: 1,
3321 new_versions: HashMap::new(),
3322 counters_since_delta_snapshot: HashMap::new(),
3323 };
3324
3325 let state: StateV2 = CompiledGraph::<StateV2>::deserialize_with_migration(&checkpoint)
3326 .expect("deserialization with migration should succeed");
3327
3328 assert_eq!(state.value, 42);
3330 assert_eq!(state.label, "migrated");
3331 }
3332
3333 #[test]
3334 fn test_deserialize_with_migration_skips_migration_when_versions_match() {
3335 use std::collections::HashMap;
3336
3337 let checkpoint = crate::checkpoint::Checkpoint {
3339 id: "test_id".to_string(),
3340 channel_values: serde_json::json!({"value": 7, "label": "original"}),
3341 channel_versions: HashMap::new(),
3342 versions_seen: HashMap::new(),
3343 pending_tasks: Vec::new(),
3344 pending_sends: Vec::new(),
3345 pending_interrupts: Vec::new(),
3346 schema_version: 2, created_at: "2024-01-01T00:00:00Z".to_string(),
3348 v: 1,
3349 new_versions: HashMap::new(),
3350 counters_since_delta_snapshot: HashMap::new(),
3351 };
3352
3353 let state: StateV2 = CompiledGraph::<StateV2>::deserialize_with_migration(&checkpoint)
3354 .expect("deserialization should succeed");
3355
3356 assert_eq!(state.value, 7);
3358 assert_eq!(state.label, "original");
3359 }
3360
3361 #[test]
3362 fn test_compile_config_default_is_empty() {
3363 let config = super::super::CompileConfig::default();
3364 assert!(config.interrupt_before.is_empty());
3365 assert!(config.interrupt_after.is_empty());
3366 }
3367
3368 #[test]
3369 fn test_compile_with_config_stores_interrupts() {
3370 let mut graph = super::super::StateGraph::<StateDummy>::new();
3371 graph
3372 .add_node_simple(
3373 "human_review",
3374 NodeFnUpdate(
3375 |_s: &StateDummy| -> std::pin::Pin<
3376 Box<
3377 dyn std::future::Future<
3378 Output = Result<StateDummyUpdate, crate::JunctureError>,
3379 > + Send,
3380 >,
3381 > { Box::pin(async move { Ok(StateDummyUpdate) }) },
3382 ),
3383 )
3384 .unwrap();
3385 graph.set_entry_point("human_review");
3386 graph.set_finish_point("human_review");
3387
3388 let config = super::super::CompileConfig {
3389 interrupt_before: vec!["human_review".to_string()],
3390 interrupt_after: vec!["human_review".to_string()],
3391 };
3392
3393 let compiled = graph.compile_with_config(config).unwrap();
3394 assert_eq!(compiled.nodes().len(), 1);
3395 }
3396
3397 #[test]
3398 fn test_effective_config_uses_compile_time_defaults() {
3399 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3400 nodes.insert("a".to_string(), mock_node("a"));
3401
3402 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3403 nodes,
3404 TriggerTable::new(),
3405 IndexMap::new(),
3406 vec!["node_a".to_string()],
3407 vec!["node_b".to_string()],
3408 None,
3409 vec![],
3410 );
3411
3412 let config = RunnableConfig::new();
3414 let effective = compiled.effective_config(&config);
3415 assert_eq!(effective.interrupt_before, Some(vec!["node_a".to_string()]));
3416 assert_eq!(effective.interrupt_after, Some(vec!["node_b".to_string()]));
3417 }
3418
3419 #[test]
3420 fn test_effective_config_runtime_overrides_compile_time() {
3421 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3422 nodes.insert("a".to_string(), mock_node("a"));
3423
3424 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3425 nodes,
3426 TriggerTable::new(),
3427 IndexMap::new(),
3428 vec!["compile_before".to_string()],
3429 vec!["compile_after".to_string()],
3430 None,
3431 vec![],
3432 );
3433
3434 let config = RunnableConfig::new()
3436 .with_interrupt_before(vec!["runtime_before".to_string()])
3437 .with_interrupt_after(vec!["runtime_after".to_string()]);
3438
3439 let effective = compiled.effective_config(&config);
3440 assert_eq!(
3441 effective.interrupt_before,
3442 Some(vec!["runtime_before".to_string()])
3443 );
3444 assert_eq!(
3445 effective.interrupt_after,
3446 Some(vec!["runtime_after".to_string()])
3447 );
3448 }
3449
3450 #[test]
3451 fn test_effective_config_empty_compile_time_no_override() {
3452 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3453 nodes.insert("a".to_string(), mock_node("a"));
3454
3455 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3456 nodes,
3457 TriggerTable::new(),
3458 IndexMap::new(),
3459 vec![],
3460 vec![],
3461 None,
3462 vec![],
3463 );
3464
3465 let config = RunnableConfig::new();
3467 let effective = compiled.effective_config(&config);
3468 assert!(effective.interrupt_before.is_none());
3469 assert!(effective.interrupt_after.is_none());
3470 }
3471
3472 #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize, PartialEq)]
3476 #[serde(crate = "serde")]
3477 struct MultiFieldState {
3478 messages: Vec<String>,
3479 count: i32,
3480 label: String,
3481 }
3482
3483 impl crate::State for MultiFieldState {
3484 type Update = MultiFieldStateUpdate;
3485 type FieldVersions = crate::state::FieldVersions;
3486
3487 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
3488 let mut mask = 0u64;
3489 if let Some(messages) = update.messages {
3490 self.messages = messages;
3491 mask |= 1;
3492 }
3493 if let Some(count) = update.count {
3494 self.count = count;
3495 mask |= 1 << 1;
3496 }
3497 if let Some(label) = update.label {
3498 self.label = label;
3499 mask |= 1 << 2;
3500 }
3501 crate::FieldsChanged(mask)
3502 }
3503
3504 fn reset_ephemeral(&mut self) {}
3505 }
3506
3507 #[derive(Clone, Debug, Default, serde::Serialize)]
3508 struct MultiFieldStateUpdate {
3509 messages: Option<Vec<String>>,
3510 count: Option<i32>,
3511 label: Option<String>,
3512 }
3513
3514 fn multi_field_node(name: &str) -> Arc<dyn crate::Node<MultiFieldState>> {
3515 NodeFnUpdate(|_s: &MultiFieldState| async move {
3516 Ok(MultiFieldStateUpdate {
3517 messages: Some(vec!["hello".to_string()]),
3518 count: Some(1),
3519 label: Some("updated".to_string()),
3520 })
3521 })
3522 .into_node(name)
3523 }
3524
3525 fn build_multi_field_graph() -> CompiledGraph<MultiFieldState> {
3526 let mut nodes: IndexMap<String, Arc<dyn crate::Node<MultiFieldState>>> = IndexMap::new();
3527 nodes.insert("node_a".to_string(), multi_field_node("node_a"));
3528
3529 let mut trigger_table = TriggerTable::new();
3530 trigger_table.add_incoming(
3531 "node_a".to_string(),
3532 crate::edge::TriggerSource::Edge {
3533 from: crate::edge::START.to_string(),
3534 },
3535 );
3536
3537 CompiledGraph::new(
3538 nodes,
3539 trigger_table,
3540 IndexMap::new(),
3541 vec![],
3542 vec![],
3543 None,
3544 vec![],
3545 )
3546 }
3547
3548 #[tokio::test]
3549 async fn test_stream_with_config_no_output_keys_emits_values() {
3550 use futures::StreamExt;
3551
3552 let compiled = build_multi_field_graph();
3553 let config = RunnableConfig::new();
3554
3555 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values);
3556
3557 let handle = compiled
3558 .stream_with_config(
3559 MultiFieldState {
3560 messages: vec![],
3561 count: 0,
3562 label: String::new(),
3563 },
3564 &config,
3565 stream_config,
3566 )
3567 .await
3568 .expect("stream_with_config should succeed");
3569
3570 let mut events = Vec::new();
3571 let mut stream = handle.stream;
3572 while let Some(result) = stream.next().await {
3573 events.push(result.expect("stream event should be Ok"));
3574 }
3575
3576 let has_values = events
3578 .iter()
3579 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. }));
3580 assert!(has_values, "Expected Values events without output_keys");
3581 }
3582
3583 #[tokio::test]
3584 async fn test_stream_with_config_output_keys_emits_filtered_values() {
3585 use futures::StreamExt;
3586
3587 let compiled = build_multi_field_graph();
3588 let config = RunnableConfig::new();
3589
3590 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values)
3591 .with_output_keys(vec!["messages".to_string()]);
3592
3593 let handle = compiled
3594 .stream_with_config(
3595 MultiFieldState {
3596 messages: vec![],
3597 count: 0,
3598 label: String::new(),
3599 },
3600 &config,
3601 stream_config,
3602 )
3603 .await
3604 .expect("stream_with_config should succeed");
3605
3606 let mut events = Vec::new();
3607 let mut stream = handle.stream;
3608 while let Some(result) = stream.next().await {
3609 events.push(result.expect("stream event should be Ok"));
3610 }
3611
3612 let filtered: Vec<_> = events
3614 .iter()
3615 .filter_map(|e| {
3616 if let crate::stream::StreamEvent::FilteredValues { data, .. } = e {
3617 Some(data.clone())
3618 } else {
3619 None
3620 }
3621 })
3622 .collect();
3623
3624 assert!(
3625 !filtered.is_empty(),
3626 "Expected FilteredValues events with output_keys set"
3627 );
3628
3629 for data in &filtered {
3630 assert!(
3632 data.get("messages").is_some(),
3633 "FilteredValues should contain 'messages' key"
3634 );
3635 assert!(
3636 data.get("count").is_none(),
3637 "FilteredValues should not contain 'count' key"
3638 );
3639 assert!(
3640 data.get("label").is_none(),
3641 "FilteredValues should not contain 'label' key"
3642 );
3643 }
3644 }
3645
3646 #[tokio::test]
3647 async fn test_stream_delegates_to_stream_with_config() {
3648 use futures::StreamExt;
3649
3650 let compiled = build_multi_field_graph();
3651 let config = RunnableConfig::new();
3652
3653 let handle = compiled
3656 .stream(
3657 MultiFieldState {
3658 messages: vec![],
3659 count: 0,
3660 label: String::new(),
3661 },
3662 &config,
3663 StreamMode::Values,
3664 )
3665 .await
3666 .expect("stream should succeed");
3667
3668 let mut events = Vec::new();
3669 let mut stream = handle.stream;
3670 while let Some(result) = stream.next().await {
3671 events.push(result.expect("stream event should be Ok"));
3672 }
3673
3674 let has_values = events
3675 .iter()
3676 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. }));
3677 let has_end = events
3678 .iter()
3679 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3680
3681 assert!(has_values, "stream() should emit Values events");
3682 assert!(has_end, "stream() should emit End event");
3683 }
3684
3685 #[tokio::test]
3686 async fn test_stream_with_config_output_keys_multiple_keys() {
3687 use futures::StreamExt;
3688
3689 let compiled = build_multi_field_graph();
3690 let config = RunnableConfig::new();
3691
3692 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values)
3693 .with_output_keys(vec!["messages".to_string(), "count".to_string()]);
3694
3695 let handle = compiled
3696 .stream_with_config(
3697 MultiFieldState {
3698 messages: vec![],
3699 count: 0,
3700 label: String::new(),
3701 },
3702 &config,
3703 stream_config,
3704 )
3705 .await
3706 .expect("stream_with_config should succeed");
3707
3708 let mut events = Vec::new();
3709 let mut stream = handle.stream;
3710 while let Some(result) = stream.next().await {
3711 events.push(result.expect("stream event should be Ok"));
3712 }
3713
3714 let filtered: Vec<_> = events
3715 .iter()
3716 .filter_map(|e| {
3717 if let crate::stream::StreamEvent::FilteredValues { data, .. } = e {
3718 Some(data.clone())
3719 } else {
3720 None
3721 }
3722 })
3723 .collect();
3724
3725 assert!(!filtered.is_empty());
3726
3727 for data in &filtered {
3728 assert!(
3729 data.get("messages").is_some(),
3730 "Should contain 'messages' key"
3731 );
3732 assert!(data.get("count").is_some(), "Should contain 'count' key");
3733 assert!(
3734 data.get("label").is_none(),
3735 "Should not contain 'label' key"
3736 );
3737 }
3738 }
3739
3740 #[tokio::test]
3741 async fn test_stream_with_config_updates_mode_output_keys() {
3742 use futures::StreamExt;
3743
3744 let compiled = build_multi_field_graph();
3745 let config = RunnableConfig::new();
3746
3747 let stream_config = crate::stream::StreamConfig::new(StreamMode::Updates)
3748 .with_output_keys(vec!["messages".to_string()]);
3749
3750 let handle = compiled
3751 .stream_with_config(
3752 MultiFieldState {
3753 messages: vec![],
3754 count: 0,
3755 label: String::new(),
3756 },
3757 &config,
3758 stream_config,
3759 )
3760 .await
3761 .expect("stream_with_config should succeed");
3762
3763 let mut events = Vec::new();
3764 let mut stream = handle.stream;
3765 while let Some(result) = stream.next().await {
3766 events.push(result.expect("stream event should be Ok"));
3767 }
3768
3769 let filtered_updates: Vec<_> = events
3771 .iter()
3772 .filter_map(|e| {
3773 if let crate::stream::StreamEvent::FilteredUpdates { data, .. } = e {
3774 Some(data.clone())
3775 } else {
3776 None
3777 }
3778 })
3779 .collect();
3780
3781 assert!(
3782 !filtered_updates.is_empty(),
3783 "Expected FilteredUpdates events in Updates mode with output_keys"
3784 );
3785
3786 for data in &filtered_updates {
3787 assert!(
3788 data.get("messages").is_some(),
3789 "FilteredUpdates should contain 'messages' key"
3790 );
3791 assert!(
3793 data.get("count").is_none(),
3794 "FilteredUpdates should not contain 'count' key"
3795 );
3796 assert!(
3797 data.get("label").is_none(),
3798 "FilteredUpdates should not contain 'label' key"
3799 );
3800 }
3801 }
3802
3803 #[test]
3804 fn test_filter_json_by_keys() {
3805 let json = serde_json::json!({
3806 "messages": ["hello"],
3807 "count": 42,
3808 "label": "test"
3809 });
3810
3811 let filtered = crate::stream::filter_json_by_keys(json, &["messages".to_string()]);
3812 assert!(filtered.get("messages").is_some());
3813 assert!(filtered.get("count").is_none());
3814 assert!(filtered.get("label").is_none());
3815 }
3816
3817 #[test]
3818 fn test_filter_json_by_keys_multiple() {
3819 let json = serde_json::json!({
3820 "a": 1,
3821 "b": 2,
3822 "c": 3
3823 });
3824
3825 let filtered =
3826 crate::stream::filter_json_by_keys(json, &["a".to_string(), "c".to_string()]);
3827 assert_eq!(filtered.get("a").unwrap(), 1);
3828 assert!(filtered.get("b").is_none());
3829 assert_eq!(filtered.get("c").unwrap(), 3);
3830 }
3831
3832 #[test]
3833 fn test_filter_json_by_keys_empty_keys() {
3834 let json = serde_json::json!({"a": 1});
3835 let filtered = crate::stream::filter_json_by_keys(json.clone(), &[]);
3836 assert_eq!(json, filtered);
3837 }
3838
3839 #[test]
3840 fn test_filter_json_by_keys_non_object() {
3841 let json = serde_json::json!("hello");
3842 let filtered = crate::stream::filter_json_by_keys(json.clone(), &["a".to_string()]);
3843 assert_eq!(json, filtered);
3844 }
3845
3846 #[test]
3849 fn test_stream_event_namespace_custom_has_ns() {
3850 let event: StreamEvent<StateDummy> = StreamEvent::Custom {
3851 node: "sub_node".to_string(),
3852 data: serde_json::json!({"x": 1}),
3853 ns: vec!["child_graph".to_string(), "sub_node:uuid".to_string()],
3854 };
3855 assert_eq!(event.namespace().len(), 2);
3856 assert_eq!(event.namespace()[0], "child_graph");
3857 }
3858
3859 #[test]
3860 fn test_stream_event_namespace_messages_has_ns() {
3861 let event: StreamEvent<StateDummy> = StreamEvent::Messages {
3862 chunk: crate::stream::MessageChunk {
3863 content: "hi".to_string(),
3864 tool_call_chunks: vec![],
3865 usage_delta: None,
3866 },
3867 metadata: crate::stream::MessageStreamMetadata {
3868 node: "llm".to_string(),
3869 model: "gpt-4".to_string(),
3870 tags: vec![],
3871 ns: vec!["child_graph".to_string()],
3872 },
3873 };
3874 assert_eq!(event.namespace().len(), 1);
3875 assert_eq!(event.namespace()[0], "child_graph");
3876 }
3877
3878 #[test]
3879 fn test_stream_event_namespace_interrupt_has_ns() {
3880 let event: StreamEvent<StateDummy> = StreamEvent::Interrupt {
3881 node: "review".to_string(),
3882 payload: serde_json::Value::Null,
3883 resumable: true,
3884 ns: vec!["subgraph_a".to_string()],
3885 };
3886 assert_eq!(event.namespace().len(), 1);
3887 }
3888
3889 #[test]
3890 fn test_stream_event_namespace_values_is_empty() {
3891 let event: StreamEvent<StateDummy> = StreamEvent::Values {
3892 state: StateDummy,
3893 step: 0,
3894 };
3895 assert!(event.namespace().is_empty());
3896 }
3897
3898 #[test]
3899 fn test_stream_event_namespace_updates_is_empty() {
3900 let event: StreamEvent<StateDummy> = StreamEvent::Updates {
3901 node: "n".to_string(),
3902 update: StateDummyUpdate,
3903 step: 0,
3904 };
3905 assert!(event.namespace().is_empty());
3906 }
3907
3908 #[test]
3909 fn test_stream_event_namespace_end_is_empty() {
3910 let event: StreamEvent<StateDummy> = StreamEvent::End { output: StateDummy };
3911 assert!(event.namespace().is_empty());
3912 }
3913
3914 #[test]
3915 fn test_stream_event_namespace_task_start_is_empty() {
3916 let event: StreamEvent<StateDummy> = StreamEvent::TaskStart {
3917 node: "n".to_string(),
3918 task_id: "t".to_string(),
3919 step: 0,
3920 };
3921 assert!(event.namespace().is_empty());
3922 }
3923
3924 #[test]
3925 fn test_stream_event_namespace_debug_is_empty() {
3926 let event: StreamEvent<StateDummy> =
3927 StreamEvent::Debug(crate::stream::DebugEvent::SuperstepStart {
3928 step: 0,
3929 pending_nodes: vec![],
3930 });
3931 assert!(event.namespace().is_empty());
3932 }
3933
3934 #[test]
3939 fn test_subgraph_filter_default_excludes_subgraph_events() {
3940 let subgraph_event: StreamEvent<StateDummy> = StreamEvent::Custom {
3942 node: "sub_node".to_string(),
3943 data: serde_json::json!({}),
3944 ns: vec!["child_graph".to_string()],
3945 };
3946
3947 let top_level_event: StreamEvent<StateDummy> = StreamEvent::Values {
3949 state: StateDummy,
3950 step: 0,
3951 };
3952
3953 let include_subgraphs = false;
3954 assert!(top_level_event.namespace().is_empty());
3958 assert!(!subgraph_event.namespace().is_empty());
3960
3961 let ns = subgraph_event.namespace();
3964 let should_skip = !ns.is_empty() && !include_subgraphs;
3965 assert!(
3966 should_skip,
3967 "subgraph events should be skipped when include_subgraphs=false"
3968 );
3969
3970 let ns = top_level_event.namespace();
3971 let should_skip = !ns.is_empty() && !include_subgraphs;
3972 assert!(!should_skip, "top-level events should not be skipped");
3973 }
3974
3975 #[test]
3978 fn test_subgraph_filter_include_all_passes() {
3979 let subgraph_event: StreamEvent<StateDummy> = StreamEvent::Custom {
3980 node: "sub_node".to_string(),
3981 data: serde_json::json!({}),
3982 ns: vec!["child_graph".to_string()],
3983 };
3984
3985 let include_subgraphs = true;
3986 let subgraph_filter: Option<Vec<String>> = None;
3987
3988 let ns = subgraph_event.namespace();
3989 let should_skip = !ns.is_empty() && !include_subgraphs;
3990 assert!(
3991 !should_skip,
3992 "include_subgraphs=true should not skip subgraph events"
3993 );
3994
3995 assert!(subgraph_filter.is_none());
3997 }
3998
3999 #[test]
4001 fn test_subgraph_filter_by_name_passes_matching() {
4002 let matching_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4003 node: "sub_node".to_string(),
4004 data: serde_json::json!({}),
4005 ns: vec!["child_a".to_string()],
4006 };
4007
4008 let non_matching_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4009 node: "sub_node".to_string(),
4010 data: serde_json::json!({}),
4011 ns: vec!["child_b".to_string()],
4012 };
4013
4014 let include_subgraphs = true;
4015 let subgraph_filter = Some(vec!["child_a".to_string()]);
4016
4017 let ns = matching_event.namespace();
4019 let should_skip = if ns.is_empty() {
4020 false
4021 } else if !include_subgraphs {
4022 true
4023 } else if let Some(ref filter) = subgraph_filter {
4024 ns.first().is_some_and(|first| !filter.contains(first))
4025 } else {
4026 false
4027 };
4028 assert!(!should_skip, "matching subgraph event should pass filter");
4029
4030 let ns = non_matching_event.namespace();
4032 let should_skip = if ns.is_empty() {
4033 false
4034 } else if !include_subgraphs {
4035 true
4036 } else if let Some(ref filter) = subgraph_filter {
4037 ns.first().is_some_and(|first| !filter.contains(first))
4038 } else {
4039 false
4040 };
4041 assert!(
4042 should_skip,
4043 "non-matching subgraph event should be filtered out"
4044 );
4045 }
4046
4047 #[test]
4050 fn test_subgraph_filter_applies_to_messages_events() {
4051 let subgraph_messages: StreamEvent<StateDummy> = StreamEvent::Messages {
4052 chunk: crate::stream::MessageChunk {
4053 content: "token".to_string(),
4054 tool_call_chunks: vec![],
4055 usage_delta: None,
4056 },
4057 metadata: crate::stream::MessageStreamMetadata {
4058 node: "llm".to_string(),
4059 model: "gpt-4".to_string(),
4060 tags: vec![],
4061 ns: vec!["sub_llm".to_string()],
4062 },
4063 };
4064
4065 let include_subgraphs = false;
4066 assert!(!subgraph_messages.namespace().is_empty());
4067
4068 let ns = subgraph_messages.namespace();
4069 let should_skip = !ns.is_empty() && !include_subgraphs;
4070 assert!(
4071 should_skip,
4072 "subgraph Messages events should be filtered when include_subgraphs=false"
4073 );
4074 }
4075
4076 #[test]
4079 fn test_subgraph_filter_applies_to_interrupt_events() {
4080 let subgraph_interrupt: StreamEvent<StateDummy> = StreamEvent::Interrupt {
4081 node: "review".to_string(),
4082 payload: serde_json::Value::Null,
4083 resumable: true,
4084 ns: vec!["sub_review".to_string()],
4085 };
4086
4087 let include_subgraphs = false;
4088 assert!(!subgraph_interrupt.namespace().is_empty());
4089
4090 let ns = subgraph_interrupt.namespace();
4091 let should_skip = !ns.is_empty() && !include_subgraphs;
4092 assert!(
4093 should_skip,
4094 "subgraph Interrupt events should be filtered when include_subgraphs=false"
4095 );
4096 }
4097
4098 #[test]
4103 fn test_nested_subgraph_default_excludes_nested_events() {
4104 let nested_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4106 node: "deep_node".to_string(),
4107 data: serde_json::json!({}),
4108 ns: vec!["parent".to_string(), "child".to_string()],
4109 };
4110
4111 let include_subgraphs = false;
4112
4113 assert_eq!(nested_event.namespace(), &["parent", "child"]);
4115 assert!(!nested_event.namespace().is_empty());
4116
4117 let should_skip = !nested_event.namespace().is_empty() && !include_subgraphs;
4118 assert!(
4119 should_skip,
4120 "nested subgraph events should be skipped when include_subgraphs=false"
4121 );
4122 }
4123
4124 #[test]
4127 fn test_nested_subgraph_include_all_passes() {
4128 let emitter_ns = vec!["parent".to_string(), "child".to_string()];
4131
4132 let nested_custom_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4134 node: "inner".to_string(),
4135 data: serde_json::json!({"k": "v"}),
4136 ns: emitter_ns,
4137 };
4138
4139 let include_subgraphs = true;
4140 let subgraph_filter: Option<Vec<String>> = None;
4141
4142 let should_skip = !nested_custom_event.namespace().is_empty() && !include_subgraphs;
4143 assert!(
4144 !should_skip,
4145 "nested subgraph events should pass when include_subgraphs=true"
4146 );
4147 assert!(subgraph_filter.is_none());
4148 }
4149
4150 #[test]
4153 fn test_nested_subgraph_filter_matches_outermost_name() {
4154 let nested_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4156 node: "deep".to_string(),
4157 data: serde_json::json!({}),
4158 ns: vec![
4159 "parent".to_string(),
4160 "child".to_string(),
4161 "grandchild".to_string(),
4162 ],
4163 };
4164
4165 let include_subgraphs = true;
4166 let subgraph_filter = Some(vec!["parent".to_string()]);
4168
4169 let ns = nested_event.namespace();
4170 let should_skip = if ns.is_empty() {
4171 false
4172 } else if !include_subgraphs {
4173 true
4174 } else if let Some(ref filter) = subgraph_filter {
4175 ns.first().is_some_and(|first| !filter.contains(first))
4176 } else {
4177 false
4178 };
4179
4180 assert!(
4181 !should_skip,
4182 "nested event from parent should pass when parent is in filter"
4183 );
4184
4185 let subgraph_filter_other = Some(vec!["other".to_string()]);
4187 let should_skip_other = if ns.is_empty() {
4188 false
4189 } else if !include_subgraphs {
4190 true
4191 } else if let Some(ref filter) = subgraph_filter_other {
4192 ns.first().is_some_and(|first| !filter.contains(first))
4193 } else {
4194 false
4195 };
4196
4197 assert!(
4198 should_skip_other,
4199 "nested event should be skipped when outermost name does not match filter"
4200 );
4201 }
4202
4203 #[test]
4206 fn test_nested_subgraph_messages_filtering() {
4207 let nested_messages: StreamEvent<StateDummy> = StreamEvent::Messages {
4208 chunk: crate::stream::MessageChunk {
4209 content: "nested_token".to_string(),
4210 tool_call_chunks: vec![],
4211 usage_delta: None,
4212 },
4213 metadata: crate::stream::MessageStreamMetadata {
4214 node: "llm".to_string(),
4215 model: "gpt-4".to_string(),
4216 tags: vec![],
4217 ns: vec!["outer".to_string(), "inner".to_string()],
4218 },
4219 };
4220
4221 let include_subgraphs = false;
4222
4223 assert_eq!(
4225 nested_messages.namespace(),
4226 &["outer", "inner"],
4227 "Messages events should expose full nested namespace via metadata.ns"
4228 );
4229
4230 let should_skip = !nested_messages.namespace().is_empty() && !include_subgraphs;
4231 assert!(
4232 should_skip,
4233 "nested subgraph Messages events should be filtered when include_subgraphs=false"
4234 );
4235
4236 let include_subgraphs_true = true;
4238 let should_pass = nested_messages.namespace().is_empty() || include_subgraphs_true;
4239 assert!(
4240 should_pass,
4241 "nested subgraph Messages events should pass when include_subgraphs=true"
4242 );
4243 }
4244
4245 #[test]
4248 fn test_subgraph_transformer_to_emitter_nested_ns() {
4249 let transformer = crate::SubgraphTransformer::new("child".to_string());
4250 let transformer = transformer.child_transformer("grandchild");
4251
4252 let (tx, _rx) = tokio::sync::mpsc::channel(16);
4253 let emitter = transformer.to_emitter::<StateDummy>(tx, crate::stream::StreamMode::Values);
4254
4255 assert_eq!(emitter.ns(), &["child", "grandchild"]);
4257 }
4258
4259 #[test]
4262 fn test_transformer_child_chain_three_levels() {
4263 use crate::stream::StreamEvent;
4264
4265 let grandparent = crate::SubgraphTransformer::new("grandparent".to_string());
4266 let parent = grandparent.child_transformer("parent");
4267 let child = parent.child_transformer("child");
4268
4269 let event = StreamEvent::<StateDummy>::TaskStart {
4271 node: "worker".to_string(),
4272 task_id: "t1".to_string(),
4273 step: 1,
4274 };
4275
4276 let result = child.transform(&event).expect("should pass filter");
4277 match result {
4278 StreamEvent::TaskStart { node, .. } => {
4279 assert_eq!(node, "grandparent/parent/child/worker");
4280 }
4281 other => panic!("expected TaskStart, got {other:?}"),
4282 }
4283
4284 let custom_event = StreamEvent::<StateDummy>::Custom {
4286 node: "agent".to_string(),
4287 data: serde_json::json!({}),
4288 ns: vec![],
4289 };
4290 let result = child.transform(&custom_event).expect("custom should pass");
4291 match result {
4292 StreamEvent::Custom { node, ns, .. } => {
4293 assert_eq!(node, "grandparent/parent/child/agent");
4294 assert_eq!(ns, vec!["grandparent", "parent", "child"]);
4295 }
4296 other => panic!("expected Custom, got {other:?}"),
4297 }
4298 }
4299
4300 #[test]
4303 fn test_stream_config_subgraph_builder_methods() {
4304 let cfg = crate::stream::StreamConfig::new(StreamMode::Values);
4305 assert!(!cfg.include_subgraphs);
4306 assert!(cfg.subgraph_filter.is_none());
4307
4308 let cfg = cfg.with_subgraphs(true);
4309 assert!(cfg.include_subgraphs);
4310
4311 let cfg = cfg.with_subgraph_filter(vec!["sub_a".to_string()]);
4312 assert_eq!(cfg.subgraph_filter.as_ref().map(Vec::len), Some(1));
4313 assert_eq!(
4314 cfg.subgraph_filter
4315 .as_ref()
4316 .and_then(|f| f.first().cloned()),
4317 Some("sub_a".to_string())
4318 );
4319 }
4320
4321 #[tokio::test]
4324 async fn test_stream_default_config_no_subgraph_events() {
4325 use futures::StreamExt;
4326
4327 let compiled = build_multi_field_graph();
4328 let config = RunnableConfig::new();
4329 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values);
4330
4331 let handle = compiled
4332 .stream_with_config(
4333 MultiFieldState {
4334 messages: vec![],
4335 count: 0,
4336 label: String::new(),
4337 },
4338 &config,
4339 stream_config,
4340 )
4341 .await
4342 .expect("stream_with_config should succeed");
4343
4344 let mut events = Vec::new();
4345 let mut stream = handle.stream;
4346 while let Some(result) = stream.next().await {
4347 events.push(result.expect("stream event should be Ok"));
4348 }
4349
4350 for event in &events {
4352 assert!(
4353 event.namespace().is_empty(),
4354 "Expected no subgraph events, but found one with ns: {:?}",
4355 event.namespace()
4356 );
4357 }
4358 }
4359
4360 fn build_two_step_graph() -> CompiledGraph<MultiFieldState> {
4366 let node_a = NodeFnUpdate(|s: &MultiFieldState| {
4367 let messages = s.messages.clone();
4368 let count = s.count;
4369 let label = s.label.clone();
4370 async move {
4371 Ok(MultiFieldStateUpdate {
4372 messages: Some(messages),
4373 count: Some(count + 1),
4374 label: Some(label),
4375 })
4376 }
4377 })
4378 .into_node("node_a");
4379
4380 let node_b = NodeFnUpdate(|s: &MultiFieldState| {
4381 let messages = s.messages.clone();
4382 let count = s.count;
4383 let label = s.label.clone();
4384 async move {
4385 Ok(MultiFieldStateUpdate {
4386 messages: Some(messages),
4387 count: Some(count + 10),
4388 label: Some(label),
4389 })
4390 }
4391 })
4392 .into_node("node_b");
4393
4394 let mut nodes: IndexMap<String, Arc<dyn crate::Node<MultiFieldState>>> = IndexMap::new();
4395 nodes.insert("node_a".to_string(), node_a);
4396 nodes.insert("node_b".to_string(), node_b);
4397
4398 let mut trigger_table = TriggerTable::new();
4399 trigger_table.add_incoming(
4401 "node_a".to_string(),
4402 crate::edge::TriggerSource::Edge {
4403 from: crate::edge::START.to_string(),
4404 },
4405 );
4406 trigger_table.add_outgoing(
4408 "node_a".to_string(),
4409 crate::edge::CompiledEdge::Fixed {
4410 target: "node_b".to_string(),
4411 },
4412 );
4413
4414 CompiledGraph::new(
4415 nodes,
4416 trigger_table,
4417 IndexMap::new(),
4418 vec![],
4419 vec![],
4420 None,
4421 vec![],
4422 )
4423 }
4424
4425 #[tokio::test]
4426 async fn test_resumption_skips_values_at_or_before_last_step() {
4427 use futures::StreamExt;
4428
4429 let compiled = build_two_step_graph();
4430 let config = RunnableConfig::new();
4431
4432 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, Some(0));
4433 let stream_config =
4434 crate::stream::StreamConfig::new(StreamMode::Values).with_resumption(resumption);
4435
4436 let handle = compiled
4437 .stream_with_config(
4438 MultiFieldState {
4439 messages: vec![],
4440 count: 0,
4441 label: String::new(),
4442 },
4443 &config,
4444 stream_config,
4445 )
4446 .await
4447 .expect("stream_with_config should succeed");
4448
4449 let mut events = Vec::new();
4450 let mut stream = handle.stream;
4451 while let Some(result) = stream.next().await {
4452 events.push(result.expect("stream event should be Ok"));
4453 }
4454
4455 let values_steps: Vec<usize> = events
4457 .iter()
4458 .filter_map(|e| match e {
4459 crate::stream::StreamEvent::Values { step, .. } => Some(*step),
4460 _ => None,
4461 })
4462 .collect();
4463
4464 assert!(
4465 !values_steps.contains(&0),
4466 "Values at step 0 should be skipped, got steps: {values_steps:?}"
4467 );
4468
4469 let has_end = events
4470 .iter()
4471 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4472 assert!(has_end, "End event must always be emitted");
4473 }
4474
4475 #[tokio::test]
4476 async fn test_resumption_allows_values_after_last_step() {
4477 use futures::StreamExt;
4478
4479 let compiled = build_two_step_graph();
4480 let config = RunnableConfig::new();
4481
4482 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, Some(5));
4483 let stream_config =
4484 crate::stream::StreamConfig::new(StreamMode::Values).with_resumption(resumption);
4485
4486 let handle = compiled
4487 .stream_with_config(
4488 MultiFieldState {
4489 messages: vec![],
4490 count: 0,
4491 label: String::new(),
4492 },
4493 &config,
4494 stream_config,
4495 )
4496 .await
4497 .expect("stream_with_config should succeed");
4498
4499 let mut events = Vec::new();
4500 let mut stream = handle.stream;
4501 while let Some(result) = stream.next().await {
4502 events.push(result.expect("stream event should be Ok"));
4503 }
4504
4505 let values_steps: Vec<usize> = events
4507 .iter()
4508 .filter_map(|e| match e {
4509 crate::stream::StreamEvent::Values { step, .. } => Some(*step),
4510 _ => None,
4511 })
4512 .collect();
4513
4514 assert!(
4515 values_steps.is_empty(),
4516 "All Values should be skipped with last_step=5, got steps: {values_steps:?}"
4517 );
4518
4519 let has_end = events
4520 .iter()
4521 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4522 assert!(
4523 has_end,
4524 "End event must always be emitted even when all steps are skipped"
4525 );
4526 }
4527
4528 #[tokio::test]
4529 async fn test_resumption_none_last_step_allows_all_events() {
4530 use futures::StreamExt;
4531
4532 let compiled = build_two_step_graph();
4533 let config = RunnableConfig::new();
4534
4535 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, None);
4536 let stream_config =
4537 crate::stream::StreamConfig::new(StreamMode::Values).with_resumption(resumption);
4538
4539 let handle = compiled
4540 .stream_with_config(
4541 MultiFieldState {
4542 messages: vec![],
4543 count: 0,
4544 label: String::new(),
4545 },
4546 &config,
4547 stream_config,
4548 )
4549 .await
4550 .expect("stream_with_config should succeed");
4551
4552 let mut events = Vec::new();
4553 let mut stream = handle.stream;
4554 while let Some(result) = stream.next().await {
4555 events.push(result.expect("stream event should be Ok"));
4556 }
4557
4558 assert!(
4560 events
4561 .iter()
4562 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. })),
4563 "Values events should be emitted when last_step is None"
4564 );
4565
4566 let has_end = events
4567 .iter()
4568 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4569 assert!(has_end, "End event must be present");
4570 }
4571
4572 #[tokio::test]
4573 async fn test_resumption_skips_updates_at_or_before_last_step() {
4574 use futures::StreamExt;
4575
4576 let compiled = build_two_step_graph();
4577 let config = RunnableConfig::new();
4578
4579 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, Some(0));
4580 let stream_config =
4581 crate::stream::StreamConfig::new(StreamMode::Updates).with_resumption(resumption);
4582
4583 let handle = compiled
4584 .stream_with_config(
4585 MultiFieldState {
4586 messages: vec![],
4587 count: 0,
4588 label: String::new(),
4589 },
4590 &config,
4591 stream_config,
4592 )
4593 .await
4594 .expect("stream_with_config should succeed");
4595
4596 let mut events = Vec::new();
4597 let mut stream = handle.stream;
4598 while let Some(result) = stream.next().await {
4599 events.push(result.expect("stream event should be Ok"));
4600 }
4601
4602 let updates_steps: Vec<usize> = events
4604 .iter()
4605 .filter_map(|e| match e {
4606 crate::stream::StreamEvent::Updates { step, .. }
4607 | crate::stream::StreamEvent::FilteredUpdates { step, .. } => Some(*step),
4608 _ => None,
4609 })
4610 .collect();
4611
4612 assert!(
4613 !updates_steps.contains(&0),
4614 "Updates at step 0 should be skipped, got steps: {updates_steps:?}"
4615 );
4616
4617 let has_end = events
4618 .iter()
4619 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4620 assert!(has_end, "End event must always be emitted");
4621 }
4622
4623 #[tokio::test]
4624 async fn test_resumption_no_resumption_emits_all_events() {
4625 use futures::StreamExt;
4626
4627 let compiled = build_two_step_graph();
4628 let config = RunnableConfig::new();
4629
4630 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values);
4632
4633 let handle = compiled
4634 .stream_with_config(
4635 MultiFieldState {
4636 messages: vec![],
4637 count: 0,
4638 label: String::new(),
4639 },
4640 &config,
4641 stream_config,
4642 )
4643 .await
4644 .expect("stream_with_config should succeed");
4645
4646 let mut events = Vec::new();
4647 let mut stream = handle.stream;
4648 while let Some(result) = stream.next().await {
4649 events.push(result.expect("stream event should be Ok"));
4650 }
4651
4652 let values_count = events
4653 .iter()
4654 .filter(|e| matches!(e, crate::stream::StreamEvent::Values { .. }))
4655 .count();
4656
4657 assert!(
4658 values_count >= 1,
4659 "At least one Values event expected without resumption"
4660 );
4661
4662 let has_end = events
4663 .iter()
4664 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4665 assert!(has_end, "End event must be present");
4666 }
4667}
4668
4669