1pub mod agent;
7pub mod condition;
8pub mod tool;
9pub mod transform;
10
11pub use agent::{AgentNode, AgentNodeConfig};
13pub use condition::{ConditionNode, ConditionNodeConfig};
14pub use tool::{ToolNode, ToolNodeConfig};
15pub use transform::{TransformNode, TransformNodeConfig};
16
17use crate::core::NodeId;
18use crate::{RGraphError, RGraphResult};
19use std::collections::HashMap;
20
21#[cfg(feature = "serde")]
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone)]
26#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
27pub struct NodeConfig {
28 pub id: NodeId,
30
31 pub name: String,
33
34 pub description: Option<String>,
36
37 pub input_mappings: HashMap<String, String>,
39
40 pub output_mappings: HashMap<String, String>,
42
43 pub config: serde_json::Value,
45
46 pub retryable: bool,
48
49 pub max_retries: usize,
51
52 pub tags: Vec<String>,
54}
55
56impl NodeConfig {
57 pub fn new(id: impl Into<NodeId>, name: impl Into<String>) -> Self {
59 Self {
60 id: id.into(),
61 name: name.into(),
62 description: None,
63 input_mappings: HashMap::new(),
64 output_mappings: HashMap::new(),
65 config: serde_json::Value::Null,
66 retryable: false,
67 max_retries: 0,
68 tags: Vec::new(),
69 }
70 }
71
72 pub fn with_description(mut self, description: impl Into<String>) -> Self {
74 self.description = Some(description.into());
75 self
76 }
77
78 pub fn with_input_mapping(
80 mut self,
81 state_key: impl Into<String>,
82 node_input_key: impl Into<String>,
83 ) -> Self {
84 self.input_mappings
85 .insert(state_key.into(), node_input_key.into());
86 self
87 }
88
89 pub fn with_output_mapping(
91 mut self,
92 node_output_key: impl Into<String>,
93 state_key: impl Into<String>,
94 ) -> Self {
95 self.output_mappings
96 .insert(node_output_key.into(), state_key.into());
97 self
98 }
99
100 pub fn with_config(mut self, config: serde_json::Value) -> Self {
102 self.config = config;
103 self
104 }
105
106 pub fn with_retries(mut self, max_retries: usize) -> Self {
108 self.retryable = true;
109 self.max_retries = max_retries;
110 self
111 }
112
113 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
115 self.tags = tags;
116 self
117 }
118
119 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
121 self.tags.push(tag.into());
122 self
123 }
124}
125
126#[derive(Debug, Clone)]
128#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
129pub struct NodeMetadata {
130 pub id: NodeId,
132
133 pub name: String,
135
136 pub description: Option<String>,
138
139 pub input_keys: Vec<String>,
141
142 pub output_keys: Vec<String>,
144
145 pub node_type: String,
147
148 pub version: String,
150
151 pub metadata: HashMap<String, serde_json::Value>,
153}
154
155pub trait NodeBuilder: Send + Sync {
157 type Node: crate::core::Node;
159
160 fn build(&self, config: NodeConfig) -> RGraphResult<Self::Node>;
162
163 fn node_type(&self) -> &str;
165
166 fn validate_config(&self, config: &NodeConfig) -> RGraphResult<()> {
168 if config.name.is_empty() {
170 return Err(RGraphError::validation("Node name cannot be empty"));
171 }
172 Ok(())
173 }
174}
175
176pub struct NodeBuilderRegistry {
178 _placeholder: bool,
179}
180
181impl NodeBuilderRegistry {
182 pub fn new() -> Self {
184 Self { _placeholder: true }
185 }
186
187 pub fn register<B>(&mut self, _node_type: String, _builder: B)
189 where
190 B: NodeBuilder + 'static,
191 B::Node: crate::core::Node + 'static,
192 {
193 }
196
197 pub fn node_types(&self) -> Vec<String> {
199 vec![]
200 }
201}
202
203impl Default for NodeBuilderRegistry {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[cfg(test)]
211pub mod test_utils {
212 use super::*;
213 use crate::core::{ExecutionContext, ExecutionResult, Node};
214 use crate::state::GraphState;
215 use async_trait::async_trait;
216 use std::sync::Arc;
217
218 pub struct PassThroughNode {
219 id: NodeId,
220 name: String,
221 output_key: String,
222 output_value: String,
223 }
224
225 impl PassThroughNode {
226 pub fn new(
227 id: impl Into<NodeId>,
228 name: impl Into<String>,
229 output_key: impl Into<String>,
230 output_value: impl Into<String>,
231 ) -> Arc<Self> {
232 Arc::new(Self {
233 id: id.into(),
234 name: name.into(),
235 output_key: output_key.into(),
236 output_value: output_value.into(),
237 })
238 }
239 }
240
241 #[async_trait]
242 impl Node for PassThroughNode {
243 async fn execute(
244 &self,
245 state: &mut GraphState,
246 _context: &ExecutionContext,
247 ) -> crate::RGraphResult<ExecutionResult> {
248 state.set(&self.output_key, &self.output_value);
249 Ok(ExecutionResult::Continue)
250 }
251
252 fn id(&self) -> &NodeId {
253 &self.id
254 }
255
256 fn name(&self) -> &str {
257 &self.name
258 }
259
260 fn output_keys(&self) -> Vec<&str> {
261 vec![&self.output_key]
262 }
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use serde_json::json;
270
271 #[test]
272 fn test_node_config_creation() {
273 let config = NodeConfig::new("test_node", "Test Node")
274 .with_description("A test node")
275 .with_input_mapping("user_input", "prompt")
276 .with_output_mapping("result", "node_output")
277 .with_config(json!({"temperature": 0.7}))
278 .with_retries(3)
279 .with_tag("test");
280
281 assert_eq!(config.id.as_str(), "test_node");
282 assert_eq!(config.name, "Test Node");
283 assert_eq!(config.description, Some("A test node".to_string()));
284 assert_eq!(
285 config.input_mappings.get("user_input"),
286 Some(&"prompt".to_string())
287 );
288 assert_eq!(
289 config.output_mappings.get("result"),
290 Some(&"node_output".to_string())
291 );
292 assert!(config.retryable);
293 assert_eq!(config.max_retries, 3);
294 assert!(config.tags.contains(&"test".to_string()));
295 }
296
297 #[test]
298 fn test_node_metadata() {
299 let metadata = NodeMetadata {
300 id: NodeId::new("test_node"),
301 name: "Test Node".to_string(),
302 description: Some("A test node".to_string()),
303 input_keys: vec!["input".to_string()],
304 output_keys: vec!["output".to_string()],
305 node_type: "test".to_string(),
306 version: "1.0.0".to_string(),
307 metadata: HashMap::new(),
308 };
309
310 assert_eq!(metadata.id.as_str(), "test_node");
311 assert_eq!(metadata.name, "Test Node");
312 assert_eq!(metadata.node_type, "test");
313 assert_eq!(metadata.version, "1.0.0");
314 assert_eq!(metadata.input_keys.len(), 1);
315 assert_eq!(metadata.output_keys.len(), 1);
316 }
317
318 #[test]
319 fn test_node_builder_registry() {
320 let mut registry = NodeBuilderRegistry::new();
321 assert_eq!(registry.node_types().len(), 0);
322
323 assert!(registry.node_types().is_empty());
326 }
327
328 #[cfg(test)]
329 #[tokio::test]
330 async fn test_pass_through_node() {
331 use crate::core::ExecutionContext;
332 use crate::state::{GraphState, StateValue};
333 use test_utils::PassThroughNode;
334
335 let node = PassThroughNode::new("test", "Test", "output", "test_value");
336 let mut state = GraphState::new();
337 let context = ExecutionContext::new("graph1".to_string(), NodeId::new("test"));
338
339 let result = node.execute(&mut state, &context).await.unwrap();
340
341 assert!(matches!(result, ExecutionResult::Continue));
342 assert_eq!(
343 state.get("output").unwrap(),
344 StateValue::String("test_value".to_string())
345 );
346 }
347}