Skip to main content

omni_dev/claude/
ai.rs

1//! AI client trait and metadata definitions.
2
3pub mod bedrock;
4pub mod claude;
5pub mod openai;
6
7use std::future::Future;
8use std::pin::Pin;
9use std::time::Duration;
10
11use anyhow::{Context, Result};
12use reqwest::Client;
13
14use crate::claude::error::ClaudeError;
15use crate::claude::model_config::get_model_registry;
16
17/// HTTP request timeout for AI API calls.
18///
19/// Set to 5 minutes to accommodate large prompts and long model responses
20/// (up to 64k output tokens) while preventing indefinite hangs.
21pub(crate) const REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
22
23/// Metadata about an AI client implementation.
24#[derive(Clone, Debug)]
25pub struct AiClientMetadata {
26    /// Service provider name.
27    pub provider: String,
28    /// Model identifier.
29    pub model: String,
30    /// Maximum context length supported.
31    pub max_context_length: usize,
32    /// Maximum token response length supported.
33    pub max_response_length: usize,
34    /// Active beta header, if any: (key, value).
35    pub active_beta: Option<(String, String)>,
36}
37
38/// Prompt formatting families for AI providers.
39///
40/// Determines provider-specific prompt behaviour (e.g., how template
41/// instructions are phrased). Parse once at the boundary via
42/// [`AiClientMetadata::prompt_style`] and match on the enum downstream.
43#[derive(Clone, Copy, Debug, PartialEq, Eq)]
44pub enum PromptStyle {
45    /// Claude models handle "literal template" instructions correctly.
46    Claude,
47    /// OpenAI-compatible models (OpenAI, Ollama) need different formatting.
48    OpenAi,
49}
50
51impl AiClientMetadata {
52    /// Derives the prompt style from the provider name.
53    ///
54    /// Matches against the exact strings set by each [`AiClient`] implementation:
55    /// - `"OpenAI"` and `"Ollama"` → [`PromptStyle::OpenAi`]
56    /// - `"Anthropic"` and `"Anthropic Bedrock"` → [`PromptStyle::Claude`]
57    ///
58    /// Unrecognised provider strings default to [`PromptStyle::Claude`].
59    #[must_use]
60    pub fn prompt_style(&self) -> PromptStyle {
61        match self.provider.as_str() {
62            "OpenAI" | "Ollama" => PromptStyle::OpenAi,
63            _ => PromptStyle::Claude,
64        }
65    }
66}
67
68// ── Shared helpers for AI client implementations ────────────────────
69
70/// Builds an HTTP client with the standard request timeout.
71pub(crate) fn build_http_client() -> Result<Client> {
72    Client::builder()
73        .timeout(REQUEST_TIMEOUT)
74        .build()
75        .context("Failed to build HTTP client")
76}
77
78/// Returns the maximum output tokens for a model from the registry,
79/// respecting beta overrides.
80#[must_use]
81pub(crate) fn registry_max_output_tokens(
82    model: &str,
83    active_beta: &Option<(String, String)>,
84) -> i32 {
85    let registry = get_model_registry();
86    if let Some((_, value)) = active_beta {
87        registry.get_max_output_tokens_with_beta(model, value) as i32
88    } else {
89        registry.get_max_output_tokens(model) as i32
90    }
91}
92
93/// Returns the (input context length, max response length) for a model
94/// from the registry, respecting beta overrides.
95#[must_use]
96pub(crate) fn registry_model_limits(
97    model: &str,
98    active_beta: &Option<(String, String)>,
99) -> (usize, usize) {
100    let registry = get_model_registry();
101    match active_beta {
102        Some((_, value)) => (
103            registry.get_input_context_with_beta(model, value),
104            registry.get_max_output_tokens_with_beta(model, value),
105        ),
106        None => (
107            registry.get_input_context(model),
108            registry.get_max_output_tokens(model),
109        ),
110    }
111}
112
113/// Checks an HTTP response for error status and returns a structured error
114/// if non-success.
115///
116/// On success, returns the response unchanged for further processing.
117/// On failure, reads the error body and returns a
118/// [`ClaudeError::ApiRequestFailed`].
119pub(crate) async fn check_error_response(response: reqwest::Response) -> Result<reqwest::Response> {
120    if response.status().is_success() {
121        return Ok(response);
122    }
123    let status = response.status();
124    let error_text = response.text().await.unwrap_or_else(|e| {
125        tracing::debug!("Failed to read error response body: {e}");
126        String::new()
127    });
128    Err(ClaudeError::ApiRequestFailed(format!("HTTP {status}: {error_text}")).into())
129}
130
131/// Logs successful text extraction from an AI API response.
132pub(crate) fn log_response_success(provider: &str, result: &Result<String>) {
133    if let Ok(text) = result {
134        tracing::debug!(
135            response_len = text.len(),
136            "Successfully extracted text content from {} API response",
137            provider
138        );
139        tracing::debug!(
140            response_content = %text,
141            "{} API response content",
142            provider
143        );
144    }
145}
146
147/// Trait for AI service clients.
148pub trait AiClient: Send + Sync {
149    /// Sends a request to the AI service and returns the raw response.
150    fn send_request<'a>(
151        &'a self,
152        system_prompt: &'a str,
153        user_prompt: &'a str,
154    ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>>;
155
156    /// Returns metadata about the AI client implementation.
157    fn get_metadata(&self) -> AiClientMetadata;
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    fn meta(provider: &str) -> AiClientMetadata {
165        AiClientMetadata {
166            provider: provider.to_string(),
167            model: "test-model".to_string(),
168            max_context_length: 1024,
169            max_response_length: 1024,
170            active_beta: None,
171        }
172    }
173
174    #[test]
175    fn prompt_style_openai() {
176        assert_eq!(meta("OpenAI").prompt_style(), PromptStyle::OpenAi);
177    }
178
179    #[test]
180    fn prompt_style_ollama() {
181        assert_eq!(meta("Ollama").prompt_style(), PromptStyle::OpenAi);
182    }
183
184    #[test]
185    fn prompt_style_anthropic() {
186        assert_eq!(meta("Anthropic").prompt_style(), PromptStyle::Claude);
187    }
188
189    #[test]
190    fn prompt_style_bedrock() {
191        assert_eq!(
192            meta("Anthropic Bedrock").prompt_style(),
193            PromptStyle::Claude
194        );
195    }
196
197    #[test]
198    fn prompt_style_unknown_defaults_to_claude() {
199        assert_eq!(meta("SomeNewProvider").prompt_style(), PromptStyle::Claude);
200    }
201
202    /// Ensure case-sensitive matching: "openai" (lowercase) is not a known provider
203    /// string and must not silently match as OpenAI.
204    #[test]
205    fn prompt_style_case_sensitive() {
206        assert_eq!(meta("openai").prompt_style(), PromptStyle::Claude);
207        assert_eq!(meta("ollama").prompt_style(), PromptStyle::Claude);
208    }
209}