Skip to main content

nihility_model/
config.rs

1use crate::error::Result;
2use crate::{ModelClient, NihilityModelError};
3use async_openai::Client;
4use async_openai::config::{AzureConfig, OpenAIConfig};
5use secrecy::{ExposeSecret, SecretString};
6use serde::{Deserialize, Serialize, Serializer};
7use std::collections::HashMap;
8use tracing::debug;
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
11pub struct NihilityModelConfig {
12    pub channel: HashMap<String, ApiConfig>,
13    pub model: HashMap<ModelType, ChannelModel>,
14}
15
16impl NihilityModelConfig {
17    pub fn model_client(
18        &self,
19        http_client: reqwest::Client,
20        model_type: ModelType,
21    ) -> Result<(ModelClient, String)> {
22        debug!("Building {:?} model client", model_type);
23        let channel_model = self
24            .model
25            .get(&model_type)
26            .ok_or(NihilityModelError::NoMatchModel(model_type.clone()))?;
27        debug!("Get channel model config: {:?}", channel_model);
28        let channel_config = self
29            .channel
30            .get(&channel_model.channel)
31            .ok_or(NihilityModelError::NoMatchModel(model_type))?;
32        debug!("Get channel config: {:?}", channel_config);
33        match &channel_config.r#type {
34            ApiConfigType::OpenAI => {
35                let config = OpenAIConfig::new()
36                    .with_api_base(channel_config.api_base.clone())
37                    .with_api_key(channel_config.api_key.expose_secret())
38                    .with_org_id(channel_config.org_id.clone())
39                    .with_project_id(channel_config.project_id.clone());
40                Ok((
41                    ModelClient::OpenAI(Client::build(
42                        http_client,
43                        config,
44                        backoff::ExponentialBackoff::default(),
45                    )),
46                    channel_model.model.clone(),
47                ))
48            }
49            ApiConfigType::Azure => {
50                let config = AzureConfig::new()
51                    .with_api_base(channel_config.api_base.clone())
52                    .with_api_key(channel_config.api_key.expose_secret())
53                    .with_api_version(channel_config.api_version.clone())
54                    .with_deployment_id(channel_config.deployment_id.clone());
55                Ok((
56                    ModelClient::Azure(Client::build(
57                        http_client,
58                        config,
59                        backoff::ExponentialBackoff::default(),
60                    )),
61                    channel_model.model.clone(),
62                ))
63            }
64        }
65    }
66}
67
68#[derive(Debug, Clone, Deserialize, Serialize)]
69pub struct ChannelModel {
70    pub channel: String,
71    pub model: String,
72}
73
74#[derive(Clone, Debug, Hash, PartialEq, Eq, Deserialize, Serialize)]
75pub enum ModelType {
76    TextSmall,
77    TextLarge,
78    VlSmall,
79    VlLarge,
80    Embeddings,
81}
82
83#[derive(Debug, Clone, Deserialize, Serialize, Default)]
84pub enum ApiConfigType {
85    #[default]
86    OpenAI,
87    Azure,
88}
89
90#[derive(Debug, Clone, Deserialize, Default)]
91pub struct ApiConfig {
92    pub r#type: ApiConfigType,
93    pub api_version: String,
94    pub deployment_id: String,
95    pub org_id: String,
96    pub project_id: String,
97    pub api_base: String,
98    api_key: SecretString,
99}
100
101impl Default for NihilityModelConfig {
102    fn default() -> Self {
103        let mut channel = HashMap::new();
104        channel.insert("test".to_string(), ApiConfig::default());
105        let mut model = HashMap::new();
106        model.insert(
107            ModelType::TextSmall,
108            ChannelModel {
109                channel: "test".to_string(),
110                model: "small".to_string(),
111            },
112        );
113        Self { channel, model }
114    }
115}
116
117impl Serialize for ApiConfig {
118    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
119    where
120        S: Serializer,
121    {
122        use serde::ser::SerializeStruct;
123
124        let mut state = serializer.serialize_struct("ApiConfig", 7)?;
125        state.serialize_field("type", &self.r#type)?;
126        state.serialize_field("api_version", &self.api_version)?;
127        state.serialize_field("deployment_id", &self.deployment_id)?;
128        state.serialize_field("org_id", &self.org_id)?;
129        state.serialize_field("project_id", &self.project_id)?;
130        state.serialize_field("api_base", &self.api_base)?;
131        state.serialize_field("api_key", &"")?;
132        state.end()
133    }
134}