praxis_graph/nodes/
tool_node.rs1use crate::node::{EventSender, Node, NodeType};
2use anyhow::Result;
3use async_trait::async_trait;
4use praxis_mcp::{MCPToolExecutor, ToolResponse};
5use crate::types::{GraphState, StreamEvent};
6use std::sync::Arc;
7use std::time::Instant;
8
9pub struct ToolNode {
10 mcp_executor: Arc<MCPToolExecutor>,
11}
12
13impl ToolNode {
14 pub fn new(mcp_executor: Arc<MCPToolExecutor>) -> Self {
15 Self { mcp_executor }
16 }
17}
18
19#[async_trait]
20impl Node for ToolNode {
21 async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
22 let tool_calls = state.get_pending_tool_calls();
24
25 if tool_calls.is_empty() {
26 return Ok(());
27 }
28
29 for tool_call in tool_calls {
31 let start = Instant::now();
32
33 let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
35
36 match self
37 .mcp_executor
38 .execute_tool(&tool_call.function.name, args)
39 .await
40 {
41 Ok(responses) => {
42 let result = ToolResponse::join_responses(&responses);
44
45 event_tx
47 .send(StreamEvent::ToolResult {
48 tool_call_id: tool_call.id.clone(),
49 result: result.clone(),
50 is_error: false,
51 duration_ms: start.elapsed().as_millis() as u64,
52 })
53 .await?;
54
55 state.add_tool_result(tool_call.id, result);
57 }
58 Err(e) => {
59 let error_msg = format!("Tool execution failed: {}", e);
61
62 event_tx
63 .send(StreamEvent::ToolResult {
64 tool_call_id: tool_call.id.clone(),
65 result: error_msg.clone(),
66 is_error: true,
67 duration_ms: start.elapsed().as_millis() as u64,
68 })
69 .await?;
70
71 state.add_tool_result(tool_call.id, error_msg);
73 }
74 }
75 }
76
77 Ok(())
78 }
79
80 fn node_type(&self) -> NodeType {
81 NodeType::Tool
82 }
83}
84