nihility-model 0.2.2

nihility project ai model module
Documentation
use crate::error::Result;
use crate::{ModelClient, NihilityModelError};
use async_openai::Client;
use async_openai::config::{AzureConfig, OpenAIConfig};
use secrecy::{ExposeSecret, SecretString};
use serde::{Deserialize, Serialize, Serializer};
use std::collections::HashMap;
use tracing::debug;

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct NihilityModelConfig {
    pub channel: HashMap<String, ApiConfig>,
    pub model: HashMap<ModelType, ChannelModel>,
}

impl NihilityModelConfig {
    pub fn model_client(
        &self,
        http_client: reqwest::Client,
        model_type: ModelType,
    ) -> Result<(ModelClient, String)> {
        debug!("Building {:?} model client", model_type);
        let channel_model = self
            .model
            .get(&model_type)
            .ok_or(NihilityModelError::NoMatchModel(model_type.clone()))?;
        debug!("Get channel model config: {:?}", channel_model);
        let channel_config = self
            .channel
            .get(&channel_model.channel)
            .ok_or(NihilityModelError::NoMatchModel(model_type))?;
        debug!("Get channel config: {:?}", channel_config);
        match &channel_config.r#type {
            ApiConfigType::OpenAI => {
                let config = OpenAIConfig::new()
                    .with_api_base(channel_config.api_base.clone())
                    .with_api_key(channel_config.api_key.expose_secret())
                    .with_org_id(channel_config.org_id.clone())
                    .with_project_id(channel_config.project_id.clone());
                Ok((
                    ModelClient::OpenAI(Client::build(
                        http_client,
                        config,
                        backoff::ExponentialBackoff::default(),
                    )),
                    channel_model.model.clone(),
                ))
            }
            ApiConfigType::Azure => {
                let config = AzureConfig::new()
                    .with_api_base(channel_config.api_base.clone())
                    .with_api_key(channel_config.api_key.expose_secret())
                    .with_api_version(channel_config.api_version.clone())
                    .with_deployment_id(channel_config.deployment_id.clone());
                Ok((
                    ModelClient::Azure(Client::build(
                        http_client,
                        config,
                        backoff::ExponentialBackoff::default(),
                    )),
                    channel_model.model.clone(),
                ))
            }
        }
    }
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChannelModel {
    pub channel: String,
    pub model: String,
}

#[derive(Clone, Debug, Hash, PartialEq, Eq, Deserialize, Serialize)]
pub enum ModelType {
    TextSmall,
    TextLarge,
    VlSmall,
    VlLarge,
    Embeddings,
}

#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub enum ApiConfigType {
    #[default]
    OpenAI,
    Azure,
}

#[derive(Debug, Clone, Deserialize, Default)]
pub struct ApiConfig {
    pub r#type: ApiConfigType,
    pub api_version: String,
    pub deployment_id: String,
    pub org_id: String,
    pub project_id: String,
    pub api_base: String,
    api_key: SecretString,
}

impl Default for NihilityModelConfig {
    fn default() -> Self {
        let mut channel = HashMap::new();
        channel.insert("test".to_string(), ApiConfig::default());
        let mut model = HashMap::new();
        model.insert(
            ModelType::TextSmall,
            ChannelModel {
                channel: "test".to_string(),
                model: "small".to_string(),
            },
        );
        Self { channel, model }
    }
}

impl Serialize for ApiConfig {
    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        use serde::ser::SerializeStruct;

        let mut state = serializer.serialize_struct("ApiConfig", 7)?;
        state.serialize_field("type", &self.r#type)?;
        state.serialize_field("api_version", &self.api_version)?;
        state.serialize_field("deployment_id", &self.deployment_id)?;
        state.serialize_field("org_id", &self.org_id)?;
        state.serialize_field("project_id", &self.project_id)?;
        state.serialize_field("api_base", &self.api_base)?;
        state.serialize_field("api_key", &"")?;
        state.end()
    }
}