1use std::{path::Path, sync::Arc, time::Duration};
2
3use error_stack::{Report, ResultExt};
4
5#[cfg(feature = "postgres")]
6use crate::database::postgres::PostgresDatabase;
7#[cfg(feature = "sqlite")]
8use crate::database::sqlite::SqliteDatabase;
9use crate::{
10 config::{AliasConfig, ApiKeyConfig, CustomProviderConfig, ProxyConfig},
11 database::{load_providers_from_database, logging::start_database_logger, Database},
12 providers::{
13 anthropic::Anthropic, anyscale::Anyscale, deepinfra::DeepInfra, fireworks::Fireworks,
14 groq::Groq, mistral::Mistral, ollama::Ollama, openai::OpenAi, together::Together,
15 ChatModelProvider,
16 },
17 Error, ProviderLookup, Proxy,
18};
19
20enum MaybeEmpty<T> {
21 None,
22 Empty,
23 Some(T),
24}
25
26impl<T> MaybeEmpty<T> {
27 fn to_option(self) -> Option<T> {
29 match self {
30 MaybeEmpty::None => None,
31 MaybeEmpty::Empty => None,
32 MaybeEmpty::Some(t) => Some(t),
33 }
34 }
35
36 fn is_set(&self) -> bool {
38 matches!(self, MaybeEmpty::Some(_) | MaybeEmpty::Empty)
39 }
40}
41
42pub struct ProxyBuilder {
44 database: Option<Database>,
45 config: ProxyConfig,
46 load_config_from_database: bool,
47 client: Option<reqwest::Client>,
48 providers: Vec<Arc<dyn ChatModelProvider>>,
49
50 anthropic: Option<String>,
51 anyscale: Option<String>,
52 #[cfg(feature = "aws-bedrock")]
53 aws_bedrock: MaybeEmpty<aws_sdk_bedrockruntime::Client>,
54 deepinfra: Option<String>,
55 fireworks: Option<String>,
56 groq: Option<String>,
57 mistral: Option<String>,
58 ollama: Option<String>,
59 openai: Option<String>,
60 together: Option<String>,
61}
62
63impl ProxyBuilder {
64 pub fn new() -> Self {
66 Self {
67 database: None,
68 config: ProxyConfig::default(),
69 load_config_from_database: true,
70 client: None,
71 providers: Vec::new(),
72
73 anthropic: Some(String::new()),
74 anyscale: Some(String::new()),
75 #[cfg(feature = "aws-bedrock")]
76 aws_bedrock: MaybeEmpty::Empty,
77 deepinfra: Some(String::new()),
78 fireworks: Some(String::new()),
79 groq: Some(String::new()),
80 mistral: Some(String::new()),
81 ollama: Some(String::new()),
82 openai: Some(String::new()),
83 together: Some(String::new()),
84 }
85 }
86
87 pub fn with_database(mut self, database: Database) -> Self {
89 self.database = Some(database);
90 self
91 }
92
93 #[cfg(feature = "sqlite")]
94 pub fn with_sqlite_pool(mut self, pool: sqlx::SqlitePool) -> Self {
96 self.database = Some(SqliteDatabase::new(pool));
97 self
98 }
99
100 #[cfg(feature = "postgres")]
101 pub fn with_postgres_pool(mut self, pool: sqlx::PgPool) -> Self {
103 self.database = Some(PostgresDatabase::new(pool));
104 self
105 }
106
107 pub fn load_config_from_database(mut self, load: bool) -> Self {
110 self.load_config_from_database = load;
111 self
112 }
113
114 pub fn log_to_database(mut self, log: bool) -> Self {
117 self.config.log_to_database = Some(log);
118 self
119 }
120
121 pub fn with_config(mut self, config: ProxyConfig) -> Self {
123 self.config.default_timeout = config.default_timeout.or(self.config.default_timeout);
124 self.config.log_to_database = config.log_to_database.or(self.config.log_to_database);
125 if config.user_agent.is_some() {
126 self.config.user_agent = config.user_agent;
127 }
128 self.config.providers.extend(config.providers);
129 self.config.aliases.extend(config.aliases);
130 self.config.api_keys.extend(config.api_keys);
131 self
132 }
133
134 pub async fn with_config_from_path(self, path: &Path) -> Result<Self, Report<Error>> {
136 let data = tokio::fs::read_to_string(path)
137 .await
138 .change_context(Error::ReadingConfig)?;
139 let config: ProxyConfig = toml::from_str(&data).change_context(Error::ReadingConfig)?;
140
141 Ok(self.with_config(config))
142 }
143
144 pub fn with_alias(mut self, alias: AliasConfig) -> Self {
146 self.config.aliases.push(alias);
147 self
148 }
149
150 pub fn with_aliases(mut self, aliases: Vec<AliasConfig>) -> Self {
152 self.config.aliases.extend(aliases);
153 self
154 }
155
156 pub fn with_api_key(mut self, key: ApiKeyConfig) -> Self {
158 self.config.api_keys.push(key);
159 self
160 }
161
162 pub fn with_api_keys(mut self, keys: Vec<ApiKeyConfig>) -> Self {
164 self.config.api_keys.extend(keys);
165 self
166 }
167
168 pub fn with_custom_provider(mut self, config: CustomProviderConfig) -> Self {
170 self.config.providers.push(config);
171 self
172 }
173
174 pub fn with_custom_providers(mut self, configs: Vec<CustomProviderConfig>) -> Self {
176 self.config.providers.extend(configs);
177 self
178 }
179
180 pub fn with_provider(mut self, provider: Arc<dyn ChatModelProvider>) -> Self {
183 self.providers.push(provider);
184 self
185 }
186
187 pub fn with_openai(mut self, token: Option<String>) -> Self {
189 self.openai = token.or(Some(String::new()));
190 self
191 }
192
193 pub fn with_anyscale(mut self, token: Option<String>) -> Self {
195 self.anyscale = token.or(Some(String::new()));
196 self
197 }
198
199 pub fn with_anthropic(mut self, token: Option<String>) -> Self {
201 self.anthropic = token.or(Some(String::new()));
202 self
203 }
204
205 #[cfg(feature = "aws-bedrock")]
206 pub fn with_aws_bedrock(mut self, client: Option<aws_sdk_bedrockruntime::Client>) -> Self {
208 self.aws_bedrock = match client {
209 Some(client) => MaybeEmpty::Some(client),
210 None => MaybeEmpty::Empty,
211 };
212 self
213 }
214
215 pub fn with_deepinfra(mut self, token: Option<String>) -> Self {
217 self.deepinfra = token.or(Some(String::new()));
218 self
219 }
220
221 pub fn with_fireworks(mut self, token: Option<String>) -> Self {
223 self.fireworks = token.or(Some(String::new()));
224 self
225 }
226
227 pub fn with_groq(mut self, token: Option<String>) -> Self {
229 self.groq = token.or(Some(String::new()));
230 self
231 }
232
233 pub fn with_mistral(mut self, token: Option<String>) -> Self {
235 self.mistral = token.or(Some(String::new()));
236 self
237 }
238
239 pub fn with_together(mut self, token: Option<String>) -> Self {
241 self.together = token.or(Some(String::new()));
242 self
243 }
244
245 pub fn with_ollama(mut self, url: Option<String>) -> Self {
247 self.ollama = url.or(Some(String::new()));
248 self
249 }
250
251 pub fn without_default_providers(mut self) -> Self {
253 self.anthropic = None;
254 self.anyscale = None;
255 #[cfg(feature = "aws-bedrock")]
256 {
257 self.aws_bedrock = MaybeEmpty::None;
258 }
259 self.deepinfra = None;
260 self.fireworks = None;
261 self.groq = None;
262 self.mistral = None;
263 self.openai = None;
264 self.ollama = None;
265 self.together = None;
266 self
267 }
268
269 pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
272 self.config.user_agent = Some(user_agent.into());
273 self
274 }
275
276 pub fn with_client(mut self, client: reqwest::Client) -> Self {
278 self.client = Some(client);
279 self
280 }
281
282 pub async fn build(self) -> Result<Proxy, Report<Error>> {
284 let mut providers = self.providers;
285 let mut provider_configs = self.config.providers;
286 let mut api_keys = self.config.api_keys;
287 let mut aliases = self.config.aliases;
288 let logger = if let Some(db) = &self.database {
289 if self.load_config_from_database {
290 let db_providers =
291 load_providers_from_database(db.as_ref(), "chronicle_custom_providers").await?;
292 let db_aliases = db
293 .load_aliases_from_database("chronicle_aliases", "chronicle_alias_providers")
294 .await?;
295 let db_api_keys = db
296 .load_api_key_configs_from_database("chronicle_api_keys")
297 .await?;
298
299 provider_configs.extend(db_providers);
300 aliases.extend(db_aliases);
301 api_keys.extend(db_api_keys);
302 }
303
304 let logger = if self.config.log_to_database.unwrap_or(false) {
305 Some(start_database_logger(
306 db.clone(),
307 100,
308 Duration::from_secs(1),
309 ))
310 } else {
311 None
312 };
313
314 logger
315 } else {
316 None
317 };
318
319 let client = self.client.unwrap_or_else(|| {
320 reqwest::Client::builder()
321 .user_agent(self.config.user_agent.as_deref().unwrap_or("chronicle"))
322 .timeout(
323 self.config
324 .default_timeout
325 .unwrap_or(Duration::from_secs(60)),
326 )
327 .build()
328 .unwrap()
329 });
330
331 providers.extend(
332 provider_configs
333 .into_iter()
334 .map(|c| Arc::new(c.into_provider(client.clone())) as Arc<dyn ChatModelProvider>),
335 );
336
337 fn empty_to_none(s: String) -> Option<String> {
338 if s.is_empty() {
339 None
340 } else {
341 Some(s)
342 }
343 }
344
345 if let Some(token) = self.anthropic {
346 providers.push(
347 Arc::new(Anthropic::new(client.clone(), empty_to_none(token)))
348 as Arc<dyn ChatModelProvider>,
349 );
350 }
351
352 if let Some(token) = self.anyscale {
353 providers.push(
354 Arc::new(Anyscale::new(client.clone(), empty_to_none(token)))
355 as Arc<dyn ChatModelProvider>,
356 );
357 }
358
359 #[cfg(feature = "aws-bedrock")]
360 if self.aws_bedrock.is_set() {
361 providers.push(Arc::new(
362 crate::providers::aws_bedrock::AwsBedrock::new(self.aws_bedrock.to_option()).await,
363 ) as Arc<dyn ChatModelProvider>);
364 }
365
366 if let Some(token) = self.deepinfra {
367 providers.push(
368 Arc::new(DeepInfra::new(client.clone(), empty_to_none(token)))
369 as Arc<dyn ChatModelProvider>,
370 );
371 }
372
373 if let Some(token) = self.fireworks {
374 providers.push(
375 Arc::new(Fireworks::new(client.clone(), empty_to_none(token)))
376 as Arc<dyn ChatModelProvider>,
377 );
378 }
379
380 if let Some(token) = self.groq {
381 providers.push(Arc::new(Groq::new(client.clone(), empty_to_none(token)))
382 as Arc<dyn ChatModelProvider>);
383 }
384
385 if let Some(token) = self.mistral {
386 providers.push(Arc::new(Mistral::new(client.clone(), empty_to_none(token)))
387 as Arc<dyn ChatModelProvider>);
388 }
389
390 if let Some(url) = self.ollama {
391 providers.push(Arc::new(Ollama::new(client.clone(), empty_to_none(url)))
392 as Arc<dyn ChatModelProvider>);
393 }
394
395 if let Some(token) = self.openai {
396 providers.push(Arc::new(OpenAi::new(client.clone(), empty_to_none(token)))
397 as Arc<dyn ChatModelProvider>);
398 }
399
400 if let Some(token) = self.together {
401 providers.push(
402 Arc::new(Together::new(client.clone(), empty_to_none(token)))
403 as Arc<dyn ChatModelProvider>,
404 );
405 }
406
407 let (log_tx, log_task) = logger.unzip();
408
409 let api_keys = api_keys
410 .into_iter()
411 .map(|mut config| {
412 if config.source == "env" {
413 let value = std::env::var(&config.value).map_err(|_| {
414 Error::MissingApiKeyEnv(config.name.clone(), config.value.clone())
415 })?;
416
417 config.value = value;
418 }
419
420 Ok::<_, Error>(config)
421 })
422 .collect::<Result<Vec<_>, Error>>()?;
423
424 let lookup = ProviderLookup::new(providers, aliases, api_keys);
425
426 Ok(Proxy {
427 lookup,
428 default_timeout: self.config.default_timeout,
429 log_tx,
430 log_task,
431 })
432 }
433}