Skip to main content

serdes_ai_graph/
iter.rs

1//! Graph iteration support.
2
3use crate::error::GraphError;
4use crate::node::{BaseNode, NodeResult};
5use crate::state::{generate_run_id, GraphRunContext, GraphRunResult, GraphState};
6use std::marker::PhantomData;
7
8/// Iterator for stepping through a graph.
9pub struct GraphIter<'a, State, Deps, End>
10where
11    State: GraphState,
12    Deps: Clone + Send + Sync + 'static,
13    End: Clone + Send + Sync + 'static,
14{
15    ctx: GraphRunContext<State, Deps>,
16    current: Option<Box<dyn BaseNode<State, Deps, End>>>,
17    finished: bool,
18    result: Option<End>,
19    history: Vec<String>,
20    _phantom: PhantomData<&'a ()>,
21}
22
23impl<'a, State, Deps, End> GraphIter<'a, State, Deps, End>
24where
25    State: GraphState,
26    Deps: Clone + Send + Sync + 'static,
27    End: Clone + Send + Sync + 'static,
28{
29    /// Create a new graph iterator.
30    pub fn new<N: BaseNode<State, Deps, End> + Clone + 'static>(
31        start: N,
32        state: State,
33        deps: Deps,
34    ) -> Self {
35        let run_id = generate_run_id();
36        Self {
37            ctx: GraphRunContext::new(state, deps, run_id),
38            current: Some(Box::new(start)),
39            finished: false,
40            result: None,
41            history: Vec::new(),
42            _phantom: PhantomData,
43        }
44    }
45
46    /// Execute the next step.
47    pub async fn step(&mut self) -> Option<StepResult<State>> {
48        if self.finished {
49            return None;
50        }
51
52        let current = self.current.take()?;
53        self.ctx.increment_step();
54
55        let node_name = current.name().to_string();
56        self.history.push(node_name.clone());
57
58        match current.run(&mut self.ctx).await {
59            Ok(NodeResult::Next(next)) => {
60                self.current = Some(next);
61                Some(StepResult::Continue { node: node_name })
62            }
63            Ok(NodeResult::NextNamed(name)) => {
64                // Named transitions require external graph lookup
65                self.finished = true;
66                Some(StepResult::NamedTransition {
67                    node: node_name,
68                    next: name,
69                })
70            }
71            Ok(NodeResult::End(_end)) => {
72                self.finished = true;
73                Some(StepResult::Finished { node: node_name })
74            }
75            Err(e) => {
76                self.finished = true;
77                Some(StepResult::Error(e))
78            }
79        }
80    }
81
82    /// Get the current state.
83    pub fn state(&self) -> &State {
84        &self.ctx.state
85    }
86
87    /// Get mutable state.
88    pub fn state_mut(&mut self) -> &mut State {
89        &mut self.ctx.state
90    }
91
92    /// Get the current step number.
93    pub fn step_count(&self) -> u32 {
94        self.ctx.step
95    }
96
97    /// Check if finished.
98    pub fn is_finished(&self) -> bool {
99        self.finished
100    }
101
102    /// Get the history of visited nodes.
103    pub fn history(&self) -> &[String] {
104        &self.history
105    }
106
107    /// Consume and get the result.
108    pub fn into_result(self) -> Option<GraphRunResult<State, End>> {
109        self.result.map(|r| {
110            GraphRunResult::new(r, self.ctx.state, self.ctx.step, self.ctx.run_id)
111                .with_history(self.history)
112        })
113    }
114}
115
116/// Result of a single step.
117#[derive(Debug)]
118pub enum StepResult<State> {
119    /// Graph continues to next node.
120    Continue {
121        /// Node that was executed.
122        node: String,
123    },
124    /// Named transition (requires lookup).
125    NamedTransition {
126        /// Node that was executed.
127        node: String,
128        /// Name of next node.
129        next: String,
130    },
131    /// Graph finished.
132    Finished {
133        /// Final node.
134        node: String,
135    },
136    /// Error occurred.
137    Error(GraphError),
138    /// Phantom state type holder.
139    #[doc(hidden)]
140    _State(PhantomData<State>),
141}
142
143impl<State> StepResult<State> {
144    /// Check if this step finished the graph.
145    pub fn is_finished(&self) -> bool {
146        matches!(self, Self::Finished { .. })
147    }
148
149    /// Check if this step had an error.
150    pub fn is_error(&self) -> bool {
151        matches!(self, Self::Error(_))
152    }
153
154    /// Get the node name if applicable.
155    pub fn node(&self) -> Option<&str> {
156        match self {
157            Self::Continue { node } => Some(node),
158            Self::NamedTransition { node, .. } => Some(node),
159            Self::Finished { node } => Some(node),
160            _ => None,
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[derive(Debug, Clone, Default)]
170    struct TestState {
171        _value: i32,
172    }
173
174    #[test]
175    fn test_step_result_is_finished() {
176        let result: StepResult<TestState> = StepResult::Finished {
177            node: "test".to_string(),
178        };
179        assert!(result.is_finished());
180    }
181
182    #[test]
183    fn test_step_result_is_error() {
184        let result: StepResult<TestState> = StepResult::Error(GraphError::NoEntryNode);
185        assert!(result.is_error());
186    }
187
188    #[test]
189    fn test_step_result_node() {
190        let result: StepResult<TestState> = StepResult::Continue {
191            node: "my_node".to_string(),
192        };
193        assert_eq!(result.node(), Some("my_node"));
194    }
195}