rig/client/
builder.rs

1use crate::agent::Agent;
2use crate::client::ProviderClient;
3use crate::embeddings::embedding::EmbeddingModelDyn;
4use crate::providers::{
5    anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
6    moonshot, ollama, openai, openrouter, perplexity, together, xai,
7};
8use crate::transcription::TranscriptionModelDyn;
9use rig::completion::CompletionModelDyn;
10use std::collections::HashMap;
11use std::panic::{RefUnwindSafe, UnwindSafe};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum ClientBuildError {
16    #[error("factory error: {}", .0)]
17    FactoryError(String),
18    #[error("invalid id string: {}", .0)]
19    InvalidIdString(String),
20    #[error("unsupported feature: {} for {}", .1, .0)]
21    UnsupportedFeature(String, String),
22    #[error("unknown provider")]
23    UnknownProvider,
24}
25
26pub type BoxCompletionModel<'a> = Box<dyn CompletionModelDyn + 'a>;
27pub type BoxAgentBuilder<'a> = AgentBuilder<CompletionModelHandle<'a>>;
28pub type BoxAgent<'a> = Agent<CompletionModelHandle<'a>>;
29pub type BoxEmbeddingModel<'a> = Box<dyn EmbeddingModelDyn + 'a>;
30pub type BoxTranscriptionModel<'a> = Box<dyn TranscriptionModelDyn + 'a>;
31
32/// A dynamic client builder.
33/// Use this when you need to support creating any kind of client from a range of LLM providers (that Rig supports).
34/// Usage:
35/// ```rust
36/// use rig::{
37///     client::builder::DynClientBuilder, completion::Prompt, providers::anthropic::CLAUDE_3_7_SONNET,
38/// };
39/// #[tokio::main]
40/// async fn main() {
41///     let multi_client = DynClientBuilder::new();
42///     // set up OpenAI client
43///     let completion_openai = multi_client.agent("openai", "gpt-4o").unwrap();
44///     let agent_openai = completion_openai
45///         .preamble("You are a helpful assistant")
46///         .build();
47///     // set up Anthropic client
48///     let completion_anthropic = multi_client.agent("anthropic", CLAUDE_3_7_SONNET).unwrap();
49///     let agent_anthropic = completion_anthropic
50///         .preamble("You are a helpful assistant")
51///         .max_tokens(1024)
52///         .build();
53///     println!("Sending prompt: 'Hello world!'");
54///     let res_openai = agent_openai.prompt("Hello world!").await.unwrap();
55///     println!("Response from OpenAI (using gpt-4o): {res_openai}");
56///     let res_anthropic = agent_anthropic.prompt("Hello world!").await.unwrap();
57///     println!("Response from Anthropic (using Claude 3.7 Sonnet): {res_anthropic}");
58/// }
59/// ```
60pub struct DynClientBuilder {
61    registry: HashMap<String, ClientFactory>,
62}
63
64impl Default for DynClientBuilder {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl<'a> DynClientBuilder {
71    /// Generate a new instance of `DynClientBuilder`.
72    /// By default, every single possible client that can be registered
73    /// will be registered to the client builder.
74    pub fn new() -> Self {
75        Self {
76            registry: HashMap::new(),
77        }
78        .register_all(vec![
79            ClientFactory::new(
80                DefaultProviders::ANTHROPIC,
81                anthropic::Client::from_env_boxed,
82                anthropic::Client::from_val_boxed,
83            ),
84            ClientFactory::new(
85                DefaultProviders::COHERE,
86                cohere::Client::from_env_boxed,
87                cohere::Client::from_val_boxed,
88            ),
89            ClientFactory::new(
90                DefaultProviders::GEMINI,
91                gemini::Client::from_env_boxed,
92                gemini::Client::from_val_boxed,
93            ),
94            ClientFactory::new(
95                DefaultProviders::HUGGINGFACE,
96                huggingface::Client::from_env_boxed,
97                huggingface::Client::from_val_boxed,
98            ),
99            ClientFactory::new(
100                DefaultProviders::OPENAI,
101                openai::Client::from_env_boxed,
102                openai::Client::from_val_boxed,
103            ),
104            ClientFactory::new(
105                DefaultProviders::OPENROUTER,
106                openrouter::Client::from_env_boxed,
107                openrouter::Client::from_val_boxed,
108            ),
109            ClientFactory::new(
110                DefaultProviders::TOGETHER,
111                together::Client::from_env_boxed,
112                together::Client::from_val_boxed,
113            ),
114            ClientFactory::new(
115                DefaultProviders::XAI,
116                xai::Client::from_env_boxed,
117                xai::Client::from_val_boxed,
118            ),
119            ClientFactory::new(
120                DefaultProviders::AZURE,
121                azure::Client::from_env_boxed,
122                azure::Client::from_val_boxed,
123            ),
124            ClientFactory::new(
125                DefaultProviders::DEEPSEEK,
126                deepseek::Client::from_env_boxed,
127                deepseek::Client::from_val_boxed,
128            ),
129            ClientFactory::new(
130                DefaultProviders::GALADRIEL,
131                galadriel::Client::from_env_boxed,
132                galadriel::Client::from_val_boxed,
133            ),
134            ClientFactory::new(
135                DefaultProviders::GROQ,
136                groq::Client::from_env_boxed,
137                groq::Client::from_val_boxed,
138            ),
139            ClientFactory::new(
140                DefaultProviders::HYPERBOLIC,
141                hyperbolic::Client::from_env_boxed,
142                hyperbolic::Client::from_val_boxed,
143            ),
144            ClientFactory::new(
145                DefaultProviders::MOONSHOT,
146                moonshot::Client::from_env_boxed,
147                moonshot::Client::from_val_boxed,
148            ),
149            ClientFactory::new(
150                DefaultProviders::MIRA,
151                mira::Client::from_env_boxed,
152                mira::Client::from_val_boxed,
153            ),
154            ClientFactory::new(
155                DefaultProviders::MISTRAL,
156                mistral::Client::from_env_boxed,
157                mistral::Client::from_val_boxed,
158            ),
159            ClientFactory::new(
160                DefaultProviders::OLLAMA,
161                ollama::Client::from_env_boxed,
162                ollama::Client::from_val_boxed,
163            ),
164            ClientFactory::new(
165                DefaultProviders::PERPLEXITY,
166                perplexity::Client::from_env_boxed,
167                perplexity::Client::from_val_boxed,
168            ),
169        ])
170    }
171
172    /// Generate a new instance of `DynClientBuilder` with no client factories registered.
173    pub fn empty() -> Self {
174        Self {
175            registry: HashMap::new(),
176        }
177    }
178
179    /// Register a new ClientFactory
180    pub fn register(mut self, client_factory: ClientFactory) -> Self {
181        self.registry
182            .insert(client_factory.name.clone(), client_factory);
183        self
184    }
185
186    /// Register multiple ClientFactories
187    pub fn register_all(mut self, factories: impl IntoIterator<Item = ClientFactory>) -> Self {
188        for factory in factories {
189            self.registry.insert(factory.name.clone(), factory);
190        }
191
192        self
193    }
194
195    /// Returns a (boxed) specific provider based on the given provider.
196    pub fn build(&self, provider: &str) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
197        let factory = self.get_factory(provider)?;
198        factory.build()
199    }
200
201    /// Returns a (boxed) specific provider based on the given provider.
202    pub fn build_val(
203        &self,
204        provider: &str,
205        provider_value: ProviderValue,
206    ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
207        let factory = self.get_factory(provider)?;
208        factory.build_from_val(provider_value)
209    }
210
211    /// Parses a provider:model string to the provider and the model separately.
212    /// For example, `openai:gpt-4o` will return ("openai", "gpt-4o").
213    pub fn parse(&self, id: &'a str) -> Result<(&'a str, &'a str), ClientBuildError> {
214        let (provider, model) = id
215            .split_once(":")
216            .ok_or(ClientBuildError::InvalidIdString(id.to_string()))?;
217
218        Ok((provider, model))
219    }
220
221    /// Returns a specific client factory (that exists in the registry).
222    fn get_factory(&self, provider: &str) -> Result<&ClientFactory, ClientBuildError> {
223        self.registry
224            .get(provider)
225            .ok_or(ClientBuildError::UnknownProvider)
226    }
227
228    /// Get a boxed completion model based on the provider and model.
229    pub fn completion(
230        &self,
231        provider: &str,
232        model: &str,
233    ) -> Result<BoxCompletionModel<'a>, ClientBuildError> {
234        let client = self.build(provider)?;
235
236        let completion = client
237            .as_completion()
238            .ok_or(ClientBuildError::UnsupportedFeature(
239                provider.to_string(),
240                "completion".to_owned(),
241            ))?;
242
243        Ok(completion.completion_model(model))
244    }
245
246    /// Get a boxed agent based on the provider and model..
247    pub fn agent(
248        &self,
249        provider: &str,
250        model: &str,
251    ) -> Result<BoxAgentBuilder<'a>, ClientBuildError> {
252        let client = self.build(provider)?;
253
254        let client = client
255            .as_completion()
256            .ok_or(ClientBuildError::UnsupportedFeature(
257                provider.to_string(),
258                "completion".to_string(),
259            ))?;
260
261        Ok(client.agent(model))
262    }
263
264    /// Get a boxed agent based on the provider and model, as well as an API key.
265    pub fn agent_with_api_key_val<P>(
266        &self,
267        provider: &str,
268        model: &str,
269        provider_value: P,
270    ) -> Result<BoxAgentBuilder<'a>, ClientBuildError>
271    where
272        P: Into<ProviderValue>,
273    {
274        let client = self.build_val(provider, provider_value.into())?;
275
276        let client = client
277            .as_completion()
278            .ok_or(ClientBuildError::UnsupportedFeature(
279                provider.to_string(),
280                "completion".to_string(),
281            ))?;
282
283        Ok(client.agent(model))
284    }
285
286    /// Get a boxed embedding model based on the provider and model.
287    pub fn embeddings(
288        &self,
289        provider: &str,
290        model: &str,
291    ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError> {
292        let client = self.build(provider)?;
293
294        let embeddings = client
295            .as_embeddings()
296            .ok_or(ClientBuildError::UnsupportedFeature(
297                provider.to_string(),
298                "embeddings".to_owned(),
299            ))?;
300
301        Ok(embeddings.embedding_model(model))
302    }
303
304    /// Get a boxed embedding model based on the provider and model.
305    pub fn embeddings_with_api_key_val<P>(
306        &self,
307        provider: &str,
308        model: &str,
309        provider_value: P,
310    ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError>
311    where
312        P: Into<ProviderValue>,
313    {
314        let client = self.build_val(provider, provider_value.into())?;
315
316        let embeddings = client
317            .as_embeddings()
318            .ok_or(ClientBuildError::UnsupportedFeature(
319                provider.to_string(),
320                "embeddings".to_owned(),
321            ))?;
322
323        Ok(embeddings.embedding_model(model))
324    }
325
326    /// Get a boxed transcription model based on the provider and model.
327    pub fn transcription(
328        &self,
329        provider: &str,
330        model: &str,
331    ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError> {
332        let client = self.build(provider)?;
333        let transcription =
334            client
335                .as_transcription()
336                .ok_or(ClientBuildError::UnsupportedFeature(
337                    provider.to_string(),
338                    "transcription".to_owned(),
339                ))?;
340
341        Ok(transcription.transcription_model(model))
342    }
343
344    /// Get a boxed transcription model based on the provider and model.
345    pub fn transcription_with_api_key_val<P>(
346        &self,
347        provider: &str,
348        model: &str,
349        provider_value: P,
350    ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError>
351    where
352        P: Into<ProviderValue>,
353    {
354        let client = self.build_val(provider, provider_value.into())?;
355        let transcription =
356            client
357                .as_transcription()
358                .ok_or(ClientBuildError::UnsupportedFeature(
359                    provider.to_string(),
360                    "transcription".to_owned(),
361                ))?;
362
363        Ok(transcription.transcription_model(model))
364    }
365
366    /// Get the ID of a provider model based on a `provider:model` ID.
367    pub fn id<'id>(&'a self, id: &'id str) -> Result<ProviderModelId<'a, 'id>, ClientBuildError> {
368        let (provider, model) = self.parse(id)?;
369
370        Ok(ProviderModelId {
371            builder: self,
372            provider,
373            model,
374        })
375    }
376}
377
378pub struct ProviderModelId<'builder, 'id> {
379    builder: &'builder DynClientBuilder,
380    provider: &'id str,
381    model: &'id str,
382}
383
384impl<'builder> ProviderModelId<'builder, '_> {
385    pub fn completion(self) -> Result<BoxCompletionModel<'builder>, ClientBuildError> {
386        self.builder.completion(self.provider, self.model)
387    }
388
389    pub fn agent(self) -> Result<BoxAgentBuilder<'builder>, ClientBuildError> {
390        self.builder.agent(self.provider, self.model)
391    }
392
393    pub fn embedding(self) -> Result<BoxEmbeddingModel<'builder>, ClientBuildError> {
394        self.builder.embeddings(self.provider, self.model)
395    }
396
397    pub fn transcription(self) -> Result<BoxTranscriptionModel<'builder>, ClientBuildError> {
398        self.builder.transcription(self.provider, self.model)
399    }
400}
401
402#[cfg(feature = "image")]
403mod image {
404    use crate::client::builder::ClientBuildError;
405    use crate::image_generation::ImageGenerationModelDyn;
406    use rig::client::builder::{DynClientBuilder, ProviderModelId};
407
408    pub type BoxImageGenerationModel<'a> = Box<dyn ImageGenerationModelDyn + 'a>;
409
410    impl DynClientBuilder {
411        pub fn image_generation<'a>(
412            &self,
413            provider: &str,
414            model: &str,
415        ) -> Result<BoxImageGenerationModel<'a>, ClientBuildError> {
416            let client = self.build(provider)?;
417            let image =
418                client
419                    .as_image_generation()
420                    .ok_or(ClientBuildError::UnsupportedFeature(
421                        provider.to_string(),
422                        "image_generation".to_string(),
423                    ))?;
424
425            Ok(image.image_generation_model(model))
426        }
427    }
428
429    impl<'builder> ProviderModelId<'builder, '_> {
430        pub fn image_generation(
431            self,
432        ) -> Result<Box<dyn ImageGenerationModelDyn + 'builder>, ClientBuildError> {
433            self.builder.image_generation(self.provider, self.model)
434        }
435    }
436}
437#[cfg(feature = "image")]
438pub use image::*;
439
440#[cfg(feature = "audio")]
441mod audio {
442    use crate::audio_generation::AudioGenerationModelDyn;
443    use crate::client::builder::DynClientBuilder;
444    use crate::client::builder::{ClientBuildError, ProviderModelId};
445
446    pub type BoxAudioGenerationModel<'a> = Box<dyn AudioGenerationModelDyn + 'a>;
447
448    impl DynClientBuilder {
449        pub fn audio_generation<'a>(
450            &self,
451            provider: &str,
452            model: &str,
453        ) -> Result<BoxAudioGenerationModel<'a>, ClientBuildError> {
454            let client = self.build(provider)?;
455            let audio =
456                client
457                    .as_audio_generation()
458                    .ok_or(ClientBuildError::UnsupportedFeature(
459                        provider.to_string(),
460                        "audio_generation".to_owned(),
461                    ))?;
462
463            Ok(audio.audio_generation_model(model))
464        }
465    }
466
467    impl<'builder> ProviderModelId<'builder, '_> {
468        pub fn audio_generation(
469            self,
470        ) -> Result<Box<dyn AudioGenerationModelDyn + 'builder>, ClientBuildError> {
471            self.builder.audio_generation(self.provider, self.model)
472        }
473    }
474}
475use crate::agent::AgentBuilder;
476use crate::client::completion::CompletionModelHandle;
477#[cfg(feature = "audio")]
478pub use audio::*;
479use rig::providers::mistral;
480
481use super::ProviderValue;
482
483pub struct ClientFactory {
484    pub name: String,
485    pub factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
486    pub factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
487}
488
489impl UnwindSafe for ClientFactory {}
490impl RefUnwindSafe for ClientFactory {}
491
492impl ClientFactory {
493    pub fn new<F1, F2>(name: &str, func_env: F1, func_val: F2) -> Self
494    where
495        F1: 'static + Fn() -> Box<dyn ProviderClient>,
496        F2: 'static + Fn(ProviderValue) -> Box<dyn ProviderClient>,
497    {
498        Self {
499            name: name.to_string(),
500            factory_env: Box::new(func_env),
501            factory_val: Box::new(func_val),
502        }
503    }
504
505    pub fn build(&self) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
506        std::panic::catch_unwind(|| (self.factory_env)())
507            .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
508    }
509
510    pub fn build_from_val(
511        &self,
512        val: ProviderValue,
513    ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
514        std::panic::catch_unwind(|| (self.factory_val)(val))
515            .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
516    }
517}
518
519pub struct DefaultProviders;
520impl DefaultProviders {
521    pub const ANTHROPIC: &'static str = "anthropic";
522    pub const COHERE: &'static str = "cohere";
523    pub const GEMINI: &'static str = "gemini";
524    pub const HUGGINGFACE: &'static str = "huggingface";
525    pub const OPENAI: &'static str = "openai";
526    pub const OPENROUTER: &'static str = "openrouter";
527    pub const TOGETHER: &'static str = "together";
528    pub const XAI: &'static str = "xai";
529    pub const AZURE: &'static str = "azure";
530    pub const DEEPSEEK: &'static str = "deepseek";
531    pub const GALADRIEL: &'static str = "galadriel";
532    pub const GROQ: &'static str = "groq";
533    pub const HYPERBOLIC: &'static str = "hyperbolic";
534    pub const MOONSHOT: &'static str = "moonshot";
535    pub const MIRA: &'static str = "mira";
536    pub const MISTRAL: &'static str = "mistral";
537    pub const OLLAMA: &'static str = "ollama";
538    pub const PERPLEXITY: &'static str = "perplexity";
539}