use crate::error::Result;
use crate::{ModelClient, NihilityModelError};
use async_openai::Client;
use async_openai::config::{AzureConfig, OpenAIConfig};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize, Serializer};
use std::collections::HashMap;
use tracing::debug;
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct NihilityModelConfig {
pub channel: HashMap<String, ApiConfig>,
pub model: HashMap<ModelType, ChannelModel>,
}
impl NihilityModelConfig {
pub fn model_client(
&self,
http_client: reqwest::Client,
model_type: ModelType,
) -> Result<(ModelClient, String)> {
debug!("Building {:?} model client", model_type);
let channel_model = self
.model
.get(&model_type)
.ok_or(NihilityModelError::NoMatchModel(model_type.clone()))?;
debug!("Get channel model config: {:?}", channel_model);
let channel_config = self
.channel
.get(&channel_model.channel)
.ok_or(NihilityModelError::NoMatchModel(model_type))?;
debug!("Get channel config: {:?}", channel_config);
match &channel_config.r#type {
ApiConfigType::OpenAI => {
let config = OpenAIConfig::new()
.with_api_base(channel_config.api_base.clone())
.with_api_key(channel_config.api_key.expose_secret())
.with_org_id(channel_config.org_id.clone())
.with_project_id(channel_config.project_id.clone());
Ok((
ModelClient::OpenAI(Client::build(
http_client,
config,
backoff::ExponentialBackoff::default(),
)),
channel_model.model.clone(),
))
}
ApiConfigType::Azure => {
let config = AzureConfig::new()
.with_api_base(channel_config.api_base.clone())
.with_api_key(channel_config.api_key.expose_secret())
.with_api_version(channel_config.api_version.clone())
.with_deployment_id(channel_config.deployment_id.clone());
Ok((
ModelClient::Azure(Client::build(
http_client,
config,
backoff::ExponentialBackoff::default(),
)),
channel_model.model.clone(),
))
}
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChannelModel {
pub channel: String,
pub model: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, Deserialize, Serialize)]
pub enum ModelType {
TextSmall,
TextLarge,
VlSmall,
VlLarge,
Embeddings,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub enum ApiConfigType {
#[default]
OpenAI,
Azure,
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct ApiConfig {
pub r#type: ApiConfigType,
pub api_version: String,
pub deployment_id: String,
pub org_id: String,
pub project_id: String,
pub api_base: String,
api_key: SecretString,
}
impl Default for NihilityModelConfig {
fn default() -> Self {
let mut channel = HashMap::new();
channel.insert("test".to_string(), ApiConfig::default());
let mut model = HashMap::new();
model.insert(
ModelType::TextSmall,
ChannelModel {
channel: "test".to_string(),
model: "small".to_string(),
},
);
Self { channel, model }
}
}
impl Serialize for ApiConfig {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("ApiConfig", 7)?;
state.serialize_field("type", &self.r#type)?;
state.serialize_field("api_version", &self.api_version)?;
state.serialize_field("deployment_id", &self.deployment_id)?;
state.serialize_field("org_id", &self.org_id)?;
state.serialize_field("project_id", &self.project_id)?;
state.serialize_field("api_base", &self.api_base)?;
state.serialize_field("api_key", &"")?;
state.end()
}
}