Skip to main content

cortexai_crew/
graph.rs

1//! Graph-based Workflow Engine with Cycle Support
2//!
3//! A LangGraph-inspired execution engine that supports:
4//! - Cycles for iterative reasoning loops
5//! - Conditional edges with dynamic routing
6//! - State checkpointing and recovery
7//! - Parallel branch execution
8//! - Maximum iteration limits to prevent infinite loops
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use cortexai_crew::graph::{Graph, GraphBuilder, StateGraph};
14//!
15//! // Create a reasoning loop that iterates until done
16//! let graph = GraphBuilder::new("reasoning_loop")
17//!     .add_node("think", think_node)
18//!     .add_node("act", act_node)
19//!     .add_node("evaluate", evaluate_node)
20//!     .add_edge("think", "act")
21//!     .add_edge("act", "evaluate")
22//!     // Conditional: loop back to think or finish
23//!     .add_conditional_edge("evaluate", |state| {
24//!         if state.get("done").unwrap_or(&false) {
25//!             "END"
26//!         } else {
27//!             "think"  // Loop back
28//!         }
29//!     })
30//!     .set_entry("think")
31//!     .set_finish("END")
32//!     .build()?;
33//!
34//! let result = graph.invoke(initial_state).await?;
35//! ```
36
37use chrono::{DateTime, Utc};
38use cortexai_core::errors::CrewError;
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::sync::Arc;
42use tokio::sync::RwLock;
43
44/// Graph state - the data that flows through the graph
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct GraphState {
47    /// State data as JSON
48    pub data: serde_json::Value,
49    /// Execution metadata
50    pub metadata: GraphMetadata,
51}
52
53impl Default for GraphState {
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59impl GraphState {
60    /// Create empty state
61    pub fn new() -> Self {
62        Self {
63            data: serde_json::json!({}),
64            metadata: GraphMetadata::default(),
65        }
66    }
67
68    /// Create from JSON data
69    pub fn from_json(data: serde_json::Value) -> Self {
70        Self {
71            data,
72            metadata: GraphMetadata::default(),
73        }
74    }
75
76    /// Get a value from state
77    pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
78        self.data
79            .get(key)
80            .and_then(|v| serde_json::from_value(v.clone()).ok())
81    }
82
83    /// Set a value in state
84    pub fn set<T: Serialize>(&mut self, key: &str, value: T) {
85        if let Some(obj) = self.data.as_object_mut() {
86            if let Ok(v) = serde_json::to_value(value) {
87                obj.insert(key.to_string(), v);
88            }
89        }
90    }
91
92    /// Merge another state's data into this one
93    pub fn merge(&mut self, other: &GraphState) {
94        if let (Some(self_obj), Some(other_obj)) =
95            (self.data.as_object_mut(), other.data.as_object())
96        {
97            for (k, v) in other_obj {
98                self_obj.insert(k.clone(), v.clone());
99            }
100        }
101    }
102
103    /// Get raw JSON data
104    pub fn raw(&self) -> &serde_json::Value {
105        &self.data
106    }
107}
108
109/// Execution metadata
110#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct GraphMetadata {
112    /// Number of iterations executed
113    pub iterations: u32,
114    /// Nodes visited in order
115    pub visited_nodes: Vec<String>,
116    /// Current checkpoint ID
117    pub checkpoint_id: Option<String>,
118    /// Total execution time in ms
119    pub execution_time_ms: u64,
120    /// Start time
121    pub started_at: Option<DateTime<Utc>>,
122}
123
124/// Result of graph execution
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct GraphResult {
127    /// Final state
128    pub state: GraphState,
129    /// Execution status
130    pub status: GraphStatus,
131    /// Error message if failed
132    pub error: Option<String>,
133}
134
135/// Graph execution status
136#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
137pub enum GraphStatus {
138    /// Successfully completed
139    Success,
140    /// Failed with error
141    Failed,
142    /// Hit maximum iterations
143    MaxIterations,
144    /// Interrupted by user
145    Interrupted,
146    /// Waiting for input
147    Paused,
148}
149
150/// A node in the graph
151pub struct GraphNode {
152    /// Node ID
153    pub id: String,
154    /// Node executor
155    pub executor: Arc<dyn NodeFn>,
156}
157
158impl std::fmt::Debug for GraphNode {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        f.debug_struct("GraphNode").field("id", &self.id).finish()
161    }
162}
163
164/// Function type for node execution
165#[async_trait::async_trait]
166pub trait NodeFn: Send + Sync {
167    /// Execute the node and return updated state
168    async fn call(&self, state: GraphState) -> Result<GraphState, CrewError>;
169}
170
171/// Simple function wrapper for node execution
172pub struct FnNode<F>(pub F);
173
174#[async_trait::async_trait]
175impl<F, Fut> NodeFn for FnNode<F>
176where
177    F: Fn(GraphState) -> Fut + Send + Sync,
178    Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send,
179{
180    async fn call(&self, state: GraphState) -> Result<GraphState, CrewError> {
181        (self.0)(state).await
182    }
183}
184
185/// Edge types in the graph
186#[derive(Clone)]
187pub enum GraphEdge {
188    /// Direct edge to a single node
189    Direct { from: String, to: String },
190    /// Conditional edge with dynamic routing
191    Conditional {
192        from: String,
193        router: Arc<dyn EdgeRouter>,
194    },
195}
196
197impl std::fmt::Debug for GraphEdge {
198    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199        match self {
200            Self::Direct { from, to } => f
201                .debug_struct("Direct")
202                .field("from", from)
203                .field("to", to)
204                .finish(),
205            Self::Conditional { from, .. } => {
206                f.debug_struct("Conditional").field("from", from).finish()
207            }
208        }
209    }
210}
211
212/// Router for conditional edges
213pub trait EdgeRouter: Send + Sync {
214    /// Determine next node based on state
215    fn route(&self, state: &GraphState) -> String;
216}
217
218/// Simple function-based router
219pub struct FnRouter<F>(pub F);
220
221impl<F> EdgeRouter for FnRouter<F>
222where
223    F: Fn(&GraphState) -> String + Send + Sync,
224{
225    fn route(&self, state: &GraphState) -> String {
226        (self.0)(state)
227    }
228}
229
230/// Condition-based router
231pub struct ConditionRouter {
232    conditions: Vec<(Box<dyn Fn(&GraphState) -> bool + Send + Sync>, String)>,
233    default: String,
234}
235
236impl ConditionRouter {
237    /// Create a new condition router
238    pub fn new(default: impl Into<String>) -> Self {
239        Self {
240            conditions: Vec::new(),
241            default: default.into(),
242        }
243    }
244
245    /// Add a condition
246    pub fn when<F>(mut self, condition: F, target: impl Into<String>) -> Self
247    where
248        F: Fn(&GraphState) -> bool + Send + Sync + 'static,
249    {
250        self.conditions.push((Box::new(condition), target.into()));
251        self
252    }
253}
254
255impl EdgeRouter for ConditionRouter {
256    fn route(&self, state: &GraphState) -> String {
257        for (condition, target) in &self.conditions {
258            if condition(state) {
259                return target.clone();
260            }
261        }
262        self.default.clone()
263    }
264}
265
266/// State checkpoint for recovery
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct Checkpoint {
269    /// Checkpoint ID
270    pub id: String,
271    /// State at checkpoint
272    pub state: GraphState,
273    /// Node that was about to execute
274    pub next_node: String,
275    /// Timestamp
276    pub created_at: DateTime<Utc>,
277}
278
279/// Checkpoint storage trait
280#[async_trait::async_trait]
281pub trait CheckpointStore: Send + Sync {
282    /// Save a checkpoint
283    async fn save(&self, checkpoint: Checkpoint) -> Result<(), CrewError>;
284    /// Load a checkpoint by ID
285    async fn load(&self, id: &str) -> Result<Option<Checkpoint>, CrewError>;
286    /// List all checkpoints for a graph
287    async fn list(&self, graph_id: &str) -> Result<Vec<String>, CrewError>;
288    /// Delete a checkpoint
289    async fn delete(&self, id: &str) -> Result<(), CrewError>;
290}
291
292/// In-memory checkpoint store
293#[derive(Default)]
294pub struct InMemoryCheckpointStore {
295    checkpoints: RwLock<HashMap<String, Checkpoint>>,
296}
297
298#[async_trait::async_trait]
299impl CheckpointStore for InMemoryCheckpointStore {
300    async fn save(&self, checkpoint: Checkpoint) -> Result<(), CrewError> {
301        self.checkpoints
302            .write()
303            .await
304            .insert(checkpoint.id.clone(), checkpoint);
305        Ok(())
306    }
307
308    async fn load(&self, id: &str) -> Result<Option<Checkpoint>, CrewError> {
309        Ok(self.checkpoints.read().await.get(id).cloned())
310    }
311
312    async fn list(&self, _graph_id: &str) -> Result<Vec<String>, CrewError> {
313        Ok(self.checkpoints.read().await.keys().cloned().collect())
314    }
315
316    async fn delete(&self, id: &str) -> Result<(), CrewError> {
317        self.checkpoints.write().await.remove(id);
318        Ok(())
319    }
320}
321
322/// Special node IDs
323pub const START: &str = "__start__";
324pub const END: &str = "__end__";
325
326/// Graph configuration
327#[derive(Debug, Clone)]
328pub struct GraphConfig {
329    /// Maximum iterations before stopping (prevents infinite loops)
330    pub max_iterations: u32,
331    /// Enable checkpointing
332    pub checkpointing: bool,
333    /// Checkpoint interval (every N nodes)
334    pub checkpoint_interval: u32,
335    /// Enable parallel execution of independent branches
336    pub parallel_branches: bool,
337    /// Timeout per node in milliseconds
338    pub node_timeout_ms: Option<u64>,
339}
340
341impl Default for GraphConfig {
342    fn default() -> Self {
343        Self {
344            max_iterations: 100,
345            checkpointing: false,
346            checkpoint_interval: 5,
347            parallel_branches: false,
348            node_timeout_ms: None,
349        }
350    }
351}
352
353/// The main graph structure
354pub struct Graph {
355    /// Graph ID
356    pub id: String,
357    /// Graph name
358    pub name: String,
359    /// Nodes in the graph
360    pub nodes: HashMap<String, GraphNode>,
361    /// Edges in the graph
362    pub edges: Vec<GraphEdge>,
363    /// Entry node ID
364    pub entry_node: String,
365    /// Configuration
366    pub config: GraphConfig,
367    /// Checkpoint store
368    pub checkpoint_store: Option<Arc<dyn CheckpointStore>>,
369}
370
371impl std::fmt::Debug for Graph {
372    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373        f.debug_struct("Graph")
374            .field("id", &self.id)
375            .field("name", &self.name)
376            .field("nodes", &self.nodes.keys().collect::<Vec<_>>())
377            .field("entry_node", &self.entry_node)
378            .finish()
379    }
380}
381
382impl Graph {
383    /// Execute the graph with initial state
384    pub async fn invoke(&self, initial_state: GraphState) -> Result<GraphResult, CrewError> {
385        let mut state = initial_state;
386        state.metadata.started_at = Some(Utc::now());
387        state.metadata.iterations = 0;
388
389        let mut current_node = self.entry_node.clone();
390
391        loop {
392            // Check max iterations
393            if state.metadata.iterations >= self.config.max_iterations {
394                return Ok(GraphResult {
395                    state,
396                    status: GraphStatus::MaxIterations,
397                    error: Some(format!(
398                        "Hit maximum iterations: {}",
399                        self.config.max_iterations
400                    )),
401                });
402            }
403
404            // Check for END node
405            if current_node == END {
406                state.metadata.execution_time_ms = state
407                    .metadata
408                    .started_at
409                    .map(|s| Utc::now().signed_duration_since(s).num_milliseconds() as u64)
410                    .unwrap_or(0);
411                return Ok(GraphResult {
412                    state,
413                    status: GraphStatus::Success,
414                    error: None,
415                });
416            }
417
418            // Get node
419            let node = self.nodes.get(&current_node).ok_or_else(|| {
420                CrewError::TaskNotFound(format!("Node not found: {}", current_node))
421            })?;
422
423            // Save checkpoint if enabled
424            if self.config.checkpointing
425                && state
426                    .metadata
427                    .iterations
428                    .is_multiple_of(self.config.checkpoint_interval)
429            {
430                if let Some(store) = &self.checkpoint_store {
431                    let checkpoint = Checkpoint {
432                        id: format!("{}_{}", self.id, state.metadata.iterations),
433                        state: state.clone(),
434                        next_node: current_node.clone(),
435                        created_at: Utc::now(),
436                    };
437                    store.save(checkpoint).await?;
438                }
439            }
440
441            // Execute node
442            state.metadata.visited_nodes.push(current_node.clone());
443            state.metadata.iterations += 1;
444
445            state = match self.config.node_timeout_ms {
446                Some(timeout) => tokio::time::timeout(
447                    std::time::Duration::from_millis(timeout),
448                    node.executor.call(state),
449                )
450                .await
451                .map_err(|_| {
452                    CrewError::ExecutionFailed(format!("Node {} timed out", current_node))
453                })??,
454                None => node.executor.call(state).await?,
455            };
456
457            // Find next node
458            current_node = self.find_next_node(&current_node, &state)?;
459        }
460    }
461
462    /// Resume from a checkpoint
463    pub async fn resume(&self, checkpoint_id: &str) -> Result<GraphResult, CrewError> {
464        let store = self.checkpoint_store.as_ref().ok_or_else(|| {
465            CrewError::InvalidConfiguration("Checkpointing not enabled".to_string())
466        })?;
467
468        let checkpoint = store.load(checkpoint_id).await?.ok_or_else(|| {
469            CrewError::TaskNotFound(format!("Checkpoint not found: {}", checkpoint_id))
470        })?;
471
472        // Continue from checkpoint
473        let mut state = checkpoint.state;
474        let mut current_node = checkpoint.next_node;
475
476        loop {
477            if state.metadata.iterations >= self.config.max_iterations {
478                return Ok(GraphResult {
479                    state,
480                    status: GraphStatus::MaxIterations,
481                    error: Some(format!(
482                        "Hit maximum iterations: {}",
483                        self.config.max_iterations
484                    )),
485                });
486            }
487
488            if current_node == END {
489                state.metadata.execution_time_ms = state
490                    .metadata
491                    .started_at
492                    .map(|s| Utc::now().signed_duration_since(s).num_milliseconds() as u64)
493                    .unwrap_or(0);
494                return Ok(GraphResult {
495                    state,
496                    status: GraphStatus::Success,
497                    error: None,
498                });
499            }
500
501            let node = self.nodes.get(&current_node).ok_or_else(|| {
502                CrewError::TaskNotFound(format!("Node not found: {}", current_node))
503            })?;
504
505            state.metadata.visited_nodes.push(current_node.clone());
506            state.metadata.iterations += 1;
507
508            state = node.executor.call(state).await?;
509            current_node = self.find_next_node(&current_node, &state)?;
510        }
511    }
512
513    /// Find the next node based on edges
514    pub fn find_next_node(&self, current: &str, state: &GraphState) -> Result<String, CrewError> {
515        for edge in &self.edges {
516            match edge {
517                GraphEdge::Direct { from, to } if from == current => {
518                    return Ok(to.clone());
519                }
520                GraphEdge::Conditional { from, router } if from == current => {
521                    return Ok(router.route(state));
522                }
523                _ => continue,
524            }
525        }
526
527        // No outgoing edge = implicit END
528        Ok(END.to_string())
529    }
530
531    /// Stream execution, yielding state after each node
532    pub fn stream(&self, initial_state: GraphState) -> GraphStream<'_> {
533        GraphStream {
534            graph: self,
535            state: Some(initial_state),
536            current_node: Some(self.entry_node.clone()),
537            finished: false,
538        }
539    }
540
541    /// Get a visual representation of the graph (Mermaid format)
542    pub fn to_mermaid(&self) -> String {
543        let mut lines = vec!["graph TD".to_string()];
544
545        for id in self.nodes.keys() {
546            let display_id = if id == START {
547                "START"
548            } else if id == END {
549                "END"
550            } else {
551                id
552            };
553            lines.push(format!("    {}[{}]", id.replace('-', "_"), display_id));
554        }
555
556        for edge in &self.edges {
557            match edge {
558                GraphEdge::Direct { from, to } => {
559                    lines.push(format!(
560                        "    {} --> {}",
561                        from.replace('-', "_"),
562                        to.replace('-', "_")
563                    ));
564                }
565                GraphEdge::Conditional { from, .. } => {
566                    lines.push(format!(
567                        "    {} -.->|condition| ...",
568                        from.replace('-', "_")
569                    ));
570                }
571            }
572        }
573
574        lines.join("\n")
575    }
576}
577
578/// Streaming graph execution
579pub struct GraphStream<'a> {
580    graph: &'a Graph,
581    state: Option<GraphState>,
582    current_node: Option<String>,
583    finished: bool,
584}
585
586impl<'a> GraphStream<'a> {
587    /// Get next state update
588    pub async fn next(&mut self) -> Option<Result<(String, GraphState), CrewError>> {
589        if self.finished {
590            return None;
591        }
592
593        let current_node = self.current_node.take()?;
594        let mut state = self.state.take()?;
595
596        // Check for END
597        if current_node == END {
598            self.finished = true;
599            return Some(Ok((END.to_string(), state)));
600        }
601
602        // Check max iterations
603        if state.metadata.iterations >= self.graph.config.max_iterations {
604            self.finished = true;
605            return Some(Err(CrewError::ExecutionFailed(
606                "Max iterations reached".to_string(),
607            )));
608        }
609
610        // Get and execute node
611        let node = match self.graph.nodes.get(&current_node) {
612            Some(n) => n,
613            None => {
614                self.finished = true;
615                return Some(Err(CrewError::TaskNotFound(current_node)));
616            }
617        };
618
619        state.metadata.visited_nodes.push(current_node.clone());
620        state.metadata.iterations += 1;
621
622        match node.executor.call(state).await {
623            Ok(new_state) => {
624                let next_node = match self.graph.find_next_node(&current_node, &new_state) {
625                    Ok(n) => n,
626                    Err(e) => {
627                        self.finished = true;
628                        return Some(Err(e));
629                    }
630                };
631
632                self.state = Some(new_state.clone());
633                self.current_node = Some(next_node);
634                Some(Ok((current_node, new_state)))
635            }
636            Err(e) => {
637                self.finished = true;
638                Some(Err(e))
639            }
640        }
641    }
642}
643
644/// Builder for creating graphs
645pub struct GraphBuilder {
646    id: String,
647    name: String,
648    nodes: HashMap<String, GraphNode>,
649    edges: Vec<GraphEdge>,
650    entry_node: Option<String>,
651    config: GraphConfig,
652    checkpoint_store: Option<Arc<dyn CheckpointStore>>,
653}
654
655impl GraphBuilder {
656    /// Create a new graph builder
657    pub fn new(id: impl Into<String>) -> Self {
658        let id = id.into();
659        Self {
660            name: id.clone(),
661            id,
662            nodes: HashMap::new(),
663            edges: Vec::new(),
664            entry_node: None,
665            config: GraphConfig::default(),
666            checkpoint_store: None,
667        }
668    }
669
670    /// Set graph name
671    pub fn name(mut self, name: impl Into<String>) -> Self {
672        self.name = name.into();
673        self
674    }
675
676    /// Add a node with an async function
677    pub fn add_node<F, Fut>(mut self, id: impl Into<String>, func: F) -> Self
678    where
679        F: Fn(GraphState) -> Fut + Send + Sync + 'static,
680        Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
681    {
682        let id = id.into();
683        self.nodes.insert(
684            id.clone(),
685            GraphNode {
686                id: id.clone(),
687                executor: Arc::new(FnNode(func)),
688            },
689        );
690        self
691    }
692
693    /// Add a node with a custom executor
694    pub fn add_node_executor(mut self, id: impl Into<String>, executor: Arc<dyn NodeFn>) -> Self {
695        let id = id.into();
696        self.nodes.insert(id.clone(), GraphNode { id, executor });
697        self
698    }
699
700    /// Add a direct edge between two nodes
701    pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
702        self.edges.push(GraphEdge::Direct {
703            from: from.into(),
704            to: to.into(),
705        });
706        self
707    }
708
709    /// Add a conditional edge with a routing function
710    pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
711    where
712        F: Fn(&GraphState) -> String + Send + Sync + 'static,
713    {
714        self.edges.push(GraphEdge::Conditional {
715            from: from.into(),
716            router: Arc::new(FnRouter(router)),
717        });
718        self
719    }
720
721    /// Add a conditional edge with a ConditionRouter
722    pub fn add_conditional_edge_router(
723        mut self,
724        from: impl Into<String>,
725        router: ConditionRouter,
726    ) -> Self {
727        self.edges.push(GraphEdge::Conditional {
728            from: from.into(),
729            router: Arc::new(router),
730        });
731        self
732    }
733
734    /// Set the entry node
735    pub fn set_entry(mut self, node_id: impl Into<String>) -> Self {
736        self.entry_node = Some(node_id.into());
737        self
738    }
739
740    /// Set maximum iterations
741    pub fn max_iterations(mut self, max: u32) -> Self {
742        self.config.max_iterations = max;
743        self
744    }
745
746    /// Enable checkpointing
747    pub fn with_checkpointing(mut self, store: Arc<dyn CheckpointStore>) -> Self {
748        self.config.checkpointing = true;
749        self.checkpoint_store = Some(store);
750        self
751    }
752
753    /// Set checkpoint interval
754    pub fn checkpoint_interval(mut self, interval: u32) -> Self {
755        self.config.checkpoint_interval = interval;
756        self
757    }
758
759    /// Set node timeout
760    pub fn node_timeout_ms(mut self, timeout: u64) -> Self {
761        self.config.node_timeout_ms = Some(timeout);
762        self
763    }
764
765    /// Build the graph
766    pub fn build(self) -> Result<Graph, CrewError> {
767        let entry_node = self.entry_node.ok_or_else(|| {
768            CrewError::InvalidConfiguration("No entry node specified".to_string())
769        })?;
770
771        if !self.nodes.contains_key(&entry_node) {
772            return Err(CrewError::InvalidConfiguration(format!(
773                "Entry node '{}' not found",
774                entry_node
775            )));
776        }
777
778        // Validate edges
779        for edge in &self.edges {
780            let from = match edge {
781                GraphEdge::Direct { from, .. } => from,
782                GraphEdge::Conditional { from, .. } => from,
783            };
784            if !self.nodes.contains_key(from) {
785                return Err(CrewError::InvalidConfiguration(format!(
786                    "Edge source '{}' not found",
787                    from
788                )));
789            }
790            // Note: We don't validate 'to' for conditional edges since they're dynamic
791            if let GraphEdge::Direct { to, .. } = edge {
792                if to != END && !self.nodes.contains_key(to) {
793                    return Err(CrewError::InvalidConfiguration(format!(
794                        "Edge target '{}' not found",
795                        to
796                    )));
797                }
798            }
799        }
800
801        Ok(Graph {
802            id: self.id,
803            name: self.name,
804            nodes: self.nodes,
805            edges: self.edges,
806            entry_node,
807            config: self.config,
808            checkpoint_store: self.checkpoint_store,
809        })
810    }
811}
812
813/// StateGraph - a higher-level API for common patterns
814pub struct StateGraph<S> {
815    graph: Graph,
816    _phantom: std::marker::PhantomData<S>,
817}
818
819impl<S: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static> StateGraph<S> {
820    /// Create from a graph
821    pub fn new(graph: Graph) -> Self {
822        Self {
823            graph,
824            _phantom: std::marker::PhantomData,
825        }
826    }
827
828    /// Invoke with typed state
829    pub async fn invoke(&self, initial: S) -> Result<S, CrewError> {
830        let json = serde_json::to_value(&initial)
831            .map_err(|e| CrewError::ExecutionFailed(format!("Serialization error: {}", e)))?;
832        let state = GraphState::from_json(json);
833        let result = self.graph.invoke(state).await?;
834
835        if result.status != GraphStatus::Success {
836            return Err(CrewError::ExecutionFailed(
837                result
838                    .error
839                    .unwrap_or_else(|| "Graph execution failed".to_string()),
840            ));
841        }
842
843        serde_json::from_value(result.state.data)
844            .map_err(|e| CrewError::ExecutionFailed(format!("Deserialization error: {}", e)))
845    }
846}
847
848#[cfg(test)]
849mod tests {
850    use super::*;
851
852    #[tokio::test]
853    async fn test_simple_graph() {
854        let graph = GraphBuilder::new("simple")
855            .add_node("step1", |mut state: GraphState| async move {
856                state.set("step1_done", true);
857                Ok(state)
858            })
859            .add_node("step2", |mut state: GraphState| async move {
860                state.set("step2_done", true);
861                Ok(state)
862            })
863            .add_edge("step1", "step2")
864            .add_edge("step2", END)
865            .set_entry("step1")
866            .build()
867            .unwrap();
868
869        let result = graph.invoke(GraphState::new()).await.unwrap();
870
871        assert_eq!(result.status, GraphStatus::Success);
872        assert_eq!(result.state.get::<bool>("step1_done"), Some(true));
873        assert_eq!(result.state.get::<bool>("step2_done"), Some(true));
874        assert_eq!(result.state.metadata.iterations, 2);
875    }
876
877    #[tokio::test]
878    async fn test_conditional_edge() {
879        let graph = GraphBuilder::new("conditional")
880            .add_node("check", |state: GraphState| async move { Ok(state) })
881            .add_node("yes_path", |mut state: GraphState| async move {
882                state.set("path", "yes");
883                Ok(state)
884            })
885            .add_node("no_path", |mut state: GraphState| async move {
886                state.set("path", "no");
887                Ok(state)
888            })
889            .add_conditional_edge("check", |state| {
890                if state.get::<bool>("condition").unwrap_or(false) {
891                    "yes_path".to_string()
892                } else {
893                    "no_path".to_string()
894                }
895            })
896            .add_edge("yes_path", END)
897            .add_edge("no_path", END)
898            .set_entry("check")
899            .build()
900            .unwrap();
901
902        // Test with condition = true
903        let mut state = GraphState::new();
904        state.set("condition", true);
905        let result = graph.invoke(state).await.unwrap();
906        assert_eq!(result.state.get::<String>("path"), Some("yes".to_string()));
907
908        // Test with condition = false
909        let mut state = GraphState::new();
910        state.set("condition", false);
911        let result = graph.invoke(state).await.unwrap();
912        assert_eq!(result.state.get::<String>("path"), Some("no".to_string()));
913    }
914
915    #[tokio::test]
916    async fn test_cycle_with_limit() {
917        let graph = GraphBuilder::new("cycle")
918            .add_node("increment", |mut state: GraphState| async move {
919                let count: i32 = state.get("count").unwrap_or(0);
920                state.set("count", count + 1);
921                Ok(state)
922            })
923            .add_conditional_edge("increment", |state| {
924                let count: i32 = state.get("count").unwrap_or(0);
925                if count >= 5 {
926                    END.to_string()
927                } else {
928                    "increment".to_string() // Loop back
929                }
930            })
931            .set_entry("increment")
932            .max_iterations(100)
933            .build()
934            .unwrap();
935
936        let result = graph.invoke(GraphState::new()).await.unwrap();
937
938        assert_eq!(result.status, GraphStatus::Success);
939        assert_eq!(result.state.get::<i32>("count"), Some(5));
940        assert_eq!(result.state.metadata.iterations, 5);
941    }
942
943    #[tokio::test]
944    async fn test_max_iterations_limit() {
945        let graph = GraphBuilder::new("infinite")
946            .add_node("loop", |state: GraphState| async move { Ok(state) })
947            .add_edge("loop", "loop") // Infinite loop
948            .set_entry("loop")
949            .max_iterations(10)
950            .build()
951            .unwrap();
952
953        let result = graph.invoke(GraphState::new()).await.unwrap();
954
955        assert_eq!(result.status, GraphStatus::MaxIterations);
956        assert_eq!(result.state.metadata.iterations, 10);
957    }
958
959    #[tokio::test]
960    async fn test_condition_router() {
961        let router = ConditionRouter::new("default")
962            .when(|s| s.get::<i32>("score").unwrap_or(0) >= 80, "excellent")
963            .when(|s| s.get::<i32>("score").unwrap_or(0) >= 60, "good")
964            .when(|s| s.get::<i32>("score").unwrap_or(0) >= 40, "pass");
965
966        let mut state = GraphState::new();
967        state.set("score", 85);
968        assert_eq!(router.route(&state), "excellent");
969
970        state.set("score", 65);
971        assert_eq!(router.route(&state), "good");
972
973        state.set("score", 30);
974        assert_eq!(router.route(&state), "default");
975    }
976
977    #[tokio::test]
978    async fn test_checkpointing() {
979        let store = Arc::new(InMemoryCheckpointStore::default());
980
981        let graph = GraphBuilder::new("checkpoint_test")
982            .add_node("step1", |mut state: GraphState| async move {
983                state.set("step", 1);
984                Ok(state)
985            })
986            .add_node("step2", |mut state: GraphState| async move {
987                state.set("step", 2);
988                Ok(state)
989            })
990            .add_edge("step1", "step2")
991            .add_edge("step2", END)
992            .set_entry("step1")
993            .with_checkpointing(store.clone())
994            .checkpoint_interval(1)
995            .build()
996            .unwrap();
997
998        let result = graph.invoke(GraphState::new()).await.unwrap();
999        assert_eq!(result.status, GraphStatus::Success);
1000
1001        // Verify checkpoints were created
1002        let checkpoints = store.list("checkpoint_test").await.unwrap();
1003        assert!(!checkpoints.is_empty());
1004    }
1005
1006    #[test]
1007    fn test_mermaid_output() {
1008        let graph = GraphBuilder::new("mermaid_test")
1009            .add_node("start", |s| async { Ok(s) })
1010            .add_node("process", |s| async { Ok(s) })
1011            .add_node("end", |s| async { Ok(s) })
1012            .add_edge("start", "process")
1013            .add_edge("process", "end")
1014            .set_entry("start")
1015            .build()
1016            .unwrap();
1017
1018        let mermaid = graph.to_mermaid();
1019        assert!(mermaid.contains("graph TD"));
1020        assert!(mermaid.contains("start"));
1021        assert!(mermaid.contains("process"));
1022    }
1023
1024    #[tokio::test]
1025    async fn test_stream_execution() {
1026        let graph = GraphBuilder::new("stream_test")
1027            .add_node("a", |mut s: GraphState| async move {
1028                s.set("a", true);
1029                Ok(s)
1030            })
1031            .add_node("b", |mut s: GraphState| async move {
1032                s.set("b", true);
1033                Ok(s)
1034            })
1035            .add_edge("a", "b")
1036            .add_edge("b", END)
1037            .set_entry("a")
1038            .build()
1039            .unwrap();
1040
1041        let mut stream = graph.stream(GraphState::new());
1042        let mut steps = Vec::new();
1043
1044        while let Some(result) = stream.next().await {
1045            let (node_id, _state) = result.unwrap();
1046            steps.push(node_id);
1047        }
1048
1049        assert_eq!(steps, vec!["a", "b", END]);
1050    }
1051}