1use crate::node::{Node, NodeType};
2use crate::nodes::{LLMNode, ToolNode};
3use crate::router::{NextNode, Router, SimpleRouter};
4use crate::builder::PersistenceConfig;
5use anyhow::Result;
6use praxis_llm::LLMClient;
7use praxis_mcp::MCPToolExecutor;
8use crate::types::{GraphConfig, GraphInput, GraphState, StreamEvent};
9use std::sync::Arc;
10use std::time::Instant;
11use tokio::sync::mpsc;
12
13pub struct PersistenceContext {
15 pub thread_id: String,
16 pub user_id: String,
17}
18
19pub struct Graph {
20 llm_client: Arc<dyn LLMClient>,
21 mcp_executor: Arc<MCPToolExecutor>,
22 config: GraphConfig,
23 persistence: Option<Arc<PersistenceConfig>>,
24}
25
26impl Graph {
27 pub fn new(
28 llm_client: Arc<dyn LLMClient>,
29 mcp_executor: Arc<MCPToolExecutor>,
30 config: GraphConfig,
31 ) -> Self {
32 Self {
33 llm_client,
34 mcp_executor,
35 config,
36 persistence: None,
37 }
38 }
39
40 pub(crate) fn new_with_persistence(
41 llm_client: Arc<dyn LLMClient>,
42 mcp_executor: Arc<MCPToolExecutor>,
43 config: GraphConfig,
44 persistence: Option<PersistenceConfig>,
45 ) -> Self {
46 Self {
47 llm_client,
48 mcp_executor,
49 config,
50 persistence: persistence.map(Arc::new),
51 }
52 }
53
54 pub fn builder() -> crate::builder::GraphBuilder {
56 crate::builder::GraphBuilder::new()
57 }
58
59 pub fn spawn_run(
61 &self,
62 input: GraphInput,
63 persistence_ctx: Option<PersistenceContext>,
64 ) -> mpsc::Receiver<StreamEvent> {
65 let (tx, rx) = mpsc::channel(1000);
66
67 let llm_client = Arc::clone(&self.llm_client);
69 let mcp_executor = Arc::clone(&self.mcp_executor);
70 let config = self.config.clone();
71 let persistence = self.persistence.clone();
72
73 tokio::spawn(async move {
74 if let Err(e) = Self::execute_loop(
75 input,
76 tx.clone(),
77 llm_client,
78 mcp_executor,
79 config,
80 persistence,
81 persistence_ctx,
82 ).await {
83 let _ = tx
84 .send(StreamEvent::Error {
85 message: e.to_string(),
86 node_id: None,
87 })
88 .await;
89 }
90 });
91
92 rx
93 }
94
95 async fn execute_loop(
96 input: GraphInput,
97 event_tx: mpsc::Sender<StreamEvent>,
98 llm_client: Arc<dyn LLMClient>,
99 mcp_executor: Arc<MCPToolExecutor>,
100 config: GraphConfig,
101 persistence: Option<Arc<PersistenceConfig>>,
102 ctx: Option<PersistenceContext>,
103 ) -> Result<()> {
104 let start_time = Instant::now();
105
106 let mut state = GraphState::from_input(input);
108
109 let mut accumulator: Option<praxis_persist::EventAccumulator<StreamEvent>> = match (&persistence, &ctx) {
111 (Some(_), Some(c)) => Some(praxis_persist::EventAccumulator::new(
112 c.thread_id.clone(),
113 c.user_id.clone(),
114 )),
115 _ => None,
116 };
117
118 let init_event = StreamEvent::InitStream {
120 run_id: state.run_id.clone(),
121 conversation_id: state.conversation_id.clone(),
122 timestamp: chrono::Utc::now().timestamp_millis(),
123 };
124 event_tx.send(init_event.clone()).await?;
125
126 let llm_node = LLMNode::new(llm_client, mcp_executor.clone());
128 let tool_node = ToolNode::new(mcp_executor);
129 let router = SimpleRouter;
130
131 let mut current_node = NodeType::LLM;
132 let mut iteration = 0;
133
134 loop {
135 if iteration >= config.max_iterations {
137 let error_event = StreamEvent::Error {
138 message: format!("Max iterations ({}) reached", config.max_iterations),
139 node_id: None,
140 };
141 event_tx.send(error_event.clone()).await?;
142 break;
143 }
144
145 match current_node {
147 NodeType::LLM => {
148 llm_node.execute(&mut state, event_tx.clone()).await?;
149 }
150 NodeType::Tool => {
151 tool_node.execute(&mut state, event_tx.clone()).await?;
152 }
153 }
154
155 let next = router.next(&state, current_node);
157
158 match next {
159 NextNode::End => break,
160 NextNode::LLM => current_node = NodeType::LLM,
161 NextNode::Tool => current_node = NodeType::Tool,
162 }
163
164 iteration += 1;
165 }
166
167 let total_duration = start_time.elapsed().as_millis() as u64;
169 let end_event = StreamEvent::EndStream {
170 status: "success".to_string(),
171 total_duration_ms: total_duration,
172 };
173 event_tx.send(end_event.clone()).await?;
174
175 if let Some(ref mut acc) = accumulator {
177 if let Some(completed_msg) = acc.push_and_check_transition(&end_event) {
178 if let Some(ref p) = persistence {
179 p.client.save_message(completed_msg).await?;
180 }
181 }
182 }
183
184 if let Some(mut acc) = accumulator {
186 if let Some(final_msg) = acc.finalize() {
187 if let Some(ref p) = persistence {
188 p.client.save_message(final_msg).await?;
189 }
190 }
191 }
192
193 Ok(())
194 }
195}
196