Skip to main content

aster/tools/
workflow_integration.rs

1//! 工具钩子系统集成示例
2//!
3//! 展示如何在 aster-rust 工具执行流程中集成三阶段工作流和钩子系统
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8
9use crate::tools::hooks::{HookContext, HookTrigger, ToolHookManager};
10use crate::tools::{Tool, ToolContext, ToolError, ToolResult};
11
12/// 工作流集成工具 - 演示如何在工具执行中使用钩子系统
13#[derive(Clone)]
14pub struct WorkflowIntegratedTool {
15    name: String,
16    description: String,
17    hook_manager: Option<Arc<ToolHookManager>>,
18}
19
20impl Default for WorkflowIntegratedTool {
21    fn default() -> Self {
22        Self {
23            name: "workflow_integrated_tool".to_string(),
24            description: "Demonstrates three-stage workflow integration with hook system"
25                .to_string(),
26            hook_manager: None,
27        }
28    }
29}
30
31impl WorkflowIntegratedTool {
32    /// 创建带钩子管理器的工具实例
33    pub fn with_hook_manager(mut self, hook_manager: Arc<ToolHookManager>) -> Self {
34        self.hook_manager = Some(hook_manager);
35        self
36    }
37
38    /// Pre-Action 阶段:执行前的上下文刷新和检查
39    async fn pre_action(
40        &self,
41        context: &ToolContext,
42        params: &serde_json::Value,
43    ) -> Result<String, ToolError> {
44        if let Some(hook_manager) = &self.hook_manager {
45            let hook_context = HookContext::new(self.name.clone(), params.clone(), context.clone());
46
47            // 触发 Pre-Execution 钩子
48            hook_manager
49                .trigger_hooks(HookTrigger::PreExecution, &hook_context)
50                .await
51                .map_err(|e| {
52                    ToolError::execution_failed(format!("Pre-action hook failed: {}", e))
53                })?;
54        }
55
56        // 模拟上下文刷新逻辑
57        let context_info = format!(
58            "🔄 Pre-Action 上下文刷新:\n\n工作目录: {:?}\n会话ID: {}\n用户: {}\n\n⚠️ 准备执行工具操作,请确认目标明确",
59            context.working_directory,
60            if context.session_id.is_empty() { "未知" } else { &context.session_id },
61            context.user.as_deref().unwrap_or("未知")
62        );
63
64        Ok(context_info)
65    }
66
67    /// Post-Action 阶段:执行后的状态更新和学习
68    async fn post_action(
69        &self,
70        context: &ToolContext,
71        params: &serde_json::Value,
72        result: &ToolResult,
73        error: Option<&ToolError>,
74    ) -> Result<String, ToolError> {
75        if let Some(hook_manager) = &self.hook_manager {
76            let mut hook_context =
77                HookContext::new(self.name.clone(), params.clone(), context.clone())
78                    .with_result(result.clone());
79
80            if let Some(err) = error {
81                hook_context = hook_context.with_error(err.to_string());
82
83                // 触发错误钩子
84                hook_manager
85                    .trigger_hooks(HookTrigger::OnError, &hook_context)
86                    .await
87                    .map_err(|e| {
88                        ToolError::execution_failed(format!("Error hook failed: {}", e))
89                    })?;
90            } else {
91                // 触发 Post-Execution 钩子
92                hook_manager
93                    .trigger_hooks(HookTrigger::PostExecution, &hook_context)
94                    .await
95                    .map_err(|e| {
96                        ToolError::execution_failed(format!("Post-action hook failed: {}", e))
97                    })?;
98            }
99        }
100
101        // 生成 Post-Action 消息
102        let mut message = "📝 Post-Action 状态更新:\n\n".to_string();
103
104        if let Some(err) = error {
105            message.push_str(&format!("🚨 错误处理: {}\n", err));
106            message.push_str("- 错误已记录到错误跟踪系统\n");
107            message.push_str("- 建议检查输入参数和执行环境\n");
108        } else {
109            message.push_str("✅ 操作成功完成\n");
110            message.push_str("- 结果已记录到进度日志\n");
111        }
112
113        message.push_str("\n💡 下一步建议:\n");
114        message.push_str("- 如果完成了某个阶段,请更新任务计划\n");
115        message.push_str("- 有重要发现请记录到 findings.md\n");
116        message.push_str("- 继续下一个计划步骤\n");
117
118        Ok(message)
119    }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct WorkflowParams {
124    pub action: String,
125    pub description: String,
126    pub simulate_error: Option<bool>,
127}
128
129#[async_trait]
130impl Tool for WorkflowIntegratedTool {
131    fn name(&self) -> &str {
132        &self.name
133    }
134
135    fn description(&self) -> &str {
136        &self.description
137    }
138
139    fn input_schema(&self) -> serde_json::Value {
140        serde_json::json!({
141            "type": "object",
142            "properties": {
143                "action": {
144                    "type": "string",
145                    "description": "Action to perform (e.g., 'analyze', 'process', 'generate')"
146                },
147                "description": {
148                    "type": "string",
149                    "description": "Detailed description of what to do"
150                },
151                "simulate_error": {
152                    "type": "boolean",
153                    "description": "Whether to simulate an error for testing (optional)"
154                }
155            },
156            "required": ["action", "description"]
157        })
158    }
159
160    async fn execute(
161        &self,
162        params: serde_json::Value,
163        context: &ToolContext,
164    ) -> Result<ToolResult, ToolError> {
165        let params: WorkflowParams = serde_json::from_value(params.clone())
166            .map_err(|e| ToolError::invalid_params(e.to_string()))?;
167
168        // === Pre-Action 阶段 ===
169        let pre_action_info = self
170            .pre_action(context, &serde_json::to_value(&params).unwrap())
171            .await?;
172
173        // === Action 阶段 ===
174        let mut result_content = format!("🔄 执行操作: {}\n\n", params.action);
175        result_content.push_str(&format!("描述: {}\n\n", params.description));
176        result_content.push_str(&format!("Pre-Action 信息:\n{}\n\n", pre_action_info));
177
178        // 模拟实际工作
179        let action_result = if params.simulate_error.unwrap_or(false) {
180            Err(ToolError::execution_failed("模拟错误:操作失败"))
181        } else {
182            result_content.push_str("✅ 操作执行成功\n");
183            result_content.push_str(&format!(
184                "时间: {}\n",
185                chrono::Utc::now().format("%Y-%m-%d %H:%M:%S")
186            ));
187
188            Ok(ToolResult::success(&result_content)
189                .with_metadata("action", serde_json::json!(params.action))
190                .with_metadata("workflow_stage", serde_json::json!("action_completed")))
191        };
192
193        // === Post-Action 阶段 ===
194        let post_action_info = match &action_result {
195            Ok(result) => {
196                self.post_action(
197                    context,
198                    &serde_json::to_value(&params).unwrap(),
199                    result,
200                    None,
201                )
202                .await?
203            }
204            Err(error) => {
205                self.post_action(
206                    context,
207                    &serde_json::to_value(&params).unwrap(),
208                    &ToolResult::error("Action failed"),
209                    Some(error),
210                )
211                .await?
212            }
213        };
214
215        // 合并结果
216        match action_result {
217            Ok(mut result) => {
218                let final_content = format!("{}\n\n{}", result.content(), post_action_info);
219                result = result.with_content(final_content);
220                Ok(result)
221            }
222            Err(error) => {
223                // 即使操作失败,也要返回包含 Post-Action 信息的结果
224                let error_content = format!("❌ 操作失败: {}\n\n{}", error, post_action_info);
225                Ok(ToolResult::error(&error_content)
226                    .with_metadata("error", serde_json::json!(error.to_string()))
227                    .with_metadata("post_action_info", serde_json::json!(post_action_info)))
228            }
229        }
230    }
231}
232
233/// 工作流集成工具的构建器
234pub struct WorkflowIntegratedToolBuilder {
235    tool: WorkflowIntegratedTool,
236}
237
238impl WorkflowIntegratedToolBuilder {
239    pub fn new() -> Self {
240        Self {
241            tool: WorkflowIntegratedTool::default(),
242        }
243    }
244
245    pub fn with_name(mut self, name: String) -> Self {
246        self.tool.name = name;
247        self
248    }
249
250    pub fn with_description(mut self, description: String) -> Self {
251        self.tool.description = description;
252        self
253    }
254
255    pub fn with_hook_manager(mut self, hook_manager: Arc<ToolHookManager>) -> Self {
256        self.tool.hook_manager = Some(hook_manager);
257        self
258    }
259
260    pub fn build(self) -> WorkflowIntegratedTool {
261        self.tool
262    }
263}
264
265impl Default for WorkflowIntegratedToolBuilder {
266    fn default() -> Self {
267        Self::new()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use std::path::PathBuf;
275
276    fn create_test_context() -> ToolContext {
277        ToolContext::new(PathBuf::from("/tmp"))
278            .with_session_id("test-session")
279            .with_user("test-user")
280    }
281
282    #[tokio::test]
283    async fn test_workflow_integrated_tool_success() {
284        let tool = WorkflowIntegratedTool::default();
285        let context = create_test_context();
286
287        let params = serde_json::json!({
288            "action": "analyze",
289            "description": "分析测试数据",
290            "simulate_error": false
291        });
292
293        let result = tool.execute(params, &context).await.unwrap();
294        assert!(result.is_success());
295        assert!(result.content().contains("Pre-Action 信息"));
296        assert!(result.content().contains("Post-Action 状态更新"));
297    }
298
299    #[tokio::test]
300    async fn test_workflow_integrated_tool_error() {
301        let tool = WorkflowIntegratedTool::default();
302        let context = create_test_context();
303
304        let params = serde_json::json!({
305            "action": "process",
306            "description": "处理错误测试",
307            "simulate_error": true
308        });
309
310        let result = tool.execute(params, &context).await.unwrap();
311        assert!(result.content().contains("操作失败"));
312        assert!(result.content().contains("Post-Action 状态更新"));
313        assert!(result.content().contains("错误处理"));
314    }
315
316    #[tokio::test]
317    async fn test_workflow_integrated_tool_with_hooks() {
318        let hook_manager = Arc::new(ToolHookManager::new(true));
319        hook_manager.register_default_hooks().await;
320
321        let tool = WorkflowIntegratedTool::default().with_hook_manager(hook_manager.clone());
322
323        let context = create_test_context();
324
325        let params = serde_json::json!({
326            "action": "test",
327            "description": "测试钩子集成",
328            "simulate_error": false
329        });
330
331        let result = tool.execute(params, &context).await.unwrap();
332        assert!(result.is_success());
333
334        // 验证钩子被触发
335        assert_eq!(hook_manager.hook_count(HookTrigger::PreExecution).await, 2); // LoggingHook + FileOperationHook
336        assert_eq!(hook_manager.hook_count(HookTrigger::PostExecution).await, 1);
337        // LoggingHook
338    }
339
340    #[tokio::test]
341    async fn test_workflow_builder() {
342        let hook_manager = Arc::new(ToolHookManager::new(true));
343
344        let tool = WorkflowIntegratedToolBuilder::new()
345            .with_name("custom_workflow_tool".to_string())
346            .with_description("自定义工作流工具".to_string())
347            .with_hook_manager(hook_manager)
348            .build();
349
350        assert_eq!(tool.name(), "custom_workflow_tool");
351        assert_eq!(tool.description(), "自定义工作流工具");
352        assert!(tool.hook_manager.is_some());
353    }
354}