use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use bytes::Bytes;
use futures::Stream;
use futures_util::StreamExt;
use reqwest::Client;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::messages::{Message, ToolCall};
use crate::provider::{Provider, ProviderError};
pub struct OpenAICompatibleProvider {
model: String,
api_key: String,
base_url: String,
max_retries: u32,
timeout: u64,
client: Client,
pub last_stream_message: Arc<Mutex<Option<Message>>>,
pub last_usage: Arc<std::sync::Mutex<Option<crate::messages::Usage>>>,
}
impl OpenAICompatibleProvider {
pub fn new(
model: String,
api_key: String,
base_url: String,
max_retries: u32,
timeout: u64,
) -> Self {
let client = Client::builder()
.timeout(std::time::Duration::from_secs(timeout))
.build()
.unwrap_or_default();
Self {
model,
api_key,
base_url,
max_retries,
timeout,
client,
last_stream_message: Arc::new(Mutex::new(None)),
last_usage: Arc::new(std::sync::Mutex::new(None)),
}
}
fn build_body(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
stream: bool,
) -> serde_json::Value {
let mut body = serde_json::json!({
"model": self.model,
"messages": messages.iter().map(|m| m.to_api_dict()).collect::<Vec<_>>(),
"temperature": temperature,
"tool_choice": tool_choice,
});
if let Some(t) = tools {
body["tools"] = serde_json::Value::Array(t.to_vec());
}
if let Some(mt) = max_tokens {
body["max_tokens"] = serde_json::Value::Number(serde_json::Number::from(mt));
}
if stream {
body["stream"] = serde_json::Value::Bool(true);
}
body
}
fn headers(&self) -> reqwest::header::HeaderMap {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::AUTHORIZATION,
format!("Bearer {}", self.api_key).parse().unwrap(),
);
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers
}
fn url(&self) -> String {
format!("{}/chat/completions", self.base_url.trim_end_matches('/'))
}
}
#[async_trait]
impl Provider for OpenAICompatibleProvider {
async fn chat_completion(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Message, ProviderError> {
let body = self.build_body(messages, tools, tool_choice, max_tokens, temperature, false);
let mut last_error: Option<ProviderError> = None;
for attempt in 0..self.max_retries {
match self.client.post(&self.url()).headers(self.headers()).json(&body).send().await {
Ok(resp) => {
if resp.status().is_success() {
let data: serde_json::Value = resp.json().await.map_err(|e| ProviderError::Other(format!("JSON parse error: {}", e)))?;
let choice = &data["choices"][0]["message"];
if let Some(usage) = data.get("usage") {
let u = crate::messages::Usage {
prompt_tokens: usage.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
completion_tokens: usage.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
total_tokens: usage.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
..Default::default()
};
if let Ok(mut last) = self.last_usage.lock() { *last = Some(u); }
}
return Ok(Message::from_api_dict(choice));
}
let status = resp.status().as_u16();
let text = resp.text().await.unwrap_or_default();
if status == 429 || (500..=599).contains(&status) {
last_error = Some(ProviderError::Api { status, body: text.clone() });
if attempt < self.max_retries - 1 { tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await; continue; }
}
return Err(ProviderError::Api { status, body: text });
}
Err(e) if e.is_timeout() => {
last_error = Some(ProviderError::Timeout(format!("Request timed out ({}s)", self.timeout)));
if attempt < self.max_retries - 1 { tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await; continue; }
}
Err(e) => {
last_error = Some(ProviderError::Http(format!("Request failed: {}", e)));
if attempt < self.max_retries - 1 { tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await; continue; }
}
}
}
Err(last_error.unwrap_or_else(|| ProviderError::Other("Provider call failed after all retries".into())))
}
async fn chat_completion_stream(
&self,
messages: &[Message],
tools: Option<&[serde_json::Value]>,
tool_choice: &str,
max_tokens: Option<u32>,
temperature: f32,
) -> Result<Pin<Box<dyn Stream<Item = Result<String, ProviderError>> + Send>>, ProviderError> {
let body = self.build_body(messages, tools, tool_choice, max_tokens, temperature, true);
let (tx, rx) = mpsc::unbounded_channel();
let client = self.client.clone();
let url = self.url();
let headers = self.headers();
let timeout_dur = std::time::Duration::from_secs(self.timeout);
let last_stream = self.last_stream_message.clone();
let last_stream_usage = self.last_usage.clone();
tokio::spawn(async move {
let result = client.post(&url).headers(headers).json(&body).timeout(timeout_dur).send().await;
let response = match result {
Ok(r) => r,
Err(e) => { let _ = tx.send(Err(ProviderError::Stream(format!("Request failed: {}", e)))); return; }
};
if !response.status().is_success() {
let status = response.status().as_u16();
let text = response.text().await.unwrap_or_default();
let _ = tx.send(Err(ProviderError::Api { status, body: text }));
return;
};
let mut content_parts: Vec<String> = Vec::new();
let mut tool_call_map: HashMap<usize, ToolCallBuilder> = HashMap::new();
let mut buffer = String::new();
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let chunk: Bytes = match chunk_result {
Ok(c) => c,
Err(e) => { let _ = tx.send(Err(ProviderError::Stream(format!("Read error: {}", e)))); return; }
};
buffer.push_str(&String::from_utf8_lossy(&chunk));
while let Some(newline) = buffer.find('\n') {
let line = buffer[..newline].to_string();
buffer = buffer[newline + 1..].to_string();
let line = line.trim().to_string();
if !line.starts_with("data: ") { continue; }
let payload = line[6..].trim().to_string();
if payload == "[DONE]" { break; }
let chunk_value: serde_json::Value = match serde_json::from_str(&payload) { Ok(v) => v, Err(_) => continue };
let choices = &chunk_value["choices"];
if choices.as_array().map_or(true, |a| a.is_empty()) {
if let Some(usage) = chunk_value.get("usage") {
let u = crate::messages::Usage {
prompt_tokens: usage.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
completion_tokens: usage.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
total_tokens: usage.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0),
..Default::default()
};
if let Ok(mut last) = last_stream_usage.lock() { *last = Some(u); }
}
continue;
}
let delta = &choices[0]["delta"];
if let Some(content) = delta["content"].as_str() {
if !content.is_empty() { content_parts.push(content.to_string()); let _ = tx.send(Ok(content.to_string())); }
}
if let Some(tc_deltas) = delta["tool_calls"].as_array() {
for tc in tc_deltas {
let idx = tc["index"].as_u64().unwrap_or(0) as usize;
let builder = tool_call_map.entry(idx).or_insert_with(|| ToolCallBuilder { id: String::new(), name: String::new(), arguments: String::new() });
if let Some(id) = tc["id"].as_str() { if !id.is_empty() { builder.id = id.to_string(); } }
if let Some(name) = tc["function"]["name"].as_str() { builder.name.push_str(name); }
if let Some(args) = tc["function"]["arguments"].as_str() { builder.arguments.push_str(args); }
}
}
}
}
let content = if content_parts.is_empty() { None } else { Some(content_parts.concat()) };
let tool_calls: Option<Vec<ToolCall>> = if tool_call_map.is_empty() {
None
} else {
let mut calls: Vec<ToolCall> = tool_call_map.into_iter().map(|(_, b)| {
let args = serde_json::from_str(&b.arguments).unwrap_or(serde_json::Value::Object(Default::default()));
ToolCall { id: b.id, name: b.name, arguments: args }
}).collect();
calls.sort_by(|a, b| a.id.cmp(&b.id));
Some(calls)
};
let msg = Message::new_assistant(content, tool_calls);
if let Ok(mut last) = last_stream.lock() { *last = Some(msg); }
});
let stream = UnboundedReceiverStream::new(rx);
Ok(Box::pin(stream))
}
fn last_stream_message(&self) -> Option<Message> {
self.last_stream_message.lock().ok().and_then(|mut guard| guard.take())
}
fn last_usage(&self) -> Option<crate::messages::Usage> {
self.last_usage.lock().ok().and_then(|mut guard| guard.take())
}
async fn embed(&self, text: &str) -> Option<Vec<f32>> {
let url = format!("{}embeddings", self.base_url.trim_end_matches('/'));
let body = serde_json::json!({
"model": "text-embedding-3-small",
"input": text
});
let resp = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.ok()?;
if !resp.status().is_success() {
return None;
}
let data: serde_json::Value = resp.json().await.ok()?;
let embedding = data["data"][0]["embedding"].as_array()?;
let vec: Vec<f32> = embedding
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if vec.is_empty() { None } else { Some(vec) }
}
}
struct ToolCallBuilder {
id: String,
name: String,
arguments: String,
}
pub fn create_provider(
provider_type: &str,
model: &str,
api_key: &str,
base_url: Option<&str>,
) -> Result<Box<dyn Provider>, String> {
match provider_type.to_lowercase().as_str() {
"openai" | "openai-compatible" => Ok(Box::new(OpenAICompatibleProvider::new(
model.to_string(), api_key.to_string(), base_url.unwrap_or("https://api.openai.com/v1").to_string(), 3, 120,
))),
_ => Err(format!("Unknown provider type: '{}'. Supported: openai.", provider_type)),
}
}