praxis_graph/nodes/
llm_node.rs

1use crate::node::{EventSender, Node, NodeType};
2use crate::types::GraphOutput;
3use anyhow::Result;
4use async_trait::async_trait;
5use futures::StreamExt;
6use praxis_llm::{ChatClient, ReasoningClient, ChatOptions, ChatRequest, ResponseRequest, ReasoningConfig, Message, ToolChoice};
7use praxis_mcp::MCPToolExecutor;
8use crate::types::GraphState;
9use std::pin::Pin;
10use std::sync::Arc;
11
12pub struct LLMNode {
13    client: Arc<dyn ChatClient>,
14    reasoning_client: Option<Arc<dyn ReasoningClient>>,
15    mcp_executor: Arc<MCPToolExecutor>,
16}
17
18impl LLMNode {
19    pub fn new(client: Arc<dyn ChatClient>, mcp_executor: Arc<MCPToolExecutor>) -> Self {
20        let reasoning_client = None; // We'll set this from client if it implements both traits
21        Self { 
22            client,
23            reasoning_client,
24            mcp_executor,
25        }
26    }
27
28    pub fn with_reasoning_client(mut self, reasoning_client: Arc<dyn ReasoningClient>) -> Self {
29        self.reasoning_client = Some(reasoning_client);
30        self
31    }
32
33    /// Convert praxis_llm::StreamEvent to Graph StreamEvent
34    /// Uses automatic From trait conversion
35    fn convert_event(event: praxis_llm::StreamEvent) -> crate::types::StreamEvent {
36        event.into()
37    }
38
39    /// Check if model should use Reasoning API
40    fn is_reasoning_model(model: &str) -> bool {
41        model.starts_with("gpt-5") || model.starts_with("o")
42    }
43    
44    /// Template Method: Create stream based on model configuration
45    async fn create_stream(
46        &self,
47        state: &GraphState,
48    ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>> {
49        let model = &state.llm_config.model;
50        let use_reasoning_api = Self::is_reasoning_model(model) && self.reasoning_client.is_some();
51        
52        tracing::info!(
53            "LLM_NODE: Creating stream with model={}, use_reasoning_api={}",
54            model,
55            use_reasoning_api
56        );
57        
58        if use_reasoning_api {
59            self.create_reasoning_stream(state).await
60        } else {
61            self.create_chat_stream(state).await
62        }
63    }
64    
65    async fn create_reasoning_stream(
66        &self,
67        state: &GraphState,
68    ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>> {
69        let reasoning_config = state.llm_config.reasoning_effort
70            .as_ref()
71            .map(|effort| match effort.as_str() {
72                "low" => ReasoningConfig::low(),
73                "high" => ReasoningConfig::high(),
74                _ => ReasoningConfig::medium(),
75            });
76
77        let request = ResponseRequest::new(
78            state.llm_config.model.clone(),
79            state.messages.clone()
80        );
81        let request = if let Some(config) = reasoning_config {
82            request.with_reasoning(config)
83        } else {
84            request
85        };
86
87        self.reasoning_client
88            .as_ref()
89            .unwrap()
90            .reason_stream(request)
91            .await
92    }
93    
94    async fn create_chat_stream(
95        &self,
96        state: &GraphState,
97    ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>> {
98        let tools = self.mcp_executor.get_llm_tools().await?;
99        
100        let mut options = ChatOptions::new()
101            .tools(tools)
102            .tool_choice(ToolChoice::auto());
103
104        if let Some(temp) = state.llm_config.temperature {
105            options = options.temperature(temp);
106        }
107        if let Some(max_tokens) = state.llm_config.max_tokens {
108            options = options.max_tokens(max_tokens);
109        }
110
111        let request = ChatRequest::new(
112            state.llm_config.model.clone(),
113            state.messages.clone()
114        ).with_options(options);
115
116        self.client.chat_stream(request).await
117    }
118    
119    /// Template Method: Process stream and return structured outputs
120    async fn process_stream(
121        &self,
122        mut stream: Pin<Box<dyn futures::Stream<Item = Result<praxis_llm::StreamEvent>> + Send>>,
123        event_tx: EventSender,
124    ) -> Result<Vec<GraphOutput>> {
125        let mut reasoning_content = String::new();
126        let mut message_content = String::new();
127        let mut tool_call_buffers: std::collections::HashMap<u32, (Option<String>, Option<String>, String)> = std::collections::HashMap::new();
128
129        // Forward events and accumulate content separately
130        while let Some(event_result) = stream.next().await {
131            let llm_event = event_result?;
132
133            // Convert and forward to client
134            let graph_event = Self::convert_event(llm_event.clone());
135            event_tx.send(graph_event).await?;
136
137            // Accumulate based on event type (keep reasoning and message separate)
138            match llm_event {
139                praxis_llm::StreamEvent::Reasoning { content } => {
140                    reasoning_content.push_str(&content);
141                }
142                praxis_llm::StreamEvent::Message { content } => {
143                    message_content.push_str(&content);
144                }
145                praxis_llm::StreamEvent::ToolCall { index, id, name, arguments } => {
146                let entry = tool_call_buffers.entry(index).or_insert((None, None, String::new()));
147                
148                if let Some(id) = id {
149                    entry.0 = Some(id);
150                }
151                if let Some(name) = name {
152                    entry.1 = Some(name);
153                }
154                if let Some(args) = arguments {
155                    entry.2.push_str(&args);
156                }
157            }
158                _ => {}
159            }
160        }
161
162        // Build output items
163        let mut outputs = Vec::new();
164        
165        // Add reasoning output if present
166        if !reasoning_content.is_empty() {
167            outputs.push(GraphOutput::reasoning(
168                format!("rs_{}", uuid::Uuid::new_v4()),
169                reasoning_content,
170            ));
171        }
172        
173        // Build tool calls
174        let tool_calls: Vec<praxis_llm::ToolCall> = tool_call_buffers
175            .into_iter()
176            .filter_map(|(_, (id, name, arguments))| {
177            if let (Some(id), Some(name)) = (id, name) {
178                    Some(praxis_llm::ToolCall {
179                    id,
180                    tool_type: "function".to_string(),
181                    function: praxis_llm::types::FunctionCall {
182                        name,
183                        arguments,
184                    },
185                    })
186                } else {
187                    None
188                }
189            })
190            .collect();
191        
192        // Add message output if present
193        if !message_content.is_empty() || !tool_calls.is_empty() {
194            if tool_calls.is_empty() {
195                outputs.push(GraphOutput::message(
196                    format!("msg_{}", uuid::Uuid::new_v4()),
197                    message_content,
198                ));
199            } else {
200                outputs.push(GraphOutput::message_with_tools(
201                    format!("msg_{}", uuid::Uuid::new_v4()),
202                    message_content,
203                    tool_calls,
204                ));
205            }
206        }
207        
208        Ok(outputs)
209    }
210    
211    /// Template Method: Save outputs to state
212    fn save_outputs(&self, state: &mut GraphState, outputs: &[GraphOutput]) -> Result<()> {
213        // Concatenate all content for backward compatibility
214        let mut combined_content = String::new();
215        let mut combined_tool_calls = Vec::new();
216        
217        for output in outputs {
218            match output {
219                GraphOutput::Reasoning { content, .. } => {
220                    combined_content.push_str(content);
221                }
222                GraphOutput::Message { content, tool_calls, .. } => {
223                    combined_content.push_str(content);
224                    if let Some(calls) = tool_calls {
225                        combined_tool_calls.extend(calls.clone());
226                    }
227                }
228            }
229        }
230
231        // Add assistant message to state
232        let content = if !combined_content.is_empty() {
233            Some(praxis_llm::Content::Text(combined_content))
234        } else {
235            None
236        };
237        
238        let tool_calls = if !combined_tool_calls.is_empty() {
239            Some(combined_tool_calls)
240        } else {
241            None
242        };
243        
244        let assistant_message = Message::AI {
245            content,
246            tool_calls,
247                name: None,
248        };
249
250        state.add_message(assistant_message);
251
252        Ok(())
253    }
254}
255
256#[async_trait]
257impl Node for LLMNode {
258    /// Template Method Pattern: Execute node with structured steps
259    async fn execute(&self, state: &mut GraphState, event_tx: EventSender) -> Result<()> {
260        // Step 1: Create stream (Chat or Reasoning API)
261        let stream = self.create_stream(state).await?;
262        
263        // Step 2: Process stream and get structured outputs
264        let outputs = self.process_stream(stream, event_tx).await?;
265        
266        // Step 3: Save outputs to state
267        self.save_outputs(state, &outputs)?;
268        
269        // Store outputs in state for later use by graph
270        state.last_outputs = Some(outputs);
271        
272        Ok(())
273    }
274
275    fn node_type(&self) -> NodeType {
276        NodeType::LLM
277    }
278}
279
280