praxis_graph/
graph.rs

1use crate::node::{Node, NodeType};
2use crate::nodes::{LLMNode, ToolNode};
3use crate::router::{NextNode, Router, SimpleRouter};
4use anyhow::Result;
5use praxis_llm::LLMClient;
6use praxis_mcp::MCPToolExecutor;
7use praxis_types::{GraphConfig, GraphInput, GraphState, StreamEvent};
8use std::sync::Arc;
9use std::time::Instant;
10use tokio::sync::mpsc;
11
12pub struct Graph {
13    llm_client: Arc<dyn LLMClient>,
14    mcp_executor: Arc<MCPToolExecutor>,
15    config: GraphConfig,
16}
17
18impl Graph {
19    pub fn new(
20        llm_client: Arc<dyn LLMClient>,
21        mcp_executor: Arc<MCPToolExecutor>,
22        config: GraphConfig,
23    ) -> Self {
24        Self {
25            llm_client,
26            mcp_executor,
27            config,
28        }
29    }
30
31    /// Spawn execution in background, return event receiver
32    pub fn spawn_run(&self, input: GraphInput) -> mpsc::Receiver<StreamEvent> {
33        let (tx, rx) = mpsc::channel(1000);
34
35        // Clone what we need for the spawned task
36        let llm_client = Arc::clone(&self.llm_client);
37        let mcp_executor = Arc::clone(&self.mcp_executor);
38        let config = self.config.clone();
39
40        tokio::spawn(async move {
41            if let Err(e) = Self::execute_loop(input, tx.clone(), llm_client, mcp_executor, config).await {
42                let _ = tx
43                    .send(StreamEvent::Error {
44                        message: e.to_string(),
45                        node_id: None,
46                    })
47                    .await;
48            }
49        });
50
51        rx
52    }
53
54    async fn execute_loop(
55        input: GraphInput,
56        event_tx: mpsc::Sender<StreamEvent>,
57        llm_client: Arc<dyn LLMClient>,
58        mcp_executor: Arc<MCPToolExecutor>,
59        config: GraphConfig,
60    ) -> Result<()> {
61        let start_time = Instant::now();
62
63        // Build initial state
64        let mut state = GraphState::from_input(input);
65
66        // Emit init event
67        event_tx
68            .send(StreamEvent::InitStream {
69                run_id: state.run_id.clone(),
70                conversation_id: state.conversation_id.clone(),
71                timestamp: chrono::Utc::now().timestamp_millis(),
72            })
73            .await?;
74
75        // Create nodes
76        let llm_node = LLMNode::new(llm_client, mcp_executor.clone());
77        let tool_node = ToolNode::new(mcp_executor);
78        let router = SimpleRouter;
79
80        let mut current_node = NodeType::LLM;
81        let mut iteration = 0;
82
83        loop {
84            // Guardrail: max iterations
85            if iteration >= config.max_iterations {
86                event_tx
87                    .send(StreamEvent::Error {
88                        message: format!("Max iterations ({}) reached", config.max_iterations),
89                        node_id: None,
90                    })
91                    .await?;
92                break;
93            }
94
95            // Execute current node
96            match current_node {
97                NodeType::LLM => {
98                    llm_node.execute(&mut state, event_tx.clone()).await?;
99                }
100                NodeType::Tool => {
101                    tool_node.execute(&mut state, event_tx.clone()).await?;
102                }
103            }
104
105            // Route to next node
106            let next = router.next(&state, current_node);
107
108            match next {
109                NextNode::End => break,
110                NextNode::LLM => current_node = NodeType::LLM,
111                NextNode::Tool => current_node = NodeType::Tool,
112            }
113
114            iteration += 1;
115        }
116
117        // Emit end event
118        let total_duration = start_time.elapsed().as_millis() as u64;
119        event_tx
120            .send(StreamEvent::EndStream {
121                status: "success".to_string(),
122                total_duration_ms: total_duration,
123            })
124            .await?;
125
126        // TODO: Persistence layer
127        // After execution, save state.messages to MongoDB
128        // Use AssistantMessage from praxis-llm::history
129
130        Ok(())
131    }
132}
133