rig/client/
builder.rs

1use crate::agent::Agent;
2use crate::client::ProviderClient;
3use crate::completion::{CompletionRequest, Message};
4use crate::embeddings::embedding::EmbeddingModelDyn;
5use crate::providers::{
6    anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
7    moonshot, ollama, openai, openrouter, perplexity, together, xai,
8};
9use crate::streaming::StreamingCompletionResponse;
10use crate::transcription::TranscriptionModelDyn;
11use rig::completion::CompletionModelDyn;
12use std::collections::HashMap;
13use std::panic::{RefUnwindSafe, UnwindSafe};
14use thiserror::Error;
15
16#[derive(Debug, Error)]
17pub enum ClientBuildError {
18    #[error("factory error: {}", .0)]
19    FactoryError(String),
20    #[error("invalid id string: {}", .0)]
21    InvalidIdString(String),
22    #[error("unsupported feature: {} for {}", .1, .0)]
23    UnsupportedFeature(String, String),
24    #[error("unknown provider")]
25    UnknownProvider,
26}
27
28pub type BoxCompletionModel<'a> = Box<dyn CompletionModelDyn + 'a>;
29pub type BoxAgentBuilder<'a> = AgentBuilder<CompletionModelHandle<'a>>;
30pub type BoxAgent<'a> = Agent<CompletionModelHandle<'a>>;
31pub type BoxEmbeddingModel<'a> = Box<dyn EmbeddingModelDyn + 'a>;
32pub type BoxTranscriptionModel<'a> = Box<dyn TranscriptionModelDyn + 'a>;
33
34/// A dynamic client builder.
35/// Use this when you need to support creating any kind of client from a range of LLM providers (that Rig supports).
36/// Usage:
37/// ```rust
38/// use rig::{
39///     client::builder::DynClientBuilder, completion::Prompt, providers::anthropic::CLAUDE_3_7_SONNET,
40/// };
41/// #[tokio::main]
42/// async fn main() {
43///     let multi_client = DynClientBuilder::new();
44///     // set up OpenAI client
45///     let completion_openai = multi_client.agent("openai", "gpt-4o").unwrap();
46///     let agent_openai = completion_openai
47///         .preamble("You are a helpful assistant")
48///         .build();
49///     // set up Anthropic client
50///     let completion_anthropic = multi_client.agent("anthropic", CLAUDE_3_7_SONNET).unwrap();
51///     let agent_anthropic = completion_anthropic
52///         .preamble("You are a helpful assistant")
53///         .max_tokens(1024)
54///         .build();
55///     println!("Sending prompt: 'Hello world!'");
56///     let res_openai = agent_openai.prompt("Hello world!").await.unwrap();
57///     println!("Response from OpenAI (using gpt-4o): {res_openai}");
58///     let res_anthropic = agent_anthropic.prompt("Hello world!").await.unwrap();
59///     println!("Response from Anthropic (using Claude 3.7 Sonnet): {res_anthropic}");
60/// }
61/// ```
62pub struct DynClientBuilder {
63    registry: HashMap<String, ClientFactory>,
64}
65
66impl Default for DynClientBuilder {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl<'a> DynClientBuilder {
73    /// Generate a new instance of `DynClientBuilder`.
74    /// By default, every single possible client that can be registered
75    /// will be registered to the client builder.
76    pub fn new() -> Self {
77        Self {
78            registry: HashMap::new(),
79        }
80        .register_all(vec![
81            ClientFactory::new(
82                DefaultProviders::ANTHROPIC,
83                anthropic::Client::from_env_boxed,
84                anthropic::Client::from_val_boxed,
85            ),
86            ClientFactory::new(
87                DefaultProviders::COHERE,
88                cohere::Client::from_env_boxed,
89                cohere::Client::from_val_boxed,
90            ),
91            ClientFactory::new(
92                DefaultProviders::GEMINI,
93                gemini::Client::from_env_boxed,
94                gemini::Client::from_val_boxed,
95            ),
96            ClientFactory::new(
97                DefaultProviders::HUGGINGFACE,
98                huggingface::Client::from_env_boxed,
99                huggingface::Client::from_val_boxed,
100            ),
101            ClientFactory::new(
102                DefaultProviders::OPENAI,
103                openai::Client::from_env_boxed,
104                openai::Client::from_val_boxed,
105            ),
106            ClientFactory::new(
107                DefaultProviders::OPENROUTER,
108                openrouter::Client::from_env_boxed,
109                openrouter::Client::from_val_boxed,
110            ),
111            ClientFactory::new(
112                DefaultProviders::TOGETHER,
113                together::Client::from_env_boxed,
114                together::Client::from_val_boxed,
115            ),
116            ClientFactory::new(
117                DefaultProviders::XAI,
118                xai::Client::from_env_boxed,
119                xai::Client::from_val_boxed,
120            ),
121            ClientFactory::new(
122                DefaultProviders::AZURE,
123                azure::Client::from_env_boxed,
124                azure::Client::from_val_boxed,
125            ),
126            ClientFactory::new(
127                DefaultProviders::DEEPSEEK,
128                deepseek::Client::from_env_boxed,
129                deepseek::Client::from_val_boxed,
130            ),
131            ClientFactory::new(
132                DefaultProviders::GALADRIEL,
133                galadriel::Client::from_env_boxed,
134                galadriel::Client::from_val_boxed,
135            ),
136            ClientFactory::new(
137                DefaultProviders::GROQ,
138                groq::Client::from_env_boxed,
139                groq::Client::from_val_boxed,
140            ),
141            ClientFactory::new(
142                DefaultProviders::HYPERBOLIC,
143                hyperbolic::Client::from_env_boxed,
144                hyperbolic::Client::from_val_boxed,
145            ),
146            ClientFactory::new(
147                DefaultProviders::MOONSHOT,
148                moonshot::Client::from_env_boxed,
149                moonshot::Client::from_val_boxed,
150            ),
151            ClientFactory::new(
152                DefaultProviders::MIRA,
153                mira::Client::from_env_boxed,
154                mira::Client::from_val_boxed,
155            ),
156            ClientFactory::new(
157                DefaultProviders::MISTRAL,
158                mistral::Client::from_env_boxed,
159                mistral::Client::from_val_boxed,
160            ),
161            ClientFactory::new(
162                DefaultProviders::OLLAMA,
163                ollama::Client::from_env_boxed,
164                ollama::Client::from_val_boxed,
165            ),
166            ClientFactory::new(
167                DefaultProviders::PERPLEXITY,
168                perplexity::Client::from_env_boxed,
169                perplexity::Client::from_val_boxed,
170            ),
171        ])
172    }
173
174    /// Generate a new instance of `DynClientBuilder` with no client factories registered.
175    pub fn empty() -> Self {
176        Self {
177            registry: HashMap::new(),
178        }
179    }
180
181    /// Register a new ClientFactory
182    pub fn register(mut self, client_factory: ClientFactory) -> Self {
183        self.registry
184            .insert(client_factory.name.clone(), client_factory);
185        self
186    }
187
188    /// Register multiple ClientFactories
189    pub fn register_all(mut self, factories: impl IntoIterator<Item = ClientFactory>) -> Self {
190        for factory in factories {
191            self.registry.insert(factory.name.clone(), factory);
192        }
193
194        self
195    }
196
197    /// Returns a (boxed) specific provider based on the given provider.
198    pub fn build(&self, provider: &str) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
199        let factory = self.get_factory(provider)?;
200        factory.build()
201    }
202
203    /// Returns a (boxed) specific provider based on the given provider.
204    pub fn build_val(
205        &self,
206        provider: &str,
207        provider_value: ProviderValue,
208    ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
209        let factory = self.get_factory(provider)?;
210        factory.build_from_val(provider_value)
211    }
212
213    /// Parses a provider:model string to the provider and the model separately.
214    /// For example, `openai:gpt-4o` will return ("openai", "gpt-4o").
215    pub fn parse(&self, id: &'a str) -> Result<(&'a str, &'a str), ClientBuildError> {
216        let (provider, model) = id
217            .split_once(":")
218            .ok_or(ClientBuildError::InvalidIdString(id.to_string()))?;
219
220        Ok((provider, model))
221    }
222
223    /// Returns a specific client factory (that exists in the registry).
224    fn get_factory(&self, provider: &str) -> Result<&ClientFactory, ClientBuildError> {
225        self.registry
226            .get(provider)
227            .ok_or(ClientBuildError::UnknownProvider)
228    }
229
230    /// Get a boxed completion model based on the provider and model.
231    pub fn completion(
232        &self,
233        provider: &str,
234        model: &str,
235    ) -> Result<BoxCompletionModel<'a>, ClientBuildError> {
236        let client = self.build(provider)?;
237
238        let completion = client
239            .as_completion()
240            .ok_or(ClientBuildError::UnsupportedFeature(
241                provider.to_string(),
242                "completion".to_owned(),
243            ))?;
244
245        Ok(completion.completion_model(model))
246    }
247
248    /// Get a boxed agent based on the provider and model..
249    pub fn agent(
250        &self,
251        provider: &str,
252        model: &str,
253    ) -> Result<BoxAgentBuilder<'a>, ClientBuildError> {
254        let client = self.build(provider)?;
255
256        let client = client
257            .as_completion()
258            .ok_or(ClientBuildError::UnsupportedFeature(
259                provider.to_string(),
260                "completion".to_string(),
261            ))?;
262
263        Ok(client.agent(model))
264    }
265
266    /// Get a boxed agent based on the provider and model, as well as an API key.
267    pub fn agent_with_api_key_val<P>(
268        &self,
269        provider: &str,
270        model: &str,
271        provider_value: P,
272    ) -> Result<BoxAgentBuilder<'a>, ClientBuildError>
273    where
274        P: Into<ProviderValue>,
275    {
276        let client = self.build_val(provider, provider_value.into())?;
277
278        let client = client
279            .as_completion()
280            .ok_or(ClientBuildError::UnsupportedFeature(
281                provider.to_string(),
282                "completion".to_string(),
283            ))?;
284
285        Ok(client.agent(model))
286    }
287
288    /// Get a boxed embedding model based on the provider and model.
289    pub fn embeddings(
290        &self,
291        provider: &str,
292        model: &str,
293    ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError> {
294        let client = self.build(provider)?;
295
296        let embeddings = client
297            .as_embeddings()
298            .ok_or(ClientBuildError::UnsupportedFeature(
299                provider.to_string(),
300                "embeddings".to_owned(),
301            ))?;
302
303        Ok(embeddings.embedding_model(model))
304    }
305
306    /// Get a boxed embedding model based on the provider and model.
307    pub fn embeddings_with_api_key_val<P>(
308        &self,
309        provider: &str,
310        model: &str,
311        provider_value: P,
312    ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError>
313    where
314        P: Into<ProviderValue>,
315    {
316        let client = self.build_val(provider, provider_value.into())?;
317
318        let embeddings = client
319            .as_embeddings()
320            .ok_or(ClientBuildError::UnsupportedFeature(
321                provider.to_string(),
322                "embeddings".to_owned(),
323            ))?;
324
325        Ok(embeddings.embedding_model(model))
326    }
327
328    /// Get a boxed transcription model based on the provider and model.
329    pub fn transcription(
330        &self,
331        provider: &str,
332        model: &str,
333    ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError> {
334        let client = self.build(provider)?;
335        let transcription =
336            client
337                .as_transcription()
338                .ok_or(ClientBuildError::UnsupportedFeature(
339                    provider.to_string(),
340                    "transcription".to_owned(),
341                ))?;
342
343        Ok(transcription.transcription_model(model))
344    }
345
346    /// Get a boxed transcription model based on the provider and model.
347    pub fn transcription_with_api_key_val<P>(
348        &self,
349        provider: &str,
350        model: &str,
351        provider_value: P,
352    ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError>
353    where
354        P: Into<ProviderValue>,
355    {
356        let client = self.build_val(provider, provider_value.into())?;
357        let transcription =
358            client
359                .as_transcription()
360                .ok_or(ClientBuildError::UnsupportedFeature(
361                    provider.to_string(),
362                    "transcription".to_owned(),
363                ))?;
364
365        Ok(transcription.transcription_model(model))
366    }
367
368    /// Get the ID of a provider model based on a `provider:model` ID.
369    pub fn id<'id>(&'a self, id: &'id str) -> Result<ProviderModelId<'a, 'id>, ClientBuildError> {
370        let (provider, model) = self.parse(id)?;
371
372        Ok(ProviderModelId {
373            builder: self,
374            provider,
375            model,
376        })
377    }
378
379    /// Stream a completion request to the specified provider and model.
380    ///
381    /// # Arguments
382    /// * `provider` - The name of the provider (e.g., "openai", "anthropic")
383    /// * `model` - The name of the model (e.g., "gpt-4o", "claude-3-sonnet")
384    /// * `request` - The completion request containing prompt, parameters, etc.
385    ///
386    /// # Returns
387    /// A future that resolves to a streaming completion response
388    pub async fn stream_completion(
389        &self,
390        provider: &str,
391        model: &str,
392        request: CompletionRequest,
393    ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
394        let client = self.build(provider)?;
395        let completion = client
396            .as_completion()
397            .ok_or(ClientBuildError::UnsupportedFeature(
398                provider.to_string(),
399                "completion".to_string(),
400            ))?;
401
402        let model = completion.completion_model(model);
403        model
404            .stream(request)
405            .await
406            .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
407    }
408
409    /// Stream a simple prompt to the specified provider and model.
410    ///
411    /// # Arguments
412    /// * `provider` - The name of the provider (e.g., "openai", "anthropic")
413    /// * `model` - The name of the model (e.g., "gpt-4o", "claude-3-sonnet")
414    /// * `prompt` - The prompt to send to the model
415    ///
416    /// # Returns
417    /// A future that resolves to a streaming completion response
418    pub async fn stream_prompt(
419        &self,
420        provider: &str,
421        model: &str,
422        prompt: impl Into<Message> + Send,
423    ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
424        let client = self.build(provider)?;
425        let completion = client
426            .as_completion()
427            .ok_or(ClientBuildError::UnsupportedFeature(
428                provider.to_string(),
429                "completion".to_string(),
430            ))?;
431
432        let model = completion.completion_model(model);
433        let request = CompletionRequest {
434            preamble: None,
435            tools: vec![],
436            documents: vec![],
437            temperature: None,
438            max_tokens: None,
439            additional_params: None,
440            chat_history: crate::OneOrMany::one(prompt.into()),
441        };
442
443        model
444            .stream(request)
445            .await
446            .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
447    }
448
449    /// Stream a chat with history to the specified provider and model.
450    ///
451    /// # Arguments
452    /// * `provider` - The name of the provider (e.g., "openai", "anthropic")
453    /// * `model` - The name of the model (e.g., "gpt-4o", "claude-3-sonnet")
454    /// * `prompt` - The new prompt to send to the model
455    /// * `chat_history` - The chat history to include with the request
456    ///
457    /// # Returns
458    /// A future that resolves to a streaming completion response
459    pub async fn stream_chat(
460        &self,
461        provider: &str,
462        model: &str,
463        prompt: impl Into<Message> + Send,
464        chat_history: Vec<Message>,
465    ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
466        let client = self.build(provider)?;
467        let completion = client
468            .as_completion()
469            .ok_or(ClientBuildError::UnsupportedFeature(
470                provider.to_string(),
471                "completion".to_string(),
472            ))?;
473
474        let model = completion.completion_model(model);
475        let mut history = chat_history;
476        history.push(prompt.into());
477
478        let request = CompletionRequest {
479            preamble: None,
480            tools: vec![],
481            documents: vec![],
482            temperature: None,
483            max_tokens: None,
484            additional_params: None,
485            chat_history: crate::OneOrMany::many(history)
486                .unwrap_or_else(|_| crate::OneOrMany::one(Message::user(""))),
487        };
488
489        model
490            .stream(request)
491            .await
492            .map_err(|e| ClientBuildError::FactoryError(e.to_string()))
493    }
494}
495
496pub struct ProviderModelId<'builder, 'id> {
497    builder: &'builder DynClientBuilder,
498    provider: &'id str,
499    model: &'id str,
500}
501
502impl<'builder> ProviderModelId<'builder, '_> {
503    pub fn completion(self) -> Result<BoxCompletionModel<'builder>, ClientBuildError> {
504        self.builder.completion(self.provider, self.model)
505    }
506
507    pub fn agent(self) -> Result<BoxAgentBuilder<'builder>, ClientBuildError> {
508        self.builder.agent(self.provider, self.model)
509    }
510
511    pub fn embedding(self) -> Result<BoxEmbeddingModel<'builder>, ClientBuildError> {
512        self.builder.embeddings(self.provider, self.model)
513    }
514
515    pub fn transcription(self) -> Result<BoxTranscriptionModel<'builder>, ClientBuildError> {
516        self.builder.transcription(self.provider, self.model)
517    }
518
519    /// Stream a completion request using this provider and model.
520    ///
521    /// # Arguments
522    /// * `request` - The completion request containing prompt, parameters, etc.
523    ///
524    /// # Returns
525    /// A future that resolves to a streaming completion response
526    pub async fn stream_completion(
527        self,
528        request: CompletionRequest,
529    ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
530        self.builder
531            .stream_completion(self.provider, self.model, request)
532            .await
533    }
534
535    /// Stream a simple prompt using this provider and model.
536    ///
537    /// # Arguments
538    /// * `prompt` - The prompt to send to the model
539    ///
540    /// # Returns
541    /// A future that resolves to a streaming completion response
542    pub async fn stream_prompt(
543        self,
544        prompt: impl Into<Message> + Send,
545    ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
546        self.builder
547            .stream_prompt(self.provider, self.model, prompt)
548            .await
549    }
550
551    /// Stream a chat with history using this provider and model.
552    ///
553    /// # Arguments
554    /// * `prompt` - The new prompt to send to the model
555    /// * `chat_history` - The chat history to include with the request
556    ///
557    /// # Returns
558    /// A future that resolves to a streaming completion response
559    pub async fn stream_chat(
560        self,
561        prompt: impl Into<Message> + Send,
562        chat_history: Vec<Message>,
563    ) -> Result<StreamingCompletionResponse<()>, ClientBuildError> {
564        self.builder
565            .stream_chat(self.provider, self.model, prompt, chat_history)
566            .await
567    }
568}
569
570#[cfg(feature = "image")]
571mod image {
572    use crate::client::builder::ClientBuildError;
573    use crate::image_generation::ImageGenerationModelDyn;
574    use rig::client::builder::{DynClientBuilder, ProviderModelId};
575
576    pub type BoxImageGenerationModel<'a> = Box<dyn ImageGenerationModelDyn + 'a>;
577
578    impl DynClientBuilder {
579        pub fn image_generation<'a>(
580            &self,
581            provider: &str,
582            model: &str,
583        ) -> Result<BoxImageGenerationModel<'a>, ClientBuildError> {
584            let client = self.build(provider)?;
585            let image =
586                client
587                    .as_image_generation()
588                    .ok_or(ClientBuildError::UnsupportedFeature(
589                        provider.to_string(),
590                        "image_generation".to_string(),
591                    ))?;
592
593            Ok(image.image_generation_model(model))
594        }
595    }
596
597    impl<'builder> ProviderModelId<'builder, '_> {
598        pub fn image_generation(
599            self,
600        ) -> Result<Box<dyn ImageGenerationModelDyn + 'builder>, ClientBuildError> {
601            self.builder.image_generation(self.provider, self.model)
602        }
603    }
604}
605#[cfg(feature = "image")]
606pub use image::*;
607
608#[cfg(feature = "audio")]
609mod audio {
610    use crate::audio_generation::AudioGenerationModelDyn;
611    use crate::client::builder::DynClientBuilder;
612    use crate::client::builder::{ClientBuildError, ProviderModelId};
613
614    pub type BoxAudioGenerationModel<'a> = Box<dyn AudioGenerationModelDyn + 'a>;
615
616    impl DynClientBuilder {
617        pub fn audio_generation<'a>(
618            &self,
619            provider: &str,
620            model: &str,
621        ) -> Result<BoxAudioGenerationModel<'a>, ClientBuildError> {
622            let client = self.build(provider)?;
623            let audio =
624                client
625                    .as_audio_generation()
626                    .ok_or(ClientBuildError::UnsupportedFeature(
627                        provider.to_string(),
628                        "audio_generation".to_owned(),
629                    ))?;
630
631            Ok(audio.audio_generation_model(model))
632        }
633    }
634
635    impl<'builder> ProviderModelId<'builder, '_> {
636        pub fn audio_generation(
637            self,
638        ) -> Result<Box<dyn AudioGenerationModelDyn + 'builder>, ClientBuildError> {
639            self.builder.audio_generation(self.provider, self.model)
640        }
641    }
642}
643use crate::agent::AgentBuilder;
644use crate::client::completion::CompletionModelHandle;
645#[cfg(feature = "audio")]
646pub use audio::*;
647use rig::providers::mistral;
648
649use super::ProviderValue;
650
651pub struct ClientFactory {
652    pub name: String,
653    pub factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
654    pub factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
655}
656
657impl UnwindSafe for ClientFactory {}
658impl RefUnwindSafe for ClientFactory {}
659
660impl ClientFactory {
661    pub fn new<F1, F2>(name: &str, func_env: F1, func_val: F2) -> Self
662    where
663        F1: 'static + Fn() -> Box<dyn ProviderClient>,
664        F2: 'static + Fn(ProviderValue) -> Box<dyn ProviderClient>,
665    {
666        Self {
667            name: name.to_string(),
668            factory_env: Box::new(func_env),
669            factory_val: Box::new(func_val),
670        }
671    }
672
673    pub fn build(&self) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
674        std::panic::catch_unwind(|| (self.factory_env)())
675            .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
676    }
677
678    pub fn build_from_val(
679        &self,
680        val: ProviderValue,
681    ) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
682        std::panic::catch_unwind(|| (self.factory_val)(val))
683            .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
684    }
685}
686
687pub struct DefaultProviders;
688impl DefaultProviders {
689    pub const ANTHROPIC: &'static str = "anthropic";
690    pub const COHERE: &'static str = "cohere";
691    pub const GEMINI: &'static str = "gemini";
692    pub const HUGGINGFACE: &'static str = "huggingface";
693    pub const OPENAI: &'static str = "openai";
694    pub const OPENROUTER: &'static str = "openrouter";
695    pub const TOGETHER: &'static str = "together";
696    pub const XAI: &'static str = "xai";
697    pub const AZURE: &'static str = "azure";
698    pub const DEEPSEEK: &'static str = "deepseek";
699    pub const GALADRIEL: &'static str = "galadriel";
700    pub const GROQ: &'static str = "groq";
701    pub const HYPERBOLIC: &'static str = "hyperbolic";
702    pub const MOONSHOT: &'static str = "moonshot";
703    pub const MIRA: &'static str = "mira";
704    pub const MISTRAL: &'static str = "mistral";
705    pub const OLLAMA: &'static str = "ollama";
706    pub const PERPLEXITY: &'static str = "perplexity";
707}