Skip to main content

ai_agent/utils/hooks/
exec_prompt_hook.rs

1// Source: ~/claudecode/openclaudecode/src/utils/hooks/execPromptHook.ts
2#![allow(dead_code)]
3
4use std::sync::Arc;
5use uuid::Uuid;
6
7use crate::types::Message;
8use crate::utils::hooks::hook_helpers::{HookResponse, add_arguments_to_prompt, hook_response_schema};
9
10/// Result of a hook execution
11pub enum HookResult {
12    Success {
13        hook_name: String,
14        hook_event: String,
15        tool_use_id: String,
16    },
17    Blocking {
18        blocking_error: String,
19        command: String,
20        prevent_continuation: bool,
21        stop_reason: String,
22    },
23    Cancelled,
24    NonBlockingError {
25        hook_name: String,
26        hook_event: String,
27        tool_use_id: String,
28        stderr: String,
29        stdout: String,
30        exit_code: i32,
31    },
32}
33
34/// Represents a prompt hook configuration
35pub struct PromptHook {
36    /// The prompt to send to the model
37    pub prompt: String,
38    /// Optional timeout in seconds
39    pub timeout: Option<u64>,
40    /// Optional model override
41    pub model: Option<String>,
42}
43
44/// Execute a prompt-based hook using an LLM
45pub async fn exec_prompt_hook(
46    hook: &PromptHook,
47    hook_name: &str,
48    hook_event: &str,
49    json_input: &str,
50    _signal: tokio::sync::watch::Receiver<bool>,
51    tool_use_context: Arc<crate::utils::hooks::can_use_tool::ToolUseContext>,
52    messages: Option<&[Message]>,
53    tool_use_id: Option<String>,
54) -> HookResult {
55    // Use provided tool_use_id or generate a new one
56    let effective_tool_use_id = tool_use_id.unwrap_or_else(|| format!("hook-{}", Uuid::new_v4()));
57
58    // Replace $ARGUMENTS with the JSON input
59    let processed_prompt = add_arguments_to_prompt(&hook.prompt, json_input);
60    log_for_debugging(&format!(
61        "Hooks: Processing prompt hook with prompt: {}",
62        processed_prompt.chars().take(200).collect::<String>()
63    ));
64
65    // Create user message directly
66    let user_message = create_user_message(&processed_prompt);
67
68    // Prepend conversation history if provided
69    let messages_to_query: Vec<serde_json::Value> = if let Some(msgs) = messages {
70        let mut msg_vec: Vec<serde_json::Value> = msgs.iter().map(|m| message_to_json(m)).collect();
71        msg_vec.push(message_to_json_user(&user_message));
72        msg_vec
73    } else {
74        vec![message_to_json_user(&user_message)]
75    };
76
77    log_for_debugging(&format!(
78        "Hooks: Querying model with {} messages",
79        messages_to_query.len()
80    ));
81
82    // Query the model with a small fast model
83    let hook_timeout_ms = hook.timeout.map_or(30_000, |t| t * 1000);
84
85    // Create abort channel
86    let (abort_tx, abort_rx) = tokio::sync::watch::channel(false);
87
88    // Setup timeout
89    let timeout_handle = tokio::spawn(async move {
90        tokio::time::sleep(tokio::time::Duration::from_millis(hook_timeout_ms)).await;
91        let _ = abort_tx.send(true);
92    });
93
94    // Build the query
95    let model = hook.model.clone().unwrap_or_else(get_small_fast_model);
96    let system_prompt = r#"You are evaluating a hook in Claude Code.
97
98Your response must be a JSON object matching one of the following schemas:
991. If the condition is met, return: {"ok": true}
1002. If the condition is not met, return: {"ok": false, "reason": "Reason for why it is not met}"#;
101
102    // Make the API call
103    let response =
104        query_model_without_streaming(&messages_to_query, system_prompt, &model, &tool_use_context)
105            .await;
106
107    timeout_handle.abort();
108
109    // Check if aborted
110    if *abort_rx.borrow() {
111        return HookResult::Cancelled;
112    }
113
114    match response {
115        Ok(content) => {
116            // Update response length for spinner display (not applicable in Rust)
117            let full_response = content.trim();
118            log_for_debugging(&format!("Hooks: Model response: {}", full_response));
119
120            // Parse JSON response
121            let json = match serde_json::from_str::<serde_json::Value>(full_response) {
122                Ok(j) => j,
123                Err(_) => {
124                    log_for_debugging(&format!(
125                        "Hooks: error parsing response as JSON: {}",
126                        full_response
127                    ));
128                    return HookResult::NonBlockingError {
129                        hook_name: hook_name.to_string(),
130                        hook_event: hook_event.to_string(),
131                        tool_use_id: effective_tool_use_id,
132                        stderr: "JSON validation failed".to_string(),
133                        stdout: full_response.to_string(),
134                        exit_code: 1,
135                    };
136                }
137            };
138
139            // Validate against hook response schema
140            let parsed = serde_json::from_value::<HookResponse>(json.clone());
141            match parsed {
142                Ok(hook_resp) => {
143                    // Failed to meet condition
144                    if !hook_resp.ok {
145                        let reason = hook_resp.reason.unwrap_or_default();
146                        log_for_debugging(&format!(
147                            "Hooks: Prompt hook condition was not met: {}",
148                            reason
149                        ));
150                        return HookResult::Blocking {
151                            blocking_error: format!(
152                                "Prompt hook condition was not met: {}",
153                                reason
154                            ),
155                            command: hook.prompt.clone(),
156                            prevent_continuation: true,
157                            stop_reason: reason,
158                        };
159                    }
160
161                    // Condition was met
162                    log_for_debugging("Hooks: Prompt hook condition was met");
163                    return HookResult::Success {
164                        hook_name: hook_name.to_string(),
165                        hook_event: hook_event.to_string(),
166                        tool_use_id: effective_tool_use_id,
167                    };
168                }
169                Err(err) => {
170                    log_for_debugging(&format!(
171                        "Hooks: model response does not conform to expected schema: {}",
172                        err
173                    ));
174                    return HookResult::NonBlockingError {
175                        hook_name: hook_name.to_string(),
176                        hook_event: hook_event.to_string(),
177                        tool_use_id: effective_tool_use_id,
178                        stderr: format!("Schema validation failed: {}", err),
179                        stdout: full_response.to_string(),
180                        exit_code: 1,
181                    };
182                }
183            }
184        }
185        Err(e) => {
186            log_for_debugging(&format!("Hooks: Prompt hook error: {}", e));
187            return HookResult::NonBlockingError {
188                hook_name: hook_name.to_string(),
189                hook_event: hook_event.to_string(),
190                tool_use_id: effective_tool_use_id,
191                stderr: format!("Error executing prompt hook: {}", e),
192                stdout: String::new(),
193                exit_code: 1,
194            };
195        }
196    }
197}
198
199/// Create a user message with the given content
200fn create_user_message(content: &str) -> Message {
201    Message {
202        role: crate::types::api_types::MessageRole::User,
203        content: content.to_string(),
204        attachments: None,
205        tool_call_id: None,
206        tool_calls: None,
207        is_error: None,
208        is_meta: None,
209        is_api_error_message: None,
210        error_details: None,
211        uuid: None,
212    }
213}
214
215/// Convert Message to JSON value
216fn message_to_json(msg: &Message) -> serde_json::Value {
217    serde_json::json!({
218        "role": msg.role.as_str(),
219        "content": &msg.content
220    })
221}
222
223/// Convert user message struct to JSON value (forces role to "user")
224fn message_to_json_user(msg: &Message) -> serde_json::Value {
225    serde_json::json!({
226        "role": "user",
227        "content": &msg.content
228    })
229}
230
231/// Get the small fast model (simplified)
232fn get_small_fast_model() -> String {
233    "claude-3-haiku-20240307".to_string()
234}
235
236/// Query model without streaming — makes a real non-streaming API call
237async fn query_model_without_streaming(
238    messages: &[serde_json::Value],
239    system_prompt: &str,
240    model: &str,
241    _tool_use_context: &crate::utils::hooks::can_use_tool::ToolUseContext,
242) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
243    // Get API credentials
244    let base_url = std::env::var("AI_API_BASE_URL").unwrap_or_else(|_| "https://api.anthropic.com".to_string());
245    let api_key = std::env::var("AI_AUTH_TOKEN")
246        .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
247        .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
248        .map_err(|e| format!("No API key found: {}", e))?;
249
250    let url = format!("{}/v1/messages", base_url);
251    let is_anthropic = base_url.contains("anthropic.com");
252
253    let request_body = serde_json::json!({
254        "model": model,
255        "max_tokens": 4096,
256        "system": [{"type": "text", "text": system_prompt}],
257        "messages": messages,
258        "temperature": 0.0,
259        "output": {
260            "type": "json_schema",
261            "name": "hook_response",
262            "schema": hook_response_schema(),
263            "strict": true
264        }
265    });
266
267    let client = reqwest::Client::new();
268    let mut req_builder = client.post(&url).json(&request_body)
269        .header("Content-Type", "application/json");
270
271    if is_anthropic {
272        req_builder = req_builder
273            .header("x-api-key", &api_key)
274            .header("anthropic-version", "2023-06-01");
275    } else {
276        req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
277    }
278
279    let response = req_builder.send().await?;
280    let status = response.status();
281    let body = response.text().await?;
282
283    if !status.is_success() {
284        return Err(format!("API error {}: {}", status, body).into());
285    }
286
287    // Extract text content from response
288    let parsed: serde_json::Value = serde_json::from_str(&body)
289        .map_err(|e| format!("Failed to parse API response: {}", e))?;
290
291    let text = extract_text(&parsed);
292    if text.is_empty() {
293        return Err("Empty response from model".into());
294    }
295
296    Ok(text)
297}
298
299/// Extract text content from an API response (supports both Anthropic and OpenAI formats)
300fn extract_text(response: &serde_json::Value) -> String {
301    // OpenAI format: choices[].message.content
302    if let Some(content) = response.get("choices").and_then(|c| c.as_array())
303        .and_then(|c| c.first())
304        .and_then(|c| c.get("message"))
305        .and_then(|m| m.get("content"))
306        .and_then(|c| c.as_str()) {
307        return content.to_string();
308    }
309    // Anthropic format: content[].text
310    if let Some(blocks) = response.get("content").and_then(|c| c.as_array()) {
311        let mut texts = Vec::new();
312        for block in blocks {
313            if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
314                texts.push(text.to_string());
315            }
316        }
317        if !texts.is_empty() {
318            return texts.join("\n");
319        }
320    }
321    String::new()
322}
323
324/// Log for debugging
325fn log_for_debugging(msg: &str) {
326    log::debug!("{}", msg);
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_extract_text_anthropic() {
335        let response = serde_json::json!({
336            "content": [
337                {"type": "text", "text": "Hello from Anthropic"},
338                {"type": "text", "text": "Second block"}
339            ]
340        });
341        assert_eq!(extract_text(&response), "Hello from Anthropic\nSecond block");
342    }
343
344    #[test]
345    fn test_extract_text_anthropic_single_block() {
346        let response = serde_json::json!({
347            "content": [
348                {"type": "text", "text": "Single block"}
349            ]
350        });
351        assert_eq!(extract_text(&response), "Single block");
352    }
353
354    #[test]
355    fn test_extract_text_openai() {
356        let response = serde_json::json!({
357            "choices": [
358                {
359                    "message": {
360                        "content": "Hello from OpenAI"
361                    }
362                }
363            ]
364        });
365        assert_eq!(extract_text(&response), "Hello from OpenAI");
366    }
367
368    #[test]
369    fn test_extract_text_empty() {
370        let response = serde_json::json!({});
371        assert_eq!(extract_text(&response), "");
372    }
373
374    #[test]
375    fn test_extract_text_no_text_blocks() {
376        let response = serde_json::json!({
377            "content": [
378                {"type": "tool_use", "name": "some_tool", "input": {}}
379            ]
380        });
381        assert_eq!(extract_text(&response), "");
382    }
383
384    #[test]
385    fn test_message_to_json_user() {
386        let msg = Message {
387            role: crate::types::api_types::MessageRole::User,
388            content: "test content".to_string(),
389            attachments: None,
390            tool_call_id: None,
391            tool_calls: None,
392            is_error: None,
393            is_meta: None,
394        is_api_error_message: None,
395        error_details: None,
396        uuid: None,
397        };
398        let json = message_to_json(&msg);
399        assert_eq!(json["role"], "user");
400        assert_eq!(json["content"], "test content");
401    }
402
403    #[test]
404    fn test_message_to_json_assistant() {
405        let msg = Message {
406            role: crate::types::api_types::MessageRole::Assistant,
407            content: "assistant reply".to_string(),
408            attachments: None,
409            tool_call_id: None,
410            tool_calls: None,
411            is_error: None,
412            is_meta: None,
413        is_api_error_message: None,
414        error_details: None,
415        uuid: None,
416        };
417        let json = message_to_json(&msg);
418        assert_eq!(json["role"], "assistant");
419        assert_eq!(json["content"], "assistant reply");
420    }
421
422    #[test]
423    fn test_message_to_json_user_forces_user_role() {
424        let msg = Message {
425            role: crate::types::api_types::MessageRole::Assistant,
426            content: "should be user".to_string(),
427            attachments: None,
428            tool_call_id: None,
429            tool_calls: None,
430            is_error: None,
431            is_meta: None,
432        is_api_error_message: None,
433        error_details: None,
434        uuid: None,
435        };
436        let json = message_to_json_user(&msg);
437        assert_eq!(json["role"], "user");
438        assert_eq!(json["content"], "should be user");
439    }
440
441    #[test]
442    fn test_role_to_str() {
443        assert_eq!(crate::types::api_types::MessageRole::User.as_str(), "user");
444        assert_eq!(crate::types::api_types::MessageRole::Assistant.as_str(), "assistant");
445        assert_eq!(crate::types::api_types::MessageRole::Tool.as_str(), "tool");
446        assert_eq!(crate::types::api_types::MessageRole::System.as_str(), "system");
447    }
448}