pub mod bedrock;
pub mod claude;
pub mod claude_cli;
pub mod openai;
use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use anyhow::{Context, Result};
use reqwest::Client;
use serde_json::Value;
use crate::claude::error::ClaudeError;
use crate::claude::model_config::get_model_registry;
pub(crate) const REQUEST_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Clone, Debug)]
pub struct AiClientMetadata {
pub provider: String,
pub model: String,
pub max_context_length: usize,
pub max_response_length: usize,
pub active_beta: Option<(String, String)>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PromptStyle {
Claude,
OpenAi,
}
impl AiClientMetadata {
#[must_use]
pub fn prompt_style(&self) -> PromptStyle {
match self.provider.as_str() {
"OpenAI" | "Ollama" => PromptStyle::OpenAi,
_ => PromptStyle::Claude,
}
}
}
pub(crate) fn build_http_client() -> Result<Client> {
Client::builder()
.timeout(REQUEST_TIMEOUT)
.build()
.context("Failed to build HTTP client")
}
#[must_use]
pub(crate) fn registry_max_output_tokens(
model: &str,
active_beta: &Option<(String, String)>,
) -> i32 {
let registry = get_model_registry();
if let Some((_, value)) = active_beta {
registry.get_max_output_tokens_with_beta(model, value) as i32
} else {
registry.get_max_output_tokens(model) as i32
}
}
#[must_use]
pub(crate) fn registry_model_limits(
model: &str,
active_beta: &Option<(String, String)>,
) -> (usize, usize) {
let registry = get_model_registry();
match active_beta {
Some((_, value)) => (
registry.get_input_context_with_beta(model, value),
registry.get_max_output_tokens_with_beta(model, value),
),
None => (
registry.get_input_context(model),
registry.get_max_output_tokens(model),
),
}
}
pub(crate) async fn check_error_response(response: reqwest::Response) -> Result<reqwest::Response> {
if response.status().is_success() {
return Ok(response);
}
let status = response.status();
let error_text = response.text().await.unwrap_or_else(|e| {
tracing::debug!("Failed to read error response body: {e}");
String::new()
});
Err(ClaudeError::ApiRequestFailed(format!("HTTP {status}: {error_text}")).into())
}
pub(crate) fn log_response_success(provider: &str, result: &Result<String>) {
if let Ok(text) = result {
tracing::debug!(
response_len = text.len(),
"Successfully extracted text content from {} API response",
provider
);
tracing::debug!(
response_content = %text,
"{} API response content",
provider
);
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct AiClientCapabilities {
pub supports_response_schema: bool,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum ResponseFormat {
#[default]
Yaml,
JsonSchema,
}
impl ResponseFormat {
#[must_use]
pub fn from_capabilities(caps: &AiClientCapabilities) -> Self {
if caps.supports_response_schema {
Self::JsonSchema
} else {
Self::Yaml
}
}
}
#[derive(Clone, Debug, Default)]
pub struct RequestOptions {
pub response_schema: Option<Value>,
}
impl RequestOptions {
#[must_use]
pub fn with_response_schema(mut self, schema: Value) -> Self {
self.response_schema = Some(schema);
self
}
}
pub trait AiClient: Send + Sync {
fn send_request<'a>(
&'a self,
system_prompt: &'a str,
user_prompt: &'a str,
) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>>;
fn get_metadata(&self) -> AiClientMetadata;
fn capabilities(&self) -> AiClientCapabilities {
AiClientCapabilities::default()
}
fn send_request_with_options<'a>(
&'a self,
system_prompt: &'a str,
user_prompt: &'a str,
_options: RequestOptions,
) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
self.send_request(system_prompt, user_prompt)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn meta(provider: &str) -> AiClientMetadata {
AiClientMetadata {
provider: provider.to_string(),
model: "test-model".to_string(),
max_context_length: 1024,
max_response_length: 1024,
active_beta: None,
}
}
#[test]
fn prompt_style_openai() {
assert_eq!(meta("OpenAI").prompt_style(), PromptStyle::OpenAi);
}
#[test]
fn prompt_style_ollama() {
assert_eq!(meta("Ollama").prompt_style(), PromptStyle::OpenAi);
}
#[test]
fn prompt_style_anthropic() {
assert_eq!(meta("Anthropic").prompt_style(), PromptStyle::Claude);
}
#[test]
fn prompt_style_bedrock() {
assert_eq!(
meta("Anthropic Bedrock").prompt_style(),
PromptStyle::Claude
);
}
#[test]
fn prompt_style_unknown_defaults_to_claude() {
assert_eq!(meta("SomeNewProvider").prompt_style(), PromptStyle::Claude);
}
#[test]
fn prompt_style_case_sensitive() {
assert_eq!(meta("openai").prompt_style(), PromptStyle::Claude);
assert_eq!(meta("ollama").prompt_style(), PromptStyle::Claude);
}
#[test]
fn capabilities_default_is_all_disabled() {
let caps = AiClientCapabilities::default();
assert!(!caps.supports_response_schema);
}
#[test]
fn response_format_default_is_yaml() {
assert_eq!(ResponseFormat::default(), ResponseFormat::Yaml);
}
#[test]
fn response_format_from_capabilities_disabled_picks_yaml() {
let caps = AiClientCapabilities::default();
assert_eq!(
ResponseFormat::from_capabilities(&caps),
ResponseFormat::Yaml
);
}
#[test]
fn response_format_from_capabilities_enabled_picks_json_schema() {
let caps = AiClientCapabilities {
supports_response_schema: true,
};
assert_eq!(
ResponseFormat::from_capabilities(&caps),
ResponseFormat::JsonSchema
);
}
#[test]
fn request_options_with_response_schema_sets_field() {
let value = serde_json::json!({"type": "object"});
let opts = RequestOptions::default().with_response_schema(value.clone());
assert_eq!(opts.response_schema, Some(value));
}
#[test]
fn request_options_default_has_no_schema() {
let opts = RequestOptions::default();
assert!(opts.response_schema.is_none());
}
}