1#![allow(dead_code)]
3
4use std::sync::Arc;
5use uuid::Uuid;
6
7use crate::types::Message;
8use crate::utils::hooks::hook_helpers::{HookResponse, add_arguments_to_prompt, hook_response_schema};
9
10pub enum HookResult {
12 Success {
13 hook_name: String,
14 hook_event: String,
15 tool_use_id: String,
16 },
17 Blocking {
18 blocking_error: String,
19 command: String,
20 prevent_continuation: bool,
21 stop_reason: String,
22 },
23 Cancelled,
24 NonBlockingError {
25 hook_name: String,
26 hook_event: String,
27 tool_use_id: String,
28 stderr: String,
29 stdout: String,
30 exit_code: i32,
31 },
32}
33
34pub struct PromptHook {
36 pub prompt: String,
38 pub timeout: Option<u64>,
40 pub model: Option<String>,
42}
43
44pub async fn exec_prompt_hook(
46 hook: &PromptHook,
47 hook_name: &str,
48 hook_event: &str,
49 json_input: &str,
50 _signal: tokio::sync::watch::Receiver<bool>,
51 tool_use_context: Arc<crate::utils::hooks::can_use_tool::ToolUseContext>,
52 messages: Option<&[Message]>,
53 tool_use_id: Option<String>,
54) -> HookResult {
55 let effective_tool_use_id = tool_use_id.unwrap_or_else(|| format!("hook-{}", Uuid::new_v4()));
57
58 let processed_prompt = add_arguments_to_prompt(&hook.prompt, json_input);
60 log_for_debugging(&format!(
61 "Hooks: Processing prompt hook with prompt: {}",
62 processed_prompt.chars().take(200).collect::<String>()
63 ));
64
65 let user_message = create_user_message(&processed_prompt);
67
68 let messages_to_query: Vec<serde_json::Value> = if let Some(msgs) = messages {
70 let mut msg_vec: Vec<serde_json::Value> = msgs.iter().map(|m| message_to_json(m)).collect();
71 msg_vec.push(message_to_json_user(&user_message));
72 msg_vec
73 } else {
74 vec![message_to_json_user(&user_message)]
75 };
76
77 log_for_debugging(&format!(
78 "Hooks: Querying model with {} messages",
79 messages_to_query.len()
80 ));
81
82 let hook_timeout_ms = hook.timeout.map_or(30_000, |t| t * 1000);
84
85 let (abort_tx, abort_rx) = tokio::sync::watch::channel(false);
87
88 let timeout_handle = tokio::spawn(async move {
90 tokio::time::sleep(tokio::time::Duration::from_millis(hook_timeout_ms)).await;
91 let _ = abort_tx.send(true);
92 });
93
94 let model = hook.model.clone().unwrap_or_else(get_small_fast_model);
96 let system_prompt = r#"You are evaluating a hook in Claude Code.
97
98Your response must be a JSON object matching one of the following schemas:
991. If the condition is met, return: {"ok": true}
1002. If the condition is not met, return: {"ok": false, "reason": "Reason for why it is not met}"#;
101
102 let response =
104 query_model_without_streaming(&messages_to_query, system_prompt, &model, &tool_use_context)
105 .await;
106
107 timeout_handle.abort();
108
109 if *abort_rx.borrow() {
111 return HookResult::Cancelled;
112 }
113
114 match response {
115 Ok(content) => {
116 let full_response = content.trim();
118 log_for_debugging(&format!("Hooks: Model response: {}", full_response));
119
120 let json = match serde_json::from_str::<serde_json::Value>(full_response) {
122 Ok(j) => j,
123 Err(_) => {
124 log_for_debugging(&format!(
125 "Hooks: error parsing response as JSON: {}",
126 full_response
127 ));
128 return HookResult::NonBlockingError {
129 hook_name: hook_name.to_string(),
130 hook_event: hook_event.to_string(),
131 tool_use_id: effective_tool_use_id,
132 stderr: "JSON validation failed".to_string(),
133 stdout: full_response.to_string(),
134 exit_code: 1,
135 };
136 }
137 };
138
139 let parsed = serde_json::from_value::<HookResponse>(json.clone());
141 match parsed {
142 Ok(hook_resp) => {
143 if !hook_resp.ok {
145 let reason = hook_resp.reason.unwrap_or_default();
146 log_for_debugging(&format!(
147 "Hooks: Prompt hook condition was not met: {}",
148 reason
149 ));
150 return HookResult::Blocking {
151 blocking_error: format!(
152 "Prompt hook condition was not met: {}",
153 reason
154 ),
155 command: hook.prompt.clone(),
156 prevent_continuation: true,
157 stop_reason: reason,
158 };
159 }
160
161 log_for_debugging("Hooks: Prompt hook condition was met");
163 return HookResult::Success {
164 hook_name: hook_name.to_string(),
165 hook_event: hook_event.to_string(),
166 tool_use_id: effective_tool_use_id,
167 };
168 }
169 Err(err) => {
170 log_for_debugging(&format!(
171 "Hooks: model response does not conform to expected schema: {}",
172 err
173 ));
174 return HookResult::NonBlockingError {
175 hook_name: hook_name.to_string(),
176 hook_event: hook_event.to_string(),
177 tool_use_id: effective_tool_use_id,
178 stderr: format!("Schema validation failed: {}", err),
179 stdout: full_response.to_string(),
180 exit_code: 1,
181 };
182 }
183 }
184 }
185 Err(e) => {
186 log_for_debugging(&format!("Hooks: Prompt hook error: {}", e));
187 return HookResult::NonBlockingError {
188 hook_name: hook_name.to_string(),
189 hook_event: hook_event.to_string(),
190 tool_use_id: effective_tool_use_id,
191 stderr: format!("Error executing prompt hook: {}", e),
192 stdout: String::new(),
193 exit_code: 1,
194 };
195 }
196 }
197}
198
199fn create_user_message(content: &str) -> Message {
201 Message {
202 role: crate::types::api_types::MessageRole::User,
203 content: content.to_string(),
204 attachments: None,
205 tool_call_id: None,
206 tool_calls: None,
207 is_error: None,
208 is_meta: None,
209 is_api_error_message: None,
210 error_details: None,
211 uuid: None,
212 }
213}
214
215fn message_to_json(msg: &Message) -> serde_json::Value {
217 serde_json::json!({
218 "role": msg.role.as_str(),
219 "content": &msg.content
220 })
221}
222
223fn message_to_json_user(msg: &Message) -> serde_json::Value {
225 serde_json::json!({
226 "role": "user",
227 "content": &msg.content
228 })
229}
230
231fn get_small_fast_model() -> String {
233 "claude-3-haiku-20240307".to_string()
234}
235
236async fn query_model_without_streaming(
238 messages: &[serde_json::Value],
239 system_prompt: &str,
240 model: &str,
241 _tool_use_context: &crate::utils::hooks::can_use_tool::ToolUseContext,
242) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
243 let base_url = std::env::var("AI_API_BASE_URL").unwrap_or_else(|_| "https://api.anthropic.com".to_string());
245 let api_key = std::env::var("AI_AUTH_TOKEN")
246 .or_else(|_| std::env::var("ANTHROPIC_API_KEY"))
247 .or_else(|_| std::env::var("ANTHROPIC_AUTH_TOKEN"))
248 .map_err(|e| format!("No API key found: {}", e))?;
249
250 let url = format!("{}/v1/messages", base_url);
251 let is_anthropic = base_url.contains("anthropic.com");
252
253 let request_body = serde_json::json!({
254 "model": model,
255 "max_tokens": 4096,
256 "system": [{"type": "text", "text": system_prompt}],
257 "messages": messages,
258 "temperature": 0.0,
259 "output": {
260 "type": "json_schema",
261 "name": "hook_response",
262 "schema": hook_response_schema(),
263 "strict": true
264 }
265 });
266
267 let client = reqwest::Client::new();
268 let mut req_builder = client.post(&url).json(&request_body)
269 .header("Content-Type", "application/json");
270
271 if is_anthropic {
272 req_builder = req_builder
273 .header("x-api-key", &api_key)
274 .header("anthropic-version", "2023-06-01");
275 } else {
276 req_builder = req_builder.header("Authorization", format!("Bearer {}", api_key));
277 }
278
279 let response = req_builder.send().await?;
280 let status = response.status();
281 let body = response.text().await?;
282
283 if !status.is_success() {
284 return Err(format!("API error {}: {}", status, body).into());
285 }
286
287 let parsed: serde_json::Value = serde_json::from_str(&body)
289 .map_err(|e| format!("Failed to parse API response: {}", e))?;
290
291 let text = extract_text(&parsed);
292 if text.is_empty() {
293 return Err("Empty response from model".into());
294 }
295
296 Ok(text)
297}
298
299fn extract_text(response: &serde_json::Value) -> String {
301 if let Some(content) = response.get("choices").and_then(|c| c.as_array())
303 .and_then(|c| c.first())
304 .and_then(|c| c.get("message"))
305 .and_then(|m| m.get("content"))
306 .and_then(|c| c.as_str()) {
307 return content.to_string();
308 }
309 if let Some(blocks) = response.get("content").and_then(|c| c.as_array()) {
311 let mut texts = Vec::new();
312 for block in blocks {
313 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
314 texts.push(text.to_string());
315 }
316 }
317 if !texts.is_empty() {
318 return texts.join("\n");
319 }
320 }
321 String::new()
322}
323
324fn log_for_debugging(msg: &str) {
326 log::debug!("{}", msg);
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 #[test]
334 fn test_extract_text_anthropic() {
335 let response = serde_json::json!({
336 "content": [
337 {"type": "text", "text": "Hello from Anthropic"},
338 {"type": "text", "text": "Second block"}
339 ]
340 });
341 assert_eq!(extract_text(&response), "Hello from Anthropic\nSecond block");
342 }
343
344 #[test]
345 fn test_extract_text_anthropic_single_block() {
346 let response = serde_json::json!({
347 "content": [
348 {"type": "text", "text": "Single block"}
349 ]
350 });
351 assert_eq!(extract_text(&response), "Single block");
352 }
353
354 #[test]
355 fn test_extract_text_openai() {
356 let response = serde_json::json!({
357 "choices": [
358 {
359 "message": {
360 "content": "Hello from OpenAI"
361 }
362 }
363 ]
364 });
365 assert_eq!(extract_text(&response), "Hello from OpenAI");
366 }
367
368 #[test]
369 fn test_extract_text_empty() {
370 let response = serde_json::json!({});
371 assert_eq!(extract_text(&response), "");
372 }
373
374 #[test]
375 fn test_extract_text_no_text_blocks() {
376 let response = serde_json::json!({
377 "content": [
378 {"type": "tool_use", "name": "some_tool", "input": {}}
379 ]
380 });
381 assert_eq!(extract_text(&response), "");
382 }
383
384 #[test]
385 fn test_message_to_json_user() {
386 let msg = Message {
387 role: crate::types::api_types::MessageRole::User,
388 content: "test content".to_string(),
389 attachments: None,
390 tool_call_id: None,
391 tool_calls: None,
392 is_error: None,
393 is_meta: None,
394 is_api_error_message: None,
395 error_details: None,
396 uuid: None,
397 };
398 let json = message_to_json(&msg);
399 assert_eq!(json["role"], "user");
400 assert_eq!(json["content"], "test content");
401 }
402
403 #[test]
404 fn test_message_to_json_assistant() {
405 let msg = Message {
406 role: crate::types::api_types::MessageRole::Assistant,
407 content: "assistant reply".to_string(),
408 attachments: None,
409 tool_call_id: None,
410 tool_calls: None,
411 is_error: None,
412 is_meta: None,
413 is_api_error_message: None,
414 error_details: None,
415 uuid: None,
416 };
417 let json = message_to_json(&msg);
418 assert_eq!(json["role"], "assistant");
419 assert_eq!(json["content"], "assistant reply");
420 }
421
422 #[test]
423 fn test_message_to_json_user_forces_user_role() {
424 let msg = Message {
425 role: crate::types::api_types::MessageRole::Assistant,
426 content: "should be user".to_string(),
427 attachments: None,
428 tool_call_id: None,
429 tool_calls: None,
430 is_error: None,
431 is_meta: None,
432 is_api_error_message: None,
433 error_details: None,
434 uuid: None,
435 };
436 let json = message_to_json_user(&msg);
437 assert_eq!(json["role"], "user");
438 assert_eq!(json["content"], "should be user");
439 }
440
441 #[test]
442 fn test_role_to_str() {
443 assert_eq!(crate::types::api_types::MessageRole::User.as_str(), "user");
444 assert_eq!(crate::types::api_types::MessageRole::Assistant.as_str(), "assistant");
445 assert_eq!(crate::types::api_types::MessageRole::Tool.as_str(), "tool");
446 assert_eq!(crate::types::api_types::MessageRole::System.as_str(), "system");
447 }
448}