Skip to main content

serdes_ai_graph/
node.rs

1//! Graph node types.
2
3use crate::error::GraphResult;
4use crate::state::{GraphRunContext, GraphState};
5use async_trait::async_trait;
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10/// Result of a node execution.
11pub enum NodeResult<State, Deps, End> {
12    /// Continue to another node.
13    Next(Box<dyn BaseNode<State, Deps, End>>),
14    /// Continue to a named node.
15    NextNamed(String),
16    /// End the graph with a result.
17    End(End),
18}
19
20impl<State, Deps, End> NodeResult<State, Deps, End> {
21    /// Create a Next result with a node.
22    pub fn next<N: BaseNode<State, Deps, End> + 'static>(node: N) -> Self {
23        Self::Next(Box::new(node))
24    }
25
26    /// Create a NextNamed result.
27    pub fn next_named(name: impl Into<String>) -> Self {
28        Self::NextNamed(name.into())
29    }
30
31    /// Create an End result.
32    pub fn end(value: End) -> Self {
33        Self::End(value)
34    }
35}
36
37/// End marker with result value.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct End<T>(pub T);
40
41impl<T> End<T> {
42    /// Create a new End marker.
43    pub fn new(value: T) -> Self {
44        Self(value)
45    }
46
47    /// Get the inner value.
48    pub fn into_inner(self) -> T {
49        self.0
50    }
51
52    /// Get a reference to the inner value.
53    pub fn value(&self) -> &T {
54        &self.0
55    }
56}
57
58impl<T: Default> Default for End<T> {
59    fn default() -> Self {
60        Self(T::default())
61    }
62}
63
64/// Base trait for all graph nodes.
65#[async_trait]
66pub trait BaseNode<State, Deps = (), End = ()>: Send + Sync {
67    /// Get the node type name for debugging.
68    fn type_name(&self) -> &'static str {
69        std::any::type_name::<Self>()
70    }
71
72    /// Get a human-readable name for this node.
73    fn name(&self) -> &str {
74        self.type_name()
75    }
76
77    /// Execute this node.
78    async fn run(
79        &self,
80        ctx: &mut GraphRunContext<State, Deps>,
81    ) -> GraphResult<NodeResult<State, Deps, End>>;
82}
83
84/// Node trait alias for simple state-only nodes.
85#[async_trait]
86pub trait Node<State: GraphState>: Send + Sync {
87    /// Execute the node and return updated state.
88    async fn execute(&self, state: State) -> GraphResult<State>;
89
90    /// Get the node name.
91    fn name(&self) -> &str;
92}
93
94/// A node that runs a function.
95pub struct FunctionNode<State, F, Fut>
96where
97    F: Fn(State) -> Fut + Send + Sync,
98    Fut: Future<Output = GraphResult<State>> + Send,
99{
100    name: String,
101    func: F,
102    _phantom: PhantomData<State>,
103}
104
105impl<State, F, Fut> FunctionNode<State, F, Fut>
106where
107    F: Fn(State) -> Fut + Send + Sync,
108    Fut: Future<Output = GraphResult<State>> + Send,
109{
110    /// Create a new function node.
111    pub fn new(name: impl Into<String>, func: F) -> Self {
112        Self {
113            name: name.into(),
114            func,
115            _phantom: PhantomData,
116        }
117    }
118}
119
120#[async_trait]
121impl<State, F, Fut> Node<State> for FunctionNode<State, F, Fut>
122where
123    State: GraphState,
124    F: Fn(State) -> Fut + Send + Sync,
125    Fut: Future<Output = GraphResult<State>> + Send,
126{
127    async fn execute(&self, state: State) -> GraphResult<State> {
128        (self.func)(state).await
129    }
130
131    fn name(&self) -> &str {
132        &self.name
133    }
134}
135
136/// A node that runs an agent.
137#[allow(dead_code)]
138pub struct AgentNode<State, Agent, UpdateFn>
139where
140    UpdateFn: Fn(State, &Agent) -> State + Send + Sync,
141{
142    name: String,
143    agent: Arc<Agent>,
144    update_state: UpdateFn,
145    _phantom: PhantomData<State>,
146}
147
148impl<State, Agent, UpdateFn> AgentNode<State, Agent, UpdateFn>
149where
150    UpdateFn: Fn(State, &Agent) -> State + Send + Sync,
151{
152    /// Create a new agent node.
153    pub fn new(name: impl Into<String>, agent: Agent, update_state: UpdateFn) -> Self {
154        Self {
155            name: name.into(),
156            agent: Arc::new(agent),
157            update_state,
158            _phantom: PhantomData,
159        }
160    }
161
162    /// Get a reference to the agent.
163    pub fn agent(&self) -> &Agent {
164        &self.agent
165    }
166}
167
168/// A node that routes based on state.
169pub struct RouterNode<State, F>
170where
171    F: Fn(&State) -> String + Send + Sync,
172{
173    #[allow(dead_code)]
174    name: String,
175    router: F,
176    _phantom: PhantomData<State>,
177}
178
179impl<State, F> RouterNode<State, F>
180where
181    F: Fn(&State) -> String + Send + Sync,
182{
183    /// Create a new router node.
184    pub fn new(name: impl Into<String>, router: F) -> Self {
185        Self {
186            name: name.into(),
187            router,
188            _phantom: PhantomData,
189        }
190    }
191
192    /// Get the next node name based on state.
193    pub fn route(&self, state: &State) -> String {
194        (self.router)(state)
195    }
196}
197
198/// A conditional node that branches based on state.
199#[allow(dead_code)]
200pub struct ConditionalNode<State, Cond, Then, Else>
201where
202    Cond: Fn(&State) -> bool + Send + Sync,
203    Then: BaseNode<State> + 'static,
204    Else: BaseNode<State> + 'static,
205{
206    name: String,
207    condition: Cond,
208    then_node: Box<Then>,
209    else_node: Box<Else>,
210    _phantom: PhantomData<State>,
211}
212
213impl<State, Cond, Then, Else> ConditionalNode<State, Cond, Then, Else>
214where
215    Cond: Fn(&State) -> bool + Send + Sync,
216    Then: BaseNode<State> + 'static,
217    Else: BaseNode<State> + 'static,
218{
219    /// Create a new conditional node.
220    pub fn new(name: impl Into<String>, condition: Cond, then_node: Then, else_node: Else) -> Self {
221        Self {
222            name: name.into(),
223            condition,
224            then_node: Box::new(then_node),
225            else_node: Box::new(else_node),
226            _phantom: PhantomData,
227        }
228    }
229}
230
231/// Node definition for registration in a graph.
232pub struct NodeDef<State, Deps = (), End = ()> {
233    /// Node name.
234    pub name: String,
235    /// The node implementation.
236    pub node: Box<dyn BaseNode<State, Deps, End>>,
237}
238
239impl<State, Deps, End> NodeDef<State, Deps, End> {
240    /// Create a new node definition.
241    pub fn new<N: BaseNode<State, Deps, End> + 'static>(name: impl Into<String>, node: N) -> Self {
242        Self {
243            name: name.into(),
244            node: Box::new(node),
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[derive(Debug, Clone, Default)]
254    struct TestState {
255        value: i32,
256    }
257
258    #[test]
259    fn test_end_marker() {
260        let end = End::new(42);
261        assert_eq!(end.value(), &42);
262        assert_eq!(end.into_inner(), 42);
263    }
264
265    #[test]
266    fn test_node_result_variants() {
267        let _next_named: NodeResult<TestState, (), i32> = NodeResult::next_named("next");
268        let _end: NodeResult<TestState, (), i32> = NodeResult::end(42);
269    }
270
271    #[test]
272    fn test_router_node() {
273        let router = RouterNode::new("router", |state: &TestState| {
274            if state.value > 0 {
275                "positive".to_string()
276            } else {
277                "negative".to_string()
278            }
279        });
280
281        assert_eq!(router.route(&TestState { value: 1 }), "positive");
282        assert_eq!(router.route(&TestState { value: -1 }), "negative");
283    }
284}