use super::chat::{ChatCompletion, ChatModel};
use super::message::{ChatMessage, ChatMessageContent, TokenUsage};
use anyhow::Error;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json;
use std::collections::HashMap;
use log::info;
#[derive(Serialize, Deserialize, Clone)]
struct OpenAIMessage {
role: String,
content: String,
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Deserialize, Default)]
struct InputTokenDetails {
audio_tokens: Option<usize>,
cache_read: Option<usize>,
reasoning_tokens: Option<usize>,
}
#[derive(Deserialize, Default)]
struct OutputTokenDetails {
cache_write: Option<usize>,
reasoning_tokens: Option<usize>,
}
#[derive(Deserialize, Default)]
struct OpenAIUsage {
prompt_tokens: usize,
completion_tokens: usize,
total_tokens: usize,
input_tokens_details: Option<InputTokenDetails>,
output_tokens_details: Option<OutputTokenDetails>,
}
#[derive(Deserialize, Default)]
struct OpenAIResponsesUsage {
input_tokens: Option<usize>,
output_tokens: Option<usize>,
total_tokens: Option<usize>,
input_tokens_details: Option<InputTokenDetails>,
output_tokens_details: Option<OutputTokenDetails>,
}
#[derive(Deserialize)]
struct OpenAIResponse {
id: Option<String>,
object: Option<String>,
created: Option<u64>,
model: Option<String>,
choices: Vec<OpenAIChoice>, usage: Option<OpenAIUsage>,
output: Option<Vec<OpenAIChoice>>,
}
#[derive(Deserialize)]
struct OpenAIChoice {
index: u32,
message: OpenAIMessage,
finish_reason: String,
}
#[derive(Debug, Clone, Copy)]
enum OpenAIApiType {
ChatCompletions,
Responses,
}
#[derive(Clone)]
pub struct OpenAIChatModel {
client: Client,
api_key: String,
base_url: String,
model_name: Option<String>,
temperature: Option<f32>,
max_tokens: Option<u32>,
api_type: OpenAIApiType,
additional_headers: HashMap<String, String>,
additional_params: HashMap<String, serde_json::Value>,
}
impl OpenAIChatModel {
pub fn new(api_key: String, base_url: Option<String>) -> Self {
Self {
client: Client::new(),
api_key,
base_url: base_url.unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
model_name: None,
temperature: Some(0.7),
max_tokens: None,
api_type: OpenAIApiType::ChatCompletions,
additional_headers: HashMap::new(),
additional_params: HashMap::new(),
}
}
pub fn model_name(&self) -> Option<&String> {
self.model_name.as_ref()
}
pub fn base_url(&self) -> &String {
&self.base_url
}
pub fn temperature(&self) -> Option<f32> {
self.temperature
}
pub fn max_tokens(&self) -> Option<u32> {
self.max_tokens
}
pub fn with_model(mut self, model_name: String) -> Self {
self.model_name = Some(model_name);
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_api_type(mut self, api_type: OpenAIApiType) -> Self {
self.api_type = api_type;
self
}
pub fn with_additional_header(mut self, key: String, value: String) -> Self {
self.additional_headers.insert(key, value);
self
}
pub fn with_additional_param(mut self, key: String, value: serde_json::Value) -> Self {
self.additional_params.insert(key, value);
self
}
fn _get_request_payload(&self, messages: &[OpenAIMessage]) -> Result<serde_json::Value, Error> {
Ok(serde_json::json!({"messages": messages}))
}
fn _convert_message_to_dict(&self, message: &OpenAIMessage) -> Result<serde_json::Value, Error> {
Ok(serde_json::to_value(message)?)
}
fn _construct_responses_api_payload(&self, messages: &[OpenAIMessage]) -> Result<serde_json::Value, Error> {
Ok(serde_json::json!({"messages": messages}))
}
fn _create_usage_metadata(&self, usage: &OpenAIUsage) -> TokenUsage {
TokenUsage {
prompt_tokens: usage.prompt_tokens,
completion_tokens: usage.completion_tokens,
total_tokens: usage.total_tokens,
}
}
fn _create_usage_metadata_responses(&self, usage: &OpenAIResponsesUsage) -> TokenUsage {
TokenUsage {
prompt_tokens: usage.input_tokens.unwrap_or(0),
completion_tokens: usage.output_tokens.unwrap_or(0),
total_tokens: usage.total_tokens.unwrap_or(0),
}
}
fn _convert_dict_to_message(&self, message_dict: serde_json::Value) -> Result<ChatMessage, Error> {
let role = message_dict.get("role").and_then(|v| v.as_str()).unwrap_or("assistant");
let content = message_dict.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
let chat_content = ChatMessageContent {
content,
name: None,
additional_kwargs: HashMap::new(),
};
match role {
"system" => Ok(ChatMessage::System(chat_content)),
"user" => Ok(ChatMessage::Human(chat_content)),
"assistant" => Ok(ChatMessage::AIMessage(chat_content)),
"tool" => Ok(ChatMessage::ToolMessage(chat_content)),
_ => Ok(ChatMessage::AIMessage(chat_content)),
}
}
}
impl ChatModel for OpenAIChatModel {
fn model_name(&self) -> Option<&str> {
self.model_name.as_deref()
}
fn base_url(&self) -> String {
self.base_url.to_string()
}
fn invoke(&self, messages: Vec<ChatMessage>) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<ChatCompletion, Error>> + Send + '_>> {
let messages = messages;
let client = self.client.clone();
let api_key = self.api_key.clone();
let base_url = self.base_url.clone();
let model_name = self.model_name.clone();
let temperature = self.temperature;
let max_tokens = self.max_tokens;
let additional_headers = self.additional_headers.clone();
let additional_params = self.additional_params.clone();
Box::pin(async move {
let openai_messages: Vec<OpenAIMessage> = messages
.into_iter()
.map(|msg| match msg {
ChatMessage::System(content) => OpenAIMessage {
role: "system".to_string(),
content: content.content,
name: content.name,
tool_call_id: None,
},
ChatMessage::Human(content) => OpenAIMessage {
role: "user".to_string(),
content: content.content,
name: content.name,
tool_call_id: None,
},
ChatMessage::AIMessage(content) => OpenAIMessage {
role: "assistant".to_string(),
content: content.content,
name: content.name,
tool_call_id: None,
},
ChatMessage::ToolMessage(content) => {
info!("Converting tool message: role=tool, content={}", content.content);
let tool_call_id = content.additional_kwargs.get("tool_call_id")
.and_then(|v| v.as_str())
.unwrap_or("default_tool_call_id").to_string();
OpenAIMessage {
role: "tool".to_string(),
content: content.content,
name: content.name,
tool_call_id: Some(tool_call_id),
}
},
})
.collect();
let mut request_body = serde_json::json!({
"messages": openai_messages,
"model": model_name.clone().unwrap_or("".to_string()),
});
if let Some(temp) = temperature {
request_body["temperature"] = serde_json::json!(temp);
}
if let Some(max) = max_tokens {
request_body["max_tokens"] = serde_json::json!(max);
}
for (key, value) in additional_params {
request_body[key] = value;
}
let api_url = format!("{}/chat/completions", base_url);
let mut request = client.post(&api_url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json");
for (key, value) in additional_headers {
request = request.header(key, value);
}
let response = request.json(&request_body).send().await?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await?;
return Err(Error::msg(format!("API request failed: {} - {}", status, error_text)));
}
let response: OpenAIResponse = response.json().await?;
let chat_message = match response.choices.first() {
Some(choice) => {
let message = &choice.message;
match message.role.as_str() {
"assistant" => ChatMessage::AIMessage(ChatMessageContent {
content: message.content.clone(),
name: message.name.clone(),
additional_kwargs: HashMap::new(),
}),
_ => {
return Err(Error::msg(format!("Unexpected message role: {}", message.role)));
}
}
},
None => {
match &response.output {
Some(outputs) => {
match outputs.first() {
Some(choice) => {
let message = &choice.message;
ChatMessage::AIMessage(ChatMessageContent {
content: message.content.clone(),
name: message.name.clone(),
additional_kwargs: HashMap::new(),
})
},
None => return Err(Error::msg("No output returned from API")),
}
},
None => return Err(Error::msg("No choices or output returned from API")),
}
},
};
let usage = match &response.usage {
Some(openai_usage) => {
Some(TokenUsage {
prompt_tokens: openai_usage.prompt_tokens,
completion_tokens: openai_usage.completion_tokens,
total_tokens: openai_usage.total_tokens,
})
},
None => None,
};
let model_name_str = response.model.as_deref().unwrap_or("unknown");
Ok(ChatCompletion {
message: chat_message,
usage,
model_name: model_name_str.to_string(),
})
})
}
}