rrag_graph/nodes/
tool.rs

1//! # Tool Node Implementation
2//!
3//! Tool nodes directly execute tools without agent reasoning.
4
5use crate::core::{ExecutionContext, ExecutionResult, Node, NodeId};
6use crate::state::{GraphState, StateValue};
7use crate::tools::Tool;
8use crate::{RGraphError, RGraphResult};
9use async_trait::async_trait;
10use std::sync::Arc;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15/// Configuration for tool nodes
16#[derive(Debug, Clone)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct ToolNodeConfig {
19    pub tool_name: String,
20    pub argument_mappings: std::collections::HashMap<String, String>,
21    pub output_key: String,
22}
23
24/// A node that executes a specific tool
25pub struct ToolNode {
26    id: NodeId,
27    name: String,
28    tool: Arc<dyn Tool>,
29    config: ToolNodeConfig,
30}
31
32impl ToolNode {
33    pub fn new(
34        id: impl Into<NodeId>,
35        name: impl Into<String>,
36        tool: Arc<dyn Tool>,
37        config: ToolNodeConfig,
38    ) -> Self {
39        Self {
40            id: id.into(),
41            name: name.into(),
42            tool,
43            config,
44        }
45    }
46}
47
48#[async_trait]
49impl Node for ToolNode {
50    async fn execute(
51        &self,
52        state: &mut GraphState,
53        context: &ExecutionContext,
54    ) -> RGraphResult<ExecutionResult> {
55        // Build arguments from state using mappings
56        let mut arguments = serde_json::Map::new();
57
58        for (state_key, arg_key) in &self.config.argument_mappings {
59            if let Ok(value) = state.get(state_key) {
60                let json_value: serde_json::Value = value.into();
61                arguments.insert(arg_key.clone(), json_value);
62            }
63        }
64
65        let arguments_json = serde_json::Value::Object(arguments);
66
67        // Execute the tool
68        match self.tool.execute(&arguments_json, state).await {
69            Ok(result) => {
70                // Store result in state
71                state.set_with_context(
72                    context.current_node.as_str(),
73                    &self.config.output_key,
74                    StateValue::from(result.output),
75                );
76                Ok(ExecutionResult::Continue)
77            }
78            Err(e) => Err(RGraphError::tool(e.to_string())),
79        }
80    }
81
82    fn id(&self) -> &NodeId {
83        &self.id
84    }
85
86    fn name(&self) -> &str {
87        &self.name
88    }
89
90    fn input_keys(&self) -> Vec<&str> {
91        self.config
92            .argument_mappings
93            .keys()
94            .map(|s| s.as_str())
95            .collect()
96    }
97
98    fn output_keys(&self) -> Vec<&str> {
99        vec![&self.config.output_key]
100    }
101}