Skip to main content

forge_agent/workflow/
task.rs

1//! Task abstraction and execution trait for workflow system.
2//!
3//! Defines the core task interface that all workflow tasks must implement,
4//! along with supporting types for task identification, execution context,
5//! and result reporting.
6
7use async_trait::async_trait;
8use forge_core::Forge;
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12
13// Import ToolRegistry for task-level tool access
14use crate::workflow::tools::ToolRegistry;
15
16// Import AuditLog for audit event recording
17use crate::audit::AuditLog;
18
19/// Unique identifier for a workflow task.
20///
21/// TaskId wraps a string identifier and implements the necessary traits
22/// for use as a HashMap key and graph node identifier.
23#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)]
24pub struct TaskId(String);
25
26impl TaskId {
27    /// Creates a new TaskId from a string.
28    pub fn new(id: impl Into<String>) -> Self {
29        Self(id.into())
30    }
31
32    /// Returns the underlying string identifier.
33    pub fn as_str(&self) -> &str {
34        &self.0
35    }
36
37    /// Consumes the TaskId and returns the underlying string.
38    pub fn into_inner(self) -> String {
39        self.0
40    }
41}
42
43impl fmt::Display for TaskId {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        write!(f, "{}", self.0)
46    }
47}
48
49impl From<String> for TaskId {
50    fn from(s: String) -> Self {
51        Self(s)
52    }
53}
54
55impl From<&str> for TaskId {
56    fn from(s: &str) -> Self {
57        Self(s.to_string())
58    }
59}
60
61/// Dependency strength between tasks.
62///
63/// Hard dependencies must complete successfully before the dependent
64/// task can execute. Soft dependencies represent preference but not
65/// requirements (not yet enforced in v0.1).
66#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum Dependency {
68    /// Task must complete successfully (blocking dependency)
69    Hard,
70    /// Task should complete if possible (non-blocking, planned for v0.2)
71    Soft,
72}
73
74/// Execution result for a workflow task.
75///
76/// Captures the outcome of task execution for audit logging and
77/// workflow coordination.
78#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
79pub enum TaskResult {
80    /// Task completed successfully
81    Success,
82    /// Task failed with an error message
83    Failed(String),
84    /// Task was skipped (e.g., due to failed hard dependency)
85    Skipped,
86    /// Task result with compensation action (for Saga rollback)
87    WithCompensation {
88        /// The actual result of task execution
89        result: Box<TaskResult>,
90        /// Compensation action to undo task side effects
91        compensation: CompensationAction,
92    },
93}
94
95/// Execution context provided to workflow tasks.
96///
97/// Provides access to the Forge SDK for graph operations,
98/// metadata about the current workflow execution, and cancellation token.
99#[derive(Clone)]
100pub struct TaskContext {
101    /// Optional Forge instance for graph queries
102    pub forge: Option<Forge>,
103    /// Workflow identifier for this execution
104    pub workflow_id: String,
105    /// Task identifier for this execution
106    pub task_id: TaskId,
107    /// Optional cancellation token for cooperative cancellation
108    cancellation_token: Option<crate::workflow::cancellation::CancellationToken>,
109    /// Optional task timeout duration
110    pub task_timeout: Option<std::time::Duration>,
111    /// Optional tool registry for tool invocation
112    pub tool_registry: Option<Arc<ToolRegistry>>,
113    /// Optional audit log for recording events (cloned from executor)
114    pub audit_log: Option<AuditLog>,
115}
116
117impl TaskContext {
118    /// Creates a new TaskContext.
119    pub fn new(workflow_id: impl Into<String>, task_id: TaskId) -> Self {
120        Self {
121            forge: None,
122            workflow_id: workflow_id.into(),
123            task_id,
124            cancellation_token: None,
125            task_timeout: None,
126            tool_registry: None,
127            audit_log: None,
128        }
129    }
130
131    /// Sets the Forge instance for graph operations.
132    pub fn with_forge(mut self, forge: Forge) -> Self {
133        self.forge = Some(forge);
134        self
135    }
136
137    /// Sets the cancellation token for cooperative cancellation.
138    ///
139    /// # Arguments
140    ///
141    /// * `token` - The cancellation token to check during task execution
142    ///
143    /// # Returns
144    ///
145    /// The context with cancellation token set (for builder pattern)
146    ///
147    /// # Example
148    ///
149    /// ```ignore
150    /// use forge_agent::workflow::{CancellationTokenSource, TaskContext};
151    ///
152    /// let source = CancellationTokenSource::new();
153    /// let context = TaskContext::new("workflow-1", task_id)
154    ///     .with_cancellation_token(source.token());
155    /// ```
156    pub fn with_cancellation_token(
157        mut self,
158        token: crate::workflow::cancellation::CancellationToken,
159    ) -> Self {
160        self.cancellation_token = Some(token);
161        self
162    }
163
164    /// Returns a reference to the cancellation token if set.
165    ///
166    /// Tasks can use this to check for cancellation during execution.
167    ///
168    /// # Example
169    ///
170    /// ```ignore
171    /// async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
172    ///     if let Some(token) = context.cancellation_token() {
173    ///         if token.is_cancelled() {
174    ///             return Ok(TaskResult::Skipped);
175    ///         }
176    ///     }
177    ///     // ... do work
178    /// }
179    /// ```
180    pub fn cancellation_token(&self) -> Option<&crate::workflow::cancellation::CancellationToken> {
181        self.cancellation_token.as_ref()
182    }
183
184    /// Sets the task timeout for this context.
185    ///
186    /// # Arguments
187    ///
188    /// * `timeout` - The timeout duration for the task
189    ///
190    /// # Returns
191    ///
192    /// The context with task timeout set (for builder pattern)
193    ///
194    /// # Example
195    ///
196    /// ```ignore
197    /// use std::time::Duration;
198    ///
199    /// let context = TaskContext::new("workflow-1", task_id)
200    ///     .with_task_timeout(Duration::from_secs(30));
201    /// ```
202    pub fn with_task_timeout(mut self, timeout: std::time::Duration) -> Self {
203        self.task_timeout = Some(timeout);
204        self
205    }
206
207    /// Sets the tool registry for tool invocation.
208    ///
209    /// # Arguments
210    ///
211    /// * `registry` - The tool registry to use for tool invocation
212    ///
213    /// # Returns
214    ///
215    /// The context with tool registry set (for builder pattern)
216    ///
217    /// # Example
218    ///
219    /// ```ignore
220    /// use forge_agent::workflow::tools::ToolRegistry;
221    /// use std::sync::Arc;
222    ///
223    /// let registry = Arc::new(ToolRegistry::new());
224    /// let context = TaskContext::new("workflow-1", task_id)
225    ///     .with_tool_registry(registry);
226    /// ```
227    pub fn with_tool_registry(mut self, registry: Arc<ToolRegistry>) -> Self {
228        self.tool_registry = Some(registry);
229        self
230    }
231
232    /// Returns a reference to the tool registry if set.
233    ///
234    /// # Returns
235    ///
236    /// - `Some(&Arc<ToolRegistry>)` if tool registry is set
237    /// - `None` if no tool registry
238    ///
239    /// # Example
240    ///
241    /// ```ignore
242    /// if let Some(registry) = context.tool_registry() {
243    ///     // Use tool registry
244    /// }
245    /// ```
246    pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
247        self.tool_registry.as_ref()
248    }
249
250    /// Sets the audit log for event recording.
251    ///
252    /// # Arguments
253    ///
254    /// * `audit_log` - The audit log to use for event recording
255    ///
256    /// # Returns
257    ///
258    /// The context with audit log set (for builder pattern)
259    ///
260    /// # Example
261    ///
262    /// ```ignore
263    /// use forge_agent::audit::AuditLog;
264    ///
265    /// let audit_log = AuditLog::new();
266    /// let context = TaskContext::new("workflow-1", task_id)
267    ///     .with_audit_log(audit_log);
268    /// ```
269    pub fn with_audit_log(mut self, audit_log: AuditLog) -> Self {
270        self.audit_log = Some(audit_log);
271        self
272    }
273
274    /// Returns a mutable reference to the audit log if set.
275    ///
276    /// # Returns
277    ///
278    /// - `Some(&mut AuditLog)` if audit log is set
279    /// - `None` if no audit log
280    ///
281    /// # Example
282    ///
283    /// ```ignore
284    /// if let Some(audit_log) = context.audit_log_mut() {
285    ///     // Use audit log
286    /// }
287    /// ```
288    pub fn audit_log_mut(&mut self) -> Option<&mut AuditLog> {
289        self.audit_log.as_mut()
290    }
291
292    /// Returns the task timeout duration if set.
293    ///
294    /// # Example
295    ///
296    /// ```ignore
297    /// if let Some(timeout) = context.task_timeout() {
298    ///     println!("Task timeout: {:?}", timeout);
299    /// }
300    /// ```
301    pub fn task_timeout(&self) -> Option<std::time::Duration> {
302        self.task_timeout
303    }
304}
305
306/// Compensation action that undoes task side effects.
307///
308/// Describes how to compensate a task during workflow rollback using the
309/// Saga pattern. This is a simplified version for use in TaskResult.
310/// The full implementation with undo functions is in the rollback module.
311#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
312pub struct CompensationAction {
313    /// Type of compensation action
314    pub action_type: CompensationType,
315    /// Human-readable description of the compensation
316    pub description: String,
317}
318
319impl CompensationAction {
320    /// Creates a new CompensationAction.
321    pub fn new(action_type: CompensationType, description: impl Into<String>) -> Self {
322        Self {
323            action_type,
324            description: description.into(),
325        }
326    }
327
328    /// Creates a Skip compensation (no undo needed).
329    pub fn skip(description: impl Into<String>) -> Self {
330        Self::new(CompensationType::Skip, description)
331    }
332
333    /// Creates a Retry compensation (recommends retry instead of undo).
334    pub fn retry(description: impl Into<String>) -> Self {
335        Self::new(CompensationType::Retry, description)
336    }
337
338    /// Creates an UndoFunction compensation.
339    pub fn undo(description: impl Into<String>) -> Self {
340        Self::new(CompensationType::UndoFunction, description)
341    }
342}
343
344/// Type of compensation action for task rollback.
345#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
346pub enum CompensationType {
347    /// Execute undo function to compensate (e.g., delete created file)
348    UndoFunction,
349    /// No compensation needed (read-only operation)
350    Skip,
351    /// Recommend retry instead of compensation (transient failure)
352    Retry,
353}
354
355/// Error types for task execution.
356#[derive(thiserror::Error, Debug)]
357pub enum TaskError {
358    /// Task execution failed with a message
359    #[error("Task execution failed: {0}")]
360    ExecutionFailed(String),
361
362    /// Task dependency failed
363    #[error("Dependency {dependency} failed: {reason}")]
364    DependencyFailed {
365        dependency: String,
366        reason: String,
367    },
368
369    /// Task was skipped due to workflow state
370    #[error("Task skipped: {0}")]
371    Skipped(String),
372
373    /// Task exceeded time limit
374    #[error("Task timeout: {0}")]
375    Timeout(String),
376
377    /// I/O error during task execution
378    #[error("I/O error: {0}")]
379    Io(#[from] std::io::Error),
380
381    /// Generic error wrapper
382    #[error("Task error: {0}")]
383    Other(#[from] anyhow::Error),
384}
385
386/// Trait for workflow task execution.
387///
388/// All workflow tasks must implement this trait to enable execution
389/// by the WorkflowExecutor. Tasks are executed asynchronously in
390/// topological order based on their dependencies.
391#[async_trait]
392pub trait WorkflowTask: Send + Sync {
393    /// Executes the task with the provided context.
394    ///
395    /// # Arguments
396    ///
397    /// * `context` - Execution context with Forge instance and metadata
398    ///
399    /// # Returns
400    ///
401    /// Returns `Ok(TaskResult)` on success, or `Err(TaskError)` on failure.
402    async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError>;
403
404    /// Returns the unique task identifier.
405    fn id(&self) -> TaskId;
406
407    /// Returns the human-readable task name.
408    fn name(&self) -> &str;
409
410    /// Returns the list of task dependencies.
411    ///
412    /// Default implementation returns an empty vector (no dependencies).
413    fn dependencies(&self) -> Vec<TaskId> {
414        Vec::new()
415    }
416
417    /// Returns the compensation action for this task (if any).
418    ///
419    /// Compensation actions are used during workflow rollback to undo
420    /// task side effects using the Saga pattern. Tasks that don't have
421    /// side effects (e.g., read-only queries) should return None.
422    ///
423    /// Default implementation returns None (no compensation).
424    ///
425    /// # Returns
426    ///
427    /// - `Some(CompensationAction)` - Task can be compensated
428    /// - `None` - Task has no compensation (will be skipped during rollback)
429    fn compensation(&self) -> Option<CompensationAction> {
430        None
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_task_id_equality() {
440        let id1 = TaskId::new("task-1");
441        let id2 = TaskId::new("task-1");
442        let id3 = TaskId::new("task-2");
443
444        assert_eq!(id1, id2);
445        assert_ne!(id1, id3);
446    }
447
448    #[test]
449    fn test_task_id_hash() {
450        use std::collections::HashSet;
451        let mut set = HashSet::new();
452
453        set.insert(TaskId::new("task-1"));
454        set.insert(TaskId::new("task-1"));
455        set.insert(TaskId::new("task-2"));
456
457        assert_eq!(set.len(), 2);
458    }
459
460    #[test]
461    fn test_task_id_from_string() {
462        let id1: TaskId = "task-1".into();
463        let id2: TaskId = TaskId::from(String::from("task-1"));
464
465        assert_eq!(id1, id2);
466        assert_eq!(id1.as_str(), "task-1");
467    }
468
469    #[test]
470    fn test_task_id_display() {
471        let id = TaskId::new("task-1");
472        assert_eq!(format!("{}", id), "task-1");
473    }
474
475    #[test]
476    fn test_dependency_variants() {
477        let hard = Dependency::Hard;
478        let soft = Dependency::Soft;
479
480        assert_ne!(hard, soft);
481    }
482
483    #[test]
484    fn test_task_result_variants() {
485        let success = TaskResult::Success;
486        let failed = TaskResult::Failed("error".to_string());
487        let skipped = TaskResult::Skipped;
488
489        assert_eq!(success, TaskResult::Success);
490        assert_eq!(failed, TaskResult::Failed("error".to_string()));
491        assert_eq!(skipped, TaskResult::Skipped);
492    }
493
494    #[test]
495    fn test_task_context_creation() {
496        let task_id = TaskId::new("task-1");
497        let context = TaskContext::new("workflow-1", task_id.clone());
498
499        assert_eq!(context.workflow_id, "workflow-1");
500        assert_eq!(context.task_id, task_id);
501        assert!(context.forge.is_none());
502    }
503
504    #[test]
505    fn test_context_without_cancellation_token() {
506        use crate::workflow::cancellation::CancellationToken;
507
508        let task_id = TaskId::new("task-1");
509        let context = TaskContext::new("workflow-1", task_id);
510
511        // Cancellation token should be None by default
512        assert!(context.cancellation_token().is_none());
513    }
514
515    #[test]
516    fn test_context_with_cancellation_token() {
517        use crate::workflow::cancellation::CancellationTokenSource;
518
519        let task_id = TaskId::new("task-1");
520        let source = CancellationTokenSource::new();
521        let token = source.token();
522
523        let context = TaskContext::new("workflow-1", task_id)
524            .with_cancellation_token(token.clone());
525
526        // Cancellation token should be accessible
527        assert!(context.cancellation_token().is_some());
528        let retrieved_token = context.cancellation_token().unwrap();
529        assert!(!retrieved_token.is_cancelled());
530
531        // Cancel source
532        source.cancel();
533
534        // Retrieved token should see cancellation
535        assert!(retrieved_token.is_cancelled());
536    }
537
538    #[test]
539    fn test_context_builder_pattern() {
540        use crate::workflow::cancellation::CancellationTokenSource;
541
542        let task_id = TaskId::new("task-1");
543        let source = CancellationTokenSource::new();
544
545        // Test builder pattern chaining
546        let context = TaskContext::new("workflow-1", task_id)
547            .with_cancellation_token(source.token());
548
549        assert!(context.cancellation_token().is_some());
550        assert_eq!(context.workflow_id, "workflow-1");
551    }
552
553    #[test]
554    fn test_context_without_task_timeout() {
555        let task_id = TaskId::new("task-1");
556        let context = TaskContext::new("workflow-1", task_id);
557
558        // Task timeout should be None by default
559        assert!(context.task_timeout().is_none());
560    }
561
562    #[test]
563    fn test_context_with_task_timeout() {
564        use std::time::Duration;
565
566        let task_id = TaskId::new("task-1");
567        let timeout = Duration::from_secs(30);
568
569        let context = TaskContext::new("workflow-1", task_id)
570            .with_task_timeout(timeout);
571
572        // Task timeout should be accessible
573        assert!(context.task_timeout().is_some());
574        assert_eq!(context.task_timeout().unwrap(), timeout);
575    }
576
577    #[test]
578    fn test_context_task_timeout_accessor() {
579        use std::time::Duration;
580
581        let task_id = TaskId::new("task-1");
582        let timeout = Duration::from_millis(5000);
583
584        let context = TaskContext::new("workflow-1", task_id)
585            .with_task_timeout(timeout);
586
587        // Verify accessor returns correct value
588        assert_eq!(context.task_timeout, Some(timeout));
589        assert_eq!(context.task_timeout().unwrap(), Duration::from_millis(5000));
590    }
591
592    // Mock task for testing WorkflowTask trait
593    struct MockTask {
594        id: TaskId,
595        name: String,
596    }
597
598    #[async_trait]
599    impl WorkflowTask for MockTask {
600        async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
601            Ok(TaskResult::Success)
602        }
603
604        fn id(&self) -> TaskId {
605            self.id.clone()
606        }
607
608        fn name(&self) -> &str {
609            &self.name
610        }
611
612        fn dependencies(&self) -> Vec<TaskId> {
613            Vec::new()
614        }
615    }
616
617    #[tokio::test]
618    async fn test_task_trait() {
619        let task = MockTask {
620            id: TaskId::new("task-1"),
621            name: "Test Task".to_string(),
622        };
623
624        assert_eq!(task.id(), TaskId::new("task-1"));
625        assert_eq!(task.name(), "Test Task");
626        assert!(task.dependencies().is_empty());
627
628        let context = TaskContext::new("workflow-1", task.id());
629        let result = task.execute(&context).await.unwrap();
630        assert_eq!(result, TaskResult::Success);
631    }
632
633    #[tokio::test]
634    async fn test_task_with_dependencies() {
635        struct TaskWithDeps {
636            id: TaskId,
637            name: String,
638            deps: Vec<TaskId>,
639        }
640
641        #[async_trait]
642        impl WorkflowTask for TaskWithDeps {
643            async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
644                Ok(TaskResult::Success)
645            }
646
647            fn id(&self) -> TaskId {
648                self.id.clone()
649            }
650
651            fn name(&self) -> &str {
652                &self.name
653            }
654
655            fn dependencies(&self) -> Vec<TaskId> {
656                self.deps.clone()
657            }
658        }
659
660        let task = TaskWithDeps {
661            id: TaskId::new("task-b"),
662            name: "Task B".to_string(),
663            deps: vec![TaskId::new("task-a")],
664        };
665
666        assert_eq!(task.dependencies().len(), 1);
667        assert_eq!(task.dependencies()[0], TaskId::new("task-a"));
668    }
669
670    #[tokio::test]
671    async fn test_task_compensation_integration() {
672        use crate::workflow::rollback::ToolCompensation;
673
674        // Test task with compensation
675        struct TaskWithCompensation {
676            id: TaskId,
677            name: String,
678            compensation: Option<CompensationAction>,
679        }
680
681        #[async_trait]
682        impl WorkflowTask for TaskWithCompensation {
683            async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, crate::workflow::TaskError> {
684                Ok(TaskResult::Success)
685            }
686
687            fn id(&self) -> TaskId {
688                self.id.clone()
689            }
690
691            fn name(&self) -> &str {
692                &self.name
693            }
694
695            fn compensation(&self) -> Option<CompensationAction> {
696                self.compensation.clone()
697            }
698        }
699
700        // Create task with skip compensation
701        let task = TaskWithCompensation {
702            id: TaskId::new("task-1"),
703            name: "Test Task".to_string(),
704            compensation: Some(CompensationAction::skip("No action needed")),
705        };
706
707        // Verify compensation is accessible
708        let comp = task.compensation();
709        assert!(comp.is_some());
710        assert_eq!(comp.unwrap().action_type, CompensationType::Skip);
711
712        // Create task with no compensation
713        let task_no_comp = TaskWithCompensation {
714            id: TaskId::new("task-2"),
715            name: "Task Without Compensation".to_string(),
716            compensation: None,
717        };
718
719        assert!(task_no_comp.compensation().is_none());
720    }
721
722    #[test]
723    fn test_compensation_action_to_tool_compensation() {
724        use crate::workflow::rollback::ToolCompensation;
725
726        // Test conversion from CompensationAction to ToolCompensation
727        let skip_action = CompensationAction::skip("Skip this");
728        let tool_comp: ToolCompensation = skip_action.into();
729        assert_eq!(tool_comp.description, "Skip this");
730
731        let retry_action = CompensationAction::retry("Retry later");
732        let tool_comp: ToolCompensation = retry_action.into();
733        assert_eq!(tool_comp.description, "Retry later");
734
735        let undo_action = CompensationAction::undo("Delete file");
736        let tool_comp: ToolCompensation = undo_action.into();
737        // Undo becomes skip with note about no undo function
738        assert!(tool_comp.description.contains("no undo function available"));
739    }
740}