1use crate::core::{ExecutionContext, ExecutionResult, Node, NodeId};
7use crate::state::{GraphState, StateValue};
8use crate::tools::Tool;
9use crate::{RGraphError, RGraphResult};
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16
17#[derive(Debug, Clone)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub struct AgentNodeConfig {
21 pub name: String,
23
24 pub system_prompt: String,
26
27 pub tools: Vec<String>,
29
30 pub max_steps: usize,
32
33 pub temperature: f32,
35
36 pub max_tokens: Option<usize>,
38
39 pub structured_output: bool,
41
42 pub instructions: Vec<String>,
44}
45
46impl Default for AgentNodeConfig {
47 fn default() -> Self {
48 Self {
49 name: "assistant".to_string(),
50 system_prompt: "You are a helpful AI assistant.".to_string(),
51 tools: Vec::new(),
52 max_steps: 10,
53 temperature: 0.7,
54 max_tokens: Some(1000),
55 structured_output: false,
56 instructions: Vec::new(),
57 }
58 }
59}
60
61pub struct AgentNode {
63 id: NodeId,
64 config: AgentNodeConfig,
65 tools: HashMap<String, Arc<dyn Tool>>,
66}
67
68impl AgentNode {
69 pub fn new(id: impl Into<NodeId>, config: AgentNodeConfig) -> Self {
71 Self {
72 id: id.into(),
73 config,
74 tools: HashMap::new(),
75 }
76 }
77
78 pub fn with_tool(mut self, name: String, tool: Arc<dyn Tool>) -> Self {
80 self.tools.insert(name, tool);
81 self
82 }
83
84 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
86 self.config.system_prompt = prompt.into();
87 self
88 }
89
90 pub fn with_tools(mut self, tools: Vec<String>) -> Self {
92 self.config.tools = tools;
93 self
94 }
95
96 pub fn with_temperature(mut self, temperature: f32) -> Self {
98 self.config.temperature = temperature.clamp(0.0, 2.0);
99 self
100 }
101
102 async fn reasoning_loop(
104 &self,
105 state: &mut GraphState,
106 _context: &ExecutionContext,
107 initial_input: &str,
108 ) -> RGraphResult<String> {
109 let mut conversation_history = Vec::new();
110 let mut step_count = 0;
111
112 conversation_history.push(AgentMessage {
114 role: MessageRole::System,
115 content: self.config.system_prompt.clone(),
116 tool_calls: None,
117 });
118
119 conversation_history.push(AgentMessage {
121 role: MessageRole::User,
122 content: initial_input.to_string(),
123 tool_calls: None,
124 });
125
126 loop {
127 if step_count >= self.config.max_steps {
128 break;
129 }
130
131 step_count += 1;
132
133 let agent_response = self.generate_response(&conversation_history, state).await?;
135
136 if let Some(tool_calls) = &agent_response.tool_calls {
138 let mut tool_results = Vec::new();
140
141 for tool_call in tool_calls {
142 if let Some(tool) = self.tools.get(&tool_call.name) {
143 match tool.execute(&tool_call.arguments, state).await {
144 Ok(result) => {
145 tool_results.push(ToolCallResult {
146 call_id: tool_call.id.clone(),
147 name: tool_call.name.clone(),
148 result: result.output,
149 success: true,
150 error: None,
151 });
152 }
153 Err(e) => {
154 tool_results.push(ToolCallResult {
155 call_id: tool_call.id.clone(),
156 name: tool_call.name.clone(),
157 result: serde_json::Value::Null,
158 success: false,
159 error: Some(e.to_string()),
160 });
161 }
162 }
163 } else {
164 tool_results.push(ToolCallResult {
165 call_id: tool_call.id.clone(),
166 name: tool_call.name.clone(),
167 result: serde_json::Value::Null,
168 success: false,
169 error: Some(format!("Tool '{}' not found", tool_call.name)),
170 });
171 }
172 }
173
174 conversation_history.push(agent_response);
176
177 for tool_result in tool_results {
179 conversation_history.push(AgentMessage {
180 role: MessageRole::Tool,
181 content: if tool_result.success {
182 serde_json::to_string_pretty(&tool_result.result)
183 .unwrap_or_else(|_| "Tool execution completed".to_string())
184 } else {
185 format!(
186 "Error: {}",
187 tool_result
188 .error
189 .unwrap_or_else(|| "Unknown error".to_string())
190 )
191 },
192 tool_calls: None,
193 });
194 }
195 } else {
196 conversation_history.push(agent_response.clone());
198 return Ok(agent_response.content);
199 }
200 }
201
202 conversation_history
204 .iter()
205 .filter(|msg| msg.role == MessageRole::Assistant)
206 .last()
207 .map(|msg| msg.content.clone())
208 .unwrap_or_else(|| "Maximum reasoning steps reached without conclusion".to_string())
209 .pipe(Ok)
210 }
211
212 async fn generate_response(
214 &self,
215 conversation: &[AgentMessage],
216 _state: &GraphState,
217 ) -> RGraphResult<AgentMessage> {
218 let empty_string = String::new();
222 let last_user_message = conversation
223 .iter()
224 .filter(|msg| msg.role == MessageRole::User)
225 .last()
226 .map(|msg| &msg.content)
227 .unwrap_or(&empty_string);
228
229 if self.should_use_tools(last_user_message) && !self.tools.is_empty() {
231 let tool_name = self.tools.keys().next().unwrap().clone();
233
234 Ok(AgentMessage {
235 role: MessageRole::Assistant,
236 content: format!(
237 "I'll help you with that. Let me use the {} tool.",
238 tool_name
239 ),
240 tool_calls: Some(vec![ToolCall {
241 id: uuid::Uuid::new_v4().to_string(),
242 name: tool_name,
243 arguments: serde_json::json!({
244 "query": last_user_message
245 }),
246 }]),
247 })
248 } else {
249 Ok(AgentMessage {
251 role: MessageRole::Assistant,
252 content: format!(
253 "Based on your request '{}', I can provide assistance. This is a simulated response from the {} agent.",
254 last_user_message,
255 self.config.name
256 ),
257 tool_calls: None,
258 })
259 }
260 }
261
262 fn should_use_tools(&self, input: &str) -> bool {
264 let tool_keywords = ["search", "calculate", "analyze", "find", "lookup", "query"];
266 let input_lower = input.to_lowercase();
267
268 tool_keywords
269 .iter()
270 .any(|keyword| input_lower.contains(keyword))
271 }
272}
273
274#[async_trait]
275impl Node for AgentNode {
276 async fn execute(
277 &self,
278 state: &mut GraphState,
279 context: &ExecutionContext,
280 ) -> RGraphResult<ExecutionResult> {
281 let input = state
283 .get("user_input")
284 .or_else(|_| state.get("query"))
285 .or_else(|_| state.get("prompt"))
286 .map_err(|_| {
287 RGraphError::node(
288 self.id.as_str(),
289 "No input found in state (expected 'user_input', 'query', or 'prompt')",
290 )
291 })?;
292
293 let input_text = match input {
294 StateValue::String(s) => s,
295 _ => {
296 return Err(RGraphError::node(
297 self.id.as_str(),
298 "Input must be a string",
299 ))
300 }
301 };
302
303 let response = self.reasoning_loop(state, context, &input_text).await?;
305
306 state.set_with_context(
308 context.current_node.as_str(),
309 "agent_response",
310 response.clone(),
311 );
312
313 state.set_with_context(context.current_node.as_str(), "output", response);
315
316 Ok(ExecutionResult::Continue)
317 }
318
319 fn id(&self) -> &NodeId {
320 &self.id
321 }
322
323 fn name(&self) -> &str {
324 &self.config.name
325 }
326
327 fn input_keys(&self) -> Vec<&str> {
328 vec!["user_input", "query", "prompt"]
329 }
330
331 fn output_keys(&self) -> Vec<&str> {
332 vec!["agent_response", "output"]
333 }
334
335 fn validate(&self, state: &GraphState) -> RGraphResult<()> {
336 if !state.contains_key("user_input")
338 && !state.contains_key("query")
339 && !state.contains_key("prompt")
340 {
341 return Err(RGraphError::validation(
342 "Agent node requires 'user_input', 'query', or 'prompt' in state",
343 ));
344 }
345
346 Ok(())
347 }
348}
349
350#[derive(Debug, Clone, PartialEq)]
352#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
353pub struct AgentMessage {
354 pub role: MessageRole,
355 pub content: String,
356 pub tool_calls: Option<Vec<ToolCall>>,
357}
358
359#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
362pub enum MessageRole {
363 System,
364 User,
365 Assistant,
366 Tool,
367}
368
369#[derive(Debug, Clone, PartialEq)]
371#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
372pub struct ToolCall {
373 pub id: String,
374 pub name: String,
375 pub arguments: serde_json::Value,
376}
377
378#[derive(Debug, Clone)]
380#[allow(dead_code)]
381struct ToolCallResult {
382 pub call_id: String,
383 pub name: String,
384 pub result: serde_json::Value,
385 pub success: bool,
386 pub error: Option<String>,
387}
388
389trait Pipe<T> {
391 fn pipe<U, F>(self, f: F) -> U
392 where
393 F: FnOnce(T) -> U;
394}
395
396impl<T> Pipe<T> for T {
397 fn pipe<U, F>(self, f: F) -> U
398 where
399 F: FnOnce(T) -> U,
400 {
401 f(self)
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::core::ExecutionContext;
409 use crate::tools::{Tool, ToolError, ToolResult};
410
411 struct MockTool {
413 name: String,
414 }
415
416 #[async_trait]
417 impl Tool for MockTool {
418 async fn execute(
419 &self,
420 _arguments: &serde_json::Value,
421 _state: &GraphState,
422 ) -> Result<ToolResult, ToolError> {
423 Ok(ToolResult {
424 output: serde_json::json!({
425 "tool": self.name,
426 "result": "mock result"
427 }),
428 metadata: HashMap::new(),
429 })
430 }
431
432 fn name(&self) -> &str {
433 &self.name
434 }
435
436 fn description(&self) -> &str {
437 "Mock tool for testing"
438 }
439 }
440
441 #[tokio::test]
442 async fn test_agent_node_creation() {
443 let config = AgentNodeConfig::default();
444 let agent = AgentNode::new("test_agent", config);
445
446 assert_eq!(agent.id().as_str(), "test_agent");
447 assert_eq!(agent.name(), "assistant");
448 }
449
450 #[tokio::test]
451 async fn test_agent_node_with_tools() {
452 let config = AgentNodeConfig::default();
453 let tool = Arc::new(MockTool {
454 name: "search".to_string(),
455 });
456
457 let agent = AgentNode::new("test_agent", config).with_tool("search".to_string(), tool);
458
459 assert!(agent.tools.contains_key("search"));
460 }
461
462 #[tokio::test]
463 async fn test_agent_execution() {
464 let config = AgentNodeConfig::default();
465 let agent = AgentNode::new("test_agent", config);
466
467 let mut state = GraphState::new();
468 state.set("user_input", "Hello, how can you help me?");
469
470 let context = ExecutionContext::new("test_graph".to_string(), NodeId::new("test_agent"));
471 let result = agent.execute(&mut state, &context).await.unwrap();
472
473 assert!(matches!(result, ExecutionResult::Continue));
474 assert!(state.contains_key("agent_response"));
475 }
476
477 #[test]
478 fn test_should_use_tools() {
479 let config = AgentNodeConfig::default();
480 let agent = AgentNode::new("test_agent", config);
481
482 assert!(agent.should_use_tools("Please search for information"));
483 assert!(agent.should_use_tools("Can you calculate this?"));
484 assert!(!agent.should_use_tools("Hello there"));
485 }
486
487 #[test]
488 fn test_agent_message() {
489 let message = AgentMessage {
490 role: MessageRole::User,
491 content: "Test message".to_string(),
492 tool_calls: None,
493 };
494
495 assert_eq!(message.role, MessageRole::User);
496 assert_eq!(message.content, "Test message");
497 assert!(message.tool_calls.is_none());
498 }
499
500 #[test]
501 fn test_tool_call() {
502 let tool_call = ToolCall {
503 id: "test-123".to_string(),
504 name: "search".to_string(),
505 arguments: serde_json::json!({"query": "test"}),
506 };
507
508 assert_eq!(tool_call.id, "test-123");
509 assert_eq!(tool_call.name, "search");
510 assert_eq!(tool_call.arguments["query"], "test");
511 }
512}