1use dashmap::DashMap;
2use std::sync::{Arc, Mutex};
3use std::time::Duration;
4use tokio::time::timeout;
5
6use crate::{
7 context::Context,
8 error::{GraphError, Result},
9 storage::Session,
10 task::{NextAction, Task, TaskResult},
11};
12
13pub type EdgeCondition = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
15
16#[derive(Clone)]
18pub struct Edge {
19 pub from: String,
20 pub to: String,
21 pub condition: Option<EdgeCondition>,
22}
23
24pub struct Graph {
26 pub id: String,
27 tasks: DashMap<String, Arc<dyn Task>>,
28 edges: Mutex<Vec<Edge>>,
29 start_task_id: Mutex<Option<String>>,
30 task_timeout: Duration,
31}
32
33impl Graph {
34 pub fn new(id: impl Into<String>) -> Self {
35 Self {
36 id: id.into(),
37 tasks: DashMap::new(),
38 edges: Mutex::new(Vec::new()),
39 start_task_id: Mutex::new(None),
40 task_timeout: Duration::from_secs(300), }
42 }
43
44 pub fn set_task_timeout(&mut self, timeout: Duration) {
46 self.task_timeout = timeout;
47 }
48
49 pub fn add_task(&self, task: Arc<dyn Task>) -> &Self {
51 let task_id = task.id().to_string();
52 let is_first = self.tasks.is_empty();
53 self.tasks.insert(task_id.clone(), task);
54
55 if is_first {
57 *self.start_task_id.lock().unwrap() = Some(task_id);
58 }
59
60 self
61 }
62
63 pub fn set_start_task(&self, task_id: impl Into<String>) -> &Self {
65 let task_id = task_id.into();
66 if self.tasks.contains_key(&task_id) {
67 *self.start_task_id.lock().unwrap() = Some(task_id);
68 }
69 self
70 }
71
72 pub fn add_edge(&self, from: impl Into<String>, to: impl Into<String>) -> &Self {
74 self.edges.lock().unwrap().push(Edge {
75 from: from.into(),
76 to: to.into(),
77 condition: None,
78 });
79 self
80 }
81
82 pub fn add_conditional_edge<F>(
85 &self,
86 from: impl Into<String>,
87 condition: F,
88 yes: impl Into<String>,
89 no: impl Into<String>,
90 ) -> &Self
91 where
92 F: Fn(&Context) -> bool + Send + Sync + 'static,
93 {
94 let from = from.into();
95 let yes_to = yes.into();
96 let no_to = no.into();
97
98 let predicate: EdgeCondition = Arc::new(condition);
99
100 let mut edges = self.edges.lock().unwrap();
101
102 edges.push(Edge {
104 from: from.clone(),
105 to: yes_to,
106 condition: Some(predicate),
107 });
108
109 edges.push(Edge {
111 from,
112 to: no_to,
113 condition: None,
114 });
115
116 self
117 }
118
119 pub async fn execute_session(&self, session: &mut Session) -> Result<ExecutionResult> {
122 tracing::info!(
123 graph_id = %self.id,
124 session_id = %session.id,
125 current_task = %session.current_task_id,
126 "Starting graph execution"
127 );
128
129 let result = self
131 .execute_single_task(&session.current_task_id, session.context.clone())
132 .await?;
133
134 match &result.next_action {
136 NextAction::Continue => {
137 session.status_message = result.status_message.clone();
139
140 if let Some(next_task_id) = self.find_next_task(&result.task_id, &session.context) {
142 session.current_task_id = next_task_id.clone();
143 Ok(ExecutionResult {
144 response: result.response,
145 status: ExecutionStatus::Paused {
146 next_task_id,
147 reason: "Task completed, continuing to next task".to_string(),
148 },
149 })
150 } else {
151 session.current_task_id = result.task_id.clone();
153 Ok(ExecutionResult {
154 response: result.response,
155 status: ExecutionStatus::Paused {
156 next_task_id: result.task_id.clone(),
157 reason: "No outgoing edge found from current task".to_string(),
158 },
159 })
160 }
161 }
162 NextAction::ContinueAndExecute => {
163 session.status_message = result.status_message.clone();
165
166 if let Some(next_task_id) = self.find_next_task(&result.task_id, &session.context) {
168 session.current_task_id = next_task_id;
171
172 return Box::pin(self.execute_session(session)).await;
174 } else {
175 session.current_task_id = result.task_id.clone();
177 Ok(ExecutionResult {
178 response: result.response,
179 status: ExecutionStatus::Paused {
180 next_task_id: result.task_id.clone(),
181 reason: "No outgoing edge found from current task".to_string(),
182 },
183 })
184 }
185 }
186 NextAction::WaitForInput => {
187 session.status_message = result.status_message.clone();
189 session.current_task_id = result.task_id.clone();
191 Ok(ExecutionResult {
192 response: result.response,
193 status: ExecutionStatus::WaitingForInput,
194 })
195 }
196 NextAction::End => {
197 session.status_message = result.status_message.clone();
199 session.current_task_id = result.task_id.clone();
200 Ok(ExecutionResult {
201 response: result.response,
202 status: ExecutionStatus::Completed,
203 })
204 }
205 NextAction::GoTo(target_id) => {
206 session.status_message = result.status_message.clone();
208 if self.tasks.contains_key(target_id) {
209 session.current_task_id = target_id.clone();
210 Ok(ExecutionResult {
211 response: result.response,
212 status: ExecutionStatus::Paused {
213 next_task_id: target_id.clone(),
214 reason: "Task requested jump to specific task".to_string(),
215 },
216 })
217 } else {
218 Err(GraphError::TaskNotFound(target_id.clone()))
219 }
220 }
221 NextAction::GoBack => {
222 session.status_message = result.status_message.clone();
224 session.current_task_id = result.task_id.clone();
226 Ok(ExecutionResult {
227 response: result.response,
228 status: ExecutionStatus::WaitingForInput,
229 })
230 }
231 }
232 }
233
234 async fn execute_single_task(&self, task_id: &str, context: Context) -> Result<TaskResult> {
236 tracing::debug!(
237 task_id = %task_id,
238 "Executing single task"
239 );
240
241 let task = self
242 .tasks
243 .get(task_id)
244 .ok_or_else(|| GraphError::TaskNotFound(task_id.to_string()))?;
245
246 let task_future = task.run(context);
248 let mut result = match timeout(self.task_timeout, task_future).await {
249 Ok(Ok(result)) => result,
250 Ok(Err(e)) => return Err(GraphError::TaskExecutionFailed(
251 format!("Task '{}' failed: {}", task_id, e)
252 )),
253 Err(_) => return Err(GraphError::TaskExecutionFailed(
254 format!("Task '{}' timed out after {:?}", task_id, self.task_timeout)
255 )),
256 };
257
258 result.task_id = task_id.to_string();
260
261 Ok(result)
262 }
263
264 pub async fn execute(&self, task_id: &str, context: Context) -> Result<TaskResult> {
266 let task = self
267 .tasks
268 .get(task_id)
269 .ok_or_else(|| GraphError::TaskNotFound(task_id.to_string()))?;
270
271 let mut result = task.run(context.clone()).await?;
272
273 result.task_id = task_id.to_string();
275
276 match &result.next_action {
278 NextAction::Continue => {
279 if result.response.is_some() {
282 Ok(result)
283 } else {
284 if let Some(next_task_id) = self.find_next_task(task_id, &context) {
286 Box::pin(self.execute(&next_task_id, context)).await
287 } else {
288 Ok(result)
289 }
290 }
291 }
292 NextAction::GoTo(target_id) => {
293 if self.tasks.contains_key(target_id) {
294 Box::pin(self.execute(target_id, context)).await
295 } else {
296 Err(GraphError::TaskNotFound(target_id.clone()))
297 }
298 }
299 _ => Ok(result),
300 }
301 }
302
303 pub fn find_next_task(&self, current_task_id: &str, context: &Context) -> Option<String> {
305 let edges = self.edges.lock().unwrap();
306
307 let mut fallback: Option<String> = None;
308 for edge in edges.iter().filter(|e| e.from == current_task_id) {
309 match &edge.condition {
310 Some(pred) if pred(context) => return Some(edge.to.clone()),
311 None if fallback.is_none() => fallback = Some(edge.to.clone()),
312 _ => {}
313 }
314 }
315 fallback
316 }
317
318 pub fn start_task_id(&self) -> Option<String> {
320 self.start_task_id.lock().unwrap().clone()
321 }
322
323 pub fn get_task(&self, task_id: &str) -> Option<Arc<dyn Task>> {
325 self.tasks.get(task_id).map(|entry| entry.clone())
326 }
327}
328
329pub struct GraphBuilder {
331 graph: Graph,
332}
333
334impl GraphBuilder {
335 pub fn new(id: impl Into<String>) -> Self {
336 Self {
337 graph: Graph::new(id),
338 }
339 }
340
341 pub fn add_task(self, task: Arc<dyn Task>) -> Self {
342 self.graph.add_task(task);
343 self
344 }
345
346 pub fn add_edge(self, from: impl Into<String>, to: impl Into<String>) -> Self {
347 self.graph.add_edge(from, to);
348 self
349 }
350
351 pub fn add_conditional_edge<F>(
352 self,
353 from: impl Into<String>,
354 condition: F,
355 yes: impl Into<String>,
356 no: impl Into<String>,
357 ) -> Self
358 where
359 F: Fn(&Context) -> bool + Send + Sync + 'static,
360 {
361 self.graph.add_conditional_edge(from, condition, yes, no);
362 self
363 }
364
365 pub fn set_start_task(self, task_id: impl Into<String>) -> Self {
366 self.graph.set_start_task(task_id);
367 self
368 }
369
370 pub fn build(self) -> Graph {
371 if self.graph.tasks.is_empty() {
373 tracing::warn!("Building graph with no tasks");
374 }
375
376 let task_count = self.graph.tasks.len();
378 if task_count > 1 {
379 let all_task_ids: Vec<String> = self.graph.tasks.iter()
381 .map(|t| t.key().clone())
382 .collect();
383
384 let edges = self.graph.edges.lock().unwrap();
386 let mut connected_tasks = std::collections::HashSet::new();
387
388 for edge in edges.iter() {
389 connected_tasks.insert(edge.from.clone());
390 connected_tasks.insert(edge.to.clone());
391 }
392 drop(edges); for task_id in all_task_ids {
396 if !connected_tasks.contains(&task_id) {
397 tracing::warn!(
398 task_id = %task_id,
399 "Task has no edges - it may be unreachable"
400 );
401 }
402 }
403 }
404
405 self.graph
406 }
407}
408
409#[derive(Debug, Clone)]
411pub struct ExecutionResult {
412 pub response: Option<String>,
413 pub status: ExecutionStatus,
414}
415
416#[derive(Debug, Clone)]
417pub enum ExecutionStatus {
418 Paused {
420 next_task_id: String,
421 reason: String,
422 },
423 WaitingForInput,
425 Completed,
427 Error(String),
429}