1use crate::node::{Node, NodeType};
2use crate::nodes::{LLMNode, ToolNode};
3use crate::router::{NextNode, Router, SimpleRouter};
4use crate::builder::PersistenceConfig;
5use praxis_llm::ReasoningClient;
6#[cfg(feature = "observability")]
7use crate::builder::ObserverConfig;
8use anyhow::Result;
9use praxis_llm::LLMClient;
10use praxis_mcp::MCPToolExecutor;
11use crate::types::{GraphConfig, GraphInput, GraphState, StreamEvent};
12use std::sync::Arc;
13use std::time::Instant;
14use tokio::sync::mpsc;
15
16pub struct PersistenceContext {
18 pub thread_id: String,
19 pub user_id: String,
20}
21
22pub struct Graph {
23 llm_client: Arc<dyn LLMClient>,
24 reasoning_client: Option<Arc<dyn praxis_llm::ReasoningClient>>,
25 mcp_executor: Arc<MCPToolExecutor>,
26 config: GraphConfig,
27 persistence: Option<Arc<PersistenceConfig>>,
28 #[cfg(feature = "observability")]
29 observer: Option<Arc<ObserverConfig>>,
30}
31
32impl Graph {
33 pub fn new(
34 llm_client: Arc<dyn LLMClient>,
35 mcp_executor: Arc<MCPToolExecutor>,
36 config: GraphConfig,
37 ) -> Self {
38 Self {
39 llm_client,
40 reasoning_client: None,
41 mcp_executor,
42 config,
43 persistence: None,
44 #[cfg(feature = "observability")]
45 observer: None,
46 }
47 }
48
49 pub(crate) fn new_with_config(
50 llm_client: Arc<dyn LLMClient>,
51 reasoning_client: Option<Arc<dyn praxis_llm::ReasoningClient>>,
52 mcp_executor: Arc<MCPToolExecutor>,
53 config: GraphConfig,
54 persistence: Option<PersistenceConfig>,
55 #[cfg(feature = "observability")]
56 observer: Option<ObserverConfig>,
57 ) -> Self {
58 Self {
59 llm_client,
60 reasoning_client,
61 mcp_executor,
62 config,
63 persistence: persistence.map(Arc::new),
64 #[cfg(feature = "observability")]
65 observer: observer.map(Arc::new),
66 }
67 }
68
69 pub fn builder() -> crate::builder::GraphBuilder {
71 crate::builder::GraphBuilder::new()
72 }
73
74 pub fn spawn_run(
76 &self,
77 input: GraphInput,
78 persistence_ctx: Option<PersistenceContext>,
79 ) -> mpsc::Receiver<StreamEvent> {
80 let (tx, rx) = mpsc::channel(1000);
81
82 let llm_client = Arc::clone(&self.llm_client);
84 let reasoning_client = self.reasoning_client.clone();
85 let mcp_executor = Arc::clone(&self.mcp_executor);
86 let config = self.config.clone();
87 let persistence = self.persistence.clone();
88 #[cfg(feature = "observability")]
89 let observer = self.observer.clone();
90
91 tokio::spawn(async move {
92 if let Err(e) = Self::execute_loop(
93 input,
94 tx.clone(),
95 llm_client,
96 reasoning_client,
97 mcp_executor,
98 config,
99 persistence,
100 #[cfg(feature = "observability")]
101 observer,
102 persistence_ctx,
103 ).await {
104 let _ = tx
105 .send(StreamEvent::Error {
106 message: e.to_string(),
107 node_id: None,
108 })
109 .await;
110 }
111 });
112
113 rx
114 }
115
116 async fn execute_loop(
117 input: GraphInput,
118 event_tx: mpsc::Sender<StreamEvent>,
119 llm_client: Arc<dyn LLMClient>,
120 reasoning_client: Option<Arc<dyn ReasoningClient>>,
121 mcp_executor: Arc<MCPToolExecutor>,
122 config: GraphConfig,
123 persistence: Option<Arc<PersistenceConfig>>,
124 #[cfg(feature = "observability")]
125 observer: Option<Arc<ObserverConfig>>,
126 ctx: Option<PersistenceContext>,
127 ) -> Result<()> {
128 let start_time = Instant::now();
129
130 let mut state = GraphState::from_input(input);
132
133 #[cfg(feature = "observability")]
135 if let Some(ref obs) = observer {
136 let obs_clone = Arc::clone(&obs.observer);
137 let run_id = state.run_id.clone();
138 let conversation_id = state.conversation_id.clone();
139 tokio::spawn(async move {
140 if let Err(e) = obs_clone.trace_start(run_id, conversation_id).await {
141 tracing::error!("Failed to start trace: {}", e);
142 }
143 });
144 }
145
146 let init_event = StreamEvent::InitStream {
148 run_id: state.run_id.clone(),
149 conversation_id: state.conversation_id.clone(),
150 timestamp: chrono::Utc::now().timestamp_millis(),
151 };
152 event_tx.send(init_event.clone()).await?;
153
154 let mut llm_node = LLMNode::new(llm_client.clone(), mcp_executor.clone());
156
157 if let Some(reasoning_client) = reasoning_client.clone() {
158 llm_node = llm_node.with_reasoning_client(reasoning_client);
159 }
160 let tool_node = ToolNode::new(mcp_executor);
161 let router = SimpleRouter;
162
163 let mut current_node = NodeType::LLM;
164 let mut iteration = 0;
165
166 loop {
167 if iteration >= config.max_iterations {
169 let error_event = StreamEvent::Error {
170 message: format!("Max iterations ({}) reached", config.max_iterations),
171 node_id: None,
172 };
173 event_tx.send(error_event.clone()).await?;
174 break;
175 }
176
177 let node_start = Instant::now();
178
179 let messages_before = state.messages.len();
181
182 match current_node {
184 NodeType::LLM => {
185 llm_node.execute(&mut state, event_tx.clone()).await?;
186 }
187 NodeType::Tool => {
188 tool_node.execute(&mut state, event_tx.clone()).await?;
189 }
190 }
191
192 let node_duration = node_start.elapsed().as_millis() as u64;
193
194 Self::handle_post_node_execution(
196 &state,
197 current_node,
198 node_start,
199 node_duration,
200 messages_before,
201 &persistence,
202 #[cfg(feature = "observability")]
203 &observer,
204 &ctx,
205 ).await;
206
207 let next = router.next(&state, current_node);
209
210 match next {
211 NextNode::End => break,
212 NextNode::LLM => current_node = NodeType::LLM,
213 NextNode::Tool => current_node = NodeType::Tool,
214 }
215
216 iteration += 1;
217 }
218
219 let total_duration = start_time.elapsed().as_millis() as u64;
221 let end_event = StreamEvent::EndStream {
222 status: "success".to_string(),
223 total_duration_ms: total_duration,
224 };
225 event_tx.send(end_event.clone()).await?;
226
227 #[cfg(feature = "observability")]
229 if let Some(ref obs) = observer {
230 let obs_clone = Arc::clone(&obs.observer);
231 let run_id = state.run_id.clone();
232 tokio::spawn(async move {
233 if let Err(e) = obs_clone.trace_end(run_id, "success".to_string(), total_duration).await {
234 tracing::error!("Failed to end trace: {}", e);
235 }
236 });
237 }
238
239 Ok(())
240 }
241
242 async fn handle_post_node_execution(
244 state: &GraphState,
245 node_type: NodeType,
246 node_start: Instant,
247 #[allow(unused_variables)]
248 node_duration: u64,
249 messages_before: usize,
250 persistence: &Option<Arc<PersistenceConfig>>,
251 #[cfg(feature = "observability")]
252 observer: &Option<Arc<ObserverConfig>>,
253 ctx: &Option<PersistenceContext>,
254 ) {
255 let new_messages = if state.messages.len() > messages_before {
257 &state.messages[messages_before..]
258 } else {
259 &[]
260 };
261
262 if let (Some(persist), Some(context)) = (persistence, ctx) {
265 if node_type == NodeType::LLM && state.last_outputs.is_some() {
266 if let Some(outputs) = &state.last_outputs {
268 for output in outputs {
269 let db_message = Self::convert_output_to_db(
270 output,
271 &context.thread_id,
272 &context.user_id,
273 );
274
275 if let Some(db_msg) = db_message {
276 let client = Arc::clone(&persist.client);
277 tokio::spawn(async move {
278 if let Err(e) = client.save_message(db_msg).await {
279 tracing::error!("Failed to save output to database: {}", e);
280 }
281 });
282 }
283 }
284 }
285 } else {
286 for msg in new_messages {
288 let db_message = Self::convert_message_to_db(
289 msg,
290 &context.thread_id,
291 &context.user_id,
292 node_type,
293 );
294
295 if let Some(db_msg) = db_message {
296 let client = Arc::clone(&persist.client);
297 tokio::spawn(async move {
298 if let Err(e) = client.save_message(db_msg).await {
299 tracing::error!("Failed to save message: {}", e);
300 }
301 });
302 }
303 }
304 }
305 }
306
307 #[cfg(feature = "observability")]
309 if let Some(obs) = observer {
310 let observation = Self::create_observation(
311 state,
312 node_type,
313 node_start,
314 node_duration,
315 new_messages,
316 );
317
318 if let Some(obs_data) = observation {
319 let obs_clone = Arc::clone(&obs.observer);
320 tokio::spawn(async move {
321 let result = match obs_data.node_type.as_str() {
322 "llm" => obs_clone.trace_llm_node(obs_data).await,
323 "tool" => obs_clone.trace_tool_node(obs_data).await,
324 _ => Ok(()),
325 };
326
327 if let Err(e) = result {
328 tracing::error!("Failed to trace node execution: {}", e);
329 }
330 });
331 }
332 }
333 }
334
335 fn convert_output_to_db(
337 output: &crate::types::GraphOutput,
338 thread_id: &str,
339 user_id: &str,
340 ) -> Option<praxis_persist::DBMessage> {
341 use crate::types::GraphOutput;
342 use praxis_persist::{MessageRole, MessageType};
343
344 match output {
345 GraphOutput::Reasoning { id, content } => {
346 Some(praxis_persist::DBMessage {
347 id: uuid::Uuid::new_v4().to_string(),
348 thread_id: thread_id.to_string(),
349 user_id: user_id.to_string(),
350 role: MessageRole::Assistant,
351 message_type: MessageType::Reasoning,
352 content: content.clone(),
353 tool_call_id: None,
354 tool_name: None,
355 arguments: None,
356 reasoning_id: Some(id.clone()),
357 created_at: chrono::Utc::now(),
358 duration_ms: None,
359 })
360 }
361 GraphOutput::Message { id, content, tool_calls } => {
362 if let Some(calls) = tool_calls {
363 if let Some(first_call) = calls.first() {
365 Some(praxis_persist::DBMessage {
366 id: uuid::Uuid::new_v4().to_string(),
367 thread_id: thread_id.to_string(),
368 user_id: user_id.to_string(),
369 role: MessageRole::Assistant,
370 message_type: MessageType::ToolCall,
371 content: String::new(),
372 tool_call_id: Some(first_call.id.clone()),
373 tool_name: Some(first_call.function.name.clone()),
374 arguments: serde_json::from_str(&first_call.function.arguments).ok(),
375 reasoning_id: Some(id.clone()),
376 created_at: chrono::Utc::now(),
377 duration_ms: None,
378 })
379 } else {
380 None
381 }
382 } else if !content.is_empty() {
383 Some(praxis_persist::DBMessage {
384 id: uuid::Uuid::new_v4().to_string(),
385 thread_id: thread_id.to_string(),
386 user_id: user_id.to_string(),
387 role: MessageRole::Assistant,
388 message_type: MessageType::Message,
389 content: content.clone(),
390 tool_call_id: None,
391 tool_name: None,
392 arguments: None,
393 reasoning_id: Some(id.clone()),
394 created_at: chrono::Utc::now(),
395 duration_ms: None,
396 })
397 } else {
398 None
399 }
400 }
401 }
402 }
403
404 fn convert_message_to_db(
406 msg: &praxis_llm::Message,
407 thread_id: &str,
408 user_id: &str,
409 _node_type: NodeType,
410 ) -> Option<praxis_persist::DBMessage> {
411 use praxis_llm::Message;
412 use praxis_persist::{MessageRole, MessageType};
413
414 match msg {
415 Message::AI { content, tool_calls, .. } => {
416 if let Some(calls) = tool_calls {
417 if let Some(first_call) = calls.first() {
421 Some(praxis_persist::DBMessage {
422 id: uuid::Uuid::new_v4().to_string(),
423 thread_id: thread_id.to_string(),
424 user_id: user_id.to_string(),
425 role: MessageRole::Assistant,
426 message_type: MessageType::ToolCall,
427 content: String::new(),
428 tool_call_id: Some(first_call.id.clone()),
429 tool_name: Some(first_call.function.name.clone()),
430 arguments: serde_json::from_str(&first_call.function.arguments).ok(),
431 reasoning_id: None,
432 created_at: chrono::Utc::now(),
433 duration_ms: None,
434 })
435 } else {
436 None
437 }
438 } else if let Some(content) = content {
439 Some(praxis_persist::DBMessage {
440 id: uuid::Uuid::new_v4().to_string(),
441 thread_id: thread_id.to_string(),
442 user_id: user_id.to_string(),
443 role: MessageRole::Assistant,
444 message_type: MessageType::Message,
445 content: content.as_text().unwrap_or("").to_string(),
446 tool_call_id: None,
447 tool_name: None,
448 arguments: None,
449 reasoning_id: None,
450 created_at: chrono::Utc::now(),
451 duration_ms: None,
452 })
453 } else {
454 None
455 }
456 }
457 Message::Tool { tool_call_id, content } => {
458 Some(praxis_persist::DBMessage {
459 id: uuid::Uuid::new_v4().to_string(),
460 thread_id: thread_id.to_string(),
461 user_id: user_id.to_string(),
462 role: MessageRole::Assistant,
463 message_type: MessageType::ToolResult,
464 content: content.as_text().unwrap_or("").to_string(),
465 tool_call_id: Some(tool_call_id.clone()),
466 tool_name: None,
467 arguments: None,
468 reasoning_id: None,
469 created_at: chrono::Utc::now(),
470 duration_ms: None,
471 })
472 }
473 _ => None,
474 }
475 }
476
477 #[cfg(feature = "observability")]
479 fn create_observation(
480 state: &GraphState,
481 node_type: NodeType,
482 _node_start: Instant,
483 node_duration: u64,
484 new_messages: &[praxis_llm::Message],
485 ) -> Option<praxis_observability::NodeObservation> {
486 use praxis_observability::{NodeObservation, NodeObservationData, NodeOutput, LangfuseMessage, ToolCallInfo, ToolResultInfo};
487 use crate::types::GraphOutput;
488
489 let span_id = uuid::Uuid::new_v4().to_string();
490 let started_at = chrono::Utc::now() - chrono::Duration::milliseconds(node_duration as i64);
491
492 match node_type {
493 NodeType::LLM => {
494 let input_count = state.messages.len() - new_messages.len();
495
496 tracing::info!(
497 "LLM observation - total messages: {}, input_count: {}, new_messages: {}",
498 state.messages.len(),
499 input_count,
500 new_messages.len()
501 );
502
503 let input_messages: Vec<LangfuseMessage> = state.messages[..input_count]
504 .iter()
505 .filter_map(Self::convert_to_langfuse_message)
506 .collect();
507
508 let outputs = if let Some(ref last_outputs) = state.last_outputs {
510 last_outputs.iter().map(|output| {
511 match output {
512 GraphOutput::Reasoning { id, content } => {
513 NodeOutput::Reasoning {
514 id: id.clone(),
515 content: content.clone(),
516 }
517 }
518 GraphOutput::Message { id, content, tool_calls } => {
519 if tool_calls.is_some() {
520 NodeOutput::ToolCalls {
521 calls: tool_calls.as_ref().unwrap().iter().map(|call| {
522 ToolCallInfo {
523 id: call.id.clone(),
524 name: call.function.name.clone(),
525 arguments: serde_json::from_str(&call.function.arguments)
526 .unwrap_or(serde_json::json!({})),
527 }
528 }).collect(),
529 }
530 } else {
531 NodeOutput::Message {
532 id: id.clone(),
533 content: content.clone(),
534 }
535 }
536 }
537 }
538 }).collect()
539 } else {
540 vec![]
542 };
543
544 if outputs.is_empty() {
545 tracing::warn!("No outputs available for LLM observation");
546 return None;
547 }
548
549 tracing::info!(
550 "Created LLM observation: input_messages={}, outputs={}",
551 input_messages.len(),
552 outputs.len()
553 );
554
555 Some(NodeObservation {
556 span_id,
557 run_id: state.run_id.clone(),
558 conversation_id: state.conversation_id.clone(),
559 node_type: "llm".to_string(),
560 started_at,
561 duration_ms: node_duration,
562 data: NodeObservationData::Llm {
563 input_messages,
564 outputs,
565 model: state.llm_config.model.clone(),
566 usage: None,
567 },
568 metadata: std::collections::HashMap::new(),
569 })
570 }
571 NodeType::Tool => {
572 let tool_calls: Vec<ToolCallInfo> = state.messages
574 .iter()
575 .rev()
576 .find_map(|msg| match msg {
577 praxis_llm::Message::AI { tool_calls: Some(calls), .. } => {
578 Some(calls.iter().map(|call| ToolCallInfo {
579 id: call.id.clone(),
580 name: call.function.name.clone(),
581 arguments: serde_json::from_str(&call.function.arguments)
582 .unwrap_or(serde_json::json!({})),
583 }).collect())
584 }
585 _ => None,
586 })?;
587
588 let tool_results: Vec<ToolResultInfo> = new_messages
590 .iter()
591 .filter_map(|msg| match msg {
592 praxis_llm::Message::Tool { tool_call_id, content } => {
593 Some(ToolResultInfo {
594 tool_call_id: tool_call_id.clone(),
595 tool_name: "unknown".to_string(), result: content.as_text().unwrap_or("").to_string(),
597 is_error: false,
598 duration_ms: 0, })
600 }
601 _ => None,
602 })
603 .collect();
604
605 tracing::debug!(
606 "Creating Tool observation: tool_calls_count={}, tool_results_count={}",
607 tool_calls.len(),
608 tool_results.len()
609 );
610
611 Some(NodeObservation {
612 span_id,
613 run_id: state.run_id.clone(),
614 conversation_id: state.conversation_id.clone(),
615 node_type: "tool".to_string(),
616 started_at,
617 duration_ms: node_duration,
618 data: NodeObservationData::Tool {
619 tool_calls,
620 tool_results,
621 },
622 metadata: std::collections::HashMap::new(),
623 })
624 }
625 }
626 }
627
628 #[cfg(feature = "observability")]
630 fn convert_to_langfuse_message(msg: &praxis_llm::Message) -> Option<praxis_observability::LangfuseMessage> {
631 use praxis_observability::{LangfuseMessage, ToolCallInfo};
632
633 match msg {
634 praxis_llm::Message::System { content, .. } => Some(LangfuseMessage {
635 role: "system".to_string(),
636 content: content.as_text().unwrap_or("").to_string(),
637 name: None,
638 tool_call_id: None,
639 tool_calls: None,
640 }),
641 praxis_llm::Message::Human { content, .. } => Some(LangfuseMessage {
642 role: "user".to_string(),
643 content: content.as_text().unwrap_or("").to_string(),
644 name: None,
645 tool_call_id: None,
646 tool_calls: None,
647 }),
648 praxis_llm::Message::AI { content, tool_calls, .. } => {
649 let tool_calls_converted = tool_calls.as_ref().map(|calls| {
650 calls.iter().map(|call| ToolCallInfo {
651 id: call.id.clone(),
652 name: call.function.name.clone(),
653 arguments: serde_json::from_str(&call.function.arguments)
654 .unwrap_or(serde_json::json!({})),
655 }).collect()
656 });
657
658 Some(LangfuseMessage {
659 role: "assistant".to_string(),
660 content: content.as_ref()
661 .and_then(|c| c.as_text())
662 .unwrap_or("")
663 .to_string(),
664 name: None,
665 tool_call_id: None,
666 tool_calls: tool_calls_converted,
667 })
668 }
669 praxis_llm::Message::Tool { tool_call_id, content } => Some(LangfuseMessage {
670 role: "tool".to_string(),
671 content: content.as_text().unwrap_or("").to_string(),
672 name: None,
673 tool_call_id: Some(tool_call_id.clone()),
674 tool_calls: None,
675 }),
676 }
677 }
678}
679