Skip to main content

hivemind/core/
graph.rs

1//! `TaskGraph` - Static, immutable DAG representing planned intent.
2//!
3//! A `TaskGraph` is created by the Planner and is immutable once execution begins.
4//! It represents what should happen, not what has happened.
5
6use super::scope::Scope;
7use super::verification::CheckConfig;
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet};
11use uuid::Uuid;
12
13/// Retry policy for a task.
14#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
15pub struct RetryPolicy {
16    /// Maximum number of retry attempts.
17    pub max_retries: u32,
18    /// Whether to escalate to human on final failure.
19    pub escalate_on_failure: bool,
20}
21
22impl Default for RetryPolicy {
23    fn default() -> Self {
24        Self {
25            max_retries: 3,
26            escalate_on_failure: true,
27        }
28    }
29}
30
31/// Success criteria for a task.
32#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
33pub struct SuccessCriteria {
34    /// Human-readable description of success.
35    pub description: String,
36    /// Automated checks to run (command patterns).
37    pub checks: Vec<CheckConfig>,
38}
39
40impl SuccessCriteria {
41    /// Creates new success criteria.
42    #[must_use]
43    pub fn new(description: impl Into<String>) -> Self {
44        Self {
45            description: description.into(),
46            checks: Vec::new(),
47        }
48    }
49
50    /// Adds an automated check.
51    #[must_use]
52    pub fn with_check(mut self, check: impl Into<CheckConfig>) -> Self {
53        self.checks.push(check.into());
54        self
55    }
56}
57
58/// A task node within a `TaskGraph`.
59#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
60pub struct GraphTask {
61    /// Unique task ID.
62    pub id: Uuid,
63    /// Task title.
64    pub title: String,
65    /// Task description/objective.
66    pub description: Option<String>,
67    /// Success criteria.
68    pub criteria: SuccessCriteria,
69    /// Retry policy.
70    pub retry_policy: RetryPolicy,
71    /// Ordered checkpoint IDs for attempt execution.
72    #[serde(default = "GraphTask::default_checkpoints")]
73    pub checkpoints: Vec<String>,
74    /// Required scope (optional at graph creation).
75    pub scope: Option<Scope>,
76}
77
78impl GraphTask {
79    fn default_checkpoints() -> Vec<String> {
80        vec!["checkpoint-1".to_string()]
81    }
82
83    /// Creates a new graph task.
84    #[must_use]
85    pub fn new(title: impl Into<String>, criteria: SuccessCriteria) -> Self {
86        Self {
87            id: Uuid::new_v4(),
88            title: title.into(),
89            description: None,
90            criteria,
91            retry_policy: RetryPolicy::default(),
92            checkpoints: Self::default_checkpoints(),
93            scope: None,
94        }
95    }
96
97    /// Sets the task description.
98    #[must_use]
99    pub fn with_description(mut self, desc: impl Into<String>) -> Self {
100        self.description = Some(desc.into());
101        self
102    }
103
104    /// Sets the retry policy.
105    #[must_use]
106    pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
107        self.retry_policy = policy;
108        self
109    }
110
111    /// Sets ordered checkpoints for this task.
112    #[must_use]
113    pub fn with_checkpoints(mut self, checkpoints: Vec<String>) -> Self {
114        self.checkpoints = if checkpoints.is_empty() {
115            Self::default_checkpoints()
116        } else {
117            checkpoints
118        };
119        self
120    }
121
122    /// Sets the scope.
123    #[must_use]
124    pub fn with_scope(mut self, scope: Scope) -> Self {
125        self.scope = Some(scope);
126        self
127    }
128}
129
130/// State of a `TaskGraph`.
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
132#[serde(rename_all = "lowercase")]
133pub enum GraphState {
134    /// Graph is being built, can be modified.
135    Draft,
136    /// Graph is validated and ready for execution.
137    Validated,
138    /// Graph is locked and immutable (execution started).
139    Locked,
140}
141
142/// A `TaskGraph` - static DAG of planned intent.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct TaskGraph {
145    /// Unique graph ID.
146    pub id: Uuid,
147    /// Associated project ID.
148    pub project_id: Uuid,
149    /// Graph name.
150    pub name: String,
151    /// Graph description.
152    pub description: Option<String>,
153    /// Current state.
154    pub state: GraphState,
155    /// Task nodes.
156    pub tasks: HashMap<Uuid, GraphTask>,
157    /// Dependencies (`task_id` -> set of dependency `task_ids`).
158    pub dependencies: HashMap<Uuid, HashSet<Uuid>>,
159    /// Creation timestamp.
160    pub created_at: DateTime<Utc>,
161    /// Last update timestamp.
162    pub updated_at: DateTime<Utc>,
163}
164
165impl TaskGraph {
166    /// Creates a new draft `TaskGraph`.
167    #[must_use]
168    pub fn new(project_id: Uuid, name: impl Into<String>) -> Self {
169        let now = Utc::now();
170        Self {
171            id: Uuid::new_v4(),
172            project_id,
173            name: name.into(),
174            description: None,
175            state: GraphState::Draft,
176            tasks: HashMap::new(),
177            dependencies: HashMap::new(),
178            created_at: now,
179            updated_at: now,
180        }
181    }
182
183    /// Checks if the graph is modifiable.
184    #[must_use]
185    pub fn is_modifiable(&self) -> bool {
186        self.state == GraphState::Draft
187    }
188
189    /// Adds a task to the graph.
190    ///
191    /// # Errors
192    /// Returns an error if the graph is locked.
193    pub fn add_task(&mut self, task: GraphTask) -> Result<Uuid, GraphError> {
194        if !self.is_modifiable() {
195            return Err(GraphError::GraphLocked);
196        }
197
198        let id = task.id;
199        self.tasks.insert(id, task);
200        self.dependencies.insert(id, HashSet::new());
201        self.updated_at = Utc::now();
202        Ok(id)
203    }
204
205    /// Adds a dependency between tasks.
206    ///
207    /// # Errors
208    /// Returns an error if the graph is locked, tasks don't exist, or would create a cycle.
209    pub fn add_dependency(&mut self, from: Uuid, to: Uuid) -> Result<(), GraphError> {
210        if !self.is_modifiable() {
211            return Err(GraphError::GraphLocked);
212        }
213
214        if !self.tasks.contains_key(&from) {
215            return Err(GraphError::TaskNotFound(from));
216        }
217        if !self.tasks.contains_key(&to) {
218            return Err(GraphError::TaskNotFound(to));
219        }
220
221        // Check for self-dependency
222        if from == to {
223            return Err(GraphError::CycleDetected);
224        }
225
226        // Add dependency tentatively
227        self.dependencies.entry(from).or_default().insert(to);
228
229        // Check for cycles
230        if self.has_cycle() {
231            // Rollback
232            self.dependencies.entry(from).or_default().remove(&to);
233            return Err(GraphError::CycleDetected);
234        }
235
236        self.updated_at = Utc::now();
237        Ok(())
238    }
239
240    /// Sets scope for a task.
241    ///
242    /// # Errors
243    /// Returns an error if the graph is locked or task doesn't exist.
244    pub fn set_scope(&mut self, task_id: Uuid, scope: Scope) -> Result<(), GraphError> {
245        if !self.is_modifiable() {
246            return Err(GraphError::GraphLocked);
247        }
248
249        let task = self
250            .tasks
251            .get_mut(&task_id)
252            .ok_or(GraphError::TaskNotFound(task_id))?;
253
254        task.scope = Some(scope);
255        self.updated_at = Utc::now();
256        Ok(())
257    }
258
259    /// Validates the graph and transitions to Validated state.
260    ///
261    /// # Errors
262    /// Returns an error if validation fails.
263    pub fn validate(&mut self) -> Result<(), GraphError> {
264        if self.state != GraphState::Draft {
265            return Err(GraphError::InvalidStateTransition);
266        }
267
268        // Must have at least one task
269        if self.tasks.is_empty() {
270            return Err(GraphError::EmptyGraph);
271        }
272
273        // Check for cycles (should already be prevented, but double-check)
274        if self.has_cycle() {
275            return Err(GraphError::CycleDetected);
276        }
277
278        // All dependencies must reference existing tasks
279        for (task_id, deps) in &self.dependencies {
280            if !self.tasks.contains_key(task_id) {
281                return Err(GraphError::TaskNotFound(*task_id));
282            }
283            for dep in deps {
284                if !self.tasks.contains_key(dep) {
285                    return Err(GraphError::TaskNotFound(*dep));
286                }
287            }
288        }
289
290        self.state = GraphState::Validated;
291        self.updated_at = Utc::now();
292        Ok(())
293    }
294
295    /// Locks the graph (immutable after this).
296    ///
297    /// # Errors
298    /// Returns an error if the graph is not validated.
299    pub fn lock(&mut self) -> Result<(), GraphError> {
300        if self.state != GraphState::Validated {
301            return Err(GraphError::InvalidStateTransition);
302        }
303
304        self.state = GraphState::Locked;
305        self.updated_at = Utc::now();
306        Ok(())
307    }
308
309    /// Checks if the graph contains a cycle using DFS.
310    fn has_cycle(&self) -> bool {
311        let mut visited = HashSet::new();
312        let mut rec_stack = HashSet::new();
313
314        for task_id in self.tasks.keys() {
315            if self.has_cycle_util(*task_id, &mut visited, &mut rec_stack) {
316                return true;
317            }
318        }
319        false
320    }
321
322    fn has_cycle_util(
323        &self,
324        node: Uuid,
325        visited: &mut HashSet<Uuid>,
326        rec_stack: &mut HashSet<Uuid>,
327    ) -> bool {
328        if rec_stack.contains(&node) {
329            return true;
330        }
331        if visited.contains(&node) {
332            return false;
333        }
334
335        visited.insert(node);
336        rec_stack.insert(node);
337
338        if let Some(deps) = self.dependencies.get(&node) {
339            for dep in deps {
340                if self.has_cycle_util(*dep, visited, rec_stack) {
341                    return true;
342                }
343            }
344        }
345
346        rec_stack.remove(&node);
347        false
348    }
349
350    /// Returns tasks in topological order (dependencies first).
351    #[must_use]
352    pub fn topological_order(&self) -> Vec<Uuid> {
353        let mut result = Vec::new();
354        let mut visited = HashSet::new();
355
356        for task_id in self.tasks.keys() {
357            self.topological_visit(*task_id, &mut visited, &mut result);
358        }
359
360        result
361    }
362
363    fn topological_visit(&self, node: Uuid, visited: &mut HashSet<Uuid>, result: &mut Vec<Uuid>) {
364        if visited.contains(&node) {
365            return;
366        }
367
368        visited.insert(node);
369
370        if let Some(deps) = self.dependencies.get(&node) {
371            for dep in deps {
372                self.topological_visit(*dep, visited, result);
373            }
374        }
375
376        result.push(node);
377    }
378
379    /// Gets tasks that have no dependencies (entry points).
380    #[must_use]
381    pub fn root_tasks(&self) -> Vec<Uuid> {
382        self.tasks
383            .keys()
384            .filter(|id| self.dependencies.get(*id).is_none_or(HashSet::is_empty))
385            .copied()
386            .collect()
387    }
388
389    /// Gets tasks that depend on a given task.
390    #[must_use]
391    pub fn dependents(&self, task_id: Uuid) -> Vec<Uuid> {
392        self.dependencies
393            .iter()
394            .filter(|(_, deps)| deps.contains(&task_id))
395            .map(|(id, _)| *id)
396            .collect()
397    }
398}
399
400/// Errors that can occur during graph operations.
401#[derive(Debug, Clone, PartialEq, Eq)]
402pub enum GraphError {
403    /// Graph is locked and cannot be modified.
404    GraphLocked,
405    /// Task not found.
406    TaskNotFound(Uuid),
407    /// Cycle detected in dependencies.
408    CycleDetected,
409    /// Invalid state transition.
410    InvalidStateTransition,
411    /// Graph is empty.
412    EmptyGraph,
413}
414
415impl std::fmt::Display for GraphError {
416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        match self {
418            Self::GraphLocked => write!(f, "Graph is locked and cannot be modified"),
419            Self::TaskNotFound(id) => write!(f, "Task not found: {id}"),
420            Self::CycleDetected => write!(f, "Cycle detected in task dependencies"),
421            Self::InvalidStateTransition => write!(f, "Invalid state transition"),
422            Self::EmptyGraph => write!(f, "Graph must contain at least one task"),
423        }
424    }
425}
426
427impl std::error::Error for GraphError {}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    fn test_graph() -> TaskGraph {
434        TaskGraph::new(Uuid::new_v4(), "test-graph")
435    }
436
437    fn test_task(title: &str) -> GraphTask {
438        GraphTask::new(title, SuccessCriteria::new("Task completed"))
439    }
440
441    #[test]
442    fn create_graph() {
443        let graph = test_graph();
444        assert_eq!(graph.state, GraphState::Draft);
445        assert!(graph.tasks.is_empty());
446    }
447
448    #[test]
449    fn add_tasks() {
450        let mut graph = test_graph();
451
452        let t1 = graph.add_task(test_task("Task 1")).unwrap();
453        let t2 = graph.add_task(test_task("Task 2")).unwrap();
454
455        assert_eq!(graph.tasks.len(), 2);
456        assert!(graph.tasks.contains_key(&t1));
457        assert!(graph.tasks.contains_key(&t2));
458    }
459
460    #[test]
461    fn add_dependencies() {
462        let mut graph = test_graph();
463
464        let t1 = graph.add_task(test_task("Task 1")).unwrap();
465        let t2 = graph.add_task(test_task("Task 2")).unwrap();
466
467        // t2 depends on t1
468        graph.add_dependency(t2, t1).unwrap();
469
470        assert!(graph.dependencies[&t2].contains(&t1));
471    }
472
473    #[test]
474    fn prevent_cycles() {
475        let mut graph = test_graph();
476
477        let t1 = graph.add_task(test_task("Task 1")).unwrap();
478        let t2 = graph.add_task(test_task("Task 2")).unwrap();
479        let t3 = graph.add_task(test_task("Task 3")).unwrap();
480
481        graph.add_dependency(t2, t1).unwrap();
482        graph.add_dependency(t3, t2).unwrap();
483
484        // This would create a cycle: t1 -> t2 -> t3 -> t1
485        let result = graph.add_dependency(t1, t3);
486        assert_eq!(result, Err(GraphError::CycleDetected));
487    }
488
489    #[test]
490    fn prevent_self_dependency() {
491        let mut graph = test_graph();
492        let t1 = graph.add_task(test_task("Task 1")).unwrap();
493
494        let result = graph.add_dependency(t1, t1);
495        assert_eq!(result, Err(GraphError::CycleDetected));
496    }
497
498    #[test]
499    fn topological_order() {
500        let mut graph = test_graph();
501
502        let t1 = graph.add_task(test_task("Task 1")).unwrap();
503        let t2 = graph.add_task(test_task("Task 2")).unwrap();
504        let t3 = graph.add_task(test_task("Task 3")).unwrap();
505
506        // t2 depends on t1, t3 depends on t2
507        graph.add_dependency(t2, t1).unwrap();
508        graph.add_dependency(t3, t2).unwrap();
509
510        let order = graph.topological_order();
511
512        // t1 must come before t2, t2 must come before t3
513        let pos1 = order.iter().position(|&x| x == t1).unwrap();
514        let pos2 = order.iter().position(|&x| x == t2).unwrap();
515        let pos3 = order.iter().position(|&x| x == t3).unwrap();
516
517        assert!(pos1 < pos2);
518        assert!(pos2 < pos3);
519    }
520
521    #[test]
522    fn root_tasks() {
523        let mut graph = test_graph();
524
525        let t1 = graph.add_task(test_task("Task 1")).unwrap();
526        let t2 = graph.add_task(test_task("Task 2")).unwrap();
527        let t3 = graph.add_task(test_task("Task 3")).unwrap();
528
529        graph.add_dependency(t2, t1).unwrap();
530        graph.add_dependency(t3, t1).unwrap();
531
532        let roots = graph.root_tasks();
533        assert_eq!(roots.len(), 1);
534        assert!(roots.contains(&t1));
535    }
536
537    #[test]
538    fn validate_and_lock() {
539        let mut graph = test_graph();
540        graph.add_task(test_task("Task 1")).unwrap();
541
542        assert!(graph.validate().is_ok());
543        assert_eq!(graph.state, GraphState::Validated);
544
545        assert!(graph.lock().is_ok());
546        assert_eq!(graph.state, GraphState::Locked);
547    }
548
549    #[test]
550    fn cannot_modify_locked_graph() {
551        let mut graph = test_graph();
552        graph.add_task(test_task("Task 1")).unwrap();
553        graph.validate().unwrap();
554        graph.lock().unwrap();
555
556        let result = graph.add_task(test_task("Task 2"));
557        assert_eq!(result, Err(GraphError::GraphLocked));
558    }
559
560    #[test]
561    fn cannot_validate_empty_graph() {
562        let mut graph = test_graph();
563        let result = graph.validate();
564        assert_eq!(result, Err(GraphError::EmptyGraph));
565    }
566
567    #[test]
568    fn graph_serialization() {
569        let mut graph = test_graph();
570        let t1 = graph.add_task(test_task("Task 1")).unwrap();
571        let t2 = graph.add_task(test_task("Task 2")).unwrap();
572        graph.add_dependency(t2, t1).unwrap();
573
574        let json = serde_json::to_string(&graph).unwrap();
575        let restored: TaskGraph = serde_json::from_str(&json).unwrap();
576
577        assert_eq!(graph.id, restored.id);
578        assert_eq!(graph.tasks.len(), restored.tasks.len());
579    }
580
581    #[test]
582    fn dependents() {
583        let mut graph = test_graph();
584
585        let t1 = graph.add_task(test_task("Task 1")).unwrap();
586        let t2 = graph.add_task(test_task("Task 2")).unwrap();
587        let t3 = graph.add_task(test_task("Task 3")).unwrap();
588
589        graph.add_dependency(t2, t1).unwrap();
590        graph.add_dependency(t3, t1).unwrap();
591
592        let dependents = graph.dependents(t1);
593        assert_eq!(dependents.len(), 2);
594        assert!(dependents.contains(&t2));
595        assert!(dependents.contains(&t3));
596    }
597}