1use super::scope::Scope;
7use super::verification::CheckConfig;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use uuid::Uuid;
12
13#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15pub struct RetryPolicy {
16 pub max_retries: u32,
18 pub escalate_on_failure: bool,
20}
21
22impl Default for RetryPolicy {
23 fn default() -> Self {
24 Self {
25 max_retries: 3,
26 escalate_on_failure: true,
27 }
28 }
29}
30
31#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub struct SuccessCriteria {
34 pub description: String,
36 pub checks: Vec<CheckConfig>,
38}
39
40impl SuccessCriteria {
41 #[must_use]
43 pub fn new(description: impl Into<String>) -> Self {
44 Self {
45 description: description.into(),
46 checks: Vec::new(),
47 }
48 }
49
50 #[must_use]
52 pub fn with_check(mut self, check: impl Into<CheckConfig>) -> Self {
53 self.checks.push(check.into());
54 self
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
60pub struct GraphTask {
61 pub id: Uuid,
63 pub title: String,
65 pub description: Option<String>,
67 pub criteria: SuccessCriteria,
69 pub retry_policy: RetryPolicy,
71 #[serde(default = "GraphTask::default_checkpoints")]
73 pub checkpoints: Vec<String>,
74 pub scope: Option<Scope>,
76}
77
78impl GraphTask {
79 fn default_checkpoints() -> Vec<String> {
80 vec!["checkpoint-1".to_string()]
81 }
82
83 #[must_use]
85 pub fn new(title: impl Into<String>, criteria: SuccessCriteria) -> Self {
86 Self {
87 id: Uuid::new_v4(),
88 title: title.into(),
89 description: None,
90 criteria,
91 retry_policy: RetryPolicy::default(),
92 checkpoints: Self::default_checkpoints(),
93 scope: None,
94 }
95 }
96
97 #[must_use]
99 pub fn with_description(mut self, desc: impl Into<String>) -> Self {
100 self.description = Some(desc.into());
101 self
102 }
103
104 #[must_use]
106 pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
107 self.retry_policy = policy;
108 self
109 }
110
111 #[must_use]
113 pub fn with_checkpoints(mut self, checkpoints: Vec<String>) -> Self {
114 self.checkpoints = if checkpoints.is_empty() {
115 Self::default_checkpoints()
116 } else {
117 checkpoints
118 };
119 self
120 }
121
122 #[must_use]
124 pub fn with_scope(mut self, scope: Scope) -> Self {
125 self.scope = Some(scope);
126 self
127 }
128}
129
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "lowercase")]
133pub enum GraphState {
134 Draft,
136 Validated,
138 Locked,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct TaskGraph {
145 pub id: Uuid,
147 pub project_id: Uuid,
149 pub name: String,
151 pub description: Option<String>,
153 pub state: GraphState,
155 pub tasks: HashMap<Uuid, GraphTask>,
157 pub dependencies: HashMap<Uuid, HashSet<Uuid>>,
159 pub created_at: DateTime<Utc>,
161 pub updated_at: DateTime<Utc>,
163}
164
165impl TaskGraph {
166 #[must_use]
168 pub fn new(project_id: Uuid, name: impl Into<String>) -> Self {
169 let now = Utc::now();
170 Self {
171 id: Uuid::new_v4(),
172 project_id,
173 name: name.into(),
174 description: None,
175 state: GraphState::Draft,
176 tasks: HashMap::new(),
177 dependencies: HashMap::new(),
178 created_at: now,
179 updated_at: now,
180 }
181 }
182
183 #[must_use]
185 pub fn is_modifiable(&self) -> bool {
186 self.state == GraphState::Draft
187 }
188
189 pub fn add_task(&mut self, task: GraphTask) -> Result<Uuid, GraphError> {
194 if !self.is_modifiable() {
195 return Err(GraphError::GraphLocked);
196 }
197
198 let id = task.id;
199 self.tasks.insert(id, task);
200 self.dependencies.insert(id, HashSet::new());
201 self.updated_at = Utc::now();
202 Ok(id)
203 }
204
205 pub fn add_dependency(&mut self, from: Uuid, to: Uuid) -> Result<(), GraphError> {
210 if !self.is_modifiable() {
211 return Err(GraphError::GraphLocked);
212 }
213
214 if !self.tasks.contains_key(&from) {
215 return Err(GraphError::TaskNotFound(from));
216 }
217 if !self.tasks.contains_key(&to) {
218 return Err(GraphError::TaskNotFound(to));
219 }
220
221 if from == to {
223 return Err(GraphError::CycleDetected);
224 }
225
226 self.dependencies.entry(from).or_default().insert(to);
228
229 if self.has_cycle() {
231 self.dependencies.entry(from).or_default().remove(&to);
233 return Err(GraphError::CycleDetected);
234 }
235
236 self.updated_at = Utc::now();
237 Ok(())
238 }
239
240 pub fn set_scope(&mut self, task_id: Uuid, scope: Scope) -> Result<(), GraphError> {
245 if !self.is_modifiable() {
246 return Err(GraphError::GraphLocked);
247 }
248
249 let task = self
250 .tasks
251 .get_mut(&task_id)
252 .ok_or(GraphError::TaskNotFound(task_id))?;
253
254 task.scope = Some(scope);
255 self.updated_at = Utc::now();
256 Ok(())
257 }
258
259 pub fn validate(&mut self) -> Result<(), GraphError> {
264 if self.state != GraphState::Draft {
265 return Err(GraphError::InvalidStateTransition);
266 }
267
268 if self.tasks.is_empty() {
270 return Err(GraphError::EmptyGraph);
271 }
272
273 if self.has_cycle() {
275 return Err(GraphError::CycleDetected);
276 }
277
278 for (task_id, deps) in &self.dependencies {
280 if !self.tasks.contains_key(task_id) {
281 return Err(GraphError::TaskNotFound(*task_id));
282 }
283 for dep in deps {
284 if !self.tasks.contains_key(dep) {
285 return Err(GraphError::TaskNotFound(*dep));
286 }
287 }
288 }
289
290 self.state = GraphState::Validated;
291 self.updated_at = Utc::now();
292 Ok(())
293 }
294
295 pub fn lock(&mut self) -> Result<(), GraphError> {
300 if self.state != GraphState::Validated {
301 return Err(GraphError::InvalidStateTransition);
302 }
303
304 self.state = GraphState::Locked;
305 self.updated_at = Utc::now();
306 Ok(())
307 }
308
309 fn has_cycle(&self) -> bool {
311 let mut visited = HashSet::new();
312 let mut rec_stack = HashSet::new();
313
314 for task_id in self.tasks.keys() {
315 if self.has_cycle_util(*task_id, &mut visited, &mut rec_stack) {
316 return true;
317 }
318 }
319 false
320 }
321
322 fn has_cycle_util(
323 &self,
324 node: Uuid,
325 visited: &mut HashSet<Uuid>,
326 rec_stack: &mut HashSet<Uuid>,
327 ) -> bool {
328 if rec_stack.contains(&node) {
329 return true;
330 }
331 if visited.contains(&node) {
332 return false;
333 }
334
335 visited.insert(node);
336 rec_stack.insert(node);
337
338 if let Some(deps) = self.dependencies.get(&node) {
339 for dep in deps {
340 if self.has_cycle_util(*dep, visited, rec_stack) {
341 return true;
342 }
343 }
344 }
345
346 rec_stack.remove(&node);
347 false
348 }
349
350 #[must_use]
352 pub fn topological_order(&self) -> Vec<Uuid> {
353 let mut result = Vec::new();
354 let mut visited = HashSet::new();
355
356 for task_id in self.tasks.keys() {
357 self.topological_visit(*task_id, &mut visited, &mut result);
358 }
359
360 result
361 }
362
363 fn topological_visit(&self, node: Uuid, visited: &mut HashSet<Uuid>, result: &mut Vec<Uuid>) {
364 if visited.contains(&node) {
365 return;
366 }
367
368 visited.insert(node);
369
370 if let Some(deps) = self.dependencies.get(&node) {
371 for dep in deps {
372 self.topological_visit(*dep, visited, result);
373 }
374 }
375
376 result.push(node);
377 }
378
379 #[must_use]
381 pub fn root_tasks(&self) -> Vec<Uuid> {
382 self.tasks
383 .keys()
384 .filter(|id| self.dependencies.get(*id).is_none_or(HashSet::is_empty))
385 .copied()
386 .collect()
387 }
388
389 #[must_use]
391 pub fn dependents(&self, task_id: Uuid) -> Vec<Uuid> {
392 self.dependencies
393 .iter()
394 .filter(|(_, deps)| deps.contains(&task_id))
395 .map(|(id, _)| *id)
396 .collect()
397 }
398}
399
400#[derive(Debug, Clone, PartialEq, Eq)]
402pub enum GraphError {
403 GraphLocked,
405 TaskNotFound(Uuid),
407 CycleDetected,
409 InvalidStateTransition,
411 EmptyGraph,
413}
414
415impl std::fmt::Display for GraphError {
416 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417 match self {
418 Self::GraphLocked => write!(f, "Graph is locked and cannot be modified"),
419 Self::TaskNotFound(id) => write!(f, "Task not found: {id}"),
420 Self::CycleDetected => write!(f, "Cycle detected in task dependencies"),
421 Self::InvalidStateTransition => write!(f, "Invalid state transition"),
422 Self::EmptyGraph => write!(f, "Graph must contain at least one task"),
423 }
424 }
425}
426
427impl std::error::Error for GraphError {}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 fn test_graph() -> TaskGraph {
434 TaskGraph::new(Uuid::new_v4(), "test-graph")
435 }
436
437 fn test_task(title: &str) -> GraphTask {
438 GraphTask::new(title, SuccessCriteria::new("Task completed"))
439 }
440
441 #[test]
442 fn create_graph() {
443 let graph = test_graph();
444 assert_eq!(graph.state, GraphState::Draft);
445 assert!(graph.tasks.is_empty());
446 }
447
448 #[test]
449 fn add_tasks() {
450 let mut graph = test_graph();
451
452 let t1 = graph.add_task(test_task("Task 1")).unwrap();
453 let t2 = graph.add_task(test_task("Task 2")).unwrap();
454
455 assert_eq!(graph.tasks.len(), 2);
456 assert!(graph.tasks.contains_key(&t1));
457 assert!(graph.tasks.contains_key(&t2));
458 }
459
460 #[test]
461 fn add_dependencies() {
462 let mut graph = test_graph();
463
464 let t1 = graph.add_task(test_task("Task 1")).unwrap();
465 let t2 = graph.add_task(test_task("Task 2")).unwrap();
466
467 graph.add_dependency(t2, t1).unwrap();
469
470 assert!(graph.dependencies[&t2].contains(&t1));
471 }
472
473 #[test]
474 fn prevent_cycles() {
475 let mut graph = test_graph();
476
477 let t1 = graph.add_task(test_task("Task 1")).unwrap();
478 let t2 = graph.add_task(test_task("Task 2")).unwrap();
479 let t3 = graph.add_task(test_task("Task 3")).unwrap();
480
481 graph.add_dependency(t2, t1).unwrap();
482 graph.add_dependency(t3, t2).unwrap();
483
484 let result = graph.add_dependency(t1, t3);
486 assert_eq!(result, Err(GraphError::CycleDetected));
487 }
488
489 #[test]
490 fn prevent_self_dependency() {
491 let mut graph = test_graph();
492 let t1 = graph.add_task(test_task("Task 1")).unwrap();
493
494 let result = graph.add_dependency(t1, t1);
495 assert_eq!(result, Err(GraphError::CycleDetected));
496 }
497
498 #[test]
499 fn topological_order() {
500 let mut graph = test_graph();
501
502 let t1 = graph.add_task(test_task("Task 1")).unwrap();
503 let t2 = graph.add_task(test_task("Task 2")).unwrap();
504 let t3 = graph.add_task(test_task("Task 3")).unwrap();
505
506 graph.add_dependency(t2, t1).unwrap();
508 graph.add_dependency(t3, t2).unwrap();
509
510 let order = graph.topological_order();
511
512 let pos1 = order.iter().position(|&x| x == t1).unwrap();
514 let pos2 = order.iter().position(|&x| x == t2).unwrap();
515 let pos3 = order.iter().position(|&x| x == t3).unwrap();
516
517 assert!(pos1 < pos2);
518 assert!(pos2 < pos3);
519 }
520
521 #[test]
522 fn root_tasks() {
523 let mut graph = test_graph();
524
525 let t1 = graph.add_task(test_task("Task 1")).unwrap();
526 let t2 = graph.add_task(test_task("Task 2")).unwrap();
527 let t3 = graph.add_task(test_task("Task 3")).unwrap();
528
529 graph.add_dependency(t2, t1).unwrap();
530 graph.add_dependency(t3, t1).unwrap();
531
532 let roots = graph.root_tasks();
533 assert_eq!(roots.len(), 1);
534 assert!(roots.contains(&t1));
535 }
536
537 #[test]
538 fn validate_and_lock() {
539 let mut graph = test_graph();
540 graph.add_task(test_task("Task 1")).unwrap();
541
542 assert!(graph.validate().is_ok());
543 assert_eq!(graph.state, GraphState::Validated);
544
545 assert!(graph.lock().is_ok());
546 assert_eq!(graph.state, GraphState::Locked);
547 }
548
549 #[test]
550 fn cannot_modify_locked_graph() {
551 let mut graph = test_graph();
552 graph.add_task(test_task("Task 1")).unwrap();
553 graph.validate().unwrap();
554 graph.lock().unwrap();
555
556 let result = graph.add_task(test_task("Task 2"));
557 assert_eq!(result, Err(GraphError::GraphLocked));
558 }
559
560 #[test]
561 fn cannot_validate_empty_graph() {
562 let mut graph = test_graph();
563 let result = graph.validate();
564 assert_eq!(result, Err(GraphError::EmptyGraph));
565 }
566
567 #[test]
568 fn graph_serialization() {
569 let mut graph = test_graph();
570 let t1 = graph.add_task(test_task("Task 1")).unwrap();
571 let t2 = graph.add_task(test_task("Task 2")).unwrap();
572 graph.add_dependency(t2, t1).unwrap();
573
574 let json = serde_json::to_string(&graph).unwrap();
575 let restored: TaskGraph = serde_json::from_str(&json).unwrap();
576
577 assert_eq!(graph.id, restored.id);
578 assert_eq!(graph.tasks.len(), restored.tasks.len());
579 }
580
581 #[test]
582 fn dependents() {
583 let mut graph = test_graph();
584
585 let t1 = graph.add_task(test_task("Task 1")).unwrap();
586 let t2 = graph.add_task(test_task("Task 2")).unwrap();
587 let t3 = graph.add_task(test_task("Task 3")).unwrap();
588
589 graph.add_dependency(t2, t1).unwrap();
590 graph.add_dependency(t3, t1).unwrap();
591
592 let dependents = graph.dependents(t1);
593 assert_eq!(dependents.len(), 2);
594 assert!(dependents.contains(&t2));
595 assert!(dependents.contains(&t3));
596 }
597}