rig/client/
builder.rs

1use crate::agent::Agent;
2use crate::client::ProviderClient;
3use crate::embeddings::embedding::EmbeddingModelDyn;
4use crate::providers::{
5    anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
6    moonshot, ollama, openai, openrouter, perplexity, together, xai,
7};
8use crate::transcription::TranscriptionModelDyn;
9use rig::completion::CompletionModelDyn;
10use std::collections::HashMap;
11use std::panic::{RefUnwindSafe, UnwindSafe};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum ClientBuildError {
16    #[error("factory error: {}", .0)]
17    FactoryError(String),
18    #[error("invalid id string: {}", .0)]
19    InvalidIdString(String),
20    #[error("unsupported feature: {} for {}", .1, .0)]
21    UnsupportedFeature(String, String),
22    #[error("unknown provider")]
23    UnknownProvider,
24}
25
26pub type BoxCompletionModel<'a> = Box<dyn CompletionModelDyn + 'a>;
27pub type BoxAgentBuilder<'a> = AgentBuilder<CompletionModelHandle<'a>>;
28pub type BoxAgent<'a> = Agent<CompletionModelHandle<'a>>;
29pub type BoxEmbeddingModel<'a> = Box<dyn EmbeddingModelDyn + 'a>;
30pub type BoxTranscriptionModel<'a> = Box<dyn TranscriptionModelDyn + 'a>;
31
32/// A dynamic client builder.
33/// Use this when you need to support creating any kind of client from a range of LLM providers (that Rig supports).
34/// Usage:
35/// ```rust
36/// use rig::{
37///     client::builder::DynClientBuilder, completion::Prompt, providers::anthropic::CLAUDE_3_7_SONNET,
38/// };
39/// #[tokio::main]
40/// async fn main() {
41///     let multi_client = DynClientBuilder::new();
42///     // set up OpenAI client
43///     let completion_openai = multi_client.agent("openai", "gpt-4o").unwrap();
44///     let agent_openai = completion_openai
45///         .preamble("You are a helpful assistant")
46///         .build();
47///     // set up Anthropic client
48///     let completion_anthropic = multi_client.agent("anthropic", CLAUDE_3_7_SONNET).unwrap();
49///     let agent_anthropic = completion_anthropic
50///         .preamble("You are a helpful assistant")
51///         .max_tokens(1024)
52///         .build();
53///     println!("Sending prompt: 'Hello world!'");
54///     let res_openai = agent_openai.prompt("Hello world!").await.unwrap();
55///     println!("Response from OpenAI (using gpt-4o): {res_openai}");
56///     let res_anthropic = agent_anthropic.prompt("Hello world!").await.unwrap();
57///     println!("Response from Anthropic (using Claude 3.7 Sonnet): {res_anthropic}");
58/// }
59/// ```
60pub struct DynClientBuilder {
61    registry: HashMap<String, ClientFactory>,
62}
63
64impl Default for DynClientBuilder {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl<'a> DynClientBuilder {
71    /// Generate a new instance of `DynClientBuilder`.
72    /// By default, every single possible client that can be registered
73    /// will be registered to the client builder.
74    pub fn new() -> Self {
75        Self {
76            registry: HashMap::new(),
77        }
78        .register_all(vec![
79            ClientFactory::new(
80                DefaultProviders::ANTHROPIC,
81                anthropic::Client::from_env_boxed,
82            ),
83            ClientFactory::new(DefaultProviders::COHERE, cohere::Client::from_env_boxed),
84            ClientFactory::new(DefaultProviders::GEMINI, gemini::Client::from_env_boxed),
85            ClientFactory::new(
86                DefaultProviders::HUGGINGFACE,
87                huggingface::Client::from_env_boxed,
88            ),
89            ClientFactory::new(DefaultProviders::OPENAI, openai::Client::from_env_boxed),
90            ClientFactory::new(
91                DefaultProviders::OPENROUTER,
92                openrouter::Client::from_env_boxed,
93            ),
94            ClientFactory::new(DefaultProviders::TOGETHER, together::Client::from_env_boxed),
95            ClientFactory::new(DefaultProviders::XAI, xai::Client::from_env_boxed),
96            ClientFactory::new(DefaultProviders::AZURE, azure::Client::from_env_boxed),
97            ClientFactory::new(DefaultProviders::DEEPSEEK, deepseek::Client::from_env_boxed),
98            ClientFactory::new(
99                DefaultProviders::GALADRIEL,
100                galadriel::Client::from_env_boxed,
101            ),
102            ClientFactory::new(DefaultProviders::GROQ, groq::Client::from_env_boxed),
103            ClientFactory::new(
104                DefaultProviders::HYPERBOLIC,
105                hyperbolic::Client::from_env_boxed,
106            ),
107            ClientFactory::new(DefaultProviders::MOONSHOT, moonshot::Client::from_env_boxed),
108            ClientFactory::new(DefaultProviders::MIRA, mira::Client::from_env_boxed),
109            ClientFactory::new(DefaultProviders::MISTRAL, mistral::Client::from_env_boxed),
110            ClientFactory::new(DefaultProviders::OLLAMA, ollama::Client::from_env_boxed),
111            ClientFactory::new(
112                DefaultProviders::PERPLEXITY,
113                perplexity::Client::from_env_boxed,
114            ),
115        ])
116    }
117
118    /// Generate a new instance of `DynClientBuilder` with no client factories registered.
119    pub fn empty() -> Self {
120        Self {
121            registry: HashMap::new(),
122        }
123    }
124
125    /// Register a new ClientFactory
126    pub fn register(mut self, client_factory: ClientFactory) -> Self {
127        self.registry
128            .insert(client_factory.name.clone(), client_factory);
129        self
130    }
131
132    /// Register multiple ClientFactories
133    pub fn register_all(mut self, factories: impl IntoIterator<Item = ClientFactory>) -> Self {
134        for factory in factories {
135            self.registry.insert(factory.name.clone(), factory);
136        }
137
138        self
139    }
140
141    /// Returns a (boxed) specific provider based on the given provider.
142    pub fn build(&self, provider: &str) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
143        let factory = self.get_factory(provider)?;
144        factory.build()
145    }
146
147    /// Parses a provider:model string to the provider and the model separately.
148    /// For example, `openai:gpt-4o` will return ("openai", "gpt-4o").
149    pub fn parse(&self, id: &'a str) -> Result<(&'a str, &'a str), ClientBuildError> {
150        let (provider, model) = id
151            .split_once(":")
152            .ok_or(ClientBuildError::InvalidIdString(id.to_string()))?;
153
154        Ok((provider, model))
155    }
156
157    /// Returns a specific client factory (that exists in the registry).
158    fn get_factory(&self, provider: &str) -> Result<&ClientFactory, ClientBuildError> {
159        self.registry
160            .get(provider)
161            .ok_or(ClientBuildError::UnknownProvider)
162    }
163
164    /// Get a boxed completion model based on the provider and model.
165    pub fn completion(
166        &self,
167        provider: &str,
168        model: &str,
169    ) -> Result<BoxCompletionModel<'a>, ClientBuildError> {
170        let client = self.build(provider)?;
171
172        let completion = client
173            .as_completion()
174            .ok_or(ClientBuildError::UnsupportedFeature(
175                provider.to_string(),
176                "completion".to_owned(),
177            ))?;
178
179        Ok(completion.completion_model(model))
180    }
181
182    /// Get a boxed agent based on the provider and model..
183    pub fn agent(
184        &self,
185        provider: &str,
186        model: &str,
187    ) -> Result<BoxAgentBuilder<'a>, ClientBuildError> {
188        let client = self.build(provider)?;
189
190        let client = client
191            .as_completion()
192            .ok_or(ClientBuildError::UnsupportedFeature(
193                provider.to_string(),
194                "completion".to_string(),
195            ))?;
196
197        Ok(client.agent(model))
198    }
199
200    /// Get a boxed embedding model based on the provider and model.
201    pub fn embeddings(
202        &self,
203        provider: &str,
204        model: &str,
205    ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError> {
206        let client = self.build(provider)?;
207
208        let embeddings = client
209            .as_embeddings()
210            .ok_or(ClientBuildError::UnsupportedFeature(
211                provider.to_string(),
212                "embeddings".to_owned(),
213            ))?;
214
215        Ok(embeddings.embedding_model(model))
216    }
217
218    /// Get a boxed transcription model based on the provider and model.
219    pub fn transcription(
220        &self,
221        provider: &str,
222        model: &str,
223    ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError> {
224        let client = self.build(provider)?;
225        let transcription =
226            client
227                .as_transcription()
228                .ok_or(ClientBuildError::UnsupportedFeature(
229                    provider.to_string(),
230                    "transcription".to_owned(),
231                ))?;
232
233        Ok(transcription.transcription_model(model))
234    }
235
236    /// Get the ID of a provider model based on a `provider:model` ID.
237    pub fn id<'id>(&'a self, id: &'id str) -> Result<ProviderModelId<'a, 'id>, ClientBuildError> {
238        let (provider, model) = self.parse(id)?;
239
240        Ok(ProviderModelId {
241            builder: self,
242            provider,
243            model,
244        })
245    }
246}
247
248pub struct ProviderModelId<'builder, 'id> {
249    builder: &'builder DynClientBuilder,
250    provider: &'id str,
251    model: &'id str,
252}
253
254impl<'builder> ProviderModelId<'builder, '_> {
255    pub fn completion(self) -> Result<BoxCompletionModel<'builder>, ClientBuildError> {
256        self.builder.completion(self.provider, self.model)
257    }
258
259    pub fn agent(self) -> Result<BoxAgentBuilder<'builder>, ClientBuildError> {
260        self.builder.agent(self.provider, self.model)
261    }
262
263    pub fn embedding(self) -> Result<BoxEmbeddingModel<'builder>, ClientBuildError> {
264        self.builder.embeddings(self.provider, self.model)
265    }
266
267    pub fn transcription(self) -> Result<BoxTranscriptionModel<'builder>, ClientBuildError> {
268        self.builder.transcription(self.provider, self.model)
269    }
270}
271
272#[cfg(feature = "image")]
273mod image {
274    use crate::client::builder::ClientBuildError;
275    use crate::image_generation::ImageGenerationModelDyn;
276    use rig::client::builder::{DynClientBuilder, ProviderModelId};
277
278    pub type BoxImageGenerationModel<'a> = Box<dyn ImageGenerationModelDyn + 'a>;
279
280    impl DynClientBuilder {
281        pub fn image_generation<'a>(
282            &self,
283            provider: &str,
284            model: &str,
285        ) -> Result<BoxImageGenerationModel<'a>, ClientBuildError> {
286            let client = self.build(provider)?;
287            let image =
288                client
289                    .as_image_generation()
290                    .ok_or(ClientBuildError::UnsupportedFeature(
291                        provider.to_string(),
292                        "image_generation".to_string(),
293                    ))?;
294
295            Ok(image.image_generation_model(model))
296        }
297    }
298
299    impl<'builder> ProviderModelId<'builder, '_> {
300        pub fn image_generation(
301            self,
302        ) -> Result<Box<dyn ImageGenerationModelDyn + 'builder>, ClientBuildError> {
303            self.builder.image_generation(self.provider, self.model)
304        }
305    }
306}
307#[cfg(feature = "image")]
308pub use image::*;
309
310#[cfg(feature = "audio")]
311mod audio {
312    use crate::audio_generation::AudioGenerationModelDyn;
313    use crate::client::builder::DynClientBuilder;
314    use crate::client::builder::{ClientBuildError, ProviderModelId};
315
316    pub type BoxAudioGenerationModel<'a> = Box<dyn AudioGenerationModelDyn + 'a>;
317
318    impl DynClientBuilder {
319        pub fn audio_generation<'a>(
320            &self,
321            provider: &str,
322            model: &str,
323        ) -> Result<BoxAudioGenerationModel<'a>, ClientBuildError> {
324            let client = self.build(provider)?;
325            let audio =
326                client
327                    .as_audio_generation()
328                    .ok_or(ClientBuildError::UnsupportedFeature(
329                        provider.to_string(),
330                        "audio_generation".to_owned(),
331                    ))?;
332
333            Ok(audio.audio_generation_model(model))
334        }
335    }
336
337    impl<'builder> ProviderModelId<'builder, '_> {
338        pub fn audio_generation(
339            self,
340        ) -> Result<Box<dyn AudioGenerationModelDyn + 'builder>, ClientBuildError> {
341            self.builder.audio_generation(self.provider, self.model)
342        }
343    }
344}
345use crate::agent::AgentBuilder;
346use crate::client::completion::CompletionModelHandle;
347#[cfg(feature = "audio")]
348pub use audio::*;
349use rig::providers::mistral;
350
351pub struct ClientFactory {
352    pub name: String,
353    pub factory: Box<dyn Fn() -> Box<dyn ProviderClient>>,
354}
355
356impl UnwindSafe for ClientFactory {}
357impl RefUnwindSafe for ClientFactory {}
358
359impl ClientFactory {
360    pub fn new<F: 'static + Fn() -> Box<dyn ProviderClient>>(name: &str, func: F) -> Self {
361        Self {
362            name: name.to_string(),
363            factory: Box::new(func),
364        }
365    }
366
367    pub fn build(&self) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
368        std::panic::catch_unwind(|| (self.factory)())
369            .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
370    }
371}
372
373pub struct DefaultProviders;
374impl DefaultProviders {
375    pub const ANTHROPIC: &'static str = "anthropic";
376    pub const COHERE: &'static str = "cohere";
377    pub const GEMINI: &'static str = "gemini";
378    pub const HUGGINGFACE: &'static str = "huggingface";
379    pub const OPENAI: &'static str = "openai";
380    pub const OPENROUTER: &'static str = "openrouter";
381    pub const TOGETHER: &'static str = "together";
382    pub const XAI: &'static str = "xai";
383    pub const AZURE: &'static str = "azure";
384    pub const DEEPSEEK: &'static str = "deepseek";
385    pub const GALADRIEL: &'static str = "galadriel";
386    pub const GROQ: &'static str = "groq";
387    pub const HYPERBOLIC: &'static str = "hyperbolic";
388    pub const MOONSHOT: &'static str = "moonshot";
389    pub const MIRA: &'static str = "mira";
390    pub const MISTRAL: &'static str = "mistral";
391    pub const OLLAMA: &'static str = "ollama";
392    pub const PERPLEXITY: &'static str = "perplexity";
393}