praxis_graph/nodes/
llm_node.rs1use crate::node::{EventSender, Node, NodeType};
2use crate::types::GraphOutput;
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::StreamExt;
6use praxis_llm::{ChatClient, ReasoningClient, ChatOptions, ChatRequest, ResponseRequest, ReasoningConfig, Message, ToolChoice};
7use praxis_mcp::MCPToolExecutor;
8use crate::types::GraphState;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub struct LLMNode {
13 client: Arc<dyn ChatClient>,
14 reasoning_client: Option<Arc<dyn ReasoningClient>>,
15 mcp_executor: Arc<MCPToolExecutor>,
16}
17
18impl LLMNode {
19 pub fn new(client: Arc<dyn ChatClient>, mcp_executor: Arc<MCPToolExecutor>) -> Self {
20 let reasoning_client = None; Self {
22 client,
23 reasoning_client,
24 mcp_executor,
25 }
26 }
27
28 pub fn with_reasoning_client(mut self, reasoning_client: Arc<dyn ReasoningClient>) -> Self {
29 self.reasoning_client = Some(reasoning_client);
30 self
31 }
32
33 fn convert_event(event: praxis_llm::StreamEvent) -> crate::types::StreamEvent {
36 event.into()
37 }
38
39 fn is_reasoning_model(model: &str) -> bool {
41 model.starts_with("gpt-5") || model.starts_with("o")
42 }
43
44 async fn create_stream(
46 &self,
47 state: &GraphState,
48 ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>> {
49 let model = &state.llm_config.model;
50 let use_reasoning_api = Self::is_reasoning_model(model) && self.reasoning_client.is_some();
51
52 tracing::info!(
53 "LLM_NODE: Creating stream with model={}, use_reasoning_api={}",
54 model,
55 use_reasoning_api
56 );
57
58 if use_reasoning_api {
59 self.create_reasoning_stream(state).await
60 } else {
61 self.create_chat_stream(state).await
62 }
63 }
64
65 async fn create_reasoning_stream(
66 &self,
67 state: &GraphState,
68 ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>> {
69 let reasoning_config = state.llm_config.reasoning_effort
70 .as_ref()
71 .map(|effort| match effort.as_str() {
72 "low" => ReasoningConfig::low(),
73 "high" => ReasoningConfig::high(),
74 _ => ReasoningConfig::medium(),
75 });
76
77 let request = ResponseRequest::new(
78 state.llm_config.model.clone(),
79 state.messages.clone()
80 );
81 let request = if let Some(config) = reasoning_config {
82 request.with_reasoning(config)
83 } else {
84 request
85 };
86
87 self.reasoning_client
88 .as_ref()
89 .unwrap()
90 .reason_stream(request)
91 .await
92 }
93
94 async fn create_chat_stream(
95 &self,
96 state: &GraphState,
97 ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>> {
98 let tools = self.mcp_executor.get_llm_tools().await?;
99
100 let mut options = ChatOptions::new()
101 .tools(tools)
102 .tool_choice(ToolChoice::auto());
103
104 if let Some(temp) = state.llm_config.temperature {
105 options = options.temperature(temp);
106 }
107 if let Some(max_tokens) = state.llm_config.max_tokens {
108 options = options.max_tokens(max_tokens);
109 }
110
111 let request = ChatRequest::new(
112 state.llm_config.model.clone(),
113 state.messages.clone()
114 ).with_options(options);
115
116 self.client.chat_stream(request).await
117 }
118
119 async fn process_stream(
121 &self,
122 mut stream: Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>,
123 event_tx: EventSender,
124 ) -> Result<Vec<GraphOutput>> {
125 let mut reasoning_content = String::new();
126 let mut message_content = String::new();
127 let mut tool_call_buffers: std::collections::HashMap<u32, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
128
129 while let Some(event_result) = stream.next().await {
131 let llm_event = event_result?;
132
133 let graph_event = Self::convert_event(llm_event.clone());
135 event_tx.send(graph_event).await?;
136
137 match llm_event {
139 praxis_llm::StreamEvent::Reasoning { content } => {
140 reasoning_content.push_str(&content);
141 }
142 praxis_llm::StreamEvent::Message { content } => {
143 message_content.push_str(&content);
144 }
145 praxis_llm::StreamEvent::ToolCall { index, id, name, arguments } => {
146 let entry = tool_call_buffers.entry(index).or_insert((None, None, String::new()));
147
148 if let Some(id) = id {
149 entry.0 = Some(id);
150 }
151 if let Some(name) = name {
152 entry.1 = Some(name);
153 }
154 if let Some(args) = arguments {
155 entry.2.push_str(&args);
156 }
157 }
158 _ => {}
159 }
160 }
161
162 let mut outputs = Vec::new();
164
165 if !reasoning_content.is_empty() {
167 outputs.push(GraphOutput::reasoning(
168 format!("rs_{}", uuid::Uuid::new_v4()),
169 reasoning_content,
170 ));
171 }
172
173 let tool_calls: Vec<praxis_llm::ToolCall> = tool_call_buffers
175 .into_iter()
176 .filter_map(|(_, (id, name, arguments))| {
177 if let (Some(id), Some(name)) = (id, name) {
178 Some(praxis_llm::ToolCall {
179 id,
180 tool_type: "function".to_string(),
181 function: praxis_llm::types::FunctionCall {
182 name,
183 arguments,
184 },
185 })
186 } else {
187 None
188 }
189 })
190 .collect();
191
192 if !message_content.is_empty() || !tool_calls.is_empty() {
194 if tool_calls.is_empty() {
195 outputs.push(GraphOutput::message(
196 format!("msg_{}", uuid::Uuid::new_v4()),
197 message_content,
198 ));
199 } else {
200 outputs.push(GraphOutput::message_with_tools(
201 format!("msg_{}", uuid::Uuid::new_v4()),
202 message_content,
203 tool_calls,
204 ));
205 }
206 }
207
208 Ok(outputs)
209 }
210
211 fn save_outputs(&self, state: &mut GraphState, outputs: &[GraphOutput]) -> Result<()> {
213 let mut combined_content = String::new();
215 let mut combined_tool_calls = Vec::new();
216
217 for output in outputs {
218 match output {
219 GraphOutput::Reasoning { content, .. } => {
220 combined_content.push_str(content);
221 }
222 GraphOutput::Message { content, tool_calls, .. } => {
223 combined_content.push_str(content);
224 if let Some(calls) = tool_calls {
225 combined_tool_calls.extend(calls.clone());
226 }
227 }
228 }
229 }
230
231 let content = if !combined_content.is_empty() {
233 Some(praxis_llm::Content::Text(combined_content))
234 } else {
235 None
236 };
237
238 let tool_calls = if !combined_tool_calls.is_empty() {
239 Some(combined_tool_calls)
240 } else {
241 None
242 };
243
244 let assistant_message = Message::AI {
245 content,
246 tool_calls,
247 name: None,
248 };
249
250 state.add_message(assistant_message);
251
252 Ok(())
253 }
254}
255
256#[async_trait]
257impl Node for LLMNode {
258 async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
260 let stream = self.create_stream(state).await?;
262
263 let outputs = self.process_stream(stream, event_tx).await?;
265
266 self.save_outputs(state, &outputs)?;
268
269 state.last_outputs = Some(outputs);
271
272 Ok(())
273 }
274
275 fn node_type(&self) -> NodeType {
276 NodeType::LLM
277 }
278}
279
280