graph_flow/
graph.rs

1use dashmap::DashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4use tokio::time::timeout;
5
6use crate::{
7    context::Context,
8    error::{GraphError, Result},
9    storage::Session,
10    task::{NextAction, Task, TaskResult},
11};
12
13/// Type alias for edge condition functions
14pub type EdgeCondition = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
15
16/// Edge between tasks in the graph
17#[derive(Clone)]
18pub struct Edge {
19    pub from: String,
20    pub to: String,
21    pub condition: Option<EdgeCondition>,
22}
23
24/// A graph of tasks that can be executed
25pub struct Graph {
26    pub id: String,
27    tasks: DashMap<String, Arc<dyn Task>>,
28    edges: Mutex<Vec<Edge>>,
29    start_task_id: Mutex<Option<String>>,
30    task_timeout: Duration,
31}
32
33impl Graph {
34    pub fn new(id: impl Into<String>) -> Self {
35        Self {
36            id: id.into(),
37            tasks: DashMap::new(),
38            edges: Mutex::new(Vec::new()),
39            start_task_id: Mutex::new(None),
40            task_timeout: Duration::from_secs(300), // Default 5 minute timeout
41        }
42    }
43    
44    /// Set the timeout duration for task execution
45    pub fn set_task_timeout(&mut self, timeout: Duration) {
46        self.task_timeout = timeout;
47    }
48
49    /// Add a task to the graph
50    pub fn add_task(&self, task: Arc<dyn Task>) -> &Self {
51        let task_id = task.id().to_string();
52        let is_first = self.tasks.is_empty();
53        self.tasks.insert(task_id.clone(), task);
54
55        // Set as start task if it's the first one
56        if is_first {
57            *self.start_task_id.lock().unwrap() = Some(task_id);
58        }
59
60        self
61    }
62
63    /// Set the starting task
64    pub fn set_start_task(&self, task_id: impl Into<String>) -> &Self {
65        let task_id = task_id.into();
66        if self.tasks.contains_key(&task_id) {
67            *self.start_task_id.lock().unwrap() = Some(task_id);
68        }
69        self
70    }
71
72    /// Add an edge between tasks
73    pub fn add_edge(&self, from: impl Into<String>, to: impl Into<String>) -> &Self {
74        self.edges.lock().unwrap().push(Edge {
75            from: from.into(),
76            to: to.into(),
77            condition: None,
78        });
79        self
80    }
81
82    /// Add a conditional edge with an explicit `else` branch.
83    /// `yes` is taken when `condition(ctx)` returns `true`; otherwise `no` is chosen.
84    pub fn add_conditional_edge<F>(
85        &self,
86        from: impl Into<String>,
87        condition: F,
88        yes: impl Into<String>,
89        no: impl Into<String>,
90    ) -> &Self
91    where
92        F: Fn(&Context) -> bool + Send + Sync + 'static,
93    {
94        let from = from.into();
95        let yes_to = yes.into();
96        let no_to = no.into();
97
98        let predicate: EdgeCondition = Arc::new(condition);
99
100        let mut edges = self.edges.lock().unwrap();
101
102        // "yes" branch
103        edges.push(Edge {
104            from: from.clone(),
105            to: yes_to,
106            condition: Some(predicate),
107        });
108
109        // "else" branch (unconditional fallback)
110        edges.push(Edge {
111            from,
112            to: no_to,
113            condition: None,
114        });
115
116        self
117    }
118
119    /// Execute the graph with session management
120    /// This method manages the session state and returns a simple status
121    pub async fn execute_session(&self, session: &mut Session) -> Result<ExecutionResult> {
122        tracing::info!(
123            graph_id = %self.id,
124            session_id = %session.id,
125            current_task = %session.current_task_id,
126            "Starting graph execution"
127        );
128        
129        // Execute ONLY the current task (not the full recursive chain)
130        let result = self
131            .execute_single_task(&session.current_task_id, session.context.clone())
132            .await?;
133
134        // Handle next action at the session level
135        match &result.next_action {
136            NextAction::Continue => {
137                // Update session status message if provided
138                session.status_message = result.status_message.clone();
139
140                // Find the next task but don't execute it
141                if let Some(next_task_id) = self.find_next_task(&result.task_id, &session.context) {
142                    session.current_task_id = next_task_id.clone();
143                    Ok(ExecutionResult {
144                        response: result.response,
145                        status: ExecutionStatus::Paused { 
146                            next_task_id,
147                            reason: "Task completed, continuing to next task".to_string(),
148                        },
149                    })
150                } else {
151                    // No next task found, stay at current task
152                    session.current_task_id = result.task_id.clone();
153                    Ok(ExecutionResult {
154                        response: result.response,
155                        status: ExecutionStatus::Paused { 
156                            next_task_id: result.task_id.clone(),
157                            reason: "No outgoing edge found from current task".to_string(),
158                        },
159                    })
160                }
161            }
162            NextAction::ContinueAndExecute => {
163                // Update session status message if provided
164                session.status_message = result.status_message.clone();
165
166                // Find the next task and execute it immediately (recursive behavior)
167                if let Some(next_task_id) = self.find_next_task(&result.task_id, &session.context) {
168                    // Instead of using the old execute method that clones context,
169                    // continue executing in session mode to preserve context updates
170                    session.current_task_id = next_task_id;
171
172                    // Recursively call execute_session to maintain proper context sharing
173                    return Box::pin(self.execute_session(session)).await;
174                } else {
175                    // No next task found, stay at current task
176                    session.current_task_id = result.task_id.clone();
177                    Ok(ExecutionResult {
178                        response: result.response,
179                        status: ExecutionStatus::Paused { 
180                            next_task_id: result.task_id.clone(),
181                            reason: "No outgoing edge found from current task".to_string(),
182                        },
183                    })
184                }
185            }
186            NextAction::WaitForInput => {
187                // Update session status message if provided
188                session.status_message = result.status_message.clone();
189                // Stay at the current task
190                session.current_task_id = result.task_id.clone();
191                Ok(ExecutionResult {
192                    response: result.response,
193                    status: ExecutionStatus::WaitingForInput,
194                })
195            }
196            NextAction::End => {
197                // Update session status message if provided
198                session.status_message = result.status_message.clone();
199                session.current_task_id = result.task_id.clone();
200                Ok(ExecutionResult {
201                    response: result.response,
202                    status: ExecutionStatus::Completed,
203                })
204            }
205            NextAction::GoTo(target_id) => {
206                // Update session status message if provided
207                session.status_message = result.status_message.clone();
208                if self.tasks.contains_key(target_id) {
209                    session.current_task_id = target_id.clone();
210                    Ok(ExecutionResult {
211                        response: result.response,
212                        status: ExecutionStatus::Paused { 
213                            next_task_id: target_id.clone(),
214                            reason: "Task requested jump to specific task".to_string(),
215                        },
216                    })
217                } else {
218                    Err(GraphError::TaskNotFound(target_id.clone()))
219                }
220            }
221            NextAction::GoBack => {
222                // Update session status message if provided
223                session.status_message = result.status_message.clone();
224                // For now, stay at current task - could implement back navigation logic later
225                session.current_task_id = result.task_id.clone();
226                Ok(ExecutionResult {
227                    response: result.response,
228                    status: ExecutionStatus::WaitingForInput,
229                })
230            }
231        }
232    }
233
234    /// Execute a single task without following Continue actions
235    async fn execute_single_task(&self, task_id: &str, context: Context) -> Result<TaskResult> {
236        tracing::debug!(
237            task_id = %task_id,
238            "Executing single task"
239        );
240        
241        let task = self
242            .tasks
243            .get(task_id)
244            .ok_or_else(|| GraphError::TaskNotFound(task_id.to_string()))?;
245
246        // Execute task with timeout
247        let task_future = task.run(context);
248        let mut result = match timeout(self.task_timeout, task_future).await {
249            Ok(Ok(result)) => result,
250            Ok(Err(e)) => return Err(GraphError::TaskExecutionFailed(
251                format!("Task '{}' failed: {}", task_id, e)
252            )),
253            Err(_) => return Err(GraphError::TaskExecutionFailed(
254                format!("Task '{}' timed out after {:?}", task_id, self.task_timeout)
255            )),
256        };
257
258        // Set the task_id in the result to track which task generated it
259        result.task_id = task_id.to_string();
260
261        Ok(result)
262    }
263
264    /// Execute the graph starting from a specific task
265    pub async fn execute(&self, task_id: &str, context: Context) -> Result<TaskResult> {
266        let task = self
267            .tasks
268            .get(task_id)
269            .ok_or_else(|| GraphError::TaskNotFound(task_id.to_string()))?;
270
271        let mut result = task.run(context.clone()).await?;
272
273        // Set the task_id in the result to track which task generated it
274        result.task_id = task_id.to_string();
275
276        // Handle next action
277        match &result.next_action {
278            NextAction::Continue => {
279                // If this task has a response, stop here and don't continue to next task
280                // This allows the response to be returned to the user
281                if result.response.is_some() {
282                    Ok(result)
283                } else {
284                    // Find the next task based on edges
285                    if let Some(next_task_id) = self.find_next_task(task_id, &context) {
286                        Box::pin(self.execute(&next_task_id, context)).await
287                    } else {
288                        Ok(result)
289                    }
290                }
291            }
292            NextAction::GoTo(target_id) => {
293                if self.tasks.contains_key(target_id) {
294                    Box::pin(self.execute(target_id, context)).await
295                } else {
296                    Err(GraphError::TaskNotFound(target_id.clone()))
297                }
298            }
299            _ => Ok(result),
300        }
301    }
302
303    /// Find the next task based on edges and conditions
304    pub fn find_next_task(&self, current_task_id: &str, context: &Context) -> Option<String> {
305        let edges = self.edges.lock().unwrap();
306
307        let mut fallback: Option<String> = None;
308        for edge in edges.iter().filter(|e| e.from == current_task_id) {
309            match &edge.condition {
310                Some(pred) if pred(context) => return Some(edge.to.clone()),
311                None if fallback.is_none() => fallback = Some(edge.to.clone()),
312                _ => {}
313            }
314        }
315        fallback
316    }
317
318    /// Get the start task ID
319    pub fn start_task_id(&self) -> Option<String> {
320        self.start_task_id.lock().unwrap().clone()
321    }
322
323    /// Get a task by ID
324    pub fn get_task(&self, task_id: &str) -> Option<Arc<dyn Task>> {
325        self.tasks.get(task_id).map(|entry| entry.clone())
326    }
327}
328
329/// Builder for creating graphs
330pub struct GraphBuilder {
331    graph: Graph,
332}
333
334impl GraphBuilder {
335    pub fn new(id: impl Into<String>) -> Self {
336        Self {
337            graph: Graph::new(id),
338        }
339    }
340
341    pub fn add_task(self, task: Arc<dyn Task>) -> Self {
342        self.graph.add_task(task);
343        self
344    }
345
346    pub fn add_edge(self, from: impl Into<String>, to: impl Into<String>) -> Self {
347        self.graph.add_edge(from, to);
348        self
349    }
350
351    pub fn add_conditional_edge<F>(
352        self,
353        from: impl Into<String>,
354        condition: F,
355        yes: impl Into<String>,
356        no: impl Into<String>,
357    ) -> Self
358    where
359        F: Fn(&Context) -> bool + Send + Sync + 'static,
360    {
361        self.graph.add_conditional_edge(from, condition, yes, no);
362        self
363    }
364
365    pub fn set_start_task(self, task_id: impl Into<String>) -> Self {
366        self.graph.set_start_task(task_id);
367        self
368    }
369
370    pub fn build(self) -> Graph {
371        // Validate the graph before returning
372        if self.graph.tasks.is_empty() {
373            tracing::warn!("Building graph with no tasks");
374        }
375        
376        // Check for orphaned tasks (tasks with no incoming or outgoing edges)
377        let task_count = self.graph.tasks.len();
378        if task_count > 1 {
379            // Collect task IDs first
380            let all_task_ids: Vec<String> = self.graph.tasks.iter()
381                .map(|t| t.key().clone())
382                .collect();
383            
384            // Then check edges
385            let edges = self.graph.edges.lock().unwrap();
386            let mut connected_tasks = std::collections::HashSet::new();
387            
388            for edge in edges.iter() {
389                connected_tasks.insert(edge.from.clone());
390                connected_tasks.insert(edge.to.clone());
391            }
392            drop(edges); // Explicitly drop the lock
393            
394            // Now check for orphaned tasks
395            for task_id in all_task_ids {
396                if !connected_tasks.contains(&task_id) {
397                    tracing::warn!(
398                        task_id = %task_id,
399                        "Task has no edges - it may be unreachable"
400                    );
401                }
402            }
403        }
404        
405        self.graph
406    }
407}
408
409/// Status of graph execution
410#[derive(Debug, Clone)]
411pub struct ExecutionResult {
412    pub response: Option<String>,
413    pub status: ExecutionStatus,
414}
415
416#[derive(Debug, Clone)]
417pub enum ExecutionStatus {
418    /// Paused, will continue automatically to the specified next task
419    Paused { 
420        next_task_id: String,
421        reason: String,
422    },
423    /// Waiting for user input to continue
424    WaitingForInput,
425    /// Workflow completed successfully
426    Completed,
427    /// Error occurred during execution
428    Error(String),
429}