praxis_graph/
graph.rs

1use crate::node::{Node, NodeType};
2use crate::nodes::{LLMNode, ToolNode};
3use crate::router::{NextNode, Router, SimpleRouter};
4use crate::builder::PersistenceConfig;
5use praxis_llm::ReasoningClient;
6#[cfg(feature = "observability")]
7use crate::builder::ObserverConfig;
8use anyhow::Result;
9use praxis_llm::LLMClient;
10use praxis_mcp::MCPToolExecutor;
11use crate::types::{GraphConfig, GraphInput, GraphState, StreamEvent};
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::mpsc;
15
16/// Context for persistence operations
17pub struct PersistenceContext {
18    pub thread_id: String,
19    pub user_id: String,
20}
21
22pub struct Graph {
23    llm_client: Arc<dyn LLMClient>,
24    reasoning_client: Option<Arc<dyn praxis_llm::ReasoningClient>>,
25    mcp_executor: Arc<MCPToolExecutor>,
26    config: GraphConfig,
27    persistence: Option<Arc<PersistenceConfig>>,
28    #[cfg(feature = "observability")]
29    observer: Option<Arc<ObserverConfig>>,
30}
31
32impl Graph {
33    pub fn new(
34        llm_client: Arc<dyn LLMClient>,
35        mcp_executor: Arc<MCPToolExecutor>,
36        config: GraphConfig,
37    ) -> Self {
38        Self {
39            llm_client,
40            reasoning_client: None,
41            mcp_executor,
42            config,
43            persistence: None,
44            #[cfg(feature = "observability")]
45            observer: None,
46        }
47    }
48    
49    pub(crate) fn new_with_config(
50        llm_client: Arc<dyn LLMClient>,
51        reasoning_client: Option<Arc<dyn praxis_llm::ReasoningClient>>,
52        mcp_executor: Arc<MCPToolExecutor>,
53        config: GraphConfig,
54        persistence: Option<PersistenceConfig>,
55        #[cfg(feature = "observability")]
56        observer: Option<ObserverConfig>,
57    ) -> Self {
58        Self {
59            llm_client,
60            reasoning_client,
61            mcp_executor,
62            config,
63            persistence: persistence.map(Arc::new),
64            #[cfg(feature = "observability")]
65            observer: observer.map(Arc::new),
66        }
67    }
68    
69    /// Create a builder for fluent construction
70    pub fn builder() -> crate::builder::GraphBuilder {
71        crate::builder::GraphBuilder::new()
72    }
73
74    /// Spawn execution in background, return event receiver
75    pub fn spawn_run(
76        &self,
77        input: GraphInput,
78        persistence_ctx: Option<PersistenceContext>,
79    ) -> mpsc::Receiver<StreamEvent> {
80        let (tx, rx) = mpsc::channel(1000);
81
82        // Clone what we need for the spawned task
83        let llm_client = Arc::clone(&self.llm_client);
84        let reasoning_client = self.reasoning_client.clone();
85        let mcp_executor = Arc::clone(&self.mcp_executor);
86        let config = self.config.clone();
87        let persistence = self.persistence.clone();
88        #[cfg(feature = "observability")]
89        let observer = self.observer.clone();
90
91        tokio::spawn(async move {
92            if let Err(e) = Self::execute_loop(
93                input,
94                tx.clone(),
95                llm_client,
96                reasoning_client,
97                mcp_executor,
98                config,
99                persistence,
100                #[cfg(feature = "observability")]
101                observer,
102                persistence_ctx,
103            ).await {
104                let _ = tx
105                    .send(StreamEvent::Error {
106                        message: e.to_string(),
107                        node_id: None,
108                    })
109                    .await;
110            }
111        });
112
113        rx
114    }
115
116    async fn execute_loop(
117        input: GraphInput,
118        event_tx: mpsc::Sender<StreamEvent>,
119        llm_client: Arc<dyn LLMClient>,
120        reasoning_client: Option<Arc<dyn ReasoningClient>>,
121        mcp_executor: Arc<MCPToolExecutor>,
122        config: GraphConfig,
123        persistence: Option<Arc<PersistenceConfig>>,
124        #[cfg(feature = "observability")]
125        observer: Option<Arc<ObserverConfig>>,
126        ctx: Option<PersistenceContext>,
127    ) -> Result<()> {
128        let start_time = Instant::now();
129
130        // Build initial state
131        let mut state = GraphState::from_input(input);
132
133        // Initialize tracing if observer is configured
134        #[cfg(feature = "observability")]
135        if let Some(ref obs) = observer {
136            let obs_clone = Arc::clone(&obs.observer);
137            let run_id = state.run_id.clone();
138            let conversation_id = state.conversation_id.clone();
139            tokio::spawn(async move {
140                if let Err(e) = obs_clone.trace_start(run_id, conversation_id).await {
141                    tracing::error!("Failed to start trace: {}", e);
142                }
143            });
144        }
145
146        // Emit init event
147        let init_event = StreamEvent::InitStream {
148            run_id: state.run_id.clone(),
149            conversation_id: state.conversation_id.clone(),
150            timestamp: chrono::Utc::now().timestamp_millis(),
151        };
152        event_tx.send(init_event.clone()).await?;
153
154        // Create nodes
155        let mut llm_node = LLMNode::new(llm_client.clone(), mcp_executor.clone());
156        
157        if let Some(reasoning_client) = reasoning_client.clone() {
158            llm_node = llm_node.with_reasoning_client(reasoning_client);
159        }
160        let tool_node = ToolNode::new(mcp_executor);
161        let router = SimpleRouter;
162
163        let mut current_node = NodeType::LLM;
164        let mut iteration = 0;
165
166        loop {
167            // Guardrail: max iterations
168            if iteration >= config.max_iterations {
169                let error_event = StreamEvent::Error {
170                    message: format!("Max iterations ({}) reached", config.max_iterations),
171                    node_id: None,
172                };
173                event_tx.send(error_event.clone()).await?;
174                break;
175            }
176
177            let node_start = Instant::now();
178            
179            // Store state snapshot before execution for observation
180            let messages_before = state.messages.len();
181
182            // Execute current node (this emits events via event_tx)
183            match current_node {
184                NodeType::LLM => {
185                    llm_node.execute(&mut state, event_tx.clone()).await?;
186                }
187                NodeType::Tool => {
188                    tool_node.execute(&mut state, event_tx.clone()).await?;
189                }
190            }
191
192            let node_duration = node_start.elapsed().as_millis() as u64;
193
194            // After node execution: persistence + observability (fire-and-forget)
195            Self::handle_post_node_execution(
196                &state,
197                current_node,
198                node_start,
199                node_duration,
200                messages_before,
201                &persistence,
202                #[cfg(feature = "observability")]
203                &observer,
204                &ctx,
205            ).await;
206
207            // Route to next node
208            let next = router.next(&state, current_node);
209
210            match next {
211                NextNode::End => break,
212                NextNode::LLM => current_node = NodeType::LLM,
213                NextNode::Tool => current_node = NodeType::Tool,
214            }
215
216            iteration += 1;
217        }
218
219        // Emit end event
220        let total_duration = start_time.elapsed().as_millis() as u64;
221        let end_event = StreamEvent::EndStream {
222            status: "success".to_string(),
223            total_duration_ms: total_duration,
224        };
225        event_tx.send(end_event.clone()).await?;
226        
227        // Finalize tracing
228        #[cfg(feature = "observability")]
229        if let Some(ref obs) = observer {
230            let obs_clone = Arc::clone(&obs.observer);
231            let run_id = state.run_id.clone();
232            tokio::spawn(async move {
233                if let Err(e) = obs_clone.trace_end(run_id, "success".to_string(), total_duration).await {
234                    tracing::error!("Failed to end trace: {}", e);
235                }
236            });
237        }
238
239        Ok(())
240    }
241
242    /// Handle post-node execution: persistence and observability
243    async fn handle_post_node_execution(
244        state: &GraphState,
245        node_type: NodeType,
246        node_start: Instant,
247        #[allow(unused_variables)]
248        node_duration: u64,
249        messages_before: usize,
250        persistence: &Option<Arc<PersistenceConfig>>,
251        #[cfg(feature = "observability")]
252        observer: &Option<Arc<ObserverConfig>>,
253        ctx: &Option<PersistenceContext>,
254    ) {
255        // Extract messages added by this node
256        let new_messages = if state.messages.len() > messages_before {
257            &state.messages[messages_before..]
258        } else {
259            &[]
260        };
261
262        // Persistence: save messages
263        // For LLM nodes, use structured outputs if available; otherwise fallback to messages
264        if let (Some(persist), Some(context)) = (persistence, ctx) {
265            if node_type == NodeType::LLM && state.last_outputs.is_some() {
266                // New approach: Save structured outputs (reasoning + message separately)
267                if let Some(outputs) = &state.last_outputs {
268                    for output in outputs {
269                        let db_message = Self::convert_output_to_db(
270                            output,
271                            &context.thread_id,
272                            &context.user_id,
273                        );
274                        
275                        if let Some(db_msg) = db_message {
276                            let client = Arc::clone(&persist.client);
277                            tokio::spawn(async move {
278                                if let Err(e) = client.save_message(db_msg).await {
279                                    tracing::error!("Failed to save output to database: {}", e);
280                                }
281                            });
282                        }
283                    }
284                }
285            } else {
286                // Fallback: Save messages directly (for Tool nodes or old LLM nodes)
287                for msg in new_messages {
288                    let db_message = Self::convert_message_to_db(
289                        msg,
290                        &context.thread_id,
291                        &context.user_id,
292                        node_type,
293                    );
294                    
295                    if let Some(db_msg) = db_message {
296                        let client = Arc::clone(&persist.client);
297                        tokio::spawn(async move {
298                            if let Err(e) = client.save_message(db_msg).await {
299                                tracing::error!("Failed to save message: {}", e);
300                            }
301                        });
302                    }
303                }
304            }
305        }
306
307        // Observability: send observation
308        #[cfg(feature = "observability")]
309        if let Some(obs) = observer {
310            let observation = Self::create_observation(
311                state,
312                node_type,
313                node_start,
314                node_duration,
315                new_messages,
316            );
317
318            if let Some(obs_data) = observation {
319                let obs_clone = Arc::clone(&obs.observer);
320                tokio::spawn(async move {
321                    let result = match obs_data.node_type.as_str() {
322                        "llm" => obs_clone.trace_llm_node(obs_data).await,
323                        "tool" => obs_clone.trace_tool_node(obs_data).await,
324                        _ => Ok(()),
325                    };
326                    
327                    if let Err(e) = result {
328                        tracing::error!("Failed to trace node execution: {}", e);
329                    }
330                });
331            }
332        }
333    }
334
335    /// Convert GraphOutput to DBMessage
336    fn convert_output_to_db(
337        output: &crate::types::GraphOutput,
338        thread_id: &str,
339        user_id: &str,
340    ) -> Option<praxis_persist::DBMessage> {
341        use crate::types::GraphOutput;
342        use praxis_persist::{MessageRole, MessageType};
343
344        match output {
345            GraphOutput::Reasoning { id, content } => {
346                Some(praxis_persist::DBMessage {
347                    id: uuid::Uuid::new_v4().to_string(),
348                    thread_id: thread_id.to_string(),
349                    user_id: user_id.to_string(),
350                    role: MessageRole::Assistant,
351                    message_type: MessageType::Reasoning,
352                    content: content.clone(),
353                    tool_call_id: None,
354                    tool_name: None,
355                    arguments: None,
356                    reasoning_id: Some(id.clone()),
357                    created_at: chrono::Utc::now(),
358                    duration_ms: None,
359                })
360            }
361            GraphOutput::Message { id, content, tool_calls } => {
362                if let Some(calls) = tool_calls {
363                    // Save first tool call (expand to handle all in production)
364                    if let Some(first_call) = calls.first() {
365                        Some(praxis_persist::DBMessage {
366                            id: uuid::Uuid::new_v4().to_string(),
367                            thread_id: thread_id.to_string(),
368                            user_id: user_id.to_string(),
369                            role: MessageRole::Assistant,
370                            message_type: MessageType::ToolCall,
371                            content: String::new(),
372                            tool_call_id: Some(first_call.id.clone()),
373                            tool_name: Some(first_call.function.name.clone()),
374                            arguments: serde_json::from_str(&first_call.function.arguments).ok(),
375                            reasoning_id: Some(id.clone()),
376                            created_at: chrono::Utc::now(),
377                            duration_ms: None,
378                        })
379                    } else {
380                        None
381                    }
382                } else if !content.is_empty() {
383                    Some(praxis_persist::DBMessage {
384                        id: uuid::Uuid::new_v4().to_string(),
385                        thread_id: thread_id.to_string(),
386                        user_id: user_id.to_string(),
387                        role: MessageRole::Assistant,
388                        message_type: MessageType::Message,
389                        content: content.clone(),
390                        tool_call_id: None,
391                        tool_name: None,
392                        arguments: None,
393                        reasoning_id: Some(id.clone()),
394                        created_at: chrono::Utc::now(),
395                        duration_ms: None,
396                    })
397                } else {
398                    None
399                }
400            }
401        }
402    }
403    
404    /// Convert praxis-llm Message to praxis-persist DBMessage
405    fn convert_message_to_db(
406        msg: &praxis_llm::Message,
407        thread_id: &str,
408        user_id: &str,
409        _node_type: NodeType,
410    ) -> Option<praxis_persist::DBMessage> {
411        use praxis_llm::Message;
412        use praxis_persist::{MessageRole, MessageType};
413
414        match msg {
415            Message::AI { content, tool_calls, .. } => {
416                if let Some(calls) = tool_calls {
417                    // Save tool calls as separate messages
418                    // For simplicity, we'll create a message for the first tool call
419                    // In production, you might want to handle all tool calls
420                    if let Some(first_call) = calls.first() {
421                        Some(praxis_persist::DBMessage {
422                            id: uuid::Uuid::new_v4().to_string(),
423                            thread_id: thread_id.to_string(),
424                            user_id: user_id.to_string(),
425                            role: MessageRole::Assistant,
426                            message_type: MessageType::ToolCall,
427                            content: String::new(),
428                            tool_call_id: Some(first_call.id.clone()),
429                            tool_name: Some(first_call.function.name.clone()),
430                            arguments: serde_json::from_str(&first_call.function.arguments).ok(),
431                            reasoning_id: None,
432                            created_at: chrono::Utc::now(),
433                            duration_ms: None,
434                        })
435                    } else {
436                        None
437                    }
438                } else if let Some(content) = content {
439                    Some(praxis_persist::DBMessage {
440                        id: uuid::Uuid::new_v4().to_string(),
441                        thread_id: thread_id.to_string(),
442                        user_id: user_id.to_string(),
443                        role: MessageRole::Assistant,
444                        message_type: MessageType::Message,
445                        content: content.as_text().unwrap_or("").to_string(),
446                        tool_call_id: None,
447                        tool_name: None,
448                        arguments: None,
449                        reasoning_id: None,
450                        created_at: chrono::Utc::now(),
451                        duration_ms: None,
452                    })
453                } else {
454                    None
455                }
456            }
457            Message::Tool { tool_call_id, content } => {
458                Some(praxis_persist::DBMessage {
459                    id: uuid::Uuid::new_v4().to_string(),
460                    thread_id: thread_id.to_string(),
461                    user_id: user_id.to_string(),
462                    role: MessageRole::Assistant,
463                    message_type: MessageType::ToolResult,
464                    content: content.as_text().unwrap_or("").to_string(),
465                    tool_call_id: Some(tool_call_id.clone()),
466                    tool_name: None,
467                    arguments: None,
468                    reasoning_id: None,
469                    created_at: chrono::Utc::now(),
470                    duration_ms: None,
471                })
472            }
473            _ => None,
474            }
475        }
476
477    /// Create observation data for tracing
478    #[cfg(feature = "observability")]
479    fn create_observation(
480        state: &GraphState,
481        node_type: NodeType,
482        _node_start: Instant,
483        node_duration: u64,
484        new_messages: &[praxis_llm::Message],
485    ) -> Option<praxis_observability::NodeObservation> {
486        use praxis_observability::{NodeObservation, NodeObservationData, NodeOutput, LangfuseMessage, ToolCallInfo, ToolResultInfo};
487        use crate::types::GraphOutput;
488
489        let span_id = uuid::Uuid::new_v4().to_string();
490        let started_at = chrono::Utc::now() - chrono::Duration::milliseconds(node_duration as i64);
491
492        match node_type {
493            NodeType::LLM => {
494                let input_count = state.messages.len() - new_messages.len();
495                
496                tracing::info!(
497                    "LLM observation - total messages: {}, input_count: {}, new_messages: {}",
498                    state.messages.len(),
499                    input_count,
500                    new_messages.len()
501                );
502                
503                let input_messages: Vec<LangfuseMessage> = state.messages[..input_count]
504                    .iter()
505                    .filter_map(Self::convert_to_langfuse_message)
506                    .collect();
507
508                // Use structured outputs if available
509                let outputs = if let Some(ref last_outputs) = state.last_outputs {
510                    last_outputs.iter().map(|output| {
511                        match output {
512                            GraphOutput::Reasoning { id, content } => {
513                                NodeOutput::Reasoning {
514                                    id: id.clone(),
515                                    content: content.clone(),
516                                }
517                            }
518                            GraphOutput::Message { id, content, tool_calls } => {
519                                if tool_calls.is_some() {
520                                    NodeOutput::ToolCalls {
521                                        calls: tool_calls.as_ref().unwrap().iter().map(|call| {
522                                            ToolCallInfo {
523                                                id: call.id.clone(),
524                                                name: call.function.name.clone(),
525                                                arguments: serde_json::from_str(&call.function.arguments)
526                                                    .unwrap_or(serde_json::json!({})),
527                                            }
528                                        }).collect(),
529                                    }
530                                } else {
531                                    NodeOutput::Message {
532                                        id: id.clone(),
533                                        content: content.clone(),
534                                    }
535                                }
536                            }
537                        }
538                    }).collect()
539                } else {
540                    // Fallback: convert from new_messages
541                    vec![]
542                };
543
544                if outputs.is_empty() {
545                    tracing::warn!("No outputs available for LLM observation");
546                    return None;
547                }
548
549                tracing::info!(
550                    "Created LLM observation: input_messages={}, outputs={}",
551                    input_messages.len(),
552                    outputs.len()
553                );
554
555                Some(NodeObservation {
556                    span_id,
557                    run_id: state.run_id.clone(),
558                    conversation_id: state.conversation_id.clone(),
559                    node_type: "llm".to_string(),
560                    started_at,
561                    duration_ms: node_duration,
562                    data: NodeObservationData::Llm {
563                        input_messages,
564                        outputs,
565                        model: state.llm_config.model.clone(),
566                        usage: None,
567                    },
568                    metadata: std::collections::HashMap::new(),
569                })
570            }
571            NodeType::Tool => {
572                // Extract tool calls from previous AI message
573                let tool_calls: Vec<ToolCallInfo> = state.messages
574                    .iter()
575                    .rev()
576                    .find_map(|msg| match msg {
577                        praxis_llm::Message::AI { tool_calls: Some(calls), .. } => {
578                            Some(calls.iter().map(|call| ToolCallInfo {
579                                id: call.id.clone(),
580                                name: call.function.name.clone(),
581                                arguments: serde_json::from_str(&call.function.arguments)
582                                    .unwrap_or(serde_json::json!({})),
583                            }).collect())
584                        }
585                        _ => None,
586                    })?;
587
588                // Extract tool results from new messages
589                let tool_results: Vec<ToolResultInfo> = new_messages
590                    .iter()
591                    .filter_map(|msg| match msg {
592                        praxis_llm::Message::Tool { tool_call_id, content } => {
593                            Some(ToolResultInfo {
594                                tool_call_id: tool_call_id.clone(),
595                                tool_name: "unknown".to_string(), // TODO: track tool name
596                                result: content.as_text().unwrap_or("").to_string(),
597                                is_error: false,
598                                duration_ms: 0, // TODO: track individual tool duration
599                            })
600                        }
601                        _ => None,
602                    })
603                    .collect();
604
605                tracing::debug!(
606                    "Creating Tool observation: tool_calls_count={}, tool_results_count={}",
607                    tool_calls.len(),
608                    tool_results.len()
609                );
610
611                Some(NodeObservation {
612                    span_id,
613                    run_id: state.run_id.clone(),
614                    conversation_id: state.conversation_id.clone(),
615                    node_type: "tool".to_string(),
616                    started_at,
617                    duration_ms: node_duration,
618                    data: NodeObservationData::Tool {
619                        tool_calls,
620                        tool_results,
621                    },
622                    metadata: std::collections::HashMap::new(),
623                })
624            }
625        }
626    }
627
628    /// Convert praxis-llm Message to Langfuse format
629    #[cfg(feature = "observability")]
630    fn convert_to_langfuse_message(msg: &praxis_llm::Message) -> Option<praxis_observability::LangfuseMessage> {
631        use praxis_observability::{LangfuseMessage, ToolCallInfo};
632
633        match msg {
634            praxis_llm::Message::System { content, .. } => Some(LangfuseMessage {
635                role: "system".to_string(),
636                content: content.as_text().unwrap_or("").to_string(),
637                name: None,
638                tool_call_id: None,
639                tool_calls: None,
640            }),
641            praxis_llm::Message::Human { content, .. } => Some(LangfuseMessage {
642                role: "user".to_string(),
643                content: content.as_text().unwrap_or("").to_string(),
644                name: None,
645                tool_call_id: None,
646                tool_calls: None,
647            }),
648            praxis_llm::Message::AI { content, tool_calls, .. } => {
649                let tool_calls_converted = tool_calls.as_ref().map(|calls| {
650                    calls.iter().map(|call| ToolCallInfo {
651                        id: call.id.clone(),
652                        name: call.function.name.clone(),
653                        arguments: serde_json::from_str(&call.function.arguments)
654                            .unwrap_or(serde_json::json!({})),
655                    }).collect()
656                });
657
658                Some(LangfuseMessage {
659                    role: "assistant".to_string(),
660                    content: content.as_ref()
661                        .and_then(|c| c.as_text())
662                        .unwrap_or("")
663                        .to_string(),
664                    name: None,
665                    tool_call_id: None,
666                    tool_calls: tool_calls_converted,
667                })
668            }
669            praxis_llm::Message::Tool { tool_call_id, content } => Some(LangfuseMessage {
670                role: "tool".to_string(),
671                content: content.as_text().unwrap_or("").to_string(),
672                name: None,
673                tool_call_id: Some(tool_call_id.clone()),
674                tool_calls: None,
675            }),
676        }
677    }
678}
679