use std::future::Future;
use std::pin::Pin;
use anyhow::Result;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tracing::{debug, info};
use super::{AiClient, AiClientMetadata};
use crate::claude::{error::ClaudeError, model_config::get_model_registry};
#[derive(Serialize, Debug)]
struct Message {
role: String,
content: String,
}
#[derive(Serialize, Debug)]
struct OpenAiRequest {
model: String,
messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_completion_tokens: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
}
#[derive(Deserialize, Debug)]
struct Choice {
message: ResponseMessage,
#[allow(dead_code)] finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
struct ResponseMessage {
#[allow(dead_code)] role: String,
content: String,
}
#[derive(Deserialize, Debug)]
struct OpenAiResponse {
choices: Vec<Choice>,
model: Option<String>,
usage: Option<Usage>,
}
#[derive(Deserialize, Debug)]
#[allow(dead_code)] struct Usage {
prompt_tokens: Option<i32>,
completion_tokens: Option<i32>,
total_tokens: Option<i32>,
}
pub struct OpenAiAiClient {
client: Client,
api_key: Option<String>,
model: String,
base_url: String,
max_tokens: Option<i32>,
temperature: Option<f32>,
active_beta: Option<(String, String)>,
}
impl OpenAiAiClient {
pub fn new(
model: String,
api_key: Option<String>,
base_url: String,
max_tokens: Option<i32>,
temperature: Option<f32>,
active_beta: Option<(String, String)>,
) -> Result<Self> {
let client = super::build_http_client()?;
Ok(Self {
client,
api_key,
model,
base_url,
max_tokens,
temperature,
active_beta,
})
}
pub fn new_ollama(
model: String,
base_url: Option<String>,
active_beta: Option<(String, String)>,
) -> Result<Self> {
Self::new(
model,
None, base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
Some(4096), Some(0.1), active_beta,
)
}
pub fn new_openai(
model: String,
api_key: String,
active_beta: Option<(String, String)>,
) -> Result<Self> {
Self::new(
model,
Some(api_key),
"https://api.openai.com".to_string(),
None, Some(0.1), active_beta,
)
}
fn get_max_tokens(&self) -> i32 {
if let Some(configured_max) = self.max_tokens {
return configured_max;
}
super::registry_max_output_tokens(&self.model, &self.active_beta)
}
fn get_api_url(&self) -> Result<String> {
let mut base = self.base_url.clone();
if base.ends_with('/') {
base.pop();
}
let url = format!("{base}/v1/chat/completions");
debug!(base_url = %self.base_url, full_url = %url, "Constructed OpenAI-compatible API URL");
Ok(url)
}
fn is_ollama(&self) -> bool {
self.base_url.contains("localhost")
|| self.base_url.contains("127.0.0.1")
|| self.api_key.is_none()
}
fn is_gpt5_series(&self) -> bool {
self.model.starts_with("gpt-5") || self.model.starts_with("o1")
}
}
impl AiClient for OpenAiAiClient {
fn send_request<'a>(
&'a self,
system_prompt: &'a str,
user_prompt: &'a str,
) -> Pin<Box<dyn Future<Output = Result<String>> + Send + 'a>> {
Box::pin(async move {
debug!(
system_prompt_len = system_prompt.len(),
user_prompt_len = user_prompt.len(),
model = %self.model,
base_url = %self.base_url,
is_ollama = self.is_ollama(),
"Preparing OpenAI-compatible API request"
);
let mut messages = Vec::new();
if !system_prompt.is_empty() {
messages.push(Message {
role: "system".to_string(),
content: system_prompt.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: user_prompt.to_string(),
});
let max_tokens = self.get_max_tokens();
let request = if self.is_gpt5_series() {
OpenAiRequest {
model: self.model.clone(),
messages,
max_tokens: None,
max_completion_tokens: Some(max_tokens),
temperature: None, stream: false,
}
} else {
OpenAiRequest {
model: self.model.clone(),
messages,
max_tokens: Some(max_tokens),
max_completion_tokens: None,
temperature: self.temperature,
stream: false,
}
};
debug!(
max_tokens = max_tokens,
configured_temperature = ?self.temperature,
effective_temperature = ?request.temperature,
message_count = request.messages.len(),
is_gpt5_series = self.is_gpt5_series(),
uses_max_completion_tokens = self.is_gpt5_series(),
"Built OpenAI-compatible request payload"
);
let api_url = self.get_api_url()?;
info!(url = %api_url, model = %self.model, "Sending request to OpenAI-compatible API");
let mut req_builder = self
.client
.post(&api_url)
.header("Content-Type", "application/json")
.json(&request);
if let Some(ref api_key) = self.api_key {
req_builder = req_builder.header("Authorization", format!("Bearer {api_key}"));
}
let response = req_builder
.send()
.await
.map_err(|e| ClaudeError::NetworkError(e.to_string()))?;
let response = super::check_error_response(response).await?;
let openai_response: OpenAiResponse = response
.json()
.await
.map_err(|e| ClaudeError::InvalidResponseFormat(e.to_string()))?;
debug!(
choice_count = openai_response.choices.len(),
model = ?openai_response.model,
usage = ?openai_response.usage,
"Received OpenAI-compatible API response"
);
let result = openai_response
.choices
.first()
.map(|choice| choice.message.content.clone())
.ok_or_else(|| {
ClaudeError::InvalidResponseFormat("No choices in response".to_string()).into()
});
super::log_response_success("OpenAI-compatible", &result);
result
})
}
fn get_metadata(&self) -> AiClientMetadata {
let registry = get_model_registry();
let max_context_length = if registry.get_input_context(&self.model) > 0 {
registry.get_input_context(&self.model)
} else {
32768 };
let max_response_length = if registry.get_max_output_tokens(&self.model) > 0 {
registry.get_max_output_tokens(&self.model)
} else {
4096 };
let provider = if self.is_ollama() {
"Ollama".to_string()
} else {
"OpenAI".to_string()
};
AiClientMetadata {
provider,
model: self.model.clone(),
max_context_length,
max_response_length,
active_beta: self.active_beta.clone(),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn new_ollama() {
let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
assert_eq!(client.model, "llama2");
assert_eq!(client.base_url, "http://localhost:11434");
assert!(client.api_key.is_none());
assert!(client.is_ollama());
}
#[test]
fn new_ollama_custom_url() {
let client = OpenAiAiClient::new_ollama(
"codellama".to_string(),
Some("http://192.168.1.100:11434".to_string()),
None,
)
.unwrap();
assert_eq!(client.base_url, "http://192.168.1.100:11434");
assert!(client.is_ollama());
}
#[test]
fn new_openai() {
let client =
OpenAiAiClient::new_openai("gpt-4".to_string(), "sk-test123".to_string(), None)
.unwrap();
assert_eq!(client.model, "gpt-4");
assert_eq!(client.base_url, "https://api.openai.com");
assert_eq!(client.api_key, Some("sk-test123".to_string()));
assert!(!client.is_ollama());
}
#[test]
fn get_api_url() {
let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
let url = client.get_api_url().unwrap();
assert_eq!(url, "http://localhost:11434/v1/chat/completions");
}
#[test]
fn get_api_url_trailing_slash() {
let client = OpenAiAiClient::new(
"test-model".to_string(),
None,
"http://localhost:11434/".to_string(),
None,
None,
None,
)
.unwrap();
let url = client.get_api_url().unwrap();
assert_eq!(url, "http://localhost:11434/v1/chat/completions");
}
#[test]
fn is_ollama_detection() {
let ollama_client = OpenAiAiClient::new(
"llama2".to_string(),
None,
"http://localhost:11434".to_string(),
None,
None,
None,
)
.unwrap();
assert!(ollama_client.is_ollama());
let local_client = OpenAiAiClient::new(
"llama2".to_string(),
Some("fake-key".to_string()),
"http://127.0.0.1:11434".to_string(),
None,
None,
None,
)
.unwrap();
assert!(local_client.is_ollama());
let no_key_client = OpenAiAiClient::new(
"llama2".to_string(),
None,
"http://remote-server.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(no_key_client.is_ollama());
let openai_client = OpenAiAiClient::new(
"gpt-4".to_string(),
Some("sk-real-key".to_string()),
"https://api.openai.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(!openai_client.is_ollama());
}
#[test]
fn gpt5_series_gpt5_models() {
let client = OpenAiAiClient::new(
"gpt-5-preview".to_string(),
Some("key".to_string()),
"https://api.openai.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(client.is_gpt5_series());
let client2 = OpenAiAiClient::new(
"gpt-5".to_string(),
Some("key".to_string()),
"https://api.openai.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(client2.is_gpt5_series());
}
#[test]
fn gpt5_series_o1_models() {
let client = OpenAiAiClient::new(
"o1-mini".to_string(),
Some("key".to_string()),
"https://api.openai.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(client.is_gpt5_series());
let client2 = OpenAiAiClient::new(
"o1-preview".to_string(),
Some("key".to_string()),
"https://api.openai.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(client2.is_gpt5_series());
}
#[test]
fn gpt5_series_regular_models_not_matched() {
let client = OpenAiAiClient::new(
"gpt-4".to_string(),
Some("key".to_string()),
"https://api.openai.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(!client.is_gpt5_series());
let client2 = OpenAiAiClient::new(
"gpt-4o-mini".to_string(),
Some("key".to_string()),
"https://api.openai.com".to_string(),
None,
None,
None,
)
.unwrap();
assert!(!client2.is_gpt5_series());
}
#[test]
fn get_max_tokens_configured_value_wins() {
let client = OpenAiAiClient::new(
"gpt-4".to_string(),
Some("key".to_string()),
"https://api.openai.com".to_string(),
Some(8192),
None,
None,
)
.unwrap();
assert_eq!(client.get_max_tokens(), 8192);
}
#[test]
fn get_max_tokens_from_registry() {
let client =
OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
let tokens = client.get_max_tokens();
assert!(tokens > 0, "expected positive token limit, got {tokens}");
}
#[test]
fn get_metadata_openai() {
let client =
OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), None).unwrap();
let metadata = client.get_metadata();
assert_eq!(metadata.provider, "OpenAI");
assert_eq!(metadata.model, "gpt-4o");
assert!(metadata.active_beta.is_none());
}
#[test]
fn get_metadata_ollama() {
let client = OpenAiAiClient::new_ollama("llama2".to_string(), None, None).unwrap();
let metadata = client.get_metadata();
assert_eq!(metadata.provider, "Ollama");
assert_eq!(metadata.model, "llama2");
}
#[test]
fn get_metadata_with_beta() {
let beta = Some(("anthropic-beta".to_string(), "output-128k".to_string()));
let client =
OpenAiAiClient::new_openai("gpt-4o".to_string(), "key".to_string(), beta).unwrap();
let metadata = client.get_metadata();
assert!(metadata.active_beta.is_some());
let (key, value) = metadata.active_beta.unwrap();
assert_eq!(key, "anthropic-beta");
assert_eq!(value, "output-128k");
}
#[test]
fn request_gpt5_uses_max_completion_tokens() {
let request = OpenAiRequest {
model: "gpt-5".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "hello".to_string(),
}],
max_tokens: None,
max_completion_tokens: Some(4096),
temperature: None,
stream: false,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("max_completion_tokens"));
assert!(!json.contains("\"max_tokens\""));
}
#[test]
fn request_regular_model_uses_max_tokens() {
let request = OpenAiRequest {
model: "gpt-4".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "hello".to_string(),
}],
max_tokens: Some(4096),
max_completion_tokens: None,
temperature: Some(0.1),
stream: false,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"max_tokens\""));
assert!(!json.contains("max_completion_tokens"));
assert!(json.contains("\"temperature\""));
}
}