1use super::builder::NodeMetadata;
8use crate::info_span;
9#[cfg(target_family = "wasm")]
10use crate::tracing_wasm::WasmInstrument;
11use crate::{
12 JunctureError, State,
13 checkpoint::{
14 Checkpoint, CheckpointFilter, CheckpointMetadata, CheckpointSource, StateSnapshot,
15 },
16 config::RunnableConfig,
17 edge::TriggerTable,
18 interrupt::ResumeValue,
19 pregel::{BudgetTracker, PregelLoop},
20 state::{FromState, IntoState},
21 stream::{EventEmitter, StreamEvent, StreamMode},
22};
23use futures::Stream;
24use indexmap::IndexMap;
25use std::{pin::Pin, sync::Arc};
26use tokio::sync::mpsc;
27#[cfg(not(target_family = "wasm"))]
28use tracing::Instrument;
29
30const CHANNEL_CAPACITY_MESSAGES: usize = 256;
36
37const CHANNEL_CAPACITY_DEFAULT: usize = 32;
43
44fn stream_capacity(mode: &StreamMode) -> usize {
50 match mode {
51 StreamMode::Messages => CHANNEL_CAPACITY_MESSAGES,
52 StreamMode::Multi(modes) if modes.iter().any(|m| matches!(m, StreamMode::Messages)) => {
53 CHANNEL_CAPACITY_MESSAGES
54 }
55 _ => CHANNEL_CAPACITY_DEFAULT,
56 }
57}
58
59pub struct StreamHandle<S: State> {
81 run_id: String,
83 pub stream: Pin<Box<dyn Stream<Item = Result<StreamEvent<S>, JunctureError>> + Send>>,
85}
86
87impl<S: State> std::fmt::Debug for StreamHandle<S> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("StreamHandle")
90 .field("run_id", &self.run_id)
91 .field("stream", &"<stream>")
92 .finish()
93 }
94}
95
96impl<S: State> StreamHandle<S> {
97 #[must_use]
99 pub fn run_id(&self) -> &str {
100 &self.run_id
101 }
102
103 #[must_use]
105 #[allow(
106 clippy::type_complexity,
107 reason = "return type mirrors StreamHandle fields"
108 )]
109 pub fn into_parts(
110 self,
111 ) -> (
112 String,
113 Pin<Box<dyn Stream<Item = Result<StreamEvent<S>, JunctureError>> + Send>>,
114 ) {
115 (self.run_id, self.stream)
116 }
117}
118
119#[derive(Clone)]
137pub struct CompiledGraph<S: State, I: IntoState<S> = S, O: FromState<S> = S> {
138 inner: Arc<CompiledGraphInner<S>>,
139 _input: std::marker::PhantomData<I>,
140 _output: std::marker::PhantomData<O>,
141}
142
143impl<S: State, I: IntoState<S>, O: FromState<S>> std::fmt::Debug for CompiledGraph<S, I, O> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("CompiledGraph")
146 .field("node_count", &self.inner.nodes.len())
147 .field("has_checkpointer", &self.inner.checkpointer.is_some())
148 .finish()
149 }
150}
151
152impl<S: State, I: IntoState<S>, O: FromState<S>> CompiledGraph<S, I, O> {
153 #[must_use]
155 pub(crate) fn new(
156 nodes: IndexMap<String, Arc<dyn crate::Node<S>>>,
157 trigger_table: TriggerTable<S>,
158 builder_metadata: IndexMap<String, NodeMetadata>,
159 interrupt_before: Vec<String>,
160 interrupt_after: Vec<String>,
161 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
162 subgraphs: Vec<SubgraphInfo>,
163 ) -> Self {
164 Self {
165 inner: Arc::new(CompiledGraphInner {
166 nodes,
167 trigger_table,
168 builder_metadata,
169 checkpointer,
170 interrupt_before,
171 interrupt_after,
172 subgraphs,
173 active_invocations: std::sync::atomic::AtomicU64::new(0),
174 }),
175 _input: std::marker::PhantomData,
176 _output: std::marker::PhantomData,
177 }
178 }
179
180 fn build_error_handler_map(&self) -> std::collections::HashMap<String, String> {
185 self.inner
186 .builder_metadata
187 .iter()
188 .filter_map(|(node_name, meta)| {
189 meta.error_handler
190 .as_ref()
191 .map(|handler| (node_name.clone(), handler.clone()))
192 })
193 .collect()
194 }
195
196 fn build_retry_policy_map(
205 &self,
206 ) -> std::collections::HashMap<String, super::builder::RetryPolicy> {
207 self.inner
208 .builder_metadata
209 .iter()
210 .filter_map(|(node_name, meta)| {
211 meta.retry_policies
212 .first()
213 .map(|policy| (node_name.clone(), policy.clone()))
214 })
215 .collect()
216 }
217
218 fn build_timeout_policy_map(&self) -> std::collections::HashMap<String, crate::TimeoutPolicy> {
226 self.inner
227 .builder_metadata
228 .iter()
229 .filter_map(|(node_name, meta)| {
230 meta.timeout_policies
231 .first()
232 .cloned()
233 .map(|policy| (node_name.clone(), policy))
234 })
235 .collect()
236 }
237
238 fn build_circuit_breaker_map(
245 &self,
246 ) -> std::collections::HashMap<String, super::builder::CircuitBreakerConfig> {
247 self.inner
248 .builder_metadata
249 .iter()
250 .filter_map(|(node_name, meta)| {
251 meta.circuit_breaker
252 .as_ref()
253 .map(|config| (node_name.clone(), config.clone()))
254 })
255 .collect()
256 }
257
258 fn build_fallback_map(&self) -> std::collections::HashMap<String, String> {
265 self.inner
266 .builder_metadata
267 .iter()
268 .filter_map(|(node_name, meta)| {
269 meta.fallback_node
270 .as_ref()
271 .map(|fallback| (node_name.clone(), fallback.clone()))
272 })
273 .collect()
274 }
275
276 fn effective_config(&self, config: &RunnableConfig) -> RunnableConfig {
282 let mut effective = config.clone();
283 if effective.interrupt_before.is_none() && !self.inner.interrupt_before.is_empty() {
284 effective.interrupt_before = Some(self.inner.interrupt_before.clone());
285 }
286 if effective.interrupt_after.is_none() && !self.inner.interrupt_after.is_empty() {
287 effective.interrupt_after = Some(self.inner.interrupt_after.clone());
288 }
289 effective
290 }
291
292 fn deserialize_with_migration(
297 checkpoint: &crate::checkpoint::Checkpoint,
298 ) -> Result<S, JunctureError>
299 where
300 S: serde::de::DeserializeOwned,
301 {
302 let mut channel_values = checkpoint.channel_values.clone();
303 let checkpoint_version = checkpoint.schema_version;
304 let current_version = S::schema_version();
305 if checkpoint_version != current_version {
306 channel_values = S::migrate(checkpoint_version, channel_values);
307 }
308 serde_json::from_value(channel_values)
309 .map_err(|e| JunctureError::checkpoint(format!("failed to deserialize state: {e}")))
310 }
311
312 pub fn invoke(
327 &self,
328 input: I,
329 config: &RunnableConfig,
330 ) -> Result<GraphOutput<S, O>, JunctureError>
331 where
332 S: serde::de::DeserializeOwned + serde::Serialize,
333 S::Update: serde::Serialize,
334 O: FromState<S>,
335 {
336 let effective = self.effective_config(config);
337
338 let runtime = {
341 #[cfg(feature = "multi-thread")]
342 {
343 tokio::runtime::Runtime::new().map_err(|e| {
344 JunctureError::execution(format!("Failed to create runtime: {e}"))
345 })?
346 }
347 #[cfg(not(feature = "multi-thread"))]
348 {
349 tokio::runtime::Builder::new_current_thread()
350 .enable_all()
351 .build()
352 .map_err(|e| {
353 JunctureError::execution(format!("Failed to create runtime: {e}"))
354 })?
355 }
356 };
357
358 runtime.block_on(self.invoke_async_inner(input, &effective))
359 }
360
361 pub async fn invoke_async(
376 &self,
377 input: I,
378 config: &RunnableConfig,
379 ) -> Result<GraphOutput<S, O>, JunctureError>
380 where
381 S: serde::de::DeserializeOwned + serde::Serialize,
382 S::Update: serde::Serialize,
383 O: FromState<S>,
384 {
385 let effective = self.effective_config(config);
386 self.invoke_async_inner(input, &effective).await
387 }
388
389 #[expect(
391 clippy::too_many_lines,
392 reason = "invoke_async_inner orchestrates the full execution lifecycle: metadata extraction, PregelLoop creation, budget wiring, span creation, the tick/execute/after_tick loop, error handling, and metric emission. Splitting would scatter the linear flow across helper methods without improving readability."
393 )]
394 async fn invoke_async_inner(
395 &self,
396 input: I,
397 config: &RunnableConfig,
398 ) -> Result<GraphOutput<S, O>, JunctureError>
399 where
400 S: serde::de::DeserializeOwned + serde::Serialize,
401 S::Update: serde::Serialize,
402 O: FromState<S>,
403 {
404 let num_fields = 64;
406
407 let error_handler_map = self.build_error_handler_map();
409
410 let retry_policy_map = self.build_retry_policy_map();
412
413 let timeout_policy_map = self.build_timeout_policy_map();
415
416 let circuit_breaker_map = self.build_circuit_breaker_map();
418
419 let fallback_map = self.build_fallback_map();
421
422 let state_input = input.into_state();
424
425 let mut pregel = PregelLoop::with_error_handlers(
427 state_input,
428 self.inner.nodes.clone(),
429 self.inner.trigger_table.clone(),
430 config.clone(),
431 num_fields,
432 error_handler_map,
433 )?;
434
435 pregel.set_retry_policies(retry_policy_map);
436 pregel.set_timeout_policies(timeout_policy_map);
437 pregel.set_circuit_breaker_policies(circuit_breaker_map);
438 pregel.set_fallback_map(fallback_map);
439
440 if let Some(budget_config) = &pregel.runnable_config.budget {
442 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
443 pregel.set_budget_tracker(
444 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
445 );
446 }
447
448 let graph_name = config
451 .graph_name
452 .clone()
453 .unwrap_or_else(|| "unnamed".to_string());
454 let run_id = pregel.run_id().to_string();
455 let recursion_limit = pregel.runnable_config.recursion_limit;
456
457 async move {
458 let graph_start = crate::time::Instant::now();
459
460 if let Some(ref collector) = config.metrics_collector {
462 collector.inc_counter("juncture.graph.invocations", 1);
463
464 let active = self
465 .inner
466 .active_invocations
467 .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
468 + 1;
469 collector.set_gauge("juncture.graph.active_invocations", active);
470 }
471
472 let execution_result = async {
474 while pregel.tick()? {
475 let result = pregel.execute_superstep().await?;
476 pregel.after_tick(result).await?;
477 }
478 Ok::<(), JunctureError>(())
479 }
480 .await;
481
482 if let Some(ref collector) = config.metrics_collector {
484 let active = self
485 .inner
486 .active_invocations
487 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed)
488 - 1;
489 collector.set_gauge("juncture.graph.active_invocations", active);
490 }
491
492 let execution_result = match execution_result {
494 Ok(()) => Ok(()),
495 Err(e) => {
496 if let Some(ref collector) = config.metrics_collector {
498 collector.inc_counter("juncture.graph.errors", 1);
499 }
500 Err(e)
501 }
502 };
503
504 let steps = pregel.step();
506 let run_id = pregel.run_id().to_string();
507
508 let final_state = pregel.into_state();
510 let output = O::from_state(&final_state);
511
512 if let Some(ref collector) = config.metrics_collector {
514 #[allow(
515 clippy::cast_precision_loss,
516 reason = "Milliseconds as f64 is sufficient for histogram metrics; sub-millisecond precision is not required for graph duration tracking"
517 )]
518 collector.record_histogram(
519 "juncture.graph.duration_ms",
520 graph_start.elapsed().as_millis() as f64,
521 );
522 }
523
524 execution_result?;
525
526 Ok(GraphOutput {
527 value: final_state,
528 output,
529 interrupts: Vec::new(),
530 metadata: GraphOutputMetadata {
531 steps,
532 run_id,
533 checkpoint_id: config.checkpoint_id.clone(),
534 budget_usage: None,
535 },
536 })
537 }
538 .instrument(info_span!(
539 "juncture.graph.invoke",
540 "juncture.graph.name" = graph_name,
541 "juncture.run.id" = %run_id,
542 "juncture.recursion.limit" = recursion_limit,
543 ))
544 .await
545 }
546
547 pub async fn stream(
595 &self,
596 input: I,
597 config: &RunnableConfig,
598 mode: StreamMode,
599 ) -> Result<StreamHandle<S>, JunctureError>
600 where
601 S: Clone + Send + serde::de::DeserializeOwned + serde::Serialize + 'static,
602 S::Update: serde::Serialize,
603 {
604 self.stream_with_config(input, config, crate::stream::StreamConfig::new(mode))
605 .await
606 }
607
608 #[allow(
665 clippy::too_many_lines,
666 reason = "stream orchestration: channel setup, PregelLoop wiring, output_keys filtering, and event forwarding are inseparable"
667 )]
668 #[expect(
669 clippy::unused_async,
670 reason = "function signature follows async convention for consistency with invoke_async"
671 )]
672 pub async fn stream_with_config(
673 &self,
674 input: I,
675 config: &RunnableConfig,
676 stream_config: crate::stream::StreamConfig,
677 ) -> Result<StreamHandle<S>, JunctureError>
678 where
679 S: Clone + Send + serde::de::DeserializeOwned + serde::Serialize + 'static,
680 S::Update: serde::Serialize,
681 {
682 use futures::stream;
683
684 let effective = self.effective_config(config);
685 let num_fields = 64;
686 let mode = stream_config.mode.clone();
687 let output_keys = stream_config.output_keys;
688 let include_subgraphs = stream_config.include_subgraphs;
689 let subgraph_filter = stream_config.subgraph_filter;
690 let resumption = stream_config.resumption;
691
692 let capacity = stream_capacity(&mode);
695 let (tx, rx) = mpsc::channel(capacity);
696
697 let error_handler_map = self.build_error_handler_map();
699
700 let retry_policy_map = self.build_retry_policy_map();
702
703 let timeout_policy_map = self.build_timeout_policy_map();
705
706 let graph_name = effective
708 .graph_name
709 .clone()
710 .unwrap_or_else(|| "unnamed".to_string());
711
712 let state_input = input.into_state();
714 let mut pregel = PregelLoop::with_error_handlers(
715 state_input,
716 self.inner.nodes.clone(),
717 self.inner.trigger_table.clone(),
718 effective,
719 num_fields,
720 error_handler_map,
721 )?;
722
723 pregel.set_retry_policies(retry_policy_map);
724 pregel.set_timeout_policies(timeout_policy_map);
725
726 if let Some(budget_config) = &pregel.runnable_config.budget {
728 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
729 pregel.set_budget_tracker(
730 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
731 );
732 }
733
734 let run_id = pregel.run_id().to_string();
736 let recursion_limit = pregel.runnable_config.recursion_limit;
737
738 let (pregel_tx, mut pregel_rx) = mpsc::unbounded_channel();
743 pregel.set_stream_sender(pregel_tx);
744
745 tokio::spawn(
747 async move {
748 let tx_forward = tx.clone();
751 let mode_forward = mode.clone();
752 let output_keys_forward = output_keys.clone();
753 let resumption_forward = resumption.clone();
754 tokio::spawn(async move {
755 let (temp_tx, _temp_rx) = mpsc::channel(1);
757 let emitter = EventEmitter::new(temp_tx, mode_forward);
758
759 while let Some(event) = pregel_rx.recv().await {
760 if !emitter.should_emit(&event) {
761 continue;
762 }
763
764 let ns = event.namespace();
767 if !ns.is_empty() {
768 if !include_subgraphs {
769 continue;
770 }
771 if let Some(ref filter) = subgraph_filter
772 && let Some(first) = ns.first()
773 && !filter.contains(first)
774 {
775 continue;
776 }
777 }
778
779 if let Some(ref r) = resumption_forward {
783 let step = match &event {
784 StreamEvent::Values { step, .. }
785 | StreamEvent::FilteredValues { step, .. }
786 | StreamEvent::Updates { step, .. }
787 | StreamEvent::FilteredUpdates { step, .. } => Some(*step),
788 _ => None,
789 };
790 if let Some(s) = step
791 && r.should_skip(s)
792 {
793 continue;
794 }
795 }
796
797 let filtered = output_keys_forward.as_ref().and_then(|keys| match &event {
799 StreamEvent::Updates { node, update, step } => {
800 serde_json::to_value(update).ok().map(|json| {
801 StreamEvent::FilteredUpdates {
802 node: node.clone(),
803 data: crate::stream::filter_json_by_keys(json, keys),
804 step: *step,
805 }
806 })
807 }
808 _ => None,
809 });
810
811 if let Some(filtered_event) = filtered {
812 let _ = tx_forward.send(Ok(filtered_event)).await;
813 } else {
814 let _ = tx_forward.send(Ok(event)).await;
815 }
816 }
817 });
818
819 while matches!(pregel.tick(), Ok(true)) {
821 let step = pregel.step();
822
823 if matches!(mode, StreamMode::Values) {
826 let skip = resumption.as_ref().is_some_and(|r| r.should_skip(step));
827
828 if !skip {
829 let event = output_keys.as_ref().map_or_else(
830 || StreamEvent::Values {
831 state: pregel.snapshot_state(),
832 step,
833 },
834 |keys| {
835 let json = serde_json::to_value(pregel.snapshot_state())
836 .unwrap_or(serde_json::Value::Null);
837 StreamEvent::FilteredValues {
838 data: crate::stream::filter_json_by_keys(json, keys),
839 step,
840 }
841 },
842 );
843 let _ = tx.send(Ok(event)).await;
844 }
845 }
846
847 match pregel.execute_superstep().await {
849 Ok(result) => {
850 if let Err(e) = pregel.after_tick(result).await {
852 let _ = tx
854 .send(Ok(StreamEvent::End {
855 output: pregel.snapshot_state(),
856 }))
857 .await;
858 let _ = tx.send(Err(e)).await;
860 return;
861 }
862 }
863 Err(e) => {
864 let _ = tx
866 .send(Ok(StreamEvent::End {
867 output: pregel.snapshot_state(),
868 }))
869 .await;
870 let _ = tx.send(Err(e)).await;
872 return;
873 }
874 }
875 }
876
877 let final_state = pregel.into_state();
879 let _ = tx
880 .send(Ok(StreamEvent::End {
881 output: final_state,
882 }))
883 .await;
884 }
885 .instrument(info_span!(
886 "juncture.graph.invoke",
887 "juncture.graph.name" = graph_name,
888 "juncture.run.id" = %run_id,
889 "juncture.recursion.limit" = recursion_limit,
890 )),
891 );
892
893 Ok(StreamHandle {
895 run_id,
896 stream: Box::pin(stream::unfold(rx, |mut rx| async move {
897 rx.recv().await.map(|item| (item, rx))
898 })),
899 })
900 }
901
902 pub async fn execute_with_emitter(
944 &self,
945 input: S,
946 config: &RunnableConfig,
947 emitter: EventEmitter<S>,
948 ) -> Result<S, JunctureError>
949 where
950 S: Clone + Send + serde::de::DeserializeOwned + serde::Serialize + 'static,
951 S::Update: serde::Serialize,
952 {
953 let num_fields = 64;
954
955 let mut exec_config = self.effective_config(config);
957 if exec_config.run_id.is_none() {
959 exec_config.run_id = Some(uuid::Uuid::new_v4().to_string());
960 }
961
962 let graph_name = exec_config
964 .graph_name
965 .clone()
966 .unwrap_or_else(|| "unnamed".to_string());
967
968 let error_handler_map = self.build_error_handler_map();
969 let retry_policy_map = self.build_retry_policy_map();
970 let timeout_policy_map = self.build_timeout_policy_map();
971
972 let mut pregel = PregelLoop::with_error_handlers(
973 input,
974 self.inner.nodes.clone(),
975 self.inner.trigger_table.clone(),
976 exec_config,
977 num_fields,
978 error_handler_map,
979 )?;
980
981 pregel.set_retry_policies(retry_policy_map);
982 pregel.set_timeout_policies(timeout_policy_map);
983
984 if let Some(budget_config) = &pregel.runnable_config.budget {
986 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
987 pregel.set_budget_tracker(
988 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
989 );
990 }
991
992 if let Some(cp) = self.inner.checkpointer.clone() {
993 pregel.set_checkpointer(cp);
994 }
995
996 let mode = emitter.mode().clone();
997 let run_id = pregel.run_id().to_string();
998 let recursion_limit = pregel.runnable_config.recursion_limit;
999
1000 let (pregel_tx, mut pregel_rx) = mpsc::unbounded_channel();
1002 pregel.set_stream_sender(pregel_tx);
1003
1004 let emitter_clone = emitter.clone();
1006 tokio::spawn(async move {
1007 while let Some(event) = pregel_rx.recv().await {
1008 if emitter_clone.should_emit(&event) {
1009 emitter_clone.emit(event).await;
1010 }
1011 }
1012 });
1013
1014 async move {
1015 while pregel.tick()? {
1017 let step = pregel.step();
1018
1019 if matches!(mode, StreamMode::Values) {
1020 let event = StreamEvent::Values {
1021 state: pregel.snapshot_state(),
1022 step,
1023 };
1024 emitter.emit(event).await;
1025 }
1026
1027 let result = pregel.execute_superstep().await?;
1028 pregel.after_tick(result).await?;
1029 }
1030
1031 let final_state = pregel.into_state();
1033 emitter
1034 .emit(StreamEvent::End {
1035 output: final_state.clone(),
1036 })
1037 .await;
1038
1039 Ok(final_state)
1040 }
1041 .instrument(info_span!(
1042 "juncture.graph.invoke",
1043 "juncture.graph.name" = graph_name,
1044 "juncture.run.id" = %run_id,
1045 "juncture.recursion.limit" = recursion_limit,
1046 ))
1047 .await
1048 }
1049
1050 pub async fn resume(
1095 &self,
1096 config: &RunnableConfig,
1097 resume_value: ResumeValue,
1098 ) -> Result<GraphOutput<S, O>, JunctureError>
1099 where
1100 S: for<'de> serde::Deserialize<'de> + serde::Serialize,
1101 S::Update: serde::Serialize,
1102 O: FromState<S>,
1103 {
1104 let checkpointer =
1105 self.inner.checkpointer.as_ref().ok_or_else(|| {
1106 JunctureError::checkpoint("no checkpointer configured for resume")
1107 })?;
1108
1109 let tuple = checkpointer
1111 .get_tuple(config)
1112 .await
1113 .map_err(|e| JunctureError::checkpoint(format!("failed to load checkpoint: {e}")))?
1114 .ok_or_else(|| {
1115 JunctureError::checkpoint(format!(
1116 "checkpoint not found: thread_id={:?}, checkpoint_id={:?}",
1117 config.thread_id, config.checkpoint_id
1118 ))
1119 })?;
1120
1121 if !matches!(tuple.metadata.source, CheckpointSource::Interrupt { .. }) {
1124 return Err(JunctureError::checkpoint(format!(
1125 "resume() requires checkpoint from Interrupt source, got {:?}",
1126 tuple.metadata.source
1127 )));
1128 }
1129
1130 let state = Self::deserialize_with_migration(&tuple.checkpoint)?;
1132
1133 let mut resume_config = self.effective_config(config);
1135 resume_config.resume_value = Some(resume_value);
1136 if resume_config.run_id.is_none() {
1137 resume_config.run_id = Some(uuid::Uuid::new_v4().to_string());
1138 }
1139
1140 let graph_name = resume_config
1142 .graph_name
1143 .clone()
1144 .unwrap_or_else(|| "unnamed".to_string());
1145
1146 let num_fields = 64; let error_handler_map = self.build_error_handler_map();
1149 let retry_policy_map = self.build_retry_policy_map();
1150 let timeout_policy_map = self.build_timeout_policy_map();
1151 let mut pregel = crate::pregel::PregelLoop::with_error_handlers(
1152 state,
1153 self.inner.nodes.clone(),
1154 self.inner.trigger_table.clone(),
1155 resume_config,
1156 num_fields,
1157 error_handler_map,
1158 )?;
1159
1160 pregel.set_retry_policies(retry_policy_map);
1161 pregel.set_timeout_policies(timeout_policy_map);
1162
1163 if let Some(budget_config) = &pregel.runnable_config.budget {
1165 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
1166 pregel.set_budget_tracker(
1167 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
1168 );
1169 }
1170
1171 if let Some(cp) = self.inner.checkpointer.clone() {
1173 pregel.set_checkpointer(cp);
1174 }
1175
1176 let run_id = pregel.run_id().to_string();
1177 let recursion_limit = pregel.runnable_config.recursion_limit;
1178
1179 async move {
1180 while pregel.tick()? {
1182 let result = pregel.execute_superstep().await?;
1183 pregel.after_tick(result).await?;
1184 }
1185
1186 let steps = pregel.step();
1188 let run_id = pregel.run_id().to_string();
1189
1190 let final_state = pregel.into_state();
1192 let output = O::from_state(&final_state);
1193
1194 Ok(GraphOutput {
1195 value: final_state,
1196 output,
1197 interrupts: Vec::new(),
1198 metadata: GraphOutputMetadata {
1199 steps,
1200 run_id,
1201 checkpoint_id: config.checkpoint_id.clone(),
1202 budget_usage: None,
1203 },
1204 })
1205 }
1206 .instrument(info_span!(
1207 "juncture.graph.invoke",
1208 "juncture.graph.name" = graph_name,
1209 "juncture.run.id" = %run_id,
1210 "juncture.recursion.limit" = recursion_limit,
1211 ))
1212 .await
1213 }
1214
1215 pub async fn resume_single(
1240 &self,
1241 config: &RunnableConfig,
1242 value: serde_json::Value,
1243 ) -> Result<GraphOutput<S, O>, JunctureError>
1244 where
1245 S: for<'de> serde::Deserialize<'de> + serde::Serialize,
1246 S::Update: serde::Serialize,
1247 O: FromState<S>,
1248 {
1249 self.resume(config, ResumeValue::Single(value)).await
1250 }
1251
1252 pub async fn resume_stream(
1309 &self,
1310 config: &RunnableConfig,
1311 resume_value: ResumeValue,
1312 mode: StreamMode,
1313 ) -> Result<StreamHandle<S>, JunctureError>
1314 where
1315 S: Clone + Send + for<'de> serde::Deserialize<'de> + serde::Serialize + 'static,
1316 S::Update: serde::Serialize,
1317 {
1318 use futures::stream;
1319
1320 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1321 JunctureError::checkpoint("no checkpointer configured for resume_stream")
1322 })?;
1323
1324 let tuple = checkpointer
1326 .get_tuple(config)
1327 .await
1328 .map_err(|e| JunctureError::checkpoint(format!("failed to load checkpoint: {e}")))?
1329 .ok_or_else(|| {
1330 JunctureError::checkpoint(format!(
1331 "checkpoint not found: thread_id={:?}, checkpoint_id={:?}",
1332 config.thread_id, config.checkpoint_id
1333 ))
1334 })?;
1335
1336 if !matches!(tuple.metadata.source, CheckpointSource::Interrupt { .. }) {
1338 return Err(JunctureError::checkpoint(format!(
1339 "resume_stream() requires checkpoint from Interrupt source, got {:?}",
1340 tuple.metadata.source
1341 )));
1342 }
1343
1344 let state = Self::deserialize_with_migration(&tuple.checkpoint)?;
1346
1347 let mut resume_config = self.effective_config(config);
1349 resume_config.resume_value = Some(resume_value);
1350 if resume_config.run_id.is_none() {
1351 resume_config.run_id = Some(uuid::Uuid::new_v4().to_string());
1352 }
1353
1354 let num_fields = 64;
1356 let error_handler_map = self.build_error_handler_map();
1357 let retry_policy_map = self.build_retry_policy_map();
1358 let timeout_policy_map = self.build_timeout_policy_map();
1359 let mut pregel = PregelLoop::with_error_handlers(
1360 state,
1361 self.inner.nodes.clone(),
1362 self.inner.trigger_table.clone(),
1363 resume_config,
1364 num_fields,
1365 error_handler_map,
1366 )?;
1367
1368 pregel.set_retry_policies(retry_policy_map);
1369 pregel.set_timeout_policies(timeout_policy_map);
1370
1371 if let Some(budget_config) = &pregel.runnable_config.budget {
1373 let metrics_collector = pregel.runnable_config.metrics_collector.clone();
1374 pregel.set_budget_tracker(
1375 BudgetTracker::new(budget_config.clone()).with_metrics_collector(metrics_collector),
1376 );
1377 }
1378
1379 if let Some(cp) = self.inner.checkpointer.clone() {
1381 pregel.set_checkpointer(cp);
1382 }
1383
1384 let (_handle, rx, run_id) = Self::spawn_streaming_loop(pregel, mode);
1385
1386 Ok(StreamHandle {
1388 run_id,
1389 stream: Box::pin(stream::unfold(rx, |mut receiver| async move {
1390 receiver.recv().await.map(|item| (item, receiver))
1391 })),
1392 })
1393 }
1394
1395 #[allow(
1400 clippy::type_complexity,
1401 reason = "return type is a tuple of channel handle, receiver, and run_id which is clear in context"
1402 )]
1403 fn spawn_streaming_loop(
1404 mut pregel: PregelLoop<S>,
1405 mode: StreamMode,
1406 ) -> (
1407 tokio::task::JoinHandle<()>,
1408 mpsc::Receiver<Result<StreamEvent<S>, JunctureError>>,
1409 String,
1410 )
1411 where
1412 S: Clone + Send + for<'de> serde::Deserialize<'de> + serde::Serialize + 'static,
1413 S::Update: serde::Serialize,
1414 {
1415 let capacity = stream_capacity(&mode);
1418 let (tx, rx) = mpsc::channel(capacity);
1419
1420 let run_id = pregel.run_id().to_string();
1422 let graph_name = pregel
1423 .runnable_config
1424 .graph_name
1425 .clone()
1426 .unwrap_or_else(|| "unnamed".to_string());
1427 let recursion_limit = pregel.runnable_config.recursion_limit;
1428
1429 let (pregel_tx, mut pregel_rx) = mpsc::unbounded_channel();
1434 pregel.set_stream_sender(pregel_tx);
1435
1436 let handle = tokio::spawn(
1437 async move {
1438 let tx_forward = tx.clone();
1440 let mode_forward = mode.clone();
1441 tokio::spawn(async move {
1442 let (temp_tx, _temp_rx) = mpsc::channel(1);
1444 let emitter = EventEmitter::new(temp_tx, mode_forward);
1445
1446 while let Some(event) = pregel_rx.recv().await {
1447 if emitter.should_emit(&event) {
1448 let _ = tx_forward.send(Ok(event)).await;
1449 }
1450 }
1451 });
1452
1453 while matches!(pregel.tick(), Ok(true)) {
1455 let step = pregel.step();
1456
1457 if matches!(mode, StreamMode::Values) {
1459 let event = StreamEvent::Values {
1460 state: pregel.snapshot_state(),
1461 step,
1462 };
1463 let _ = tx.send(Ok(event)).await;
1464 }
1465
1466 match pregel.execute_superstep().await {
1468 Ok(result) => {
1469 if let Err(e) = pregel.after_tick(result).await {
1470 let _ = tx
1471 .send(Ok(StreamEvent::End {
1472 output: pregel.snapshot_state(),
1473 }))
1474 .await;
1475 let _ = tx.send(Err(e)).await;
1476 return;
1477 }
1478 }
1479 Err(e) => {
1480 let _ = tx
1481 .send(Ok(StreamEvent::End {
1482 output: pregel.snapshot_state(),
1483 }))
1484 .await;
1485 let _ = tx.send(Err(e)).await;
1486 return;
1487 }
1488 }
1489 }
1490
1491 let final_state = pregel.into_state();
1493 let _ = tx
1494 .send(Ok(StreamEvent::End {
1495 output: final_state,
1496 }))
1497 .await;
1498 }
1499 .instrument(info_span!(
1500 "juncture.graph.invoke",
1501 "juncture.graph.name" = graph_name,
1502 "juncture.run.id" = %run_id,
1503 "juncture.recursion.limit" = recursion_limit,
1504 )),
1505 );
1506
1507 (handle, rx, run_id)
1508 }
1509
1510 pub async fn get_state(
1519 &self,
1520 config: &RunnableConfig,
1521 ) -> Result<Option<StateSnapshot<S>>, JunctureError>
1522 where
1523 S: serde::de::DeserializeOwned,
1524 {
1525 let checkpointer =
1526 self.inner.checkpointer.as_ref().ok_or_else(|| {
1527 JunctureError::checkpoint("no checkpointer configured for get_state")
1528 })?;
1529
1530 let tuple = checkpointer
1531 .get_tuple(config)
1532 .await
1533 .map_err(|e| JunctureError::checkpoint(e.to_string()))?;
1534
1535 let Some(tuple) = tuple else {
1536 return Ok(None);
1537 };
1538
1539 let values = Self::deserialize_with_migration(&tuple.checkpoint)?;
1541
1542 let next: Vec<String> = tuple
1544 .checkpoint
1545 .pending_tasks
1546 .iter()
1547 .map(|t| t.node.clone())
1548 .collect();
1549
1550 let snapshot = StateSnapshot {
1551 values,
1552 next,
1553 config: tuple.config,
1554 metadata: tuple.metadata,
1555 created_at: tuple.checkpoint.created_at,
1556 parent_config: tuple.parent_config,
1557 tasks: vec![],
1558 };
1559
1560 Ok(Some(snapshot))
1561 }
1562
1563 #[expect(
1578 clippy::unused_async,
1579 reason = "async API consistency for checkpoint operations"
1580 )]
1581 pub async fn get_state_history(
1582 &self,
1583 _config: &RunnableConfig,
1584 filter: Option<CheckpointFilter>,
1585 ) -> Result<Vec<StateSnapshot<S>>, JunctureError> {
1586 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1587 JunctureError::checkpoint("no checkpointer configured for get_state_history")
1588 })?;
1589
1590 let _ = (checkpointer, filter);
1591
1592 Err(JunctureError::checkpoint(
1595 "get_state_history not yet implemented: requires checkpoint state recovery",
1596 ))
1597 }
1598
1599 pub async fn update_state(
1622 &self,
1623 config: &RunnableConfig,
1624 update: StateUpdate<S>,
1625 ) -> Result<RunnableConfig, JunctureError>
1626 where
1627 S: serde::de::DeserializeOwned + serde::Serialize,
1628 {
1629 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1630 JunctureError::checkpoint("no checkpointer configured for update_state")
1631 })?;
1632
1633 let tuple = checkpointer
1635 .get_tuple(config)
1636 .await
1637 .map_err(|e| JunctureError::checkpoint(e.to_string()))?;
1638
1639 let Some(tuple) = tuple else {
1640 return Err(JunctureError::checkpoint(
1641 "no checkpoint found for update_state",
1642 ));
1643 };
1644
1645 let mut state = Self::deserialize_with_migration(&tuple.checkpoint)?;
1647
1648 state.apply(update.update);
1650
1651 let updated_values = serde_json::to_value(&state).map_err(|e| {
1653 JunctureError::checkpoint(format!("failed to serialize updated state: {e}"))
1654 })?;
1655
1656 let mut writes = tuple.metadata.writes;
1658 if let Some(as_node) = update.as_node {
1659 writes.insert(as_node, serde_json::Value::Null);
1660 }
1661
1662 let updated_checkpoint = Checkpoint {
1664 channel_values: updated_values,
1665 ..tuple.checkpoint
1666 };
1667
1668 let metadata = CheckpointMetadata {
1670 source: CheckpointSource::Update,
1671 step: tuple.metadata.step + 1,
1672 writes,
1673 ..tuple.metadata
1674 };
1675
1676 checkpointer
1678 .put(config, updated_checkpoint, metadata)
1679 .await
1680 .map_err(|e| JunctureError::checkpoint(e.to_string()))
1681 }
1682
1683 #[expect(
1698 clippy::unused_async,
1699 reason = "async API consistency for checkpoint operations"
1700 )]
1701 pub async fn bulk_update_state(
1702 &self,
1703 _config: &RunnableConfig,
1704 updates: Vec<StateUpdate<S>>,
1705 ) -> Result<Vec<RunnableConfig>, JunctureError> {
1706 let checkpointer = self.inner.checkpointer.as_ref().ok_or_else(|| {
1707 JunctureError::checkpoint("no checkpointer configured for bulk_update_state")
1708 })?;
1709
1710 let _ = (checkpointer, updates);
1711
1712 Err(JunctureError::checkpoint(
1715 "bulk_update_state not yet implemented: requires checkpoint state recovery",
1716 ))
1717 }
1718
1719 #[must_use]
1730 pub fn get_graph(&self, xray: Option<usize>) -> DrawableGraph {
1731 let _ = xray;
1732
1733 self.to_drawable()
1736 }
1737
1738 #[must_use]
1743 pub fn get_subgraphs(&self) -> Vec<SubgraphInfo> {
1744 self.inner.subgraphs.clone()
1745 }
1746
1747 #[must_use]
1749 pub fn nodes(&self) -> &IndexMap<String, Arc<dyn crate::Node<S>>> {
1750 &self.inner.nodes
1751 }
1752
1753 #[must_use]
1755 pub fn trigger_table(&self) -> &TriggerTable<S> {
1756 &self.inner.trigger_table
1757 }
1758
1759 #[must_use]
1761 pub fn checkpointer(&self) -> Option<&Arc<dyn crate::checkpoint::CheckpointSaver>> {
1762 self.inner.checkpointer.as_ref()
1763 }
1764
1765 #[must_use]
1767 pub fn builder_metadata(&self) -> &IndexMap<String, NodeMetadata> {
1768 &self.inner.builder_metadata
1769 }
1770
1771 #[must_use]
1782 pub fn to_mermaid(&self) -> String {
1783 let mut lines = vec!["graph TD".to_string()];
1784
1785 for node_name in self.inner.nodes.keys() {
1787 lines.push(format!(" {node_name}[{node_name}]"));
1788 }
1789
1790 for (from, edges) in &self.inner.trigger_table.outgoing {
1792 for edge in edges {
1793 match edge {
1794 crate::edge::CompiledEdge::Fixed { target } => {
1795 lines.push(format!(" {from} --> {target}"));
1796 }
1797 crate::edge::CompiledEdge::Conditional { path_map, .. } => {
1798 for (branch, target) in path_map.iter() {
1799 lines.push(format!(" {from} -->|{branch}| {target}"));
1800 }
1801 }
1802 }
1803 }
1804 }
1805
1806 if let Some(entry) = self.find_entry_point() {
1808 lines.push(format!(" START((start)) --> {entry}"));
1809 }
1810
1811 lines.join("\n")
1812 }
1813
1814 #[must_use]
1825 pub fn to_dot(&self) -> String {
1826 let mut lines = vec!["digraph juncture_graph {".to_string()];
1827 lines.push(" rankdir=LR;".to_string());
1828 lines.push(" node [shape=box];".to_string());
1829 lines.push(" START [shape=circle];".to_string());
1830 lines.push(" END [shape=doublecircle];".to_string());
1831 lines.push(String::new());
1832
1833 for node_name in self.inner.nodes.keys() {
1835 lines.push(format!(" {node_name};"));
1836 }
1837
1838 lines.push(String::new());
1839
1840 for (from, edges) in &self.inner.trigger_table.outgoing {
1842 for edge in edges {
1843 match edge {
1844 crate::edge::CompiledEdge::Fixed { target } => {
1845 lines.push(format!(" {from} -> {target};"));
1846 }
1847 crate::edge::CompiledEdge::Conditional { path_map, .. } => {
1848 for (branch, target) in path_map.iter() {
1849 lines.push(format!(" {from} -> {target} [label=\"{branch}\"];"));
1850 }
1851 }
1852 }
1853 }
1854 }
1855
1856 if let Some(entry) = self.find_entry_point() {
1858 lines.push(format!(" START -> {entry};"));
1859 }
1860
1861 lines.push("}".to_string());
1862 lines.join("\n")
1863 }
1864
1865 #[must_use]
1877 pub fn to_json(&self) -> serde_json::Value {
1878 let drawable = self.to_drawable();
1879
1880 serde_json::json!({
1881 "nodes": drawable.nodes.into_iter().map(|n| {
1882 serde_json::json!({
1883 "name": n.name,
1884 "metadata": n.metadata,
1885 })
1886 }).collect::<Vec<_>>(),
1887 "edges": drawable.edges.into_iter().map(|e| {
1888 let mut edge = serde_json::json!({
1889 "from": e.from,
1890 "to": e.to,
1891 "conditional": e.conditional,
1892 });
1893 if let Some(label) = e.label {
1894 edge["label"] = serde_json::Value::String(label);
1895 }
1896 edge
1897 }).collect::<Vec<_>>(),
1898 })
1899 }
1900
1901 #[must_use]
1914 pub fn to_html(&self) -> String {
1915 let mermaid = self.to_mermaid();
1916 let json = self.to_json();
1917 let json_pretty = serde_json::to_string_pretty(&json).unwrap_or_default();
1918
1919 let mermaid_escaped = escape_html(&mermaid);
1921 let json_escaped = escape_html(&json_pretty);
1922
1923 format!(
1924 r#"<!DOCTYPE html>
1925<html lang="en">
1926<head>
1927 <meta charset="UTF-8">
1928 <meta name="viewport" content="width=device-width, initial-scale=1.0">
1929 <title>Juncture Graph Visualization</title>
1930 <script src="https://cdn.jsdelivr.net/npm/mermaid@11/dist/mermaid.min.js"></script>
1931 <style>
1932 body {{
1933 font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
1934 margin: 0;
1935 padding: 20px;
1936 background: #f5f5f5;
1937 }}
1938 .container {{
1939 max-width: 1200px;
1940 margin: 0 auto;
1941 }}
1942 h1 {{
1943 color: #333;
1944 margin-bottom: 20px;
1945 }}
1946 .graph-container {{
1947 background: white;
1948 border-radius: 8px;
1949 padding: 20px;
1950 box-shadow: 0 2px 4px rgba(0,0,0,0.1);
1951 margin-bottom: 20px;
1952 overflow-x: auto;
1953 }}
1954 .json-container {{
1955 background: #1e1e1e;
1956 color: #d4d4d4;
1957 border-radius: 8px;
1958 padding: 20px;
1959 overflow-x: auto;
1960 font-family: 'Monaco', 'Menlo', monospace;
1961 font-size: 13px;
1962 line-height: 1.5;
1963 }}
1964 .toggle-btn {{
1965 background: #007bff;
1966 color: white;
1967 border: none;
1968 padding: 10px 20px;
1969 border-radius: 4px;
1970 cursor: pointer;
1971 margin-bottom: 10px;
1972 }}
1973 .toggle-btn:hover {{
1974 background: #0056b3;
1975 }}
1976 #json-view {{
1977 display: none;
1978 }}
1979 </style>
1980</head>
1981<body>
1982 <div class="container">
1983 <h1>Juncture Graph Visualization</h1>
1984 <button class="toggle-btn" onclick="toggleJson()">Toggle JSON View</button>
1985 <div class="graph-container">
1986 <pre class="mermaid">
1987{mermaid_escaped}
1988 </pre>
1989 </div>
1990 <div id="json-view" class="json-container">
1991 <pre>{json_escaped}</pre>
1992 </div>
1993 </div>
1994 <script>
1995 mermaid.initialize({{ startOnLoad: true, theme: 'default' }});
1996 function toggleJson() {{
1997 const jsonView = document.getElementById('json-view');
1998 jsonView.style.display = jsonView.style.display === 'none' ? 'block' : 'none';
1999 }}
2000 </script>
2001</body>
2002</html>"#
2003 )
2004 }
2005
2006 fn to_drawable(&self) -> DrawableGraph {
2008 let mut nodes = Vec::new();
2009 let mut edges = Vec::new();
2010
2011 for node_name in self.inner.nodes.keys() {
2013 let metadata = self
2014 .inner
2015 .builder_metadata
2016 .get(node_name)
2017 .and_then(|m| m.metadata.clone())
2018 .unwrap_or_default();
2019
2020 nodes.push(DrawableNode {
2021 name: node_name.clone(),
2022 metadata,
2023 });
2024 }
2025
2026 for (from, edge_list) in &self.inner.trigger_table.outgoing {
2028 for edge in edge_list {
2029 match edge {
2030 crate::edge::CompiledEdge::Fixed { target } => {
2031 edges.push(DrawableEdge {
2032 from: from.clone(),
2033 to: target.clone(),
2034 conditional: false,
2035 label: None,
2036 });
2037 }
2038 crate::edge::CompiledEdge::Conditional { path_map, .. } => {
2039 for (branch, target) in path_map.iter() {
2040 edges.push(DrawableEdge {
2041 from: from.clone(),
2042 to: target.clone(),
2043 conditional: true,
2044 label: Some(branch.clone()),
2045 });
2046 }
2047 }
2048 }
2049 }
2050 }
2051
2052 DrawableGraph { nodes, edges }
2053 }
2054
2055 #[must_use]
2066 pub fn display(&self) -> String {
2067 let mut lines = Vec::new();
2068
2069 lines.push("Juncture Graph".to_string());
2071 lines.push("=".repeat(40));
2072 lines.push(String::new());
2073
2074 if let Some(entry) = self.find_entry_point() {
2076 lines.push(format!("Entry: {entry}"));
2077 lines.push(String::new());
2078 }
2079
2080 lines.push("Nodes:".to_string());
2082 for node_name in self.inner.nodes.keys() {
2083 let has_retry = self
2084 .inner
2085 .builder_metadata
2086 .get(node_name)
2087 .is_some_and(|m| !m.retry_policies.is_empty());
2088 let has_timeout = self
2089 .inner
2090 .builder_metadata
2091 .get(node_name)
2092 .is_some_and(|m| !m.timeout_policies.is_empty());
2093 let has_circuit_breaker = self
2094 .inner
2095 .builder_metadata
2096 .get(node_name)
2097 .is_some_and(|m| m.circuit_breaker.is_some());
2098 let has_fallback = self
2099 .inner
2100 .builder_metadata
2101 .get(node_name)
2102 .is_some_and(|m| m.fallback_node.is_some());
2103
2104 let mut annotations = Vec::new();
2105 if has_retry {
2106 annotations.push("retry");
2107 }
2108 if has_timeout {
2109 annotations.push("timeout");
2110 }
2111 if has_circuit_breaker {
2112 annotations.push("circuit-breaker");
2113 }
2114 if has_fallback {
2115 annotations.push("fallback");
2116 }
2117
2118 if annotations.is_empty() {
2119 lines.push(format!(" - {node_name}"));
2120 } else {
2121 lines.push(format!(" - {node_name} [{}]", annotations.join(", ")));
2122 }
2123 }
2124 lines.push(String::new());
2125
2126 lines.push("Edges:".to_string());
2128 for (from, edges) in &self.inner.trigger_table.outgoing {
2129 for edge in edges {
2130 match edge {
2131 crate::edge::CompiledEdge::Fixed { target } => {
2132 lines.push(format!(" {from} --> {target}"));
2133 }
2134 crate::edge::CompiledEdge::Conditional { path_map, .. } => {
2135 for (branch, target) in path_map.iter() {
2136 lines.push(format!(" {from} -->|{branch}| {target}"));
2137 }
2138 }
2139 }
2140 }
2141 }
2142
2143 lines.join("\n")
2144 }
2145
2146 fn find_entry_point(&self) -> Option<String> {
2148 for (target, sources) in &self.inner.trigger_table.incoming {
2149 for source in sources {
2150 if matches!(source, crate::edge::TriggerSource::Edge { from } if from == "START") {
2151 return Some(target.clone());
2152 }
2153 }
2154 }
2155 None
2156 }
2157}
2158
2159fn escape_html(s: &str) -> String {
2161 s.replace('&', "&")
2162 .replace('<', "<")
2163 .replace('>', ">")
2164 .replace('"', """)
2165 .replace('\'', "'")
2166}
2167
2168#[allow(dead_code, reason = "fields used through Arc, not directly")]
2170struct CompiledGraphInner<S: State> {
2171 nodes: IndexMap<String, Arc<dyn crate::Node<S>>>,
2173
2174 trigger_table: TriggerTable<S>,
2176
2177 builder_metadata: IndexMap<String, NodeMetadata>,
2179
2180 checkpointer: Option<Arc<dyn crate::checkpoint::CheckpointSaver>>,
2182
2183 interrupt_before: Vec<String>,
2185
2186 interrupt_after: Vec<String>,
2188
2189 subgraphs: Vec<SubgraphInfo>,
2191
2192 active_invocations: std::sync::atomic::AtomicU64,
2198}
2199
2200#[derive(Debug)]
2205pub struct GraphOutput<S: State, O: FromState<S> = S> {
2206 pub value: S,
2208
2209 pub output: O,
2211
2212 pub interrupts: Vec<InterruptInfo>,
2214
2215 pub metadata: GraphOutputMetadata,
2217}
2218
2219#[derive(Clone, Debug)]
2223pub struct InterruptInfo {
2224 pub node: String,
2226
2227 pub value: serde_json::Value,
2229
2230 pub id: Option<String>,
2232}
2233
2234#[derive(Clone, Debug)]
2238pub struct GraphOutputMetadata {
2239 pub steps: usize,
2241
2242 pub run_id: String,
2244
2245 pub checkpoint_id: Option<String>,
2247
2248 pub budget_usage: Option<crate::pregel::BudgetUsage>,
2250}
2251
2252#[derive(Clone, Debug)]
2257pub struct StateUpdate<S: State> {
2258 pub update: S::Update,
2260
2261 pub label: Option<String>,
2263
2264 pub as_node: Option<String>,
2266}
2267
2268#[derive(Clone, Debug)]
2273pub struct SubgraphInfo {
2274 pub name: String,
2276
2277 pub persistence: crate::subgraph::SubgraphPersistence,
2279}
2280
2281#[derive(Clone, Debug, Default)]
2285pub struct StateFilter {
2286 pub after_step: Option<usize>,
2288
2289 pub before_step: Option<usize>,
2291
2292 pub limit: Option<usize>,
2294}
2295
2296#[derive(Clone, Debug)]
2300pub struct DrawableGraph {
2301 pub nodes: Vec<DrawableNode>,
2303
2304 pub edges: Vec<DrawableEdge>,
2306}
2307
2308#[derive(Clone, Debug)]
2312pub struct DrawableNode {
2313 pub name: String,
2315
2316 pub metadata: std::collections::HashMap<String, serde_json::Value>,
2318}
2319
2320#[derive(Clone, Debug)]
2324pub struct DrawableEdge {
2325 pub from: String,
2327
2328 pub to: String,
2330
2331 pub conditional: bool,
2333
2334 pub label: Option<String>,
2336}
2337
2338#[cfg(test)]
2339mod tests {
2340 use super::*;
2341 use crate::{node::IntoNode, node::NodeFnUpdate};
2342
2343 #[test]
2344 fn test_compiled_graph_creation() {
2345 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2346 nodes.insert("test".to_string(), mock_node("test"));
2347
2348 let trigger_table = TriggerTable::new();
2349 let builder_metadata = IndexMap::new();
2350
2351 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2352 nodes,
2353 trigger_table,
2354 builder_metadata,
2355 vec![],
2356 vec![],
2357 None,
2358 vec![],
2359 );
2360 assert_eq!(compiled.nodes().len(), 1);
2361 }
2362
2363 #[test]
2364 fn test_to_mermaid() {
2365 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2366 nodes.insert("a".to_string(), mock_node("a"));
2367 nodes.insert("b".to_string(), mock_node("b"));
2368
2369 let mut trigger_table = TriggerTable::new();
2370 trigger_table.add_outgoing(
2371 "a".to_string(),
2372 crate::edge::CompiledEdge::Fixed {
2373 target: "b".to_string(),
2374 },
2375 );
2376
2377 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2378 nodes,
2379 trigger_table,
2380 IndexMap::new(),
2381 vec![],
2382 vec![],
2383 None,
2384 vec![],
2385 );
2386 let mermaid = compiled.to_mermaid();
2387
2388 assert!(mermaid.contains("graph TD"));
2389 assert!(mermaid.contains("a --> b"));
2390 }
2391
2392 #[test]
2393 fn test_to_dot() {
2394 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2395 nodes.insert("a".to_string(), mock_node("a"));
2396 nodes.insert("b".to_string(), mock_node("b"));
2397
2398 let mut trigger_table = TriggerTable::new();
2399 trigger_table.add_outgoing(
2400 "a".to_string(),
2401 crate::edge::CompiledEdge::Fixed {
2402 target: "b".to_string(),
2403 },
2404 );
2405
2406 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2407 nodes,
2408 trigger_table,
2409 IndexMap::new(),
2410 vec![],
2411 vec![],
2412 None,
2413 vec![],
2414 );
2415 let dot = compiled.to_dot();
2416
2417 assert!(dot.contains("digraph juncture_graph"));
2418 assert!(dot.contains("a -> b"));
2419 }
2420
2421 #[test]
2422 fn test_to_json() {
2423 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2424 nodes.insert("a".to_string(), mock_node("a"));
2425 nodes.insert("b".to_string(), mock_node("b"));
2426
2427 let mut trigger_table = TriggerTable::new();
2428 trigger_table.add_outgoing(
2429 "a".to_string(),
2430 crate::edge::CompiledEdge::Fixed {
2431 target: "b".to_string(),
2432 },
2433 );
2434
2435 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2436 nodes,
2437 trigger_table,
2438 IndexMap::new(),
2439 vec![],
2440 vec![],
2441 None,
2442 vec![],
2443 );
2444 let json = compiled.to_json();
2445
2446 assert!(json.is_object());
2447 assert!(json.get("nodes").is_some());
2448 assert!(json.get("edges").is_some());
2449 }
2450
2451 #[test]
2452 fn test_get_graph() {
2453 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2454 nodes.insert("a".to_string(), mock_node("a"));
2455
2456 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2457 nodes,
2458 TriggerTable::new(),
2459 IndexMap::new(),
2460 vec![],
2461 vec![],
2462 None,
2463 vec![],
2464 );
2465 let drawable = compiled.get_graph(None);
2466 assert_eq!(drawable.nodes.len(), 1);
2467
2468 let drawable_xray = compiled.get_graph(Some(2));
2469 assert_eq!(drawable_xray.nodes.len(), 1);
2470 }
2471
2472 #[test]
2473 fn test_get_subgraphs_empty() {
2474 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2475 nodes.insert("a".to_string(), mock_node("a"));
2476
2477 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2478 nodes,
2479 TriggerTable::new(),
2480 IndexMap::new(),
2481 vec![],
2482 vec![],
2483 None,
2484 vec![],
2485 );
2486 let subgraphs = compiled.get_subgraphs();
2487 assert!(subgraphs.is_empty());
2488 }
2489
2490 #[test]
2491 fn test_get_subgraphs_with_mounted_subgraphs() {
2492 use crate::subgraph::{SubgraphConfig, SubgraphMount, SubgraphPersistence};
2493
2494 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2495 nodes.insert("a".to_string(), mock_node("a"));
2496
2497 let sub_node = mock_node("sub_node");
2498 let mount_inherit = SubgraphMount::new(
2499 "child_graph",
2500 SubgraphConfig {
2501 persistence: SubgraphPersistence::Inherit,
2502 },
2503 Arc::clone(&sub_node),
2504 );
2505 let mount_per_thread = SubgraphMount::new(
2506 "worker_graph",
2507 SubgraphConfig {
2508 persistence: SubgraphPersistence::PerThread,
2509 },
2510 sub_node,
2511 );
2512
2513 let subgraphs = vec![
2514 super::SubgraphInfo {
2515 name: mount_inherit.name.clone(),
2516 persistence: mount_inherit.config.persistence,
2517 },
2518 super::SubgraphInfo {
2519 name: mount_per_thread.name.clone(),
2520 persistence: mount_per_thread.config.persistence,
2521 },
2522 ];
2523
2524 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2525 nodes,
2526 TriggerTable::new(),
2527 IndexMap::new(),
2528 vec![],
2529 vec![],
2530 None,
2531 subgraphs,
2532 );
2533
2534 let result = compiled.get_subgraphs();
2535 assert_eq!(result.len(), 2);
2536 assert_eq!(result[0].name, "child_graph");
2537 assert_eq!(result[0].persistence, SubgraphPersistence::Inherit);
2538 assert_eq!(result[1].name, "worker_graph");
2539 assert_eq!(result[1].persistence, SubgraphPersistence::PerThread);
2540 }
2541
2542 #[tokio::test]
2543 async fn test_resume_no_checkpointer() {
2544 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2545 nodes.insert("a".to_string(), mock_node("a"));
2546
2547 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2548 nodes,
2549 TriggerTable::new(),
2550 IndexMap::new(),
2551 vec![],
2552 vec![],
2553 None,
2554 vec![],
2555 );
2556 let config = RunnableConfig::new();
2557
2558 let result = compiled
2559 .resume(&config, ResumeValue::Single(serde_json::Value::Null))
2560 .await;
2561 assert!(result.is_err());
2562 assert!(result.unwrap_err().is_checkpoint());
2563 }
2564
2565 #[tokio::test]
2566 #[expect(
2567 clippy::too_many_lines,
2568 reason = "comprehensive test with multiple mock scenarios"
2569 )]
2570 async fn test_resume_validates_interrupt_source() {
2571 use crate::checkpoint::{
2572 Checkpoint, CheckpointMetadata, CheckpointSource, CheckpointTuple,
2573 };
2574 use std::collections::HashMap;
2575
2576 struct MockCheckpointer {
2578 checkpoint_source: CheckpointSource,
2579 }
2580
2581 #[async_trait::async_trait]
2582 impl crate::checkpoint::CheckpointSaver for MockCheckpointer {
2583 async fn get_tuple(
2584 &self,
2585 _config: &crate::config::RunnableConfig,
2586 ) -> Result<Option<CheckpointTuple>, crate::checkpoint::CheckpointError> {
2587 Ok(Some(CheckpointTuple {
2588 config: crate::config::RunnableConfig::new(),
2589 checkpoint: Checkpoint {
2590 id: "test_id".to_string(),
2591 channel_values: serde_json::json!({}),
2592 channel_versions: HashMap::new(),
2593 versions_seen: HashMap::new(),
2594 pending_tasks: Vec::new(),
2595 pending_sends: Vec::new(),
2596 pending_interrupts: Vec::new(),
2597 schema_version: 1,
2598 created_at: "2024-01-01T00:00:00Z".to_string(),
2599 v: 1,
2600 new_versions: HashMap::new(),
2601 counters_since_delta_snapshot: HashMap::new(),
2602 },
2603 metadata: CheckpointMetadata {
2604 source: self.checkpoint_source.clone(),
2605 step: 1,
2606 writes: HashMap::new(),
2607 parents: HashMap::new(),
2608 run_id: "test_run".to_string(),
2609 },
2610 pending_writes: Vec::new(),
2611 parent_config: None,
2612 }))
2613 }
2614
2615 async fn list(
2616 &self,
2617 _config: &crate::config::RunnableConfig,
2618 _filter: Option<crate::checkpoint::CheckpointFilter>,
2619 ) -> Result<Vec<CheckpointTuple>, crate::checkpoint::CheckpointError> {
2620 Ok(Vec::new())
2621 }
2622
2623 async fn put(
2624 &self,
2625 _config: &crate::config::RunnableConfig,
2626 _checkpoint: Checkpoint,
2627 _metadata: CheckpointMetadata,
2628 ) -> Result<crate::config::RunnableConfig, crate::checkpoint::CheckpointError>
2629 {
2630 Ok(crate::config::RunnableConfig::new())
2631 }
2632
2633 async fn put_writes(
2634 &self,
2635 _config: &crate::config::RunnableConfig,
2636 _writes: Vec<crate::checkpoint::PendingWrite>,
2637 _task_id: &str,
2638 ) -> Result<(), crate::checkpoint::CheckpointError> {
2639 Ok(())
2640 }
2641 }
2642
2643 let nodes = {
2645 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2646 nodes.insert("a".to_string(), mock_node("a"));
2647 nodes
2648 };
2649
2650 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2651 nodes.clone(),
2652 TriggerTable::new(),
2653 IndexMap::new(),
2654 vec![],
2655 vec![],
2656 Some(Arc::new(MockCheckpointer {
2657 checkpoint_source: CheckpointSource::Input,
2658 })),
2659 vec![],
2660 );
2661
2662 let config = RunnableConfig::new();
2663 let result = compiled
2664 .resume(&config, ResumeValue::Single(serde_json::json!("test")))
2665 .await;
2666
2667 assert!(result.is_err());
2668 let err = result.unwrap_err();
2669 assert!(err.is_checkpoint());
2670 assert!(
2671 err.to_string()
2672 .contains("resume() requires checkpoint from Interrupt source")
2673 );
2674 assert!(err.to_string().contains("Input"));
2675
2676 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2678 nodes.clone(),
2679 TriggerTable::new(),
2680 IndexMap::new(),
2681 vec![],
2682 vec![],
2683 Some(Arc::new(MockCheckpointer {
2684 checkpoint_source: CheckpointSource::Loop,
2685 })),
2686 vec![],
2687 );
2688
2689 let result = compiled
2690 .resume(&config, ResumeValue::Single(serde_json::json!("test")))
2691 .await;
2692
2693 assert!(result.is_err());
2694 let err = result.unwrap_err();
2695 assert!(err.is_checkpoint());
2696 assert!(
2697 err.to_string()
2698 .contains("resume() requires checkpoint from Interrupt source")
2699 );
2700 assert!(err.to_string().contains("Loop"));
2701
2702 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2704 nodes,
2705 TriggerTable::new(),
2706 IndexMap::new(),
2707 vec![],
2708 vec![],
2709 Some(Arc::new(MockCheckpointer {
2710 checkpoint_source: CheckpointSource::Interrupt {
2711 node: "test_node".to_string(),
2712 },
2713 })),
2714 vec![],
2715 );
2716
2717 let result = compiled
2718 .resume(&config, ResumeValue::Single(serde_json::json!("test")))
2719 .await;
2720
2721 if let Err(err) = result {
2724 assert!(
2725 !err.to_string()
2726 .contains("resume() requires checkpoint from Interrupt source")
2727 );
2728 }
2729 }
2730
2731 #[tokio::test]
2732 async fn test_resume_single_no_checkpointer() {
2733 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2734 nodes.insert("a".to_string(), mock_node("a"));
2735
2736 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2737 nodes,
2738 TriggerTable::new(),
2739 IndexMap::new(),
2740 vec![],
2741 vec![],
2742 None,
2743 vec![],
2744 );
2745 let config = RunnableConfig::new();
2746
2747 let result = compiled
2748 .resume_single(&config, serde_json::Value::Null)
2749 .await;
2750 assert!(result.is_err());
2751 assert!(result.unwrap_err().is_checkpoint());
2752 }
2753
2754 #[tokio::test]
2755 async fn test_resume_stream_no_checkpointer() {
2756 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2757 nodes.insert("a".to_string(), mock_node("a"));
2758
2759 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2760 nodes,
2761 TriggerTable::new(),
2762 IndexMap::new(),
2763 vec![],
2764 vec![],
2765 None,
2766 vec![],
2767 );
2768 let config = RunnableConfig::new();
2769
2770 let result = compiled
2771 .resume_stream(
2772 &config,
2773 ResumeValue::Single(serde_json::Value::Null),
2774 StreamMode::Values,
2775 )
2776 .await;
2777 let Err(err) = result else {
2778 panic!("expected checkpoint error, got stream");
2779 };
2780 assert!(err.is_checkpoint());
2781 }
2782
2783 #[tokio::test]
2784 #[expect(
2785 clippy::too_many_lines,
2786 reason = "mock checkpointer boilerplate inflates line count; extraction would hurt readability"
2787 )]
2788 async fn test_resume_stream_validates_interrupt_source() {
2789 use crate::checkpoint::{
2790 Checkpoint, CheckpointError, CheckpointMetadata, CheckpointSource, CheckpointTuple,
2791 };
2792 use std::collections::HashMap;
2793
2794 struct MockCheckpointer {
2795 checkpoint_source: CheckpointSource,
2796 }
2797
2798 #[async_trait::async_trait]
2799 impl crate::checkpoint::CheckpointSaver for MockCheckpointer {
2800 async fn get_tuple(
2801 &self,
2802 _config: &crate::config::RunnableConfig,
2803 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
2804 Ok(Some(CheckpointTuple {
2805 config: crate::config::RunnableConfig::new(),
2806 checkpoint: Checkpoint {
2807 id: "test_id".to_string(),
2808 channel_values: serde_json::json!({}),
2809 channel_versions: HashMap::new(),
2810 versions_seen: HashMap::new(),
2811 pending_tasks: Vec::new(),
2812 pending_sends: Vec::new(),
2813 pending_interrupts: Vec::new(),
2814 schema_version: 1,
2815 created_at: "2024-01-01T00:00:00Z".to_string(),
2816 v: 1,
2817 new_versions: HashMap::new(),
2818 counters_since_delta_snapshot: HashMap::new(),
2819 },
2820 metadata: CheckpointMetadata {
2821 source: self.checkpoint_source.clone(),
2822 step: 1,
2823 writes: HashMap::new(),
2824 parents: HashMap::new(),
2825 run_id: "test_run".to_string(),
2826 },
2827 pending_writes: Vec::new(),
2828 parent_config: None,
2829 }))
2830 }
2831
2832 async fn list(
2833 &self,
2834 _config: &crate::config::RunnableConfig,
2835 _filter: Option<crate::checkpoint::CheckpointFilter>,
2836 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
2837 Ok(Vec::new())
2838 }
2839
2840 async fn put(
2841 &self,
2842 _config: &crate::config::RunnableConfig,
2843 _checkpoint: Checkpoint,
2844 _metadata: CheckpointMetadata,
2845 ) -> Result<crate::config::RunnableConfig, CheckpointError> {
2846 Ok(crate::config::RunnableConfig::new())
2847 }
2848
2849 async fn put_writes(
2850 &self,
2851 _config: &crate::config::RunnableConfig,
2852 _writes: Vec<crate::checkpoint::PendingWrite>,
2853 _task_id: &str,
2854 ) -> Result<(), CheckpointError> {
2855 Ok(())
2856 }
2857 }
2858
2859 let nodes = {
2860 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2861 nodes.insert("a".to_string(), mock_node("a"));
2862 nodes
2863 };
2864
2865 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2867 nodes.clone(),
2868 TriggerTable::new(),
2869 IndexMap::new(),
2870 vec![],
2871 vec![],
2872 Some(Arc::new(MockCheckpointer {
2873 checkpoint_source: CheckpointSource::Input,
2874 })),
2875 vec![],
2876 );
2877
2878 let config = RunnableConfig::new();
2879 let result = compiled
2880 .resume_stream(
2881 &config,
2882 ResumeValue::Single(serde_json::json!("test")),
2883 StreamMode::Values,
2884 )
2885 .await;
2886
2887 assert!(result.is_err());
2888 let Err(err) = result else {
2889 panic!("expected checkpoint error, got stream");
2890 };
2891 assert!(err.is_checkpoint());
2892 assert!(
2893 err.to_string()
2894 .contains("resume_stream() requires checkpoint from Interrupt source"),
2895 "Expected interrupt source validation error, got: {err}"
2896 );
2897
2898 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2900 nodes,
2901 TriggerTable::new(),
2902 IndexMap::new(),
2903 vec![],
2904 vec![],
2905 Some(Arc::new(MockCheckpointer {
2906 checkpoint_source: CheckpointSource::Interrupt {
2907 node: "test_node".to_string(),
2908 },
2909 })),
2910 vec![],
2911 );
2912
2913 let result = compiled
2914 .resume_stream(
2915 &config,
2916 ResumeValue::Single(serde_json::json!("test")),
2917 StreamMode::Values,
2918 )
2919 .await;
2920
2921 if let Err(err) = result {
2924 assert!(
2925 !err.to_string()
2926 .contains("resume_stream() requires checkpoint from Interrupt source"),
2927 "Interrupt source should pass validation: {err}"
2928 );
2929 }
2930 }
2931
2932 #[tokio::test]
2933 async fn test_get_state_no_checkpointer() {
2934 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2935 nodes.insert("a".to_string(), mock_node("a"));
2936
2937 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2938 nodes,
2939 TriggerTable::new(),
2940 IndexMap::new(),
2941 vec![],
2942 vec![],
2943 None,
2944 vec![],
2945 );
2946 let config = RunnableConfig::new();
2947
2948 let result = compiled.get_state(&config).await;
2949 assert!(result.is_err());
2950 assert!(result.unwrap_err().is_checkpoint());
2951 }
2952
2953 #[tokio::test]
2954 async fn test_get_state_history_no_checkpointer() {
2955 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2956 nodes.insert("a".to_string(), mock_node("a"));
2957
2958 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2959 nodes,
2960 TriggerTable::new(),
2961 IndexMap::new(),
2962 vec![],
2963 vec![],
2964 None,
2965 vec![],
2966 );
2967 let config = RunnableConfig::new();
2968
2969 let result = compiled.get_state_history(&config, None).await;
2970 assert!(result.is_err());
2971 assert!(result.unwrap_err().is_checkpoint());
2972 }
2973
2974 #[tokio::test]
2975 async fn test_update_state_no_checkpointer() {
2976 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
2977 nodes.insert("a".to_string(), mock_node("a"));
2978
2979 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
2980 nodes,
2981 TriggerTable::new(),
2982 IndexMap::new(),
2983 vec![],
2984 vec![],
2985 None,
2986 vec![],
2987 );
2988 let config = RunnableConfig::new();
2989
2990 let update = StateUpdate {
2991 update: StateDummyUpdate,
2992 label: None,
2993 as_node: None,
2994 };
2995
2996 let result = compiled.update_state(&config, update).await;
2997 assert!(result.is_err());
2998 assert!(result.unwrap_err().is_checkpoint());
2999 }
3000
3001 #[tokio::test]
3002 async fn test_update_state_no_checkpoint_found() {
3003 use crate::checkpoint::{Checkpoint, CheckpointError, CheckpointMetadata, CheckpointTuple};
3004
3005 struct NoCheckpointCheckpointer;
3006
3007 #[async_trait::async_trait]
3008 impl crate::checkpoint::CheckpointSaver for NoCheckpointCheckpointer {
3009 async fn get_tuple(
3010 &self,
3011 _config: &crate::config::RunnableConfig,
3012 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
3013 Ok(None)
3014 }
3015
3016 async fn list(
3017 &self,
3018 _config: &crate::config::RunnableConfig,
3019 _filter: Option<crate::checkpoint::CheckpointFilter>,
3020 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
3021 Ok(Vec::new())
3022 }
3023
3024 async fn put(
3025 &self,
3026 _config: &crate::config::RunnableConfig,
3027 _checkpoint: Checkpoint,
3028 _metadata: CheckpointMetadata,
3029 ) -> Result<crate::config::RunnableConfig, CheckpointError> {
3030 Ok(crate::config::RunnableConfig::new())
3031 }
3032
3033 async fn put_writes(
3034 &self,
3035 _config: &crate::config::RunnableConfig,
3036 _writes: Vec<crate::checkpoint::PendingWrite>,
3037 _task_id: &str,
3038 ) -> Result<(), CheckpointError> {
3039 Ok(())
3040 }
3041 }
3042
3043 let nodes = {
3044 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3045 nodes.insert("a".to_string(), mock_node("a"));
3046 nodes
3047 };
3048
3049 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3050 nodes,
3051 TriggerTable::new(),
3052 IndexMap::new(),
3053 vec![],
3054 vec![],
3055 Some(Arc::new(NoCheckpointCheckpointer)),
3056 vec![],
3057 );
3058
3059 let config = RunnableConfig::new();
3060 let update = StateUpdate {
3061 update: StateDummyUpdate,
3062 label: None,
3063 as_node: None,
3064 };
3065
3066 let result = compiled.update_state(&config, update).await;
3067 assert!(result.is_err());
3068 let err = result.unwrap_err();
3069 assert!(err.is_checkpoint());
3070 assert!(
3071 err.to_string().contains("no checkpoint found"),
3072 "Expected 'no checkpoint found' error, got: {err}"
3073 );
3074 }
3075
3076 #[tokio::test]
3077 #[expect(
3078 clippy::too_many_lines,
3079 reason = "mock checkpointer boilerplate inflates line count; extraction would hurt readability"
3080 )]
3081 async fn test_update_state_success() {
3082 use crate::checkpoint::{
3083 Checkpoint, CheckpointError, CheckpointMetadata, CheckpointSource, CheckpointTuple,
3084 };
3085 use std::collections::HashMap;
3086 use std::sync::{Arc, Mutex};
3087
3088 #[derive(Clone)]
3089 enum ObservedCall {
3090 Put { source: CheckpointSource, step: i64 },
3091 }
3092
3093 struct MockCheckpointer {
3094 observed: Arc<Mutex<Vec<ObservedCall>>>,
3095 }
3096
3097 #[async_trait::async_trait]
3098 impl crate::checkpoint::CheckpointSaver for MockCheckpointer {
3099 async fn get_tuple(
3100 &self,
3101 _config: &crate::config::RunnableConfig,
3102 ) -> Result<Option<CheckpointTuple>, CheckpointError> {
3103 Ok(Some(CheckpointTuple {
3104 config: crate::config::RunnableConfig::new(),
3105 checkpoint: Checkpoint {
3106 id: "cp_123".to_string(),
3107 channel_values: serde_json::Value::Null,
3108 channel_versions: HashMap::new(),
3109 versions_seen: HashMap::new(),
3110 pending_tasks: Vec::new(),
3111 pending_sends: Vec::new(),
3112 pending_interrupts: Vec::new(),
3113 schema_version: 1,
3114 created_at: "2024-01-01T00:00:00Z".to_string(),
3115 v: 1,
3116 new_versions: HashMap::new(),
3117 counters_since_delta_snapshot: HashMap::new(),
3118 },
3119 metadata: CheckpointMetadata {
3120 source: CheckpointSource::Loop,
3121 step: 5,
3122 writes: HashMap::new(),
3123 parents: HashMap::new(),
3124 run_id: "run_abc".to_string(),
3125 },
3126 pending_writes: Vec::new(),
3127 parent_config: None,
3128 }))
3129 }
3130
3131 async fn list(
3132 &self,
3133 _config: &crate::config::RunnableConfig,
3134 _filter: Option<crate::checkpoint::CheckpointFilter>,
3135 ) -> Result<Vec<CheckpointTuple>, CheckpointError> {
3136 Ok(Vec::new())
3137 }
3138
3139 async fn put(
3140 &self,
3141 _config: &crate::config::RunnableConfig,
3142 _checkpoint: Checkpoint,
3143 metadata: CheckpointMetadata,
3144 ) -> Result<crate::config::RunnableConfig, CheckpointError> {
3145 self.observed
3146 .lock()
3147 .unwrap_or_else(std::sync::PoisonError::into_inner)
3148 .push(ObservedCall::Put {
3149 source: metadata.source,
3150 step: metadata.step,
3151 });
3152 Ok(crate::config::RunnableConfig::new())
3153 }
3154
3155 async fn put_writes(
3156 &self,
3157 _config: &crate::config::RunnableConfig,
3158 _writes: Vec<crate::checkpoint::PendingWrite>,
3159 _task_id: &str,
3160 ) -> Result<(), CheckpointError> {
3161 Ok(())
3162 }
3163 }
3164
3165 let observed = Arc::new(Mutex::new(Vec::new()));
3166 let nodes = {
3167 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3168 nodes.insert("a".to_string(), mock_node("a"));
3169 nodes
3170 };
3171
3172 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3173 nodes,
3174 TriggerTable::new(),
3175 IndexMap::new(),
3176 vec![],
3177 vec![],
3178 Some(Arc::new(MockCheckpointer {
3179 observed: Arc::clone(&observed),
3180 })),
3181 vec![],
3182 );
3183
3184 let config = RunnableConfig::new();
3185 let update = StateUpdate {
3186 update: StateDummyUpdate,
3187 label: Some("manual fix".to_string()),
3188 as_node: Some("admin".to_string()),
3189 };
3190
3191 let result = compiled.update_state(&config, update).await;
3192 assert!(result.is_ok(), "update_state should succeed");
3193
3194 let calls = observed
3196 .lock()
3197 .unwrap_or_else(std::sync::PoisonError::into_inner);
3198 assert_eq!(calls.len(), 1, "Expected exactly one put call");
3199 match &calls[0] {
3200 ObservedCall::Put { source, step } => {
3201 assert!(
3202 matches!(source, CheckpointSource::Update),
3203 "Expected Update source, got {source:?}"
3204 );
3205 assert_eq!(*step, 6, "Expected step to be incremented from 5 to 6");
3206 }
3207 }
3208 }
3209
3210 #[tokio::test]
3211 async fn test_bulk_update_state_no_checkpointer() {
3212 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3213 nodes.insert("a".to_string(), mock_node("a"));
3214
3215 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3216 nodes,
3217 TriggerTable::new(),
3218 IndexMap::new(),
3219 vec![],
3220 vec![],
3221 None,
3222 vec![],
3223 );
3224 let config = RunnableConfig::new();
3225
3226 let updates = vec![StateUpdate {
3227 update: StateDummyUpdate,
3228 label: None,
3229 as_node: None,
3230 }];
3231
3232 let result = compiled.bulk_update_state(&config, updates).await;
3233 assert!(result.is_err());
3234 assert!(result.unwrap_err().is_checkpoint());
3235 }
3236
3237 #[test]
3238 fn test_state_update_creation() {
3239 let update: StateUpdate<StateDummy> = StateUpdate {
3240 update: StateDummyUpdate,
3241 label: Some("test update".to_string()),
3242 as_node: Some("my_node".to_string()),
3243 };
3244
3245 assert!(update.label.is_some());
3246 assert!(update.as_node.is_some());
3247 }
3248
3249 #[test]
3250 fn test_subgraph_info_creation() {
3251 let info = SubgraphInfo {
3252 name: "my_subgraph".to_string(),
3253 persistence: crate::subgraph::SubgraphPersistence::Inherit,
3254 };
3255
3256 assert_eq!(info.name, "my_subgraph");
3257 }
3258
3259 #[test]
3260 fn test_state_filter_default() {
3261 let filter = StateFilter::default();
3262 assert!(filter.after_step.is_none());
3263 assert!(filter.before_step.is_none());
3264 assert!(filter.limit.is_none());
3265 }
3266
3267 #[test]
3268 fn test_state_filter_with_values() {
3269 let filter = StateFilter {
3270 after_step: Some(5),
3271 before_step: Some(10),
3272 limit: Some(20),
3273 };
3274
3275 assert_eq!(filter.after_step, Some(5));
3276 assert_eq!(filter.before_step, Some(10));
3277 assert_eq!(filter.limit, Some(20));
3278 }
3279
3280 #[tokio::test]
3281 async fn test_stream_values_mode() {
3282 use futures::StreamExt;
3283
3284 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3286 nodes.insert("node_a".to_string(), mock_node("node_a"));
3287
3288 let mut trigger_table = TriggerTable::new();
3289 trigger_table.add_incoming(
3291 "node_a".to_string(),
3292 crate::edge::TriggerSource::Edge {
3293 from: crate::edge::START.to_string(),
3294 },
3295 );
3296
3297 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3298 nodes,
3299 trigger_table,
3300 IndexMap::new(),
3301 vec![],
3302 vec![],
3303 None,
3304 vec![],
3305 );
3306 let config = RunnableConfig::new();
3307
3308 let handle = compiled
3309 .stream(StateDummy, &config, StreamMode::Values)
3310 .await
3311 .expect("stream should succeed");
3312
3313 let mut events = Vec::new();
3315 let mut stream = handle.stream;
3316 while let Some(result) = stream.next().await {
3317 events.push(result.expect("stream event should be Ok"));
3318 }
3319
3320 let has_values = events
3322 .iter()
3323 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. }));
3324 let has_end = events
3325 .iter()
3326 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3327
3328 assert!(has_values, "Expected Values events in Values mode");
3329 assert!(has_end, "Expected End event");
3330 }
3331
3332 #[tokio::test]
3333 async fn test_stream_updates_mode() {
3334 use futures::StreamExt;
3335
3336 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3338 nodes.insert("node_a".to_string(), mock_node("node_a"));
3339
3340 let mut trigger_table = TriggerTable::new();
3341 trigger_table.add_incoming(
3342 "node_a".to_string(),
3343 crate::edge::TriggerSource::Edge {
3344 from: crate::edge::START.to_string(),
3345 },
3346 );
3347
3348 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3349 nodes,
3350 trigger_table,
3351 IndexMap::new(),
3352 vec![],
3353 vec![],
3354 None,
3355 vec![],
3356 );
3357 let config = RunnableConfig::new();
3358
3359 let handle = compiled
3360 .stream(StateDummy, &config, StreamMode::Updates)
3361 .await
3362 .expect("stream should succeed");
3363
3364 let mut events = Vec::new();
3366 let mut stream = handle.stream;
3367 while let Some(result) = stream.next().await {
3368 events.push(result.expect("stream event should be Ok"));
3369 }
3370
3371 let has_updates = events
3373 .iter()
3374 .any(|e| matches!(e, crate::stream::StreamEvent::Updates { .. }));
3375 let has_end = events
3376 .iter()
3377 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3378
3379 assert!(has_updates, "Expected Updates events in Updates mode");
3380 assert!(has_end, "Expected End event");
3381 }
3382
3383 #[tokio::test]
3384 async fn test_stream_debug_mode() {
3385 use futures::StreamExt;
3386
3387 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3389 nodes.insert("node_a".to_string(), mock_node("node_a"));
3390
3391 let mut trigger_table = TriggerTable::new();
3392 trigger_table.add_incoming(
3393 "node_a".to_string(),
3394 crate::edge::TriggerSource::Edge {
3395 from: crate::edge::START.to_string(),
3396 },
3397 );
3398
3399 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3400 nodes,
3401 trigger_table,
3402 IndexMap::new(),
3403 vec![],
3404 vec![],
3405 None,
3406 vec![],
3407 );
3408 let config = RunnableConfig::new();
3409
3410 let handle = compiled
3411 .stream(StateDummy, &config, StreamMode::Debug)
3412 .await
3413 .expect("stream should succeed");
3414
3415 let mut events = Vec::new();
3417 let mut stream = handle.stream;
3418 while let Some(result) = stream.next().await {
3419 events.push(result.expect("stream event should be Ok"));
3420 }
3421
3422 let has_debug = events
3424 .iter()
3425 .any(|e| matches!(e, crate::stream::StreamEvent::Debug(_)));
3426 let has_end = events
3427 .iter()
3428 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3429
3430 assert!(has_debug, "Expected Debug events in Debug mode");
3431 assert!(has_end, "Expected End event");
3432 }
3433
3434 #[tokio::test]
3435 async fn test_stream_end_event() {
3436 use futures::StreamExt;
3437
3438 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3440 nodes.insert("node_a".to_string(), mock_node("node_a"));
3441
3442 let mut trigger_table = TriggerTable::new();
3443 trigger_table.add_incoming(
3444 "node_a".to_string(),
3445 crate::edge::TriggerSource::Edge {
3446 from: crate::edge::START.to_string(),
3447 },
3448 );
3449
3450 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3451 nodes,
3452 trigger_table,
3453 IndexMap::new(),
3454 vec![],
3455 vec![],
3456 None,
3457 vec![],
3458 );
3459 let config = RunnableConfig::new();
3460
3461 let handle = compiled
3462 .stream(StateDummy, &config, StreamMode::Values)
3463 .await
3464 .expect("stream should succeed");
3465
3466 let mut events = Vec::new();
3468 let mut stream = handle.stream;
3469 while let Some(result) = stream.next().await {
3470 events.push(result.expect("stream event should be Ok"));
3471 }
3472
3473 assert!(!events.is_empty(), "Stream should emit events");
3475
3476 let end_events: Vec<_> = events
3477 .iter()
3478 .filter_map(|e| {
3479 if let crate::stream::StreamEvent::End { output } = e {
3480 Some(output.clone())
3481 } else {
3482 None
3483 }
3484 })
3485 .collect();
3486
3487 assert!(!end_events.is_empty(), "Expected at least one End event");
3488
3489 for state in end_events {
3491 let _cloned_state = state.clone();
3492 }
3493 }
3494
3495 fn mock_node(name: &str) -> Arc<dyn crate::Node<StateDummy>> {
3496 NodeFnUpdate(|_s: &StateDummy| async move { Ok(StateDummyUpdate) }).into_node(name)
3497 }
3498
3499 #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize)]
3500 #[serde(crate = "serde")]
3501 struct StateDummy;
3502
3503 impl crate::State for StateDummy {
3504 type Update = StateDummyUpdate;
3505 type FieldVersions = crate::state::FieldVersions;
3506
3507 fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
3508 crate::FieldsChanged(0)
3509 }
3510
3511 fn reset_ephemeral(&mut self) {}
3512 }
3513
3514 #[derive(Clone, Debug, Default, serde::Serialize)]
3515 struct StateDummyUpdate;
3516
3517 #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize, PartialEq)]
3522 #[serde(crate = "serde")]
3523 struct StateV2 {
3524 value: i32,
3525 label: String,
3526 }
3527
3528 impl crate::State for StateV2 {
3529 type Update = StateV2Update;
3530 type FieldVersions = crate::state::FieldVersions;
3531
3532 fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
3533 crate::FieldsChanged(0)
3534 }
3535
3536 fn reset_ephemeral(&mut self) {}
3537
3538 fn schema_version() -> u32 {
3539 2
3540 }
3541
3542 fn migrate(from_version: u32, value: serde_json::Value) -> serde_json::Value {
3543 let mut map = match value {
3544 serde_json::Value::Object(m) => m,
3545 other => return other,
3546 };
3547 if from_version < 2 {
3548 map.insert(
3549 "label".to_string(),
3550 serde_json::Value::String("migrated".to_string()),
3551 );
3552 }
3553 serde_json::Value::Object(map)
3554 }
3555 }
3556
3557 #[derive(Clone, Debug, Default, serde::Serialize)]
3558 struct StateV2Update;
3559
3560 #[test]
3561 fn test_deserialize_with_migration_applies_migration_when_versions_differ() {
3562 use std::collections::HashMap;
3563
3564 let checkpoint = crate::checkpoint::Checkpoint {
3566 id: "test_id".to_string(),
3567 channel_values: serde_json::json!({"value": 42}),
3568 channel_versions: HashMap::new(),
3569 versions_seen: HashMap::new(),
3570 pending_tasks: Vec::new(),
3571 pending_sends: Vec::new(),
3572 pending_interrupts: Vec::new(),
3573 schema_version: 1, created_at: "2024-01-01T00:00:00Z".to_string(),
3575 v: 1,
3576 new_versions: HashMap::new(),
3577 counters_since_delta_snapshot: HashMap::new(),
3578 };
3579
3580 let state: StateV2 = CompiledGraph::<StateV2>::deserialize_with_migration(&checkpoint)
3581 .expect("deserialization with migration should succeed");
3582
3583 assert_eq!(state.value, 42);
3585 assert_eq!(state.label, "migrated");
3586 }
3587
3588 #[test]
3589 fn test_deserialize_with_migration_skips_migration_when_versions_match() {
3590 use std::collections::HashMap;
3591
3592 let checkpoint = crate::checkpoint::Checkpoint {
3594 id: "test_id".to_string(),
3595 channel_values: serde_json::json!({"value": 7, "label": "original"}),
3596 channel_versions: HashMap::new(),
3597 versions_seen: HashMap::new(),
3598 pending_tasks: Vec::new(),
3599 pending_sends: Vec::new(),
3600 pending_interrupts: Vec::new(),
3601 schema_version: 2, created_at: "2024-01-01T00:00:00Z".to_string(),
3603 v: 1,
3604 new_versions: HashMap::new(),
3605 counters_since_delta_snapshot: HashMap::new(),
3606 };
3607
3608 let state: StateV2 = CompiledGraph::<StateV2>::deserialize_with_migration(&checkpoint)
3609 .expect("deserialization should succeed");
3610
3611 assert_eq!(state.value, 7);
3613 assert_eq!(state.label, "original");
3614 }
3615
3616 #[test]
3617 fn test_compile_config_default_is_empty() {
3618 let config = super::super::CompileConfig::default();
3619 assert!(config.interrupt_before.is_empty());
3620 assert!(config.interrupt_after.is_empty());
3621 }
3622
3623 #[test]
3624 fn test_compile_with_config_stores_interrupts() {
3625 let mut graph = super::super::StateGraph::<StateDummy>::new();
3626 graph
3627 .add_node_simple(
3628 "human_review",
3629 NodeFnUpdate(
3630 |_s: &StateDummy| -> std::pin::Pin<
3631 Box<
3632 dyn std::future::Future<
3633 Output = Result<StateDummyUpdate, crate::JunctureError>,
3634 > + Send,
3635 >,
3636 > { Box::pin(async move { Ok(StateDummyUpdate) }) },
3637 ),
3638 )
3639 .unwrap();
3640 graph.set_entry_point("human_review");
3641 graph.set_finish_point("human_review");
3642
3643 let config = super::super::CompileConfig {
3644 interrupt_before: vec!["human_review".to_string()],
3645 interrupt_after: vec!["human_review".to_string()],
3646 };
3647
3648 let compiled = graph.compile_with_config(config).unwrap();
3649 assert_eq!(compiled.nodes().len(), 1);
3650 }
3651
3652 #[test]
3653 fn test_effective_config_uses_compile_time_defaults() {
3654 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3655 nodes.insert("a".to_string(), mock_node("a"));
3656
3657 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3658 nodes,
3659 TriggerTable::new(),
3660 IndexMap::new(),
3661 vec!["node_a".to_string()],
3662 vec!["node_b".to_string()],
3663 None,
3664 vec![],
3665 );
3666
3667 let config = RunnableConfig::new();
3669 let effective = compiled.effective_config(&config);
3670 assert_eq!(effective.interrupt_before, Some(vec!["node_a".to_string()]));
3671 assert_eq!(effective.interrupt_after, Some(vec!["node_b".to_string()]));
3672 }
3673
3674 #[test]
3675 fn test_effective_config_runtime_overrides_compile_time() {
3676 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3677 nodes.insert("a".to_string(), mock_node("a"));
3678
3679 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3680 nodes,
3681 TriggerTable::new(),
3682 IndexMap::new(),
3683 vec!["compile_before".to_string()],
3684 vec!["compile_after".to_string()],
3685 None,
3686 vec![],
3687 );
3688
3689 let config = RunnableConfig::new()
3691 .with_interrupt_before(vec!["runtime_before".to_string()])
3692 .with_interrupt_after(vec!["runtime_after".to_string()]);
3693
3694 let effective = compiled.effective_config(&config);
3695 assert_eq!(
3696 effective.interrupt_before,
3697 Some(vec!["runtime_before".to_string()])
3698 );
3699 assert_eq!(
3700 effective.interrupt_after,
3701 Some(vec!["runtime_after".to_string()])
3702 );
3703 }
3704
3705 #[test]
3706 fn test_effective_config_empty_compile_time_no_override() {
3707 let mut nodes: IndexMap<String, Arc<dyn crate::Node<StateDummy>>> = IndexMap::new();
3708 nodes.insert("a".to_string(), mock_node("a"));
3709
3710 let compiled: CompiledGraph<StateDummy> = CompiledGraph::new(
3711 nodes,
3712 TriggerTable::new(),
3713 IndexMap::new(),
3714 vec![],
3715 vec![],
3716 None,
3717 vec![],
3718 );
3719
3720 let config = RunnableConfig::new();
3722 let effective = compiled.effective_config(&config);
3723 assert!(effective.interrupt_before.is_none());
3724 assert!(effective.interrupt_after.is_none());
3725 }
3726
3727 #[derive(Clone, Debug, Default, serde::Deserialize, serde::Serialize, PartialEq)]
3731 #[serde(crate = "serde")]
3732 struct MultiFieldState {
3733 messages: Vec<String>,
3734 count: i32,
3735 label: String,
3736 }
3737
3738 impl crate::State for MultiFieldState {
3739 type Update = MultiFieldStateUpdate;
3740 type FieldVersions = crate::state::FieldVersions;
3741
3742 fn apply(&mut self, update: Self::Update) -> crate::FieldsChanged {
3743 let mut mask = 0u64;
3744 if let Some(messages) = update.messages {
3745 self.messages = messages;
3746 mask |= 1;
3747 }
3748 if let Some(count) = update.count {
3749 self.count = count;
3750 mask |= 1 << 1;
3751 }
3752 if let Some(label) = update.label {
3753 self.label = label;
3754 mask |= 1 << 2;
3755 }
3756 crate::FieldsChanged(mask)
3757 }
3758
3759 fn reset_ephemeral(&mut self) {}
3760 }
3761
3762 #[derive(Clone, Debug, Default, serde::Serialize)]
3763 struct MultiFieldStateUpdate {
3764 messages: Option<Vec<String>>,
3765 count: Option<i32>,
3766 label: Option<String>,
3767 }
3768
3769 fn multi_field_node(name: &str) -> Arc<dyn crate::Node<MultiFieldState>> {
3770 NodeFnUpdate(|_s: &MultiFieldState| async move {
3771 Ok(MultiFieldStateUpdate {
3772 messages: Some(vec!["hello".to_string()]),
3773 count: Some(1),
3774 label: Some("updated".to_string()),
3775 })
3776 })
3777 .into_node(name)
3778 }
3779
3780 fn build_multi_field_graph() -> CompiledGraph<MultiFieldState> {
3781 let mut nodes: IndexMap<String, Arc<dyn crate::Node<MultiFieldState>>> = IndexMap::new();
3782 nodes.insert("node_a".to_string(), multi_field_node("node_a"));
3783
3784 let mut trigger_table = TriggerTable::new();
3785 trigger_table.add_incoming(
3786 "node_a".to_string(),
3787 crate::edge::TriggerSource::Edge {
3788 from: crate::edge::START.to_string(),
3789 },
3790 );
3791
3792 CompiledGraph::new(
3793 nodes,
3794 trigger_table,
3795 IndexMap::new(),
3796 vec![],
3797 vec![],
3798 None,
3799 vec![],
3800 )
3801 }
3802
3803 #[tokio::test]
3804 async fn test_stream_with_config_no_output_keys_emits_values() {
3805 use futures::StreamExt;
3806
3807 let compiled = build_multi_field_graph();
3808 let config = RunnableConfig::new();
3809
3810 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values);
3811
3812 let handle = compiled
3813 .stream_with_config(
3814 MultiFieldState {
3815 messages: vec![],
3816 count: 0,
3817 label: String::new(),
3818 },
3819 &config,
3820 stream_config,
3821 )
3822 .await
3823 .expect("stream_with_config should succeed");
3824
3825 let mut events = Vec::new();
3826 let mut stream = handle.stream;
3827 while let Some(result) = stream.next().await {
3828 events.push(result.expect("stream event should be Ok"));
3829 }
3830
3831 let has_values = events
3833 .iter()
3834 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. }));
3835 assert!(has_values, "Expected Values events without output_keys");
3836 }
3837
3838 #[tokio::test]
3839 async fn test_stream_with_config_output_keys_emits_filtered_values() {
3840 use futures::StreamExt;
3841
3842 let compiled = build_multi_field_graph();
3843 let config = RunnableConfig::new();
3844
3845 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values)
3846 .with_output_keys(vec!["messages".to_string()]);
3847
3848 let handle = compiled
3849 .stream_with_config(
3850 MultiFieldState {
3851 messages: vec![],
3852 count: 0,
3853 label: String::new(),
3854 },
3855 &config,
3856 stream_config,
3857 )
3858 .await
3859 .expect("stream_with_config should succeed");
3860
3861 let mut events = Vec::new();
3862 let mut stream = handle.stream;
3863 while let Some(result) = stream.next().await {
3864 events.push(result.expect("stream event should be Ok"));
3865 }
3866
3867 let filtered: Vec<_> = events
3869 .iter()
3870 .filter_map(|e| {
3871 if let crate::stream::StreamEvent::FilteredValues { data, .. } = e {
3872 Some(data.clone())
3873 } else {
3874 None
3875 }
3876 })
3877 .collect();
3878
3879 assert!(
3880 !filtered.is_empty(),
3881 "Expected FilteredValues events with output_keys set"
3882 );
3883
3884 for data in &filtered {
3885 assert!(
3887 data.get("messages").is_some(),
3888 "FilteredValues should contain 'messages' key"
3889 );
3890 assert!(
3891 data.get("count").is_none(),
3892 "FilteredValues should not contain 'count' key"
3893 );
3894 assert!(
3895 data.get("label").is_none(),
3896 "FilteredValues should not contain 'label' key"
3897 );
3898 }
3899 }
3900
3901 #[tokio::test]
3902 async fn test_stream_delegates_to_stream_with_config() {
3903 use futures::StreamExt;
3904
3905 let compiled = build_multi_field_graph();
3906 let config = RunnableConfig::new();
3907
3908 let handle = compiled
3911 .stream(
3912 MultiFieldState {
3913 messages: vec![],
3914 count: 0,
3915 label: String::new(),
3916 },
3917 &config,
3918 StreamMode::Values,
3919 )
3920 .await
3921 .expect("stream should succeed");
3922
3923 let mut events = Vec::new();
3924 let mut stream = handle.stream;
3925 while let Some(result) = stream.next().await {
3926 events.push(result.expect("stream event should be Ok"));
3927 }
3928
3929 let has_values = events
3930 .iter()
3931 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. }));
3932 let has_end = events
3933 .iter()
3934 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
3935
3936 assert!(has_values, "stream() should emit Values events");
3937 assert!(has_end, "stream() should emit End event");
3938 }
3939
3940 #[tokio::test]
3941 async fn test_stream_with_config_output_keys_multiple_keys() {
3942 use futures::StreamExt;
3943
3944 let compiled = build_multi_field_graph();
3945 let config = RunnableConfig::new();
3946
3947 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values)
3948 .with_output_keys(vec!["messages".to_string(), "count".to_string()]);
3949
3950 let handle = compiled
3951 .stream_with_config(
3952 MultiFieldState {
3953 messages: vec![],
3954 count: 0,
3955 label: String::new(),
3956 },
3957 &config,
3958 stream_config,
3959 )
3960 .await
3961 .expect("stream_with_config should succeed");
3962
3963 let mut events = Vec::new();
3964 let mut stream = handle.stream;
3965 while let Some(result) = stream.next().await {
3966 events.push(result.expect("stream event should be Ok"));
3967 }
3968
3969 let filtered: Vec<_> = events
3970 .iter()
3971 .filter_map(|e| {
3972 if let crate::stream::StreamEvent::FilteredValues { data, .. } = e {
3973 Some(data.clone())
3974 } else {
3975 None
3976 }
3977 })
3978 .collect();
3979
3980 assert!(!filtered.is_empty());
3981
3982 for data in &filtered {
3983 assert!(
3984 data.get("messages").is_some(),
3985 "Should contain 'messages' key"
3986 );
3987 assert!(data.get("count").is_some(), "Should contain 'count' key");
3988 assert!(
3989 data.get("label").is_none(),
3990 "Should not contain 'label' key"
3991 );
3992 }
3993 }
3994
3995 #[tokio::test]
3996 async fn test_stream_with_config_updates_mode_output_keys() {
3997 use futures::StreamExt;
3998
3999 let compiled = build_multi_field_graph();
4000 let config = RunnableConfig::new();
4001
4002 let stream_config = crate::stream::StreamConfig::new(StreamMode::Updates)
4003 .with_output_keys(vec!["messages".to_string()]);
4004
4005 let handle = compiled
4006 .stream_with_config(
4007 MultiFieldState {
4008 messages: vec![],
4009 count: 0,
4010 label: String::new(),
4011 },
4012 &config,
4013 stream_config,
4014 )
4015 .await
4016 .expect("stream_with_config should succeed");
4017
4018 let mut events = Vec::new();
4019 let mut stream = handle.stream;
4020 while let Some(result) = stream.next().await {
4021 events.push(result.expect("stream event should be Ok"));
4022 }
4023
4024 let filtered_updates: Vec<_> = events
4026 .iter()
4027 .filter_map(|e| {
4028 if let crate::stream::StreamEvent::FilteredUpdates { data, .. } = e {
4029 Some(data.clone())
4030 } else {
4031 None
4032 }
4033 })
4034 .collect();
4035
4036 assert!(
4037 !filtered_updates.is_empty(),
4038 "Expected FilteredUpdates events in Updates mode with output_keys"
4039 );
4040
4041 for data in &filtered_updates {
4042 assert!(
4043 data.get("messages").is_some(),
4044 "FilteredUpdates should contain 'messages' key"
4045 );
4046 assert!(
4048 data.get("count").is_none(),
4049 "FilteredUpdates should not contain 'count' key"
4050 );
4051 assert!(
4052 data.get("label").is_none(),
4053 "FilteredUpdates should not contain 'label' key"
4054 );
4055 }
4056 }
4057
4058 #[test]
4059 fn test_filter_json_by_keys() {
4060 let json = serde_json::json!({
4061 "messages": ["hello"],
4062 "count": 42,
4063 "label": "test"
4064 });
4065
4066 let filtered = crate::stream::filter_json_by_keys(json, &["messages".to_string()]);
4067 assert!(filtered.get("messages").is_some());
4068 assert!(filtered.get("count").is_none());
4069 assert!(filtered.get("label").is_none());
4070 }
4071
4072 #[test]
4073 fn test_filter_json_by_keys_multiple() {
4074 let json = serde_json::json!({
4075 "a": 1,
4076 "b": 2,
4077 "c": 3
4078 });
4079
4080 let filtered =
4081 crate::stream::filter_json_by_keys(json, &["a".to_string(), "c".to_string()]);
4082 assert_eq!(filtered.get("a").unwrap(), 1);
4083 assert!(filtered.get("b").is_none());
4084 assert_eq!(filtered.get("c").unwrap(), 3);
4085 }
4086
4087 #[test]
4088 fn test_filter_json_by_keys_empty_keys() {
4089 let json = serde_json::json!({"a": 1});
4090 let filtered = crate::stream::filter_json_by_keys(json.clone(), &[]);
4091 assert_eq!(json, filtered);
4092 }
4093
4094 #[test]
4095 fn test_filter_json_by_keys_non_object() {
4096 let json = serde_json::json!("hello");
4097 let filtered = crate::stream::filter_json_by_keys(json.clone(), &["a".to_string()]);
4098 assert_eq!(json, filtered);
4099 }
4100
4101 #[test]
4104 fn test_stream_event_namespace_custom_has_ns() {
4105 let event: StreamEvent<StateDummy> = StreamEvent::Custom {
4106 node: "sub_node".to_string(),
4107 data: serde_json::json!({"x": 1}),
4108 ns: vec!["child_graph".to_string(), "sub_node:uuid".to_string()],
4109 };
4110 assert_eq!(event.namespace().len(), 2);
4111 assert_eq!(event.namespace()[0], "child_graph");
4112 }
4113
4114 #[test]
4115 fn test_stream_event_namespace_messages_has_ns() {
4116 let event: StreamEvent<StateDummy> = StreamEvent::Messages {
4117 chunk: crate::stream::MessageChunk {
4118 content: "hi".to_string(),
4119 tool_call_chunks: vec![],
4120 usage_delta: None,
4121 },
4122 metadata: crate::stream::MessageStreamMetadata {
4123 node: "llm".to_string(),
4124 model: "gpt-4".to_string(),
4125 tags: vec![],
4126 ns: vec!["child_graph".to_string()],
4127 },
4128 };
4129 assert_eq!(event.namespace().len(), 1);
4130 assert_eq!(event.namespace()[0], "child_graph");
4131 }
4132
4133 #[test]
4134 fn test_stream_event_namespace_interrupt_has_ns() {
4135 let event: StreamEvent<StateDummy> = StreamEvent::Interrupt {
4136 node: "review".to_string(),
4137 payload: serde_json::Value::Null,
4138 resumable: true,
4139 ns: vec!["subgraph_a".to_string()],
4140 };
4141 assert_eq!(event.namespace().len(), 1);
4142 }
4143
4144 #[test]
4145 fn test_stream_event_namespace_values_is_empty() {
4146 let event: StreamEvent<StateDummy> = StreamEvent::Values {
4147 state: StateDummy,
4148 step: 0,
4149 };
4150 assert!(event.namespace().is_empty());
4151 }
4152
4153 #[test]
4154 fn test_stream_event_namespace_updates_is_empty() {
4155 let event: StreamEvent<StateDummy> = StreamEvent::Updates {
4156 node: "n".to_string(),
4157 update: StateDummyUpdate,
4158 step: 0,
4159 };
4160 assert!(event.namespace().is_empty());
4161 }
4162
4163 #[test]
4164 fn test_stream_event_namespace_end_is_empty() {
4165 let event: StreamEvent<StateDummy> = StreamEvent::End { output: StateDummy };
4166 assert!(event.namespace().is_empty());
4167 }
4168
4169 #[test]
4170 fn test_stream_event_namespace_task_start_is_empty() {
4171 let event: StreamEvent<StateDummy> = StreamEvent::TaskStart {
4172 node: "n".to_string(),
4173 task_id: "t".to_string(),
4174 step: 0,
4175 };
4176 assert!(event.namespace().is_empty());
4177 }
4178
4179 #[test]
4180 fn test_stream_event_namespace_debug_is_empty() {
4181 let event: StreamEvent<StateDummy> =
4182 StreamEvent::Debug(crate::stream::DebugEvent::SuperstepStart {
4183 step: 0,
4184 pending_nodes: vec![],
4185 });
4186 assert!(event.namespace().is_empty());
4187 }
4188
4189 #[test]
4194 fn test_subgraph_filter_default_excludes_subgraph_events() {
4195 let subgraph_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4197 node: "sub_node".to_string(),
4198 data: serde_json::json!({}),
4199 ns: vec!["child_graph".to_string()],
4200 };
4201
4202 let top_level_event: StreamEvent<StateDummy> = StreamEvent::Values {
4204 state: StateDummy,
4205 step: 0,
4206 };
4207
4208 let include_subgraphs = false;
4209 assert!(top_level_event.namespace().is_empty());
4213 assert!(!subgraph_event.namespace().is_empty());
4215
4216 let ns = subgraph_event.namespace();
4219 let should_skip = !ns.is_empty() && !include_subgraphs;
4220 assert!(
4221 should_skip,
4222 "subgraph events should be skipped when include_subgraphs=false"
4223 );
4224
4225 let ns = top_level_event.namespace();
4226 let should_skip = !ns.is_empty() && !include_subgraphs;
4227 assert!(!should_skip, "top-level events should not be skipped");
4228 }
4229
4230 #[test]
4233 fn test_subgraph_filter_include_all_passes() {
4234 let subgraph_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4235 node: "sub_node".to_string(),
4236 data: serde_json::json!({}),
4237 ns: vec!["child_graph".to_string()],
4238 };
4239
4240 let include_subgraphs = true;
4241 let subgraph_filter: Option<Vec<String>> = None;
4242
4243 let ns = subgraph_event.namespace();
4244 let should_skip = !ns.is_empty() && !include_subgraphs;
4245 assert!(
4246 !should_skip,
4247 "include_subgraphs=true should not skip subgraph events"
4248 );
4249
4250 assert!(subgraph_filter.is_none());
4252 }
4253
4254 #[test]
4256 fn test_subgraph_filter_by_name_passes_matching() {
4257 let matching_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4258 node: "sub_node".to_string(),
4259 data: serde_json::json!({}),
4260 ns: vec!["child_a".to_string()],
4261 };
4262
4263 let non_matching_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4264 node: "sub_node".to_string(),
4265 data: serde_json::json!({}),
4266 ns: vec!["child_b".to_string()],
4267 };
4268
4269 let include_subgraphs = true;
4270 let subgraph_filter = Some(vec!["child_a".to_string()]);
4271
4272 let ns = matching_event.namespace();
4274 let should_skip = if ns.is_empty() {
4275 false
4276 } else if !include_subgraphs {
4277 true
4278 } else if let Some(ref filter) = subgraph_filter {
4279 ns.first().is_some_and(|first| !filter.contains(first))
4280 } else {
4281 false
4282 };
4283 assert!(!should_skip, "matching subgraph event should pass filter");
4284
4285 let ns = non_matching_event.namespace();
4287 let should_skip = if ns.is_empty() {
4288 false
4289 } else if !include_subgraphs {
4290 true
4291 } else if let Some(ref filter) = subgraph_filter {
4292 ns.first().is_some_and(|first| !filter.contains(first))
4293 } else {
4294 false
4295 };
4296 assert!(
4297 should_skip,
4298 "non-matching subgraph event should be filtered out"
4299 );
4300 }
4301
4302 #[test]
4305 fn test_subgraph_filter_applies_to_messages_events() {
4306 let subgraph_messages: StreamEvent<StateDummy> = StreamEvent::Messages {
4307 chunk: crate::stream::MessageChunk {
4308 content: "token".to_string(),
4309 tool_call_chunks: vec![],
4310 usage_delta: None,
4311 },
4312 metadata: crate::stream::MessageStreamMetadata {
4313 node: "llm".to_string(),
4314 model: "gpt-4".to_string(),
4315 tags: vec![],
4316 ns: vec!["sub_llm".to_string()],
4317 },
4318 };
4319
4320 let include_subgraphs = false;
4321 assert!(!subgraph_messages.namespace().is_empty());
4322
4323 let ns = subgraph_messages.namespace();
4324 let should_skip = !ns.is_empty() && !include_subgraphs;
4325 assert!(
4326 should_skip,
4327 "subgraph Messages events should be filtered when include_subgraphs=false"
4328 );
4329 }
4330
4331 #[test]
4334 fn test_subgraph_filter_applies_to_interrupt_events() {
4335 let subgraph_interrupt: StreamEvent<StateDummy> = StreamEvent::Interrupt {
4336 node: "review".to_string(),
4337 payload: serde_json::Value::Null,
4338 resumable: true,
4339 ns: vec!["sub_review".to_string()],
4340 };
4341
4342 let include_subgraphs = false;
4343 assert!(!subgraph_interrupt.namespace().is_empty());
4344
4345 let ns = subgraph_interrupt.namespace();
4346 let should_skip = !ns.is_empty() && !include_subgraphs;
4347 assert!(
4348 should_skip,
4349 "subgraph Interrupt events should be filtered when include_subgraphs=false"
4350 );
4351 }
4352
4353 #[test]
4358 fn test_nested_subgraph_default_excludes_nested_events() {
4359 let nested_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4361 node: "deep_node".to_string(),
4362 data: serde_json::json!({}),
4363 ns: vec!["parent".to_string(), "child".to_string()],
4364 };
4365
4366 let include_subgraphs = false;
4367
4368 assert_eq!(nested_event.namespace(), &["parent", "child"]);
4370 assert!(!nested_event.namespace().is_empty());
4371
4372 let should_skip = !nested_event.namespace().is_empty() && !include_subgraphs;
4373 assert!(
4374 should_skip,
4375 "nested subgraph events should be skipped when include_subgraphs=false"
4376 );
4377 }
4378
4379 #[test]
4382 fn test_nested_subgraph_include_all_passes() {
4383 let emitter_ns = vec!["parent".to_string(), "child".to_string()];
4386
4387 let nested_custom_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4389 node: "inner".to_string(),
4390 data: serde_json::json!({"k": "v"}),
4391 ns: emitter_ns,
4392 };
4393
4394 let include_subgraphs = true;
4395 let subgraph_filter: Option<Vec<String>> = None;
4396
4397 let should_skip = !nested_custom_event.namespace().is_empty() && !include_subgraphs;
4398 assert!(
4399 !should_skip,
4400 "nested subgraph events should pass when include_subgraphs=true"
4401 );
4402 assert!(subgraph_filter.is_none());
4403 }
4404
4405 #[test]
4408 fn test_nested_subgraph_filter_matches_outermost_name() {
4409 let nested_event: StreamEvent<StateDummy> = StreamEvent::Custom {
4411 node: "deep".to_string(),
4412 data: serde_json::json!({}),
4413 ns: vec![
4414 "parent".to_string(),
4415 "child".to_string(),
4416 "grandchild".to_string(),
4417 ],
4418 };
4419
4420 let include_subgraphs = true;
4421 let subgraph_filter = Some(vec!["parent".to_string()]);
4423
4424 let ns = nested_event.namespace();
4425 let should_skip = if ns.is_empty() {
4426 false
4427 } else if !include_subgraphs {
4428 true
4429 } else if let Some(ref filter) = subgraph_filter {
4430 ns.first().is_some_and(|first| !filter.contains(first))
4431 } else {
4432 false
4433 };
4434
4435 assert!(
4436 !should_skip,
4437 "nested event from parent should pass when parent is in filter"
4438 );
4439
4440 let subgraph_filter_other = Some(vec!["other".to_string()]);
4442 let should_skip_other = if ns.is_empty() {
4443 false
4444 } else if !include_subgraphs {
4445 true
4446 } else if let Some(ref filter) = subgraph_filter_other {
4447 ns.first().is_some_and(|first| !filter.contains(first))
4448 } else {
4449 false
4450 };
4451
4452 assert!(
4453 should_skip_other,
4454 "nested event should be skipped when outermost name does not match filter"
4455 );
4456 }
4457
4458 #[test]
4461 fn test_nested_subgraph_messages_filtering() {
4462 let nested_messages: StreamEvent<StateDummy> = StreamEvent::Messages {
4463 chunk: crate::stream::MessageChunk {
4464 content: "nested_token".to_string(),
4465 tool_call_chunks: vec![],
4466 usage_delta: None,
4467 },
4468 metadata: crate::stream::MessageStreamMetadata {
4469 node: "llm".to_string(),
4470 model: "gpt-4".to_string(),
4471 tags: vec![],
4472 ns: vec!["outer".to_string(), "inner".to_string()],
4473 },
4474 };
4475
4476 let include_subgraphs = false;
4477
4478 assert_eq!(
4480 nested_messages.namespace(),
4481 &["outer", "inner"],
4482 "Messages events should expose full nested namespace via metadata.ns"
4483 );
4484
4485 let should_skip = !nested_messages.namespace().is_empty() && !include_subgraphs;
4486 assert!(
4487 should_skip,
4488 "nested subgraph Messages events should be filtered when include_subgraphs=false"
4489 );
4490
4491 let include_subgraphs_true = true;
4493 let should_pass = nested_messages.namespace().is_empty() || include_subgraphs_true;
4494 assert!(
4495 should_pass,
4496 "nested subgraph Messages events should pass when include_subgraphs=true"
4497 );
4498 }
4499
4500 #[test]
4503 fn test_subgraph_transformer_to_emitter_nested_ns() {
4504 let transformer = crate::SubgraphTransformer::new("child".to_string());
4505 let transformer = transformer.child_transformer("grandchild");
4506
4507 let (tx, _rx) = tokio::sync::mpsc::channel(16);
4508 let emitter = transformer.to_emitter::<StateDummy>(tx, crate::stream::StreamMode::Values);
4509
4510 assert_eq!(emitter.ns(), &["child", "grandchild"]);
4512 }
4513
4514 #[test]
4517 fn test_transformer_child_chain_three_levels() {
4518 use crate::stream::StreamEvent;
4519
4520 let grandparent = crate::SubgraphTransformer::new("grandparent".to_string());
4521 let parent = grandparent.child_transformer("parent");
4522 let child = parent.child_transformer("child");
4523
4524 let event = StreamEvent::<StateDummy>::TaskStart {
4526 node: "worker".to_string(),
4527 task_id: "t1".to_string(),
4528 step: 1,
4529 };
4530
4531 let result = child.transform(&event).expect("should pass filter");
4532 match result {
4533 StreamEvent::TaskStart { node, .. } => {
4534 assert_eq!(node, "grandparent/parent/child/worker");
4535 }
4536 other => panic!("expected TaskStart, got {other:?}"),
4537 }
4538
4539 let custom_event = StreamEvent::<StateDummy>::Custom {
4541 node: "agent".to_string(),
4542 data: serde_json::json!({}),
4543 ns: vec![],
4544 };
4545 let result = child.transform(&custom_event).expect("custom should pass");
4546 match result {
4547 StreamEvent::Custom { node, ns, .. } => {
4548 assert_eq!(node, "grandparent/parent/child/agent");
4549 assert_eq!(ns, vec!["grandparent", "parent", "child"]);
4550 }
4551 other => panic!("expected Custom, got {other:?}"),
4552 }
4553 }
4554
4555 #[test]
4558 fn test_stream_config_subgraph_builder_methods() {
4559 let cfg = crate::stream::StreamConfig::new(StreamMode::Values);
4560 assert!(!cfg.include_subgraphs);
4561 assert!(cfg.subgraph_filter.is_none());
4562
4563 let cfg = cfg.with_subgraphs(true);
4564 assert!(cfg.include_subgraphs);
4565
4566 let cfg = cfg.with_subgraph_filter(vec!["sub_a".to_string()]);
4567 assert_eq!(cfg.subgraph_filter.as_ref().map(Vec::len), Some(1));
4568 assert_eq!(
4569 cfg.subgraph_filter
4570 .as_ref()
4571 .and_then(|f| f.first().cloned()),
4572 Some("sub_a".to_string())
4573 );
4574 }
4575
4576 #[tokio::test]
4579 async fn test_stream_default_config_no_subgraph_events() {
4580 use futures::StreamExt;
4581
4582 let compiled = build_multi_field_graph();
4583 let config = RunnableConfig::new();
4584 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values);
4585
4586 let handle = compiled
4587 .stream_with_config(
4588 MultiFieldState {
4589 messages: vec![],
4590 count: 0,
4591 label: String::new(),
4592 },
4593 &config,
4594 stream_config,
4595 )
4596 .await
4597 .expect("stream_with_config should succeed");
4598
4599 let mut events = Vec::new();
4600 let mut stream = handle.stream;
4601 while let Some(result) = stream.next().await {
4602 events.push(result.expect("stream event should be Ok"));
4603 }
4604
4605 for event in &events {
4607 assert!(
4608 event.namespace().is_empty(),
4609 "Expected no subgraph events, but found one with ns: {:?}",
4610 event.namespace()
4611 );
4612 }
4613 }
4614
4615 fn build_two_step_graph() -> CompiledGraph<MultiFieldState> {
4621 let node_a = NodeFnUpdate(|s: &MultiFieldState| {
4622 let messages = s.messages.clone();
4623 let count = s.count;
4624 let label = s.label.clone();
4625 async move {
4626 Ok(MultiFieldStateUpdate {
4627 messages: Some(messages),
4628 count: Some(count + 1),
4629 label: Some(label),
4630 })
4631 }
4632 })
4633 .into_node("node_a");
4634
4635 let node_b = NodeFnUpdate(|s: &MultiFieldState| {
4636 let messages = s.messages.clone();
4637 let count = s.count;
4638 let label = s.label.clone();
4639 async move {
4640 Ok(MultiFieldStateUpdate {
4641 messages: Some(messages),
4642 count: Some(count + 10),
4643 label: Some(label),
4644 })
4645 }
4646 })
4647 .into_node("node_b");
4648
4649 let mut nodes: IndexMap<String, Arc<dyn crate::Node<MultiFieldState>>> = IndexMap::new();
4650 nodes.insert("node_a".to_string(), node_a);
4651 nodes.insert("node_b".to_string(), node_b);
4652
4653 let mut trigger_table = TriggerTable::new();
4654 trigger_table.add_incoming(
4656 "node_a".to_string(),
4657 crate::edge::TriggerSource::Edge {
4658 from: crate::edge::START.to_string(),
4659 },
4660 );
4661 trigger_table.add_outgoing(
4663 "node_a".to_string(),
4664 crate::edge::CompiledEdge::Fixed {
4665 target: "node_b".to_string(),
4666 },
4667 );
4668
4669 CompiledGraph::new(
4670 nodes,
4671 trigger_table,
4672 IndexMap::new(),
4673 vec![],
4674 vec![],
4675 None,
4676 vec![],
4677 )
4678 }
4679
4680 #[tokio::test]
4681 async fn test_resumption_skips_values_at_or_before_last_step() {
4682 use futures::StreamExt;
4683
4684 let compiled = build_two_step_graph();
4685 let config = RunnableConfig::new();
4686
4687 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, Some(0));
4688 let stream_config =
4689 crate::stream::StreamConfig::new(StreamMode::Values).with_resumption(resumption);
4690
4691 let handle = compiled
4692 .stream_with_config(
4693 MultiFieldState {
4694 messages: vec![],
4695 count: 0,
4696 label: String::new(),
4697 },
4698 &config,
4699 stream_config,
4700 )
4701 .await
4702 .expect("stream_with_config should succeed");
4703
4704 let mut events = Vec::new();
4705 let mut stream = handle.stream;
4706 while let Some(result) = stream.next().await {
4707 events.push(result.expect("stream event should be Ok"));
4708 }
4709
4710 let values_steps: Vec<usize> = events
4712 .iter()
4713 .filter_map(|e| match e {
4714 crate::stream::StreamEvent::Values { step, .. } => Some(*step),
4715 _ => None,
4716 })
4717 .collect();
4718
4719 assert!(
4720 !values_steps.contains(&0),
4721 "Values at step 0 should be skipped, got steps: {values_steps:?}"
4722 );
4723
4724 let has_end = events
4725 .iter()
4726 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4727 assert!(has_end, "End event must always be emitted");
4728 }
4729
4730 #[tokio::test]
4731 async fn test_resumption_allows_values_after_last_step() {
4732 use futures::StreamExt;
4733
4734 let compiled = build_two_step_graph();
4735 let config = RunnableConfig::new();
4736
4737 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, Some(5));
4738 let stream_config =
4739 crate::stream::StreamConfig::new(StreamMode::Values).with_resumption(resumption);
4740
4741 let handle = compiled
4742 .stream_with_config(
4743 MultiFieldState {
4744 messages: vec![],
4745 count: 0,
4746 label: String::new(),
4747 },
4748 &config,
4749 stream_config,
4750 )
4751 .await
4752 .expect("stream_with_config should succeed");
4753
4754 let mut events = Vec::new();
4755 let mut stream = handle.stream;
4756 while let Some(result) = stream.next().await {
4757 events.push(result.expect("stream event should be Ok"));
4758 }
4759
4760 let values_steps: Vec<usize> = events
4762 .iter()
4763 .filter_map(|e| match e {
4764 crate::stream::StreamEvent::Values { step, .. } => Some(*step),
4765 _ => None,
4766 })
4767 .collect();
4768
4769 assert!(
4770 values_steps.is_empty(),
4771 "All Values should be skipped with last_step=5, got steps: {values_steps:?}"
4772 );
4773
4774 let has_end = events
4775 .iter()
4776 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4777 assert!(
4778 has_end,
4779 "End event must always be emitted even when all steps are skipped"
4780 );
4781 }
4782
4783 #[tokio::test]
4784 async fn test_resumption_none_last_step_allows_all_events() {
4785 use futures::StreamExt;
4786
4787 let compiled = build_two_step_graph();
4788 let config = RunnableConfig::new();
4789
4790 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, None);
4791 let stream_config =
4792 crate::stream::StreamConfig::new(StreamMode::Values).with_resumption(resumption);
4793
4794 let handle = compiled
4795 .stream_with_config(
4796 MultiFieldState {
4797 messages: vec![],
4798 count: 0,
4799 label: String::new(),
4800 },
4801 &config,
4802 stream_config,
4803 )
4804 .await
4805 .expect("stream_with_config should succeed");
4806
4807 let mut events = Vec::new();
4808 let mut stream = handle.stream;
4809 while let Some(result) = stream.next().await {
4810 events.push(result.expect("stream event should be Ok"));
4811 }
4812
4813 assert!(
4815 events
4816 .iter()
4817 .any(|e| matches!(e, crate::stream::StreamEvent::Values { .. })),
4818 "Values events should be emitted when last_step is None"
4819 );
4820
4821 let has_end = events
4822 .iter()
4823 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4824 assert!(has_end, "End event must be present");
4825 }
4826
4827 #[tokio::test]
4828 async fn test_resumption_skips_updates_at_or_before_last_step() {
4829 use futures::StreamExt;
4830
4831 let compiled = build_two_step_graph();
4832 let config = RunnableConfig::new();
4833
4834 let resumption = crate::stream::StreamResumption::new("run1".to_string(), None, Some(0));
4835 let stream_config =
4836 crate::stream::StreamConfig::new(StreamMode::Updates).with_resumption(resumption);
4837
4838 let handle = compiled
4839 .stream_with_config(
4840 MultiFieldState {
4841 messages: vec![],
4842 count: 0,
4843 label: String::new(),
4844 },
4845 &config,
4846 stream_config,
4847 )
4848 .await
4849 .expect("stream_with_config should succeed");
4850
4851 let mut events = Vec::new();
4852 let mut stream = handle.stream;
4853 while let Some(result) = stream.next().await {
4854 events.push(result.expect("stream event should be Ok"));
4855 }
4856
4857 let updates_steps: Vec<usize> = events
4859 .iter()
4860 .filter_map(|e| match e {
4861 crate::stream::StreamEvent::Updates { step, .. }
4862 | crate::stream::StreamEvent::FilteredUpdates { step, .. } => Some(*step),
4863 _ => None,
4864 })
4865 .collect();
4866
4867 assert!(
4868 !updates_steps.contains(&0),
4869 "Updates at step 0 should be skipped, got steps: {updates_steps:?}"
4870 );
4871
4872 let has_end = events
4873 .iter()
4874 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4875 assert!(has_end, "End event must always be emitted");
4876 }
4877
4878 #[tokio::test]
4879 async fn test_resumption_no_resumption_emits_all_events() {
4880 use futures::StreamExt;
4881
4882 let compiled = build_two_step_graph();
4883 let config = RunnableConfig::new();
4884
4885 let stream_config = crate::stream::StreamConfig::new(StreamMode::Values);
4887
4888 let handle = compiled
4889 .stream_with_config(
4890 MultiFieldState {
4891 messages: vec![],
4892 count: 0,
4893 label: String::new(),
4894 },
4895 &config,
4896 stream_config,
4897 )
4898 .await
4899 .expect("stream_with_config should succeed");
4900
4901 let mut events = Vec::new();
4902 let mut stream = handle.stream;
4903 while let Some(result) = stream.next().await {
4904 events.push(result.expect("stream event should be Ok"));
4905 }
4906
4907 let values_count = events
4908 .iter()
4909 .filter(|e| matches!(e, crate::stream::StreamEvent::Values { .. }))
4910 .count();
4911
4912 assert!(
4913 values_count >= 1,
4914 "At least one Values event expected without resumption"
4915 );
4916
4917 let has_end = events
4918 .iter()
4919 .any(|e| matches!(e, crate::stream::StreamEvent::End { .. }));
4920 assert!(has_end, "End event must be present");
4921 }
4922}
4923
4924