Skip to main content

ai_agent/utils/hooks/
api_query_hook_helper.rs

1// Source: ~/claudecode/openclaudecode/src/utils/hooks/apiQueryHookHelper.ts
2#![allow(dead_code)]
3
4use serde::{Deserialize, Serialize};
5use std::future::Future;
6use std::sync::Arc;
7use uuid::Uuid;
8
9use crate::types::{Message, MessageRole};
10
11/// System prompt type - a vector of strings
12pub type SystemPrompt = Vec<String>;
13
14/// Context for REPL hooks (both post-sampling and stop hooks)
15#[derive(Clone)]
16pub struct ReplHookContext {
17    /// Full message history including assistant responses
18    pub messages: Vec<Message>,
19    /// System prompt
20    pub system_prompt: SystemPrompt,
21    /// User context key-value pairs
22    pub user_context: std::collections::HashMap<String, String>,
23    /// System context key-value pairs
24    pub system_context: std::collections::HashMap<String, String>,
25    /// Tool use context
26    pub tool_use_context: Arc<crate::utils::hooks::can_use_tool::ToolUseContext>,
27    /// Query source identifier
28    pub query_source: Option<String>,
29    /// Optional: message count for API queries
30    pub query_message_count: Option<usize>,
31}
32
33/// Configuration for an API query hook
34pub struct ApiQueryHookConfig<TResult> {
35    /// Query source name
36    pub name: String,
37    /// Whether this hook should run
38    pub should_run: Box<
39        dyn Fn(&ReplHookContext) -> std::pin::Pin<Box<dyn Future<Output = bool> + Send>>
40            + Send
41            + Sync,
42    >,
43    /// Build the complete message list to send to the API
44    pub build_messages: Box<dyn Fn(&ReplHookContext) -> Vec<Message> + Send + Sync>,
45    /// Optional: override system prompt (defaults to context.system_prompt)
46    pub system_prompt: Option<SystemPrompt>,
47    /// Optional: whether to use tools from context (defaults to true)
48    pub use_tools: Option<bool>,
49    /// Parse the response content into a result
50    pub parse_response: Box<dyn Fn(&str, &ReplHookContext) -> TResult + Send + Sync>,
51    /// Log the result
52    pub log_result: Box<dyn Fn(ApiQueryResult<TResult>, &ReplHookContext) + Send + Sync>,
53    /// Get the model to use (lazy loaded)
54    pub get_model: Box<dyn Fn(&ReplHookContext) -> String + Send + Sync>,
55}
56
57/// Result of an API query hook execution
58pub enum ApiQueryResult<TResult> {
59    Success {
60        query_name: String,
61        result: TResult,
62        message_id: String,
63        model: String,
64        uuid: String,
65    },
66    Error {
67        query_name: String,
68        error: Box<dyn std::error::Error + Send + Sync>,
69        uuid: String,
70    },
71}
72
73/// Create an API query hook from the given configuration.
74/// Returns an async function that executes the hook when called.
75pub fn create_api_query_hook<TResult: 'static>(
76    config: ApiQueryHookConfig<TResult>,
77) -> Box<dyn Fn(ReplHookContext) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>
78{
79    let config = Arc::new(config);
80    Box::new(move |context: ReplHookContext| {
81        let config = config.clone();
82        Box::pin(async move {
83            let should_run = (config.should_run)(&context).await;
84            if !should_run {
85                return;
86            }
87
88            let uuid = Uuid::new_v4().to_string();
89
90            // Build messages using the config's build_messages function
91            let messages = (config.build_messages)(&context);
92            // Note: we can't mutate context directly in Rust; the caller
93            // would need to handle query_message_count tracking externally
94
95            // Use config's system prompt if provided, otherwise use context's
96            let system_prompt = config
97                .system_prompt
98                .clone()
99                .unwrap_or_else(|| context.system_prompt.clone());
100
101            // Use config's tools preference (defaults to true = use context tools)
102            // In Rust, tool access would be through the tool_use_context
103
104            // Get model (lazy loaded)
105            let model = (config.get_model)(&context);
106
107            // Make API call - this would use the actual query function
108            // The TS version calls queryModelWithoutStreaming
109            let response_result =
110                query_model_without_streaming_impl(&messages, &system_prompt, &model, &context)
111                    .await;
112
113            match response_result {
114                Ok(response) => {
115                    // Extract text content from response JSON
116                    let content = extract_text_content(&response.content).trim().to_string();
117
118                    let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
119                        (config.parse_response)(&content, &context)
120                    }));
121
122                    match result {
123                        Ok(parsed_result) => {
124                            (config.log_result)(
125                                ApiQueryResult::Success {
126                                    query_name: config.name.clone(),
127                                    result: parsed_result,
128                                    message_id: response.message_id,
129                                    model,
130                                    uuid,
131                                },
132                                &context,
133                            );
134                        }
135                        Err(err) => {
136                            let error = if let Some(s) = err.downcast_ref::<String>() {
137                                Box::new(std::io::Error::new(std::io::ErrorKind::Other, s.clone()))
138                            } else if let Some(s) = err.downcast_ref::<&str>() {
139                                Box::new(std::io::Error::new(
140                                    std::io::ErrorKind::Other,
141                                    s.to_string(),
142                                ))
143                            } else {
144                                Box::new(std::io::Error::new(
145                                    std::io::ErrorKind::Other,
146                                    "Unknown panic in parse_response",
147                                ))
148                            };
149                            (config.log_result)(
150                                ApiQueryResult::Error {
151                                    query_name: config.name.clone(),
152                                    error,
153                                    uuid,
154                                },
155                                &context,
156                            );
157                        }
158                    }
159                }
160                Err(error) => {
161                    log_error(&format!("API query hook error: {}", error));
162                    (config.log_result)(
163                        ApiQueryResult::Error {
164                            query_name: config.name.clone(),
165                            error,
166                            uuid,
167                        },
168                        &context,
169                    );
170                }
171            }
172        })
173    })
174}
175
176/// Internal struct for API response
177struct ApiResponse {
178    message_id: String,
179    content: String,
180}
181
182/// Get the API key from environment variables.
183/// Checks AI_AUTH_TOKEN, ANTHROPIC_API_KEY, ANTHROPIC_AUTH_TOKEN in order.
184fn get_api_key() -> Result<String, String> {
185    if let Ok(key) = std::env::var("AI_AUTH_TOKEN") {
186        if !key.is_empty() {
187            return Ok(key);
188        }
189    }
190    if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
191        if !key.is_empty() {
192            return Ok(key);
193        }
194    }
195    if let Ok(key) = std::env::var("ANTHROPIC_AUTH_TOKEN") {
196        if !key.is_empty() {
197            return Ok(key);
198        }
199    }
200    Err("No API key found. Set AI_AUTH_TOKEN, ANTHROPIC_API_KEY, or ANTHROPIC_AUTH_TOKEN"
201        .to_string())
202}
203
204/// Convert a MessageRole to its API-compatible string representation.
205fn role_to_api_string(role: &MessageRole) -> &'static str {
206    match role {
207        MessageRole::User => "user",
208        MessageRole::Assistant => "assistant",
209        MessageRole::Tool => "tool",
210        MessageRole::System => "system",
211    }
212}
213
214/// Make a real non-streaming Anthropic Messages API call.
215/// Follows the same pattern as `make_away_api_request` in away_summary.rs.
216async fn query_model_without_streaming_impl(
217    messages: &[Message],
218    system_prompt: &SystemPrompt,
219    model: &str,
220    _context: &ReplHookContext,
221) -> Result<ApiResponse, Box<dyn std::error::Error + Send + Sync>> {
222    let api_key = get_api_key().map_err(|e| {
223        Box::<dyn std::error::Error + Send + Sync>::from(std::io::Error::new(
224            std::io::ErrorKind::Other,
225            e,
226        ))
227    })?;
228
229    let base_url = std::env::var("AI_API_BASE_URL")
230        .ok()
231        .unwrap_or_else(|| "https://api.anthropic.com".to_string());
232    let url = format!("{}/v1/messages", base_url);
233
234    // Determine if this is Anthropic API or a third-party API
235    let is_anthropic = base_url.contains("anthropic.com");
236
237    // Build API messages from Message structs
238    let api_messages: Vec<serde_json::Value> = messages
239        .iter()
240        .map(|m| {
241            let mut msg_obj = serde_json::json!({
242                "role": role_to_api_string(&m.role),
243                "content": &m.content
244            });
245
246            // Add tool_call_id for tool role messages
247            if m.role == MessageRole::Tool {
248                if let Some(ref tool_call_id) = m.tool_call_id {
249                    msg_obj["tool_use_id"] = serde_json::json!(tool_call_id);
250                }
251            }
252
253            msg_obj
254        })
255        .collect();
256
257    // Build system prompt as Anthropic format
258    let system_prompt_value = serde_json::json!({
259        "type": "text",
260        "text": system_prompt.join("\n")
261    });
262
263    // Build request body
264    let request_body = serde_json::json!({
265        "model": model,
266        "max_tokens": 4096,
267        "system": system_prompt_value,
268        "messages": api_messages,
269        "temperature": 0.0,
270    });
271
272    let client = reqwest::Client::new();
273    let request_builder = if is_anthropic {
274        client
275            .post(&url)
276            .header("x-api-key", &api_key)
277            .header("anthropic-version", "2023-06-01")
278            .header("Content-Type", "application/json")
279            .header("User-Agent", crate::utils::http::get_user_agent())
280            .json(&request_body)
281    } else {
282        client
283            .post(&url)
284            .header("Authorization", format!("Bearer {}", api_key))
285            .header("Content-Type", "application/json")
286            .header("User-Agent", crate::utils::http::get_user_agent())
287            .json(&request_body)
288    };
289
290    let response = request_builder
291        .send()
292        .await
293        .map_err(|e| {
294            Box::<dyn std::error::Error + Send + Sync>::from(std::io::Error::new(
295                std::io::ErrorKind::ConnectionRefused,
296                format!("API request failed: {}", e),
297            ))
298        })?;
299
300    let status = response.status();
301    if !status.is_success() {
302        let error_text = response.text().await.unwrap_or_default();
303        return Err(Box::<dyn std::error::Error + Send + Sync>::from(
304            std::io::Error::new(
305                std::io::ErrorKind::Other,
306                format!("API error {}: {}", status, error_text),
307            ),
308        ));
309    }
310
311    // Parse JSON response
312    let response_json: serde_json::Value = response
313        .json()
314        .await
315        .map_err(|e| {
316            Box::<dyn std::error::Error + Send + Sync>::from(std::io::Error::new(
317                std::io::ErrorKind::InvalidData,
318                format!("Failed to parse API response: {}", e),
319            ))
320        })?;
321
322    // Check for API error in response body
323    if let Some(error) = response_json.get("error") {
324        let error_msg = error
325            .get("message")
326            .and_then(|m| m.as_str())
327            .unwrap_or("Unknown error");
328        return Err(Box::<dyn std::error::Error + Send + Sync>::from(
329            std::io::Error::new(
330                std::io::ErrorKind::Other,
331                format!("API error: {}", error_msg),
332            ),
333        ));
334    }
335
336    // Extract message ID
337    let message_id = response_json
338        .get("id")
339        .and_then(|id| id.as_str())
340        .unwrap_or("unknown")
341        .to_string();
342
343    // Extract raw JSON content for downstream text extraction
344    let content = serde_json::to_string(&response_json).unwrap_or_default();
345
346    Ok(ApiResponse {
347        message_id,
348        content,
349    })
350}
351
352/// Extract text content from API response JSON.
353/// Supports both Anthropic format (content array of blocks with text)
354/// and OpenAI-compatible format (choices[0].message.content).
355/// Falls back to raw string if JSON parsing fails or format is unrecognized.
356fn extract_text_content(response_json: &str) -> String {
357    let Ok(response) = serde_json::from_str::<serde_json::Value>(response_json) else {
358        return response_json.to_string();
359    };
360
361    // OpenAI-compatible: response.choices[0].message.content
362    if let Some(content) = response
363        .get("choices")
364        .and_then(|c| c.as_array())
365        .and_then(|c| c.first())
366        .and_then(|c| c.get("message"))
367        .and_then(|m| m.get("content"))
368        .and_then(|c| c.as_str())
369    {
370        return content.to_string();
371    }
372
373    // Anthropic: response.content[N].text blocks
374    if let Some(blocks) = response.get("content").and_then(|c| c.as_array()) {
375        let mut texts = Vec::new();
376        for block in blocks {
377            if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
378                texts.push(text.to_string());
379            }
380        }
381        if !texts.is_empty() {
382            return texts.join("\n");
383        }
384    }
385
386    response_json.to_string()
387}
388
389/// Log an error (simplified version of logError)
390fn log_error(msg: &str) {
391    log::error!("{}", msg);
392}
393
394/// Create a system prompt from a list of strings
395pub fn as_system_prompt(parts: Vec<&str>) -> SystemPrompt {
396    parts.iter().map(|s| s.to_string()).collect()
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_extract_text_content_anthropic() {
405        let response = r#"{
406            "id": "msg_abc123",
407            "content": [
408                {"type": "text", "text": "Hello from Anthropic"},
409                {"type": "text", "text": "Second block"}
410            ]
411        }"#;
412        let result = extract_text_content(response);
413        assert_eq!(result, "Hello from Anthropic\nSecond block");
414    }
415
416    #[test]
417    fn test_extract_text_content_anthropic_single_block() {
418        let response = r#"{
419            "id": "msg_abc123",
420            "content": [
421                {"type": "text", "text": "Single block response"}
422            ]
423        }"#;
424        let result = extract_text_content(response);
425        assert_eq!(result, "Single block response");
426    }
427
428    #[test]
429    fn test_extract_text_content_openai() {
430        let response = r#"{
431            "id": "chatcmpl-123",
432            "choices": [
433                {
434                    "index": 0,
435                    "message": {
436                        "role": "assistant",
437                        "content": "Hello from OpenAI compatible"
438                    }
439                }
440            ]
441        }"#;
442        let result = extract_text_content(response);
443        assert_eq!(result, "Hello from OpenAI compatible");
444    }
445
446    #[test]
447    fn test_extract_text_content_fallback_invalid_json() {
448        let raw = "this is not json at all";
449        let result = extract_text_content(raw);
450        assert_eq!(result, raw);
451    }
452
453    #[test]
454    fn test_extract_text_content_fallback_unknown_format() {
455        let response = r#"{
456            "foo": "bar",
457            "data": "no content or choices here"
458        }"#;
459        let result = extract_text_content(response);
460        // Falls back to re-serialized JSON string
461        assert!(result.contains("foo"));
462        assert!(result.contains("bar"));
463    }
464
465    #[test]
466    fn test_role_to_api_string() {
467        assert_eq!(role_to_api_string(&MessageRole::User), "user");
468        assert_eq!(role_to_api_string(&MessageRole::Assistant), "assistant");
469        assert_eq!(role_to_api_string(&MessageRole::Tool), "tool");
470        assert_eq!(role_to_api_string(&MessageRole::System), "system");
471    }
472
473    #[test]
474    fn test_as_system_prompt() {
475        let prompt = as_system_prompt(vec!["line 1", "line 2", "line 3"]);
476        assert_eq!(prompt, vec!["line 1", "line 2", "line 3"]);
477    }
478
479    #[tokio::test]
480    async fn test_create_api_query_hook_should_run_false() {
481        // Verify the hook short-circuits when should_run returns false
482        let logged = Arc::new(std::sync::atomic::AtomicBool::new(false));
483        let logged_clone = logged.clone();
484        let hook = create_api_query_hook(ApiQueryHookConfig {
485            name: "test_hook".to_string(),
486            should_run: Box::new(|_| Box::pin(async { false })),
487            build_messages: Box::new(|_| vec![]),
488            system_prompt: None,
489            use_tools: None,
490            parse_response: Box::new(|_, _| ()),
491            log_result: Box::new(move |_, _| {
492                logged_clone.store(true, std::sync::atomic::Ordering::SeqCst);
493            }),
494            get_model: Box::new(|_| "test-model".to_string()),
495        });
496
497        // Create a minimal context
498        let context = ReplHookContext {
499            messages: vec![],
500            system_prompt: vec![],
501            user_context: std::collections::HashMap::new(),
502            system_context: std::collections::HashMap::new(),
503            tool_use_context: Arc::new(
504                crate::utils::hooks::can_use_tool::ToolUseContext {
505                    session_id: "test".to_string(),
506                    cwd: None,
507                    is_non_interactive_session: true,
508                    options: None,
509                }
510            ),
511            query_source: None,
512            query_message_count: None,
513        };
514
515        hook(context).await;
516        // should_run returned false, so log_result should NOT have been called
517        assert!(
518            !logged.load(std::sync::atomic::Ordering::SeqCst),
519            "log_result should not be called when should_run is false"
520        );
521    }
522
523    #[tokio::test]
524    async fn test_create_api_query_hook_calls_impl() {
525        // Verify the hook calls query_model_without_streaming_impl when should_run is true.
526        // Since the impl makes a real HTTP call, it will fail without an API key,
527        // but we verify the error path is wired correctly.
528        let hook_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
529        let hook_called_clone = hook_called.clone();
530        let hook = create_api_query_hook(ApiQueryHookConfig {
531            name: "wiring_test".to_string(),
532            should_run: Box::new(|_| Box::pin(async { true })),
533            build_messages: Box::new(|_| vec![Message {
534                role: MessageRole::User,
535                content: "test".to_string(),
536                ..Default::default()
537            }]),
538            system_prompt: Some(vec!["system prompt".to_string()]),
539            use_tools: None,
540            parse_response: Box::new(|_, _| ()),
541            log_result: Box::new(move |result, _| {
542                hook_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
543                // We expect an Error result because no real API key is set
544                match result {
545                    ApiQueryResult::Error { error, .. } => {
546                        // Expected: no API key or connection error
547                        let _ = error.to_string();
548                    }
549                    ApiQueryResult::Success { .. } => {
550                        // If it somehow succeeds, that's fine too
551                    }
552                }
553            }),
554            get_model: Box::new(|_| "claude-sonnet-4-5-20250514".to_string()),
555        });
556
557        let context = ReplHookContext {
558            messages: vec![],
559            system_prompt: vec![],
560            user_context: std::collections::HashMap::new(),
561            system_context: std::collections::HashMap::new(),
562            tool_use_context: Arc::new(
563                crate::utils::hooks::can_use_tool::ToolUseContext {
564                    session_id: "test".to_string(),
565                    cwd: None,
566                    is_non_interactive_session: true,
567                    options: None,
568                }
569            ),
570            query_source: None,
571            query_message_count: None,
572        };
573
574        hook(context).await;
575        assert!(
576            hook_called.load(std::sync::atomic::Ordering::SeqCst),
577            "log_result should have been called"
578        );
579    }
580
581    #[test]
582    fn test_extract_text_content_anthropic_with_tool_use_blocks() {
583        // Anthropic response with tool_use block mixed with text
584        let response = r#"{
585            "id": "msg_xyz",
586            "content": [
587                {"type": "text", "text": "Let me check that for you."},
588                {"type": "tool_use", "id": "tool_1", "name": "Read", "input": {"path": "file.txt"}},
589                {"type": "text", "text": "Here is the result."}
590            ]
591        }"#;
592        let result = extract_text_content(response);
593        // Should only extract text blocks, skipping tool_use blocks
594        assert_eq!(result, "Let me check that for you.\nHere is the result.");
595    }
596}