Skip to main content

cortexai_crew/
subgraph.rs

1//! Subgraphs / Nested Workflows
2//!
3//! Enables composing graphs within graphs for modular, reusable workflow components.
4//!
5//! ## Features
6//!
7//! - Nest graphs as nodes within parent graphs
8//! - State mapping between parent and child graphs
9//! - Parallel subgraph execution
10//! - Isolated or shared state modes
11//! - Subgraph libraries for reusable components
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use cortexai_crew::subgraph::{SubgraphNode, StateMapping};
17//!
18//! // Create a reusable research subgraph
19//! let research_graph = GraphBuilder::new("research")
20//!     .add_node("search", search_node)
21//!     .add_node("analyze", analyze_node)
22//!     .add_edge("search", "analyze")
23//!     .set_entry("search")
24//!     .build()?;
25//!
26//! // Create a reusable writing subgraph
27//! let writing_graph = GraphBuilder::new("writing")
28//!     .add_node("draft", draft_node)
29//!     .add_node("edit", edit_node)
30//!     .add_edge("draft", "edit")
31//!     .set_entry("draft")
32//!     .build()?;
33//!
34//! // Compose into a parent graph
35//! let main_graph = GraphBuilder::new("main")
36//!     .add_node("init", init_node)
37//!     .add_subgraph("research", research_graph, StateMapping::default())
38//!     .add_subgraph("writing", writing_graph, StateMapping::default())
39//!     .add_edge("init", "research")
40//!     .add_edge("research", "writing")
41//!     .set_entry("init")
42//!     .build()?;
43//! ```
44
45use crate::graph::{Graph, GraphState, GraphStatus, NodeFn};
46use cortexai_core::errors::CrewError;
47use serde::{Deserialize, Serialize};
48use std::collections::HashMap;
49use std::sync::Arc;
50
51/// State mapping configuration between parent and child graphs
52#[derive(Debug, Clone, Default)]
53pub struct StateMapping {
54    /// Map parent keys to child keys (parent_key -> child_key)
55    pub input_mapping: HashMap<String, String>,
56    /// Map child keys back to parent keys (child_key -> parent_key)
57    pub output_mapping: HashMap<String, String>,
58    /// Keys to pass through unchanged (same key in both)
59    pub passthrough: Vec<String>,
60    /// Whether to merge all child state back (vs only mapped keys)
61    pub merge_all: bool,
62    /// Prefix for child keys when merging all
63    pub output_prefix: Option<String>,
64}
65
66impl StateMapping {
67    /// Create a new state mapping
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    /// Map a parent key to a child key on input
73    pub fn map_input(
74        mut self,
75        parent_key: impl Into<String>,
76        child_key: impl Into<String>,
77    ) -> Self {
78        self.input_mapping
79            .insert(parent_key.into(), child_key.into());
80        self
81    }
82
83    /// Map a child key back to a parent key on output
84    pub fn map_output(
85        mut self,
86        child_key: impl Into<String>,
87        parent_key: impl Into<String>,
88    ) -> Self {
89        self.output_mapping
90            .insert(child_key.into(), parent_key.into());
91        self
92    }
93
94    /// Add a passthrough key (same name in parent and child)
95    pub fn passthrough(mut self, key: impl Into<String>) -> Self {
96        self.passthrough.push(key.into());
97        self
98    }
99
100    /// Add multiple passthrough keys
101    pub fn passthrough_keys(mut self, keys: Vec<String>) -> Self {
102        self.passthrough.extend(keys);
103        self
104    }
105
106    /// Merge all child state back to parent
107    pub fn merge_all(mut self) -> Self {
108        self.merge_all = true;
109        self
110    }
111
112    /// Set prefix for merged keys
113    pub fn with_prefix(mut self, prefix: impl Into<String>) -> Self {
114        self.output_prefix = Some(prefix.into());
115        self.merge_all = true;
116        self
117    }
118
119    /// Apply input mapping: parent state -> child state
120    pub fn apply_input(&self, parent_state: &GraphState) -> GraphState {
121        let mut child_state = GraphState::new();
122
123        // Apply explicit mappings
124        if let Some(parent_obj) = parent_state.data.as_object() {
125            // Input mappings
126            for (parent_key, child_key) in &self.input_mapping {
127                if let Some(value) = parent_obj.get(parent_key) {
128                    child_state.set(child_key, value.clone());
129                }
130            }
131
132            // Passthrough keys
133            for key in &self.passthrough {
134                if let Some(value) = parent_obj.get(key) {
135                    child_state.set(key, value.clone());
136                }
137            }
138
139            // If merge_all and no specific input mappings, pass all parent state
140            if self.merge_all && self.input_mapping.is_empty() {
141                for (key, value) in parent_obj {
142                    child_state.set(key, value.clone());
143                }
144            }
145        }
146
147        child_state
148    }
149
150    /// Apply output mapping: child state -> updated parent state
151    pub fn apply_output(&self, parent_state: &GraphState, child_state: &GraphState) -> GraphState {
152        let mut result = parent_state.clone();
153
154        if let Some(child_obj) = child_state.data.as_object() {
155            // Output mappings
156            for (child_key, parent_key) in &self.output_mapping {
157                if let Some(value) = child_obj.get(child_key) {
158                    result.set(parent_key, value.clone());
159                }
160            }
161
162            // Passthrough keys
163            for key in &self.passthrough {
164                if let Some(value) = child_obj.get(key) {
165                    result.set(key, value.clone());
166                }
167            }
168
169            // Merge all with optional prefix
170            if self.merge_all {
171                for (key, value) in child_obj {
172                    let target_key = match &self.output_prefix {
173                        Some(prefix) => format!("{}.{}", prefix, key),
174                        None => key.clone(),
175                    };
176                    result.set(&target_key, value.clone());
177                }
178            }
179        }
180
181        result
182    }
183}
184
185/// A node that executes a subgraph
186pub struct SubgraphNode {
187    /// The subgraph to execute
188    graph: Arc<Graph>,
189    /// State mapping configuration
190    mapping: StateMapping,
191    /// Optional name for this subgraph instance
192    name: Option<String>,
193}
194
195impl SubgraphNode {
196    /// Create a new subgraph node
197    pub fn new(graph: Graph) -> Self {
198        Self {
199            graph: Arc::new(graph),
200            mapping: StateMapping::default(),
201            name: None,
202        }
203    }
204
205    /// Create with state mapping
206    pub fn with_mapping(graph: Graph, mapping: StateMapping) -> Self {
207        Self {
208            graph: Arc::new(graph),
209            mapping,
210            name: None,
211        }
212    }
213
214    /// Set a name for this subgraph instance
215    pub fn named(mut self, name: impl Into<String>) -> Self {
216        self.name = Some(name.into());
217        self
218    }
219
220    /// Get the subgraph
221    pub fn graph(&self) -> &Graph {
222        &self.graph
223    }
224}
225
226#[async_trait::async_trait]
227impl NodeFn for SubgraphNode {
228    async fn call(&self, state: GraphState) -> Result<GraphState, CrewError> {
229        // Apply input mapping
230        let child_state = self.mapping.apply_input(&state);
231
232        // Execute subgraph
233        let result = self.graph.invoke(child_state).await?;
234
235        // Check for errors
236        if result.status != GraphStatus::Success {
237            return Err(CrewError::ExecutionFailed(format!(
238                "Subgraph '{}' failed: {}",
239                self.name.as_deref().unwrap_or(&self.graph.name),
240                result.error.unwrap_or_else(|| "Unknown error".to_string())
241            )));
242        }
243
244        // Apply output mapping
245        let final_state = self.mapping.apply_output(&state, &result.state);
246
247        Ok(final_state)
248    }
249}
250
251/// Parallel subgraph executor - runs multiple subgraphs concurrently
252pub struct ParallelSubgraphs {
253    /// Subgraphs to execute in parallel
254    subgraphs: Vec<(String, Arc<Graph>, StateMapping)>,
255    /// How to merge results
256    merge_strategy: MergeStrategy,
257}
258
259/// Strategy for merging parallel subgraph results
260#[derive(Debug, Clone, Default)]
261pub enum MergeStrategy {
262    /// Merge all results, later overwrites earlier on conflicts
263    #[default]
264    MergeAll,
265    /// Prefix each subgraph's output with its name
266    Prefixed,
267    /// Collect results into a map
268    Collect,
269    /// Custom merge function (not serializable)
270    Custom,
271}
272
273impl ParallelSubgraphs {
274    /// Create a new parallel subgraph executor
275    pub fn new() -> Self {
276        Self {
277            subgraphs: Vec::new(),
278            merge_strategy: MergeStrategy::default(),
279        }
280    }
281
282    /// Add a subgraph
283    pub fn add(mut self, name: impl Into<String>, graph: Graph, mapping: StateMapping) -> Self {
284        self.subgraphs.push((name.into(), Arc::new(graph), mapping));
285        self
286    }
287
288    /// Set merge strategy
289    pub fn with_strategy(mut self, strategy: MergeStrategy) -> Self {
290        self.merge_strategy = strategy;
291        self
292    }
293
294    /// Execute all subgraphs in parallel
295    pub async fn execute(&self, state: &GraphState) -> Result<GraphState, CrewError> {
296        use futures::future::join_all;
297
298        // Prepare child states
299        let tasks: Vec<_> = self
300            .subgraphs
301            .iter()
302            .map(|(name, graph, mapping)| {
303                let child_state = mapping.apply_input(state);
304                let graph = Arc::clone(graph);
305                let name = name.clone();
306                let mapping = mapping.clone();
307
308                async move {
309                    let result = graph.invoke(child_state).await?;
310                    if result.status != GraphStatus::Success {
311                        return Err(CrewError::ExecutionFailed(format!(
312                            "Parallel subgraph '{}' failed: {}",
313                            name,
314                            result.error.unwrap_or_else(|| "Unknown error".to_string())
315                        )));
316                    }
317                    Ok((name, result.state, mapping))
318                }
319            })
320            .collect();
321
322        // Execute in parallel
323        let results = join_all(tasks).await;
324
325        // Check for errors
326        let mut successful: Vec<(String, GraphState, StateMapping)> = Vec::new();
327        for result in results {
328            successful.push(result?);
329        }
330
331        // Merge results based on strategy
332        let mut final_state = state.clone();
333        match self.merge_strategy {
334            MergeStrategy::MergeAll => {
335                for (_, child_state, mapping) in successful {
336                    final_state = mapping.apply_output(&final_state, &child_state);
337                }
338            }
339            MergeStrategy::Prefixed => {
340                for (name, child_state, _) in successful {
341                    if let Some(obj) = child_state.data.as_object() {
342                        for (key, value) in obj {
343                            final_state.set(&format!("{}.{}", name, key), value.clone());
344                        }
345                    }
346                }
347            }
348            MergeStrategy::Collect => {
349                let mut collected = serde_json::Map::new();
350                for (name, child_state, _) in successful {
351                    collected.insert(name, child_state.data);
352                }
353                final_state.set("parallel_results", serde_json::Value::Object(collected));
354            }
355            MergeStrategy::Custom => {
356                // Custom strategy would need a closure, which we can't easily support
357                // For now, fall back to MergeAll
358                for (_, child_state, mapping) in successful {
359                    final_state = mapping.apply_output(&final_state, &child_state);
360                }
361            }
362        }
363
364        Ok(final_state)
365    }
366}
367
368impl Default for ParallelSubgraphs {
369    fn default() -> Self {
370        Self::new()
371    }
372}
373
374#[async_trait::async_trait]
375impl NodeFn for ParallelSubgraphs {
376    async fn call(&self, state: GraphState) -> Result<GraphState, CrewError> {
377        self.execute(&state).await
378    }
379}
380
381/// Subgraph library for reusable components
382#[derive(Default)]
383pub struct SubgraphLibrary {
384    graphs: HashMap<String, Arc<Graph>>,
385}
386
387impl SubgraphLibrary {
388    /// Create a new library
389    pub fn new() -> Self {
390        Self::default()
391    }
392
393    /// Register a graph in the library
394    pub fn register(&mut self, name: impl Into<String>, graph: Graph) {
395        self.graphs.insert(name.into(), Arc::new(graph));
396    }
397
398    /// Get a graph from the library
399    pub fn get(&self, name: &str) -> Option<Arc<Graph>> {
400        self.graphs.get(name).cloned()
401    }
402
403    /// Create a subgraph node from a library graph
404    pub fn create_node(&self, name: &str, mapping: StateMapping) -> Option<SubgraphNode> {
405        self.graphs.get(name).map(|graph| SubgraphNode {
406            graph: Arc::clone(graph),
407            mapping,
408            name: Some(name.to_string()),
409        })
410    }
411
412    /// List all registered graphs
413    pub fn list(&self) -> Vec<&str> {
414        self.graphs.keys().map(|s| s.as_str()).collect()
415    }
416
417    /// Check if a graph is registered
418    pub fn contains(&self, name: &str) -> bool {
419        self.graphs.contains_key(name)
420    }
421
422    /// Remove a graph from the library
423    pub fn remove(&mut self, name: &str) -> Option<Arc<Graph>> {
424        self.graphs.remove(name)
425    }
426}
427
428/// Conditional subgraph - executes one of several subgraphs based on state
429pub struct ConditionalSubgraph {
430    /// Branches with their conditions
431    branches: Vec<(Box<dyn Fn(&GraphState) -> bool + Send + Sync>, SubgraphNode)>,
432    /// Default branch if no condition matches
433    default: Option<SubgraphNode>,
434}
435
436impl ConditionalSubgraph {
437    /// Create a new conditional subgraph
438    pub fn new() -> Self {
439        Self {
440            branches: Vec::new(),
441            default: None,
442        }
443    }
444
445    /// Add a conditional branch
446    pub fn when<F>(mut self, condition: F, graph: Graph, mapping: StateMapping) -> Self
447    where
448        F: Fn(&GraphState) -> bool + Send + Sync + 'static,
449    {
450        self.branches.push((
451            Box::new(condition),
452            SubgraphNode::with_mapping(graph, mapping),
453        ));
454        self
455    }
456
457    /// Set the default branch
458    pub fn otherwise(mut self, graph: Graph, mapping: StateMapping) -> Self {
459        self.default = Some(SubgraphNode::with_mapping(graph, mapping));
460        self
461    }
462}
463
464impl Default for ConditionalSubgraph {
465    fn default() -> Self {
466        Self::new()
467    }
468}
469
470#[async_trait::async_trait]
471impl NodeFn for ConditionalSubgraph {
472    async fn call(&self, state: GraphState) -> Result<GraphState, CrewError> {
473        // Find first matching condition
474        for (condition, subgraph) in &self.branches {
475            if condition(&state) {
476                return subgraph.call(state).await;
477            }
478        }
479
480        // Use default if no condition matched
481        if let Some(default) = &self.default {
482            return default.call(state).await;
483        }
484
485        // No match and no default - pass through unchanged
486        Ok(state)
487    }
488}
489
490/// Loop subgraph - executes a subgraph repeatedly until a condition is met
491pub struct LoopSubgraph {
492    /// The subgraph to loop
493    subgraph: SubgraphNode,
494    /// Continue condition (returns true to continue looping)
495    continue_condition: Box<dyn Fn(&GraphState) -> bool + Send + Sync>,
496    /// Maximum iterations
497    max_iterations: u32,
498}
499
500impl LoopSubgraph {
501    /// Create a new loop subgraph
502    pub fn new<F>(graph: Graph, mapping: StateMapping, continue_while: F) -> Self
503    where
504        F: Fn(&GraphState) -> bool + Send + Sync + 'static,
505    {
506        Self {
507            subgraph: SubgraphNode::with_mapping(graph, mapping),
508            continue_condition: Box::new(continue_while),
509            max_iterations: 100,
510        }
511    }
512
513    /// Set maximum iterations
514    pub fn max_iterations(mut self, max: u32) -> Self {
515        self.max_iterations = max;
516        self
517    }
518}
519
520#[async_trait::async_trait]
521impl NodeFn for LoopSubgraph {
522    async fn call(&self, mut state: GraphState) -> Result<GraphState, CrewError> {
523        let mut iterations = 0;
524
525        while (self.continue_condition)(&state) && iterations < self.max_iterations {
526            state = self.subgraph.call(state).await?;
527            iterations += 1;
528        }
529
530        if iterations >= self.max_iterations {
531            return Err(CrewError::ExecutionFailed(format!(
532                "Loop subgraph exceeded max iterations: {}",
533                self.max_iterations
534            )));
535        }
536
537        Ok(state)
538    }
539}
540
541/// Retry subgraph - retries a subgraph on failure
542pub struct RetrySubgraph {
543    /// The subgraph to retry
544    subgraph: SubgraphNode,
545    /// Maximum retry attempts
546    max_retries: u32,
547    /// Delay between retries in milliseconds
548    retry_delay_ms: u64,
549    /// Exponential backoff multiplier
550    backoff_multiplier: f64,
551}
552
553impl RetrySubgraph {
554    /// Create a new retry subgraph
555    pub fn new(graph: Graph, mapping: StateMapping) -> Self {
556        Self {
557            subgraph: SubgraphNode::with_mapping(graph, mapping),
558            max_retries: 3,
559            retry_delay_ms: 100,
560            backoff_multiplier: 2.0,
561        }
562    }
563
564    /// Set maximum retries
565    pub fn max_retries(mut self, max: u32) -> Self {
566        self.max_retries = max;
567        self
568    }
569
570    /// Set retry delay
571    pub fn retry_delay_ms(mut self, delay: u64) -> Self {
572        self.retry_delay_ms = delay;
573        self
574    }
575
576    /// Set backoff multiplier
577    pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
578        self.backoff_multiplier = multiplier;
579        self
580    }
581}
582
583#[async_trait::async_trait]
584impl NodeFn for RetrySubgraph {
585    async fn call(&self, state: GraphState) -> Result<GraphState, CrewError> {
586        let mut last_error = None;
587        let mut delay = self.retry_delay_ms;
588
589        for attempt in 0..=self.max_retries {
590            match self.subgraph.call(state.clone()).await {
591                Ok(result) => return Ok(result),
592                Err(e) => {
593                    last_error = Some(e);
594                    if attempt < self.max_retries {
595                        tokio::time::sleep(std::time::Duration::from_millis(delay)).await;
596                        delay = (delay as f64 * self.backoff_multiplier) as u64;
597                    }
598                }
599            }
600        }
601
602        Err(last_error.unwrap_or_else(|| {
603            CrewError::ExecutionFailed("Retry subgraph failed with unknown error".to_string())
604        }))
605    }
606}
607
608/// Extension trait for GraphBuilder to add subgraph nodes
609pub trait GraphBuilderSubgraphExt {
610    /// Add a subgraph as a node
611    fn add_subgraph(self, id: impl Into<String>, graph: Graph, mapping: StateMapping) -> Self;
612
613    /// Add a parallel subgraphs node
614    fn add_parallel_subgraphs(self, id: impl Into<String>, parallel: ParallelSubgraphs) -> Self;
615
616    /// Add a conditional subgraph node
617    fn add_conditional_subgraph(
618        self,
619        id: impl Into<String>,
620        conditional: ConditionalSubgraph,
621    ) -> Self;
622
623    /// Add a loop subgraph node
624    fn add_loop_subgraph(self, id: impl Into<String>, loop_sg: LoopSubgraph) -> Self;
625
626    /// Add a retry subgraph node
627    fn add_retry_subgraph(self, id: impl Into<String>, retry: RetrySubgraph) -> Self;
628}
629
630impl GraphBuilderSubgraphExt for crate::graph::GraphBuilder {
631    fn add_subgraph(self, id: impl Into<String>, graph: Graph, mapping: StateMapping) -> Self {
632        self.add_node_executor(id, Arc::new(SubgraphNode::with_mapping(graph, mapping)))
633    }
634
635    fn add_parallel_subgraphs(self, id: impl Into<String>, parallel: ParallelSubgraphs) -> Self {
636        self.add_node_executor(id, Arc::new(parallel))
637    }
638
639    fn add_conditional_subgraph(
640        self,
641        id: impl Into<String>,
642        conditional: ConditionalSubgraph,
643    ) -> Self {
644        self.add_node_executor(id, Arc::new(conditional))
645    }
646
647    fn add_loop_subgraph(self, id: impl Into<String>, loop_sg: LoopSubgraph) -> Self {
648        self.add_node_executor(id, Arc::new(loop_sg))
649    }
650
651    fn add_retry_subgraph(self, id: impl Into<String>, retry: RetrySubgraph) -> Self {
652        self.add_node_executor(id, Arc::new(retry))
653    }
654}
655
656/// Subgraph execution result with detailed info
657#[derive(Debug, Clone, Serialize, Deserialize)]
658pub struct SubgraphExecutionInfo {
659    /// Name of the subgraph
660    pub name: String,
661    /// Execution status
662    pub status: GraphStatus,
663    /// Input state (after mapping)
664    pub input_state: serde_json::Value,
665    /// Output state (before mapping back)
666    pub output_state: serde_json::Value,
667    /// Execution time in milliseconds
668    pub execution_time_ms: u64,
669    /// Number of nodes executed
670    pub nodes_executed: usize,
671}
672
673#[cfg(test)]
674mod tests {
675    use super::*;
676    use crate::graph::{GraphBuilder, END};
677
678    fn create_adder_graph(amount: i32) -> Graph {
679        GraphBuilder::new(format!("adder_{}", amount))
680            .add_node("add", move |mut state: GraphState| async move {
681                let value: i32 = state.get("value").unwrap_or(0);
682                state.set("value", value + amount);
683                Ok(state)
684            })
685            .add_edge("add", END)
686            .set_entry("add")
687            .build()
688            .unwrap()
689    }
690
691    fn create_multiplier_graph(factor: i32) -> Graph {
692        GraphBuilder::new(format!("multiplier_{}", factor))
693            .add_node("multiply", move |mut state: GraphState| async move {
694                let value: i32 = state.get("value").unwrap_or(0);
695                state.set("value", value * factor);
696                Ok(state)
697            })
698            .add_edge("multiply", END)
699            .set_entry("multiply")
700            .build()
701            .unwrap()
702    }
703
704    #[tokio::test]
705    async fn test_simple_subgraph() {
706        let adder = create_adder_graph(5);
707        // Use merge_all to pass through all state
708        let subgraph_node = SubgraphNode::with_mapping(adder, StateMapping::new().merge_all());
709
710        let mut state = GraphState::new();
711        state.set("value", 10);
712
713        let result = subgraph_node.call(state).await.unwrap();
714        let value: i32 = result.get("value").unwrap();
715        assert_eq!(value, 15);
716    }
717
718    #[tokio::test]
719    async fn test_subgraph_with_mapping() {
720        let adder = create_adder_graph(5);
721        let mapping = StateMapping::new()
722            .map_input("input_value", "value")
723            .map_output("value", "output_value");
724
725        let subgraph_node = SubgraphNode::with_mapping(adder, mapping);
726
727        let mut state = GraphState::new();
728        state.set("input_value", 10);
729
730        let result = subgraph_node.call(state).await.unwrap();
731        let value: i32 = result.get("output_value").unwrap();
732        assert_eq!(value, 15);
733    }
734
735    #[tokio::test]
736    async fn test_nested_subgraphs() {
737        // Inner subgraph: add 5
738        let inner = create_adder_graph(5);
739
740        // Outer subgraph: contains inner, then multiplies by 2
741        let outer = GraphBuilder::new("outer")
742            .add_subgraph("add_step", inner, StateMapping::new().passthrough("value"))
743            .add_node("multiply", |mut state: GraphState| async move {
744                let value: i32 = state.get("value").unwrap_or(0);
745                state.set("value", value * 2);
746                Ok(state)
747            })
748            .add_edge("add_step", "multiply")
749            .add_edge("multiply", END)
750            .set_entry("add_step")
751            .build()
752            .unwrap();
753
754        let mut state = GraphState::new();
755        state.set("value", 10);
756
757        let result = outer.invoke(state).await.unwrap();
758        let value: i32 = result.state.get("value").unwrap();
759        assert_eq!(value, 30); // (10 + 5) * 2 = 30
760    }
761
762    #[tokio::test]
763    async fn test_parallel_subgraphs() {
764        let adder = create_adder_graph(5);
765        let multiplier = create_multiplier_graph(3);
766
767        let parallel = ParallelSubgraphs::new()
768            .add("adder", adder, StateMapping::new().passthrough("value"))
769            .add(
770                "multiplier",
771                multiplier,
772                StateMapping::new().passthrough("value"),
773            )
774            .with_strategy(MergeStrategy::Prefixed);
775
776        let mut state = GraphState::new();
777        state.set("value", 10);
778
779        let result = parallel.call(state).await.unwrap();
780
781        // With Prefixed strategy, results are prefixed with subgraph name
782        let added: i32 = result.get("adder.value").unwrap();
783        let multiplied: i32 = result.get("multiplier.value").unwrap();
784
785        assert_eq!(added, 15);
786        assert_eq!(multiplied, 30);
787    }
788
789    #[tokio::test]
790    async fn test_parallel_subgraphs_collect() {
791        let adder = create_adder_graph(5);
792        let multiplier = create_multiplier_graph(3);
793
794        let parallel = ParallelSubgraphs::new()
795            .add("adder", adder, StateMapping::new().passthrough("value"))
796            .add(
797                "multiplier",
798                multiplier,
799                StateMapping::new().passthrough("value"),
800            )
801            .with_strategy(MergeStrategy::Collect);
802
803        let mut state = GraphState::new();
804        state.set("value", 10);
805
806        let result = parallel.call(state).await.unwrap();
807
808        // With Collect strategy, results are in parallel_results map
809        let parallel_results: serde_json::Value = result.get("parallel_results").unwrap();
810        let added = parallel_results["adder"]["value"].as_i64().unwrap();
811        let multiplied = parallel_results["multiplier"]["value"].as_i64().unwrap();
812
813        assert_eq!(added, 15);
814        assert_eq!(multiplied, 30);
815    }
816
817    #[tokio::test]
818    async fn test_conditional_subgraph() {
819        let adder = create_adder_graph(10);
820        let multiplier = create_multiplier_graph(2);
821
822        let conditional = ConditionalSubgraph::new()
823            .when(
824                |state| state.get::<bool>("use_addition").unwrap_or(false),
825                adder,
826                StateMapping::new().passthrough("value"),
827            )
828            .otherwise(multiplier, StateMapping::new().passthrough("value"));
829
830        // Test addition branch
831        let mut state1 = GraphState::new();
832        state1.set("value", 5);
833        state1.set("use_addition", true);
834        let result1 = conditional.call(state1).await.unwrap();
835        assert_eq!(result1.get::<i32>("value").unwrap(), 15);
836
837        // Test multiplication branch (default)
838        let mut state2 = GraphState::new();
839        state2.set("value", 5);
840        state2.set("use_addition", false);
841        let result2 = conditional.call(state2).await.unwrap();
842        assert_eq!(result2.get::<i32>("value").unwrap(), 10);
843    }
844
845    #[tokio::test]
846    async fn test_loop_subgraph() {
847        let adder = create_adder_graph(1);
848
849        let loop_sg = LoopSubgraph::new(adder, StateMapping::new().passthrough("value"), |state| {
850            state.get::<i32>("value").unwrap_or(0) < 5
851        })
852        .max_iterations(10);
853
854        let mut state = GraphState::new();
855        state.set("value", 0);
856
857        let result = loop_sg.call(state).await.unwrap();
858        let value: i32 = result.get("value").unwrap();
859        assert_eq!(value, 5); // Loop until value >= 5
860    }
861
862    #[tokio::test]
863    async fn test_subgraph_library() {
864        let mut library = SubgraphLibrary::new();
865        library.register("add5", create_adder_graph(5));
866        library.register("mult2", create_multiplier_graph(2));
867
868        assert!(library.contains("add5"));
869        assert!(library.contains("mult2"));
870        assert!(!library.contains("unknown"));
871
872        let node = library
873            .create_node("add5", StateMapping::new().passthrough("value"))
874            .unwrap();
875        let mut state = GraphState::new();
876        state.set("value", 10);
877
878        let result = node.call(state).await.unwrap();
879        assert_eq!(result.get::<i32>("value").unwrap(), 15);
880    }
881
882    #[tokio::test]
883    async fn test_state_mapping_merge_all() {
884        let graph = GraphBuilder::new("multi_output")
885            .add_node("compute", |mut state: GraphState| async move {
886                let value: i32 = state.get("value").unwrap_or(0);
887                state.set("doubled", value * 2);
888                state.set("tripled", value * 3);
889                state.set("squared", value * value);
890                Ok(state)
891            })
892            .add_edge("compute", END)
893            .set_entry("compute")
894            .build()
895            .unwrap();
896
897        let mapping = StateMapping::new().passthrough("value").merge_all();
898
899        let subgraph = SubgraphNode::with_mapping(graph, mapping);
900
901        let mut state = GraphState::new();
902        state.set("value", 5);
903
904        let result = subgraph.call(state).await.unwrap();
905        assert_eq!(result.get::<i32>("doubled").unwrap(), 10);
906        assert_eq!(result.get::<i32>("tripled").unwrap(), 15);
907        assert_eq!(result.get::<i32>("squared").unwrap(), 25);
908    }
909
910    #[tokio::test]
911    async fn test_state_mapping_with_prefix() {
912        let graph = GraphBuilder::new("outputs")
913            .add_node("compute", |mut state: GraphState| async move {
914                state.set("result", 42);
915                Ok(state)
916            })
917            .add_edge("compute", END)
918            .set_entry("compute")
919            .build()
920            .unwrap();
921
922        let mapping = StateMapping::new().with_prefix("child");
923        let subgraph = SubgraphNode::with_mapping(graph, mapping);
924
925        let state = GraphState::new();
926        let result = subgraph.call(state).await.unwrap();
927
928        assert_eq!(result.get::<i32>("child.result").unwrap(), 42);
929    }
930
931    #[tokio::test]
932    async fn test_graph_builder_extension() {
933        let inner = create_adder_graph(10);
934
935        let graph = GraphBuilder::new("with_subgraph")
936            .add_node("init", |mut state: GraphState| async move {
937                state.set("value", 5);
938                Ok(state)
939            })
940            .add_subgraph("add", inner, StateMapping::new().passthrough("value"))
941            .add_edge("init", "add")
942            .add_edge("add", END)
943            .set_entry("init")
944            .build()
945            .unwrap();
946
947        let result = graph.invoke(GraphState::new()).await.unwrap();
948        assert_eq!(result.state.get::<i32>("value").unwrap(), 15);
949    }
950}