1pub mod bedrock;
4pub mod claude;
5pub mod openai;
6
7use std::future::Future;
8use std::pin::Pin;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use reqwest::Client;
13
14use crate::claude::error::ClaudeError;
15use crate::claude::model_config::get_model_registry;
16
17pub(crate) const REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
22
23#[derive(Clone, Debug)]
25pub struct AiClientMetadata {
26 pub provider: String,
28 pub model: String,
30 pub max_context_length: usize,
32 pub max_response_length: usize,
34 pub active_beta: Option<(String, String)>,
36}
37
38#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum PromptStyle {
45 Claude,
47 OpenAi,
49}
50
51impl AiClientMetadata {
52 #[must_use]
60 pub fn prompt_style(&self) -> PromptStyle {
61 match self.provider.as_str() {
62 "OpenAI" | "Ollama" => PromptStyle::OpenAi,
63 _ => PromptStyle::Claude,
64 }
65 }
66}
67
68pub(crate) fn build_http_client() -> Result<Client> {
72 Client::builder()
73 .timeout(REQUEST_TIMEOUT)
74 .build()
75 .context("Failed to build HTTP client")
76}
77
78#[must_use]
81pub(crate) fn registry_max_output_tokens(
82 model: &str,
83 active_beta: &Option<(String, String)>,
84) -> i32 {
85 let registry = get_model_registry();
86 if let Some((_, value)) = active_beta {
87 registry.get_max_output_tokens_with_beta(model, value) as i32
88 } else {
89 registry.get_max_output_tokens(model) as i32
90 }
91}
92
93#[must_use]
96pub(crate) fn registry_model_limits(
97 model: &str,
98 active_beta: &Option<(String, String)>,
99) -> (usize, usize) {
100 let registry = get_model_registry();
101 match active_beta {
102 Some((_, value)) => (
103 registry.get_input_context_with_beta(model, value),
104 registry.get_max_output_tokens_with_beta(model, value),
105 ),
106 None => (
107 registry.get_input_context(model),
108 registry.get_max_output_tokens(model),
109 ),
110 }
111}
112
113pub(crate) async fn check_error_response(response: reqwest::Response) -> Result<reqwest::Response> {
120 if response.status().is_success() {
121 return Ok(response);
122 }
123 let status = response.status();
124 let error_text = response.text().await.unwrap_or_else(|e| {
125 tracing::debug!("Failed to read error response body: {e}");
126 String::new()
127 });
128 Err(ClaudeError::ApiRequestFailed(format!("HTTP {status}: {error_text}")).into())
129}
130
131pub(crate) fn log_response_success(provider: &str, result: &Result<String>) {
133 if let Ok(text) = result {
134 tracing::debug!(
135 response_len = text.len(),
136 "Successfully extracted text content from {} API response",
137 provider
138 );
139 tracing::debug!(
140 response_content = %text,
141 "{} API response content",
142 provider
143 );
144 }
145}
146
147pub trait AiClient: Send + Sync {
149 fn send_request<'a>(
151 &'a self,
152 system_prompt: &'a str,
153 user_prompt: &'a str,
154 ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>>;
155
156 fn get_metadata(&self) -> AiClientMetadata;
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163
164 fn meta(provider: &str) -> AiClientMetadata {
165 AiClientMetadata {
166 provider: provider.to_string(),
167 model: "test-model".to_string(),
168 max_context_length: 1024,
169 max_response_length: 1024,
170 active_beta: None,
171 }
172 }
173
174 #[test]
175 fn prompt_style_openai() {
176 assert_eq!(meta("OpenAI").prompt_style(), PromptStyle::OpenAi);
177 }
178
179 #[test]
180 fn prompt_style_ollama() {
181 assert_eq!(meta("Ollama").prompt_style(), PromptStyle::OpenAi);
182 }
183
184 #[test]
185 fn prompt_style_anthropic() {
186 assert_eq!(meta("Anthropic").prompt_style(), PromptStyle::Claude);
187 }
188
189 #[test]
190 fn prompt_style_bedrock() {
191 assert_eq!(
192 meta("Anthropic Bedrock").prompt_style(),
193 PromptStyle::Claude
194 );
195 }
196
197 #[test]
198 fn prompt_style_unknown_defaults_to_claude() {
199 assert_eq!(meta("SomeNewProvider").prompt_style(), PromptStyle::Claude);
200 }
201
202 #[test]
205 fn prompt_style_case_sensitive() {
206 assert_eq!(meta("openai").prompt_style(), PromptStyle::Claude);
207 assert_eq!(meta("ollama").prompt_style(), PromptStyle::Claude);
208 }
209}