1use crate::{
7 State, checkpoint::CheckpointNamespace, command::Command, config::RunnableConfig,
8 error::JunctureError, node::Node,
9};
10use std::sync::Arc;
11
12fn compute_child_namespace(
18 persistence: SubgraphPersistence,
19 name: &str,
20 parent_ns: Option<&CheckpointNamespace>,
21 thread_id: Option<&str>,
22) -> Option<CheckpointNamespace> {
23 match persistence {
24 SubgraphPersistence::Stateless => None,
25 SubgraphPersistence::PerThread => {
26 let thread_key = thread_id.unwrap_or("default");
27 let base = parent_ns.cloned().unwrap_or_default();
28 Some(base.child(name, thread_key))
29 }
30 SubgraphPersistence::Inherit => {
31 let invocation_id = uuid::Uuid::new_v4().to_string();
32 let base = parent_ns.cloned().unwrap_or_default();
33 Some(base.child(name, &invocation_id))
34 }
35 }
36}
37
38pub trait StateSubset<Parent: State>: State {
74 fn extract(parent: &Parent) -> Self;
87
88 fn map_update(update: Self::Update) -> Parent::Update;
102}
103
104#[derive(Clone, Debug, Default)]
108pub struct SubgraphConfig {
109 pub persistence: SubgraphPersistence,
111}
112
113#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
117pub enum SubgraphPersistence {
118 #[default]
120 Inherit,
121
122 PerThread,
124
125 Stateless,
127}
128
129pub struct SubgraphMount<S: State> {
134 pub name: String,
136
137 pub config: SubgraphConfig,
139
140 pub node: Arc<dyn Node<S>>,
142}
143
144impl<S: State> std::fmt::Debug for SubgraphMount<S> {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("SubgraphMount")
147 .field("name", &self.name)
148 .field("config", &self.config)
149 .field("node", &"<node>")
150 .finish()
151 }
152}
153
154impl<S: State> SubgraphMount<S> {
155 #[must_use]
157 pub fn new(name: impl Into<String>, config: SubgraphConfig, node: Arc<dyn Node<S>>) -> Self {
158 Self {
159 name: name.into(),
160 config,
161 node,
162 }
163 }
164
165 #[must_use]
169 pub fn with_name(mut self, name: impl Into<String>) -> Self {
170 self.name = name.into();
171 self
172 }
173
174 #[must_use]
178 pub const fn with_config(mut self, config: SubgraphConfig) -> Self {
179 self.config = config;
180 self
181 }
182
183 #[must_use]
189 pub const fn with_persistence(mut self, persistence: SubgraphPersistence) -> Self {
190 self.config.persistence = persistence;
191 self
192 }
193}
194
195pub struct SubgraphNode<S: State, Sub: State> {
209 pub subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
211
212 pub name: String,
214
215 #[allow(
217 clippy::type_complexity,
218 reason = "requires type erasure for trait object"
219 )]
220 pub input_map: Arc<dyn Fn(&S) -> Sub + Send + Sync>,
221
222 #[allow(
224 clippy::type_complexity,
225 reason = "requires type erasure for trait object"
226 )]
227 pub output_map: Arc<dyn Fn(&Sub) -> S::Update + Send + Sync>,
228
229 pub config: SubgraphConfig,
231}
232
233impl<S: State, Sub: State> std::fmt::Debug for SubgraphNode<S, Sub> {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.debug_struct("SubgraphNode")
236 .field("subgraph", &"<graph>")
237 .field("name", &self.name)
238 .field("input_map", &"<fn>")
239 .field("output_map", &"<fn>")
240 .field("config", &self.config)
241 .finish()
242 }
243}
244
245impl<S: State, Sub: State> SubgraphNode<S, Sub> {
246 #[must_use]
248 #[allow(
249 clippy::type_complexity,
250 reason = "requires type erasure for trait object"
251 )]
252 pub fn new(
253 subgraph: Arc<crate::graph::CompiledGraph<Sub>>,
254 name: String,
255 #[allow(
256 clippy::type_complexity,
257 reason = "requires type erasure for trait object"
258 )]
259 input_map: Arc<dyn Fn(&S) -> Sub + Send + Sync>,
260 #[allow(
261 clippy::type_complexity,
262 reason = "requires type erasure for trait object"
263 )]
264 output_map: Arc<dyn Fn(&Sub) -> S::Update + Send + Sync>,
265 config: SubgraphConfig,
266 ) -> Self {
267 Self {
268 subgraph,
269 name,
270 input_map,
271 output_map,
272 config,
273 }
274 }
275}
276
277impl<S: State, Sub> Node<S> for SubgraphNode<S, Sub>
278where
279 Sub: State + serde::Serialize + for<'de> serde::Deserialize<'de>,
280 Sub::Update: serde::Serialize,
281{
282 fn call(
283 &self,
284 state: &S,
285 config: &RunnableConfig,
286 ) -> std::pin::Pin<
287 Box<dyn std::future::Future<Output = Result<Command<S>, JunctureError>> + Send + '_>,
288 > {
289 let config = config.clone();
290 let subgraph = Arc::clone(&self.subgraph);
291 let input_map = Arc::clone(&self.input_map);
292 let output_map = Arc::clone(&self.output_map);
293 let name = self.name.clone();
294 let persistence = self.config.persistence;
295 let state = state.clone();
296
297 Box::pin(async move {
298 let child_ns = compute_child_namespace(
301 persistence,
302 &name,
303 config.checkpoint_ns.as_ref(),
304 config.thread_id.as_deref(),
305 );
306
307 let mut child_config = config.clone();
309 child_config.checkpoint_ns = child_ns;
310
311 if matches!(persistence, SubgraphPersistence::Stateless) {
313 child_config.resume_value = None;
314 }
315 let should_resume = if let Some(checkpointer) = subgraph.checkpointer() {
324 checkpointer
325 .get_tuple(&child_config)
326 .await
327 .ok()
328 .flatten()
329 .is_some_and(|tuple| {
330 matches!(
331 tuple.metadata.source,
332 crate::checkpoint::CheckpointSource::Interrupt { .. }
333 )
334 })
335 } else {
336 false
337 };
338
339 let sub_output = if should_resume {
340 let resume_val = child_config.resume_value.clone().unwrap_or(
343 crate::interrupt::ResumeValue::Single(serde_json::Value::Null),
344 );
345 subgraph.resume(&child_config, resume_val).await
346 } else {
347 let sub_input = (input_map)(&state);
349 subgraph.invoke_async(sub_input, &child_config).await
350 };
351
352 let sub_output = match sub_output {
354 Ok(output) => output,
355 Err(e) if e.is_parent_command() => {
356 let target = e.parent_command_target().unwrap_or("END");
359 return Ok(Command::goto(target));
360 }
361 Err(e) if e.is_interrupt() => {
362 return Err(e);
365 }
366 Err(e) => {
367 return Err(JunctureError::subgraph(format!("{name}: {e}")));
369 }
370 };
371
372 let update = (output_map)(&sub_output.value);
374
375 Ok(Command::update(update))
376 })
377 }
378
379 fn name(&self) -> &str {
380 &self.name
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::{node::IntoNode, node::NodeFnUpdate};
388
389 #[test]
390 fn test_subgraph_config_default() {
391 let config = SubgraphConfig::default();
392 assert_eq!(config.persistence, SubgraphPersistence::Inherit);
393 }
394
395 #[test]
396 fn test_subgraph_persistence_variants() {
397 let inherit = SubgraphPersistence::Inherit;
398 let per_thread = SubgraphPersistence::PerThread;
399 let stateless = SubgraphPersistence::Stateless;
400
401 assert_ne!(inherit, per_thread);
402 assert_ne!(inherit, stateless);
403 assert_ne!(per_thread, stateless);
404 }
405
406 #[test]
407 fn test_subgraph_mount_creation() {
408 let node = mock_node("test");
409 let mount = SubgraphMount::new("subgraph_test", SubgraphConfig::default(), node);
410
411 assert_eq!(mount.name, "subgraph_test");
412 assert_eq!(mount.config.persistence, SubgraphPersistence::Inherit);
413 }
414
415 #[test]
416 fn test_with_name_changes_name() {
417 let node = mock_node("test");
418 let mount =
419 SubgraphMount::new("original", SubgraphConfig::default(), node).with_name("renamed");
420
421 assert_eq!(mount.name, "renamed");
422 }
423
424 #[test]
425 fn test_with_config_replaces_config() {
426 let node = mock_node("test");
427 let custom_config = SubgraphConfig {
428 persistence: SubgraphPersistence::Stateless,
429 };
430 let mount =
431 SubgraphMount::new("sg", SubgraphConfig::default(), node).with_config(custom_config);
432
433 assert_eq!(mount.config.persistence, SubgraphPersistence::Stateless);
434 }
435
436 #[test]
437 fn test_with_persistence_sets_mode() {
438 let node = mock_node("test");
439 let mount = SubgraphMount::new("sg", SubgraphConfig::default(), node)
440 .with_persistence(SubgraphPersistence::PerThread);
441
442 assert_eq!(mount.config.persistence, SubgraphPersistence::PerThread);
443 }
444
445 #[test]
446 fn test_builder_chaining() {
447 let node = mock_node("test");
448 let mount = SubgraphMount::new("initial", SubgraphConfig::default(), node)
449 .with_name("chained")
450 .with_persistence(SubgraphPersistence::Stateless);
451
452 assert_eq!(mount.name, "chained");
453 assert_eq!(mount.config.persistence, SubgraphPersistence::Stateless);
454 }
455
456 #[test]
457 fn test_with_name_accepts_non_string_types() {
458 let node = mock_node("test");
459 let mount = SubgraphMount::new("x", SubgraphConfig::default(), node)
460 .with_name(String::from("from_string"));
461
462 assert_eq!(mount.name, "from_string");
463 }
464
465 #[test]
466 fn test_checkpoint_namespace_separator() {
467 let ns = crate::checkpoint::CheckpointNamespace::root();
469 let child = ns.child("node1", "id1");
470 let grandchild = child.child("node2", "id2");
471
472 assert_eq!(child.as_str(), "|node1:id1");
473 assert_eq!(grandchild.as_str(), "|node1:id1|node2:id2");
474
475 let parsed = crate::checkpoint::CheckpointNamespace::parse("|node1:id1|node2:id2");
477 assert_eq!(parsed.as_str(), "|node1:id1|node2:id2");
478
479 assert_eq!(ns.as_str(), "");
481 assert!(ns.is_root());
482 }
483
484 #[test]
487 fn test_stateless_namespace_is_none() {
488 let ns = compute_child_namespace(
489 SubgraphPersistence::Stateless,
490 "my_sub",
491 None,
492 Some("thread-42"),
493 );
494 assert!(ns.is_none(), "Stateless should return None for namespace");
495 }
496
497 #[test]
498 fn test_stateless_namespace_is_none_even_with_parent_ns() {
499 let parent = CheckpointNamespace::parse("|parent:abc");
500 let ns = compute_child_namespace(
501 SubgraphPersistence::Stateless,
502 "my_sub",
503 Some(&parent),
504 Some("thread-42"),
505 );
506 assert!(
507 ns.is_none(),
508 "Stateless should return None even with parent namespace"
509 );
510 }
511
512 #[test]
513 fn test_perthread_namespace_uses_thread_id() {
514 let ns = compute_child_namespace(
515 SubgraphPersistence::PerThread,
516 "my_sub",
517 None,
518 Some("thread-42"),
519 );
520 let ns = ns.expect("PerThread should produce a namespace");
521 assert_eq!(ns.as_str(), "|my_sub:thread-42");
522 }
523
524 #[test]
525 fn test_perthread_namespace_appends_to_parent_ns() {
526 let parent = CheckpointNamespace::parse("|parent:abc");
527 let ns = compute_child_namespace(
528 SubgraphPersistence::PerThread,
529 "my_sub",
530 Some(&parent),
531 Some("thread-42"),
532 );
533 let ns = ns.expect("PerThread should produce a namespace");
534 assert_eq!(ns.as_str(), "|parent:abc|my_sub:thread-42");
535 }
536
537 #[test]
538 fn test_perthread_namespace_falls_back_to_default() {
539 let ns = compute_child_namespace(SubgraphPersistence::PerThread, "my_sub", None, None);
540 let ns = ns.expect("PerThread should produce a namespace");
541 assert_eq!(ns.as_str(), "|my_sub:default");
542 }
543
544 #[test]
545 fn test_perthread_namespace_is_stable() {
546 let a = compute_child_namespace(SubgraphPersistence::PerThread, "sub", None, Some("t1"));
548 let b = compute_child_namespace(SubgraphPersistence::PerThread, "sub", None, Some("t1"));
549 assert_eq!(a, b);
550 }
551
552 #[test]
553 fn test_inherit_namespace_is_uuid_based() {
554 let ns = compute_child_namespace(
555 SubgraphPersistence::Inherit,
556 "my_sub",
557 None,
558 Some("thread-42"),
559 );
560 let ns = ns.expect("Inherit should produce a namespace");
561 let rendered = ns.as_str();
562 assert!(rendered.starts_with("|my_sub:"));
563 let uuid_part = rendered.strip_prefix("|my_sub:").expect("prefix present");
565 assert!(
566 uuid::Uuid::parse_str(uuid_part).is_ok(),
567 "suffix should be a valid UUID, got: {uuid_part}"
568 );
569 }
570
571 #[test]
572 fn test_inherit_namespace_appends_to_parent_ns() {
573 let parent = CheckpointNamespace::parse("|parent:abc");
574 let ns = compute_child_namespace(
575 SubgraphPersistence::Inherit,
576 "my_sub",
577 Some(&parent),
578 Some("thread-42"),
579 );
580 let ns = ns.expect("Inherit should produce a namespace");
581 let rendered = ns.as_str();
582 assert!(rendered.starts_with("|parent:abc|my_sub:"));
583 let uuid_part = rendered
584 .strip_prefix("|parent:abc|my_sub:")
585 .expect("prefix present");
586 assert!(
587 uuid::Uuid::parse_str(uuid_part).is_ok(),
588 "suffix should be a valid UUID, got: {uuid_part}"
589 );
590 }
591
592 #[test]
593 fn test_inherit_namespace_differs_between_invocations() {
594 let a = compute_child_namespace(SubgraphPersistence::Inherit, "sub", None, Some("t1"));
597 let b = compute_child_namespace(SubgraphPersistence::Inherit, "sub", None, Some("t1"));
598 assert_ne!(a, b, "Inherit mode should produce unique namespaces");
599 }
600
601 #[test]
602 fn send_fan_out_produces_unique_namespaces() {
603 let count = 10;
607 let namespaces: Vec<Option<CheckpointNamespace>> = (0..count)
608 .map(|_| {
609 compute_child_namespace(SubgraphPersistence::Inherit, "worker", None, Some("t1"))
610 })
611 .collect();
612
613 assert!(
615 namespaces.iter().all(Option::is_some),
616 "all Inherit invocations should produce a namespace"
617 );
618
619 let unique: std::collections::HashSet<String> = namespaces
621 .iter()
622 .map(|ns| {
623 ns.as_ref()
624 .map_or_else(String::new, CheckpointNamespace::as_str)
625 })
626 .collect();
627 assert_eq!(
628 unique.len(),
629 count,
630 "Send fan-out to subgraph must produce {count} distinct namespaces"
631 );
632
633 for ns in &namespaces {
635 let rendered = ns
636 .as_ref()
637 .map_or_else(String::new, CheckpointNamespace::as_str);
638 assert!(
639 rendered.starts_with("|worker:"),
640 "namespace should start with '|worker:', got: {rendered}"
641 );
642 let uuid_part = rendered.strip_prefix("|worker:").unwrap_or("");
643 assert!(
644 uuid::Uuid::parse_str(uuid_part).is_ok(),
645 "suffix must be a valid UUID, got: {uuid_part}"
646 );
647 }
648 }
649
650 fn make_transformer(name: &str) -> SubgraphTransformer {
653 SubgraphTransformer::new(name.to_string())
654 }
655
656 fn make_nested_transformer(name: &str, parent_ns: &[&str]) -> SubgraphTransformer {
657 let mut t = SubgraphTransformer::new(name.to_string());
658 for segment in parent_ns {
659 t.add_namespace((*segment).to_string());
660 }
661 t
662 }
663
664 #[test]
665 fn transform_updates_prefixes_node_name() {
666 let t = make_transformer("review");
667 let event = crate::stream::StreamEvent::<StateDummy>::Updates {
668 node: "agent".to_string(),
669 update: StateDummyUpdate,
670 step: 1,
671 };
672 let result = t.transform(&event).expect("should pass filter");
673 match result {
674 crate::stream::StreamEvent::Updates { node, .. } => {
675 assert_eq!(node, "review/agent");
676 }
677 other => panic!("expected Updates, got {other:?}"),
678 }
679 }
680
681 #[test]
682 fn transform_filtered_updates_prefixes_node_name() {
683 let t = make_transformer("review");
684 let event = crate::stream::StreamEvent::<StateDummy>::FilteredUpdates {
685 node: "agent".to_string(),
686 data: serde_json::json!({"key": "val"}),
687 step: 2,
688 };
689 let result = t.transform(&event).expect("should pass filter");
690 match result {
691 crate::stream::StreamEvent::FilteredUpdates { node, .. } => {
692 assert_eq!(node, "review/agent");
693 }
694 other => panic!("expected FilteredUpdates, got {other:?}"),
695 }
696 }
697
698 #[test]
699 fn transform_task_start_prefixes_node_name() {
700 let t = make_transformer("sub");
701 let event = crate::stream::StreamEvent::<StateDummy>::TaskStart {
702 node: "worker".to_string(),
703 task_id: "t1".to_string(),
704 step: 3,
705 };
706 let result = t.transform(&event).expect("should pass filter");
707 match result {
708 crate::stream::StreamEvent::TaskStart {
709 node,
710 task_id,
711 step,
712 } => {
713 assert_eq!(node, "sub/worker");
714 assert_eq!(task_id, "t1");
715 assert_eq!(step, 3);
716 }
717 other => panic!("expected TaskStart, got {other:?}"),
718 }
719 }
720
721 #[test]
722 fn transform_task_end_prefixes_node_name() {
723 let t = make_transformer("sub");
724 let event = crate::stream::StreamEvent::<StateDummy>::TaskEnd {
725 node: "worker".to_string(),
726 task_id: "t1".to_string(),
727 step: 3,
728 duration_ms: 150,
729 };
730 let result = t.transform(&event).expect("should pass filter");
731 match result {
732 crate::stream::StreamEvent::TaskEnd {
733 node, duration_ms, ..
734 } => {
735 assert_eq!(node, "sub/worker");
736 assert_eq!(duration_ms, 150);
737 }
738 other => panic!("expected TaskEnd, got {other:?}"),
739 }
740 }
741
742 #[test]
743 fn transform_task_detail_prefixes_node_name() {
744 let t = make_transformer("sub");
745 let event = crate::stream::StreamEvent::<StateDummy>::TaskDetail {
746 task_id: "t2".to_string(),
747 node: "inner".to_string(),
748 step: 4,
749 attempt: 1,
750 event: crate::stream::TaskEventType::Started,
751 };
752 let result = t.transform(&event).expect("should pass filter");
753 match result {
754 crate::stream::StreamEvent::TaskDetail { task_id, node, .. } => {
755 assert_eq!(task_id, "t2");
756 assert_eq!(node, "sub/inner");
757 }
758 other => panic!("expected TaskDetail, got {other:?}"),
759 }
760 }
761
762 #[test]
763 fn transform_custom_prefixes_node_and_ns() {
764 let t = make_transformer("review");
765 let event = crate::stream::StreamEvent::<StateDummy>::Custom {
766 node: "agent".to_string(),
767 data: serde_json::json!({"action": "thinking"}),
768 ns: vec!["old_ns".to_string()],
769 };
770 let result = t.transform(&event).expect("should pass filter");
771 match result {
772 crate::stream::StreamEvent::Custom { node, ns, .. } => {
773 assert_eq!(node, "review/agent");
774 assert_eq!(ns, vec!["review"]);
776 }
777 other => panic!("expected Custom, got {other:?}"),
778 }
779 }
780
781 #[test]
782 fn transform_interrupt_prefixes_node_and_ns() {
783 let t = make_transformer("review");
784 let event = crate::stream::StreamEvent::<StateDummy>::Interrupt {
785 node: "agent".to_string(),
786 payload: serde_json::json!({"question": "approve?"}),
787 resumable: true,
788 ns: vec![],
789 };
790 let result = t.transform(&event).expect("should pass filter");
791 match result {
792 crate::stream::StreamEvent::Interrupt {
793 node,
794 ns,
795 resumable,
796 ..
797 } => {
798 assert_eq!(node, "review/agent");
799 assert_eq!(ns, vec!["review"]);
800 assert!(resumable);
801 }
802 other => panic!("expected Interrupt, got {other:?}"),
803 }
804 }
805
806 #[test]
807 fn transform_messages_prefixes_node_in_metadata() {
808 let t = make_transformer("sub");
809 let event = crate::stream::StreamEvent::<StateDummy>::Messages {
810 chunk: crate::stream::MessageChunk {
811 content: "hello".to_string(),
812 tool_call_chunks: vec![],
813 usage_delta: None,
814 },
815 metadata: crate::stream::MessageStreamMetadata {
816 node: "llm".to_string(),
817 model: "gpt-4".to_string(),
818 tags: vec![],
819 ns: vec![],
820 },
821 };
822 let result = t.transform(&event).expect("should pass filter");
823 match result {
824 crate::stream::StreamEvent::Messages { metadata, .. } => {
825 assert_eq!(metadata.node, "sub/llm");
826 assert_eq!(metadata.ns, vec!["sub"]);
827 assert_eq!(metadata.model, "gpt-4");
828 }
829 other => panic!("expected Messages, got {other:?}"),
830 }
831 }
832
833 #[test]
836 fn transform_values_passes_through() {
837 let t = make_transformer("sub");
838 let event = crate::stream::StreamEvent::<StateDummy>::Values {
839 state: StateDummy,
840 step: 5,
841 };
842 let result = t.transform(&event).expect("should pass filter");
843 match result {
844 crate::stream::StreamEvent::Values { step, .. } => assert_eq!(step, 5),
845 other => panic!("expected Values, got {other:?}"),
846 }
847 }
848
849 #[test]
850 fn transform_end_passes_through() {
851 let t = make_transformer("sub");
852 let event = crate::stream::StreamEvent::<StateDummy>::End { output: StateDummy };
853 let result = t.transform(&event).expect("should pass filter");
854 match result {
855 crate::stream::StreamEvent::End { .. } => {}
856 other => panic!("expected End, got {other:?}"),
857 }
858 }
859
860 #[test]
861 fn transform_budget_exceeded_passes_through() {
862 let t = make_transformer("sub");
863 let event = crate::stream::StreamEvent::<StateDummy>::BudgetExceeded {
864 reason: crate::pregel::BudgetExceededReason::Steps {
865 used: 25,
866 limit: 25,
867 },
868 usage: crate::stream::BudgetUsage {
869 tokens_used: 1000,
870 cost_usd: 0.05,
871 duration_ms: 200,
872 steps_completed: 25,
873 },
874 };
875 let result = t.transform(&event).expect("should pass filter");
876 match result {
877 crate::stream::StreamEvent::BudgetExceeded { .. } => {}
878 other => panic!("expected BudgetExceeded, got {other:?}"),
879 }
880 }
881
882 #[test]
883 fn transform_checkpoint_saved_passes_through() {
884 let t = make_transformer("sub");
885 let event = crate::stream::StreamEvent::<StateDummy>::CheckpointSaved {
886 checkpoint_id: "cp-1".to_string(),
887 metadata: crate::checkpoint::CheckpointMetadata {
888 source: crate::checkpoint::CheckpointSource::Loop,
889 step: 1,
890 writes: std::collections::HashMap::new(),
891 parents: std::collections::HashMap::new(),
892 run_id: "run-1".to_string(),
893 },
894 step: 1,
895 };
896 let result = t.transform(&event).expect("should pass filter");
897 match result {
898 crate::stream::StreamEvent::CheckpointSaved { checkpoint_id, .. } => {
899 assert_eq!(checkpoint_id, "cp-1");
900 }
901 other => panic!("expected CheckpointSaved, got {other:?}"),
902 }
903 }
904
905 #[test]
908 fn transform_nested_namespace_prefixes_correctly() {
909 let t = make_nested_transformer("child", &["parent", "middle"]);
910 let event = crate::stream::StreamEvent::<StateDummy>::Updates {
911 node: "agent".to_string(),
912 update: StateDummyUpdate,
913 step: 1,
914 };
915 let result = t.transform(&event).expect("should pass filter");
916 match result {
917 crate::stream::StreamEvent::Updates { node, .. } => {
918 assert_eq!(node, "parent/middle/child/agent");
919 }
920 other => panic!("expected Updates, got {other:?}"),
921 }
922 }
923
924 #[test]
925 fn transform_nested_custom_sets_full_ns() {
926 let t = make_nested_transformer("child", &["parent"]);
927 let event = crate::stream::StreamEvent::<StateDummy>::Custom {
928 node: "agent".to_string(),
929 data: serde_json::json!({}),
930 ns: vec![],
931 };
932 let result = t.transform(&event).expect("should pass filter");
933 match result {
934 crate::stream::StreamEvent::Custom { node, ns, .. } => {
935 assert_eq!(node, "parent/child/agent");
936 assert_eq!(ns, vec!["parent", "child"]);
937 }
938 other => panic!("expected Custom, got {other:?}"),
939 }
940 }
941
942 #[test]
943 fn transform_nested_interrupt_sets_full_ns() {
944 let t = make_nested_transformer("grandchild", &["parent", "child"]);
945 let event = crate::stream::StreamEvent::<StateDummy>::Interrupt {
946 node: "agent".to_string(),
947 payload: serde_json::Value::Null,
948 resumable: false,
949 ns: vec!["old".to_string()],
950 };
951 let result = t.transform(&event).expect("should pass filter");
952 match result {
953 crate::stream::StreamEvent::Interrupt { node, ns, .. } => {
954 assert_eq!(node, "parent/child/grandchild/agent");
955 assert_eq!(ns, vec!["parent", "child", "grandchild"]);
956 }
957 other => panic!("expected Interrupt, got {other:?}"),
958 }
959 }
960
961 #[test]
964 fn transform_filter_rejects_non_matching_type() {
965 let t = SubgraphTransformer::new("sub".to_string())
966 .with_filter_types(vec!["updates".to_string()]);
967
968 let event = crate::stream::StreamEvent::<StateDummy>::TaskStart {
969 node: "worker".to_string(),
970 task_id: "t1".to_string(),
971 step: 1,
972 };
973 assert!(
974 t.transform(&event).is_none(),
975 "task_start should be filtered"
976 );
977 }
978
979 #[test]
980 fn transform_filter_allows_matching_type() {
981 let t = SubgraphTransformer::new("sub".to_string())
982 .with_filter_types(vec!["updates".to_string()]);
983
984 let event = crate::stream::StreamEvent::<StateDummy>::Updates {
985 node: "agent".to_string(),
986 update: StateDummyUpdate,
987 step: 1,
988 };
989 let result = t.transform(&event).expect("updates should pass filter");
990 match result {
991 crate::stream::StreamEvent::Updates { node, .. } => {
992 assert_eq!(node, "sub/agent");
993 }
994 other => panic!("expected Updates, got {other:?}"),
995 }
996 }
997
998 #[test]
999 fn transform_filter_empty_types_allows_all() {
1000 let t = SubgraphTransformer::new("sub".to_string()).with_filter_types(vec![]);
1001 let event = crate::stream::StreamEvent::<StateDummy>::End { output: StateDummy };
1002 assert!(
1003 t.transform(&event).is_some(),
1004 "empty filter should allow all"
1005 );
1006 }
1007
1008 #[test]
1011 fn to_emitter_creates_emitter_with_correct_ns() {
1012 let (tx, _rx) = tokio::sync::mpsc::channel(1);
1013 let t = SubgraphTransformer::new("review".to_string());
1014 let emitter = t.to_emitter::<StateDummy>(tx, crate::stream::StreamMode::Updates);
1015 assert_eq!(emitter.ns(), &["review"]);
1016 }
1017
1018 #[test]
1019 fn to_emitter_with_parent_ns() {
1020 let (tx, _rx) = tokio::sync::mpsc::channel(1);
1021 let t = make_nested_transformer("child", &["parent"]);
1022 let emitter = t.to_emitter::<StateDummy>(tx, crate::stream::StreamMode::Values);
1023 assert_eq!(emitter.ns(), &["parent", "child"]);
1024 }
1025
1026 #[test]
1027 fn to_emitter_with_deep_nesting() {
1028 let (tx, _rx) = tokio::sync::mpsc::channel(1);
1029 let t = make_nested_transformer("grandchild", &["root", "parent"]);
1030 let emitter = t.to_emitter::<StateDummy>(tx, crate::stream::StreamMode::Custom);
1031 assert_eq!(emitter.ns(), &["root", "parent", "grandchild"]);
1032 }
1033
1034 #[test]
1035 fn child_transformer_produces_correct_build_ns() {
1036 let parent = SubgraphTransformer::new("parent".to_string());
1037 let child = parent.child_transformer("child");
1038 let event = crate::stream::StreamEvent::<StateDummy>::Updates {
1039 node: "worker".to_string(),
1040 update: StateDummyUpdate,
1041 step: 1,
1042 };
1043 let result = child.transform(&event).expect("should pass filter");
1044 match result {
1045 crate::stream::StreamEvent::Updates { node, .. } => {
1046 assert_eq!(node, "parent/child/worker");
1047 }
1048 other => panic!("expected Updates, got {other:?}"),
1049 }
1050 }
1051
1052 #[test]
1053 fn child_transformer_three_level_deep() {
1054 let root = SubgraphTransformer::new("root".to_string());
1055 let middle = root.child_transformer("middle");
1056 let leaf = middle.child_transformer("leaf");
1057
1058 let event = crate::stream::StreamEvent::<StateDummy>::Custom {
1059 node: "agent".to_string(),
1060 data: serde_json::json!({"key": "val"}),
1061 ns: vec![],
1062 };
1063 let result = leaf.transform(&event).expect("should pass filter");
1064 match result {
1065 crate::stream::StreamEvent::Custom { node, ns, .. } => {
1066 assert_eq!(node, "root/middle/leaf/agent");
1067 assert_eq!(ns, vec!["root", "middle", "leaf"]);
1068 }
1069 other => panic!("expected Custom, got {other:?}"),
1070 }
1071 }
1072
1073 #[test]
1074 fn child_transformer_preserves_filter() {
1075 let root = SubgraphTransformer::new("root".to_string())
1076 .with_filter_types(vec!["custom".to_string()]);
1077 let child = root.child_transformer("child");
1078
1079 let updates_event = crate::stream::StreamEvent::<StateDummy>::Updates {
1081 node: "agent".to_string(),
1082 update: StateDummyUpdate,
1083 step: 1,
1084 };
1085 assert!(child.transform(&updates_event).is_none());
1086
1087 let custom_event = crate::stream::StreamEvent::<StateDummy>::Custom {
1089 node: "agent".to_string(),
1090 data: serde_json::json!({}),
1091 ns: vec![],
1092 };
1093 let result = child.transform(&custom_event).expect("custom should pass");
1094 match result {
1095 crate::stream::StreamEvent::Custom { node, ns, .. } => {
1096 assert_eq!(node, "root/child/agent");
1097 assert_eq!(ns, vec!["root", "child"]);
1098 }
1099 other => panic!("expected Custom, got {other:?}"),
1100 }
1101 }
1102
1103 #[test]
1106 fn nested_namespace_three_levels_deep() {
1107 let ns = crate::checkpoint::CheckpointNamespace::root();
1108 let level1 = ns.child("review", "uuid-1");
1109 let level2 = level1.child("detail", "uuid-2");
1110 let level3 = level2.child("sub", "uuid-3");
1111
1112 assert_eq!(level1.as_str(), "|review:uuid-1");
1113 assert_eq!(level2.as_str(), "|review:uuid-1|detail:uuid-2");
1114 assert_eq!(level3.as_str(), "|review:uuid-1|detail:uuid-2|sub:uuid-3");
1115 assert!(ns.is_root());
1116 assert!(!level1.is_root());
1117 assert!(!level3.is_root());
1118 }
1119
1120 #[test]
1121 fn nested_namespace_parse_roundtrip_three_levels() {
1122 let original = "|alpha:aaa|beta:bbb|gamma:ccc";
1123 let parsed = crate::checkpoint::CheckpointNamespace::parse(original);
1124 assert_eq!(parsed.as_str(), original);
1125
1126 assert_eq!(parsed.segments.len(), 3);
1128 assert_eq!(parsed.segments[0].node_name, "alpha");
1129 assert_eq!(parsed.segments[0].invocation_id, "aaa");
1130 assert_eq!(parsed.segments[1].node_name, "beta");
1131 assert_eq!(parsed.segments[1].invocation_id, "bbb");
1132 assert_eq!(parsed.segments[2].node_name, "gamma");
1133 assert_eq!(parsed.segments[2].invocation_id, "ccc");
1134
1135 let level4 = parsed.child("delta", "ddd");
1137 assert_eq!(level4.as_str(), "|alpha:aaa|beta:bbb|gamma:ccc|delta:ddd");
1138 }
1139
1140 #[test]
1141 fn nested_compute_child_namespace_chains_correctly() {
1142 let parent = CheckpointNamespace::parse("|review:uuid-1|detail:uuid-2");
1144
1145 let child_inherit = compute_child_namespace(
1147 SubgraphPersistence::Inherit,
1148 "sub",
1149 Some(&parent),
1150 Some("thread-1"),
1151 );
1152 let child_inherit = child_inherit.expect("Inherit should produce a namespace");
1153 let rendered = child_inherit.as_str();
1154 assert!(rendered.starts_with("|review:uuid-1|detail:uuid-2|sub:"));
1155 let uuid_part = rendered
1156 .strip_prefix("|review:uuid-1|detail:uuid-2|sub:")
1157 .expect("prefix present");
1158 assert!(
1159 uuid::Uuid::parse_str(uuid_part).is_ok(),
1160 "suffix should be a valid UUID, got: {uuid_part}"
1161 );
1162
1163 let child_perthread = compute_child_namespace(
1165 SubgraphPersistence::PerThread,
1166 "sub",
1167 Some(&parent),
1168 Some("thread-42"),
1169 );
1170 let child_perthread = child_perthread.expect("PerThread should produce a namespace");
1171 assert_eq!(
1172 child_perthread.as_str(),
1173 "|review:uuid-1|detail:uuid-2|sub:thread-42"
1174 );
1175
1176 let child_stateless = compute_child_namespace(
1178 SubgraphPersistence::Stateless,
1179 "sub",
1180 Some(&parent),
1181 Some("thread-1"),
1182 );
1183 assert!(
1184 child_stateless.is_none(),
1185 "Stateless should return None for namespace"
1186 );
1187 }
1188
1189 #[test]
1190 fn nested_namespace_different_uuids_at_each_level() {
1191 let ns = crate::checkpoint::CheckpointNamespace::root();
1192 let level1 = ns.child("review", "11111111-1111-1111-1111-111111111111");
1193 let level2 = level1.child("detail", "22222222-2222-2222-2222-222222222222");
1194 let level3 = level2.child("sub", "33333333-3333-3333-3333-333333333333");
1195
1196 let rendered = level3.as_str();
1197 assert_eq!(
1198 rendered,
1199 "|review:11111111-1111-1111-1111-111111111111\
1200 |detail:22222222-2222-2222-2222-222222222222\
1201 |sub:33333333-3333-3333-3333-333333333333"
1202 );
1203
1204 assert_ne!(level1.as_str(), level2.as_str());
1206 assert_ne!(level2.as_str(), level3.as_str());
1207 assert_ne!(level1.as_str(), level3.as_str());
1208
1209 assert_eq!(level1.segments.len(), 1);
1211 assert_eq!(level2.segments.len(), 2);
1212 assert_eq!(level3.segments.len(), 3);
1213
1214 assert_eq!(
1216 level3
1217 .parent()
1218 .as_ref()
1219 .map(crate::checkpoint::CheckpointNamespace::as_str),
1220 Some(level2.as_str())
1221 );
1222 assert_eq!(
1223 level2
1224 .parent()
1225 .as_ref()
1226 .map(crate::checkpoint::CheckpointNamespace::as_str),
1227 Some(level1.as_str())
1228 );
1229 assert_eq!(
1230 level1
1231 .parent()
1232 .as_ref()
1233 .map(crate::checkpoint::CheckpointNamespace::as_str),
1234 Some(String::new())
1235 );
1236 assert_eq!(ns.parent(), None);
1237 }
1238
1239 fn mock_node(name: &str) -> Arc<dyn crate::Node<StateDummy>> {
1240 NodeFnUpdate(|_s: &StateDummy| async move { Ok(StateDummyUpdate) }).into_node(name)
1241 }
1242
1243 #[derive(Clone, Debug, Default)]
1244 struct StateDummy;
1245
1246 impl crate::State for StateDummy {
1247 type Update = StateDummyUpdate;
1248 type FieldVersions = crate::state::FieldVersions;
1249
1250 fn apply(&mut self, _update: Self::Update) -> crate::FieldsChanged {
1251 crate::FieldsChanged(0)
1252 }
1253
1254 fn reset_ephemeral(&mut self) {}
1255 }
1256
1257 #[derive(Clone, Debug, Default)]
1258 struct StateDummyUpdate;
1259}
1260
1261#[derive(Clone)]
1266pub struct SubgraphTransformer {
1267 pub subgraph_name: String,
1269
1270 pub ns: Vec<String>,
1272
1273 #[allow(
1278 clippy::type_complexity,
1279 reason = "trait object requires full signature for filter closure"
1280 )]
1281 pub filter: Option<std::sync::Arc<dyn Fn(&serde_json::Value) -> bool + Send + Sync>>,
1282
1283 pub include_internal: bool,
1285}
1286
1287impl std::fmt::Debug for SubgraphTransformer {
1288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1289 f.debug_struct("SubgraphTransformer")
1290 .field("subgraph_name", &self.subgraph_name)
1291 .field("ns", &self.ns)
1292 .field("filter", &self.filter.as_ref().map(|_| "<fn>"))
1293 .field("include_internal", &self.include_internal)
1294 .finish()
1295 }
1296}
1297
1298impl SubgraphTransformer {
1299 #[must_use]
1305 pub const fn new(subgraph_name: String) -> Self {
1306 Self {
1307 subgraph_name,
1308 ns: Vec::new(),
1309 filter: None,
1310 include_internal: false,
1311 }
1312 }
1313
1314 #[must_use]
1321 pub fn with_filter(
1322 mut self,
1323 filter: impl Fn(&serde_json::Value) -> bool + Send + Sync + 'static,
1324 ) -> Self {
1325 self.filter = Some(std::sync::Arc::new(filter));
1326 self
1327 }
1328
1329 #[must_use]
1335 pub fn with_filter_types(mut self, types: Vec<String>) -> Self {
1336 if types.is_empty() {
1337 self.filter = None;
1338 } else {
1339 let filter = move |value: &serde_json::Value| {
1340 value
1341 .get("type")
1342 .and_then(|v| v.as_str())
1343 .is_some_and(|event_type| types.iter().any(|t| t == event_type))
1344 };
1345 self.filter = Some(std::sync::Arc::new(filter));
1346 }
1347 self
1348 }
1349
1350 #[must_use]
1356 pub const fn with_internal(mut self, include: bool) -> Self {
1357 self.include_internal = include;
1358 self
1359 }
1360
1361 #[must_use]
1381 pub fn transform<S: State>(
1382 &self,
1383 event: &crate::stream::StreamEvent<S>,
1384 ) -> Option<crate::stream::StreamEvent<S>> {
1385 if !self.passes_filter(event) {
1386 return None;
1387 }
1388 Some(self.apply_namespace(event))
1389 }
1390
1391 fn passes_filter<S: State>(&self, event: &crate::stream::StreamEvent<S>) -> bool {
1393 use crate::stream::StreamEvent;
1394
1395 let Some(ref filter) = self.filter else {
1396 return true;
1397 };
1398 let event_type = match event {
1399 StreamEvent::Values { .. } | StreamEvent::FilteredValues { .. } => "values",
1400 StreamEvent::Updates { .. } | StreamEvent::FilteredUpdates { .. } => "updates",
1401 StreamEvent::Messages { .. } => "messages",
1402 StreamEvent::Custom { .. } => "custom",
1403 StreamEvent::TaskStart { .. } => "task_start",
1404 StreamEvent::TaskEnd { .. } => "task_end",
1405 StreamEvent::Interrupt { .. } => "interrupt",
1406 StreamEvent::BudgetExceeded { .. } => "budget_exceeded",
1407 StreamEvent::End { .. } => "end",
1408 StreamEvent::Debug(_) => "debug",
1409 StreamEvent::Tools(_) => "tools",
1410 StreamEvent::CheckpointSaved { .. } => "checkpoint_saved",
1411 StreamEvent::TaskDetail { .. } => "task_detail",
1412 StreamEvent::Cancelled { .. } => "cancelled",
1413 };
1414 let filter_value = serde_json::json!({ "type": event_type });
1415 filter(&filter_value)
1416 }
1417
1418 fn build_ns(&self) -> (String, Vec<String>) {
1421 let ns_prefix = if self.ns.is_empty() {
1422 self.subgraph_name.clone()
1423 } else {
1424 format!("{}/{}", self.ns.join("/"), self.subgraph_name)
1425 };
1426 let full_ns = {
1427 let mut ns = self.ns.clone();
1428 ns.push(self.subgraph_name.clone());
1429 ns
1430 };
1431 (ns_prefix, full_ns)
1432 }
1433
1434 fn apply_namespace<S: State>(
1436 &self,
1437 event: &crate::stream::StreamEvent<S>,
1438 ) -> crate::stream::StreamEvent<S> {
1439 use crate::stream::StreamEvent;
1440
1441 let (ns_prefix, full_ns) = self.build_ns();
1442 let namespaced = |node: &str| -> String { format!("{ns_prefix}/{node}") };
1443
1444 match event.clone() {
1445 StreamEvent::Updates { node, update, step } => StreamEvent::Updates {
1446 node: namespaced(&node),
1447 update,
1448 step,
1449 },
1450 StreamEvent::FilteredUpdates { node, data, step } => StreamEvent::FilteredUpdates {
1451 node: namespaced(&node),
1452 data,
1453 step,
1454 },
1455 StreamEvent::TaskStart {
1456 node,
1457 task_id,
1458 step,
1459 } => StreamEvent::TaskStart {
1460 node: namespaced(&node),
1461 task_id,
1462 step,
1463 },
1464 StreamEvent::TaskEnd {
1465 node,
1466 task_id,
1467 step,
1468 duration_ms,
1469 } => StreamEvent::TaskEnd {
1470 node: namespaced(&node),
1471 task_id,
1472 step,
1473 duration_ms,
1474 },
1475 StreamEvent::TaskDetail {
1476 task_id,
1477 node,
1478 step,
1479 attempt,
1480 event: task_event,
1481 } => StreamEvent::TaskDetail {
1482 task_id,
1483 node: namespaced(&node),
1484 step,
1485 attempt,
1486 event: task_event,
1487 },
1488
1489 StreamEvent::Custom { node, data, .. } => StreamEvent::Custom {
1491 node: namespaced(&node),
1492 data,
1493 ns: full_ns,
1494 },
1495 StreamEvent::Interrupt {
1496 node,
1497 payload,
1498 resumable,
1499 ..
1500 } => StreamEvent::Interrupt {
1501 node: namespaced(&node),
1502 payload,
1503 resumable,
1504 ns: full_ns,
1505 },
1506
1507 StreamEvent::Messages {
1509 chunk,
1510 mut metadata,
1511 } => {
1512 metadata.node = namespaced(&metadata.node);
1513 metadata.ns = full_ns;
1514 StreamEvent::Messages { chunk, metadata }
1515 }
1516
1517 other => other,
1519 }
1520 }
1521
1522 pub fn add_namespace(&mut self, segment: String) {
1528 self.ns.push(segment);
1529 }
1530
1531 #[must_use]
1554 pub fn child_transformer(&self, child_name: &str) -> Self {
1555 let mut child = self.clone();
1556 child.ns.push(self.subgraph_name.clone());
1557 child.subgraph_name = child_name.to_string();
1558 child
1559 }
1560
1561 #[must_use]
1599 pub fn to_emitter<S: crate::State>(
1600 &self,
1601 tx: tokio::sync::mpsc::Sender<crate::stream::StreamEvent<S>>,
1602 mode: crate::stream::StreamMode,
1603 ) -> crate::stream::EventEmitter<S> {
1604 let mut emitter = crate::stream::EventEmitter::new(tx, mode);
1605 for segment in &self.ns {
1606 emitter = emitter.with_subgraph_ns(segment.clone());
1607 }
1608 emitter.with_subgraph_ns(self.subgraph_name.clone())
1609 }
1610}
1611
1612