use serde_json::{Value, json};
use crate::{
message::Message,
traits::{AsyncGenerateData, GenerateData, IsLLM},
};
#[derive(Debug, Clone)]
pub struct AzureOpenAILLM {
model: String,
base_url: String,
api_key: String,
}
impl AzureOpenAILLM {
pub fn new(api_base: &str, api_key: &str, deployment_id: &str, api_version: &str) -> Self {
let base_url: String = format!(
"{}/openai/deployments/{}/chat/completions?api-version={}",
api_base, deployment_id, api_version
);
Self {
model: deployment_id.to_string(),
base_url,
api_key: api_key.to_string(),
}
}
}
impl IsLLM for AzureOpenAILLM {
fn get_authorization_credentials(&self) -> String {
self.api_key.clone()
}
fn get_model_ref(&self) -> &str {
&self.model
}
fn get_chat_completion_request_url(&self) -> String {
self.base_url.clone()
}
fn get_request_body(&self, message: Message, return_json: bool) -> Value {
if return_json {
return json!(
{
"messages": [message],
"response_format": {"type": "json_object"}
}
);
}
return json!(
{
"messages": [message],
}
);
}
}
impl GenerateData for AzureOpenAILLM {}
impl AsyncGenerateData for AzureOpenAILLM {}