1pub mod bedrock;
4pub mod claude;
5pub mod claude_cli;
6pub mod openai;
7
8use std::future::Future;
9use std::pin::Pin;
10use std::time::Duration;
11
12use anyhow::{Context, Result};
13use reqwest::Client;
14use serde_json::Value;
15
16use crate::claude::error::ClaudeError;
17use crate::claude::model_config::get_model_registry;
18
19pub(crate) const REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
24
25#[derive(Clone, Debug)]
27pub struct AiClientMetadata {
28 pub provider: String,
30 pub model: String,
32 pub max_context_length: usize,
34 pub max_response_length: usize,
36 pub active_beta: Option<(String, String)>,
38}
39
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46pub enum PromptStyle {
47 Claude,
49 OpenAi,
51}
52
53impl AiClientMetadata {
54 #[must_use]
62 pub fn prompt_style(&self) -> PromptStyle {
63 match self.provider.as_str() {
64 "OpenAI" | "Ollama" => PromptStyle::OpenAi,
65 _ => PromptStyle::Claude,
66 }
67 }
68}
69
70pub(crate) fn build_http_client() -> Result<Client> {
74 Client::builder()
75 .timeout(REQUEST_TIMEOUT)
76 .build()
77 .context("Failed to build HTTP client")
78}
79
80#[must_use]
83pub(crate) fn registry_max_output_tokens(
84 model: &str,
85 active_beta: &Option<(String, String)>,
86) -> i32 {
87 let registry = get_model_registry();
88 if let Some((_, value)) = active_beta {
89 registry.get_max_output_tokens_with_beta(model, value) as i32
90 } else {
91 registry.get_max_output_tokens(model) as i32
92 }
93}
94
95#[must_use]
98pub(crate) fn registry_model_limits(
99 model: &str,
100 active_beta: &Option<(String, String)>,
101) -> (usize, usize) {
102 let registry = get_model_registry();
103 match active_beta {
104 Some((_, value)) => (
105 registry.get_input_context_with_beta(model, value),
106 registry.get_max_output_tokens_with_beta(model, value),
107 ),
108 None => (
109 registry.get_input_context(model),
110 registry.get_max_output_tokens(model),
111 ),
112 }
113}
114
115pub(crate) async fn check_error_response(response: reqwest::Response) -> Result<reqwest::Response> {
122 if response.status().is_success() {
123 return Ok(response);
124 }
125 let status = response.status();
126 let error_text = response.text().await.unwrap_or_else(|e| {
127 tracing::debug!("Failed to read error response body: {e}");
128 String::new()
129 });
130 Err(ClaudeError::ApiRequestFailed(format!("HTTP {status}: {error_text}")).into())
131}
132
133pub(crate) fn log_response_success(provider: &str, result: &Result<String>) {
135 if let Ok(text) = result {
136 tracing::debug!(
137 response_len = text.len(),
138 "Successfully extracted text content from {} API response",
139 provider
140 );
141 tracing::debug!(
142 response_content = %text,
143 "{} API response content",
144 provider
145 );
146 }
147}
148
149#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
157pub struct AiClientCapabilities {
158 pub supports_response_schema: bool,
165}
166
167#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
173pub enum ResponseFormat {
174 #[default]
177 Yaml,
178 JsonSchema,
183}
184
185impl ResponseFormat {
186 #[must_use]
189 pub fn from_capabilities(caps: &AiClientCapabilities) -> Self {
190 if caps.supports_response_schema {
191 Self::JsonSchema
192 } else {
193 Self::Yaml
194 }
195 }
196}
197
198#[derive(Clone, Debug, Default)]
205pub struct RequestOptions {
206 pub response_schema: Option<Value>,
210}
211
212impl RequestOptions {
213 #[must_use]
215 pub fn with_response_schema(mut self, schema: Value) -> Self {
216 self.response_schema = Some(schema);
217 self
218 }
219}
220
221pub trait AiClient: Send + Sync {
223 fn send_request<'a>(
225 &'a self,
226 system_prompt: &'a str,
227 user_prompt: &'a str,
228 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>>;
229
230 fn get_metadata(&self) -> AiClientMetadata;
232
233 fn capabilities(&self) -> AiClientCapabilities {
240 AiClientCapabilities::default()
241 }
242
243 fn send_request_with_options<'a>(
252 &'a self,
253 system_prompt: &'a str,
254 user_prompt: &'a str,
255 _options: RequestOptions,
256 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
257 self.send_request(system_prompt, user_prompt)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264
265 fn meta(provider: &str) -> AiClientMetadata {
266 AiClientMetadata {
267 provider: provider.to_string(),
268 model: "test-model".to_string(),
269 max_context_length: 1024,
270 max_response_length: 1024,
271 active_beta: None,
272 }
273 }
274
275 #[test]
276 fn prompt_style_openai() {
277 assert_eq!(meta("OpenAI").prompt_style(), PromptStyle::OpenAi);
278 }
279
280 #[test]
281 fn prompt_style_ollama() {
282 assert_eq!(meta("Ollama").prompt_style(), PromptStyle::OpenAi);
283 }
284
285 #[test]
286 fn prompt_style_anthropic() {
287 assert_eq!(meta("Anthropic").prompt_style(), PromptStyle::Claude);
288 }
289
290 #[test]
291 fn prompt_style_bedrock() {
292 assert_eq!(
293 meta("Anthropic Bedrock").prompt_style(),
294 PromptStyle::Claude
295 );
296 }
297
298 #[test]
299 fn prompt_style_unknown_defaults_to_claude() {
300 assert_eq!(meta("SomeNewProvider").prompt_style(), PromptStyle::Claude);
301 }
302
303 #[test]
306 fn prompt_style_case_sensitive() {
307 assert_eq!(meta("openai").prompt_style(), PromptStyle::Claude);
308 assert_eq!(meta("ollama").prompt_style(), PromptStyle::Claude);
309 }
310
311 #[test]
312 fn capabilities_default_is_all_disabled() {
313 let caps = AiClientCapabilities::default();
314 assert!(!caps.supports_response_schema);
315 }
316
317 #[test]
318 fn response_format_default_is_yaml() {
319 assert_eq!(ResponseFormat::default(), ResponseFormat::Yaml);
320 }
321
322 #[test]
323 fn response_format_from_capabilities_disabled_picks_yaml() {
324 let caps = AiClientCapabilities::default();
325 assert_eq!(
326 ResponseFormat::from_capabilities(&caps),
327 ResponseFormat::Yaml
328 );
329 }
330
331 #[test]
332 fn response_format_from_capabilities_enabled_picks_json_schema() {
333 let caps = AiClientCapabilities {
334 supports_response_schema: true,
335 };
336 assert_eq!(
337 ResponseFormat::from_capabilities(&caps),
338 ResponseFormat::JsonSchema
339 );
340 }
341
342 #[test]
343 fn request_options_with_response_schema_sets_field() {
344 let value = serde_json::json!({"type": "object"});
345 let opts = RequestOptions::default().with_response_schema(value.clone());
346 assert_eq!(opts.response_schema, Some(value));
347 }
348
349 #[test]
350 fn request_options_default_has_no_schema() {
351 let opts = RequestOptions::default();
352 assert!(opts.response_schema.is_none());
353 }
354}