use serde_json::{Value, json};
use crate::{
constants::OPENAI_CHAT_COMPLETION_ROUTE,
message::Message,
traits::{AsyncGenerateData, GenerateData, IsLLM},
};
#[derive(Debug, Clone)]
pub struct OpenAILLM {
model: String,
api_key: String,
api_base: String,
}
impl OpenAILLM {
pub fn new(
api_base: &str,
api_key: &str,
model: &str,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync + 'static>> {
Ok(Self {
model: model.to_string(),
api_base: api_base.to_string(),
api_key: api_key.to_string(),
})
}
}
impl IsLLM for OpenAILLM {
fn get_authorization_credentials(&self) -> String {
format!("Bearer {}", self.api_key)
}
fn get_model_ref(&self) -> &str {
&self.model
}
fn get_chat_completion_request_url(&self) -> String {
format!("{}{}", self.api_base, OPENAI_CHAT_COMPLETION_ROUTE)
}
fn get_request_body(&self, message: Message, return_json: bool) -> Value {
if return_json {
return json!(
{
"model": self.get_model_ref(),
"messages": [message],
"response_format": {"type": "json_object"}
}
);
}
return json!(
{
"model": self.get_model_ref(),
"messages": [message],
}
);
}
}
impl GenerateData for OpenAILLM {}
impl AsyncGenerateData for OpenAILLM {}