rrag_graph/nodes/
agent.rs

1//! # Agent Node Implementation
2//!
3//! Agent nodes represent autonomous AI agents that can reason, make decisions,
4//! and use tools to accomplish tasks.
5
6use crate::core::{ExecutionContext, ExecutionResult, Node, NodeId};
7use crate::state::{GraphState, StateValue};
8use crate::tools::Tool;
9use crate::{RGraphError, RGraphResult};
10use async_trait::async_trait;
11use std::collections::HashMap;
12use std::sync::Arc;
13
14#[cfg(feature = "serde")]
15use serde::{Deserialize, Serialize};
16
17/// Configuration for an agent node
18#[derive(Debug, Clone)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub struct AgentNodeConfig {
21    /// Agent name
22    pub name: String,
23
24    /// System prompt for the agent
25    pub system_prompt: String,
26
27    /// Available tools
28    pub tools: Vec<String>,
29
30    /// Maximum number of reasoning steps
31    pub max_steps: usize,
32
33    /// Temperature for generation
34    pub temperature: f32,
35
36    /// Maximum tokens for generation
37    pub max_tokens: Option<usize>,
38
39    /// Whether to use structured output
40    pub structured_output: bool,
41
42    /// Custom instructions
43    pub instructions: Vec<String>,
44}
45
46impl Default for AgentNodeConfig {
47    fn default() -> Self {
48        Self {
49            name: "assistant".to_string(),
50            system_prompt: "You are a helpful AI assistant.".to_string(),
51            tools: Vec::new(),
52            max_steps: 10,
53            temperature: 0.7,
54            max_tokens: Some(1000),
55            structured_output: false,
56            instructions: Vec::new(),
57        }
58    }
59}
60
61/// An agent node that can reason and use tools
62pub struct AgentNode {
63    id: NodeId,
64    config: AgentNodeConfig,
65    tools: HashMap<String, Arc<dyn Tool>>,
66}
67
68impl AgentNode {
69    /// Create a new agent node
70    pub fn new(id: impl Into<NodeId>, config: AgentNodeConfig) -> Self {
71        Self {
72            id: id.into(),
73            config,
74            tools: HashMap::new(),
75        }
76    }
77
78    /// Add a tool to the agent
79    pub fn with_tool(mut self, name: String, tool: Arc<dyn Tool>) -> Self {
80        self.tools.insert(name, tool);
81        self
82    }
83
84    /// Set the system prompt
85    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
86        self.config.system_prompt = prompt.into();
87        self
88    }
89
90    /// Add tools by name
91    pub fn with_tools(mut self, tools: Vec<String>) -> Self {
92        self.config.tools = tools;
93        self
94    }
95
96    /// Set temperature
97    pub fn with_temperature(mut self, temperature: f32) -> Self {
98        self.config.temperature = temperature.clamp(0.0, 2.0);
99        self
100    }
101
102    /// Execute the agent's reasoning loop
103    async fn reasoning_loop(
104        &self,
105        state: &mut GraphState,
106        _context: &ExecutionContext,
107        initial_input: &str,
108    ) -> RGraphResult<String> {
109        let mut conversation_history = Vec::new();
110        let mut step_count = 0;
111
112        // Add system prompt
113        conversation_history.push(AgentMessage {
114            role: MessageRole::System,
115            content: self.config.system_prompt.clone(),
116            tool_calls: None,
117        });
118
119        // Add user input
120        conversation_history.push(AgentMessage {
121            role: MessageRole::User,
122            content: initial_input.to_string(),
123            tool_calls: None,
124        });
125
126        loop {
127            if step_count >= self.config.max_steps {
128                break;
129            }
130
131            step_count += 1;
132
133            // Generate agent response
134            let agent_response = self.generate_response(&conversation_history, state).await?;
135
136            // Check if agent wants to use tools
137            if let Some(tool_calls) = &agent_response.tool_calls {
138                // Execute tool calls
139                let mut tool_results = Vec::new();
140
141                for tool_call in tool_calls {
142                    if let Some(tool) = self.tools.get(&tool_call.name) {
143                        match tool.execute(&tool_call.arguments, state).await {
144                            Ok(result) => {
145                                tool_results.push(ToolCallResult {
146                                    call_id: tool_call.id.clone(),
147                                    name: tool_call.name.clone(),
148                                    result: result.output,
149                                    success: true,
150                                    error: None,
151                                });
152                            }
153                            Err(e) => {
154                                tool_results.push(ToolCallResult {
155                                    call_id: tool_call.id.clone(),
156                                    name: tool_call.name.clone(),
157                                    result: serde_json::Value::Null,
158                                    success: false,
159                                    error: Some(e.to_string()),
160                                });
161                            }
162                        }
163                    } else {
164                        tool_results.push(ToolCallResult {
165                            call_id: tool_call.id.clone(),
166                            name: tool_call.name.clone(),
167                            result: serde_json::Value::Null,
168                            success: false,
169                            error: Some(format!("Tool '{}' not found", tool_call.name)),
170                        });
171                    }
172                }
173
174                // Add assistant message with tool calls
175                conversation_history.push(agent_response);
176
177                // Add tool results
178                for tool_result in tool_results {
179                    conversation_history.push(AgentMessage {
180                        role: MessageRole::Tool,
181                        content: if tool_result.success {
182                            serde_json::to_string_pretty(&tool_result.result)
183                                .unwrap_or_else(|_| "Tool execution completed".to_string())
184                        } else {
185                            format!(
186                                "Error: {}",
187                                tool_result
188                                    .error
189                                    .unwrap_or_else(|| "Unknown error".to_string())
190                            )
191                        },
192                        tool_calls: None,
193                    });
194                }
195            } else {
196                // Agent provided final response
197                conversation_history.push(agent_response.clone());
198                return Ok(agent_response.content);
199            }
200        }
201
202        // If we exit the loop without a final response, return the last agent message
203        conversation_history
204            .iter()
205            .filter(|msg| msg.role == MessageRole::Assistant)
206            .last()
207            .map(|msg| msg.content.clone())
208            .unwrap_or_else(|| "Maximum reasoning steps reached without conclusion".to_string())
209            .pipe(Ok)
210    }
211
212    /// Generate a response from the agent
213    async fn generate_response(
214        &self,
215        conversation: &[AgentMessage],
216        _state: &GraphState,
217    ) -> RGraphResult<AgentMessage> {
218        // In a real implementation, this would call an LLM API
219        // For now, we'll simulate an agent response
220
221        let empty_string = String::new();
222        let last_user_message = conversation
223            .iter()
224            .filter(|msg| msg.role == MessageRole::User)
225            .last()
226            .map(|msg| &msg.content)
227            .unwrap_or(&empty_string);
228
229        // Check if we should use tools based on the input
230        if self.should_use_tools(last_user_message) && !self.tools.is_empty() {
231            // Simulate tool usage decision
232            let tool_name = self.tools.keys().next().unwrap().clone();
233
234            Ok(AgentMessage {
235                role: MessageRole::Assistant,
236                content: format!(
237                    "I'll help you with that. Let me use the {} tool.",
238                    tool_name
239                ),
240                tool_calls: Some(vec![ToolCall {
241                    id: uuid::Uuid::new_v4().to_string(),
242                    name: tool_name,
243                    arguments: serde_json::json!({
244                        "query": last_user_message
245                    }),
246                }]),
247            })
248        } else {
249            // Generate a direct response
250            Ok(AgentMessage {
251                role: MessageRole::Assistant,
252                content: format!(
253                    "Based on your request '{}', I can provide assistance. This is a simulated response from the {} agent.",
254                    last_user_message,
255                    self.config.name
256                ),
257                tool_calls: None,
258            })
259        }
260    }
261
262    /// Determine if the agent should use tools for this input
263    fn should_use_tools(&self, input: &str) -> bool {
264        // Simple heuristic - in a real implementation this would be more sophisticated
265        let tool_keywords = ["search", "calculate", "analyze", "find", "lookup", "query"];
266        let input_lower = input.to_lowercase();
267
268        tool_keywords
269            .iter()
270            .any(|keyword| input_lower.contains(keyword))
271    }
272}
273
274#[async_trait]
275impl Node for AgentNode {
276    async fn execute(
277        &self,
278        state: &mut GraphState,
279        context: &ExecutionContext,
280    ) -> RGraphResult<ExecutionResult> {
281        // Get input from state
282        let input = state
283            .get("user_input")
284            .or_else(|_| state.get("query"))
285            .or_else(|_| state.get("prompt"))
286            .map_err(|_| {
287                RGraphError::node(
288                    self.id.as_str(),
289                    "No input found in state (expected 'user_input', 'query', or 'prompt')",
290                )
291            })?;
292
293        let input_text = match input {
294            StateValue::String(s) => s,
295            _ => {
296                return Err(RGraphError::node(
297                    self.id.as_str(),
298                    "Input must be a string",
299                ))
300            }
301        };
302
303        // Execute reasoning loop
304        let response = self.reasoning_loop(state, context, &input_text).await?;
305
306        // Store the response in state
307        state.set_with_context(
308            context.current_node.as_str(),
309            "agent_response",
310            response.clone(),
311        );
312
313        // Also store in a generic output key
314        state.set_with_context(context.current_node.as_str(), "output", response);
315
316        Ok(ExecutionResult::Continue)
317    }
318
319    fn id(&self) -> &NodeId {
320        &self.id
321    }
322
323    fn name(&self) -> &str {
324        &self.config.name
325    }
326
327    fn input_keys(&self) -> Vec<&str> {
328        vec!["user_input", "query", "prompt"]
329    }
330
331    fn output_keys(&self) -> Vec<&str> {
332        vec!["agent_response", "output"]
333    }
334
335    fn validate(&self, state: &GraphState) -> RGraphResult<()> {
336        // Check that we have input
337        if !state.contains_key("user_input")
338            && !state.contains_key("query")
339            && !state.contains_key("prompt")
340        {
341            return Err(RGraphError::validation(
342                "Agent node requires 'user_input', 'query', or 'prompt' in state",
343            ));
344        }
345
346        Ok(())
347    }
348}
349
350/// Message in agent conversation
351#[derive(Debug, Clone, PartialEq)]
352#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
353pub struct AgentMessage {
354    pub role: MessageRole,
355    pub content: String,
356    pub tool_calls: Option<Vec<ToolCall>>,
357}
358
359/// Role of message in conversation
360#[derive(Debug, Clone, Copy, PartialEq, Eq)]
361#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
362pub enum MessageRole {
363    System,
364    User,
365    Assistant,
366    Tool,
367}
368
369/// Tool call from agent
370#[derive(Debug, Clone, PartialEq)]
371#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
372pub struct ToolCall {
373    pub id: String,
374    pub name: String,
375    pub arguments: serde_json::Value,
376}
377
378/// Result of a tool call
379#[derive(Debug, Clone)]
380#[allow(dead_code)]
381struct ToolCallResult {
382    pub call_id: String,
383    pub name: String,
384    pub result: serde_json::Value,
385    pub success: bool,
386    pub error: Option<String>,
387}
388
389// Helper trait for pipe operations
390trait Pipe<T> {
391    fn pipe<U, F>(self, f: F) -> U
392    where
393        F: FnOnce(T) -> U;
394}
395
396impl<T> Pipe<T> for T {
397    fn pipe<U, F>(self, f: F) -> U
398    where
399        F: FnOnce(T) -> U,
400    {
401        f(self)
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::core::ExecutionContext;
409    use crate::tools::{Tool, ToolError, ToolResult};
410
411    // Mock tool for testing
412    struct MockTool {
413        name: String,
414    }
415
416    #[async_trait]
417    impl Tool for MockTool {
418        async fn execute(
419            &self,
420            _arguments: &serde_json::Value,
421            _state: &GraphState,
422        ) -> Result<ToolResult, ToolError> {
423            Ok(ToolResult {
424                output: serde_json::json!({
425                    "tool": self.name,
426                    "result": "mock result"
427                }),
428                metadata: HashMap::new(),
429            })
430        }
431
432        fn name(&self) -> &str {
433            &self.name
434        }
435
436        fn description(&self) -> &str {
437            "Mock tool for testing"
438        }
439    }
440
441    #[tokio::test]
442    async fn test_agent_node_creation() {
443        let config = AgentNodeConfig::default();
444        let agent = AgentNode::new("test_agent", config);
445
446        assert_eq!(agent.id().as_str(), "test_agent");
447        assert_eq!(agent.name(), "assistant");
448    }
449
450    #[tokio::test]
451    async fn test_agent_node_with_tools() {
452        let config = AgentNodeConfig::default();
453        let tool = Arc::new(MockTool {
454            name: "search".to_string(),
455        });
456
457        let agent = AgentNode::new("test_agent", config).with_tool("search".to_string(), tool);
458
459        assert!(agent.tools.contains_key("search"));
460    }
461
462    #[tokio::test]
463    async fn test_agent_execution() {
464        let config = AgentNodeConfig::default();
465        let agent = AgentNode::new("test_agent", config);
466
467        let mut state = GraphState::new();
468        state.set("user_input", "Hello, how can you help me?");
469
470        let context = ExecutionContext::new("test_graph".to_string(), NodeId::new("test_agent"));
471        let result = agent.execute(&mut state, &context).await.unwrap();
472
473        assert!(matches!(result, ExecutionResult::Continue));
474        assert!(state.contains_key("agent_response"));
475    }
476
477    #[test]
478    fn test_should_use_tools() {
479        let config = AgentNodeConfig::default();
480        let agent = AgentNode::new("test_agent", config);
481
482        assert!(agent.should_use_tools("Please search for information"));
483        assert!(agent.should_use_tools("Can you calculate this?"));
484        assert!(!agent.should_use_tools("Hello there"));
485    }
486
487    #[test]
488    fn test_agent_message() {
489        let message = AgentMessage {
490            role: MessageRole::User,
491            content: "Test message".to_string(),
492            tool_calls: None,
493        };
494
495        assert_eq!(message.role, MessageRole::User);
496        assert_eq!(message.content, "Test message");
497        assert!(message.tool_calls.is_none());
498    }
499
500    #[test]
501    fn test_tool_call() {
502        let tool_call = ToolCall {
503            id: "test-123".to_string(),
504            name: "search".to_string(),
505            arguments: serde_json::json!({"query": "test"}),
506        };
507
508        assert_eq!(tool_call.id, "test-123");
509        assert_eq!(tool_call.name, "search");
510        assert_eq!(tool_call.arguments["query"], "test");
511    }
512}