1use crate::workflow::task::{CompensationAction, TaskContext, TaskError, TaskId, TaskResult, WorkflowTask};
7use crate::workflow::tools::{FallbackHandler, FallbackResult, ToolError, ToolInvocation, ToolRegistry};
8use async_trait::async_trait;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::pin::Pin;
13use std::sync::Arc;
14use std::future::Future;
15use std::time::Duration;
16use std::process::Command;
17
18pub struct FunctionTask {
38 id: TaskId,
39 name: String,
40 f: Box<dyn Fn(&TaskContext) -> Pin<Box<dyn Future<Output = Result<TaskResult, TaskError>> + Send>> + Send + Sync>,
41}
42
43impl FunctionTask {
44 pub fn new<F, Fut>(id: TaskId, name: String, f: F) -> Self
46 where
47 F: Fn(&TaskContext) -> Fut + Send + Sync + 'static,
48 Fut: Future<Output = Result<TaskResult, TaskError>> + Send + 'static,
49 {
50 Self {
51 id,
52 name,
53 f: Box::new(move |ctx| Box::pin(f(ctx)) as Pin<Box<dyn Future<Output = Result<TaskResult, TaskError>> + Send>>),
54 }
55 }
56}
57
58#[async_trait]
59impl WorkflowTask for FunctionTask {
60 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
61 (self.f)(context).await
62 }
63
64 fn id(&self) -> TaskId {
65 self.id.clone()
66 }
67
68 fn name(&self) -> &str {
69 &self.name
70 }
71}
72
73#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
75pub enum GraphQueryType {
76 FindSymbol,
78 References,
80 ImpactAnalysis,
82}
83
84pub struct GraphQueryTask {
88 id: TaskId,
89 name: String,
90 query_type: GraphQueryType,
91 target: String,
92}
93
94impl GraphQueryTask {
95 pub fn find_symbol(target: impl Into<String>) -> Self {
97 Self::new(GraphQueryType::FindSymbol, target)
98 }
99
100 pub fn references(target: impl Into<String>) -> Self {
102 Self::new(GraphQueryType::References, target)
103 }
104
105 pub fn impact_analysis(target: impl Into<String>) -> Self {
107 Self::new(GraphQueryType::ImpactAnalysis, target)
108 }
109
110 fn new(query_type: GraphQueryType, target: impl Into<String>) -> Self {
111 let target_str = target.into();
112 Self {
113 id: TaskId::new(format!("graph_query_{:?}", query_type)),
114 name: format!("Graph Query: {:?}", query_type),
115 query_type,
116 target: target_str,
117 }
118 }
119
120 pub fn with_id(id: TaskId, query_type: GraphQueryType, target: impl Into<String>) -> Self {
122 Self {
123 id,
124 name: format!("Graph Query: {:?}", query_type),
125 query_type,
126 target: target.into(),
127 }
128 }
129}
130
131#[async_trait]
132impl WorkflowTask for GraphQueryTask {
133 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
134 match self.query_type {
137 GraphQueryType::FindSymbol => {
138 Ok(TaskResult::Success)
139 }
140 GraphQueryType::References => {
141 Ok(TaskResult::Success)
142 }
143 GraphQueryType::ImpactAnalysis => {
144 Ok(TaskResult::Success)
145 }
146 }
147 }
148
149 fn id(&self) -> TaskId {
150 self.id.clone()
151 }
152
153 fn name(&self) -> &str {
154 &self.name
155 }
156
157 fn compensation(&self) -> Option<CompensationAction> {
158 Some(CompensationAction::skip("Read-only graph query - no undo needed"))
160 }
161}
162
163pub struct AgentLoopTask {
167 id: TaskId,
168 name: String,
169 query: String,
170}
171
172impl AgentLoopTask {
173 pub fn new(id: TaskId, name: String, query: impl Into<String>) -> Self {
175 Self {
176 id,
177 name,
178 query: query.into(),
179 }
180 }
181
182 pub fn query(&self) -> &str {
184 &self.query
185 }
186}
187
188#[async_trait]
189impl WorkflowTask for AgentLoopTask {
190 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
191 Ok(TaskResult::Success)
194 }
195
196 fn id(&self) -> TaskId {
197 self.id.clone()
198 }
199
200 fn name(&self) -> &str {
201 &self.name
202 }
203
204 fn compensation(&self) -> Option<CompensationAction> {
205 Some(CompensationAction::skip("Read-only agent loop - no undo needed in v0.4"))
208 }
209}
210
211#[derive(Clone, Debug, PartialEq)]
216pub struct ShellCommandConfig {
217 pub command: String,
219 pub args: Vec<String>,
221 pub working_dir: Option<PathBuf>,
223 pub env: HashMap<String, String>,
225 pub timeout: Option<Duration>,
227}
228
229impl ShellCommandConfig {
230 pub fn new(command: impl Into<String>) -> Self {
236 Self {
237 command: command.into(),
238 args: Vec::new(),
239 working_dir: None,
240 env: HashMap::new(),
241 timeout: None,
242 }
243 }
244
245 pub fn args(mut self, args: Vec<String>) -> Self {
255 self.args = args;
256 self
257 }
258
259 pub fn working_dir(mut self, path: impl Into<PathBuf>) -> Self {
269 self.working_dir = Some(path.into());
270 self
271 }
272
273 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
284 self.env.insert(key.into(), value.into());
285 self
286 }
287
288 pub fn timeout(mut self, duration: Duration) -> Self {
298 self.timeout = Some(duration);
299 self
300 }
301}
302
303pub struct ShellCommandTask {
309 id: TaskId,
310 name: String,
311 config: ShellCommandConfig,
312 last_pid: Arc<std::sync::Mutex<Option<u32>>>,
314}
315
316impl ShellCommandTask {
317 pub fn new(id: TaskId, name: String, command: impl Into<String>) -> Self {
325 Self {
326 id,
327 name,
328 config: ShellCommandConfig::new(command),
329 last_pid: Arc::new(std::sync::Mutex::new(None)),
330 }
331 }
332
333 pub fn with_config(id: TaskId, name: String, config: ShellCommandConfig) -> Self {
341 Self {
342 id,
343 name,
344 config,
345 last_pid: Arc::new(std::sync::Mutex::new(None)),
346 }
347 }
348
349 #[deprecated(since = "0.4.0", note = "Use with_config() instead for better configurability")]
355 pub fn with_args(mut self, args: Vec<String>) -> Self {
356 self.config.args = args;
357 self
358 }
359
360 pub fn command(&self) -> &str {
362 &self.config.command
363 }
364
365 pub fn args(&self) -> &[String] {
367 &self.config.args
368 }
369
370 pub fn config(&self) -> &ShellCommandConfig {
372 &self.config
373 }
374}
375
376#[async_trait]
377impl WorkflowTask for ShellCommandTask {
378 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
379 let mut cmd = tokio::process::Command::new(&self.config.command);
381
382 cmd.args(&self.config.args);
384
385 if let Some(ref working_dir) = self.config.working_dir {
387 cmd.current_dir(working_dir);
388 }
389
390 for (key, value) in &self.config.env {
392 cmd.env(key, value);
393 }
394
395 let child = cmd.spawn().map_err(|e| TaskError::Io(e))?;
397
398 if let Some(pid) = child.id() {
400 let mut last_pid = self.last_pid.lock().unwrap();
401 *last_pid = Some(pid);
402 }
403
404 let output = if let Some(timeout) = self.config.timeout {
406 tokio::time::timeout(timeout, child.wait_with_output())
407 .await
408 .map_err(|_| TaskError::Timeout(format!("Command timed out after {:?}", timeout)))?
409 .map_err(TaskError::Io)?
410 } else {
411 child.wait_with_output().await.map_err(TaskError::Io)?
412 };
413
414 if output.status.success() {
416 Ok(TaskResult::Success)
417 } else {
418 let exit_code = output.status.code().unwrap_or(-1);
419 let stderr = String::from_utf8_lossy(&output.stderr);
420 let error_msg = if !stderr.is_empty() {
421 format!("exit code: {}, stderr: {}", exit_code, stderr)
422 } else {
423 format!("exit code: {}", exit_code)
424 };
425 Ok(TaskResult::Failed(error_msg))
426 }
427 }
428
429 fn id(&self) -> TaskId {
430 self.id.clone()
431 }
432
433 fn name(&self) -> &str {
434 &self.name
435 }
436
437 fn compensation(&self) -> Option<CompensationAction> {
438 let pid_guard = self.last_pid.lock().unwrap();
440 if let Some(pid) = *pid_guard {
441 Some(CompensationAction::undo(format!(
443 "Terminate spawned process: {}",
444 pid
445 )))
446 } else {
447 Some(CompensationAction::skip("No process was spawned"))
449 }
450 }
451}
452
453pub struct FileEditTask {
458 id: TaskId,
459 name: String,
460 file_path: PathBuf,
461 original_content: String,
462 new_content: String,
463}
464
465impl FileEditTask {
466 pub fn new(
476 id: TaskId,
477 name: String,
478 file_path: PathBuf,
479 original_content: String,
480 new_content: String,
481 ) -> Self {
482 Self {
483 id,
484 name,
485 file_path,
486 original_content,
487 new_content,
488 }
489 }
490
491 pub fn file_path(&self) -> &PathBuf {
493 &self.file_path
494 }
495
496 pub fn original_content(&self) -> &str {
498 &self.original_content
499 }
500
501 pub fn new_content(&self) -> &str {
503 &self.new_content
504 }
505}
506
507#[async_trait]
508impl WorkflowTask for FileEditTask {
509 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
510 Ok(TaskResult::Success)
513 }
514
515 fn id(&self) -> TaskId {
516 self.id.clone()
517 }
518
519 fn name(&self) -> &str {
520 &self.name
521 }
522
523 fn compensation(&self) -> Option<CompensationAction> {
524 Some(CompensationAction::undo(format!(
527 "Restore original content of {}",
528 self.file_path.display()
529 )))
530 }
531}
532
533pub struct ToolTask {
553 id: TaskId,
555 name: String,
557 invocation: ToolInvocation,
559 fallback: Option<Arc<dyn FallbackHandler>>,
561}
562
563impl ToolTask {
564 pub fn new(id: TaskId, name: String, tool_name: impl Into<String>) -> Self {
585 Self {
586 id,
587 name,
588 invocation: ToolInvocation::new(tool_name),
589 fallback: None,
590 }
591 }
592
593 pub fn args(mut self, args: Vec<String>) -> Self {
617 self.invocation = self.invocation.args(args);
618 self
619 }
620
621 pub fn working_dir(mut self, dir: impl Into<PathBuf>) -> Self {
645 self.invocation = self.invocation.working_dir(dir);
646 self
647 }
648
649 pub fn env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
674 self.invocation = self.invocation.env(key, value);
675 self
676 }
677
678 pub fn with_fallback(mut self, handler: Box<dyn FallbackHandler>) -> Self {
703 self.fallback = Some(Arc::from(handler));
704 self
705 }
706
707 pub fn tool_name(&self) -> &str {
709 &self.invocation.tool_name
710 }
711
712 pub fn invocation(&self) -> &ToolInvocation {
714 &self.invocation
715 }
716
717 async fn record_fallback_event(
719 context: &TaskContext,
720 tool_name: &str,
721 error: &ToolError,
722 fallback_action: &str,
723 ) {
724 use crate::audit::AuditEvent;
725 use chrono::Utc;
726
727 let event = AuditEvent::WorkflowToolFallback {
728 timestamp: Utc::now(),
729 workflow_id: context.workflow_id.clone(),
730 task_id: context.task_id.as_str().to_string(),
731 tool_name: tool_name.to_string(),
732 error: error.to_string(),
733 fallback_action: fallback_action.to_string(),
734 };
735
736 eprintln!("Fallback event: {} -> {}", tool_name, fallback_action);
741 }
742}
743
744#[async_trait]
745impl WorkflowTask for ToolTask {
746 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
747 let registry = context.tool_registry
749 .as_ref()
750 .ok_or_else(|| TaskError::ExecutionFailed(
751 "ToolRegistry not available in TaskContext".to_string()
752 ))?;
753
754 let invocation_result = registry.invoke(&self.invocation).await;
756
757 match invocation_result {
758 Ok(result) => {
759 if result.result.success {
760 Ok(TaskResult::Success)
761 } else {
762 Ok(TaskResult::Failed(result.result.stderr))
763 }
764 }
765 Err(error) => {
766 if let Some(ref fallback) = self.fallback {
768 match fallback.handle(&error, &self.invocation).await {
769 FallbackResult::Retry(retry_invocation) => {
770 Self::record_fallback_event(
772 context,
773 &self.invocation.tool_name,
774 &error,
775 "Retry"
776 ).await;
777
778 match registry.invoke(&retry_invocation).await {
780 Ok(retry_result) => {
781 if retry_result.result.success {
782 Ok(TaskResult::Success)
783 } else {
784 Ok(TaskResult::Failed(retry_result.result.stderr))
785 }
786 }
787 Err(retry_error) => {
788 Ok(TaskResult::Failed(format!(
789 "Tool {} failed after retry: {}",
790 self.invocation.tool_name,
791 retry_error
792 )))
793 }
794 }
795 }
796 FallbackResult::Skip(result) => {
797 Self::record_fallback_event(
799 context,
800 &self.invocation.tool_name,
801 &error,
802 "Skip"
803 ).await;
804
805 Ok(result)
806 }
807 FallbackResult::Fail(fail_error) => {
808 Self::record_fallback_event(
810 context,
811 &self.invocation.tool_name,
812 &error,
813 &format!("Fail: {}", fail_error)
814 ).await;
815
816 Ok(TaskResult::Failed(format!(
817 "Tool {} failed: {}",
818 self.invocation.tool_name,
819 fail_error
820 )))
821 }
822 }
823 } else {
824 Ok(TaskResult::Failed(format!(
826 "Tool {} failed: {}",
827 self.invocation.tool_name,
828 error
829 )))
830 }
831 }
832 }
833 }
834
835 fn id(&self) -> TaskId {
836 self.id.clone()
837 }
838
839 fn name(&self) -> &str {
840 &self.name
841 }
842
843 fn compensation(&self) -> Option<CompensationAction> {
844 Some(CompensationAction::skip(
847 "Tool side effects handled by ProcessGuard"
848 ))
849 }
850}
851
852#[cfg(test)]
853mod tests {
854 use super::*;
855
856 #[tokio::test]
857 async fn test_function_task() {
858 let task = FunctionTask::new(
859 TaskId::new("test_task"),
860 "Test Task".to_string(),
861 |_ctx| async { Ok(TaskResult::Success) },
862 );
863
864 let context = TaskContext::new("workflow_1", TaskId::new("test_task"));
865 let result = task.execute(&context).await.unwrap();
866
867 assert_eq!(result, TaskResult::Success);
868 assert_eq!(task.id(), TaskId::new("test_task"));
869 assert_eq!(task.name(), "Test Task");
870 }
871
872 #[tokio::test]
873 async fn test_agent_loop_task() {
874 let task = AgentLoopTask::new(
875 TaskId::new("agent_task"),
876 "Agent Task".to_string(),
877 "Find all functions",
878 );
879
880 assert_eq!(task.id(), TaskId::new("agent_task"));
881 assert_eq!(task.name(), "Agent Task");
882 assert_eq!(task.query(), "Find all functions");
883
884 let context = TaskContext::new("workflow_1", TaskId::new("agent_task"));
885 let result = task.execute(&context).await.unwrap();
886 assert_eq!(result, TaskResult::Success);
887 }
888
889 #[tokio::test]
890 async fn test_graph_query_task() {
891 let task = GraphQueryTask::find_symbol("process_data");
892
893 assert_eq!(task.query_type, GraphQueryType::FindSymbol);
894 assert_eq!(task.target, "process_data");
895
896 let context = TaskContext::new("workflow_1", task.id());
897 let result = task.execute(&context).await.unwrap();
898 assert_eq!(result, TaskResult::Success);
899 }
900
901 #[tokio::test]
902 async fn test_graph_query_references() {
903 let task = GraphQueryTask::references("my_function");
904
905 assert_eq!(task.query_type, GraphQueryType::References);
906 assert_eq!(task.target, "my_function");
907 }
908
909 #[tokio::test]
910 async fn test_graph_query_impact() {
911 let task = GraphQueryTask::impact_analysis("struct_name");
912
913 assert_eq!(task.query_type, GraphQueryType::ImpactAnalysis);
914 assert_eq!(task.target, "struct_name");
915 }
916
917 #[tokio::test]
918 async fn test_graph_query_with_custom_id() {
919 let task = GraphQueryTask::with_id(
920 TaskId::new("custom_id"),
921 GraphQueryType::FindSymbol,
922 "my_symbol",
923 );
924
925 assert_eq!(task.id(), TaskId::new("custom_id"));
926 assert_eq!(task.target, "my_symbol");
927 }
928
929 #[tokio::test]
930 async fn test_shell_command_task_stub() {
931 let task = ShellCommandTask::new(
932 TaskId::new("shell_task"),
933 "Shell Task".to_string(),
934 "echo",
935 ).with_args(vec!["hello".to_string(), "world".to_string()]);
936
937 assert_eq!(task.id(), TaskId::new("shell_task"));
938 assert_eq!(task.command(), "echo");
939 assert_eq!(task.args(), &["hello", "world"]);
940
941 let context = TaskContext::new("workflow_1", task.id());
942 let result = task.execute(&context).await.unwrap();
943 assert_eq!(result, TaskResult::Success);
944 }
945
946 #[tokio::test]
947 async fn test_shell_task_args_default() {
948 let task = ShellCommandTask::new(
949 TaskId::new("shell_task"),
950 "Shell Task".to_string(),
951 "ls",
952 );
953
954 assert_eq!(task.args().len(), 0);
955 assert!(task.args().is_empty());
956 }
957
958 #[tokio::test]
959 async fn test_shell_command_with_working_dir() {
960 let temp_dir = std::env::temp_dir();
962 let test_file = temp_dir.join("test_shell_command.txt");
963
964 std::fs::write(&test_file, "test content").unwrap();
966
967 let config = ShellCommandConfig::new("ls")
969 .args(vec![temp_dir.to_string_lossy().to_string()])
970 .working_dir(&temp_dir);
971
972 let task = ShellCommandTask::with_config(
973 TaskId::new("shell_task"),
974 "Shell Task".to_string(),
975 config,
976 );
977
978 let context = TaskContext::new("workflow_1", task.id());
979 let result = task.execute(&context).await.unwrap();
980
981 assert_eq!(result, TaskResult::Success);
983
984 std::fs::remove_file(&test_file).ok();
986 }
987
988 #[tokio::test]
989 async fn test_shell_command_with_env() {
990 let config = ShellCommandConfig::new("sh")
992 .args(vec!["-c".to_string(), "echo $TEST_VAR".to_string()])
993 .env("TEST_VAR", "test_value");
994
995 let task = ShellCommandTask::with_config(
996 TaskId::new("shell_task"),
997 "Shell Task".to_string(),
998 config,
999 );
1000
1001 let context = TaskContext::new("workflow_1", task.id());
1002 let result = task.execute(&context).await.unwrap();
1003
1004 assert_eq!(result, TaskResult::Success);
1006 }
1007
1008 #[tokio::test]
1009 async fn test_shell_command_compensation() {
1010 let task = ShellCommandTask::new(
1013 TaskId::new("shell_task"),
1014 "Shell Task".to_string(),
1015 "echo",
1016 ).with_args(vec!["test".to_string()]);
1017
1018 let compensation = task.compensation();
1020 assert!(compensation.is_some());
1021 assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1022
1023 let context = TaskContext::new("workflow_1", task.id());
1025 let result = task.execute(&context).await.unwrap();
1026 assert_eq!(result, TaskResult::Success);
1027
1028 let compensation = task.compensation();
1030 assert!(compensation.is_some());
1031 assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::UndoFunction);
1032 }
1033
1034 #[tokio::test]
1035 async fn test_graph_query_compensation_skip() {
1036 let task = GraphQueryTask::find_symbol("my_function");
1037
1038 let compensation = task.compensation();
1040 assert!(compensation.is_some());
1041 assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1042 }
1043
1044 #[tokio::test]
1045 async fn test_agent_loop_compensation_skip() {
1046 let task = AgentLoopTask::new(
1047 TaskId::new("agent_task"),
1048 "Agent Task".to_string(),
1049 "Find all functions",
1050 );
1051
1052 let compensation = task.compensation();
1054 assert!(compensation.is_some());
1055 assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1056 }
1057
1058 #[tokio::test]
1059 async fn test_file_edit_compensation_undo() {
1060 let task = FileEditTask::new(
1061 TaskId::new("file_edit"),
1062 "Edit File".to_string(),
1063 PathBuf::from("/tmp/test.txt"),
1064 "original".to_string(),
1065 "new".to_string(),
1066 );
1067
1068 let compensation = task.compensation();
1070 assert!(compensation.is_some());
1071 assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::UndoFunction);
1072 }
1073
1074 #[tokio::test]
1077 async fn test_tool_task_creation() {
1078 let task = ToolTask::new(
1079 TaskId::new("tool_task"),
1080 "Echo Tool".to_string(),
1081 "echo"
1082 );
1083
1084 assert_eq!(task.id(), TaskId::new("tool_task"));
1085 assert_eq!(task.name(), "Echo Tool");
1086 assert_eq!(task.tool_name(), "echo");
1087 assert!(task.invocation().args.is_empty());
1088 assert!(task.fallback.is_none());
1089 }
1090
1091 #[tokio::test]
1092 async fn test_tool_task_with_args() {
1093 let task = ToolTask::new(
1094 TaskId::new("tool_task"),
1095 "Echo Tool".to_string(),
1096 "echo"
1097 )
1098 .args(vec!["hello".to_string(), "world".to_string()]);
1099
1100 assert_eq!(task.invocation().args.len(), 2);
1101 assert_eq!(task.invocation().args[0], "hello");
1102 assert_eq!(task.invocation().args[1], "world");
1103 }
1104
1105 #[tokio::test]
1106 async fn test_tool_task_with_working_dir() {
1107 let task = ToolTask::new(
1108 TaskId::new("tool_task"),
1109 "Cargo Test".to_string(),
1110 "cargo"
1111 )
1112 .working_dir("/home/user/project");
1113
1114 assert_eq!(
1115 task.invocation().working_dir,
1116 Some(PathBuf::from("/home/user/project"))
1117 );
1118 }
1119
1120 #[tokio::test]
1121 async fn test_tool_task_with_env() {
1122 let task = ToolTask::new(
1123 TaskId::new("tool_task"),
1124 "Cargo Test".to_string(),
1125 "cargo"
1126 )
1127 .env("RUST_LOG", "debug");
1128
1129 assert_eq!(task.invocation().env.len(), 1);
1130 assert_eq!(task.invocation().env.get("RUST_LOG"), Some(&"debug".to_string()));
1131 }
1132
1133 #[tokio::test]
1134 async fn test_tool_task_builder_pattern() {
1135 let task = ToolTask::new(
1136 TaskId::new("tool_task"),
1137 "Cargo Test".to_string(),
1138 "cargo"
1139 )
1140 .args(vec!["test".to_string()])
1141 .working_dir("/tmp")
1142 .env("TEST_VAR", "value");
1143
1144 assert_eq!(task.invocation().args.len(), 1);
1145 assert_eq!(task.invocation().working_dir, Some(PathBuf::from("/tmp")));
1146 assert_eq!(task.invocation().env.get("TEST_VAR"), Some(&"value".to_string()));
1147 }
1148
1149 #[tokio::test]
1150 async fn test_tool_task_compensation() {
1151 let task = ToolTask::new(
1152 TaskId::new("tool_task"),
1153 "Echo Tool".to_string(),
1154 "echo"
1155 );
1156
1157 let compensation = task.compensation();
1159 assert!(compensation.is_some());
1160 assert_eq!(compensation.unwrap().action_type, crate::workflow::task::CompensationType::Skip);
1161 }
1162
1163 #[tokio::test]
1164 async fn test_tool_task_execution() {
1165 use std::sync::Arc;
1166
1167 let mut registry = crate::workflow::tools::ToolRegistry::new();
1169 registry.register(crate::workflow::tools::Tool::new("echo", "echo")).unwrap();
1170
1171 let context = TaskContext::new("workflow_1", TaskId::new("tool_task"))
1173 .with_tool_registry(Arc::new(registry));
1174
1175 let task = ToolTask::new(
1177 TaskId::new("tool_task"),
1178 "Echo Tool".to_string(),
1179 "echo"
1180 )
1181 .args(vec!["test".to_string()]);
1182
1183 let result = task.execute(&context).await.unwrap();
1185 assert_eq!(result, TaskResult::Success);
1186 }
1187
1188 #[tokio::test]
1189 async fn test_tool_task_with_fallback() {
1190 use std::sync::Arc;
1191
1192 let mut registry = crate::workflow::tools::ToolRegistry::new();
1194 registry.register(crate::workflow::tools::Tool::new("echo", "echo")).unwrap();
1195
1196 let context = TaskContext::new("workflow_1", TaskId::new("tool_task"))
1198 .with_tool_registry(Arc::new(registry));
1199
1200 let task = ToolTask::new(
1202 TaskId::new("tool_task"),
1203 "Nonexistent Tool".to_string(),
1204 "nonexistent" )
1206 .with_fallback(Box::new(crate::workflow::tools::SkipFallback::skip()));
1207
1208 let result = task.execute(&context).await.unwrap();
1210 assert_eq!(result, TaskResult::Skipped);
1211 }
1212
1213 #[tokio::test]
1214 async fn test_standard_tools() {
1215 use crate::workflow::tools::ToolRegistry;
1216
1217 let registry = ToolRegistry::with_standard_tools();
1218
1219 let tool_count = registry.len();
1222 eprintln!("Standard tools registered: {}", tool_count);
1223
1224 assert!(tool_count >= 0);
1227 }
1228
1229 #[tokio::test]
1230 async fn test_tool_invoke_from_workflow() {
1231 use crate::workflow::dag::Workflow;
1232 use crate::workflow::executor::WorkflowExecutor;
1233 use crate::workflow::tools::{Tool, ToolRegistry};
1234 use std::sync::Arc;
1235
1236 let mut workflow = Workflow::new();
1238 let task_id = TaskId::new("tool_task");
1239
1240 let mut registry = ToolRegistry::new();
1242 registry.register(Tool::new("echo", "echo")).unwrap();
1243
1244 let tool_task = ToolTask::new(
1245 task_id.clone(),
1246 "Echo Tool".to_string(),
1247 "echo"
1248 )
1249 .args(vec!["hello".to_string()]);
1250
1251 workflow.add_task(Box::new(tool_task));
1252
1253 let mut executor = WorkflowExecutor::new(workflow)
1255 .with_tool_registry(registry);
1256
1257 let result = executor.execute().await.unwrap();
1259 assert!(result.success);
1260 assert!(result.completed_tasks.contains(&task_id));
1261 }
1262
1263 #[tokio::test]
1264 async fn test_tool_fallback_audit_event() {
1265 use crate::audit::{AuditEvent, AuditLog};
1266
1267 let audit_log = AuditLog::new();
1269
1270 let mut registry = crate::workflow::tools::ToolRegistry::new();
1272 registry.register(crate::workflow::tools::Tool::new("echo", "echo")).unwrap();
1273
1274 let context = TaskContext::new("workflow_1", TaskId::new("tool_task"))
1276 .with_tool_registry(Arc::new(registry))
1277 .with_audit_log(audit_log);
1278
1279 let task = ToolTask::new(
1281 TaskId::new("tool_task"),
1282 "Nonexistent Tool".to_string(),
1283 "nonexistent" )
1285 .with_fallback(Box::new(crate::workflow::tools::SkipFallback::skip()));
1286
1287 let result = task.execute(&context).await.unwrap();
1289 assert_eq!(result, TaskResult::Skipped);
1290
1291 }
1295}