praxis_graph/nodes/
llm_node.rs1use crate::node::{EventSender, Node, NodeType};
2use anyhow::Result;
3use async_trait::async_trait;
4use futures::StreamExt;
5use praxis_llm::{ChatClient, ChatOptions, ChatRequest, Message, ToolChoice};
6use praxis_mcp::MCPToolExecutor;
7use crate::types::GraphState;
8use std::sync::Arc;
9
10pub struct LLMNode {
11 client: Arc<dyn ChatClient>,
12 mcp_executor: Arc<MCPToolExecutor>,
13}
14
15impl LLMNode {
16 pub fn new(client: Arc<dyn ChatClient>, mcp_executor: Arc<MCPToolExecutor>) -> Self {
17 Self {
18 client,
19 mcp_executor,
20 }
21 }
22
23 fn convert_event(event: praxis_llm::StreamEvent) -> crate::types::StreamEvent {
26 event.into()
27 }
28}
29
30#[async_trait]
31impl Node for LLMNode {
32 async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
33 let tools = self.mcp_executor.get_llm_tools().await?;
35
36 let options = ChatOptions::new()
38 .tools(tools)
39 .tool_choice(ToolChoice::auto());
40
41 let request = ChatRequest::new(state.llm_config.model.clone(), state.messages.clone())
42 .with_options(options);
43
44 let mut stream = self.client.chat_stream(request).await?;
46
47 let mut accumulated_tool_calls: Vec<praxis_llm::ToolCall> = Vec::new();
49 let mut tool_call_buffers: std::collections::HashMap<u32, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
50
51 while let Some(event_result) = stream.next().await {
53 let llm_event = event_result?;
54
55 let graph_event = Self::convert_event(llm_event.clone());
57 event_tx.send(graph_event).await?;
58
59 if let praxis_llm::StreamEvent::ToolCall { index, id, name, arguments } = llm_event {
61 let entry = tool_call_buffers.entry(index).or_insert((None, None, String::new()));
62
63 if let Some(id) = id {
64 entry.0 = Some(id);
65 }
66 if let Some(name) = name {
67 entry.1 = Some(name);
68 }
69 if let Some(args) = arguments {
70 entry.2.push_str(&args);
71 }
72 }
73 }
74
75 for (_, (id, name, arguments)) in tool_call_buffers {
77 if let (Some(id), Some(name)) = (id, name) {
78 accumulated_tool_calls.push(praxis_llm::ToolCall {
79 id,
80 tool_type: "function".to_string(),
81 function: praxis_llm::types::FunctionCall {
82 name,
83 arguments,
84 },
85 });
86 }
87 }
88
89 let assistant_message = if accumulated_tool_calls.is_empty() {
91 Message::AI {
92 content: None, tool_calls: None,
94 name: None,
95 }
96 } else {
97 Message::AI {
98 content: None,
99 tool_calls: Some(accumulated_tool_calls),
100 name: None,
101 }
102 };
103
104 state.add_message(assistant_message);
105
106 Ok(())
107 }
108
109 fn node_type(&self) -> NodeType {
110 NodeType::LLM
111 }
112}
113