use super::traits::{ChatRequest, ChatResponse, FinishReason, LlmProvider, TokenUsage};
use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub struct OpenAiProvider {
name: Arc<str>,
api_key: Arc<str>,
base_url: Arc<str>,
client: Client,
}
impl OpenAiProvider {
pub fn new(name: Arc<str>, api_key: Arc<str>, base_url: Option<Arc<str>>) -> Self {
Self {
name,
api_key,
base_url: base_url.unwrap_or_else(|| Arc::from("https://api.openai.com/v1")),
client: Client::new(),
}
}
}
#[async_trait]
impl LlmProvider for OpenAiProvider {
fn name(&self) -> &str {
&self.name
}
fn provider_type(&self) -> &str {
"openai"
}
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
let url = format!("{}/chat/completions", self.base_url);
let messages: Vec<OpenAiMessage> = request
.messages
.into_iter()
.map(|m| OpenAiMessage {
role: match m.role {
super::traits::MessageRole::System => "system".to_string(),
super::traits::MessageRole::User => "user".to_string(),
super::traits::MessageRole::Assistant => "assistant".to_string(),
super::traits::MessageRole::Tool => "tool".to_string(),
},
content: m.content,
name: m.name,
tool_call_id: m.tool_call_id,
})
.collect();
let body = OpenAiRequest {
model: request.model,
messages,
max_tokens: request.max_tokens,
temperature: request.temperature,
};
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send request to OpenAI")?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("OpenAI API error: {}", error_text);
}
let api_response: OpenAiResponse = response
.json()
.await
.context("Failed to parse OpenAI response")?;
let choice = api_response
.choices
.into_iter()
.next()
.context("No choices in OpenAI response")?;
Ok(ChatResponse {
id: api_response.id,
model: api_response.model,
content: choice.message.content,
finish_reason: match choice.finish_reason.as_str() {
"stop" => FinishReason::Stop,
"length" => FinishReason::Length,
"tool_calls" => FinishReason::ToolCalls,
_ => FinishReason::Stop,
},
usage: TokenUsage {
prompt_tokens: api_response.usage.prompt_tokens,
completion_tokens: api_response.usage.completion_tokens,
total_tokens: api_response.usage.total_tokens,
},
tool_calls: None,
})
}
async fn is_available(&self) -> bool {
if self.api_key.is_empty() {
return false;
}
let url = format!("{}/models", self.base_url);
let response = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await;
match response {
Ok(resp) => resp.status().is_success(),
Err(_) => false,
}
}
}
#[derive(Debug, Serialize)]
struct OpenAiRequest {
model: String,
messages: Vec<OpenAiMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
struct OpenAiMessage {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OpenAiResponse {
id: String,
model: String,
choices: Vec<OpenAiChoice>,
usage: OpenAiUsage,
}
#[derive(Debug, Deserialize)]
struct OpenAiChoice {
message: OpenAiMessage,
finish_reason: String,
}
#[derive(Debug, Deserialize)]
struct OpenAiUsage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}