1use crate::node::{Node, NodeType};
2use crate::nodes::{LLMNode, ToolNode};
3use crate::router::{NextNode, Router, SimpleRouter};
4use anyhow::Result;
5use praxis_llm::LLMClient;
6use praxis_mcp::MCPToolExecutor;
7use praxis_types::{GraphConfig, GraphInput, GraphState, StreamEvent};
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::mpsc;
11
12pub struct Graph {
13 llm_client: Arc<dyn LLMClient>,
14 mcp_executor: Arc<MCPToolExecutor>,
15 config: GraphConfig,
16}
17
18impl Graph {
19 pub fn new(
20 llm_client: Arc<dyn LLMClient>,
21 mcp_executor: Arc<MCPToolExecutor>,
22 config: GraphConfig,
23 ) -> Self {
24 Self {
25 llm_client,
26 mcp_executor,
27 config,
28 }
29 }
30
31 pub fn spawn_run(&self, input: GraphInput) -> mpsc::Receiver<StreamEvent> {
33 let (tx, rx) = mpsc::channel(1000);
34
35 let llm_client = Arc::clone(&self.llm_client);
37 let mcp_executor = Arc::clone(&self.mcp_executor);
38 let config = self.config.clone();
39
40 tokio::spawn(async move {
41 if let Err(e) = Self::execute_loop(input, tx.clone(), llm_client, mcp_executor, config).await {
42 let _ = tx
43 .send(StreamEvent::Error {
44 message: e.to_string(),
45 node_id: None,
46 })
47 .await;
48 }
49 });
50
51 rx
52 }
53
54 async fn execute_loop(
55 input: GraphInput,
56 event_tx: mpsc::Sender<StreamEvent>,
57 llm_client: Arc<dyn LLMClient>,
58 mcp_executor: Arc<MCPToolExecutor>,
59 config: GraphConfig,
60 ) -> Result<()> {
61 let start_time = Instant::now();
62
63 let mut state = GraphState::from_input(input);
65
66 event_tx
68 .send(StreamEvent::InitStream {
69 run_id: state.run_id.clone(),
70 conversation_id: state.conversation_id.clone(),
71 timestamp: chrono::Utc::now().timestamp_millis(),
72 })
73 .await?;
74
75 let llm_node = LLMNode::new(llm_client, mcp_executor.clone());
77 let tool_node = ToolNode::new(mcp_executor);
78 let router = SimpleRouter;
79
80 let mut current_node = NodeType::LLM;
81 let mut iteration = 0;
82
83 loop {
84 if iteration >= config.max_iterations {
86 event_tx
87 .send(StreamEvent::Error {
88 message: format!("Max iterations ({}) reached", config.max_iterations),
89 node_id: None,
90 })
91 .await?;
92 break;
93 }
94
95 match current_node {
97 NodeType::LLM => {
98 llm_node.execute(&mut state, event_tx.clone()).await?;
99 }
100 NodeType::Tool => {
101 tool_node.execute(&mut state, event_tx.clone()).await?;
102 }
103 }
104
105 let next = router.next(&state, current_node);
107
108 match next {
109 NextNode::End => break,
110 NextNode::LLM => current_node = NodeType::LLM,
111 NextNode::Tool => current_node = NodeType::Tool,
112 }
113
114 iteration += 1;
115 }
116
117 let total_duration = start_time.elapsed().as_millis() as u64;
119 event_tx
120 .send(StreamEvent::EndStream {
121 status: "success".to_string(),
122 total_duration_ms: total_duration,
123 })
124 .await?;
125
126 Ok(())
131 }
132}
133