1use crate::core::{ExecutionContext, ExecutionResult, Node, NodeId};
6use crate::state::{GraphState, StateValue};
7use crate::tools::Tool;
8use crate::{RGraphError, RGraphResult};
9use async_trait::async_trait;
10use std::sync::Arc;
11
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone)]
17#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
18pub struct ToolNodeConfig {
19 pub tool_name: String,
20 pub argument_mappings: std::collections::HashMap<String, String>,
21 pub output_key: String,
22}
23
24pub struct ToolNode {
26 id: NodeId,
27 name: String,
28 tool: Arc<dyn Tool>,
29 config: ToolNodeConfig,
30}
31
32impl ToolNode {
33 pub fn new(
34 id: impl Into<NodeId>,
35 name: impl Into<String>,
36 tool: Arc<dyn Tool>,
37 config: ToolNodeConfig,
38 ) -> Self {
39 Self {
40 id: id.into(),
41 name: name.into(),
42 tool,
43 config,
44 }
45 }
46}
47
48#[async_trait]
49impl Node for ToolNode {
50 async fn execute(
51 &self,
52 state: &mut GraphState,
53 context: &ExecutionContext,
54 ) -> RGraphResult<ExecutionResult> {
55 let mut arguments = serde_json::Map::new();
57
58 for (state_key, arg_key) in &self.config.argument_mappings {
59 if let Ok(value) = state.get(state_key) {
60 let json_value: serde_json::Value = value.into();
61 arguments.insert(arg_key.clone(), json_value);
62 }
63 }
64
65 let arguments_json = serde_json::Value::Object(arguments);
66
67 match self.tool.execute(&arguments_json, state).await {
69 Ok(result) => {
70 state.set_with_context(
72 context.current_node.as_str(),
73 &self.config.output_key,
74 StateValue::from(result.output),
75 );
76 Ok(ExecutionResult::Continue)
77 }
78 Err(e) => Err(RGraphError::tool(e.to_string())),
79 }
80 }
81
82 fn id(&self) -> &NodeId {
83 &self.id
84 }
85
86 fn name(&self) -> &str {
87 &self.name
88 }
89
90 fn input_keys(&self) -> Vec<&str> {
91 self.config
92 .argument_mappings
93 .keys()
94 .map(|s| s.as_str())
95 .collect()
96 }
97
98 fn output_keys(&self) -> Vec<&str> {
99 vec![&self.config.output_key]
100 }
101}