use std::time::Duration;
use crate::api::models::{ChatRequest, ChatResponse};
#[derive(Debug, thiserror::Error)]
pub enum ApiCallError {
#[error("API status {status}: {body}")]
Status {
status: u16,
retry_after: Option<Duration>,
body: String,
},
#[error("network error: {0}")]
Network(#[from] reqwest::Error),
#[error("response decode error: {0}")]
Decode(String),
}
impl ApiCallError {
pub fn is_retryable(&self) -> bool {
match self {
Self::Status { status, .. } => {
matches!(*status, 408 | 425 | 429 | 500 | 502 | 503 | 504)
}
Self::Network(e) => e.is_connect() || e.is_timeout() || e.is_request() || e.is_body(),
Self::Decode(_) => false,
}
}
pub fn retry_after(&self) -> Option<Duration> {
match self {
Self::Status { retry_after, .. } => *retry_after,
_ => None,
}
}
pub fn is_rate_limit(&self) -> bool {
matches!(self, Self::Status { status: 429, .. })
}
}
fn parse_retry_after(headers: &reqwest::header::HeaderMap) -> Option<Duration> {
let value = headers.get(reqwest::header::RETRY_AFTER)?.to_str().ok()?;
let secs: u64 = value.trim().parse().ok()?;
Some(Duration::from_secs(secs))
}
#[derive(Clone)]
pub struct OpenAiCompatibleProvider {
http: reqwest::Client,
base_url: String,
api_key: String,
pub model: String,
pub max_tokens: u32,
context_window_size: usize,
has_tool_support: bool,
has_reasoning: bool,
extra_headers: std::collections::HashMap<String, String>,
}
impl OpenAiCompatibleProvider {
pub fn new(
base_url: String,
api_key: String,
model: String,
context_window_size: usize,
) -> anyhow::Result<Self> {
let profile = crate::api::model_profile::profile_for(&model);
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.pool_max_idle_per_host(10)
.pool_idle_timeout(std::time::Duration::from_secs(60))
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
http,
base_url,
api_key,
max_tokens: profile.max_output_tokens,
model,
context_window_size,
has_tool_support: profile.supports_tool_use,
has_reasoning: profile.supports_reasoning,
extra_headers: std::collections::HashMap::new(),
})
}
pub fn from_config(config: &crate::config::Config) -> anyhow::Result<Self> {
let profile = crate::api::model_profile::profile_for(&config.model);
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.pool_max_idle_per_host(10)
.pool_idle_timeout(std::time::Duration::from_secs(60))
.tcp_keepalive(std::time::Duration::from_secs(30))
.build()?;
Ok(Self {
http,
base_url: config.base_url.clone(),
api_key: config.api_key.clone(),
model: config.model.clone(),
max_tokens: config.max_tokens,
context_window_size: profile.context_window,
has_tool_support: profile.supports_tool_use,
has_reasoning: profile.supports_reasoning,
extra_headers: config.proxy_headers.clone(),
})
}
pub fn from_entry(
entry: &crate::config::ProviderEntry,
api_key: &str,
model: &str,
extra_headers: std::collections::HashMap<String, String>,
) -> anyhow::Result<Self> {
let profile = crate::api::model_profile::profile_for(model);
let mut provider = Self::new(
entry.base_url.clone(),
api_key.to_string(),
model.to_string(),
profile.context_window,
)?;
provider.extra_headers = extra_headers;
Ok(provider)
}
pub fn switch_provider(
&mut self,
base_url: String,
api_key: String,
model: String,
max_tokens: u32,
) {
self.base_url = base_url;
self.api_key = api_key;
self.model = model;
self.max_tokens = max_tokens;
}
pub fn base_url(&self) -> &str {
&self.base_url
}
pub fn api_key(&self) -> &str {
&self.api_key
}
pub async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, ApiCallError> {
let url = format!("{}/chat/completions", self.base_url());
tracing::trace!(base_url = self.base_url(), model = %self.model, "chat request");
let mut req = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json");
for (k, v) in &self.extra_headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = req.json(request).send().await?;
let status = resp.status();
if !status.is_success() {
let retry_after = parse_retry_after(resp.headers());
let body = resp.text().await.unwrap_or_default();
return Err(ApiCallError::Status {
status: status.as_u16(),
retry_after,
body,
});
}
let response: ChatResponse = resp
.json()
.await
.map_err(|e| ApiCallError::Decode(e.to_string()))?;
tracing::trace!(
response_id = %response.id,
base_url = %self.base_url(),
finish_reason = ?response.choices.first().and_then(|c| c.finish_reason.as_deref()),
first_choice_index = ?response.choices.first().map(|c| c.index),
total_tokens = ?response.usage.as_ref().map(|u| u.total_tokens),
choices = response.choices.len(),
"Non-streaming API response"
);
Ok(response)
}
pub async fn chat_stream(
&self,
request: &ChatRequest,
) -> Result<reqwest::Response, ApiCallError> {
let url = format!("{}/chat/completions", self.base_url);
let mut req = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json");
for (k, v) in &self.extra_headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = req.json(request).send().await?;
let status = resp.status();
if !status.is_success() {
let retry_after = parse_retry_after(resp.headers());
let body = resp.text().await.unwrap_or_default();
return Err(ApiCallError::Status {
status: status.as_u16(),
retry_after,
body,
});
}
Ok(resp)
}
pub fn model_name(&self) -> &str {
&self.model
}
pub fn context_window(&self) -> usize {
self.context_window_size
}
pub fn supports_tools(&self) -> bool {
self.has_tool_support
}
pub fn supports_reasoning(&self) -> bool {
self.has_reasoning
}
}
#[cfg(test)]
mod tests {
use super::*;
fn status_err(status: u16) -> ApiCallError {
ApiCallError::Status {
status,
retry_after: None,
body: String::new(),
}
}
#[test]
fn test_is_retryable_5xx_and_429() {
for code in [408u16, 425, 429, 500, 502, 503, 504] {
assert!(
status_err(code).is_retryable(),
"status {code} must be retryable"
);
}
}
#[test]
fn test_is_retryable_skips_4xx_permanent() {
for code in [400u16, 401, 403, 404, 409, 422] {
assert!(
!status_err(code).is_retryable(),
"status {code} must NOT be retryable"
);
}
}
#[test]
fn test_is_rate_limit_only_429() {
assert!(status_err(429).is_rate_limit());
assert!(!status_err(503).is_rate_limit());
assert!(!status_err(500).is_rate_limit());
}
#[test]
fn test_retry_after_propagates() {
let err = ApiCallError::Status {
status: 429,
retry_after: Some(Duration::from_secs(7)),
body: String::new(),
};
assert_eq!(err.retry_after(), Some(Duration::from_secs(7)));
assert_eq!(status_err(429).retry_after(), None);
}
#[test]
fn test_decode_error_not_retryable() {
assert!(!ApiCallError::Decode("bad json".into()).is_retryable());
}
#[test]
fn test_parse_retry_after_seconds() {
let mut h = reqwest::header::HeaderMap::new();
h.insert(reqwest::header::RETRY_AFTER, "12".parse().unwrap());
assert_eq!(parse_retry_after(&h), Some(Duration::from_secs(12)));
}
#[test]
fn test_parse_retry_after_missing_or_garbage() {
let empty = reqwest::header::HeaderMap::new();
assert_eq!(parse_retry_after(&empty), None);
let mut bad = reqwest::header::HeaderMap::new();
bad.insert(
reqwest::header::RETRY_AFTER,
"Wed, 21 Oct 2026 07:28:00 GMT".parse().unwrap(),
);
assert_eq!(parse_retry_after(&bad), None);
}
#[test]
fn test_openai_compatible_provider_creation() {
let provider = OpenAiCompatibleProvider::new(
"https://api.example.com/v1".to_string(),
"test-key".to_string(),
"glm-4.7".to_string(),
128_000,
);
assert!(provider.is_ok());
let p = provider.unwrap();
assert_eq!(p.model_name(), "glm-4.7");
assert_eq!(p.context_window(), 128_000);
assert!(p.supports_tools());
}
}