1use crate::agent::Agent;
2use crate::client::ProviderClient;
3use crate::embeddings::embedding::EmbeddingModelDyn;
4use crate::providers::{
5 anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
6 moonshot, ollama, openai, openrouter, perplexity, together, xai,
7};
8use crate::transcription::TranscriptionModelDyn;
9use rig::completion::CompletionModelDyn;
10use std::collections::HashMap;
11use std::panic::{RefUnwindSafe, UnwindSafe};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum ClientBuildError {
16 #[error("factory error: {}", .0)]
17 FactoryError(String),
18 #[error("invalid id string: {}", .0)]
19 InvalidIdString(String),
20 #[error("unsupported feature: {} for {}", .1, .0)]
21 UnsupportedFeature(String, String),
22 #[error("unknown provider")]
23 UnknownProvider,
24}
25
26pub type BoxCompletionModel<'a> = Box<dyn CompletionModelDyn + 'a>;
27pub type BoxAgentBuilder<'a> = AgentBuilder<CompletionModelHandle<'a>>;
28pub type BoxAgent<'a> = Agent<CompletionModelHandle<'a>>;
29pub type BoxEmbeddingModel<'a> = Box<dyn EmbeddingModelDyn + 'a>;
30pub type BoxTranscriptionModel<'a> = Box<dyn TranscriptionModelDyn + 'a>;
31
32pub struct DynClientBuilder {
61 registry: HashMap<String, ClientFactory>,
62}
63
64impl Default for DynClientBuilder {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl<'a> DynClientBuilder {
71 pub fn new() -> Self {
75 Self {
76 registry: HashMap::new(),
77 }
78 .register_all(vec![
79 ClientFactory::new(
80 DefaultProviders::ANTHROPIC,
81 anthropic::Client::from_env_boxed,
82 anthropic::Client::from_val_boxed,
83 ),
84 ClientFactory::new(
85 DefaultProviders::COHERE,
86 cohere::Client::from_env_boxed,
87 cohere::Client::from_val_boxed,
88 ),
89 ClientFactory::new(
90 DefaultProviders::GEMINI,
91 gemini::Client::from_env_boxed,
92 gemini::Client::from_val_boxed,
93 ),
94 ClientFactory::new(
95 DefaultProviders::HUGGINGFACE,
96 huggingface::Client::from_env_boxed,
97 huggingface::Client::from_val_boxed,
98 ),
99 ClientFactory::new(
100 DefaultProviders::OPENAI,
101 openai::Client::from_env_boxed,
102 openai::Client::from_val_boxed,
103 ),
104 ClientFactory::new(
105 DefaultProviders::OPENROUTER,
106 openrouter::Client::from_env_boxed,
107 openrouter::Client::from_val_boxed,
108 ),
109 ClientFactory::new(
110 DefaultProviders::TOGETHER,
111 together::Client::from_env_boxed,
112 together::Client::from_val_boxed,
113 ),
114 ClientFactory::new(
115 DefaultProviders::XAI,
116 xai::Client::from_env_boxed,
117 xai::Client::from_val_boxed,
118 ),
119 ClientFactory::new(
120 DefaultProviders::AZURE,
121 azure::Client::from_env_boxed,
122 azure::Client::from_val_boxed,
123 ),
124 ClientFactory::new(
125 DefaultProviders::DEEPSEEK,
126 deepseek::Client::from_env_boxed,
127 deepseek::Client::from_val_boxed,
128 ),
129 ClientFactory::new(
130 DefaultProviders::GALADRIEL,
131 galadriel::Client::from_env_boxed,
132 galadriel::Client::from_val_boxed,
133 ),
134 ClientFactory::new(
135 DefaultProviders::GROQ,
136 groq::Client::from_env_boxed,
137 groq::Client::from_val_boxed,
138 ),
139 ClientFactory::new(
140 DefaultProviders::HYPERBOLIC,
141 hyperbolic::Client::from_env_boxed,
142 hyperbolic::Client::from_val_boxed,
143 ),
144 ClientFactory::new(
145 DefaultProviders::MOONSHOT,
146 moonshot::Client::from_env_boxed,
147 moonshot::Client::from_val_boxed,
148 ),
149 ClientFactory::new(
150 DefaultProviders::MIRA,
151 mira::Client::from_env_boxed,
152 mira::Client::from_val_boxed,
153 ),
154 ClientFactory::new(
155 DefaultProviders::MISTRAL,
156 mistral::Client::from_env_boxed,
157 mistral::Client::from_val_boxed,
158 ),
159 ClientFactory::new(
160 DefaultProviders::OLLAMA,
161 ollama::Client::from_env_boxed,
162 ollama::Client::from_val_boxed,
163 ),
164 ClientFactory::new(
165 DefaultProviders::PERPLEXITY,
166 perplexity::Client::from_env_boxed,
167 perplexity::Client::from_val_boxed,
168 ),
169 ])
170 }
171
172 pub fn empty() -> Self {
174 Self {
175 registry: HashMap::new(),
176 }
177 }
178
179 pub fn register(mut self, client_factory: ClientFactory) -> Self {
181 self.registry
182 .insert(client_factory.name.clone(), client_factory);
183 self
184 }
185
186 pub fn register_all(mut self, factories: impl IntoIterator<Item = ClientFactory>) -> Self {
188 for factory in factories {
189 self.registry.insert(factory.name.clone(), factory);
190 }
191
192 self
193 }
194
195 pub fn build(&self, provider: &str) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
197 let factory = self.get_factory(provider)?;
198 factory.build()
199 }
200
201 pub fn build_val(
203 &self,
204 provider: &str,
205 provider_value: ProviderValue,
206 ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
207 let factory = self.get_factory(provider)?;
208 factory.build_from_val(provider_value)
209 }
210
211 pub fn parse(&self, id: &'a str) -> Result<(&'a str, &'a str), ClientBuildError> {
214 let (provider, model) = id
215 .split_once(":")
216 .ok_or(ClientBuildError::InvalidIdString(id.to_string()))?;
217
218 Ok((provider, model))
219 }
220
221 fn get_factory(&self, provider: &str) -> Result<&ClientFactory, ClientBuildError> {
223 self.registry
224 .get(provider)
225 .ok_or(ClientBuildError::UnknownProvider)
226 }
227
228 pub fn completion(
230 &self,
231 provider: &str,
232 model: &str,
233 ) -> Result<BoxCompletionModel<'a>, ClientBuildError> {
234 let client = self.build(provider)?;
235
236 let completion = client
237 .as_completion()
238 .ok_or(ClientBuildError::UnsupportedFeature(
239 provider.to_string(),
240 "completion".to_owned(),
241 ))?;
242
243 Ok(completion.completion_model(model))
244 }
245
246 pub fn agent(
248 &self,
249 provider: &str,
250 model: &str,
251 ) -> Result<BoxAgentBuilder<'a>, ClientBuildError> {
252 let client = self.build(provider)?;
253
254 let client = client
255 .as_completion()
256 .ok_or(ClientBuildError::UnsupportedFeature(
257 provider.to_string(),
258 "completion".to_string(),
259 ))?;
260
261 Ok(client.agent(model))
262 }
263
264 pub fn agent_with_api_key_val<P>(
266 &self,
267 provider: &str,
268 model: &str,
269 provider_value: P,
270 ) -> Result<BoxAgentBuilder<'a>, ClientBuildError>
271 where
272 P: Into<ProviderValue>,
273 {
274 let client = self.build_val(provider, provider_value.into())?;
275
276 let client = client
277 .as_completion()
278 .ok_or(ClientBuildError::UnsupportedFeature(
279 provider.to_string(),
280 "completion".to_string(),
281 ))?;
282
283 Ok(client.agent(model))
284 }
285
286 pub fn embeddings(
288 &self,
289 provider: &str,
290 model: &str,
291 ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError> {
292 let client = self.build(provider)?;
293
294 let embeddings = client
295 .as_embeddings()
296 .ok_or(ClientBuildError::UnsupportedFeature(
297 provider.to_string(),
298 "embeddings".to_owned(),
299 ))?;
300
301 Ok(embeddings.embedding_model(model))
302 }
303
304 pub fn embeddings_with_api_key_val<P>(
306 &self,
307 provider: &str,
308 model: &str,
309 provider_value: P,
310 ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError>
311 where
312 P: Into<ProviderValue>,
313 {
314 let client = self.build_val(provider, provider_value.into())?;
315
316 let embeddings = client
317 .as_embeddings()
318 .ok_or(ClientBuildError::UnsupportedFeature(
319 provider.to_string(),
320 "embeddings".to_owned(),
321 ))?;
322
323 Ok(embeddings.embedding_model(model))
324 }
325
326 pub fn transcription(
328 &self,
329 provider: &str,
330 model: &str,
331 ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError> {
332 let client = self.build(provider)?;
333 let transcription =
334 client
335 .as_transcription()
336 .ok_or(ClientBuildError::UnsupportedFeature(
337 provider.to_string(),
338 "transcription".to_owned(),
339 ))?;
340
341 Ok(transcription.transcription_model(model))
342 }
343
344 pub fn transcription_with_api_key_val<P>(
346 &self,
347 provider: &str,
348 model: &str,
349 provider_value: P,
350 ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError>
351 where
352 P: Into<ProviderValue>,
353 {
354 let client = self.build_val(provider, provider_value.into())?;
355 let transcription =
356 client
357 .as_transcription()
358 .ok_or(ClientBuildError::UnsupportedFeature(
359 provider.to_string(),
360 "transcription".to_owned(),
361 ))?;
362
363 Ok(transcription.transcription_model(model))
364 }
365
366 pub fn id<'id>(&'a self, id: &'id str) -> Result<ProviderModelId<'a, 'id>, ClientBuildError> {
368 let (provider, model) = self.parse(id)?;
369
370 Ok(ProviderModelId {
371 builder: self,
372 provider,
373 model,
374 })
375 }
376}
377
378pub struct ProviderModelId<'builder, 'id> {
379 builder: &'builder DynClientBuilder,
380 provider: &'id str,
381 model: &'id str,
382}
383
384impl<'builder> ProviderModelId<'builder, '_> {
385 pub fn completion(self) -> Result<BoxCompletionModel<'builder>, ClientBuildError> {
386 self.builder.completion(self.provider, self.model)
387 }
388
389 pub fn agent(self) -> Result<BoxAgentBuilder<'builder>, ClientBuildError> {
390 self.builder.agent(self.provider, self.model)
391 }
392
393 pub fn embedding(self) -> Result<BoxEmbeddingModel<'builder>, ClientBuildError> {
394 self.builder.embeddings(self.provider, self.model)
395 }
396
397 pub fn transcription(self) -> Result<BoxTranscriptionModel<'builder>, ClientBuildError> {
398 self.builder.transcription(self.provider, self.model)
399 }
400}
401
402#[cfg(feature = "image")]
403mod image {
404 use crate::client::builder::ClientBuildError;
405 use crate::image_generation::ImageGenerationModelDyn;
406 use rig::client::builder::{DynClientBuilder, ProviderModelId};
407
408 pub type BoxImageGenerationModel<'a> = Box<dyn ImageGenerationModelDyn + 'a>;
409
410 impl DynClientBuilder {
411 pub fn image_generation<'a>(
412 &self,
413 provider: &str,
414 model: &str,
415 ) -> Result<BoxImageGenerationModel<'a>, ClientBuildError> {
416 let client = self.build(provider)?;
417 let image =
418 client
419 .as_image_generation()
420 .ok_or(ClientBuildError::UnsupportedFeature(
421 provider.to_string(),
422 "image_generation".to_string(),
423 ))?;
424
425 Ok(image.image_generation_model(model))
426 }
427 }
428
429 impl<'builder> ProviderModelId<'builder, '_> {
430 pub fn image_generation(
431 self,
432 ) -> Result<Box<dyn ImageGenerationModelDyn + 'builder>, ClientBuildError> {
433 self.builder.image_generation(self.provider, self.model)
434 }
435 }
436}
437#[cfg(feature = "image")]
438pub use image::*;
439
440#[cfg(feature = "audio")]
441mod audio {
442 use crate::audio_generation::AudioGenerationModelDyn;
443 use crate::client::builder::DynClientBuilder;
444 use crate::client::builder::{ClientBuildError, ProviderModelId};
445
446 pub type BoxAudioGenerationModel<'a> = Box<dyn AudioGenerationModelDyn + 'a>;
447
448 impl DynClientBuilder {
449 pub fn audio_generation<'a>(
450 &self,
451 provider: &str,
452 model: &str,
453 ) -> Result<BoxAudioGenerationModel<'a>, ClientBuildError> {
454 let client = self.build(provider)?;
455 let audio =
456 client
457 .as_audio_generation()
458 .ok_or(ClientBuildError::UnsupportedFeature(
459 provider.to_string(),
460 "audio_generation".to_owned(),
461 ))?;
462
463 Ok(audio.audio_generation_model(model))
464 }
465 }
466
467 impl<'builder> ProviderModelId<'builder, '_> {
468 pub fn audio_generation(
469 self,
470 ) -> Result<Box<dyn AudioGenerationModelDyn + 'builder>, ClientBuildError> {
471 self.builder.audio_generation(self.provider, self.model)
472 }
473 }
474}
475use crate::agent::AgentBuilder;
476use crate::client::completion::CompletionModelHandle;
477#[cfg(feature = "audio")]
478pub use audio::*;
479use rig::providers::mistral;
480
481use super::ProviderValue;
482
483pub struct ClientFactory {
484 pub name: String,
485 pub factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
486 pub factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
487}
488
489impl UnwindSafe for ClientFactory {}
490impl RefUnwindSafe for ClientFactory {}
491
492impl ClientFactory {
493 pub fn new<F1, F2>(name: &str, func_env: F1, func_val: F2) -> Self
494 where
495 F1: 'static + Fn() -> Box<dyn ProviderClient>,
496 F2: 'static + Fn(ProviderValue) -> Box<dyn ProviderClient>,
497 {
498 Self {
499 name: name.to_string(),
500 factory_env: Box::new(func_env),
501 factory_val: Box::new(func_val),
502 }
503 }
504
505 pub fn build(&self) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
506 std::panic::catch_unwind(|| (self.factory_env)())
507 .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
508 }
509
510 pub fn build_from_val(
511 &self,
512 val: ProviderValue,
513 ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
514 std::panic::catch_unwind(|| (self.factory_val)(val))
515 .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
516 }
517}
518
519pub struct DefaultProviders;
520impl DefaultProviders {
521 pub const ANTHROPIC: &'static str = "anthropic";
522 pub const COHERE: &'static str = "cohere";
523 pub const GEMINI: &'static str = "gemini";
524 pub const HUGGINGFACE: &'static str = "huggingface";
525 pub const OPENAI: &'static str = "openai";
526 pub const OPENROUTER: &'static str = "openrouter";
527 pub const TOGETHER: &'static str = "together";
528 pub const XAI: &'static str = "xai";
529 pub const AZURE: &'static str = "azure";
530 pub const DEEPSEEK: &'static str = "deepseek";
531 pub const GALADRIEL: &'static str = "galadriel";
532 pub const GROQ: &'static str = "groq";
533 pub const HYPERBOLIC: &'static str = "hyperbolic";
534 pub const MOONSHOT: &'static str = "moonshot";
535 pub const MIRA: &'static str = "mira";
536 pub const MISTRAL: &'static str = "mistral";
537 pub const OLLAMA: &'static str = "ollama";
538 pub const PERPLEXITY: &'static str = "perplexity";
539}