praxis_graph/
graph.rs

1use crate::node::{Node, NodeType};
2use crate::nodes::{LLMNode, ToolNode};
3use crate::router::{NextNode, Router, SimpleRouter};
4use crate::builder::PersistenceConfig;
5use anyhow::Result;
6use praxis_llm::LLMClient;
7use praxis_mcp::MCPToolExecutor;
8use crate::types::{GraphConfig, GraphInput, GraphState, StreamEvent};
9use std::sync::Arc;
10use std::time::Instant;
11use tokio::sync::mpsc;
12
13/// Context for persistence operations
14pub struct PersistenceContext {
15    pub thread_id: String,
16    pub user_id: String,
17}
18
19pub struct Graph {
20    llm_client: Arc<dyn LLMClient>,
21    mcp_executor: Arc<MCPToolExecutor>,
22    config: GraphConfig,
23    persistence: Option<Arc<PersistenceConfig>>,
24}
25
26impl Graph {
27    pub fn new(
28        llm_client: Arc<dyn LLMClient>,
29        mcp_executor: Arc<MCPToolExecutor>,
30        config: GraphConfig,
31    ) -> Self {
32        Self {
33            llm_client,
34            mcp_executor,
35            config,
36            persistence: None,
37        }
38    }
39    
40    pub(crate) fn new_with_persistence(
41        llm_client: Arc<dyn LLMClient>,
42        mcp_executor: Arc<MCPToolExecutor>,
43        config: GraphConfig,
44        persistence: Option<PersistenceConfig>,
45    ) -> Self {
46        Self {
47            llm_client,
48            mcp_executor,
49            config,
50            persistence: persistence.map(Arc::new),
51        }
52    }
53    
54    /// Create a builder for fluent construction
55    pub fn builder() -> crate::builder::GraphBuilder {
56        crate::builder::GraphBuilder::new()
57    }
58
59    /// Spawn execution in background, return event receiver
60    pub fn spawn_run(
61        &self,
62        input: GraphInput,
63        persistence_ctx: Option<PersistenceContext>,
64    ) -> mpsc::Receiver<StreamEvent> {
65        let (tx, rx) = mpsc::channel(1000);
66
67        // Clone what we need for the spawned task
68        let llm_client = Arc::clone(&self.llm_client);
69        let mcp_executor = Arc::clone(&self.mcp_executor);
70        let config = self.config.clone();
71        let persistence = self.persistence.clone();
72
73        tokio::spawn(async move {
74            if let Err(e) = Self::execute_loop(
75                input,
76                tx.clone(),
77                llm_client,
78                mcp_executor,
79                config,
80                persistence,
81                persistence_ctx,
82            ).await {
83                let _ = tx
84                    .send(StreamEvent::Error {
85                        message: e.to_string(),
86                        node_id: None,
87                    })
88                    .await;
89            }
90        });
91
92        rx
93    }
94
95    async fn execute_loop(
96        input: GraphInput,
97        event_tx: mpsc::Sender<StreamEvent>,
98        llm_client: Arc<dyn LLMClient>,
99        mcp_executor: Arc<MCPToolExecutor>,
100        config: GraphConfig,
101        persistence: Option<Arc<PersistenceConfig>>,
102        ctx: Option<PersistenceContext>,
103    ) -> Result<()> {
104        let start_time = Instant::now();
105
106        // Build initial state
107        let mut state = GraphState::from_input(input);
108
109        // Create Observer if persistence enabled
110        let mut accumulator: Option<praxis_persist::EventAccumulator<StreamEvent>> = match (&persistence, &ctx) {
111            (Some(_), Some(c)) => Some(praxis_persist::EventAccumulator::new(
112                c.thread_id.clone(),
113                c.user_id.clone(),
114            )),
115            _ => None,
116        };
117
118        // Emit init event
119        let init_event = StreamEvent::InitStream {
120            run_id: state.run_id.clone(),
121            conversation_id: state.conversation_id.clone(),
122            timestamp: chrono::Utc::now().timestamp_millis(),
123        };
124        event_tx.send(init_event.clone()).await?;
125
126        // Create nodes
127        let llm_node = LLMNode::new(llm_client, mcp_executor.clone());
128        let tool_node = ToolNode::new(mcp_executor);
129        let router = SimpleRouter;
130
131        let mut current_node = NodeType::LLM;
132        let mut iteration = 0;
133
134        loop {
135            // Guardrail: max iterations
136            if iteration >= config.max_iterations {
137                let error_event = StreamEvent::Error {
138                    message: format!("Max iterations ({}) reached", config.max_iterations),
139                    node_id: None,
140                };
141                event_tx.send(error_event.clone()).await?;
142                break;
143            }
144
145            // Execute current node (this emits events via event_tx)
146            match current_node {
147                NodeType::LLM => {
148                    llm_node.execute(&mut state, event_tx.clone()).await?;
149                }
150                NodeType::Tool => {
151                    tool_node.execute(&mut state, event_tx.clone()).await?;
152                }
153            }
154
155            // Route to next node
156            let next = router.next(&state, current_node);
157
158            match next {
159                NextNode::End => break,
160                NextNode::LLM => current_node = NodeType::LLM,
161                NextNode::Tool => current_node = NodeType::Tool,
162            }
163
164            iteration += 1;
165        }
166
167        // Emit end event
168        let total_duration = start_time.elapsed().as_millis() as u64;
169        let end_event = StreamEvent::EndStream {
170            status: "success".to_string(),
171            total_duration_ms: total_duration,
172        };
173        event_tx.send(end_event.clone()).await?;
174        
175        // Observer pattern: check for final transition
176        if let Some(ref mut acc) = accumulator {
177            if let Some(completed_msg) = acc.push_and_check_transition(&end_event) {
178                if let Some(ref p) = persistence {
179                    p.client.save_message(completed_msg).await?;
180                }
181            }
182        }
183
184        // Finalize any remaining buffer
185        if let Some(mut acc) = accumulator {
186            if let Some(final_msg) = acc.finalize() {
187                if let Some(ref p) = persistence {
188                    p.client.save_message(final_msg).await?;
189                }
190            }
191        }
192
193        Ok(())
194    }
195}
196