1use crate::agent::Agent;
2use crate::client::ProviderClient;
3use crate::completion::{CompletionRequest, Message};
4use crate::embeddings::embedding::EmbeddingModelDyn;
5use crate::providers::{
6 anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
7 moonshot, ollama, openai, openrouter, perplexity, together, xai,
8};
9use crate::streaming::StreamingCompletionResponse;
10use crate::transcription::TranscriptionModelDyn;
11use rig::completion::CompletionModelDyn;
12use std::collections::HashMap;
13use std::panic::{RefUnwindSafe, UnwindSafe};
14use thiserror::Error;
15
16#[derive(Debug, Error)]
17pub enum ClientBuildError {
18 #[error("factory error: {}", .0)]
19 FactoryError(String),
20 #[error("invalid id string: {}", .0)]
21 InvalidIdString(String),
22 #[error("unsupported feature: {} for {}", .1, .0)]
23 UnsupportedFeature(String, String),
24 #[error("unknown provider")]
25 UnknownProvider,
26}
27
28pub type BoxCompletionModel<'a> = Box<dyn CompletionModelDyn + 'a>;
29pub type BoxAgentBuilder<'a> = AgentBuilder<CompletionModelHandle<'a>>;
30pub type BoxAgent<'a> = Agent<CompletionModelHandle<'a>>;
31pub type BoxEmbeddingModel<'a> = Box<dyn EmbeddingModelDyn + 'a>;
32pub type BoxTranscriptionModel<'a> = Box<dyn TranscriptionModelDyn + 'a>;
33
34pub struct DynClientBuilder {
63 registry: HashMap<String, ClientFactory>,
64}
65
66impl Default for DynClientBuilder {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl<'a> DynClientBuilder {
73 pub fn new() -> Self {
77 Self {
78 registry: HashMap::new(),
79 }
80 .register_all(vec![
81 ClientFactory::new(
82 DefaultProviders::ANTHROPIC,
83 anthropic::Client::from_env_boxed,
84 anthropic::Client::from_val_boxed,
85 ),
86 ClientFactory::new(
87 DefaultProviders::COHERE,
88 cohere::Client::from_env_boxed,
89 cohere::Client::from_val_boxed,
90 ),
91 ClientFactory::new(
92 DefaultProviders::GEMINI,
93 gemini::Client::from_env_boxed,
94 gemini::Client::from_val_boxed,
95 ),
96 ClientFactory::new(
97 DefaultProviders::HUGGINGFACE,
98 huggingface::Client::from_env_boxed,
99 huggingface::Client::from_val_boxed,
100 ),
101 ClientFactory::new(
102 DefaultProviders::OPENAI,
103 openai::Client::from_env_boxed,
104 openai::Client::from_val_boxed,
105 ),
106 ClientFactory::new(
107 DefaultProviders::OPENROUTER,
108 openrouter::Client::from_env_boxed,
109 openrouter::Client::from_val_boxed,
110 ),
111 ClientFactory::new(
112 DefaultProviders::TOGETHER,
113 together::Client::from_env_boxed,
114 together::Client::from_val_boxed,
115 ),
116 ClientFactory::new(
117 DefaultProviders::XAI,
118 xai::Client::from_env_boxed,
119 xai::Client::from_val_boxed,
120 ),
121 ClientFactory::new(
122 DefaultProviders::AZURE,
123 azure::Client::from_env_boxed,
124 azure::Client::from_val_boxed,
125 ),
126 ClientFactory::new(
127 DefaultProviders::DEEPSEEK,
128 deepseek::Client::from_env_boxed,
129 deepseek::Client::from_val_boxed,
130 ),
131 ClientFactory::new(
132 DefaultProviders::GALADRIEL,
133 galadriel::Client::from_env_boxed,
134 galadriel::Client::from_val_boxed,
135 ),
136 ClientFactory::new(
137 DefaultProviders::GROQ,
138 groq::Client::from_env_boxed,
139 groq::Client::from_val_boxed,
140 ),
141 ClientFactory::new(
142 DefaultProviders::HYPERBOLIC,
143 hyperbolic::Client::from_env_boxed,
144 hyperbolic::Client::from_val_boxed,
145 ),
146 ClientFactory::new(
147 DefaultProviders::MOONSHOT,
148 moonshot::Client::from_env_boxed,
149 moonshot::Client::from_val_boxed,
150 ),
151 ClientFactory::new(
152 DefaultProviders::MIRA,
153 mira::Client::from_env_boxed,
154 mira::Client::from_val_boxed,
155 ),
156 ClientFactory::new(
157 DefaultProviders::MISTRAL,
158 mistral::Client::from_env_boxed,
159 mistral::Client::from_val_boxed,
160 ),
161 ClientFactory::new(
162 DefaultProviders::OLLAMA,
163 ollama::Client::from_env_boxed,
164 ollama::Client::from_val_boxed,
165 ),
166 ClientFactory::new(
167 DefaultProviders::PERPLEXITY,
168 perplexity::Client::from_env_boxed,
169 perplexity::Client::from_val_boxed,
170 ),
171 ])
172 }
173
174 pub fn empty() -> Self {
176 Self {
177 registry: HashMap::new(),
178 }
179 }
180
181 pub fn register(mut self, client_factory: ClientFactory) -> Self {
183 self.registry
184 .insert(client_factory.name.clone(), client_factory);
185 self
186 }
187
188 pub fn register_all(mut self, factories: impl IntoIterator<Item = ClientFactory>) -> Self {
190 for factory in factories {
191 self.registry.insert(factory.name.clone(), factory);
192 }
193
194 self
195 }
196
197 pub fn build(&self, provider: &str) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
199 let factory = self.get_factory(provider)?;
200 factory.build()
201 }
202
203 pub fn build_val(
205 &self,
206 provider: &str,
207 provider_value: ProviderValue,
208 ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
209 let factory = self.get_factory(provider)?;
210 factory.build_from_val(provider_value)
211 }
212
213 pub fn parse(&self, id: &'a str) -> Result<(&'a str, &'a str), ClientBuildError> {
216 let (provider, model) = id
217 .split_once(":")
218 .ok_or(ClientBuildError::InvalidIdString(id.to_string()))?;
219
220 Ok((provider, model))
221 }
222
223 fn get_factory(&self, provider: &str) -> Result<&ClientFactory, ClientBuildError> {
225 self.registry
226 .get(provider)
227 .ok_or(ClientBuildError::UnknownProvider)
228 }
229
230 pub fn completion(
232 &self,
233 provider: &str,
234 model: &str,
235 ) -> Result<BoxCompletionModel<'a>, ClientBuildError> {
236 let client = self.build(provider)?;
237
238 let completion = client
239 .as_completion()
240 .ok_or(ClientBuildError::UnsupportedFeature(
241 provider.to_string(),
242 "completion".to_owned(),
243 ))?;
244
245 Ok(completion.completion_model(model))
246 }
247
248 pub fn agent(
250 &self,
251 provider: &str,
252 model: &str,
253 ) -> Result<BoxAgentBuilder<'a>, ClientBuildError> {
254 let client = self.build(provider)?;
255
256 let client = client
257 .as_completion()
258 .ok_or(ClientBuildError::UnsupportedFeature(
259 provider.to_string(),
260 "completion".to_string(),
261 ))?;
262
263 Ok(client.agent(model))
264 }
265
266 pub fn agent_with_api_key_val<P>(
268 &self,
269 provider: &str,
270 model: &str,
271 provider_value: P,
272 ) -> Result<BoxAgentBuilder<'a>, ClientBuildError>
273 where
274 P: Into<ProviderValue>,
275 {
276 let client = self.build_val(provider, provider_value.into())?;
277
278 let client = client
279 .as_completion()
280 .ok_or(ClientBuildError::UnsupportedFeature(
281 provider.to_string(),
282 "completion".to_string(),
283 ))?;
284
285 Ok(client.agent(model))
286 }
287
288 pub fn embeddings(
290 &self,
291 provider: &str,
292 model: &str,
293 ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError> {
294 let client = self.build(provider)?;
295
296 let embeddings = client
297 .as_embeddings()
298 .ok_or(ClientBuildError::UnsupportedFeature(
299 provider.to_string(),
300 "embeddings".to_owned(),
301 ))?;
302
303 Ok(embeddings.embedding_model(model))
304 }
305
306 pub fn embeddings_with_api_key_val<P>(
308 &self,
309 provider: &str,
310 model: &str,
311 provider_value: P,
312 ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError>
313 where
314 P: Into<ProviderValue>,
315 {
316 let client = self.build_val(provider, provider_value.into())?;
317
318 let embeddings = client
319 .as_embeddings()
320 .ok_or(ClientBuildError::UnsupportedFeature(
321 provider.to_string(),
322 "embeddings".to_owned(),
323 ))?;
324
325 Ok(embeddings.embedding_model(model))
326 }
327
328 pub fn transcription(
330 &self,
331 provider: &str,
332 model: &str,
333 ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError> {
334 let client = self.build(provider)?;
335 let transcription =
336 client
337 .as_transcription()
338 .ok_or(ClientBuildError::UnsupportedFeature(
339 provider.to_string(),
340 "transcription".to_owned(),
341 ))?;
342
343 Ok(transcription.transcription_model(model))
344 }
345
346 pub fn transcription_with_api_key_val<P>(
348 &self,
349 provider: &str,
350 model: &str,
351 provider_value: P,
352 ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError>
353 where
354 P: Into<ProviderValue>,
355 {
356 let client = self.build_val(provider, provider_value.into())?;
357 let transcription =
358 client
359 .as_transcription()
360 .ok_or(ClientBuildError::UnsupportedFeature(
361 provider.to_string(),
362 "transcription".to_owned(),
363 ))?;
364
365 Ok(transcription.transcription_model(model))
366 }
367
368 pub fn id<'id>(&'a self, id: &'id str) -> Result<ProviderModelId<'a, 'id>, ClientBuildError> {
370 let (provider, model) = self.parse(id)?;
371
372 Ok(ProviderModelId {
373 builder: self,
374 provider,
375 model,
376 })
377 }
378
379 pub async fn stream_completion(
389 &self,
390 provider: &str,
391 model: &str,
392 request: CompletionRequest,
393 ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
394 let client = self.build(provider)?;
395 let completion = client
396 .as_completion()
397 .ok_or(ClientBuildError::UnsupportedFeature(
398 provider.to_string(),
399 "completion".to_string(),
400 ))?;
401
402 let model = completion.completion_model(model);
403 model
404 .stream(request)
405 .await
406 .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
407 }
408
409 pub async fn stream_prompt(
419 &self,
420 provider: &str,
421 model: &str,
422 prompt: impl Into<Message> + Send,
423 ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
424 let client = self.build(provider)?;
425 let completion = client
426 .as_completion()
427 .ok_or(ClientBuildError::UnsupportedFeature(
428 provider.to_string(),
429 "completion".to_string(),
430 ))?;
431
432 let model = completion.completion_model(model);
433 let request = CompletionRequest {
434 preamble: None,
435 tools: vec![],
436 documents: vec![],
437 temperature: None,
438 max_tokens: None,
439 additional_params: None,
440 chat_history: crate::OneOrMany::one(prompt.into()),
441 };
442
443 model
444 .stream(request)
445 .await
446 .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
447 }
448
449 pub async fn stream_chat(
460 &self,
461 provider: &str,
462 model: &str,
463 prompt: impl Into<Message> + Send,
464 chat_history: Vec<Message>,
465 ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
466 let client = self.build(provider)?;
467 let completion = client
468 .as_completion()
469 .ok_or(ClientBuildError::UnsupportedFeature(
470 provider.to_string(),
471 "completion".to_string(),
472 ))?;
473
474 let model = completion.completion_model(model);
475 let mut history = chat_history;
476 history.push(prompt.into());
477
478 let request = CompletionRequest {
479 preamble: None,
480 tools: vec![],
481 documents: vec![],
482 temperature: None,
483 max_tokens: None,
484 additional_params: None,
485 chat_history: crate::OneOrMany::many(history)
486 .unwrap_or_else(|_| crate::OneOrMany::one(Message::user(""))),
487 };
488
489 model
490 .stream(request)
491 .await
492 .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
493 }
494}
495
496pub struct ProviderModelId<'builder, 'id> {
497 builder: &'builder DynClientBuilder,
498 provider: &'id str,
499 model: &'id str,
500}
501
502impl<'builder> ProviderModelId<'builder, '_> {
503 pub fn completion(self) -> Result<BoxCompletionModel<'builder>, ClientBuildError> {
504 self.builder.completion(self.provider, self.model)
505 }
506
507 pub fn agent(self) -> Result<BoxAgentBuilder<'builder>, ClientBuildError> {
508 self.builder.agent(self.provider, self.model)
509 }
510
511 pub fn embedding(self) -> Result<BoxEmbeddingModel<'builder>, ClientBuildError> {
512 self.builder.embeddings(self.provider, self.model)
513 }
514
515 pub fn transcription(self) -> Result<BoxTranscriptionModel<'builder>, ClientBuildError> {
516 self.builder.transcription(self.provider, self.model)
517 }
518
519 pub async fn stream_completion(
527 self,
528 request: CompletionRequest,
529 ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
530 self.builder
531 .stream_completion(self.provider, self.model, request)
532 .await
533 }
534
535 pub async fn stream_prompt(
543 self,
544 prompt: impl Into<Message> + Send,
545 ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
546 self.builder
547 .stream_prompt(self.provider, self.model, prompt)
548 .await
549 }
550
551 pub async fn stream_chat(
560 self,
561 prompt: impl Into<Message> + Send,
562 chat_history: Vec<Message>,
563 ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
564 self.builder
565 .stream_chat(self.provider, self.model, prompt, chat_history)
566 .await
567 }
568}
569
570#[cfg(feature = "image")]
571mod image {
572 use crate::client::builder::ClientBuildError;
573 use crate::image_generation::ImageGenerationModelDyn;
574 use rig::client::builder::{DynClientBuilder, ProviderModelId};
575
576 pub type BoxImageGenerationModel<'a> = Box<dyn ImageGenerationModelDyn + 'a>;
577
578 impl DynClientBuilder {
579 pub fn image_generation<'a>(
580 &self,
581 provider: &str,
582 model: &str,
583 ) -> Result<BoxImageGenerationModel<'a>, ClientBuildError> {
584 let client = self.build(provider)?;
585 let image =
586 client
587 .as_image_generation()
588 .ok_or(ClientBuildError::UnsupportedFeature(
589 provider.to_string(),
590 "image_generation".to_string(),
591 ))?;
592
593 Ok(image.image_generation_model(model))
594 }
595 }
596
597 impl<'builder> ProviderModelId<'builder, '_> {
598 pub fn image_generation(
599 self,
600 ) -> Result<Box<dyn ImageGenerationModelDyn + 'builder>, ClientBuildError> {
601 self.builder.image_generation(self.provider, self.model)
602 }
603 }
604}
605#[cfg(feature = "image")]
606pub use image::*;
607
608#[cfg(feature = "audio")]
609mod audio {
610 use crate::audio_generation::AudioGenerationModelDyn;
611 use crate::client::builder::DynClientBuilder;
612 use crate::client::builder::{ClientBuildError, ProviderModelId};
613
614 pub type BoxAudioGenerationModel<'a> = Box<dyn AudioGenerationModelDyn + 'a>;
615
616 impl DynClientBuilder {
617 pub fn audio_generation<'a>(
618 &self,
619 provider: &str,
620 model: &str,
621 ) -> Result<BoxAudioGenerationModel<'a>, ClientBuildError> {
622 let client = self.build(provider)?;
623 let audio =
624 client
625 .as_audio_generation()
626 .ok_or(ClientBuildError::UnsupportedFeature(
627 provider.to_string(),
628 "audio_generation".to_owned(),
629 ))?;
630
631 Ok(audio.audio_generation_model(model))
632 }
633 }
634
635 impl<'builder> ProviderModelId<'builder, '_> {
636 pub fn audio_generation(
637 self,
638 ) -> Result<Box<dyn AudioGenerationModelDyn + 'builder>, ClientBuildError> {
639 self.builder.audio_generation(self.provider, self.model)
640 }
641 }
642}
643use crate::agent::AgentBuilder;
644use crate::client::completion::CompletionModelHandle;
645#[cfg(feature = "audio")]
646pub use audio::*;
647use rig::providers::mistral;
648
649use super::ProviderValue;
650
651pub struct ClientFactory {
652 pub name: String,
653 pub factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
654 pub factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
655}
656
657impl UnwindSafe for ClientFactory {}
658impl RefUnwindSafe for ClientFactory {}
659
660impl ClientFactory {
661 pub fn new<F1, F2>(name: &str, func_env: F1, func_val: F2) -> Self
662 where
663 F1: 'static + Fn() -> Box<dyn ProviderClient>,
664 F2: 'static + Fn(ProviderValue) -> Box<dyn ProviderClient>,
665 {
666 Self {
667 name: name.to_string(),
668 factory_env: Box::new(func_env),
669 factory_val: Box::new(func_val),
670 }
671 }
672
673 pub fn build(&self) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
674 std::panic::catch_unwind(|| (self.factory_env)())
675 .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
676 }
677
678 pub fn build_from_val(
679 &self,
680 val: ProviderValue,
681 ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
682 std::panic::catch_unwind(|| (self.factory_val)(val))
683 .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
684 }
685}
686
687pub struct DefaultProviders;
688impl DefaultProviders {
689 pub const ANTHROPIC: &'static str = "anthropic";
690 pub const COHERE: &'static str = "cohere";
691 pub const GEMINI: &'static str = "gemini";
692 pub const HUGGINGFACE: &'static str = "huggingface";
693 pub const OPENAI: &'static str = "openai";
694 pub const OPENROUTER: &'static str = "openrouter";
695 pub const TOGETHER: &'static str = "together";
696 pub const XAI: &'static str = "xai";
697 pub const AZURE: &'static str = "azure";
698 pub const DEEPSEEK: &'static str = "deepseek";
699 pub const GALADRIEL: &'static str = "galadriel";
700 pub const GROQ: &'static str = "groq";
701 pub const HYPERBOLIC: &'static str = "hyperbolic";
702 pub const MOONSHOT: &'static str = "moonshot";
703 pub const MIRA: &'static str = "mira";
704 pub const MISTRAL: &'static str = "mistral";
705 pub const OLLAMA: &'static str = "ollama";
706 pub const PERPLEXITY: &'static str = "perplexity";
707}