1use async_trait::async_trait;
8use forge_core::Forge;
9use serde::{Deserialize, Serialize};
10use std::fmt;
11use std::sync::Arc;
12
13use crate::workflow::tools::ToolRegistry;
15
16use crate::audit::AuditLog;
18
19#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)]
24pub struct TaskId(String);
25
26impl TaskId {
27 pub fn new(id: impl Into<String>) -> Self {
29 Self(id.into())
30 }
31
32 pub fn as_str(&self) -> &str {
34 &self.0
35 }
36
37 pub fn into_inner(self) -> String {
39 self.0
40 }
41}
42
43impl fmt::Display for TaskId {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 write!(f, "{}", self.0)
46 }
47}
48
49impl From<String> for TaskId {
50 fn from(s: String) -> Self {
51 Self(s)
52 }
53}
54
55impl From<&str> for TaskId {
56 fn from(s: &str) -> Self {
57 Self(s.to_string())
58 }
59}
60
61#[derive(Clone, Debug, Copy, PartialEq, Eq, Serialize, Deserialize)]
67pub enum Dependency {
68 Hard,
70 Soft,
72}
73
74#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
79pub enum TaskResult {
80 Success,
82 Failed(String),
84 Skipped,
86 WithCompensation {
88 result: Box<TaskResult>,
90 compensation: CompensationAction,
92 },
93}
94
95#[derive(Clone)]
100pub struct TaskContext {
101 pub forge: Option<Forge>,
103 pub workflow_id: String,
105 pub task_id: TaskId,
107 cancellation_token: Option<crate::workflow::cancellation::CancellationToken>,
109 pub task_timeout: Option<std::time::Duration>,
111 pub tool_registry: Option<Arc<ToolRegistry>>,
113 pub audit_log: Option<AuditLog>,
115}
116
117impl TaskContext {
118 pub fn new(workflow_id: impl Into<String>, task_id: TaskId) -> Self {
120 Self {
121 forge: None,
122 workflow_id: workflow_id.into(),
123 task_id,
124 cancellation_token: None,
125 task_timeout: None,
126 tool_registry: None,
127 audit_log: None,
128 }
129 }
130
131 pub fn with_forge(mut self, forge: Forge) -> Self {
133 self.forge = Some(forge);
134 self
135 }
136
137 pub fn with_cancellation_token(
157 mut self,
158 token: crate::workflow::cancellation::CancellationToken,
159 ) -> Self {
160 self.cancellation_token = Some(token);
161 self
162 }
163
164 pub fn cancellation_token(&self) -> Option<&crate::workflow::cancellation::CancellationToken> {
181 self.cancellation_token.as_ref()
182 }
183
184 pub fn with_task_timeout(mut self, timeout: std::time::Duration) -> Self {
203 self.task_timeout = Some(timeout);
204 self
205 }
206
207 pub fn with_tool_registry(mut self, registry: Arc<ToolRegistry>) -> Self {
228 self.tool_registry = Some(registry);
229 self
230 }
231
232 pub fn tool_registry(&self) -> Option<&Arc<ToolRegistry>> {
247 self.tool_registry.as_ref()
248 }
249
250 pub fn with_audit_log(mut self, audit_log: AuditLog) -> Self {
270 self.audit_log = Some(audit_log);
271 self
272 }
273
274 pub fn audit_log_mut(&mut self) -> Option<&mut AuditLog> {
289 self.audit_log.as_mut()
290 }
291
292 pub fn task_timeout(&self) -> Option<std::time::Duration> {
302 self.task_timeout
303 }
304}
305
306#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
312pub struct CompensationAction {
313 pub action_type: CompensationType,
315 pub description: String,
317}
318
319impl CompensationAction {
320 pub fn new(action_type: CompensationType, description: impl Into<String>) -> Self {
322 Self {
323 action_type,
324 description: description.into(),
325 }
326 }
327
328 pub fn skip(description: impl Into<String>) -> Self {
330 Self::new(CompensationType::Skip, description)
331 }
332
333 pub fn retry(description: impl Into<String>) -> Self {
335 Self::new(CompensationType::Retry, description)
336 }
337
338 pub fn undo(description: impl Into<String>) -> Self {
340 Self::new(CompensationType::UndoFunction, description)
341 }
342}
343
344#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
346pub enum CompensationType {
347 UndoFunction,
349 Skip,
351 Retry,
353}
354
355#[derive(thiserror::Error, Debug)]
357pub enum TaskError {
358 #[error("Task execution failed: {0}")]
360 ExecutionFailed(String),
361
362 #[error("Dependency {dependency} failed: {reason}")]
364 DependencyFailed {
365 dependency: String,
366 reason: String,
367 },
368
369 #[error("Task skipped: {0}")]
371 Skipped(String),
372
373 #[error("Task timeout: {0}")]
375 Timeout(String),
376
377 #[error("I/O error: {0}")]
379 Io(#[from] std::io::Error),
380
381 #[error("Task error: {0}")]
383 Other(#[from] anyhow::Error),
384}
385
386#[async_trait]
392pub trait WorkflowTask: Send + Sync {
393 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError>;
403
404 fn id(&self) -> TaskId;
406
407 fn name(&self) -> &str;
409
410 fn dependencies(&self) -> Vec<TaskId> {
414 Vec::new()
415 }
416
417 fn compensation(&self) -> Option<CompensationAction> {
430 None
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_task_id_equality() {
440 let id1 = TaskId::new("task-1");
441 let id2 = TaskId::new("task-1");
442 let id3 = TaskId::new("task-2");
443
444 assert_eq!(id1, id2);
445 assert_ne!(id1, id3);
446 }
447
448 #[test]
449 fn test_task_id_hash() {
450 use std::collections::HashSet;
451 let mut set = HashSet::new();
452
453 set.insert(TaskId::new("task-1"));
454 set.insert(TaskId::new("task-1"));
455 set.insert(TaskId::new("task-2"));
456
457 assert_eq!(set.len(), 2);
458 }
459
460 #[test]
461 fn test_task_id_from_string() {
462 let id1: TaskId = "task-1".into();
463 let id2: TaskId = TaskId::from(String::from("task-1"));
464
465 assert_eq!(id1, id2);
466 assert_eq!(id1.as_str(), "task-1");
467 }
468
469 #[test]
470 fn test_task_id_display() {
471 let id = TaskId::new("task-1");
472 assert_eq!(format!("{}", id), "task-1");
473 }
474
475 #[test]
476 fn test_dependency_variants() {
477 let hard = Dependency::Hard;
478 let soft = Dependency::Soft;
479
480 assert_ne!(hard, soft);
481 }
482
483 #[test]
484 fn test_task_result_variants() {
485 let success = TaskResult::Success;
486 let failed = TaskResult::Failed("error".to_string());
487 let skipped = TaskResult::Skipped;
488
489 assert_eq!(success, TaskResult::Success);
490 assert_eq!(failed, TaskResult::Failed("error".to_string()));
491 assert_eq!(skipped, TaskResult::Skipped);
492 }
493
494 #[test]
495 fn test_task_context_creation() {
496 let task_id = TaskId::new("task-1");
497 let context = TaskContext::new("workflow-1", task_id.clone());
498
499 assert_eq!(context.workflow_id, "workflow-1");
500 assert_eq!(context.task_id, task_id);
501 assert!(context.forge.is_none());
502 }
503
504 #[test]
505 fn test_context_without_cancellation_token() {
506 use crate::workflow::cancellation::CancellationToken;
507
508 let task_id = TaskId::new("task-1");
509 let context = TaskContext::new("workflow-1", task_id);
510
511 assert!(context.cancellation_token().is_none());
513 }
514
515 #[test]
516 fn test_context_with_cancellation_token() {
517 use crate::workflow::cancellation::CancellationTokenSource;
518
519 let task_id = TaskId::new("task-1");
520 let source = CancellationTokenSource::new();
521 let token = source.token();
522
523 let context = TaskContext::new("workflow-1", task_id)
524 .with_cancellation_token(token.clone());
525
526 assert!(context.cancellation_token().is_some());
528 let retrieved_token = context.cancellation_token().unwrap();
529 assert!(!retrieved_token.is_cancelled());
530
531 source.cancel();
533
534 assert!(retrieved_token.is_cancelled());
536 }
537
538 #[test]
539 fn test_context_builder_pattern() {
540 use crate::workflow::cancellation::CancellationTokenSource;
541
542 let task_id = TaskId::new("task-1");
543 let source = CancellationTokenSource::new();
544
545 let context = TaskContext::new("workflow-1", task_id)
547 .with_cancellation_token(source.token());
548
549 assert!(context.cancellation_token().is_some());
550 assert_eq!(context.workflow_id, "workflow-1");
551 }
552
553 #[test]
554 fn test_context_without_task_timeout() {
555 let task_id = TaskId::new("task-1");
556 let context = TaskContext::new("workflow-1", task_id);
557
558 assert!(context.task_timeout().is_none());
560 }
561
562 #[test]
563 fn test_context_with_task_timeout() {
564 use std::time::Duration;
565
566 let task_id = TaskId::new("task-1");
567 let timeout = Duration::from_secs(30);
568
569 let context = TaskContext::new("workflow-1", task_id)
570 .with_task_timeout(timeout);
571
572 assert!(context.task_timeout().is_some());
574 assert_eq!(context.task_timeout().unwrap(), timeout);
575 }
576
577 #[test]
578 fn test_context_task_timeout_accessor() {
579 use std::time::Duration;
580
581 let task_id = TaskId::new("task-1");
582 let timeout = Duration::from_millis(5000);
583
584 let context = TaskContext::new("workflow-1", task_id)
585 .with_task_timeout(timeout);
586
587 assert_eq!(context.task_timeout, Some(timeout));
589 assert_eq!(context.task_timeout().unwrap(), Duration::from_millis(5000));
590 }
591
592 struct MockTask {
594 id: TaskId,
595 name: String,
596 }
597
598 #[async_trait]
599 impl WorkflowTask for MockTask {
600 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
601 Ok(TaskResult::Success)
602 }
603
604 fn id(&self) -> TaskId {
605 self.id.clone()
606 }
607
608 fn name(&self) -> &str {
609 &self.name
610 }
611
612 fn dependencies(&self) -> Vec<TaskId> {
613 Vec::new()
614 }
615 }
616
617 #[tokio::test]
618 async fn test_task_trait() {
619 let task = MockTask {
620 id: TaskId::new("task-1"),
621 name: "Test Task".to_string(),
622 };
623
624 assert_eq!(task.id(), TaskId::new("task-1"));
625 assert_eq!(task.name(), "Test Task");
626 assert!(task.dependencies().is_empty());
627
628 let context = TaskContext::new("workflow-1", task.id());
629 let result = task.execute(&context).await.unwrap();
630 assert_eq!(result, TaskResult::Success);
631 }
632
633 #[tokio::test]
634 async fn test_task_with_dependencies() {
635 struct TaskWithDeps {
636 id: TaskId,
637 name: String,
638 deps: Vec<TaskId>,
639 }
640
641 #[async_trait]
642 impl WorkflowTask for TaskWithDeps {
643 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
644 Ok(TaskResult::Success)
645 }
646
647 fn id(&self) -> TaskId {
648 self.id.clone()
649 }
650
651 fn name(&self) -> &str {
652 &self.name
653 }
654
655 fn dependencies(&self) -> Vec<TaskId> {
656 self.deps.clone()
657 }
658 }
659
660 let task = TaskWithDeps {
661 id: TaskId::new("task-b"),
662 name: "Task B".to_string(),
663 deps: vec![TaskId::new("task-a")],
664 };
665
666 assert_eq!(task.dependencies().len(), 1);
667 assert_eq!(task.dependencies()[0], TaskId::new("task-a"));
668 }
669
670 #[tokio::test]
671 async fn test_task_compensation_integration() {
672 use crate::workflow::rollback::ToolCompensation;
673
674 struct TaskWithCompensation {
676 id: TaskId,
677 name: String,
678 compensation: Option<CompensationAction>,
679 }
680
681 #[async_trait]
682 impl WorkflowTask for TaskWithCompensation {
683 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, crate::workflow::TaskError> {
684 Ok(TaskResult::Success)
685 }
686
687 fn id(&self) -> TaskId {
688 self.id.clone()
689 }
690
691 fn name(&self) -> &str {
692 &self.name
693 }
694
695 fn compensation(&self) -> Option<CompensationAction> {
696 self.compensation.clone()
697 }
698 }
699
700 let task = TaskWithCompensation {
702 id: TaskId::new("task-1"),
703 name: "Test Task".to_string(),
704 compensation: Some(CompensationAction::skip("No action needed")),
705 };
706
707 let comp = task.compensation();
709 assert!(comp.is_some());
710 assert_eq!(comp.unwrap().action_type, CompensationType::Skip);
711
712 let task_no_comp = TaskWithCompensation {
714 id: TaskId::new("task-2"),
715 name: "Task Without Compensation".to_string(),
716 compensation: None,
717 };
718
719 assert!(task_no_comp.compensation().is_none());
720 }
721
722 #[test]
723 fn test_compensation_action_to_tool_compensation() {
724 use crate::workflow::rollback::ToolCompensation;
725
726 let skip_action = CompensationAction::skip("Skip this");
728 let tool_comp: ToolCompensation = skip_action.into();
729 assert_eq!(tool_comp.description, "Skip this");
730
731 let retry_action = CompensationAction::retry("Retry later");
732 let tool_comp: ToolCompensation = retry_action.into();
733 assert_eq!(tool_comp.description, "Retry later");
734
735 let undo_action = CompensationAction::undo("Delete file");
736 let tool_comp: ToolCompensation = undo_action.into();
737 assert!(tool_comp.description.contains("no undo function available"));
739 }
740}