chronicle_proxy/
builder.rs

1use std::{path::Path, sync::Arc, time::Duration};
2
3use error_stack::{Report, ResultExt};
4
5#[cfg(feature = "postgres")]
6use crate::database::postgres::PostgresDatabase;
7#[cfg(feature = "sqlite")]
8use crate::database::sqlite::SqliteDatabase;
9use crate::{
10    config::{AliasConfig, ApiKeyConfig, CustomProviderConfig, ProxyConfig},
11    database::{load_providers_from_database, logging::start_database_logger, Database},
12    providers::{
13        anthropic::Anthropic, anyscale::Anyscale, deepinfra::DeepInfra, fireworks::Fireworks,
14        groq::Groq, mistral::Mistral, ollama::Ollama, openai::OpenAi, together::Together,
15        ChatModelProvider,
16    },
17    Error, ProviderLookup, Proxy,
18};
19
20enum MaybeEmpty<T> {
21    None,
22    Empty,
23    Some(T),
24}
25
26impl<T> MaybeEmpty<T> {
27    /// Convert this `MaybeEmpty` into an `Option<T>`, setting the Empty case to None.
28    fn to_option(self) -> Option<T> {
29        match self {
30            MaybeEmpty::None => None,
31            MaybeEmpty::Empty => None,
32            MaybeEmpty::Some(t) => Some(t),
33        }
34    }
35
36    /// Return true if the value is Empty or Some
37    fn is_set(&self) -> bool {
38        matches!(self, MaybeEmpty::Some(_) | MaybeEmpty::Empty)
39    }
40}
41
42/// A builder for [Proxy]
43pub struct ProxyBuilder {
44    database: Option<Database>,
45    config: ProxyConfig,
46    load_config_from_database: bool,
47    client: Option<reqwest::Client>,
48    providers: Vec<Arc<dyn ChatModelProvider>>,
49
50    anthropic: Option<String>,
51    anyscale: Option<String>,
52    #[cfg(feature = "aws-bedrock")]
53    aws_bedrock: MaybeEmpty<aws_sdk_bedrockruntime::Client>,
54    deepinfra: Option<String>,
55    fireworks: Option<String>,
56    groq: Option<String>,
57    mistral: Option<String>,
58    ollama: Option<String>,
59    openai: Option<String>,
60    together: Option<String>,
61}
62
63impl ProxyBuilder {
64    /// Create a new builder
65    pub fn new() -> Self {
66        Self {
67            database: None,
68            config: ProxyConfig::default(),
69            load_config_from_database: true,
70            client: None,
71            providers: Vec::new(),
72
73            anthropic: Some(String::new()),
74            anyscale: Some(String::new()),
75            #[cfg(feature = "aws-bedrock")]
76            aws_bedrock: MaybeEmpty::Empty,
77            deepinfra: Some(String::new()),
78            fireworks: Some(String::new()),
79            groq: Some(String::new()),
80            mistral: Some(String::new()),
81            ollama: Some(String::new()),
82            openai: Some(String::new()),
83            together: Some(String::new()),
84        }
85    }
86
87    /// Set the database with a pre-made adapter
88    pub fn with_database(mut self, database: Database) -> Self {
89        self.database = Some(database);
90        self
91    }
92
93    #[cfg(feature = "sqlite")]
94    /// Use this SQLite database
95    pub fn with_sqlite_pool(mut self, pool: sqlx::SqlitePool) -> Self {
96        self.database = Some(SqliteDatabase::new(pool));
97        self
98    }
99
100    #[cfg(feature = "postgres")]
101    /// Use this PostgreSQL database pool
102    pub fn with_postgres_pool(mut self, pool: sqlx::PgPool) -> Self {
103        self.database = Some(PostgresDatabase::new(pool));
104        self
105    }
106
107    /// Load configuration for custom providers, aliases, and API keys from the database. If a
108    /// database pool is provided, this defaults to true.
109    pub fn load_config_from_database(mut self, load: bool) -> Self {
110        self.load_config_from_database = load;
111        self
112    }
113
114    /// Enable or disable logging to the database. Logging requires `with_database` to have been
115    /// called.
116    pub fn log_to_database(mut self, log: bool) -> Self {
117        self.config.log_to_database = Some(log);
118        self
119    }
120
121    /// Merge this configuration into the current one.
122    pub fn with_config(mut self, config: ProxyConfig) -> Self {
123        self.config.default_timeout = config.default_timeout.or(self.config.default_timeout);
124        self.config.log_to_database = config.log_to_database.or(self.config.log_to_database);
125        if config.user_agent.is_some() {
126            self.config.user_agent = config.user_agent;
127        }
128        self.config.providers.extend(config.providers);
129        self.config.aliases.extend(config.aliases);
130        self.config.api_keys.extend(config.api_keys);
131        self
132    }
133
134    /// Read a configuration file from this path and merge it into the current configuration.
135    pub async fn with_config_from_path(self, path: &Path) -> Result<Self, Report<Error>> {
136        let data = tokio::fs::read_to_string(path)
137            .await
138            .change_context(Error::ReadingConfig)?;
139        let config: ProxyConfig = toml::from_str(&data).change_context(Error::ReadingConfig)?;
140
141        Ok(self.with_config(config))
142    }
143
144    /// Add an [AliasConfig] to the [Proxy]
145    pub fn with_alias(mut self, alias: AliasConfig) -> Self {
146        self.config.aliases.push(alias);
147        self
148    }
149
150    /// Add multiple [AliasConfig] objects to the [Proxy]
151    pub fn with_aliases(mut self, aliases: Vec<AliasConfig>) -> Self {
152        self.config.aliases.extend(aliases);
153        self
154    }
155
156    /// Add an [ApiKeyConfig] to the proxy
157    pub fn with_api_key(mut self, key: ApiKeyConfig) -> Self {
158        self.config.api_keys.push(key);
159        self
160    }
161
162    /// Add multiple [ApiKeyConfig] objects to the proxy
163    pub fn with_api_keys(mut self, keys: Vec<ApiKeyConfig>) -> Self {
164        self.config.api_keys.extend(keys);
165        self
166    }
167
168    /// Add a custom provider to the list of providers
169    pub fn with_custom_provider(mut self, config: CustomProviderConfig) -> Self {
170        self.config.providers.push(config);
171        self
172    }
173
174    /// Add multiple custom providers to the list of providers
175    pub fn with_custom_providers(mut self, configs: Vec<CustomProviderConfig>) -> Self {
176        self.config.providers.extend(configs);
177        self
178    }
179
180    /// Add a precreated provider to the list of providers. This can be used to create your own
181    /// custom providers that require capabilities not provided by the [CustomProviderConfig].
182    pub fn with_provider(mut self, provider: Arc<dyn ChatModelProvider>) -> Self {
183        self.providers.push(provider);
184        self
185    }
186
187    /// Enable the OpenAI provider, if it was disabled by [without_default_providers]
188    pub fn with_openai(mut self, token: Option<String>) -> Self {
189        self.openai = token.or(Some(String::new()));
190        self
191    }
192
193    /// Enable the Anyscale provider, if it was disabled by [without_default_providers]
194    pub fn with_anyscale(mut self, token: Option<String>) -> Self {
195        self.anyscale = token.or(Some(String::new()));
196        self
197    }
198
199    /// Enable the Anthropic provider, if it was disabled by [without_default_providers]
200    pub fn with_anthropic(mut self, token: Option<String>) -> Self {
201        self.anthropic = token.or(Some(String::new()));
202        self
203    }
204
205    #[cfg(feature = "aws-bedrock")]
206    /// Enable the AWS Bedrock provider, possibly passing a custom client.
207    pub fn with_aws_bedrock(mut self, client: Option<aws_sdk_bedrockruntime::Client>) -> Self {
208        self.aws_bedrock = match client {
209            Some(client) => MaybeEmpty::Some(client),
210            None => MaybeEmpty::Empty,
211        };
212        self
213    }
214
215    /// Enable the DeepInfra provider, if it was disabled by [without_default_providers]
216    pub fn with_deepinfra(mut self, token: Option<String>) -> Self {
217        self.deepinfra = token.or(Some(String::new()));
218        self
219    }
220
221    /// Enable the Fireworks provider, if it was disabled by [without_default_providers]
222    pub fn with_fireworks(mut self, token: Option<String>) -> Self {
223        self.fireworks = token.or(Some(String::new()));
224        self
225    }
226
227    /// Enable the Groq provider, if it was disabled by [without_default_providers]
228    pub fn with_groq(mut self, token: Option<String>) -> Self {
229        self.groq = token.or(Some(String::new()));
230        self
231    }
232
233    /// Enable the Mistral provider, if it was disabled by [without_default_providers]
234    pub fn with_mistral(mut self, token: Option<String>) -> Self {
235        self.mistral = token.or(Some(String::new()));
236        self
237    }
238
239    /// Enable the Together provider, if it was disabled by [without_default_providers]
240    pub fn with_together(mut self, token: Option<String>) -> Self {
241        self.together = token.or(Some(String::new()));
242        self
243    }
244
245    /// Enable the Ollama provider, if it was disabled by [without_default_providers]
246    pub fn with_ollama(mut self, url: Option<String>) -> Self {
247        self.ollama = url.or(Some(String::new()));
248        self
249    }
250
251    /// Don't load the default providers
252    pub fn without_default_providers(mut self) -> Self {
253        self.anthropic = None;
254        self.anyscale = None;
255        #[cfg(feature = "aws-bedrock")]
256        {
257            self.aws_bedrock = MaybeEmpty::None;
258        }
259        self.deepinfra = None;
260        self.fireworks = None;
261        self.groq = None;
262        self.mistral = None;
263        self.openai = None;
264        self.ollama = None;
265        self.together = None;
266        self
267    }
268
269    /// Set the user agent that will be used for HTTP requests. This only applies if
270    /// `with_client` is not used.
271    pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
272        self.config.user_agent = Some(user_agent.into());
273        self
274    }
275
276    /// Supply a custom [reqwest::Client] that the proxy will use to make requests.
277    pub fn with_client(mut self, client: reqwest::Client) -> Self {
278        self.client = Some(client);
279        self
280    }
281
282    /// Build the proxy from the supplied options.
283    pub async fn build(self) -> Result<Proxy, Report<Error>> {
284        let mut providers = self.providers;
285        let mut provider_configs = self.config.providers;
286        let mut api_keys = self.config.api_keys;
287        let mut aliases = self.config.aliases;
288        let logger = if let Some(db) = &self.database {
289            if self.load_config_from_database {
290                let db_providers =
291                    load_providers_from_database(db.as_ref(), "chronicle_custom_providers").await?;
292                let db_aliases = db
293                    .load_aliases_from_database("chronicle_aliases", "chronicle_alias_providers")
294                    .await?;
295                let db_api_keys = db
296                    .load_api_key_configs_from_database("chronicle_api_keys")
297                    .await?;
298
299                provider_configs.extend(db_providers);
300                aliases.extend(db_aliases);
301                api_keys.extend(db_api_keys);
302            }
303
304            let logger = if self.config.log_to_database.unwrap_or(false) {
305                Some(start_database_logger(
306                    db.clone(),
307                    100,
308                    Duration::from_secs(1),
309                ))
310            } else {
311                None
312            };
313
314            logger
315        } else {
316            None
317        };
318
319        let client = self.client.unwrap_or_else(|| {
320            reqwest::Client::builder()
321                .user_agent(self.config.user_agent.as_deref().unwrap_or("chronicle"))
322                .timeout(
323                    self.config
324                        .default_timeout
325                        .unwrap_or(Duration::from_secs(60)),
326                )
327                .build()
328                .unwrap()
329        });
330
331        providers.extend(
332            provider_configs
333                .into_iter()
334                .map(|c| Arc::new(c.into_provider(client.clone())) as Arc<dyn ChatModelProvider>),
335        );
336
337        fn empty_to_none(s: String) -> Option<String> {
338            if s.is_empty() {
339                None
340            } else {
341                Some(s)
342            }
343        }
344
345        if let Some(token) = self.anthropic {
346            providers.push(
347                Arc::new(Anthropic::new(client.clone(), empty_to_none(token)))
348                    as Arc<dyn ChatModelProvider>,
349            );
350        }
351
352        if let Some(token) = self.anyscale {
353            providers.push(
354                Arc::new(Anyscale::new(client.clone(), empty_to_none(token)))
355                    as Arc<dyn ChatModelProvider>,
356            );
357        }
358
359        #[cfg(feature = "aws-bedrock")]
360        if self.aws_bedrock.is_set() {
361            providers.push(Arc::new(
362                crate::providers::aws_bedrock::AwsBedrock::new(self.aws_bedrock.to_option()).await,
363            ) as Arc<dyn ChatModelProvider>);
364        }
365
366        if let Some(token) = self.deepinfra {
367            providers.push(
368                Arc::new(DeepInfra::new(client.clone(), empty_to_none(token)))
369                    as Arc<dyn ChatModelProvider>,
370            );
371        }
372
373        if let Some(token) = self.fireworks {
374            providers.push(
375                Arc::new(Fireworks::new(client.clone(), empty_to_none(token)))
376                    as Arc<dyn ChatModelProvider>,
377            );
378        }
379
380        if let Some(token) = self.groq {
381            providers.push(Arc::new(Groq::new(client.clone(), empty_to_none(token)))
382                as Arc<dyn ChatModelProvider>);
383        }
384
385        if let Some(token) = self.mistral {
386            providers.push(Arc::new(Mistral::new(client.clone(), empty_to_none(token)))
387                as Arc<dyn ChatModelProvider>);
388        }
389
390        if let Some(url) = self.ollama {
391            providers.push(Arc::new(Ollama::new(client.clone(), empty_to_none(url)))
392                as Arc<dyn ChatModelProvider>);
393        }
394
395        if let Some(token) = self.openai {
396            providers.push(Arc::new(OpenAi::new(client.clone(), empty_to_none(token)))
397                as Arc<dyn ChatModelProvider>);
398        }
399
400        if let Some(token) = self.together {
401            providers.push(
402                Arc::new(Together::new(client.clone(), empty_to_none(token)))
403                    as Arc<dyn ChatModelProvider>,
404            );
405        }
406
407        let (log_tx, log_task) = logger.unzip();
408
409        let api_keys = api_keys
410            .into_iter()
411            .map(|mut config| {
412                if config.source == "env" {
413                    let value = std::env::var(&config.value).map_err(|_| {
414                        Error::MissingApiKeyEnv(config.name.clone(), config.value.clone())
415                    })?;
416
417                    config.value = value;
418                }
419
420                Ok::<_, Error>(config)
421            })
422            .collect::<Result<Vec<_>, Error>>()?;
423
424        let lookup = ProviderLookup::new(providers, aliases, api_keys);
425
426        Ok(Proxy {
427            lookup,
428            default_timeout: self.config.default_timeout,
429            log_tx,
430            log_task,
431        })
432    }
433}