1#![allow(dead_code)]
3
4use serde::{Deserialize, Serialize};
5use std::future::Future;
6use std::sync::Arc;
7use uuid::Uuid;
8
9use crate::types::{Message, MessageRole};
10
11pub type SystemPrompt = Vec<String>;
13
14#[derive(Clone)]
16pub struct ReplHookContext {
17 pub messages: Vec<Message>,
19 pub system_prompt: SystemPrompt,
21 pub user_context: std::collections::HashMap<String, String>,
23 pub system_context: std::collections::HashMap<String, String>,
25 pub tool_use_context: Arc<crate::utils::hooks::can_use_tool::ToolUseContext>,
27 pub query_source: Option<String>,
29 pub query_message_count: Option<usize>,
31}
32
33pub struct ApiQueryHookConfig<TResult> {
35 pub name: String,
37 pub should_run: Box<
39 dyn Fn(&ReplHookContext) -> std::pin::Pin<Box<dyn Future<Output = bool> + Send>>
40 + Send
41 + Sync,
42 >,
43 pub build_messages: Box<dyn Fn(&ReplHookContext) -> Vec<Message> + Send + Sync>,
45 pub system_prompt: Option<SystemPrompt>,
47 pub use_tools: Option<bool>,
49 pub parse_response: Box<dyn Fn(&str, &ReplHookContext) -> TResult + Send + Sync>,
51 pub log_result: Box<dyn Fn(ApiQueryResult<TResult>, &ReplHookContext) + Send + Sync>,
53 pub get_model: Box<dyn Fn(&ReplHookContext) -> String + Send + Sync>,
55}
56
57pub enum ApiQueryResult<TResult> {
59 Success {
60 query_name: String,
61 result: TResult,
62 message_id: String,
63 model: String,
64 uuid: String,
65 },
66 Error {
67 query_name: String,
68 error: Box<dyn std::error::Error + Send + Sync>,
69 uuid: String,
70 },
71}
72
73pub fn create_api_query_hook<TResult: 'static>(
76 config: ApiQueryHookConfig<TResult>,
77) -> Box<dyn Fn(ReplHookContext) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>
78{
79 let config = Arc::new(config);
80 Box::new(move |context: ReplHookContext| {
81 let config = config.clone();
82 Box::pin(async move {
83 let should_run = (config.should_run)(&context).await;
84 if !should_run {
85 return;
86 }
87
88 let uuid = Uuid::new_v4().to_string();
89
90 let messages = (config.build_messages)(&context);
92 let system_prompt = config
97 .system_prompt
98 .clone()
99 .unwrap_or_else(|| context.system_prompt.clone());
100
101 let model = (config.get_model)(&context);
106
107 let response_result =
110 query_model_without_streaming_impl(&messages, &system_prompt, &model, &context)
111 .await;
112
113 match response_result {
114 Ok(response) => {
115 let content = extract_text_content(&response.content).trim().to_string();
117
118 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
119 (config.parse_response)(&content, &context)
120 }));
121
122 match result {
123 Ok(parsed_result) => {
124 (config.log_result)(
125 ApiQueryResult::Success {
126 query_name: config.name.clone(),
127 result: parsed_result,
128 message_id: response.message_id,
129 model,
130 uuid,
131 },
132 &context,
133 );
134 }
135 Err(err) => {
136 let error = if let Some(s) = err.downcast_ref::<String>() {
137 Box::new(std::io::Error::new(std::io::ErrorKind::Other, s.clone()))
138 } else if let Some(s) = err.downcast_ref::<&str>() {
139 Box::new(std::io::Error::new(
140 std::io::ErrorKind::Other,
141 s.to_string(),
142 ))
143 } else {
144 Box::new(std::io::Error::new(
145 std::io::ErrorKind::Other,
146 "Unknown panic in parse_response",
147 ))
148 };
149 (config.log_result)(
150 ApiQueryResult::Error {
151 query_name: config.name.clone(),
152 error,
153 uuid,
154 },
155 &context,
156 );
157 }
158 }
159 }
160 Err(error) => {
161 log_error(&format!("API query hook error: {}", error));
162 (config.log_result)(
163 ApiQueryResult::Error {
164 query_name: config.name.clone(),
165 error,
166 uuid,
167 },
168 &context,
169 );
170 }
171 }
172 })
173 })
174}
175
176struct ApiResponse {
178 message_id: String,
179 content: String,
180}
181
182fn get_api_key() -> Result<String, String> {
185 if let Ok(key) = std::env::var("AI_AUTH_TOKEN") {
186 if !key.is_empty() {
187 return Ok(key);
188 }
189 }
190 if let Ok(key) = std::env::var("ANTHROPIC_API_KEY") {
191 if !key.is_empty() {
192 return Ok(key);
193 }
194 }
195 if let Ok(key) = std::env::var("ANTHROPIC_AUTH_TOKEN") {
196 if !key.is_empty() {
197 return Ok(key);
198 }
199 }
200 Err("No API key found. Set AI_AUTH_TOKEN, ANTHROPIC_API_KEY, or ANTHROPIC_AUTH_TOKEN"
201 .to_string())
202}
203
204fn role_to_api_string(role: &MessageRole) -> &'static str {
206 match role {
207 MessageRole::User => "user",
208 MessageRole::Assistant => "assistant",
209 MessageRole::Tool => "tool",
210 MessageRole::System => "system",
211 }
212}
213
214async fn query_model_without_streaming_impl(
217 messages: &[Message],
218 system_prompt: &SystemPrompt,
219 model: &str,
220 _context: &ReplHookContext,
221) -> Result<ApiResponse, Box<dyn std::error::Error + Send + Sync>> {
222 let api_key = get_api_key().map_err(|e| {
223 Box::<dyn std::error::Error + Send + Sync>::from(std::io::Error::new(
224 std::io::ErrorKind::Other,
225 e,
226 ))
227 })?;
228
229 let base_url = std::env::var("AI_API_BASE_URL")
230 .ok()
231 .unwrap_or_else(|| "https://api.anthropic.com".to_string());
232 let url = format!("{}/v1/messages", base_url);
233
234 let is_anthropic = base_url.contains("anthropic.com");
236
237 let api_messages: Vec<serde_json::Value> = messages
239 .iter()
240 .map(|m| {
241 let mut msg_obj = serde_json::json!({
242 "role": role_to_api_string(&m.role),
243 "content": &m.content
244 });
245
246 if m.role == MessageRole::Tool {
248 if let Some(ref tool_call_id) = m.tool_call_id {
249 msg_obj["tool_use_id"] = serde_json::json!(tool_call_id);
250 }
251 }
252
253 msg_obj
254 })
255 .collect();
256
257 let system_prompt_value = serde_json::json!({
259 "type": "text",
260 "text": system_prompt.join("\n")
261 });
262
263 let request_body = serde_json::json!({
265 "model": model,
266 "max_tokens": 4096,
267 "system": system_prompt_value,
268 "messages": api_messages,
269 "temperature": 0.0,
270 });
271
272 let client = reqwest::Client::new();
273 let request_builder = if is_anthropic {
274 client
275 .post(&url)
276 .header("x-api-key", &api_key)
277 .header("anthropic-version", "2023-06-01")
278 .header("Content-Type", "application/json")
279 .header("User-Agent", crate::utils::http::get_user_agent())
280 .json(&request_body)
281 } else {
282 client
283 .post(&url)
284 .header("Authorization", format!("Bearer {}", api_key))
285 .header("Content-Type", "application/json")
286 .header("User-Agent", crate::utils::http::get_user_agent())
287 .json(&request_body)
288 };
289
290 let response = request_builder
291 .send()
292 .await
293 .map_err(|e| {
294 Box::<dyn std::error::Error + Send + Sync>::from(std::io::Error::new(
295 std::io::ErrorKind::ConnectionRefused,
296 format!("API request failed: {}", e),
297 ))
298 })?;
299
300 let status = response.status();
301 if !status.is_success() {
302 let error_text = response.text().await.unwrap_or_default();
303 return Err(Box::<dyn std::error::Error + Send + Sync>::from(
304 std::io::Error::new(
305 std::io::ErrorKind::Other,
306 format!("API error {}: {}", status, error_text),
307 ),
308 ));
309 }
310
311 let response_json: serde_json::Value = response
313 .json()
314 .await
315 .map_err(|e| {
316 Box::<dyn std::error::Error + Send + Sync>::from(std::io::Error::new(
317 std::io::ErrorKind::InvalidData,
318 format!("Failed to parse API response: {}", e),
319 ))
320 })?;
321
322 if let Some(error) = response_json.get("error") {
324 let error_msg = error
325 .get("message")
326 .and_then(|m| m.as_str())
327 .unwrap_or("Unknown error");
328 return Err(Box::<dyn std::error::Error + Send + Sync>::from(
329 std::io::Error::new(
330 std::io::ErrorKind::Other,
331 format!("API error: {}", error_msg),
332 ),
333 ));
334 }
335
336 let message_id = response_json
338 .get("id")
339 .and_then(|id| id.as_str())
340 .unwrap_or("unknown")
341 .to_string();
342
343 let content = serde_json::to_string(&response_json).unwrap_or_default();
345
346 Ok(ApiResponse {
347 message_id,
348 content,
349 })
350}
351
352fn extract_text_content(response_json: &str) -> String {
357 let Ok(response) = serde_json::from_str::<serde_json::Value>(response_json) else {
358 return response_json.to_string();
359 };
360
361 if let Some(content) = response
363 .get("choices")
364 .and_then(|c| c.as_array())
365 .and_then(|c| c.first())
366 .and_then(|c| c.get("message"))
367 .and_then(|m| m.get("content"))
368 .and_then(|c| c.as_str())
369 {
370 return content.to_string();
371 }
372
373 if let Some(blocks) = response.get("content").and_then(|c| c.as_array()) {
375 let mut texts = Vec::new();
376 for block in blocks {
377 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
378 texts.push(text.to_string());
379 }
380 }
381 if !texts.is_empty() {
382 return texts.join("\n");
383 }
384 }
385
386 response_json.to_string()
387}
388
389fn log_error(msg: &str) {
391 log::error!("{}", msg);
392}
393
394pub fn as_system_prompt(parts: Vec<&str>) -> SystemPrompt {
396 parts.iter().map(|s| s.to_string()).collect()
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_extract_text_content_anthropic() {
405 let response = r#"{
406 "id": "msg_abc123",
407 "content": [
408 {"type": "text", "text": "Hello from Anthropic"},
409 {"type": "text", "text": "Second block"}
410 ]
411 }"#;
412 let result = extract_text_content(response);
413 assert_eq!(result, "Hello from Anthropic\nSecond block");
414 }
415
416 #[test]
417 fn test_extract_text_content_anthropic_single_block() {
418 let response = r#"{
419 "id": "msg_abc123",
420 "content": [
421 {"type": "text", "text": "Single block response"}
422 ]
423 }"#;
424 let result = extract_text_content(response);
425 assert_eq!(result, "Single block response");
426 }
427
428 #[test]
429 fn test_extract_text_content_openai() {
430 let response = r#"{
431 "id": "chatcmpl-123",
432 "choices": [
433 {
434 "index": 0,
435 "message": {
436 "role": "assistant",
437 "content": "Hello from OpenAI compatible"
438 }
439 }
440 ]
441 }"#;
442 let result = extract_text_content(response);
443 assert_eq!(result, "Hello from OpenAI compatible");
444 }
445
446 #[test]
447 fn test_extract_text_content_fallback_invalid_json() {
448 let raw = "this is not json at all";
449 let result = extract_text_content(raw);
450 assert_eq!(result, raw);
451 }
452
453 #[test]
454 fn test_extract_text_content_fallback_unknown_format() {
455 let response = r#"{
456 "foo": "bar",
457 "data": "no content or choices here"
458 }"#;
459 let result = extract_text_content(response);
460 assert!(result.contains("foo"));
462 assert!(result.contains("bar"));
463 }
464
465 #[test]
466 fn test_role_to_api_string() {
467 assert_eq!(role_to_api_string(&MessageRole::User), "user");
468 assert_eq!(role_to_api_string(&MessageRole::Assistant), "assistant");
469 assert_eq!(role_to_api_string(&MessageRole::Tool), "tool");
470 assert_eq!(role_to_api_string(&MessageRole::System), "system");
471 }
472
473 #[test]
474 fn test_as_system_prompt() {
475 let prompt = as_system_prompt(vec!["line 1", "line 2", "line 3"]);
476 assert_eq!(prompt, vec!["line 1", "line 2", "line 3"]);
477 }
478
479 #[tokio::test]
480 async fn test_create_api_query_hook_should_run_false() {
481 let logged = Arc::new(std::sync::atomic::AtomicBool::new(false));
483 let logged_clone = logged.clone();
484 let hook = create_api_query_hook(ApiQueryHookConfig {
485 name: "test_hook".to_string(),
486 should_run: Box::new(|_| Box::pin(async { false })),
487 build_messages: Box::new(|_| vec![]),
488 system_prompt: None,
489 use_tools: None,
490 parse_response: Box::new(|_, _| ()),
491 log_result: Box::new(move |_, _| {
492 logged_clone.store(true, std::sync::atomic::Ordering::SeqCst);
493 }),
494 get_model: Box::new(|_| "test-model".to_string()),
495 });
496
497 let context = ReplHookContext {
499 messages: vec![],
500 system_prompt: vec![],
501 user_context: std::collections::HashMap::new(),
502 system_context: std::collections::HashMap::new(),
503 tool_use_context: Arc::new(
504 crate::utils::hooks::can_use_tool::ToolUseContext {
505 session_id: "test".to_string(),
506 cwd: None,
507 is_non_interactive_session: true,
508 options: None,
509 }
510 ),
511 query_source: None,
512 query_message_count: None,
513 };
514
515 hook(context).await;
516 assert!(
518 !logged.load(std::sync::atomic::Ordering::SeqCst),
519 "log_result should not be called when should_run is false"
520 );
521 }
522
523 #[tokio::test]
524 async fn test_create_api_query_hook_calls_impl() {
525 let hook_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
529 let hook_called_clone = hook_called.clone();
530 let hook = create_api_query_hook(ApiQueryHookConfig {
531 name: "wiring_test".to_string(),
532 should_run: Box::new(|_| Box::pin(async { true })),
533 build_messages: Box::new(|_| vec![Message {
534 role: MessageRole::User,
535 content: "test".to_string(),
536 ..Default::default()
537 }]),
538 system_prompt: Some(vec!["system prompt".to_string()]),
539 use_tools: None,
540 parse_response: Box::new(|_, _| ()),
541 log_result: Box::new(move |result, _| {
542 hook_called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
543 match result {
545 ApiQueryResult::Error { error, .. } => {
546 let _ = error.to_string();
548 }
549 ApiQueryResult::Success { .. } => {
550 }
552 }
553 }),
554 get_model: Box::new(|_| "claude-sonnet-4-5-20250514".to_string()),
555 });
556
557 let context = ReplHookContext {
558 messages: vec![],
559 system_prompt: vec![],
560 user_context: std::collections::HashMap::new(),
561 system_context: std::collections::HashMap::new(),
562 tool_use_context: Arc::new(
563 crate::utils::hooks::can_use_tool::ToolUseContext {
564 session_id: "test".to_string(),
565 cwd: None,
566 is_non_interactive_session: true,
567 options: None,
568 }
569 ),
570 query_source: None,
571 query_message_count: None,
572 };
573
574 hook(context).await;
575 assert!(
576 hook_called.load(std::sync::atomic::Ordering::SeqCst),
577 "log_result should have been called"
578 );
579 }
580
581 #[test]
582 fn test_extract_text_content_anthropic_with_tool_use_blocks() {
583 let response = r#"{
585 "id": "msg_xyz",
586 "content": [
587 {"type": "text", "text": "Let me check that for you."},
588 {"type": "tool_use", "id": "tool_1", "name": "Read", "input": {"path": "file.txt"}},
589 {"type": "text", "text": "Here is the result."}
590 ]
591 }"#;
592 let result = extract_text_content(response);
593 assert_eq!(result, "Let me check that for you.\nHere is the result.");
595 }
596}