use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use solo_core::{Error, LlmClient, Message, Result, Role};
use zeroize::Zeroizing;
use super::retry::{
RetryConfig, exp_backoff_with_jitter, is_retryable_reqwest_err, is_retryable_status,
parse_retry_after,
};
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
const CHAT_COMPLETIONS_PATH: &str = "/chat/completions";
const DEFAULT_MODEL: &str = "gpt-4o-mini";
const DEFAULT_MAX_TOKENS: u32 = 1024;
const DEFAULT_TIMEOUT_SECS: u64 = 60;
#[derive(Clone)]
pub struct OpenAIClient {
http: reqwest::Client,
api_key: Arc<Zeroizing<String>>,
model: String,
max_tokens: u32,
base_url: String,
retry: RetryConfig,
}
impl OpenAIClient {
pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Result<Self> {
let http = reqwest::Client::builder()
.timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.map_err(|e| Error::llm(format!("build reqwest client: {e}")))?;
Ok(Self {
http,
api_key: Arc::new(Zeroizing::new(api_key.into())),
model: model.into(),
max_tokens: DEFAULT_MAX_TOKENS,
base_url: DEFAULT_BASE_URL.to_string(),
retry: RetryConfig::default(),
})
}
pub fn with_max_tokens(mut self, n: u32) -> Self {
self.max_tokens = n;
self
}
pub fn model(&self) -> &str {
&self.model
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn with_timeout(mut self, timeout: Duration) -> Result<Self> {
self.http = reqwest::Client::builder()
.timeout(timeout)
.build()
.map_err(|e| Error::llm(format!("rebuild reqwest client: {e}")))?;
Ok(self)
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
let mut s = base_url.into();
while s.ends_with('/') {
s.pop();
}
self.base_url = s;
self
}
pub fn with_retry_config(mut self, retry: RetryConfig) -> Self {
self.retry = retry;
self
}
}
#[async_trait]
impl LlmClient for OpenAIClient {
fn name(&self) -> &str {
&self.model
}
async fn complete(&self, messages: &[Message]) -> Result<Message> {
let body = OpenAIRequest {
model: &self.model,
max_tokens: self.max_tokens,
messages: messages.iter().map(to_openai_message).collect(),
};
let url = format!("{}{}", self.base_url, CHAT_COMPLETIONS_PATH);
let mut attempt: u32 = 0;
loop {
let send_res = self
.http
.post(&url)
.bearer_auth(self.api_key.as_str())
.header("content-type", "application/json")
.json(&body)
.send()
.await;
match send_res {
Ok(resp) => {
let status = resp.status();
if status.is_success() {
let parsed: OpenAIResponse = resp.json().await.map_err(|e| {
Error::llm(format!("openai response parse: {e}"))
})?;
let text = parsed
.choices
.into_iter()
.next()
.and_then(|c| c.message.content)
.ok_or_else(|| {
Error::llm(
"openai response had no choices[0].message.content"
.to_string(),
)
})?;
return Ok(Message {
role: Role::Assistant,
content: text,
});
}
let retry_after_hdr = resp
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let body_text = resp.text().await.unwrap_or_default();
if attempt < self.retry.max_retries
&& is_retryable_status(status.as_u16())
{
let delay = parse_retry_after(
retry_after_hdr.as_deref(),
self.retry.max_delay,
)
.unwrap_or_else(|| {
exp_backoff_with_jitter(attempt + 1, &self.retry)
});
tracing::warn!(
attempt = attempt + 1,
status = %status,
delay_ms = delay.as_millis() as u64,
"openai retryable HTTP error; backing off"
);
tokio::time::sleep(delay).await;
attempt += 1;
continue;
}
return Err(Error::llm(format!(
"openai HTTP {}: {}",
status,
truncate(&body_text, 500)
)));
}
Err(e) => {
if attempt < self.retry.max_retries
&& is_retryable_reqwest_err(&e)
{
let delay = exp_backoff_with_jitter(attempt + 1, &self.retry);
tracing::warn!(
attempt = attempt + 1,
error = %e,
delay_ms = delay.as_millis() as u64,
"openai retryable network error; backing off"
);
tokio::time::sleep(delay).await;
attempt += 1;
continue;
}
return Err(Error::llm(format!("openai request: {e}")));
}
}
}
}
}
pub fn build_openai_client_from_env() -> Result<Option<Arc<dyn LlmClient>>> {
let key = match std::env::var("OPENAI_API_KEY") {
Ok(k) if !k.is_empty() => k,
_ => return Ok(None),
};
eprintln!(
"warning: reading OPENAI_API_KEY from env — visible via /proc on Linux. \
File-based key support is a planned follow-up."
);
let model =
std::env::var("OPENAI_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.to_string());
let mut client = OpenAIClient::new(key, model)?;
if let Ok(base) = std::env::var("OPENAI_BASE_URL") {
if !base.is_empty() {
client = client.with_base_url(base);
}
}
let arc: Arc<dyn LlmClient> =
if super::ollama::is_ollama_base_url(client.base_url()) {
Arc::new(super::ollama::OllamaClient::wrap(client))
} else {
Arc::new(client)
};
Ok(Some(arc))
}
#[derive(Debug, Serialize)]
struct OpenAIRequest<'a> {
model: &'a str,
max_tokens: u32,
messages: Vec<OpenAIMessage>,
}
#[derive(Debug, Serialize)]
struct OpenAIMessage {
role: &'static str, content: String,
}
#[derive(Debug, Deserialize)]
struct OpenAIResponse {
#[serde(default)]
choices: Vec<OpenAIChoice>,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoice {
message: OpenAIChoiceMessage,
}
#[derive(Debug, Deserialize)]
struct OpenAIChoiceMessage {
#[serde(default)]
content: Option<String>,
}
fn to_openai_message(m: &Message) -> OpenAIMessage {
OpenAIMessage {
role: match m.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
},
content: m.content.clone(),
}
}
fn truncate(s: &str, max: usize) -> String {
if s.chars().count() <= max {
s.to_string()
} else {
let mut out: String = s.chars().take(max - 1).collect();
out.push('…');
out
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_mapping_preserves_three_roles() {
let msgs = vec![
Message::system("you are helpful"),
Message::user("hi"),
Message::assistant("hello"),
];
let mapped: Vec<OpenAIMessage> = msgs.iter().map(to_openai_message).collect();
assert_eq!(mapped.len(), 3);
assert_eq!(mapped[0].role, "system");
assert_eq!(mapped[0].content, "you are helpful");
assert_eq!(mapped[1].role, "user");
assert_eq!(mapped[1].content, "hi");
assert_eq!(mapped[2].role, "assistant");
assert_eq!(mapped[2].content, "hello");
}
#[test]
fn response_parses_choices_zero_content() {
let raw = r#"{
"id": "chatcmpl-1",
"model": "gpt-4o-mini",
"choices": [{
"index": 0,
"message": { "role": "assistant", "content": "hello world" },
"finish_reason": "stop"
}]
}"#;
let parsed: OpenAIResponse = serde_json::from_str(raw).unwrap();
assert_eq!(parsed.choices.len(), 1);
assert_eq!(parsed.choices[0].message.content.as_deref(), Some("hello world"));
}
#[test]
fn response_with_tool_call_has_null_content() {
let raw = r#"{
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": null,
"tool_calls": [{"id":"x","type":"function","function":{"name":"f","arguments":"{}"}}]
},
"finish_reason": "tool_calls"
}]
}"#;
let parsed: OpenAIResponse = serde_json::from_str(raw).unwrap();
assert!(parsed.choices[0].message.content.is_none());
}
#[test]
fn response_with_no_choices_yields_error_in_complete_path() {
let raw = r#"{ "choices": [] }"#;
let parsed: OpenAIResponse = serde_json::from_str(raw).unwrap();
let text = parsed
.choices
.into_iter()
.next()
.and_then(|c| c.message.content);
assert!(text.is_none(), "no choices → None for text");
}
#[test]
fn name_returns_configured_model() {
let c = OpenAIClient::new("dummy", "gpt-test-model").unwrap();
assert_eq!(c.name(), "gpt-test-model");
}
#[test]
fn with_max_tokens_overrides_default() {
let c = OpenAIClient::new("dummy", "m")
.unwrap()
.with_max_tokens(2048);
assert_eq!(c.max_tokens, 2048);
}
#[test]
fn with_base_url_strips_trailing_slashes() {
let c = OpenAIClient::new("dummy", "m")
.unwrap()
.with_base_url("http://localhost:1234/v1//");
assert_eq!(c.base_url, "http://localhost:1234/v1");
}
#[test]
fn with_base_url_keeps_clean_url_unchanged() {
let c = OpenAIClient::new("dummy", "m")
.unwrap()
.with_base_url("https://api.openai.com/v1");
assert_eq!(c.base_url, "https://api.openai.com/v1");
}
#[test]
fn build_from_env_returns_none_when_key_missing() {
unsafe {
std::env::remove_var("OPENAI_API_KEY");
}
let r = build_openai_client_from_env().unwrap();
assert!(r.is_none());
}
#[test]
fn build_from_env_returns_none_when_key_empty() {
unsafe {
std::env::set_var("OPENAI_API_KEY", "");
}
let r = build_openai_client_from_env().unwrap();
assert!(r.is_none());
unsafe {
std::env::remove_var("OPENAI_API_KEY");
}
}
#[tokio::test]
#[ignore]
async fn openai_smoke_real_api() {
let Ok(key) = std::env::var("OPENAI_API_KEY") else {
eprintln!("OPENAI_API_KEY not set; skipping");
return;
};
let model =
std::env::var("OPENAI_MODEL").unwrap_or_else(|_| DEFAULT_MODEL.to_string());
let mut client = OpenAIClient::new(key, model).unwrap();
if let Ok(base) = std::env::var("OPENAI_BASE_URL") {
if !base.is_empty() {
client = client.with_base_url(base);
}
}
let resp = client
.complete(&[Message::user("Reply with the single word: ok")])
.await
.expect("openai round-trip");
assert_eq!(resp.role, Role::Assistant);
assert!(!resp.content.is_empty());
}
}