Skip to main content

omni_dev/claude/
ai.rs

1//! AI client trait and metadata definitions.
2
3pub mod bedrock;
4pub mod claude;
5pub mod claude_cli;
6pub mod openai;
7
8use std::future::Future;
9use std::pin::Pin;
10use std::time::Duration;
11
12use anyhow::{Context, Result};
13use reqwest::Client;
14use serde_json::Value;
15
16use crate::claude::error::ClaudeError;
17use crate::claude::model_config::get_model_registry;
18
19/// HTTP request timeout for AI API calls.
20///
21/// Set to 5 minutes to accommodate large prompts and long model responses
22/// (up to 64k output tokens) while preventing indefinite hangs.
23pub(crate) const REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
24
25/// Metadata about an AI client implementation.
26#[derive(Clone, Debug)]
27pub struct AiClientMetadata {
28    /// Service provider name.
29    pub provider: String,
30    /// Model identifier.
31    pub model: String,
32    /// Maximum context length supported.
33    pub max_context_length: usize,
34    /// Maximum token response length supported.
35    pub max_response_length: usize,
36    /// Active beta header, if any: (key, value).
37    pub active_beta: Option<(String, String)>,
38}
39
40/// Prompt formatting families for AI providers.
41///
42/// Determines provider-specific prompt behaviour (e.g., how template
43/// instructions are phrased). Parse once at the boundary via
44/// [`AiClientMetadata::prompt_style`] and match on the enum downstream.
45#[derive(Clone, Copy, Debug, PartialEq, Eq)]
46pub enum PromptStyle {
47    /// Claude models handle "literal template" instructions correctly.
48    Claude,
49    /// OpenAI-compatible models (OpenAI, Ollama) need different formatting.
50    OpenAi,
51}
52
53impl AiClientMetadata {
54    /// Derives the prompt style from the provider name.
55    ///
56    /// Matches against the exact strings set by each [`AiClient`] implementation:
57    /// - `"OpenAI"` and `"Ollama"` → [`PromptStyle::OpenAi`]
58    /// - `"Anthropic"` and `"Anthropic Bedrock"` → [`PromptStyle::Claude`]
59    ///
60    /// Unrecognised provider strings default to [`PromptStyle::Claude`].
61    #[must_use]
62    pub fn prompt_style(&self) -> PromptStyle {
63        match self.provider.as_str() {
64            "OpenAI" | "Ollama" => PromptStyle::OpenAi,
65            _ => PromptStyle::Claude,
66        }
67    }
68}
69
70// ── Shared helpers for AI client implementations ────────────────────
71
72/// Builds an HTTP client with the standard request timeout.
73pub(crate) fn build_http_client() -> Result<Client> {
74    Client::builder()
75        .timeout(REQUEST_TIMEOUT)
76        .build()
77        .context("Failed to build HTTP client")
78}
79
80/// Returns the maximum output tokens for a model from the registry,
81/// respecting beta overrides.
82#[must_use]
83pub(crate) fn registry_max_output_tokens(
84    model: &str,
85    active_beta: &Option<(String, String)>,
86) -> i32 {
87    let registry = get_model_registry();
88    if let Some((_, value)) = active_beta {
89        registry.get_max_output_tokens_with_beta(model, value) as i32
90    } else {
91        registry.get_max_output_tokens(model) as i32
92    }
93}
94
95/// Returns the (input context length, max response length) for a model
96/// from the registry, respecting beta overrides.
97#[must_use]
98pub(crate) fn registry_model_limits(
99    model: &str,
100    active_beta: &Option<(String, String)>,
101) -> (usize, usize) {
102    let registry = get_model_registry();
103    match active_beta {
104        Some((_, value)) => (
105            registry.get_input_context_with_beta(model, value),
106            registry.get_max_output_tokens_with_beta(model, value),
107        ),
108        None => (
109            registry.get_input_context(model),
110            registry.get_max_output_tokens(model),
111        ),
112    }
113}
114
115/// Checks an HTTP response for error status and returns a structured error
116/// if non-success.
117///
118/// On success, returns the response unchanged for further processing.
119/// On failure, reads the error body and returns a
120/// [`ClaudeError::ApiRequestFailed`].
121pub(crate) async fn check_error_response(response: reqwest::Response) -> Result<reqwest::Response> {
122    if response.status().is_success() {
123        return Ok(response);
124    }
125    let status = response.status();
126    let error_text = response.text().await.unwrap_or_else(|e| {
127        tracing::debug!("Failed to read error response body: {e}");
128        String::new()
129    });
130    Err(ClaudeError::ApiRequestFailed(format!("HTTP {status}: {error_text}")).into())
131}
132
133/// Logs successful text extraction from an AI API response.
134pub(crate) fn log_response_success(provider: &str, result: &Result<String>) {
135    if let Ok(text) = result {
136        tracing::debug!(
137            response_len = text.len(),
138            "Successfully extracted text content from {} API response",
139            provider
140        );
141        tracing::debug!(
142            response_content = %text,
143            "{} API response content",
144            provider
145        );
146    }
147}
148
149/// Capabilities advertised by an [`AiClient`] implementation.
150///
151/// Used by call sites to decide whether to attach a structured-response
152/// schema (or other backend-specific request options) before dispatching.
153/// The default value is the conservative ''nothing supported'' baseline so
154/// new fields can be added without forcing existing implementations to
155/// update.
156#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
157pub struct AiClientCapabilities {
158    /// Whether the backend can enforce a JSON Schema on its response.
159    ///
160    /// When `true`, the call site may set
161    /// [`RequestOptions::response_schema`]; the backend will hand the schema
162    /// to its underlying API (e.g. `claude -p --json-schema <file>`) and the
163    /// API re-prompts until the model produces a validating response.
164    pub supports_response_schema: bool,
165}
166
167/// Whether the response should be formatted as YAML (default) or JSON
168/// matching a schema.
169///
170/// Used by the prompts module to swap the format-specific portion of a
171/// structured prompt without rewriting the semantic instructions.
172#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
173pub enum ResponseFormat {
174    /// Plain YAML, with the prompt asking the model to emit a fenced or
175    /// bare YAML document.
176    #[default]
177    Yaml,
178    /// JSON object that matches a schema attached via
179    /// [`RequestOptions::response_schema`]. The prompt drops the YAML
180    /// structure literal and tells the model to return only the JSON
181    /// object.
182    JsonSchema,
183}
184
185impl ResponseFormat {
186    /// Returns the response format that should be used given a backend's
187    /// capabilities.
188    #[must_use]
189    pub fn from_capabilities(caps: &AiClientCapabilities) -> Self {
190        if caps.supports_response_schema {
191            Self::JsonSchema
192        } else {
193            Self::Yaml
194        }
195    }
196}
197
198/// Per-request options passed to [`AiClient::send_request_with_options`].
199///
200/// Schema and other knobs live on the request, not the client, so a shared
201/// client cannot leak settings between concurrent calls. Backends that do
202/// not support an option are expected to ignore it (and the call site is
203/// expected to consult [`AiClient::capabilities`] before setting it).
204#[derive(Clone, Debug, Default)]
205pub struct RequestOptions {
206    /// Optional JSON Schema (as a `serde_json::Value`) constraining the
207    /// model's response. Only honoured by backends whose
208    /// [`AiClientCapabilities::supports_response_schema`] is `true`.
209    pub response_schema: Option<Value>,
210}
211
212impl RequestOptions {
213    /// Returns a new [`RequestOptions`] with [`response_schema`] set.
214    #[must_use]
215    pub fn with_response_schema(mut self, schema: Value) -> Self {
216        self.response_schema = Some(schema);
217        self
218    }
219}
220
221/// Trait for AI service clients.
222pub trait AiClient: Send + Sync {
223    /// Sends a request to the AI service and returns the raw response.
224    fn send_request<'a>(
225        &'a self,
226        system_prompt: &'a str,
227        user_prompt: &'a str,
228    ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>>;
229
230    /// Returns metadata about the AI client implementation.
231    fn get_metadata(&self) -> AiClientMetadata;
232
233    /// Returns the optional capabilities advertised by this backend.
234    ///
235    /// The default implementation returns the all-disabled baseline so
236    /// existing backends remain source-compatible. Backends that gain new
237    /// capabilities (e.g. structured-output enforcement) should override
238    /// this method.
239    fn capabilities(&self) -> AiClientCapabilities {
240        AiClientCapabilities::default()
241    }
242
243    /// Sends a request with optional per-request settings.
244    ///
245    /// The default implementation drops `options` and dispatches via
246    /// [`send_request`]. Backends that honour any field in
247    /// [`RequestOptions`] (e.g. `response_schema`) override this method.
248    /// Backends that don't honour an option must ignore it; call sites
249    /// should consult [`capabilities`](Self::capabilities) before setting
250    /// options that not all backends support.
251    fn send_request_with_options<'a>(
252        &'a self,
253        system_prompt: &'a str,
254        user_prompt: &'a str,
255        _options: RequestOptions,
256    ) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
257        self.send_request(system_prompt, user_prompt)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264
265    fn meta(provider: &str) -> AiClientMetadata {
266        AiClientMetadata {
267            provider: provider.to_string(),
268            model: "test-model".to_string(),
269            max_context_length: 1024,
270            max_response_length: 1024,
271            active_beta: None,
272        }
273    }
274
275    #[test]
276    fn prompt_style_openai() {
277        assert_eq!(meta("OpenAI").prompt_style(), PromptStyle::OpenAi);
278    }
279
280    #[test]
281    fn prompt_style_ollama() {
282        assert_eq!(meta("Ollama").prompt_style(), PromptStyle::OpenAi);
283    }
284
285    #[test]
286    fn prompt_style_anthropic() {
287        assert_eq!(meta("Anthropic").prompt_style(), PromptStyle::Claude);
288    }
289
290    #[test]
291    fn prompt_style_bedrock() {
292        assert_eq!(
293            meta("Anthropic Bedrock").prompt_style(),
294            PromptStyle::Claude
295        );
296    }
297
298    #[test]
299    fn prompt_style_unknown_defaults_to_claude() {
300        assert_eq!(meta("SomeNewProvider").prompt_style(), PromptStyle::Claude);
301    }
302
303    /// Ensure case-sensitive matching: "openai" (lowercase) is not a known provider
304    /// string and must not silently match as OpenAI.
305    #[test]
306    fn prompt_style_case_sensitive() {
307        assert_eq!(meta("openai").prompt_style(), PromptStyle::Claude);
308        assert_eq!(meta("ollama").prompt_style(), PromptStyle::Claude);
309    }
310
311    #[test]
312    fn capabilities_default_is_all_disabled() {
313        let caps = AiClientCapabilities::default();
314        assert!(!caps.supports_response_schema);
315    }
316
317    #[test]
318    fn response_format_default_is_yaml() {
319        assert_eq!(ResponseFormat::default(), ResponseFormat::Yaml);
320    }
321
322    #[test]
323    fn response_format_from_capabilities_disabled_picks_yaml() {
324        let caps = AiClientCapabilities::default();
325        assert_eq!(
326            ResponseFormat::from_capabilities(&caps),
327            ResponseFormat::Yaml
328        );
329    }
330
331    #[test]
332    fn response_format_from_capabilities_enabled_picks_json_schema() {
333        let caps = AiClientCapabilities {
334            supports_response_schema: true,
335        };
336        assert_eq!(
337            ResponseFormat::from_capabilities(&caps),
338            ResponseFormat::JsonSchema
339        );
340    }
341
342    #[test]
343    fn request_options_with_response_schema_sets_field() {
344        let value = serde_json::json!({"type": "object"});
345        let opts = RequestOptions::default().with_response_schema(value.clone());
346        assert_eq!(opts.response_schema, Some(value));
347    }
348
349    #[test]
350    fn request_options_default_has_no_schema() {
351        let opts = RequestOptions::default();
352        assert!(opts.response_schema.is_none());
353    }
354}