Skip to main content

forge_agent/workflow/
tasks.rs

1//! Built-in task implementations for common workflow operations.
2//!
3//! Provides pre-built task types for graph queries, agent loops, shell commands,
4//! and simple function wrapping.
5
6use crate::workflow::task::{CompensationAction, TaskContext, TaskError, TaskId, TaskResult, WorkflowTask};
7use crate::workflow::tools::{FallbackHandler, FallbackResult, ToolError, ToolInvocation, ToolRegistry};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::future::Future;
15use std::time::Duration;
16use std::process::Command;
17
18/// Task that wraps an async function for easy workflow definition.
19///
20/// Useful for simple workflows without custom task types.
21///
22/// # Example
23///
24/// ```ignore
25/// use forge_agent::workflow::tasks::FunctionTask;
26/// use forge_agent::workflow::TaskId;
27///
28/// let task = FunctionTask::new(
29///     TaskId::new("my_task"),
30///     "My Task".to_string(),
31///     |ctx| async {
32///         // Do work here
33///         Ok(TaskResult::Success)
34///     }
35/// );
36/// ```
37pub struct FunctionTask {
38    id: TaskId,
39    name: String,
40    f: Box<dyn Fn(&TaskContext) -> Pin<Box<dyn Future<Output = Result<TaskResult, TaskError>> + Send>> + Send + Sync>,
41}
42
43impl FunctionTask {
44    /// Creates a new FunctionTask with the given ID, name, and async function.
45    pub fn new<F, Fut>(id: TaskId, name: String, f: F) -> Self
46    where
47        F: Fn(&TaskContext) -> Fut + Send + Sync + 'static,
48        Fut: Future<Output = Result<TaskResult, TaskError>> + Send + 'static,
49    {
50        Self {
51            id,
52            name,
53            f: Box::new(move |ctx| Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = Result<TaskResult, TaskError>> + Send>>),
54        }
55    }
56}
57
58#[async_trait]
59impl WorkflowTask for FunctionTask {
60    async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
61        (self.f)(context).await
62    }
63
64    fn id(&self) -> TaskId {
65        self.id.clone()
66    }
67
68    fn name(&self) -> &str {
69        &self.name
70    }
71}
72
73/// Types of graph queries supported by GraphQueryTask.
74#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
75pub enum GraphQueryType {
76    /// Find a symbol by name
77    FindSymbol,
78    /// Find references to a symbol
79    References,
80    /// Analyze impact of changes to a symbol
81    ImpactAnalysis,
82}
83
84/// Task that executes graph queries using the Forge SDK.
85///
86/// Queries the code graph for symbols, references, or impact analysis.
87pub struct GraphQueryTask {
88    id: TaskId,
89    name: String,
90    query_type: GraphQueryType,
91    target: String,
92}
93
94impl GraphQueryTask {
95    /// Creates a new GraphQueryTask for finding a symbol.
96    pub fn find_symbol(target: impl Into<String>) -> Self {
97        Self::new(GraphQueryType::FindSymbol, target)
98    }
99
100    /// Creates a new GraphQueryTask for finding references.
101    pub fn references(target: impl Into<String>) -> Self {
102        Self::new(GraphQueryType::References, target)
103    }
104
105    /// Creates a new GraphQueryTask for impact analysis.
106    pub fn impact_analysis(target: impl Into<String>) -> Self {
107        Self::new(GraphQueryType::ImpactAnalysis, target)
108    }
109
110    fn new(query_type: GraphQueryType, target: impl Into<String>) -> Self {
111        let target_str = target.into();
112        Self {
113            id: TaskId::new(format!("graph_query_{:?}", query_type)),
114            name: format!("Graph Query: {:?}", query_type),
115            query_type,
116            target: target_str,
117        }
118    }
119
120    /// Creates a GraphQueryTask with a custom ID.
121    pub fn with_id(id: TaskId, query_type: GraphQueryType, target: impl Into<String>) -> Self {
122        Self {
123            id,
124            name: format!("Graph Query: {:?}", query_type),
125            query_type,
126            target: target.into(),
127        }
128    }
129}
130
131#[async_trait]
132impl WorkflowTask for GraphQueryTask {
133    async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
134        // Phase 8 stub - all graph queries return success
135        // Actual Forge SDK integration will be in Phase 10
136        match self.query_type {
137            GraphQueryType::FindSymbol => {
138                Ok(TaskResult::Success)
139            }
140            GraphQueryType::References => {
141                Ok(TaskResult::Success)
142            }
143            GraphQueryType::ImpactAnalysis => {
144                Ok(TaskResult::Success)
145            }
146        }
147    }
148
149    fn id(&self) -> TaskId {
150        self.id.clone()
151    }
152
153    fn name(&self) -> &str {
154        &self.name
155    }
156
157    fn compensation(&self) -> Option<CompensationAction> {
158        // Graph queries are read-only operations with no side effects
159        Some(CompensationAction::skip("Read-only graph query - no undo needed"))
160    }
161}
162
163/// Task that executes an agent loop for AI-driven operations.
164///
165/// Wraps the AgentLoop as a workflow task for multi-step AI operations.
166pub struct AgentLoopTask {
167    id: TaskId,
168    name: String,
169    query: String,
170}
171
172impl AgentLoopTask {
173    /// Creates a new AgentLoopTask with the given query.
174    pub fn new(id: TaskId, name: String, query: impl Into<String>) -> Self {
175        Self {
176            id,
177            name,
178            query: query.into(),
179        }
180    }
181
182    /// Gets the query for this task.
183    pub fn query(&self) -> &str {
184        &self.query
185    }
186}
187
188#[async_trait]
189impl WorkflowTask for AgentLoopTask {
190    async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
191        // Stub implementation - actual AgentLoop integration in Phase 10
192        // For now, just return success to indicate the task structure is valid
193        Ok(TaskResult::Success)
194    }
195
196    fn id(&self) -> TaskId {
197        self.id.clone()
198    }
199
200    fn name(&self) -> &str {
201        &self.name
202    }
203
204    fn compensation(&self) -> Option<CompensationAction> {
205        // AgentLoop is read-only in v0.4 - no compensation needed
206        // Future versions may implement undo for mutations
207        Some(CompensationAction::skip("Read-only agent loop - no undo needed in v0.4"))
208    }
209}
210
211/// Configuration for shell command execution.
212///
213/// Provides configurable working directory, environment variables,
214/// and timeout settings for shell command tasks.
215#[derive(Clone, Debug, PartialEq)]
216pub struct ShellCommandConfig {
217    /// The command to execute
218    pub command: String,
219    /// Command arguments
220    pub args: Vec<String>,
221    /// Optional working directory for command execution
222    pub working_dir: Option<PathBuf>,
223    /// Environment variables to set for the command
224    pub env: HashMap<String, String>,
225    /// Optional timeout for command execution
226    pub timeout: Option<Duration>,
227}
228
229impl ShellCommandConfig {
230    /// Creates a new ShellCommandConfig with the given command.
231    ///
232    /// # Arguments
233    ///
234    /// * `command` - The command to execute (e.g., "echo", "ls", "cargo")
235    pub fn new(command: impl Into<String>) -> Self {
236        Self {
237            command: command.into(),
238            args: Vec::new(),
239            working_dir: None,
240            env: HashMap::new(),
241            timeout: None,
242        }
243    }
244
245    /// Sets the command arguments.
246    ///
247    /// # Arguments
248    ///
249    /// * `args` - Vector of argument strings
250    ///
251    /// # Returns
252    ///
253    /// Self for builder pattern chaining
254    pub fn args(mut self, args: Vec<String>) -> Self {
255        self.args = args;
256        self
257    }
258
259    /// Sets the working directory for command execution.
260    ///
261    /// # Arguments
262    ///
263    /// * `path` - Path to the working directory
264    ///
265    /// # Returns
266    ///
267    /// Self for builder pattern chaining
268    pub fn working_dir(mut self, path: impl Into<PathBuf>) -> Self {
269        self.working_dir = Some(path.into());
270        self
271    }
272
273    /// Adds an environment variable for the command.
274    ///
275    /// # Arguments
276    ///
277    /// * `key` - Environment variable name
278    /// * `value` - Environment variable value
279    ///
280    /// # Returns
281    ///
282    /// Self for builder pattern chaining
283    pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
284        self.env.insert(key.into(), value.into());
285        self
286    }
287
288    /// Sets the timeout for command execution.
289    ///
290    /// # Arguments
291    ///
292    /// * `duration` - Timeout duration
293    ///
294    /// # Returns
295    ///
296    /// Self for builder pattern chaining
297    pub fn timeout(mut self, duration: Duration) -> Self {
298        self.timeout = Some(duration);
299        self
300    }
301}
302
303/// Task that executes shell commands using tokio::process.
304///
305/// Executes external shell commands with configurable working directory,
306/// environment variables, and timeout settings. Supports process
307/// compensation for rollback operations.
308pub struct ShellCommandTask {
309    id: TaskId,
310    name: String,
311    config: ShellCommandConfig,
312    /// Last spawned process ID (for compensation)
313    last_pid: Arc<std::sync::Mutex<Option<u32>>>,
314}
315
316impl ShellCommandTask {
317    /// Creates a new ShellCommandTask with the given command.
318    ///
319    /// # Arguments
320    ///
321    /// * `id` - Task identifier
322    /// * `name` - Human-readable task name
323    /// * `command` - Command to execute (e.g., "echo", "ls", "cargo")
324    pub fn new(id: TaskId, name: String, command: impl Into<String>) -> Self {
325        Self {
326            id,
327            name,
328            config: ShellCommandConfig::new(command),
329            last_pid: Arc::new(std::sync::Mutex::new(None)),
330        }
331    }
332
333    /// Creates a new ShellCommandTask with a ShellCommandConfig.
334    ///
335    /// # Arguments
336    ///
337    /// * `id` - Task identifier
338    /// * `name` - Human-readable task name
339    /// * `config` - Shell command configuration
340    pub fn with_config(id: TaskId, name: String, config: ShellCommandConfig) -> Self {
341        Self {
342            id,
343            name,
344            config,
345            last_pid: Arc::new(std::sync::Mutex::new(None)),
346        }
347    }
348
349    /// Sets the arguments for the shell command.
350    ///
351    /// # Deprecated
352    ///
353    /// Use `with_config()` and `ShellCommandConfig::args()` instead.
354    #[deprecated(since = "0.4.0", note = "Use with_config() instead for better configurability")]
355    pub fn with_args(mut self, args: Vec<String>) -> Self {
356        self.config.args = args;
357        self
358    }
359
360    /// Gets the command for this task.
361    pub fn command(&self) -> &str {
362        &self.config.command
363    }
364
365    /// Gets the arguments for this task.
366    pub fn args(&self) -> &[String] {
367        &self.config.args
368    }
369
370    /// Gets the configuration for this task.
371    pub fn config(&self) -> &ShellCommandConfig {
372        &self.config
373    }
374}
375
376#[async_trait]
377impl WorkflowTask for ShellCommandTask {
378    async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
379        // Build the tokio process command
380        let mut cmd = tokio::process::Command::new(&self.config.command);
381
382        // Apply arguments
383        cmd.args(&self.config.args);
384
385        // Apply working directory if configured
386        if let Some(ref working_dir) = self.config.working_dir {
387            cmd.current_dir(working_dir);
388        }
389
390        // Apply environment variables
391        for (key, value) in &self.config.env {
392            cmd.env(key, value);
393        }
394
395        // Spawn the process
396        let child = cmd.spawn().map_err(|e| TaskError::Io(e))?;
397
398        // Store the process ID for compensation
399        if let Some(pid) = child.id() {
400            let mut last_pid = self.last_pid.lock().unwrap();
401            *last_pid = Some(pid);
402        }
403
404        // Wait for output with optional timeout
405        let output = if let Some(timeout) = self.config.timeout {
406            tokio::time::timeout(timeout, child.wait_with_output())
407                .await
408                .map_err(|_| TaskError::Timeout(format!("Command timed out after {:?}", timeout)))?
409                .map_err(TaskError::Io)?
410        } else {
411            child.wait_with_output().await.map_err(TaskError::Io)?
412        };
413
414        // Check exit status
415        if output.status.success() {
416            Ok(TaskResult::Success)
417        } else {
418            let exit_code = output.status.code().unwrap_or(-1);
419            let stderr = String::from_utf8_lossy(&output.stderr);
420            let error_msg = if !stderr.is_empty() {
421                format!("exit code: {}, stderr: {}", exit_code, stderr)
422            } else {
423                format!("exit code: {}", exit_code)
424            };
425            Ok(TaskResult::Failed(error_msg))
426        }
427    }
428
429    fn id(&self) -> TaskId {
430        self.id.clone()
431    }
432
433    fn name(&self) -> &str {
434        &self.name
435    }
436
437    fn compensation(&self) -> Option<CompensationAction> {
438        // Check if we spawned a process
439        let pid_guard = self.last_pid.lock().unwrap();
440        if let Some(pid) = *pid_guard {
441            // Return undo compensation for process termination
442            Some(CompensationAction::undo(format!(
443                "Terminate spawned process: {}",
444                pid
445            )))
446        } else {
447            // No process was spawned
448            Some(CompensationAction::skip("No process was spawned"))
449        }
450    }
451}
452
453/// Task that edits a file (stub for Phase 11).
454///
455/// Demonstrates the Saga compensation pattern with undo functionality.
456/// In Phase 11, this will be implemented with actual file editing.
457pub struct FileEditTask {
458    id: TaskId,
459    name: String,
460    file_path: PathBuf,
461    original_content: String,
462    new_content: String,
463}
464
465impl FileEditTask {
466    /// Creates a new FileEditTask.
467    ///
468    /// # Arguments
469    ///
470    /// * `id` - Task identifier
471    /// * `name` - Human-readable task name
472    /// * `file_path` - Path to the file to edit
473    /// * `original_content` - Original content (for rollback)
474    /// * `new_content` - New content to write
475    pub fn new(
476        id: TaskId,
477        name: String,
478        file_path: PathBuf,
479        original_content: String,
480        new_content: String,
481    ) -> Self {
482        Self {
483            id,
484            name,
485            file_path,
486            original_content,
487            new_content,
488        }
489    }
490
491    /// Gets the file path.
492    pub fn file_path(&self) -> &PathBuf {
493        &self.file_path
494    }
495
496    /// Gets the original content.
497    pub fn original_content(&self) -> &str {
498        &self.original_content
499    }
500
501    /// Gets the new content.
502    pub fn new_content(&self) -> &str {
503        &self.new_content
504    }
505}
506
507#[async_trait]
508impl WorkflowTask for FileEditTask {
509    async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
510        // Phase 8 stub - actual file editing will be implemented in Phase 11
511        // For now, return Success to indicate the task structure is valid
512        Ok(TaskResult::Success)
513    }
514
515    fn id(&self) -> TaskId {
516        self.id.clone()
517    }
518
519    fn name(&self) -> &str {
520        &self.name
521    }
522
523    fn compensation(&self) -> Option<CompensationAction> {
524        // Return undo compensation that restores original content
525        // This demonstrates the Saga compensation pattern
526        Some(CompensationAction::undo(format!(
527            "Restore original content of {}",
528            self.file_path.display()
529        )))
530    }
531}
532
533/// Task that invokes a registered tool from the ToolRegistry.
534///
535/// ToolTask executes external tools (magellan, cargo, splice, etc.) with
536/// configurable fallback handlers for error recovery.
537///
538/// # Example
539///
540/// ```ignore
541/// use forge_agent::workflow::tasks::ToolTask;
542/// use forge_agent::workflow::tools::ToolInvocation;
543/// use forge_agent::workflow::TaskId;
544///
545/// let task = ToolTask::new(
546///     TaskId::new("tool_task"),
547///     "Magellan Query".to_string(),
548///     "magellan"
549/// )
550/// .args(vec!["find".to_string(), "--name".to_string(), "symbol".to_string()]);
551/// ```
552pub struct ToolTask {
553    /// Task identifier
554    id: TaskId,
555    /// Human-readable task name
556    name: String,
557    /// Tool invocation specification
558    invocation: ToolInvocation,
559    /// Optional fallback handler for error recovery
560    fallback: Option<Arc<dyn FallbackHandler>>,
561}
562
563impl ToolTask {
564    /// Creates a new ToolTask.
565    ///
566    /// # Arguments
567    ///
568    /// * `id` - Task identifier
569    /// * `name` - Human-readable task name
570    /// * `tool_name` - Name of the registered tool to invoke
571    ///
572    /// # Example
573    ///
574    /// ```
575    /// use forge_agent::workflow::tasks::ToolTask;
576    /// use forge_agent::workflow::TaskId;
577    ///
578    /// let task = ToolTask::new(
579    ///     TaskId::new("tool_task"),
580    ///     "Query Magellan".to_string(),
581    ///     "magellan"
582    /// );
583    /// ```
584    pub fn new(id: TaskId, name: String, tool_name: impl Into<String>) -> Self {
585        Self {
586            id,
587            name,
588            invocation: ToolInvocation::new(tool_name),
589            fallback: None,
590        }
591    }
592
593    /// Sets the arguments for the tool invocation.
594    ///
595    /// # Arguments
596    ///
597    /// * `args` - Vector of argument strings
598    ///
599    /// # Returns
600    ///
601    /// Self for builder pattern chaining
602    ///
603    /// # Example
604    ///
605    /// ```
606    /// use forge_agent::workflow::tasks::ToolTask;
607    /// use forge_agent::workflow::TaskId;
608    ///
609    /// let task = ToolTask::new(
610    ///     TaskId::new("tool_task"),
611    ///     "Query Magellan".to_string(),
612    ///     "magellan"
613    /// )
614    /// .args(vec!["find".to_string(), "--name".to_string(), "symbol".to_string()]);
615    /// ```
616    pub fn args(mut self, args: Vec<String>) -> Self {
617        self.invocation = self.invocation.args(args);
618        self
619    }
620
621    /// Sets the working directory for the tool invocation.
622    ///
623    /// # Arguments
624    ///
625    /// * `dir` - Working directory path
626    ///
627    /// # Returns
628    ///
629    /// Self for builder pattern chaining
630    ///
631    /// # Example
632    ///
633    /// ```
634    /// use forge_agent::workflow::tasks::ToolTask;
635    /// use forge_agent::workflow::TaskId;
636    ///
637    /// let task = ToolTask::new(
638    ///     TaskId::new("tool_task"),
639    ///     "Run cargo".to_string(),
640    ///     "cargo"
641    /// )
642    /// .working_dir("/home/user/project");
643    /// ```
644    pub fn working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
645        self.invocation = self.invocation.working_dir(dir);
646        self
647    }
648
649    /// Adds an environment variable to the tool invocation.
650    ///
651    /// # Arguments
652    ///
653    /// * `key` - Environment variable name
654    /// * `value` - Environment variable value
655    ///
656    /// # Returns
657    ///
658    /// Self for builder pattern chaining
659    ///
660    /// # Example
661    ///
662    /// ```
663    /// use forge_agent::workflow::tasks::ToolTask;
664    /// use forge_agent::workflow::TaskId;
665    ///
666    /// let task = ToolTask::new(
667    ///     TaskId::new("tool_task"),
668    ///     "Run cargo".to_string(),
669    ///     "cargo"
670    /// )
671    /// .env("RUST_LOG", "debug");
672    /// ```
673    pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
674        self.invocation = self.invocation.env(key, value);
675        self
676    }
677
678    /// Sets the fallback handler for error recovery.
679    ///
680    /// # Arguments
681    ///
682    /// * `handler` - Fallback handler to use on tool failure
683    ///
684    /// # Returns
685    ///
686    /// Self for builder pattern chaining
687    ///
688    /// # Example
689    ///
690    /// ```
691    /// use forge_agent::workflow::tasks::ToolTask;
692    /// use forge_agent::workflow::tools::RetryFallback;
693    /// use forge_agent::workflow::TaskId;
694    ///
695    /// let task = ToolTask::new(
696    ///     TaskId::new("tool_task"),
697    ///     "Query Magellan".to_string(),
698    ///     "magellan"
699    /// )
700    /// .with_fallback(Box::new(RetryFallback::new(3, 100)));
701    /// ```
702    pub fn with_fallback(mut self, handler: Box<dyn FallbackHandler>) -> Self {
703        self.fallback = Some(Arc::from(handler));
704        self
705    }
706
707    /// Gets the tool name for this task.
708    pub fn tool_name(&self) -> &str {
709        &self.invocation.tool_name
710    }
711
712    /// Gets the invocation for this task.
713    pub fn invocation(&self) -> &ToolInvocation {
714        &self.invocation
715    }
716
717    /// Records a fallback activation event to the audit log.
718    async fn record_fallback_event(
719        context: &TaskContext,
720        tool_name: &str,
721        error: &ToolError,
722        fallback_action: &str,
723    ) {
724        use crate::audit::AuditEvent;
725        use chrono::Utc;
726
727        let event = AuditEvent::WorkflowToolFallback {
728            timestamp: Utc::now(),
729            workflow_id: context.workflow_id.clone(),
730            task_id: context.task_id.as_str().to_string(),
731            tool_name: tool_name.to_string(),
732            error: error.to_string(),
733            fallback_action: fallback_action.to_string(),
734        };
735
736        // Note: We can't directly record from here because we don't have mutable access
737        // The executor will need to record fallback events based on task results
738        // For now, we'll just drop the event on the floor
739        // TODO: This is a limitation of the current design
740        eprintln!("Fallback event: {} -> {}", tool_name, fallback_action);
741    }
742}
743
744#[async_trait]
745impl WorkflowTask for ToolTask {
746    async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
747        // Get the tool registry from context
748        let registry = context.tool_registry
749            .as_ref()
750            .ok_or_else(|| TaskError::ExecutionFailed(
751                "ToolRegistry not available in TaskContext".to_string()
752            ))?;
753
754        // Try to invoke the tool
755        let invocation_result = registry.invoke(&self.invocation).await;
756
757        match invocation_result {
758            Ok(result) => {
759                if result.result.success {
760                    Ok(TaskResult::Success)
761                } else {
762                    Ok(TaskResult::Failed(result.result.stderr))
763                }
764            }
765            Err(error) => {
766                // Try fallback handler if available
767                if let Some(ref fallback) = self.fallback {
768                    match fallback.handle(&error, &self.invocation).await {
769                        FallbackResult::Retry(retry_invocation) => {
770                            // Record fallback activation
771                            Self::record_fallback_event(
772                                context,
773                                &self.invocation.tool_name,
774                                &error,
775                                "Retry"
776                            ).await;
777
778                            // Retry with modified invocation
779                            match registry.invoke(&retry_invocation).await {
780                                Ok(retry_result) => {
781                                    if retry_result.result.success {
782                                        Ok(TaskResult::Success)
783                                    } else {
784                                        Ok(TaskResult::Failed(retry_result.result.stderr))
785                                    }
786                                }
787                                Err(retry_error) => {
788                                    Ok(TaskResult::Failed(format!(
789                                        "Tool {} failed after retry: {}",
790                                        self.invocation.tool_name,
791                                        retry_error
792                                    )))
793                                }
794                            }
795                        }
796                        FallbackResult::Skip(result) => {
797                            // Record fallback activation
798                            Self::record_fallback_event(
799                                context,
800                                &self.invocation.tool_name,
801                                &error,
802                                "Skip"
803                            ).await;
804
805                            Ok(result)
806                        }
807                        FallbackResult::Fail(fail_error) => {
808                            // Record fallback activation
809                            Self::record_fallback_event(
810                                context,
811                                &self.invocation.tool_name,
812                                &error,
813                                &format!("Fail: {}", fail_error)
814                            ).await;
815
816                            Ok(TaskResult::Failed(format!(
817                                "Tool {} failed: {}",
818                                self.invocation.tool_name,
819                                fail_error
820                            )))
821                        }
822                    }
823                } else {
824                    // No fallback handler, return error
825                    Ok(TaskResult::Failed(format!(
826                        "Tool {} failed: {}",
827                        self.invocation.tool_name,
828                        error
829                    )))
830                }
831            }
832        }
833    }
834
835    fn id(&self) -> TaskId {
836        self.id.clone()
837    }
838
839    fn name(&self) -> &str {
840        &self.name
841    }
842
843    fn compensation(&self) -> Option<CompensationAction> {
844        // Tool side effects are handled by ProcessGuard in the tool registry
845        // Return skip compensation
846        Some(CompensationAction::skip(
847            "Tool side effects handled by ProcessGuard"
848        ))
849    }
850}
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855
856    #[tokio::test]
857    async fn test_function_task() {
858        let task = FunctionTask::new(
859            TaskId::new("test_task"),
860            "Test Task".to_string(),
861            |_ctx| async { Ok(TaskResult::Success) },
862        );
863
864        let context = TaskContext::new("workflow_1", TaskId::new("test_task"));
865        let result = task.execute(&context).await.unwrap();
866
867        assert_eq!(result, TaskResult::Success);
868        assert_eq!(task.id(), TaskId::new("test_task"));
869        assert_eq!(task.name(), "Test Task");
870    }
871
872    #[tokio::test]
873    async fn test_agent_loop_task() {
874        let task = AgentLoopTask::new(
875            TaskId::new("agent_task"),
876            "Agent Task".to_string(),
877            "Find all functions",
878        );
879
880        assert_eq!(task.id(), TaskId::new("agent_task"));
881        assert_eq!(task.name(), "Agent Task");
882        assert_eq!(task.query(), "Find all functions");
883
884        let context = TaskContext::new("workflow_1", TaskId::new("agent_task"));
885        let result = task.execute(&context).await.unwrap();
886        assert_eq!(result, TaskResult::Success);
887    }
888
889    #[tokio::test]
890    async fn test_graph_query_task() {
891        let task = GraphQueryTask::find_symbol("process_data");
892
893        assert_eq!(task.query_type, GraphQueryType::FindSymbol);
894        assert_eq!(task.target, "process_data");
895
896        let context = TaskContext::new("workflow_1", task.id());
897        let result = task.execute(&context).await.unwrap();
898        assert_eq!(result, TaskResult::Success);
899    }
900
901    #[tokio::test]
902    async fn test_graph_query_references() {
903        let task = GraphQueryTask::references("my_function");
904
905        assert_eq!(task.query_type, GraphQueryType::References);
906        assert_eq!(task.target, "my_function");
907    }
908
909    #[tokio::test]
910    async fn test_graph_query_impact() {
911        let task = GraphQueryTask::impact_analysis("struct_name");
912
913        assert_eq!(task.query_type, GraphQueryType::ImpactAnalysis);
914        assert_eq!(task.target, "struct_name");
915    }
916
917    #[tokio::test]
918    async fn test_graph_query_with_custom_id() {
919        let task = GraphQueryTask::with_id(
920            TaskId::new("custom_id"),
921            GraphQueryType::FindSymbol,
922            "my_symbol",
923        );
924
925        assert_eq!(task.id(), TaskId::new("custom_id"));
926        assert_eq!(task.target, "my_symbol");
927    }
928
929    #[tokio::test]
930    async fn test_shell_command_task_stub() {
931        let task = ShellCommandTask::new(
932            TaskId::new("shell_task"),
933            "Shell Task".to_string(),
934            "echo",
935        ).with_args(vec!["hello".to_string(), "world".to_string()]);
936
937        assert_eq!(task.id(), TaskId::new("shell_task"));
938        assert_eq!(task.command(), "echo");
939        assert_eq!(task.args(), &["hello", "world"]);
940
941        let context = TaskContext::new("workflow_1", task.id());
942        let result = task.execute(&context).await.unwrap();
943        assert_eq!(result, TaskResult::Success);
944    }
945
946    #[tokio::test]
947    async fn test_shell_task_args_default() {
948        let task = ShellCommandTask::new(
949            TaskId::new("shell_task"),
950            "Shell Task".to_string(),
951            "ls",
952        );
953
954        assert_eq!(task.args().len(), 0);
955        assert!(task.args().is_empty());
956    }
957
958    #[tokio::test]
959    async fn test_shell_command_with_working_dir() {
960        // Create a temporary directory for testing
961        let temp_dir = std::env::temp_dir();
962        let test_file = temp_dir.join("test_shell_command.txt");
963
964        // Create a test file in the temp directory
965        std::fs::write(&test_file, "test content").unwrap();
966
967        // Create a task that lists files in the temp directory
968        let config = ShellCommandConfig::new("ls")
969            .args(vec![temp_dir.to_string_lossy().to_string()])
970            .working_dir(&temp_dir);
971
972        let task = ShellCommandTask::with_config(
973            TaskId::new("shell_task"),
974            "Shell Task".to_string(),
975            config,
976        );
977
978        let context = TaskContext::new("workflow_1", task.id());
979        let result = task.execute(&context).await.unwrap();
980
981        // Command should succeed
982        assert_eq!(result, TaskResult::Success);
983
984        // Clean up
985        std::fs::remove_file(&test_file).ok();
986    }
987
988    #[tokio::test]
989    async fn test_shell_command_with_env() {
990        // Create a task that reads an environment variable
991        let config = ShellCommandConfig::new("sh")
992            .args(vec!["-c".to_string(), "echo $TEST_VAR".to_string()])
993            .env("TEST_VAR", "test_value");
994
995        let task = ShellCommandTask::with_config(
996            TaskId::new("shell_task"),
997            "Shell Task".to_string(),
998            config,
999        );
1000
1001        let context = TaskContext::new("workflow_1", task.id());
1002        let result = task.execute(&context).await.unwrap();
1003
1004        // Command should succeed
1005        assert_eq!(result, TaskResult::Success);
1006    }
1007
1008    #[tokio::test]
1009    async fn test_shell_command_compensation() {
1010        // Create a task that spawns a long-running process
1011        // For testing, we use echo which exits immediately
1012        let task = ShellCommandTask::new(
1013            TaskId::new("shell_task"),
1014            "Shell Task".to_string(),
1015            "echo",
1016        ).with_args(vec!["test".to_string()]);
1017
1018        // Before execution, compensation should indicate no process spawned
1019        let compensation = task.compensation();
1020        assert!(compensation.is_some());
1021        assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1022
1023        // Execute the task
1024        let context = TaskContext::new("workflow_1", task.id());
1025        let result = task.execute(&context).await.unwrap();
1026        assert_eq!(result, TaskResult::Success);
1027
1028        // After execution, compensation should indicate process termination
1029        let compensation = task.compensation();
1030        assert!(compensation.is_some());
1031        assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::UndoFunction);
1032    }
1033
1034    #[tokio::test]
1035    async fn test_graph_query_compensation_skip() {
1036        let task = GraphQueryTask::find_symbol("my_function");
1037
1038        // Graph queries should have Skip compensation
1039        let compensation = task.compensation();
1040        assert!(compensation.is_some());
1041        assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1042    }
1043
1044    #[tokio::test]
1045    async fn test_agent_loop_compensation_skip() {
1046        let task = AgentLoopTask::new(
1047            TaskId::new("agent_task"),
1048            "Agent Task".to_string(),
1049            "Find all functions",
1050        );
1051
1052        // AgentLoop should have Skip compensation in v0.4
1053        let compensation = task.compensation();
1054        assert!(compensation.is_some());
1055        assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1056    }
1057
1058    #[tokio::test]
1059    async fn test_file_edit_compensation_undo() {
1060        let task = FileEditTask::new(
1061            TaskId::new("file_edit"),
1062            "Edit File".to_string(),
1063            PathBuf::from("/tmp/test.txt"),
1064            "original".to_string(),
1065            "new".to_string(),
1066        );
1067
1068        // FileEdit should have UndoFunction compensation
1069        let compensation = task.compensation();
1070        assert!(compensation.is_some());
1071        assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::UndoFunction);
1072    }
1073
1074    // ============== ToolTask Tests ==============
1075
1076    #[tokio::test]
1077    async fn test_tool_task_creation() {
1078        let task = ToolTask::new(
1079            TaskId::new("tool_task"),
1080            "Echo Tool".to_string(),
1081            "echo"
1082        );
1083
1084        assert_eq!(task.id(), TaskId::new("tool_task"));
1085        assert_eq!(task.name(), "Echo Tool");
1086        assert_eq!(task.tool_name(), "echo");
1087        assert!(task.invocation().args.is_empty());
1088        assert!(task.fallback.is_none());
1089    }
1090
1091    #[tokio::test]
1092    async fn test_tool_task_with_args() {
1093        let task = ToolTask::new(
1094            TaskId::new("tool_task"),
1095            "Echo Tool".to_string(),
1096            "echo"
1097        )
1098        .args(vec!["hello".to_string(), "world".to_string()]);
1099
1100        assert_eq!(task.invocation().args.len(), 2);
1101        assert_eq!(task.invocation().args[0], "hello");
1102        assert_eq!(task.invocation().args[1], "world");
1103    }
1104
1105    #[tokio::test]
1106    async fn test_tool_task_with_working_dir() {
1107        let task = ToolTask::new(
1108            TaskId::new("tool_task"),
1109            "Cargo Test".to_string(),
1110            "cargo"
1111        )
1112        .working_dir("/home/user/project");
1113
1114        assert_eq!(
1115            task.invocation().working_dir,
1116            Some(PathBuf::from("/home/user/project"))
1117        );
1118    }
1119
1120    #[tokio::test]
1121    async fn test_tool_task_with_env() {
1122        let task = ToolTask::new(
1123            TaskId::new("tool_task"),
1124            "Cargo Test".to_string(),
1125            "cargo"
1126        )
1127        .env("RUST_LOG", "debug");
1128
1129        assert_eq!(task.invocation().env.len(), 1);
1130        assert_eq!(task.invocation().env.get("RUST_LOG"), Some(&"debug".to_string()));
1131    }
1132
1133    #[tokio::test]
1134    async fn test_tool_task_builder_pattern() {
1135        let task = ToolTask::new(
1136            TaskId::new("tool_task"),
1137            "Cargo Test".to_string(),
1138            "cargo"
1139        )
1140        .args(vec!["test".to_string()])
1141        .working_dir("/tmp")
1142        .env("TEST_VAR", "value");
1143
1144        assert_eq!(task.invocation().args.len(), 1);
1145        assert_eq!(task.invocation().working_dir, Some(PathBuf::from("/tmp")));
1146        assert_eq!(task.invocation().env.get("TEST_VAR"), Some(&"value".to_string()));
1147    }
1148
1149    #[tokio::test]
1150    async fn test_tool_task_compensation() {
1151        let task = ToolTask::new(
1152            TaskId::new("tool_task"),
1153            "Echo Tool".to_string(),
1154            "echo"
1155        );
1156
1157        // ToolTask should have Skip compensation
1158        let compensation = task.compensation();
1159        assert!(compensation.is_some());
1160        assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1161    }
1162
1163    #[tokio::test]
1164    async fn test_tool_task_execution() {
1165        use std::sync::Arc;
1166
1167        // Create a tool registry with echo
1168        let mut registry = crate::workflow::tools::ToolRegistry::new();
1169        registry.register(crate::workflow::tools::Tool::new("echo", "echo")).unwrap();
1170
1171        // Create a context with the tool registry
1172        let context = TaskContext::new("workflow_1", TaskId::new("tool_task"))
1173            .with_tool_registry(Arc::new(registry));
1174
1175        // Create a tool task
1176        let task = ToolTask::new(
1177            TaskId::new("tool_task"),
1178            "Echo Tool".to_string(),
1179            "echo"
1180        )
1181        .args(vec!["test".to_string()]);
1182
1183        // Execute the task
1184        let result = task.execute(&context).await.unwrap();
1185        assert_eq!(result, TaskResult::Success);
1186    }
1187
1188    #[tokio::test]
1189    async fn test_tool_task_with_fallback() {
1190        use std::sync::Arc;
1191
1192        // Create a tool registry with echo
1193        let mut registry = crate::workflow::tools::ToolRegistry::new();
1194        registry.register(crate::workflow::tools::Tool::new("echo", "echo")).unwrap();
1195
1196        // Create a context with the tool registry
1197        let context = TaskContext::new("workflow_1", TaskId::new("tool_task"))
1198            .with_tool_registry(Arc::new(registry));
1199
1200        // Create a tool task with skip fallback
1201        let task = ToolTask::new(
1202            TaskId::new("tool_task"),
1203            "Nonexistent Tool".to_string(),
1204            "nonexistent"  // Tool not registered
1205        )
1206        .with_fallback(Box::new(crate::workflow::tools::SkipFallback::skip()));
1207
1208        // Execute the task - should use fallback
1209        let result = task.execute(&context).await.unwrap();
1210        assert_eq!(result, TaskResult::Skipped);
1211    }
1212
1213    #[tokio::test]
1214    async fn test_standard_tools() {
1215        use crate::workflow::tools::ToolRegistry;
1216
1217        let registry = ToolRegistry::with_standard_tools();
1218
1219        // Check that at least some tools might be registered
1220        // (we can't assume all tools are present on the system)
1221        let tool_count = registry.len();
1222        eprintln!("Standard tools registered: {}", tool_count);
1223
1224        // Registry should be created successfully (even if no tools found)
1225        // This is a basic smoke test
1226        assert!(tool_count >= 0);
1227    }
1228
1229    #[tokio::test]
1230    async fn test_tool_invoke_from_workflow() {
1231        use crate::workflow::dag::Workflow;
1232        use crate::workflow::executor::WorkflowExecutor;
1233        use crate::workflow::tools::{Tool, ToolRegistry};
1234        use std::sync::Arc;
1235
1236        // Create a workflow with a tool task
1237        let mut workflow = Workflow::new();
1238        let task_id = TaskId::new("tool_task");
1239
1240        // Create tool registry with echo
1241        let mut registry = ToolRegistry::new();
1242        registry.register(Tool::new("echo", "echo")).unwrap();
1243
1244        let tool_task = ToolTask::new(
1245            task_id.clone(),
1246            "Echo Tool".to_string(),
1247            "echo"
1248        )
1249        .args(vec!["hello".to_string()]);
1250
1251        workflow.add_task(Box::new(tool_task));
1252
1253        // Create executor with tool registry
1254        let mut executor = WorkflowExecutor::new(workflow)
1255            .with_tool_registry(registry);
1256
1257        // Execute the workflow
1258        let result = executor.execute().await.unwrap();
1259        assert!(result.success);
1260        assert!(result.completed_tasks.contains(&task_id));
1261    }
1262
1263    #[tokio::test]
1264    async fn test_tool_fallback_audit_event() {
1265        use crate::audit::{AuditEvent, AuditLog};
1266
1267        // Create an audit log
1268        let audit_log = AuditLog::new();
1269
1270        // Create a tool registry with echo
1271        let mut registry = crate::workflow::tools::ToolRegistry::new();
1272        registry.register(crate::workflow::tools::Tool::new("echo", "echo")).unwrap();
1273
1274        // Create a context with the tool registry and audit log
1275        let context = TaskContext::new("workflow_1", TaskId::new("tool_task"))
1276            .with_tool_registry(Arc::new(registry))
1277            .with_audit_log(audit_log);
1278
1279        // Create a tool task with skip fallback
1280        let task = ToolTask::new(
1281            TaskId::new("tool_task"),
1282            "Nonexistent Tool".to_string(),
1283            "nonexistent"  // Tool not registered
1284        )
1285        .with_fallback(Box::new(crate::workflow::tools::SkipFallback::skip()));
1286
1287        // Execute the task - should trigger fallback
1288        let result = task.execute(&context).await.unwrap();
1289        assert_eq!(result, TaskResult::Skipped);
1290
1291        // Note: Audit event recording from within tasks is a limitation of the current design
1292        // The executor records events, but tasks can't easily record to the audit log
1293        // without mutable access. For now, we just verify the fallback works correctly.
1294    }
1295}