rrag_graph/
nodes.rs

1//! # Graph Node Implementations
2//!
3//! This module provides various types of nodes that can be used in workflow graphs,
4//! including agent nodes, tool nodes, condition nodes, and transform nodes.
5
6pub mod agent;
7pub mod condition;
8pub mod tool;
9pub mod transform;
10
11// Re-export node types
12pub use agent::{AgentNode, AgentNodeConfig};
13pub use condition::{ConditionNode, ConditionNodeConfig};
14pub use tool::{ToolNode, ToolNodeConfig};
15pub use transform::{TransformNode, TransformNodeConfig};
16
17use crate::core::NodeId;
18use crate::{RGraphError, RGraphResult};
19use std::collections::HashMap;
20
21#[cfg(feature = "serde")]
22use serde::{Deserialize, Serialize};
23
24/// Base configuration for all node types
25#[derive(Debug, Clone)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27pub struct NodeConfig {
28    /// Node ID
29    pub id: NodeId,
30
31    /// Display name
32    pub name: String,
33
34    /// Optional description
35    pub description: Option<String>,
36
37    /// Input key mappings (state_key -> node_input_key)
38    pub input_mappings: HashMap<String, String>,
39
40    /// Output key mappings (node_output_key -> state_key)
41    pub output_mappings: HashMap<String, String>,
42
43    /// Node-specific configuration
44    pub config: serde_json::Value,
45
46    /// Whether this node can be retried on failure
47    pub retryable: bool,
48
49    /// Maximum number of retry attempts
50    pub max_retries: usize,
51
52    /// Tags for organizing and filtering nodes
53    pub tags: Vec<String>,
54}
55
56impl NodeConfig {
57    /// Create a new node configuration
58    pub fn new(id: impl Into<NodeId>, name: impl Into<String>) -> Self {
59        Self {
60            id: id.into(),
61            name: name.into(),
62            description: None,
63            input_mappings: HashMap::new(),
64            output_mappings: HashMap::new(),
65            config: serde_json::Value::Null,
66            retryable: false,
67            max_retries: 0,
68            tags: Vec::new(),
69        }
70    }
71
72    /// Set the description
73    pub fn with_description(mut self, description: impl Into<String>) -> Self {
74        self.description = Some(description.into());
75        self
76    }
77
78    /// Add an input mapping
79    pub fn with_input_mapping(
80        mut self,
81        state_key: impl Into<String>,
82        node_input_key: impl Into<String>,
83    ) -> Self {
84        self.input_mappings
85            .insert(state_key.into(), node_input_key.into());
86        self
87    }
88
89    /// Add an output mapping
90    pub fn with_output_mapping(
91        mut self,
92        node_output_key: impl Into<String>,
93        state_key: impl Into<String>,
94    ) -> Self {
95        self.output_mappings
96            .insert(node_output_key.into(), state_key.into());
97        self
98    }
99
100    /// Set the configuration
101    pub fn with_config(mut self, config: serde_json::Value) -> Self {
102        self.config = config;
103        self
104    }
105
106    /// Make the node retryable
107    pub fn with_retries(mut self, max_retries: usize) -> Self {
108        self.retryable = true;
109        self.max_retries = max_retries;
110        self
111    }
112
113    /// Add tags
114    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
115        self.tags = tags;
116        self
117    }
118
119    /// Add a single tag
120    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
121        self.tags.push(tag.into());
122        self
123    }
124}
125
126/// Metadata about a node implementation
127#[derive(Debug, Clone)]
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129pub struct NodeMetadata {
130    /// Node ID
131    pub id: NodeId,
132
133    /// Node name
134    pub name: String,
135
136    /// Node description
137    pub description: Option<String>,
138
139    /// Expected input keys
140    pub input_keys: Vec<String>,
141
142    /// Output keys
143    pub output_keys: Vec<String>,
144
145    /// Node type
146    pub node_type: String,
147
148    /// Node version
149    pub version: String,
150
151    /// Additional metadata
152    pub metadata: HashMap<String, serde_json::Value>,
153}
154
155/// Trait for node builders that can create nodes from configuration
156pub trait NodeBuilder: Send + Sync {
157    /// The type of node this builder creates
158    type Node: crate::core::Node;
159
160    /// Build a node from configuration
161    fn build(&self, config: NodeConfig) -> RGraphResult<Self::Node>;
162
163    /// Get the node type this builder creates
164    fn node_type(&self) -> &str;
165
166    /// Validate the configuration
167    fn validate_config(&self, config: &NodeConfig) -> RGraphResult<()> {
168        // Default implementation - can be overridden
169        if config.name.is_empty() {
170            return Err(RGraphError::validation("Node name cannot be empty"));
171        }
172        Ok(())
173    }
174}
175
176/// Registry for node builders (placeholder implementation)
177pub struct NodeBuilderRegistry {
178    _placeholder: bool,
179}
180
181impl NodeBuilderRegistry {
182    /// Create a new registry
183    pub fn new() -> Self {
184        Self { _placeholder: true }
185    }
186
187    /// Register a node builder (placeholder)
188    pub fn register<B>(&mut self, _node_type: String, _builder: B)
189    where
190        B: NodeBuilder + 'static,
191        B::Node: crate::core::Node + 'static,
192    {
193        // This would need proper type erasure in a real implementation
194        // For now, this is a placeholder
195    }
196
197    /// Get available node types
198    pub fn node_types(&self) -> Vec<String> {
199        vec![]
200    }
201}
202
203impl Default for NodeBuilderRegistry {
204    fn default() -> Self {
205        Self::new()
206    }
207}
208
209/// Helper function to create a simple pass-through node for testing
210#[cfg(test)]
211pub mod test_utils {
212    use super::*;
213    use crate::core::{ExecutionContext, ExecutionResult, Node};
214    use crate::state::GraphState;
215    use async_trait::async_trait;
216    use std::sync::Arc;
217
218    pub struct PassThroughNode {
219        id: NodeId,
220        name: String,
221        output_key: String,
222        output_value: String,
223    }
224
225    impl PassThroughNode {
226        pub fn new(
227            id: impl Into<NodeId>,
228            name: impl Into<String>,
229            output_key: impl Into<String>,
230            output_value: impl Into<String>,
231        ) -> Arc<Self> {
232            Arc::new(Self {
233                id: id.into(),
234                name: name.into(),
235                output_key: output_key.into(),
236                output_value: output_value.into(),
237            })
238        }
239    }
240
241    #[async_trait]
242    impl Node for PassThroughNode {
243        async fn execute(
244            &self,
245            state: &mut GraphState,
246            _context: &ExecutionContext,
247        ) -> crate::RGraphResult<ExecutionResult> {
248            state.set(&self.output_key, &self.output_value);
249            Ok(ExecutionResult::Continue)
250        }
251
252        fn id(&self) -> &NodeId {
253            &self.id
254        }
255
256        fn name(&self) -> &str {
257            &self.name
258        }
259
260        fn output_keys(&self) -> Vec<&str> {
261            vec![&self.output_key]
262        }
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use serde_json::json;
270
271    #[test]
272    fn test_node_config_creation() {
273        let config = NodeConfig::new("test_node", "Test Node")
274            .with_description("A test node")
275            .with_input_mapping("user_input", "prompt")
276            .with_output_mapping("result", "node_output")
277            .with_config(json!({"temperature": 0.7}))
278            .with_retries(3)
279            .with_tag("test");
280
281        assert_eq!(config.id.as_str(), "test_node");
282        assert_eq!(config.name, "Test Node");
283        assert_eq!(config.description, Some("A test node".to_string()));
284        assert_eq!(
285            config.input_mappings.get("user_input"),
286            Some(&"prompt".to_string())
287        );
288        assert_eq!(
289            config.output_mappings.get("result"),
290            Some(&"node_output".to_string())
291        );
292        assert!(config.retryable);
293        assert_eq!(config.max_retries, 3);
294        assert!(config.tags.contains(&"test".to_string()));
295    }
296
297    #[test]
298    fn test_node_metadata() {
299        let metadata = NodeMetadata {
300            id: NodeId::new("test_node"),
301            name: "Test Node".to_string(),
302            description: Some("A test node".to_string()),
303            input_keys: vec!["input".to_string()],
304            output_keys: vec!["output".to_string()],
305            node_type: "test".to_string(),
306            version: "1.0.0".to_string(),
307            metadata: HashMap::new(),
308        };
309
310        assert_eq!(metadata.id.as_str(), "test_node");
311        assert_eq!(metadata.name, "Test Node");
312        assert_eq!(metadata.node_type, "test");
313        assert_eq!(metadata.version, "1.0.0");
314        assert_eq!(metadata.input_keys.len(), 1);
315        assert_eq!(metadata.output_keys.len(), 1);
316    }
317
318    #[test]
319    fn test_node_builder_registry() {
320        let mut registry = NodeBuilderRegistry::new();
321        assert_eq!(registry.node_types().len(), 0);
322
323        // In a real implementation, we'd register actual builders here
324        // For now, we just test the basic structure
325        assert!(registry.node_types().is_empty());
326    }
327
328    #[cfg(test)]
329    #[tokio::test]
330    async fn test_pass_through_node() {
331        use crate::core::ExecutionContext;
332        use crate::state::{GraphState, StateValue};
333        use test_utils::PassThroughNode;
334
335        let node = PassThroughNode::new("test", "Test", "output", "test_value");
336        let mut state = GraphState::new();
337        let context = ExecutionContext::new("graph1".to_string(), NodeId::new("test"));
338
339        let result = node.execute(&mut state, &context).await.unwrap();
340
341        assert!(matches!(result, ExecutionResult::Continue));
342        assert_eq!(
343            state.get("output").unwrap(),
344            StateValue::String("test_value".to_string())
345        );
346    }
347}