praxis_graph/nodes/
llm_node.rs

1use crate::node::{EventSender, Node, NodeType};
2use anyhow::Result;
3use async_trait::async_trait;
4use futures::StreamExt;
5use praxis_llm::{ChatClient, ChatOptions, ChatRequest, Message, ToolChoice};
6use praxis_mcp::MCPToolExecutor;
7use crate::types::GraphState;
8use std::sync::Arc;
9
10pub struct LLMNode {
11    client: Arc<dyn ChatClient>,
12    mcp_executor: Arc<MCPToolExecutor>,
13}
14
15impl LLMNode {
16    pub fn new(client: Arc<dyn ChatClient>, mcp_executor: Arc<MCPToolExecutor>) -> Self {
17        Self { 
18            client,
19            mcp_executor,
20        }
21    }
22
23    /// Convert praxis_llm::StreamEvent to Graph StreamEvent
24    /// Uses automatic From trait conversion
25    fn convert_event(event: praxis_llm::StreamEvent) -> crate::types::StreamEvent {
26        event.into()
27    }
28}
29
30#[async_trait]
31impl Node for LLMNode {
32    async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
33        // Get tools from all connected MCP servers
34        let tools = self.mcp_executor.get_llm_tools().await?;
35        
36        // Build chat request from state
37        let options = ChatOptions::new()
38            .tools(tools)
39            .tool_choice(ToolChoice::auto());
40
41        let request = ChatRequest::new(state.llm_config.model.clone(), state.messages.clone())
42            .with_options(options);
43
44        // Call LLM with streaming
45        let mut stream = self.client.chat_stream(request).await?;
46
47        // Track tool calls as they stream in
48        let mut accumulated_tool_calls: Vec<praxis_llm::ToolCall> = Vec::new();
49        let mut tool_call_buffers: std::collections::HashMap<u32, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
50
51        // Forward events and accumulate tool calls
52        while let Some(event_result) = stream.next().await {
53            let llm_event = event_result?;
54
55            // Convert and forward to client
56            let graph_event = Self::convert_event(llm_event.clone());
57            event_tx.send(graph_event).await?;
58
59            // Accumulate tool calls for state
60            if let praxis_llm::StreamEvent::ToolCall { index, id, name, arguments } = llm_event {
61                let entry = tool_call_buffers.entry(index).or_insert((None, None, String::new()));
62                
63                if let Some(id) = id {
64                    entry.0 = Some(id);
65                }
66                if let Some(name) = name {
67                    entry.1 = Some(name);
68                }
69                if let Some(args) = arguments {
70                    entry.2.push_str(&args);
71                }
72            }
73        }
74
75        // Build final tool calls from accumulated data
76        for (_, (id, name, arguments)) in tool_call_buffers {
77            if let (Some(id), Some(name)) = (id, name) {
78                accumulated_tool_calls.push(praxis_llm::ToolCall {
79                    id,
80                    tool_type: "function".to_string(),
81                    function: praxis_llm::types::FunctionCall {
82                        name,
83                        arguments,
84                    },
85                });
86            }
87        }
88
89        // Add assistant message to state
90        let assistant_message = if accumulated_tool_calls.is_empty() {
91            Message::AI {
92                content: None, // Content was streamed, not accumulated here
93                tool_calls: None,
94                name: None,
95            }
96        } else {
97            Message::AI {
98                content: None,
99                tool_calls: Some(accumulated_tool_calls),
100                name: None,
101            }
102        };
103
104        state.add_message(assistant_message);
105
106        Ok(())
107    }
108
109    fn node_type(&self) -> NodeType {
110        NodeType::LLM
111    }
112}
113