1use crate::error::Result;
2use crate::{ModelClient, NihilityModelError};
3use async_openai::Client;
4use async_openai::config::{AzureConfig, OpenAIConfig};
5use secrecy::{ExposeSecret, SecretString};
6use serde::{Deserialize, Serialize, Serializer};
7use std::collections::HashMap;
8use tracing::debug;
9
10#[derive(Debug, Clone, Deserialize, Serialize)]
11pub struct NihilityModelConfig {
12 pub channel: HashMap<String, ApiConfig>,
13 pub model: HashMap<ModelType, ChannelModel>,
14}
15
16impl NihilityModelConfig {
17 pub fn model_client(
18 &self,
19 http_client: reqwest::Client,
20 model_type: ModelType,
21 ) -> Result<(ModelClient, String)> {
22 debug!("Building {:?} model client", model_type);
23 let channel_model = self
24 .model
25 .get(&model_type)
26 .ok_or(NihilityModelError::NoMatchModel(model_type.clone()))?;
27 debug!("Get channel model config: {:?}", channel_model);
28 let channel_config = self
29 .channel
30 .get(&channel_model.channel)
31 .ok_or(NihilityModelError::NoMatchModel(model_type))?;
32 debug!("Get channel config: {:?}", channel_config);
33 match &channel_config.r#type {
34 ApiConfigType::OpenAI => {
35 let config = OpenAIConfig::new()
36 .with_api_base(channel_config.api_base.clone())
37 .with_api_key(channel_config.api_key.expose_secret())
38 .with_org_id(channel_config.org_id.clone())
39 .with_project_id(channel_config.project_id.clone());
40 Ok((
41 ModelClient::OpenAI(Client::build(
42 http_client,
43 config,
44 backoff::ExponentialBackoff::default(),
45 )),
46 channel_model.model.clone(),
47 ))
48 }
49 ApiConfigType::Azure => {
50 let config = AzureConfig::new()
51 .with_api_base(channel_config.api_base.clone())
52 .with_api_key(channel_config.api_key.expose_secret())
53 .with_api_version(channel_config.api_version.clone())
54 .with_deployment_id(channel_config.deployment_id.clone());
55 Ok((
56 ModelClient::Azure(Client::build(
57 http_client,
58 config,
59 backoff::ExponentialBackoff::default(),
60 )),
61 channel_model.model.clone(),
62 ))
63 }
64 }
65 }
66}
67
68#[derive(Debug, Clone, Deserialize, Serialize)]
69pub struct ChannelModel {
70 pub channel: String,
71 pub model: String,
72}
73
74#[derive(Clone, Debug, Hash, PartialEq, Eq, Deserialize, Serialize)]
75pub enum ModelType {
76 TextSmall,
77 TextLarge,
78 VlSmall,
79 VlLarge,
80 Embeddings,
81}
82
83#[derive(Debug, Clone, Deserialize, Serialize, Default)]
84pub enum ApiConfigType {
85 #[default]
86 OpenAI,
87 Azure,
88}
89
90#[derive(Debug, Clone, Deserialize, Default)]
91pub struct ApiConfig {
92 pub r#type: ApiConfigType,
93 pub api_version: String,
94 pub deployment_id: String,
95 pub org_id: String,
96 pub project_id: String,
97 pub api_base: String,
98 api_key: SecretString,
99}
100
101impl Default for NihilityModelConfig {
102 fn default() -> Self {
103 let mut channel = HashMap::new();
104 channel.insert("test".to_string(), ApiConfig::default());
105 let mut model = HashMap::new();
106 model.insert(
107 ModelType::TextSmall,
108 ChannelModel {
109 channel: "test".to_string(),
110 model: "small".to_string(),
111 },
112 );
113 Self { channel, model }
114 }
115}
116
117impl Serialize for ApiConfig {
118 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
119 where
120 S: Serializer,
121 {
122 use serde::ser::SerializeStruct;
123
124 let mut state = serializer.serialize_struct("ApiConfig", 7)?;
125 state.serialize_field("type", &self.r#type)?;
126 state.serialize_field("api_version", &self.api_version)?;
127 state.serialize_field("deployment_id", &self.deployment_id)?;
128 state.serialize_field("org_id", &self.org_id)?;
129 state.serialize_field("project_id", &self.project_id)?;
130 state.serialize_field("api_base", &self.api_base)?;
131 state.serialize_field("api_key", &"")?;
132 state.end()
133 }
134}