praxis_graph/nodes/
tool_node.rs

1use crate::node::{EventSender, Node, NodeType};
2use anyhow::Result;
3use async_trait::async_trait;
4use praxis_mcp::{MCPToolExecutor, ToolResponse};
5use crate::types::{GraphState, StreamEvent};
6use std::sync::Arc;
7use std::time::Instant;
8
9pub struct ToolNode {
10    mcp_executor: Arc<MCPToolExecutor>,
11}
12
13impl ToolNode {
14    pub fn new(mcp_executor: Arc<MCPToolExecutor>) -> Self {
15        Self { mcp_executor }
16    }
17}
18
19#[async_trait]
20impl Node for ToolNode {
21    async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
22        // Get pending tool calls from state
23        let tool_calls = state.get_pending_tool_calls();
24
25        if tool_calls.is_empty() {
26            return Ok(());
27        }
28
29        // Execute each tool call
30        for tool_call in tool_calls {
31            let start = Instant::now();
32
33            // Parse arguments from string to Value
34            let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments)?;
35            
36            match self
37                .mcp_executor
38                .execute_tool(&tool_call.function.name, args)
39                .await
40            {
41                Ok(responses) => {
42                    // Join all responses into a single result string
43                    let result = ToolResponse::join_responses(&responses);
44                    
45                    // Success: emit result event
46                    event_tx
47                        .send(StreamEvent::ToolResult {
48                            tool_call_id: tool_call.id.clone(),
49                            result: result.clone(),
50                            is_error: false,
51                            duration_ms: start.elapsed().as_millis() as u64,
52                        })
53                        .await?;
54
55                    // Add tool result to state
56                    state.add_tool_result(tool_call.id, result);
57                }
58                Err(e) => {
59                    // Tool failed (resilient) - emit error result
60                    let error_msg = format!("Tool execution failed: {}", e);
61
62                    event_tx
63                        .send(StreamEvent::ToolResult {
64                            tool_call_id: tool_call.id.clone(),
65                            result: error_msg.clone(),
66                            is_error: true,
67                            duration_ms: start.elapsed().as_millis() as u64,
68                        })
69                        .await?;
70
71                    // Add error result to state so LLM can see it
72                    state.add_tool_result(tool_call.id, error_msg);
73                }
74            }
75        }
76
77        Ok(())
78    }
79
80    fn node_type(&self) -> NodeType {
81        NodeType::Tool
82    }
83}
84