1use crate::agent::Agent;
2use crate::client::ProviderClient;
3use crate::completion::{CompletionRequest, GetTokenUsage, Message, Usage};
4use crate::embeddings::embedding::EmbeddingModelDyn;
5use crate::providers::{
6 anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
7 moonshot, ollama, openai, openrouter, perplexity, together, xai,
8};
9use crate::streaming::StreamingCompletionResponse;
10use crate::transcription::TranscriptionModelDyn;
11use rig::completion::CompletionModelDyn;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::panic::{RefUnwindSafe, UnwindSafe};
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18pub enum ClientBuildError {
19 #[error("factory error: {}", .0)]
20 FactoryError(String),
21 #[error("invalid id string: {}", .0)]
22 InvalidIdString(String),
23 #[error("unsupported feature: {} for {}", .1, .0)]
24 UnsupportedFeature(String, String),
25 #[error("unknown provider")]
26 UnknownProvider,
27}
28
29pub type BoxCompletionModel<'a> = Box<dyn CompletionModelDyn + 'a>;
30pub type BoxAgentBuilder<'a> = AgentBuilder<CompletionModelHandle<'a>>;
31pub type BoxAgent<'a> = Agent<CompletionModelHandle<'a>>;
32pub type BoxEmbeddingModel<'a> = Box<dyn EmbeddingModelDyn + 'a>;
33pub type BoxTranscriptionModel<'a> = Box<dyn TranscriptionModelDyn + 'a>;
34
35pub struct DynClientBuilder {
64 registry: HashMap<String, ClientFactory>,
65}
66
67impl Default for DynClientBuilder {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl<'a> DynClientBuilder {
74 pub fn new() -> Self {
78 Self {
79 registry: HashMap::new(),
80 }
81 .register_all(vec![
82 ClientFactory::new(
83 DefaultProviders::ANTHROPIC,
84 anthropic::Client::<reqwest::Client>::from_env_boxed,
85 anthropic::Client::<reqwest::Client>::from_val_boxed,
86 ),
87 ClientFactory::new(
88 DefaultProviders::COHERE,
89 cohere::Client::<reqwest::Client>::from_env_boxed,
90 cohere::Client::<reqwest::Client>::from_val_boxed,
91 ),
92 ClientFactory::new(
93 DefaultProviders::GEMINI,
94 gemini::Client::<reqwest::Client>::from_env_boxed,
95 gemini::Client::<reqwest::Client>::from_val_boxed,
96 ),
97 ClientFactory::new(
98 DefaultProviders::HUGGINGFACE,
99 huggingface::Client::<reqwest::Client>::from_env_boxed,
100 huggingface::Client::<reqwest::Client>::from_val_boxed,
101 ),
102 ClientFactory::new(
103 DefaultProviders::OPENAI,
104 openai::Client::<reqwest::Client>::from_env_boxed,
105 openai::Client::<reqwest::Client>::from_val_boxed,
106 ),
107 ClientFactory::new(
108 DefaultProviders::OPENROUTER,
109 openrouter::Client::<reqwest::Client>::from_env_boxed,
110 openrouter::Client::<reqwest::Client>::from_val_boxed,
111 ),
112 ClientFactory::new(
113 DefaultProviders::TOGETHER,
114 together::Client::<reqwest::Client>::from_env_boxed,
115 together::Client::<reqwest::Client>::from_val_boxed,
116 ),
117 ClientFactory::new(
118 DefaultProviders::XAI,
119 xai::Client::<reqwest::Client>::from_env_boxed,
120 xai::Client::<reqwest::Client>::from_val_boxed,
121 ),
122 ClientFactory::new(
123 DefaultProviders::AZURE,
124 azure::Client::<reqwest::Client>::from_env_boxed,
125 azure::Client::<reqwest::Client>::from_val_boxed,
126 ),
127 ClientFactory::new(
128 DefaultProviders::DEEPSEEK,
129 deepseek::Client::<reqwest::Client>::from_env_boxed,
130 deepseek::Client::<reqwest::Client>::from_val_boxed,
131 ),
132 ClientFactory::new(
133 DefaultProviders::GALADRIEL,
134 galadriel::Client::<reqwest::Client>::from_env_boxed,
135 galadriel::Client::<reqwest::Client>::from_val_boxed,
136 ),
137 ClientFactory::new(
138 DefaultProviders::GROQ,
139 groq::Client::<reqwest::Client>::from_env_boxed,
140 groq::Client::<reqwest::Client>::from_val_boxed,
141 ),
142 ClientFactory::new(
143 DefaultProviders::HYPERBOLIC,
144 hyperbolic::Client::<reqwest::Client>::from_env_boxed,
145 hyperbolic::Client::<reqwest::Client>::from_val_boxed,
146 ),
147 ClientFactory::new(
148 DefaultProviders::MOONSHOT,
149 moonshot::Client::<reqwest::Client>::from_env_boxed,
150 moonshot::Client::<reqwest::Client>::from_val_boxed,
151 ),
152 ClientFactory::new(
153 DefaultProviders::MIRA,
154 mira::Client::<reqwest::Client>::from_env_boxed,
155 mira::Client::<reqwest::Client>::from_val_boxed,
156 ),
157 ClientFactory::new(
158 DefaultProviders::MISTRAL,
159 mistral::Client::<reqwest::Client>::from_env_boxed,
160 mistral::Client::<reqwest::Client>::from_val_boxed,
161 ),
162 ClientFactory::new(
163 DefaultProviders::OLLAMA,
164 ollama::Client::<reqwest::Client>::from_env_boxed,
165 ollama::Client::<reqwest::Client>::from_val_boxed,
166 ),
167 ClientFactory::new(
168 DefaultProviders::PERPLEXITY,
169 perplexity::Client::<reqwest::Client>::from_env_boxed,
170 perplexity::Client::<reqwest::Client>::from_val_boxed,
171 ),
172 ])
173 }
174
175 pub fn empty() -> Self {
177 Self {
178 registry: HashMap::new(),
179 }
180 }
181
182 pub fn register(mut self, client_factory: ClientFactory) -> Self {
184 self.registry
185 .insert(client_factory.name.clone(), client_factory);
186 self
187 }
188
189 pub fn register_all(mut self, factories: impl IntoIterator<Item = ClientFactory>) -> Self {
191 for factory in factories {
192 self.registry.insert(factory.name.clone(), factory);
193 }
194
195 self
196 }
197
198 pub fn build(&self, provider: &str) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
200 let factory = self.get_factory(provider)?;
201 factory.build()
202 }
203
204 pub fn build_val(
206 &self,
207 provider: &str,
208 provider_value: ProviderValue,
209 ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
210 let factory = self.get_factory(provider)?;
211 factory.build_from_val(provider_value)
212 }
213
214 pub fn parse(&self, id: &'a str) -> Result<(&'a str, &'a str), ClientBuildError> {
217 let (provider, model) = id
218 .split_once(":")
219 .ok_or(ClientBuildError::InvalidIdString(id.to_string()))?;
220
221 Ok((provider, model))
222 }
223
224 fn get_factory(&self, provider: &str) -> Result<&ClientFactory, ClientBuildError> {
226 self.registry
227 .get(provider)
228 .ok_or(ClientBuildError::UnknownProvider)
229 }
230
231 pub fn completion(
233 &self,
234 provider: &str,
235 model: &str,
236 ) -> Result<BoxCompletionModel<'a>, ClientBuildError> {
237 let client = self.build(provider)?;
238
239 let completion = client
240 .as_completion()
241 .ok_or(ClientBuildError::UnsupportedFeature(
242 provider.to_string(),
243 "completion".to_owned(),
244 ))?;
245
246 Ok(completion.completion_model(model))
247 }
248
249 pub fn agent(
251 &self,
252 provider: &str,
253 model: &str,
254 ) -> Result<BoxAgentBuilder<'a>, ClientBuildError> {
255 let client = self.build(provider)?;
256
257 let client = client
258 .as_completion()
259 .ok_or(ClientBuildError::UnsupportedFeature(
260 provider.to_string(),
261 "completion".to_string(),
262 ))?;
263
264 Ok(client.agent(model))
265 }
266
267 pub fn agent_with_api_key_val<P>(
269 &self,
270 provider: &str,
271 model: &str,
272 provider_value: P,
273 ) -> Result<BoxAgentBuilder<'a>, ClientBuildError>
274 where
275 P: Into<ProviderValue>,
276 {
277 let client = self.build_val(provider, provider_value.into())?;
278
279 let client = client
280 .as_completion()
281 .ok_or(ClientBuildError::UnsupportedFeature(
282 provider.to_string(),
283 "completion".to_string(),
284 ))?;
285
286 Ok(client.agent(model))
287 }
288
289 pub fn embeddings(
291 &self,
292 provider: &str,
293 model: &str,
294 ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError> {
295 let client = self.build(provider)?;
296
297 let embeddings = client
298 .as_embeddings()
299 .ok_or(ClientBuildError::UnsupportedFeature(
300 provider.to_string(),
301 "embeddings".to_owned(),
302 ))?;
303
304 Ok(embeddings.embedding_model(model))
305 }
306
307 pub fn embeddings_with_api_key_val<P>(
309 &self,
310 provider: &str,
311 model: &str,
312 provider_value: P,
313 ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError>
314 where
315 P: Into<ProviderValue>,
316 {
317 let client = self.build_val(provider, provider_value.into())?;
318
319 let embeddings = client
320 .as_embeddings()
321 .ok_or(ClientBuildError::UnsupportedFeature(
322 provider.to_string(),
323 "embeddings".to_owned(),
324 ))?;
325
326 Ok(embeddings.embedding_model(model))
327 }
328
329 pub fn transcription(
331 &self,
332 provider: &str,
333 model: &str,
334 ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError> {
335 let client = self.build(provider)?;
336 let transcription =
337 client
338 .as_transcription()
339 .ok_or(ClientBuildError::UnsupportedFeature(
340 provider.to_string(),
341 "transcription".to_owned(),
342 ))?;
343
344 Ok(transcription.transcription_model(model))
345 }
346
347 pub fn transcription_with_api_key_val<P>(
349 &self,
350 provider: &str,
351 model: &str,
352 provider_value: P,
353 ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError>
354 where
355 P: Into<ProviderValue>,
356 {
357 let client = self.build_val(provider, provider_value.into())?;
358 let transcription =
359 client
360 .as_transcription()
361 .ok_or(ClientBuildError::UnsupportedFeature(
362 provider.to_string(),
363 "transcription".to_owned(),
364 ))?;
365
366 Ok(transcription.transcription_model(model))
367 }
368
369 pub fn id<'id>(&'a self, id: &'id str) -> Result<ProviderModelId<'a, 'id>, ClientBuildError> {
371 let (provider, model) = self.parse(id)?;
372
373 Ok(ProviderModelId {
374 builder: self,
375 provider,
376 model,
377 })
378 }
379
380 pub async fn stream_completion(
390 &self,
391 provider: &str,
392 model: &str,
393 request: CompletionRequest,
394 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
395 let client = self.build(provider)?;
396 let completion = client
397 .as_completion()
398 .ok_or(ClientBuildError::UnsupportedFeature(
399 provider.to_string(),
400 "completion".to_string(),
401 ))?;
402
403 let model = completion.completion_model(model);
404 model
405 .stream(request)
406 .await
407 .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
408 }
409
410 pub async fn stream_prompt(
420 &self,
421 provider: &str,
422 model: &str,
423 prompt: impl Into<Message> + Send,
424 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
425 let client = self.build(provider)?;
426 let completion = client
427 .as_completion()
428 .ok_or(ClientBuildError::UnsupportedFeature(
429 provider.to_string(),
430 "completion".to_string(),
431 ))?;
432
433 let model = completion.completion_model(model);
434 let request = CompletionRequest {
435 preamble: None,
436 tools: vec![],
437 documents: vec![],
438 temperature: None,
439 max_tokens: None,
440 additional_params: None,
441 tool_choice: None,
442 chat_history: crate::OneOrMany::one(prompt.into()),
443 };
444
445 model
446 .stream(request)
447 .await
448 .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
449 }
450
451 pub async fn stream_chat(
462 &self,
463 provider: &str,
464 model: &str,
465 prompt: impl Into<Message> + Send,
466 chat_history: Vec<Message>,
467 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
468 let client = self.build(provider)?;
469 let completion = client
470 .as_completion()
471 .ok_or(ClientBuildError::UnsupportedFeature(
472 provider.to_string(),
473 "completion".to_string(),
474 ))?;
475
476 let model = completion.completion_model(model);
477 let mut history = chat_history;
478 history.push(prompt.into());
479
480 let request = CompletionRequest {
481 preamble: None,
482 tools: vec![],
483 documents: vec![],
484 temperature: None,
485 max_tokens: None,
486 additional_params: None,
487 tool_choice: None,
488 chat_history: crate::OneOrMany::many(history)
489 .unwrap_or_else(|_| crate::OneOrMany::one(Message::user(""))),
490 };
491
492 model
493 .stream(request)
494 .await
495 .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
496 }
497}
498
499#[derive(Debug, Deserialize, Clone, Serialize)]
501pub struct FinalCompletionResponse {
502 pub usage: Option<Usage>,
503}
504
505impl GetTokenUsage for FinalCompletionResponse {
506 fn token_usage(&self) -> Option<Usage> {
507 self.usage
508 }
509}
510
511pub struct ProviderModelId<'builder, 'id> {
512 builder: &'builder DynClientBuilder,
513 provider: &'id str,
514 model: &'id str,
515}
516
517impl<'builder> ProviderModelId<'builder, '_> {
518 pub fn completion(self) -> Result<BoxCompletionModel<'builder>, ClientBuildError> {
519 self.builder.completion(self.provider, self.model)
520 }
521
522 pub fn agent(self) -> Result<BoxAgentBuilder<'builder>, ClientBuildError> {
523 self.builder.agent(self.provider, self.model)
524 }
525
526 pub fn embedding(self) -> Result<BoxEmbeddingModel<'builder>, ClientBuildError> {
527 self.builder.embeddings(self.provider, self.model)
528 }
529
530 pub fn transcription(self) -> Result<BoxTranscriptionModel<'builder>, ClientBuildError> {
531 self.builder.transcription(self.provider, self.model)
532 }
533
534 pub async fn stream_completion(
542 self,
543 request: CompletionRequest,
544 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
545 self.builder
546 .stream_completion(self.provider, self.model, request)
547 .await
548 }
549
550 pub async fn stream_prompt(
558 self,
559 prompt: impl Into<Message> + Send,
560 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
561 self.builder
562 .stream_prompt(self.provider, self.model, prompt)
563 .await
564 }
565
566 pub async fn stream_chat(
575 self,
576 prompt: impl Into<Message> + Send,
577 chat_history: Vec<Message>,
578 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, ClientBuildError> {
579 self.builder
580 .stream_chat(self.provider, self.model, prompt, chat_history)
581 .await
582 }
583}
584
585#[cfg(feature = "image")]
586mod image {
587 use crate::client::builder::ClientBuildError;
588 use crate::image_generation::ImageGenerationModelDyn;
589 use rig::client::builder::{DynClientBuilder, ProviderModelId};
590
591 pub type BoxImageGenerationModel<'a> = Box<dyn ImageGenerationModelDyn + 'a>;
592
593 impl DynClientBuilder {
594 pub fn image_generation<'a>(
595 &self,
596 provider: &str,
597 model: &str,
598 ) -> Result<BoxImageGenerationModel<'a>, ClientBuildError> {
599 let client = self.build(provider)?;
600 let image =
601 client
602 .as_image_generation()
603 .ok_or(ClientBuildError::UnsupportedFeature(
604 provider.to_string(),
605 "image_generation".to_string(),
606 ))?;
607
608 Ok(image.image_generation_model(model))
609 }
610 }
611
612 impl<'builder> ProviderModelId<'builder, '_> {
613 pub fn image_generation(
614 self,
615 ) -> Result<Box<dyn ImageGenerationModelDyn + 'builder>, ClientBuildError> {
616 self.builder.image_generation(self.provider, self.model)
617 }
618 }
619}
620#[cfg(feature = "image")]
621pub use image::*;
622
623#[cfg(feature = "audio")]
624mod audio {
625 use crate::audio_generation::AudioGenerationModelDyn;
626 use crate::client::builder::DynClientBuilder;
627 use crate::client::builder::{ClientBuildError, ProviderModelId};
628
629 pub type BoxAudioGenerationModel<'a> = Box<dyn AudioGenerationModelDyn + 'a>;
630
631 impl DynClientBuilder {
632 pub fn audio_generation<'a>(
633 &self,
634 provider: &str,
635 model: &str,
636 ) -> Result<BoxAudioGenerationModel<'a>, ClientBuildError> {
637 let client = self.build(provider)?;
638 let audio =
639 client
640 .as_audio_generation()
641 .ok_or(ClientBuildError::UnsupportedFeature(
642 provider.to_string(),
643 "audio_generation".to_owned(),
644 ))?;
645
646 Ok(audio.audio_generation_model(model))
647 }
648 }
649
650 impl<'builder> ProviderModelId<'builder, '_> {
651 pub fn audio_generation(
652 self,
653 ) -> Result<Box<dyn AudioGenerationModelDyn + 'builder>, ClientBuildError> {
654 self.builder.audio_generation(self.provider, self.model)
655 }
656 }
657}
658use crate::agent::AgentBuilder;
659use crate::client::completion::CompletionModelHandle;
660#[cfg(feature = "audio")]
661pub use audio::*;
662use rig::providers::mistral;
663
664use super::ProviderValue;
665
666pub struct ClientFactory {
667 pub name: String,
668 pub factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
669 pub factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
670}
671
672impl UnwindSafe for ClientFactory {}
673impl RefUnwindSafe for ClientFactory {}
674
675impl ClientFactory {
676 pub fn new<F1, F2>(name: &str, func_env: F1, func_val: F2) -> Self
677 where
678 F1: 'static + Fn() -> Box<dyn ProviderClient>,
679 F2: 'static + Fn(ProviderValue) -> Box<dyn ProviderClient>,
680 {
681 Self {
682 name: name.to_string(),
683 factory_env: Box::new(func_env),
684 factory_val: Box::new(func_val),
685 }
686 }
687
688 pub fn build(&self) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
689 std::panic::catch_unwind(|| (self.factory_env)())
690 .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
691 }
692
693 pub fn build_from_val(
694 &self,
695 val: ProviderValue,
696 ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
697 std::panic::catch_unwind(|| (self.factory_val)(val))
698 .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
699 }
700}
701
702pub struct DefaultProviders;
703impl DefaultProviders {
704 pub const ANTHROPIC: &'static str = "anthropic";
705 pub const COHERE: &'static str = "cohere";
706 pub const GEMINI: &'static str = "gemini";
707 pub const HUGGINGFACE: &'static str = "huggingface";
708 pub const OPENAI: &'static str = "openai";
709 pub const OPENROUTER: &'static str = "openrouter";
710 pub const TOGETHER: &'static str = "together";
711 pub const XAI: &'static str = "xai";
712 pub const AZURE: &'static str = "azure";
713 pub const DEEPSEEK: &'static str = "deepseek";
714 pub const GALADRIEL: &'static str = "galadriel";
715 pub const GROQ: &'static str = "groq";
716 pub const HYPERBOLIC: &'static str = "hyperbolic";
717 pub const MOONSHOT: &'static str = "moonshot";
718 pub const MIRA: &'static str = "mira";
719 pub const MISTRAL: &'static str = "mistral";
720 pub const OLLAMA: &'static str = "ollama";
721 pub const PERPLEXITY: &'static str = "perplexity";
722}