praxis_graph/nodes/
llm_node.rs

1use crate::node::{EventSender, Node, NodeType};
2use anyhow::Result;
3use async_trait::async_trait;
4use futures::StreamExt;
5use praxis_llm::{ChatOptions, ChatRequest, LLMClient, Message, ToolChoice};
6use praxis_mcp::MCPToolExecutor;
7use praxis_types::GraphState;
8use std::sync::Arc;
9
10pub struct LLMNode {
11    client: Arc<dyn LLMClient>,
12    mcp_executor: Arc<MCPToolExecutor>,
13}
14
15impl LLMNode {
16    pub fn new(client: Arc<dyn LLMClient>, mcp_executor: Arc<MCPToolExecutor>) -> Self {
17        Self { 
18            client,
19            mcp_executor,
20        }
21    }
22
23    /// Convert praxis_llm::StreamEvent to praxis_types::StreamEvent
24    fn convert_event(event: praxis_llm::StreamEvent) -> praxis_types::StreamEvent {
25        match event {
26            praxis_llm::StreamEvent::Reasoning { content } => {
27                praxis_types::StreamEvent::Reasoning { content }
28            }
29            praxis_llm::StreamEvent::Message { content } => {
30                praxis_types::StreamEvent::Message { content }
31            }
32            praxis_llm::StreamEvent::ToolCall {
33                index,
34                id,
35                name,
36                arguments,
37            } => praxis_types::StreamEvent::ToolCall {
38                index,
39                id,
40                name,
41                arguments,
42            },
43            praxis_llm::StreamEvent::Done { finish_reason } => {
44                praxis_types::StreamEvent::Done { finish_reason }
45            }
46        }
47    }
48}
49
50#[async_trait]
51impl Node for LLMNode {
52    async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
53        // Get tools from all connected MCP servers
54        let tools = self.mcp_executor.get_llm_tools().await?;
55        
56        // Build chat request from state
57        let options = ChatOptions::new()
58            .tools(tools)
59            .tool_choice(ToolChoice::auto());
60
61        let request = ChatRequest::new(state.llm_config.model.clone(), state.messages.clone())
62            .with_options(options);
63
64        // Call LLM with streaming
65        let mut stream = self.client.chat_completion_stream(request).await?;
66
67        // Track tool calls as they stream in
68        let mut accumulated_tool_calls: Vec<praxis_llm::ToolCall> = Vec::new();
69        let mut tool_call_buffers: std::collections::HashMap<u32, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
70
71        // Forward events and accumulate tool calls
72        while let Some(event_result) = stream.next().await {
73            let llm_event = event_result?;
74
75            // Convert and forward to client
76            let graph_event = Self::convert_event(llm_event.clone());
77            event_tx.send(graph_event).await?;
78
79            // Accumulate tool calls for state
80            if let praxis_llm::StreamEvent::ToolCall { index, id, name, arguments } = llm_event {
81                let entry = tool_call_buffers.entry(index).or_insert((None, None, String::new()));
82                
83                if let Some(id) = id {
84                    entry.0 = Some(id);
85                }
86                if let Some(name) = name {
87                    entry.1 = Some(name);
88                }
89                if let Some(args) = arguments {
90                    entry.2.push_str(&args);
91                }
92            }
93        }
94
95        // Build final tool calls from accumulated data
96        for (_, (id, name, arguments)) in tool_call_buffers {
97            if let (Some(id), Some(name)) = (id, name) {
98                accumulated_tool_calls.push(praxis_llm::ToolCall {
99                    id,
100                    tool_type: "function".to_string(),
101                    function: praxis_llm::types::FunctionCall {
102                        name,
103                        arguments,
104                    },
105                });
106            }
107        }
108
109        // Add assistant message to state
110        let assistant_message = if accumulated_tool_calls.is_empty() {
111            Message::AI {
112                content: None, // Content was streamed, not accumulated here
113                tool_calls: None,
114                name: None,
115            }
116        } else {
117            Message::AI {
118                content: None,
119                tool_calls: Some(accumulated_tool_calls),
120                name: None,
121            }
122        };
123
124        state.add_message(assistant_message);
125
126        Ok(())
127    }
128
129    fn node_type(&self) -> NodeType {
130        NodeType::LLM
131    }
132}
133