praxis_graph/nodes/
llm_node.rs1use crate::node::{EventSender, Node, NodeType};
2use anyhow::Result;
3use async_trait::async_trait;
4use futures::StreamExt;
5use praxis_llm::{ChatOptions, ChatRequest, LLMClient, Message, ToolChoice};
6use praxis_mcp::MCPToolExecutor;
7use praxis_types::GraphState;
8use std::sync::Arc;
9
10pub struct LLMNode {
11 client: Arc<dyn LLMClient>,
12 mcp_executor: Arc<MCPToolExecutor>,
13}
14
15impl LLMNode {
16 pub fn new(client: Arc<dyn LLMClient>, mcp_executor: Arc<MCPToolExecutor>) -> Self {
17 Self {
18 client,
19 mcp_executor,
20 }
21 }
22
23 fn convert_event(event: praxis_llm::StreamEvent) -> praxis_types::StreamEvent {
25 match event {
26 praxis_llm::StreamEvent::Reasoning { content } => {
27 praxis_types::StreamEvent::Reasoning { content }
28 }
29 praxis_llm::StreamEvent::Message { content } => {
30 praxis_types::StreamEvent::Message { content }
31 }
32 praxis_llm::StreamEvent::ToolCall {
33 index,
34 id,
35 name,
36 arguments,
37 } => praxis_types::StreamEvent::ToolCall {
38 index,
39 id,
40 name,
41 arguments,
42 },
43 praxis_llm::StreamEvent::Done { finish_reason } => {
44 praxis_types::StreamEvent::Done { finish_reason }
45 }
46 }
47 }
48}
49
50#[async_trait]
51impl Node for LLMNode {
52 async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
53 let tools = self.mcp_executor.get_llm_tools().await?;
55
56 let options = ChatOptions::new()
58 .tools(tools)
59 .tool_choice(ToolChoice::auto());
60
61 let request = ChatRequest::new(state.llm_config.model.clone(), state.messages.clone())
62 .with_options(options);
63
64 let mut stream = self.client.chat_completion_stream(request).await?;
66
67 let mut accumulated_tool_calls: Vec<praxis_llm::ToolCall> = Vec::new();
69 let mut tool_call_buffers: std::collections::HashMap<u32, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
70
71 while let Some(event_result) = stream.next().await {
73 let llm_event = event_result?;
74
75 let graph_event = Self::convert_event(llm_event.clone());
77 event_tx.send(graph_event).await?;
78
79 if let praxis_llm::StreamEvent::ToolCall { index, id, name, arguments } = llm_event {
81 let entry = tool_call_buffers.entry(index).or_insert((None, None, String::new()));
82
83 if let Some(id) = id {
84 entry.0 = Some(id);
85 }
86 if let Some(name) = name {
87 entry.1 = Some(name);
88 }
89 if let Some(args) = arguments {
90 entry.2.push_str(&args);
91 }
92 }
93 }
94
95 for (_, (id, name, arguments)) in tool_call_buffers {
97 if let (Some(id), Some(name)) = (id, name) {
98 accumulated_tool_calls.push(praxis_llm::ToolCall {
99 id,
100 tool_type: "function".to_string(),
101 function: praxis_llm::types::FunctionCall {
102 name,
103 arguments,
104 },
105 });
106 }
107 }
108
109 let assistant_message = if accumulated_tool_calls.is_empty() {
111 Message::AI {
112 content: None, tool_calls: None,
114 name: None,
115 }
116 } else {
117 Message::AI {
118 content: None,
119 tool_calls: Some(accumulated_tool_calls),
120 name: None,
121 }
122 };
123
124 state.add_message(assistant_message);
125
126 Ok(())
127 }
128
129 fn node_type(&self) -> NodeType {
130 NodeType::LLM
131 }
132}
133