1use chrono::{DateTime, Utc};
38use cortexai_core::errors::CrewError;
39use serde::{Deserialize, Serialize};
40use std::collections::HashMap;
41use std::sync::Arc;
42use tokio::sync::RwLock;
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct GraphState {
47 pub data: serde_json::Value,
49 pub metadata: GraphMetadata,
51}
52
53impl Default for GraphState {
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59impl GraphState {
60 pub fn new() -> Self {
62 Self {
63 data: serde_json::json!({}),
64 metadata: GraphMetadata::default(),
65 }
66 }
67
68 pub fn from_json(data: serde_json::Value) -> Self {
70 Self {
71 data,
72 metadata: GraphMetadata::default(),
73 }
74 }
75
76 pub fn get<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
78 self.data
79 .get(key)
80 .and_then(|v| serde_json::from_value(v.clone()).ok())
81 }
82
83 pub fn set<T: Serialize>(&mut self, key: &str, value: T) {
85 if let Some(obj) = self.data.as_object_mut() {
86 if let Ok(v) = serde_json::to_value(value) {
87 obj.insert(key.to_string(), v);
88 }
89 }
90 }
91
92 pub fn merge(&mut self, other: &GraphState) {
94 if let (Some(self_obj), Some(other_obj)) =
95 (self.data.as_object_mut(), other.data.as_object())
96 {
97 for (k, v) in other_obj {
98 self_obj.insert(k.clone(), v.clone());
99 }
100 }
101 }
102
103 pub fn raw(&self) -> &serde_json::Value {
105 &self.data
106 }
107}
108
109#[derive(Debug, Clone, Default, Serialize, Deserialize)]
111pub struct GraphMetadata {
112 pub iterations: u32,
114 pub visited_nodes: Vec<String>,
116 pub checkpoint_id: Option<String>,
118 pub execution_time_ms: u64,
120 pub started_at: Option<DateTime<Utc>>,
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct GraphResult {
127 pub state: GraphState,
129 pub status: GraphStatus,
131 pub error: Option<String>,
133}
134
135#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
137pub enum GraphStatus {
138 Success,
140 Failed,
142 MaxIterations,
144 Interrupted,
146 Paused,
148}
149
150pub struct GraphNode {
152 pub id: String,
154 pub executor: Arc<dyn NodeFn>,
156}
157
158impl std::fmt::Debug for GraphNode {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("GraphNode").field("id", &self.id).finish()
161 }
162}
163
164#[async_trait::async_trait]
166pub trait NodeFn: Send + Sync {
167 async fn call(&self, state: GraphState) -> Result<GraphState, CrewError>;
169}
170
171pub struct FnNode<F>(pub F);
173
174#[async_trait::async_trait]
175impl<F, Fut> NodeFn for FnNode<F>
176where
177 F: Fn(GraphState) -> Fut + Send + Sync,
178 Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send,
179{
180 async fn call(&self, state: GraphState) -> Result<GraphState, CrewError> {
181 (self.0)(state).await
182 }
183}
184
185#[derive(Clone)]
187pub enum GraphEdge {
188 Direct { from: String, to: String },
190 Conditional {
192 from: String,
193 router: Arc<dyn EdgeRouter>,
194 },
195}
196
197impl std::fmt::Debug for GraphEdge {
198 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
199 match self {
200 Self::Direct { from, to } => f
201 .debug_struct("Direct")
202 .field("from", from)
203 .field("to", to)
204 .finish(),
205 Self::Conditional { from, .. } => {
206 f.debug_struct("Conditional").field("from", from).finish()
207 }
208 }
209 }
210}
211
212pub trait EdgeRouter: Send + Sync {
214 fn route(&self, state: &GraphState) -> String;
216}
217
218pub struct FnRouter<F>(pub F);
220
221impl<F> EdgeRouter for FnRouter<F>
222where
223 F: Fn(&GraphState) -> String + Send + Sync,
224{
225 fn route(&self, state: &GraphState) -> String {
226 (self.0)(state)
227 }
228}
229
230pub struct ConditionRouter {
232 conditions: Vec<(Box<dyn Fn(&GraphState) -> bool + Send + Sync>, String)>,
233 default: String,
234}
235
236impl ConditionRouter {
237 pub fn new(default: impl Into<String>) -> Self {
239 Self {
240 conditions: Vec::new(),
241 default: default.into(),
242 }
243 }
244
245 pub fn when<F>(mut self, condition: F, target: impl Into<String>) -> Self
247 where
248 F: Fn(&GraphState) -> bool + Send + Sync + 'static,
249 {
250 self.conditions.push((Box::new(condition), target.into()));
251 self
252 }
253}
254
255impl EdgeRouter for ConditionRouter {
256 fn route(&self, state: &GraphState) -> String {
257 for (condition, target) in &self.conditions {
258 if condition(state) {
259 return target.clone();
260 }
261 }
262 self.default.clone()
263 }
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct Checkpoint {
269 pub id: String,
271 pub state: GraphState,
273 pub next_node: String,
275 pub created_at: DateTime<Utc>,
277}
278
279#[async_trait::async_trait]
281pub trait CheckpointStore: Send + Sync {
282 async fn save(&self, checkpoint: Checkpoint) -> Result<(), CrewError>;
284 async fn load(&self, id: &str) -> Result<Option<Checkpoint>, CrewError>;
286 async fn list(&self, graph_id: &str) -> Result<Vec<String>, CrewError>;
288 async fn delete(&self, id: &str) -> Result<(), CrewError>;
290}
291
292#[derive(Default)]
294pub struct InMemoryCheckpointStore {
295 checkpoints: RwLock<HashMap<String, Checkpoint>>,
296}
297
298#[async_trait::async_trait]
299impl CheckpointStore for InMemoryCheckpointStore {
300 async fn save(&self, checkpoint: Checkpoint) -> Result<(), CrewError> {
301 self.checkpoints
302 .write()
303 .await
304 .insert(checkpoint.id.clone(), checkpoint);
305 Ok(())
306 }
307
308 async fn load(&self, id: &str) -> Result<Option<Checkpoint>, CrewError> {
309 Ok(self.checkpoints.read().await.get(id).cloned())
310 }
311
312 async fn list(&self, _graph_id: &str) -> Result<Vec<String>, CrewError> {
313 Ok(self.checkpoints.read().await.keys().cloned().collect())
314 }
315
316 async fn delete(&self, id: &str) -> Result<(), CrewError> {
317 self.checkpoints.write().await.remove(id);
318 Ok(())
319 }
320}
321
322pub const START: &str = "__start__";
324pub const END: &str = "__end__";
325
326#[derive(Debug, Clone)]
328pub struct GraphConfig {
329 pub max_iterations: u32,
331 pub checkpointing: bool,
333 pub checkpoint_interval: u32,
335 pub parallel_branches: bool,
337 pub node_timeout_ms: Option<u64>,
339}
340
341impl Default for GraphConfig {
342 fn default() -> Self {
343 Self {
344 max_iterations: 100,
345 checkpointing: false,
346 checkpoint_interval: 5,
347 parallel_branches: false,
348 node_timeout_ms: None,
349 }
350 }
351}
352
353pub struct Graph {
355 pub id: String,
357 pub name: String,
359 pub nodes: HashMap<String, GraphNode>,
361 pub edges: Vec<GraphEdge>,
363 pub entry_node: String,
365 pub config: GraphConfig,
367 pub checkpoint_store: Option<Arc<dyn CheckpointStore>>,
369}
370
371impl std::fmt::Debug for Graph {
372 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
373 f.debug_struct("Graph")
374 .field("id", &self.id)
375 .field("name", &self.name)
376 .field("nodes", &self.nodes.keys().collect::<Vec<_>>())
377 .field("entry_node", &self.entry_node)
378 .finish()
379 }
380}
381
382impl Graph {
383 pub async fn invoke(&self, initial_state: GraphState) -> Result<GraphResult, CrewError> {
385 let mut state = initial_state;
386 state.metadata.started_at = Some(Utc::now());
387 state.metadata.iterations = 0;
388
389 let mut current_node = self.entry_node.clone();
390
391 loop {
392 if state.metadata.iterations >= self.config.max_iterations {
394 return Ok(GraphResult {
395 state,
396 status: GraphStatus::MaxIterations,
397 error: Some(format!(
398 "Hit maximum iterations: {}",
399 self.config.max_iterations
400 )),
401 });
402 }
403
404 if current_node == END {
406 state.metadata.execution_time_ms = state
407 .metadata
408 .started_at
409 .map(|s| Utc::now().signed_duration_since(s).num_milliseconds() as u64)
410 .unwrap_or(0);
411 return Ok(GraphResult {
412 state,
413 status: GraphStatus::Success,
414 error: None,
415 });
416 }
417
418 let node = self.nodes.get(¤t_node).ok_or_else(|| {
420 CrewError::TaskNotFound(format!("Node not found: {}", current_node))
421 })?;
422
423 if self.config.checkpointing
425 && state
426 .metadata
427 .iterations
428 .is_multiple_of(self.config.checkpoint_interval)
429 {
430 if let Some(store) = &self.checkpoint_store {
431 let checkpoint = Checkpoint {
432 id: format!("{}_{}", self.id, state.metadata.iterations),
433 state: state.clone(),
434 next_node: current_node.clone(),
435 created_at: Utc::now(),
436 };
437 store.save(checkpoint).await?;
438 }
439 }
440
441 state.metadata.visited_nodes.push(current_node.clone());
443 state.metadata.iterations += 1;
444
445 state = match self.config.node_timeout_ms {
446 Some(timeout) => tokio::time::timeout(
447 std::time::Duration::from_millis(timeout),
448 node.executor.call(state),
449 )
450 .await
451 .map_err(|_| {
452 CrewError::ExecutionFailed(format!("Node {} timed out", current_node))
453 })??,
454 None => node.executor.call(state).await?,
455 };
456
457 current_node = self.find_next_node(¤t_node, &state)?;
459 }
460 }
461
462 pub async fn resume(&self, checkpoint_id: &str) -> Result<GraphResult, CrewError> {
464 let store = self.checkpoint_store.as_ref().ok_or_else(|| {
465 CrewError::InvalidConfiguration("Checkpointing not enabled".to_string())
466 })?;
467
468 let checkpoint = store.load(checkpoint_id).await?.ok_or_else(|| {
469 CrewError::TaskNotFound(format!("Checkpoint not found: {}", checkpoint_id))
470 })?;
471
472 let mut state = checkpoint.state;
474 let mut current_node = checkpoint.next_node;
475
476 loop {
477 if state.metadata.iterations >= self.config.max_iterations {
478 return Ok(GraphResult {
479 state,
480 status: GraphStatus::MaxIterations,
481 error: Some(format!(
482 "Hit maximum iterations: {}",
483 self.config.max_iterations
484 )),
485 });
486 }
487
488 if current_node == END {
489 state.metadata.execution_time_ms = state
490 .metadata
491 .started_at
492 .map(|s| Utc::now().signed_duration_since(s).num_milliseconds() as u64)
493 .unwrap_or(0);
494 return Ok(GraphResult {
495 state,
496 status: GraphStatus::Success,
497 error: None,
498 });
499 }
500
501 let node = self.nodes.get(¤t_node).ok_or_else(|| {
502 CrewError::TaskNotFound(format!("Node not found: {}", current_node))
503 })?;
504
505 state.metadata.visited_nodes.push(current_node.clone());
506 state.metadata.iterations += 1;
507
508 state = node.executor.call(state).await?;
509 current_node = self.find_next_node(¤t_node, &state)?;
510 }
511 }
512
513 pub fn find_next_node(&self, current: &str, state: &GraphState) -> Result<String, CrewError> {
515 for edge in &self.edges {
516 match edge {
517 GraphEdge::Direct { from, to } if from == current => {
518 return Ok(to.clone());
519 }
520 GraphEdge::Conditional { from, router } if from == current => {
521 return Ok(router.route(state));
522 }
523 _ => continue,
524 }
525 }
526
527 Ok(END.to_string())
529 }
530
531 pub fn stream(&self, initial_state: GraphState) -> GraphStream<'_> {
533 GraphStream {
534 graph: self,
535 state: Some(initial_state),
536 current_node: Some(self.entry_node.clone()),
537 finished: false,
538 }
539 }
540
541 pub fn to_mermaid(&self) -> String {
543 let mut lines = vec!["graph TD".to_string()];
544
545 for id in self.nodes.keys() {
546 let display_id = if id == START {
547 "START"
548 } else if id == END {
549 "END"
550 } else {
551 id
552 };
553 lines.push(format!(" {}[{}]", id.replace('-', "_"), display_id));
554 }
555
556 for edge in &self.edges {
557 match edge {
558 GraphEdge::Direct { from, to } => {
559 lines.push(format!(
560 " {} --> {}",
561 from.replace('-', "_"),
562 to.replace('-', "_")
563 ));
564 }
565 GraphEdge::Conditional { from, .. } => {
566 lines.push(format!(
567 " {} -.->|condition| ...",
568 from.replace('-', "_")
569 ));
570 }
571 }
572 }
573
574 lines.join("\n")
575 }
576}
577
578pub struct GraphStream<'a> {
580 graph: &'a Graph,
581 state: Option<GraphState>,
582 current_node: Option<String>,
583 finished: bool,
584}
585
586impl<'a> GraphStream<'a> {
587 pub async fn next(&mut self) -> Option<Result<(String, GraphState), CrewError>> {
589 if self.finished {
590 return None;
591 }
592
593 let current_node = self.current_node.take()?;
594 let mut state = self.state.take()?;
595
596 if current_node == END {
598 self.finished = true;
599 return Some(Ok((END.to_string(), state)));
600 }
601
602 if state.metadata.iterations >= self.graph.config.max_iterations {
604 self.finished = true;
605 return Some(Err(CrewError::ExecutionFailed(
606 "Max iterations reached".to_string(),
607 )));
608 }
609
610 let node = match self.graph.nodes.get(¤t_node) {
612 Some(n) => n,
613 None => {
614 self.finished = true;
615 return Some(Err(CrewError::TaskNotFound(current_node)));
616 }
617 };
618
619 state.metadata.visited_nodes.push(current_node.clone());
620 state.metadata.iterations += 1;
621
622 match node.executor.call(state).await {
623 Ok(new_state) => {
624 let next_node = match self.graph.find_next_node(¤t_node, &new_state) {
625 Ok(n) => n,
626 Err(e) => {
627 self.finished = true;
628 return Some(Err(e));
629 }
630 };
631
632 self.state = Some(new_state.clone());
633 self.current_node = Some(next_node);
634 Some(Ok((current_node, new_state)))
635 }
636 Err(e) => {
637 self.finished = true;
638 Some(Err(e))
639 }
640 }
641 }
642}
643
644pub struct GraphBuilder {
646 id: String,
647 name: String,
648 nodes: HashMap<String, GraphNode>,
649 edges: Vec<GraphEdge>,
650 entry_node: Option<String>,
651 config: GraphConfig,
652 checkpoint_store: Option<Arc<dyn CheckpointStore>>,
653}
654
655impl GraphBuilder {
656 pub fn new(id: impl Into<String>) -> Self {
658 let id = id.into();
659 Self {
660 name: id.clone(),
661 id,
662 nodes: HashMap::new(),
663 edges: Vec::new(),
664 entry_node: None,
665 config: GraphConfig::default(),
666 checkpoint_store: None,
667 }
668 }
669
670 pub fn name(mut self, name: impl Into<String>) -> Self {
672 self.name = name.into();
673 self
674 }
675
676 pub fn add_node<F, Fut>(mut self, id: impl Into<String>, func: F) -> Self
678 where
679 F: Fn(GraphState) -> Fut + Send + Sync + 'static,
680 Fut: std::future::Future<Output = Result<GraphState, CrewError>> + Send + 'static,
681 {
682 let id = id.into();
683 self.nodes.insert(
684 id.clone(),
685 GraphNode {
686 id: id.clone(),
687 executor: Arc::new(FnNode(func)),
688 },
689 );
690 self
691 }
692
693 pub fn add_node_executor(mut self, id: impl Into<String>, executor: Arc<dyn NodeFn>) -> Self {
695 let id = id.into();
696 self.nodes.insert(id.clone(), GraphNode { id, executor });
697 self
698 }
699
700 pub fn add_edge(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
702 self.edges.push(GraphEdge::Direct {
703 from: from.into(),
704 to: to.into(),
705 });
706 self
707 }
708
709 pub fn add_conditional_edge<F>(mut self, from: impl Into<String>, router: F) -> Self
711 where
712 F: Fn(&GraphState) -> String + Send + Sync + 'static,
713 {
714 self.edges.push(GraphEdge::Conditional {
715 from: from.into(),
716 router: Arc::new(FnRouter(router)),
717 });
718 self
719 }
720
721 pub fn add_conditional_edge_router(
723 mut self,
724 from: impl Into<String>,
725 router: ConditionRouter,
726 ) -> Self {
727 self.edges.push(GraphEdge::Conditional {
728 from: from.into(),
729 router: Arc::new(router),
730 });
731 self
732 }
733
734 pub fn set_entry(mut self, node_id: impl Into<String>) -> Self {
736 self.entry_node = Some(node_id.into());
737 self
738 }
739
740 pub fn max_iterations(mut self, max: u32) -> Self {
742 self.config.max_iterations = max;
743 self
744 }
745
746 pub fn with_checkpointing(mut self, store: Arc<dyn CheckpointStore>) -> Self {
748 self.config.checkpointing = true;
749 self.checkpoint_store = Some(store);
750 self
751 }
752
753 pub fn checkpoint_interval(mut self, interval: u32) -> Self {
755 self.config.checkpoint_interval = interval;
756 self
757 }
758
759 pub fn node_timeout_ms(mut self, timeout: u64) -> Self {
761 self.config.node_timeout_ms = Some(timeout);
762 self
763 }
764
765 pub fn build(self) -> Result<Graph, CrewError> {
767 let entry_node = self.entry_node.ok_or_else(|| {
768 CrewError::InvalidConfiguration("No entry node specified".to_string())
769 })?;
770
771 if !self.nodes.contains_key(&entry_node) {
772 return Err(CrewError::InvalidConfiguration(format!(
773 "Entry node '{}' not found",
774 entry_node
775 )));
776 }
777
778 for edge in &self.edges {
780 let from = match edge {
781 GraphEdge::Direct { from, .. } => from,
782 GraphEdge::Conditional { from, .. } => from,
783 };
784 if !self.nodes.contains_key(from) {
785 return Err(CrewError::InvalidConfiguration(format!(
786 "Edge source '{}' not found",
787 from
788 )));
789 }
790 if let GraphEdge::Direct { to, .. } = edge {
792 if to != END && !self.nodes.contains_key(to) {
793 return Err(CrewError::InvalidConfiguration(format!(
794 "Edge target '{}' not found",
795 to
796 )));
797 }
798 }
799 }
800
801 Ok(Graph {
802 id: self.id,
803 name: self.name,
804 nodes: self.nodes,
805 edges: self.edges,
806 entry_node,
807 config: self.config,
808 checkpoint_store: self.checkpoint_store,
809 })
810 }
811}
812
813pub struct StateGraph<S> {
815 graph: Graph,
816 _phantom: std::marker::PhantomData<S>,
817}
818
819impl<S: Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static> StateGraph<S> {
820 pub fn new(graph: Graph) -> Self {
822 Self {
823 graph,
824 _phantom: std::marker::PhantomData,
825 }
826 }
827
828 pub async fn invoke(&self, initial: S) -> Result<S, CrewError> {
830 let json = serde_json::to_value(&initial)
831 .map_err(|e| CrewError::ExecutionFailed(format!("Serialization error: {}", e)))?;
832 let state = GraphState::from_json(json);
833 let result = self.graph.invoke(state).await?;
834
835 if result.status != GraphStatus::Success {
836 return Err(CrewError::ExecutionFailed(
837 result
838 .error
839 .unwrap_or_else(|| "Graph execution failed".to_string()),
840 ));
841 }
842
843 serde_json::from_value(result.state.data)
844 .map_err(|e| CrewError::ExecutionFailed(format!("Deserialization error: {}", e)))
845 }
846}
847
848#[cfg(test)]
849mod tests {
850 use super::*;
851
852 #[tokio::test]
853 async fn test_simple_graph() {
854 let graph = GraphBuilder::new("simple")
855 .add_node("step1", |mut state: GraphState| async move {
856 state.set("step1_done", true);
857 Ok(state)
858 })
859 .add_node("step2", |mut state: GraphState| async move {
860 state.set("step2_done", true);
861 Ok(state)
862 })
863 .add_edge("step1", "step2")
864 .add_edge("step2", END)
865 .set_entry("step1")
866 .build()
867 .unwrap();
868
869 let result = graph.invoke(GraphState::new()).await.unwrap();
870
871 assert_eq!(result.status, GraphStatus::Success);
872 assert_eq!(result.state.get::<bool>("step1_done"), Some(true));
873 assert_eq!(result.state.get::<bool>("step2_done"), Some(true));
874 assert_eq!(result.state.metadata.iterations, 2);
875 }
876
877 #[tokio::test]
878 async fn test_conditional_edge() {
879 let graph = GraphBuilder::new("conditional")
880 .add_node("check", |state: GraphState| async move { Ok(state) })
881 .add_node("yes_path", |mut state: GraphState| async move {
882 state.set("path", "yes");
883 Ok(state)
884 })
885 .add_node("no_path", |mut state: GraphState| async move {
886 state.set("path", "no");
887 Ok(state)
888 })
889 .add_conditional_edge("check", |state| {
890 if state.get::<bool>("condition").unwrap_or(false) {
891 "yes_path".to_string()
892 } else {
893 "no_path".to_string()
894 }
895 })
896 .add_edge("yes_path", END)
897 .add_edge("no_path", END)
898 .set_entry("check")
899 .build()
900 .unwrap();
901
902 let mut state = GraphState::new();
904 state.set("condition", true);
905 let result = graph.invoke(state).await.unwrap();
906 assert_eq!(result.state.get::<String>("path"), Some("yes".to_string()));
907
908 let mut state = GraphState::new();
910 state.set("condition", false);
911 let result = graph.invoke(state).await.unwrap();
912 assert_eq!(result.state.get::<String>("path"), Some("no".to_string()));
913 }
914
915 #[tokio::test]
916 async fn test_cycle_with_limit() {
917 let graph = GraphBuilder::new("cycle")
918 .add_node("increment", |mut state: GraphState| async move {
919 let count: i32 = state.get("count").unwrap_or(0);
920 state.set("count", count + 1);
921 Ok(state)
922 })
923 .add_conditional_edge("increment", |state| {
924 let count: i32 = state.get("count").unwrap_or(0);
925 if count >= 5 {
926 END.to_string()
927 } else {
928 "increment".to_string() }
930 })
931 .set_entry("increment")
932 .max_iterations(100)
933 .build()
934 .unwrap();
935
936 let result = graph.invoke(GraphState::new()).await.unwrap();
937
938 assert_eq!(result.status, GraphStatus::Success);
939 assert_eq!(result.state.get::<i32>("count"), Some(5));
940 assert_eq!(result.state.metadata.iterations, 5);
941 }
942
943 #[tokio::test]
944 async fn test_max_iterations_limit() {
945 let graph = GraphBuilder::new("infinite")
946 .add_node("loop", |state: GraphState| async move { Ok(state) })
947 .add_edge("loop", "loop") .set_entry("loop")
949 .max_iterations(10)
950 .build()
951 .unwrap();
952
953 let result = graph.invoke(GraphState::new()).await.unwrap();
954
955 assert_eq!(result.status, GraphStatus::MaxIterations);
956 assert_eq!(result.state.metadata.iterations, 10);
957 }
958
959 #[tokio::test]
960 async fn test_condition_router() {
961 let router = ConditionRouter::new("default")
962 .when(|s| s.get::<i32>("score").unwrap_or(0) >= 80, "excellent")
963 .when(|s| s.get::<i32>("score").unwrap_or(0) >= 60, "good")
964 .when(|s| s.get::<i32>("score").unwrap_or(0) >= 40, "pass");
965
966 let mut state = GraphState::new();
967 state.set("score", 85);
968 assert_eq!(router.route(&state), "excellent");
969
970 state.set("score", 65);
971 assert_eq!(router.route(&state), "good");
972
973 state.set("score", 30);
974 assert_eq!(router.route(&state), "default");
975 }
976
977 #[tokio::test]
978 async fn test_checkpointing() {
979 let store = Arc::new(InMemoryCheckpointStore::default());
980
981 let graph = GraphBuilder::new("checkpoint_test")
982 .add_node("step1", |mut state: GraphState| async move {
983 state.set("step", 1);
984 Ok(state)
985 })
986 .add_node("step2", |mut state: GraphState| async move {
987 state.set("step", 2);
988 Ok(state)
989 })
990 .add_edge("step1", "step2")
991 .add_edge("step2", END)
992 .set_entry("step1")
993 .with_checkpointing(store.clone())
994 .checkpoint_interval(1)
995 .build()
996 .unwrap();
997
998 let result = graph.invoke(GraphState::new()).await.unwrap();
999 assert_eq!(result.status, GraphStatus::Success);
1000
1001 let checkpoints = store.list("checkpoint_test").await.unwrap();
1003 assert!(!checkpoints.is_empty());
1004 }
1005
1006 #[test]
1007 fn test_mermaid_output() {
1008 let graph = GraphBuilder::new("mermaid_test")
1009 .add_node("start", |s| async { Ok(s) })
1010 .add_node("process", |s| async { Ok(s) })
1011 .add_node("end", |s| async { Ok(s) })
1012 .add_edge("start", "process")
1013 .add_edge("process", "end")
1014 .set_entry("start")
1015 .build()
1016 .unwrap();
1017
1018 let mermaid = graph.to_mermaid();
1019 assert!(mermaid.contains("graph TD"));
1020 assert!(mermaid.contains("start"));
1021 assert!(mermaid.contains("process"));
1022 }
1023
1024 #[tokio::test]
1025 async fn test_stream_execution() {
1026 let graph = GraphBuilder::new("stream_test")
1027 .add_node("a", |mut s: GraphState| async move {
1028 s.set("a", true);
1029 Ok(s)
1030 })
1031 .add_node("b", |mut s: GraphState| async move {
1032 s.set("b", true);
1033 Ok(s)
1034 })
1035 .add_edge("a", "b")
1036 .add_edge("b", END)
1037 .set_entry("a")
1038 .build()
1039 .unwrap();
1040
1041 let mut stream = graph.stream(GraphState::new());
1042 let mut steps = Vec::new();
1043
1044 while let Some(result) = stream.next().await {
1045 let (node_id, _state) = result.unwrap();
1046 steps.push(node_id);
1047 }
1048
1049 assert_eq!(steps, vec!["a", "b", END]);
1050 }
1051}