Skip to main content

rust_langgraph/
nodes.rs

1//! Node abstraction for graph execution.
2//!
3//! Nodes are the computational units in a LangGraph. Each node takes state
4//! as input and produces updated state as output.
5
6use crate::config::Config;
7use crate::errors::Result;
8use crate::state::State;
9use async_trait::async_trait;
10use std::fmt::Debug;
11use std::future::Future;
12use std::sync::Arc;
13
14/// The core trait for graph nodes.
15///
16/// A node is a unit of computation that takes state and produces
17/// updated state. Nodes can be async functions, closures, or
18/// custom types implementing this trait.
19///
20/// # Example
21///
22/// ```rust
23/// use rust_langgraph::{Node, Config, Error};
24/// use async_trait::async_trait;
25///
26/// #[derive(Clone)]
27/// struct MyState {
28///     count: i32,
29/// }
30///
31/// struct IncrementNode;
32///
33/// #[async_trait]
34/// impl Node<MyState> for IncrementNode {
35///     async fn invoke(&self, mut state: MyState, _config: &Config) -> Result<MyState, Error> {
36///         state.count += 1;
37///         Ok(state)
38///     }
39/// }
40/// ```
41#[async_trait]
42pub trait Node<S: State>: Send + Sync {
43    /// Execute the node with the given state and configuration.
44    ///
45    /// # Arguments
46    ///
47    /// * `state` - The current state
48    /// * `config` - Execution configuration
49    ///
50    /// # Returns
51    ///
52    /// The updated state or an error
53    async fn invoke(&self, state: S, config: &Config) -> Result<S>;
54}
55
56// Implement Node for async closures
57#[async_trait]
58impl<S, F, Fut> Node<S> for F
59where
60    S: State,
61    F: Fn(S, &Config) -> Fut + Send + Sync,
62    Fut: Future<Output = Result<S>> + Send,
63{
64    async fn invoke(&self, state: S, config: &Config) -> Result<S> {
65        self(state, config).await
66    }
67}
68
69/// Type alias for boxed nodes
70pub type NodeBox<S> = Box<dyn Node<S>>;
71
72/// Type alias for arc'd nodes (more efficient for shared ownership)
73pub type NodeArc<S> = Arc<dyn Node<S>>;
74
75/// A node in the Pregel execution engine.
76///
77/// PregelNode wraps a user's node with metadata about which channels
78/// it reads from and writes to, and what triggers it.
79#[derive(Clone)]
80pub struct PregelNode<S: State> {
81    /// The name of this node
82    pub name: String,
83
84    /// The channels this node reads from
85    pub channels: Vec<String>,
86
87    /// The channels that trigger this node when written to
88    pub triggers: Vec<String>,
89
90    /// The actual node computation
91    pub bound: NodeArc<S>,
92
93    /// The channels this node writes to
94    pub writers: Vec<ChannelWrite>,
95}
96
97impl<S: State> Debug for PregelNode<S> {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("PregelNode")
100            .field("name", &self.name)
101            .field("channels", &self.channels)
102            .field("triggers", &self.triggers)
103            .field("bound", &"<node>")
104            .field("writers", &self.writers)
105            .finish()
106    }
107}
108
109impl<S: State> PregelNode<S> {
110    /// Create a new PregelNode
111    pub fn new(
112        name: impl Into<String>,
113        channels: Vec<String>,
114        triggers: Vec<String>,
115        bound: NodeArc<S>,
116        writers: Vec<ChannelWrite>,
117    ) -> Self {
118        Self {
119            name: name.into(),
120            channels,
121            triggers,
122            bound,
123            writers,
124        }
125    }
126
127    /// Create a PregelNode from a concrete node implementation
128    pub fn from_node(
129        name: impl Into<String>,
130        channels: Vec<String>,
131        triggers: Vec<String>,
132        bound: impl Node<S> + 'static,
133        writers: Vec<ChannelWrite>,
134    ) -> Self {
135        Self {
136            name: name.into(),
137            channels,
138            triggers,
139            bound: Arc::new(bound),
140            writers,
141        }
142    }
143
144    /// Check if this node is triggered by the given channel writes
145    pub fn is_triggered(&self, written_channels: &[String]) -> bool {
146        self.triggers.iter().any(|t| written_channels.contains(t))
147    }
148}
149
150/// Specification for writing to a channel after node execution
151#[derive(Debug, Clone)]
152pub struct ChannelWrite {
153    /// The channel to write to
154    pub channel: String,
155
156    /// Whether to skip writing if the value is None
157    pub skip_none: bool,
158
159    /// Optional mapper function name
160    pub mapper: Option<String>,
161}
162
163impl ChannelWrite {
164    /// Create a new channel write specification
165    pub fn new(channel: impl Into<String>) -> Self {
166        Self {
167            channel: channel.into(),
168            skip_none: true,
169            mapper: None,
170        }
171    }
172
173    /// Set whether to skip None values
174    pub fn with_skip_none(mut self, skip: bool) -> Self {
175        self.skip_none = skip;
176        self
177    }
178}
179
180/// Helper to create a simple node from an async function
181pub fn node_fn<S, F, Fut>(f: F) -> impl Node<S>
182where
183    S: State,
184    F: Fn(S, &Config) -> Fut + Send + Sync + 'static,
185    Fut: Future<Output = Result<S>> + Send + 'static,
186{
187    f
188}
189
190/// Helper to create a node that doesn't use config
191pub fn simple_node<S, F, Fut>(f: F) -> impl Node<S>
192where
193    S: State,
194    F: Fn(S) -> Fut + Send + Sync + 'static,
195    Fut: Future<Output = Result<S>> + Send + 'static,
196{
197    move |state: S, _config: &Config| f(state)
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::state::State as StateTrait;
204    use serde::{Deserialize, Serialize};
205
206    #[derive(Clone, Debug, Serialize, Deserialize)]
207    struct TestState {
208        count: i32,
209    }
210
211    impl StateTrait for TestState {
212        fn merge(&mut self, other: Self) -> Result<()> {
213            self.count += other.count;
214            Ok(())
215        }
216    }
217
218    #[tokio::test]
219    async fn test_node_from_closure() {
220        let node = |mut state: TestState, _config: &Config| async move {
221            state.count += 1;
222            Ok(state)
223        };
224
225        let state = TestState { count: 0 };
226        let result = node.invoke(state, &Config::default()).await.unwrap();
227        assert_eq!(result.count, 1);
228    }
229
230    #[tokio::test]
231    async fn test_simple_node() {
232        let node = simple_node(|mut state: TestState| async move {
233            state.count += 10;
234            Ok(state)
235        });
236
237        let state = TestState { count: 5 };
238        let result = node.invoke(state, &Config::default()).await.unwrap();
239        assert_eq!(result.count, 15);
240    }
241
242    struct CustomNode;
243
244    #[async_trait]
245    impl Node<TestState> for CustomNode {
246        async fn invoke(&self, mut state: TestState, _config: &Config) -> Result<TestState> {
247            state.count *= 2;
248            Ok(state)
249        }
250    }
251
252    #[tokio::test]
253    async fn test_custom_node() {
254        let node = CustomNode;
255        let state = TestState { count: 5 };
256        let result = node.invoke(state, &Config::default()).await.unwrap();
257        assert_eq!(result.count, 10);
258    }
259
260    #[test]
261    fn test_pregel_node_is_triggered() {
262        let node = PregelNode::from_node(
263            "test",
264            vec!["in".to_string()],
265            vec!["trigger_a".to_string(), "trigger_b".to_string()],
266            |state: TestState, _: &Config| async move { Ok(state) },
267            vec![],
268        );
269
270        assert!(node.is_triggered(&["trigger_a".to_string()]));
271        assert!(node.is_triggered(&["trigger_b".to_string()]));
272        assert!(node.is_triggered(&["trigger_a".to_string(), "other".to_string()]));
273        assert!(!node.is_triggered(&["other".to_string()]));
274        assert!(!node.is_triggered(&[]));
275    }
276
277    #[test]
278    fn test_channel_write() {
279        let write = ChannelWrite::new("output").with_skip_none(false);
280        assert_eq!(write.channel, "output");
281        assert!(!write.skip_none);
282    }
283}