use reqwest::Client;
use serde_json::{json, Value};
use crate::error::ProviderError;
use crate::models::{ContentBlock, Conversation, LlmResponse, MessageRole};
use crate::providers::LlmProvider;
const DEFAULT_BASE_URL: &str = "https://api.openai.com/v1";
pub struct OpenAiProvider {
client: Client,
api_key: String,
model: String,
base_url: String,
}
impl OpenAiProvider {
pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
Self {
client: Client::new(),
api_key,
model,
base_url: base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string()),
}
}
}
#[async_trait::async_trait]
impl LlmProvider for OpenAiProvider {
async fn complete(&self, conversation: &Conversation) -> Result<LlmResponse, ProviderError> {
let mut messages = Vec::new();
for msg in &conversation.messages {
let role = match msg.role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
};
let mut message = json!({
"role": role,
"content": msg.content,
});
if let Some(ref tool_call_id) = msg.tool_call_id {
message["tool_call_id"] = json!(tool_call_id);
}
messages.push(message);
}
let mut body = json!({
"model": self.model,
"messages": messages,
});
if let Some(ref tools) = conversation.tools {
let openai_tools: Vec<Value> = tools.iter().map(|t| {
json!({
"type": "function",
"function": {
"name": t.name,
"description": t.description,
"parameters": t.parameters,
}
})
}).collect();
body["tools"] = json!(openai_tools);
}
let url = format!(
"{}/chat/completions",
self.base_url.trim_end_matches('/')
);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&body)
.send()
.await?;
let status = response.status();
let retry_after = response
.headers()
.get("retry-after")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok());
if !status.is_success() {
let error_body = response.text().await.unwrap_or_default();
return Err(ProviderError::Http {
status: status.as_u16(),
body: error_body,
retry_after,
});
}
let data: Value = response.json().await.map_err(|e| ProviderError::Parse(e.to_string()))?;
let choice = data["choices"][0].clone();
let message = &choice["message"];
let model = data["model"].as_str().unwrap_or(&self.model).to_string();
let mut content_blocks = Vec::new();
if let Some(text) = message["content"].as_str()
&& !text.is_empty()
{
content_blocks.push(ContentBlock::Text { text: text.to_string() });
}
if let Some(tool_calls) = message["tool_calls"].as_array() {
for tc in tool_calls {
let id = tc["id"].as_str().unwrap_or("").to_string();
let function = &tc["function"];
let name = function["name"].as_str().unwrap_or("").to_string();
let arguments = function["arguments"].as_str().unwrap_or("{}").to_string();
content_blocks.push(ContentBlock::ToolCall { id, name, arguments });
}
}
let usage = &data["usage"];
let input_tokens = usage["prompt_tokens"].as_u64().map(|v| v as u32);
let output_tokens = usage["completion_tokens"].as_u64().map(|v| v as u32);
Ok(LlmResponse {
content: content_blocks,
input_tokens,
output_tokens,
model,
})
}
fn provider_name(&self) -> &str {
"openai"
}
fn model_name(&self) -> &str {
&self.model
}
}