Skip to main content

aster/hooks/
executor.rs

1//! Hook 执行器
2//!
3//! 执行各种类型的 hooks
4
5use super::registry::global_registry;
6use super::types::*;
7use std::collections::HashMap;
8use std::process::Stdio;
9use std::time::Duration;
10use tokio::io::AsyncWriteExt;
11use tokio::process::Command;
12use tokio::time::timeout;
13use tracing::warn;
14
15/// 替换命令中的环境变量占位符
16fn replace_command_variables(command: &str, input: &HookInput) -> String {
17    command
18        .replace("$TOOL_NAME", input.tool_name.as_deref().unwrap_or(""))
19        .replace(
20            "$EVENT",
21            &input.event.map(|e| e.to_string()).unwrap_or_default(),
22        )
23        .replace("$SESSION_ID", input.session_id.as_deref().unwrap_or(""))
24}
25
26/// 执行 Command Hook
27async fn execute_command_hook(hook: &CommandHookConfig, input: &HookInput) -> HookResult {
28    let timeout_duration = Duration::from_millis(hook.timeout);
29    let command = replace_command_variables(&hook.command, input);
30
31    // 准备环境变量
32    let mut env: HashMap<String, String> = std::env::vars().collect();
33    env.extend(hook.env.clone());
34    env.insert(
35        "CLAUDE_HOOK_EVENT".to_string(),
36        input.event.map(|e| e.to_string()).unwrap_or_default(),
37    );
38    env.insert(
39        "CLAUDE_HOOK_TOOL_NAME".to_string(),
40        input.tool_name.clone().unwrap_or_default(),
41    );
42    env.insert(
43        "CLAUDE_HOOK_SESSION_ID".to_string(),
44        input.session_id.clone().unwrap_or_default(),
45    );
46
47    // 准备输入 JSON
48    let input_json = serde_json::to_string(input).unwrap_or_default();
49
50    let mut cmd = Command::new("sh");
51    cmd.arg("-c")
52        .arg(&command)
53        .envs(&env)
54        .stdin(Stdio::piped())
55        .stdout(Stdio::piped())
56        .stderr(Stdio::piped());
57
58    let result = timeout(timeout_duration, async {
59        let mut child = match cmd.spawn() {
60            Ok(c) => c,
61            Err(e) => {
62                return HookResult::failure(format!("Failed to spawn: {}", e));
63            }
64        };
65
66        // 写入 stdin
67        if let Some(mut stdin) = child.stdin.take() {
68            let _ = stdin.write_all(input_json.as_bytes()).await;
69        }
70
71        match child.wait_with_output().await {
72            Ok(output) => {
73                let stdout = String::from_utf8_lossy(&output.stdout).to_string();
74                let stderr = String::from_utf8_lossy(&output.stderr).to_string();
75
76                if !output.status.success() {
77                    // 尝试解析 JSON 输出以获取阻塞消息
78                    if let Ok(json) = serde_json::from_str::<serde_json::Value>(&stdout) {
79                        if json.get("blocked").and_then(|v| v.as_bool()) == Some(true) {
80                            let message = json
81                                .get("message")
82                                .and_then(|v| v.as_str())
83                                .unwrap_or("Blocked by hook")
84                                .to_string();
85                            return HookResult::blocked(message);
86                        }
87                    }
88                    return HookResult::failure(if stderr.is_empty() {
89                        format!("Hook exited with code {:?}", output.status.code())
90                    } else {
91                        stderr
92                    });
93                }
94
95                HookResult::success(Some(stdout))
96            }
97            Err(e) => HookResult::failure(format!("Failed to wait: {}", e)),
98        }
99    })
100    .await;
101
102    match result {
103        Ok(r) => r,
104        Err(_) => HookResult::failure("Hook execution timed out".to_string()),
105    }
106}
107
108/// 执行 URL Hook
109async fn execute_url_hook(hook: &UrlHookConfig, input: &HookInput) -> HookResult {
110    let timeout_duration = Duration::from_millis(hook.timeout);
111
112    let payload = serde_json::json!({
113        "event": input.event,
114        "toolName": input.tool_name,
115        "toolInput": input.tool_input,
116        "toolOutput": input.tool_output,
117        "message": input.message,
118        "sessionId": input.session_id,
119        "timestamp": chrono::Utc::now().to_rfc3339(),
120        "tool_use_id": input.tool_use_id,
121        "error": input.error,
122        "error_type": input.error_type,
123        "is_interrupt": input.is_interrupt,
124        "is_timeout": input.is_timeout,
125        "agent_id": input.agent_id,
126        "agent_type": input.agent_type,
127        "result": input.result,
128        "notification_type": input.notification_type,
129        "source": input.source,
130        "reason": input.reason,
131        "trigger": input.trigger,
132        "currentTokens": input.current_tokens,
133    });
134
135    let client = reqwest::Client::new();
136    let mut request = match hook.method {
137        HttpMethod::Get => client.get(&hook.url),
138        HttpMethod::Post => client.post(&hook.url),
139        HttpMethod::Put => client.put(&hook.url),
140        HttpMethod::Patch => client.patch(&hook.url),
141    };
142
143    request = request
144        .header("Content-Type", "application/json")
145        .header("User-Agent", "Aster-Hooks/1.0");
146
147    for (key, value) in &hook.headers {
148        request = request.header(key, value);
149    }
150
151    if hook.method != HttpMethod::Get {
152        request = request.json(&payload);
153    }
154
155    let result = timeout(timeout_duration, request.send()).await;
156
157    match result {
158        Ok(Ok(response)) => {
159            if !response.status().is_success() {
160                let status = response.status();
161                let text = response.text().await.unwrap_or_default();
162                return HookResult::failure(format!("HTTP {}: {}", status, text));
163            }
164
165            let text = response.text().await.unwrap_or_default();
166
167            // 尝试解析 JSON 响应
168            if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
169                if json.get("blocked").and_then(|v| v.as_bool()) == Some(true) {
170                    let message = json
171                        .get("message")
172                        .and_then(|v| v.as_str())
173                        .unwrap_or("Blocked by hook")
174                        .to_string();
175                    return HookResult::blocked(message);
176                }
177            }
178
179            HookResult::success(Some(text))
180        }
181        Ok(Err(e)) => HookResult::failure(format!("Request failed: {}", e)),
182        Err(_) => HookResult::failure("Hook request timed out".to_string()),
183    }
184}
185
186/// 执行 MCP Hook(占位实现)
187async fn execute_mcp_hook(hook: &McpHookConfig, _input: &HookInput) -> HookResult {
188    // TODO: 实现 MCP 工具调用
189    warn!(
190        "MCP hook not fully implemented: server={}, tool={}",
191        hook.server, hook.tool
192    );
193    HookResult::success(None)
194}
195
196/// 执行 Prompt Hook(占位实现)
197async fn execute_prompt_hook(_hook: &PromptHookConfig, _input: &HookInput) -> HookResult {
198    // TODO: 实现 LLM 提示评估
199    warn!("Prompt hook not fully implemented");
200    HookResult::success(None)
201}
202
203/// 执行 Agent Hook(占位实现)
204async fn execute_agent_hook(hook: &AgentHookConfig, _input: &HookInput) -> HookResult {
205    // TODO: 实现代理验证器
206    warn!("Agent hook not fully implemented: type={}", hook.agent_type);
207    HookResult::success(None)
208}
209
210/// 执行单个 hook
211async fn execute_hook(hook: &HookConfig, input: &HookInput) -> HookResult {
212    match hook {
213        HookConfig::Command(c) => execute_command_hook(c, input).await,
214        HookConfig::Url(c) => execute_url_hook(c, input).await,
215        HookConfig::Mcp(c) => execute_mcp_hook(c, input).await,
216        HookConfig::Prompt(c) => execute_prompt_hook(c, input).await,
217        HookConfig::Agent(c) => execute_agent_hook(c, input).await,
218    }
219}
220
221/// 运行所有匹配的 hooks
222pub async fn run_hooks(input: HookInput) -> Vec<HookResult> {
223    let event = match input.event {
224        Some(e) => e,
225        None => return vec![],
226    };
227
228    let registry = global_registry();
229    let matching_hooks = registry.get_matching(event, input.tool_name.as_deref());
230    let mut results = Vec::new();
231
232    for hook in &matching_hooks {
233        let result = execute_hook(hook, &input).await;
234        let is_blocked = result.blocked;
235        let is_blocking = hook.is_blocking();
236        results.push(result);
237
238        // 如果 hook 阻塞且是 blocking 类型,停止执行后续 hooks
239        if is_blocked && is_blocking {
240            break;
241        }
242    }
243
244    results
245}
246
247/// 检查是否有任何 hook 阻塞操作
248pub fn is_blocked(results: &[HookResult]) -> (bool, Option<String>) {
249    for result in results {
250        if result.blocked {
251            return (true, result.block_message.clone());
252        }
253    }
254    (false, None)
255}
256
257/// PreToolUse hook 辅助函数
258pub async fn run_pre_tool_use_hooks(
259    tool_name: &str,
260    tool_input: Option<serde_json::Value>,
261    session_id: Option<String>,
262) -> (bool, Option<String>) {
263    let results = run_hooks(HookInput {
264        event: Some(HookEvent::PreToolUse),
265        tool_name: Some(tool_name.to_string()),
266        tool_input,
267        session_id,
268        ..Default::default()
269    })
270    .await;
271
272    let (blocked, message) = is_blocked(&results);
273    (!blocked, message)
274}
275
276/// PostToolUse hook 辅助函数
277pub async fn run_post_tool_use_hooks(
278    tool_name: &str,
279    tool_input: Option<serde_json::Value>,
280    tool_output: String,
281    session_id: Option<String>,
282) {
283    let _ = run_hooks(HookInput {
284        event: Some(HookEvent::PostToolUse),
285        tool_name: Some(tool_name.to_string()),
286        tool_input,
287        tool_output: Some(tool_output),
288        session_id,
289        ..Default::default()
290    })
291    .await;
292}
293
294/// UserPromptSubmit hook
295pub async fn run_user_prompt_submit_hooks(
296    prompt: &str,
297    session_id: Option<String>,
298) -> (bool, Option<String>) {
299    let results = run_hooks(HookInput {
300        event: Some(HookEvent::UserPromptSubmit),
301        message: Some(prompt.to_string()),
302        session_id,
303        ..Default::default()
304    })
305    .await;
306
307    let (blocked, message) = is_blocked(&results);
308    (!blocked, message)
309}
310
311/// Stop hook
312pub async fn run_stop_hooks(reason: Option<String>, session_id: Option<String>) {
313    let _ = run_hooks(HookInput {
314        event: Some(HookEvent::Stop),
315        message: reason,
316        session_id,
317        ..Default::default()
318    })
319    .await;
320}
321
322/// PreCompact hook
323pub async fn run_pre_compact_hooks(
324    session_id: Option<String>,
325    current_tokens: Option<u64>,
326    trigger: Option<CompactTrigger>,
327) -> (bool, Option<String>) {
328    let results = run_hooks(HookInput {
329        event: Some(HookEvent::PreCompact),
330        current_tokens,
331        trigger,
332        session_id,
333        ..Default::default()
334    })
335    .await;
336
337    let (blocked, message) = is_blocked(&results);
338    (!blocked, message)
339}
340
341/// PostToolUseFailure hook
342#[allow(clippy::too_many_arguments)]
343pub async fn run_post_tool_use_failure_hooks(
344    tool_name: &str,
345    tool_input: Option<serde_json::Value>,
346    tool_use_id: String,
347    error: String,
348    error_type: HookErrorType,
349    is_interrupt: bool,
350    is_timeout: bool,
351    session_id: Option<String>,
352) {
353    let _ = run_hooks(HookInput {
354        event: Some(HookEvent::PostToolUseFailure),
355        tool_name: Some(tool_name.to_string()),
356        tool_input,
357        tool_use_id: Some(tool_use_id),
358        error: Some(error),
359        error_type: Some(error_type),
360        is_interrupt: Some(is_interrupt),
361        is_timeout: Some(is_timeout),
362        session_id,
363        ..Default::default()
364    })
365    .await;
366}
367
368/// SessionStart hook
369pub async fn run_session_start_hooks(session_id: String, source: Option<SessionSource>) {
370    let _ = run_hooks(HookInput {
371        event: Some(HookEvent::SessionStart),
372        source,
373        session_id: Some(session_id),
374        ..Default::default()
375    })
376    .await;
377}
378
379/// SessionEnd hook
380pub async fn run_session_end_hooks(session_id: String, reason: Option<SessionEndReason>) {
381    let _ = run_hooks(HookInput {
382        event: Some(HookEvent::SessionEnd),
383        reason,
384        session_id: Some(session_id),
385        ..Default::default()
386    })
387    .await;
388}
389
390/// SubagentStart hook
391pub async fn run_subagent_start_hooks(
392    agent_id: String,
393    agent_type: String,
394    session_id: Option<String>,
395) {
396    let _ = run_hooks(HookInput {
397        event: Some(HookEvent::SubagentStart),
398        agent_id: Some(agent_id),
399        agent_type: Some(agent_type),
400        session_id,
401        ..Default::default()
402    })
403    .await;
404}
405
406/// SubagentStop hook
407pub async fn run_subagent_stop_hooks(
408    agent_id: String,
409    agent_type: String,
410    result: Option<serde_json::Value>,
411    session_id: Option<String>,
412) {
413    let _ = run_hooks(HookInput {
414        event: Some(HookEvent::SubagentStop),
415        agent_id: Some(agent_id),
416        agent_type: Some(agent_type),
417        result,
418        session_id,
419        ..Default::default()
420    })
421    .await;
422}
423
424/// PermissionRequest hook
425pub async fn run_permission_request_hooks(
426    tool_name: &str,
427    tool_input: Option<serde_json::Value>,
428    tool_use_id: Option<String>,
429    session_id: Option<String>,
430) -> (Option<HookDecision>, Option<String>) {
431    let results = run_hooks(HookInput {
432        event: Some(HookEvent::PermissionRequest),
433        tool_name: Some(tool_name.to_string()),
434        tool_input,
435        tool_use_id,
436        session_id,
437        ..Default::default()
438    })
439    .await;
440
441    for result in &results {
442        if let Some(output) = &result.output {
443            if let Ok(json) = serde_json::from_str::<serde_json::Value>(output) {
444                if let Some(decision) = json.get("decision").and_then(|v| v.as_str()) {
445                    let d = match decision {
446                        "allow" => HookDecision::Allow,
447                        "deny" => HookDecision::Deny,
448                        _ => continue,
449                    };
450                    let message = json
451                        .get("message")
452                        .and_then(|v| v.as_str())
453                        .map(|s| s.to_string());
454                    return (Some(d), message);
455                }
456            }
457        }
458    }
459
460    (None, None)
461}
462
463/// Notification hook
464pub async fn run_notification_hooks(
465    message: &str,
466    notification_type: Option<NotificationType>,
467    session_id: Option<String>,
468) {
469    let _ = run_hooks(HookInput {
470        event: Some(HookEvent::Notification),
471        message: Some(message.to_string()),
472        notification_type,
473        session_id,
474        ..Default::default()
475    })
476    .await;
477}