Skip to main content

serdes_ai_graph/
state.rs

1//! Graph state types.
2
3use serde::{de::DeserializeOwned, Serialize};
4use std::fmt::Debug;
5
6/// Trait for graph state types.
7///
8/// State must be clonable, sendable, and debuggable.
9/// For persistence, it should also be serializable.
10pub trait GraphState: Clone + Send + Sync + Debug + 'static {}
11
12/// Blanket implementation for all compatible types.
13impl<T> GraphState for T where T: Clone + Send + Sync + Debug + 'static {}
14
15/// Context passed to nodes during execution.
16#[derive(Debug, Clone)]
17pub struct GraphRunContext<State, Deps = ()> {
18    /// Current state.
19    pub state: State,
20    /// Dependencies.
21    pub deps: Deps,
22    /// Current step number.
23    pub step: u32,
24    /// Unique run identifier.
25    pub run_id: String,
26    /// Maximum steps allowed.
27    pub max_steps: u32,
28}
29
30impl<State, Deps> GraphRunContext<State, Deps> {
31    /// Create a new context.
32    pub fn new(state: State, deps: Deps, run_id: impl Into<String>) -> Self {
33        Self {
34            state,
35            deps,
36            step: 0,
37            run_id: run_id.into(),
38            max_steps: 100,
39        }
40    }
41
42    /// Set maximum steps.
43    pub fn with_max_steps(mut self, max: u32) -> Self {
44        self.max_steps = max;
45        self
46    }
47
48    /// Increment step counter.
49    pub fn increment_step(&mut self) {
50        self.step += 1;
51    }
52
53    /// Check if max steps reached.
54    pub fn is_max_steps_reached(&self) -> bool {
55        self.step >= self.max_steps
56    }
57}
58
59impl<State: Default, Deps: Default> Default for GraphRunContext<State, Deps> {
60    fn default() -> Self {
61        Self {
62            state: State::default(),
63            deps: Deps::default(),
64            step: 0,
65            run_id: generate_run_id(),
66            max_steps: 100,
67        }
68    }
69}
70
71/// Result of a graph run.
72#[derive(Debug, Clone)]
73pub struct GraphRunResult<State, End = ()> {
74    /// Final result value.
75    pub result: End,
76    /// Final state.
77    pub state: State,
78    /// Number of steps executed.
79    pub steps: u32,
80    /// History of node names visited.
81    pub history: Vec<String>,
82    /// Run ID.
83    pub run_id: String,
84}
85
86impl<State, End> GraphRunResult<State, End> {
87    /// Create a new result.
88    pub fn new(result: End, state: State, steps: u32, run_id: impl Into<String>) -> Self {
89        Self {
90            result,
91            state,
92            steps,
93            history: Vec::new(),
94            run_id: run_id.into(),
95        }
96    }
97
98    /// Add history.
99    pub fn with_history(mut self, history: Vec<String>) -> Self {
100        self.history = history;
101        self
102    }
103}
104
105/// Generate a unique run ID.
106pub fn generate_run_id() -> String {
107    use std::time::SystemTime;
108    let timestamp = SystemTime::now()
109        .duration_since(SystemTime::UNIX_EPOCH)
110        .unwrap_or_default()
111        .as_nanos();
112    format!("run-{:x}", timestamp)
113}
114
115/// Trait for serializable state (for persistence).
116pub trait PersistableState: GraphState + Serialize + DeserializeOwned {}
117
118impl<T> PersistableState for T where T: GraphState + Serialize + DeserializeOwned {}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[derive(Debug, Clone, Default)]
125    struct TestState {
126        value: i32,
127    }
128
129    #[test]
130    fn test_graph_state_trait() {
131        let state = TestState { value: 42 };
132        let cloned = state.clone();
133        assert_eq!(cloned.value, 42);
134    }
135
136    #[test]
137    fn test_run_context() {
138        let mut ctx = GraphRunContext::new(TestState { value: 0 }, (), "test-run");
139
140        assert_eq!(ctx.step, 0);
141        ctx.increment_step();
142        assert_eq!(ctx.step, 1);
143    }
144
145    #[test]
146    fn test_max_steps() {
147        let ctx = GraphRunContext::new(TestState::default(), (), "test").with_max_steps(5);
148
149        assert_eq!(ctx.max_steps, 5);
150    }
151
152    #[test]
153    fn test_generate_run_id() {
154        let id1 = generate_run_id();
155        let id2 = generate_run_id();
156        assert!(id1.starts_with("run-"));
157        // IDs might be same if generated in same nanosecond, but that's rare
158        assert!(!id1.is_empty());
159        assert!(!id2.is_empty());
160    }
161}