rig/client/
builder.rs

1#[allow(deprecated)]
2#[cfg(feature = "audio")]
3use super::audio_generation::AudioGenerationClientDyn;
4#[cfg(feature = "image")]
5#[allow(deprecated)]
6use super::image_generation::ImageGenerationClientDyn;
7#[allow(deprecated)]
8#[cfg(feature = "audio")]
9use crate::audio_generation::AudioGenerationModelDyn;
10#[cfg(feature = "image")]
11#[allow(deprecated)]
12use crate::image_generation::ImageGenerationModelDyn;
13#[allow(deprecated)]
14use crate::{
15    OneOrMany,
16    agent::AgentBuilder,
17    client::{
18        Capabilities, Capability, Client, FinalCompletionResponse, Provider, ProviderClient,
19        completion::{CompletionClientDyn, CompletionModelHandle},
20        embeddings::EmbeddingsClientDyn,
21        transcription::TranscriptionClientDyn,
22    },
23    completion::{CompletionError, CompletionModelDyn, CompletionRequest},
24    embeddings::EmbeddingModelDyn,
25    message::Message,
26    providers::{
27        anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
28        mistral, moonshot, ollama, openai, openrouter, perplexity, together, xai,
29    },
30    streaming::StreamingCompletionResponse,
31    transcription::TranscriptionModelDyn,
32    wasm_compat::{WasmCompatSend, WasmCompatSync},
33};
34use std::{any::Any, collections::HashMap};
35
36#[derive(Debug, thiserror::Error)]
37pub enum Error {
38    #[error("Provider '{0}' not found")]
39    NotFound(String),
40    #[error("Provider '{provider}' cannot be coerced to a '{role}'")]
41    NotCapable { provider: String, role: String },
42    #[error("Error generating response\n{0}")]
43    Completion(#[from] CompletionError),
44}
45
46#[deprecated(
47    since = "0.25.0",
48    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release."
49)]
50pub struct AnyClient {
51    client: Box<dyn Any + 'static>,
52    vtable: AnyClientVTable,
53}
54
55struct AnyClientVTable {
56    #[allow(deprecated)]
57    as_completion: fn(&dyn Any) -> Option<&&dyn CompletionClientDyn>,
58    #[allow(deprecated)]
59    as_embedding: fn(&dyn Any) -> Option<&&dyn EmbeddingsClientDyn>,
60    #[allow(deprecated)]
61    as_transcription: fn(&dyn Any) -> Option<&&dyn TranscriptionClientDyn>,
62    #[allow(deprecated)]
63    #[cfg(feature = "image")]
64    as_image_generation: fn(&dyn Any) -> Option<&&dyn ImageGenerationClientDyn>,
65    #[allow(deprecated)]
66    #[cfg(feature = "audio")]
67    as_audio_generation: fn(&dyn Any) -> Option<&&dyn AudioGenerationClientDyn>,
68}
69
70#[allow(deprecated)]
71impl AnyClient {
72    pub fn new<Ext, H>(client: Client<Ext, H>) -> Self
73    where
74        Ext: Provider + Capabilities + WasmCompatSend + WasmCompatSync + 'static,
75        H: WasmCompatSend + WasmCompatSync + 'static,
76        Client<Ext, H>: WasmCompatSend + WasmCompatSync + 'static,
77    {
78        Self {
79            client: Box::new(client),
80            vtable: AnyClientVTable {
81                as_completion: if <<Ext as Capabilities>::Completion as Capability>::CAPABLE {
82                    |any| any.downcast_ref()
83                } else {
84                    |_| None
85                },
86
87                as_embedding: if <<Ext as Capabilities>::Embeddings as Capability>::CAPABLE {
88                    |any| any.downcast_ref()
89                } else {
90                    |_| None
91                },
92
93                as_transcription: if <<Ext as Capabilities>::Transcription as Capability>::CAPABLE {
94                    |any| any.downcast_ref()
95                } else {
96                    |_| None
97                },
98
99                #[cfg(feature = "image")]
100                as_image_generation:
101                    if <<Ext as Capabilities>::ImageGeneration as Capability>::CAPABLE {
102                        |any| any.downcast_ref()
103                    } else {
104                        |_| None
105                    },
106
107                #[cfg(feature = "audio")]
108                as_audio_generation:
109                    if <<Ext as Capabilities>::AudioGeneration as Capability>::CAPABLE {
110                        |any| any.downcast_ref()
111                    } else {
112                        |_| None
113                    },
114            },
115        }
116    }
117
118    pub fn as_completion(&self) -> Option<&dyn CompletionClientDyn> {
119        (self.vtable.as_completion)(self.client.as_ref()).copied()
120    }
121
122    pub fn as_embedding(&self) -> Option<&dyn EmbeddingsClientDyn> {
123        (self.vtable.as_embedding)(self.client.as_ref()).copied()
124    }
125
126    pub fn as_transcription(&self) -> Option<&dyn TranscriptionClientDyn> {
127        (self.vtable.as_transcription)(self.client.as_ref()).copied()
128    }
129
130    #[cfg(feature = "image")]
131    pub fn as_image_generation(&self) -> Option<&dyn ImageGenerationClientDyn> {
132        (self.vtable.as_image_generation)(self.client.as_ref()).copied()
133    }
134
135    #[cfg(feature = "audio")]
136    pub fn as_audio_generation(&self) -> Option<&dyn AudioGenerationClientDyn> {
137        (self.vtable.as_audio_generation)(self.client.as_ref()).copied()
138    }
139}
140
141#[deprecated(
142    since = "0.25.0",
143    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release."
144)]
145#[derive(Debug, Clone)]
146pub struct ProviderFactory {
147    /// Create a client from environment variables
148    #[allow(deprecated)]
149    from_env: fn() -> Result<AnyClient, Error>,
150}
151
152#[allow(deprecated)]
153#[deprecated(
154    since = "0.25.0",
155    note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release."
156)]
157#[derive(Debug, Clone)]
158pub struct DynClientBuilder(HashMap<String, ProviderFactory>);
159
160#[allow(deprecated)]
161impl Default for DynClientBuilder {
162    fn default() -> Self {
163        // Give it a capacity ~the number of providers we have from the start
164        Self(HashMap::with_capacity(32))
165    }
166}
167
168#[repr(u8)]
169#[derive(Debug, Clone, Copy)]
170pub enum DefaultProviders {
171    Anthropic,
172    Cohere,
173    Gemini,
174    HuggingFace,
175    OpenAI,
176    OpenRouter,
177    Together,
178    XAI,
179    Azure,
180    DeepSeek,
181    Galadriel,
182    Groq,
183    Hyperbolic,
184    Moonshot,
185    Mira,
186    Mistral,
187    Ollama,
188    Perplexity,
189}
190
191impl From<DefaultProviders> for &'static str {
192    fn from(value: DefaultProviders) -> Self {
193        use DefaultProviders::*;
194
195        match value {
196            Anthropic => "anthropic",
197            Cohere => "cohere",
198            Gemini => "gemini",
199            HuggingFace => "huggingface",
200            OpenAI => "openai",
201            OpenRouter => "openrouter",
202            Together => "together",
203            XAI => "xai",
204            Azure => "azure",
205            DeepSeek => "deepseek",
206            Galadriel => "galadriel",
207            Groq => "groq",
208            Hyperbolic => "hyperbolic",
209            Moonshot => "moonshot",
210            Mira => "mira",
211            Mistral => "mistral",
212            Ollama => "ollama",
213            Perplexity => "perplexity",
214        }
215    }
216}
217pub use DefaultProviders::*;
218
219impl std::fmt::Display for DefaultProviders {
220    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221        let s: &str = (*self).into();
222        f.write_str(s)
223    }
224}
225
226impl DefaultProviders {
227    fn all() -> impl Iterator<Item = Self> {
228        use DefaultProviders::*;
229
230        [
231            Anthropic,
232            Cohere,
233            Gemini,
234            HuggingFace,
235            OpenAI,
236            OpenRouter,
237            Together,
238            XAI,
239            Azure,
240            DeepSeek,
241            Galadriel,
242            Groq,
243            Hyperbolic,
244            Moonshot,
245            Mira,
246            Mistral,
247            Ollama,
248            Perplexity,
249        ]
250        .into_iter()
251    }
252
253    #[allow(deprecated)]
254    fn get_env_fn(self) -> fn() -> Result<AnyClient, Error> {
255        use DefaultProviders::*;
256
257        match self {
258            Anthropic => || Ok(AnyClient::new(anthropic::Client::from_env())),
259            Cohere => || Ok(AnyClient::new(cohere::Client::from_env())),
260            Gemini => || Ok(AnyClient::new(gemini::Client::from_env())),
261            HuggingFace => || Ok(AnyClient::new(huggingface::Client::from_env())),
262            OpenAI => || Ok(AnyClient::new(openai::Client::from_env())),
263            OpenRouter => || Ok(AnyClient::new(openrouter::Client::from_env())),
264            Together => || Ok(AnyClient::new(together::Client::from_env())),
265            XAI => || Ok(AnyClient::new(xai::Client::from_env())),
266            Azure => || Ok(AnyClient::new(azure::Client::from_env())),
267            DeepSeek => || Ok(AnyClient::new(deepseek::Client::from_env())),
268            Galadriel => || Ok(AnyClient::new(galadriel::Client::from_env())),
269            Groq => || Ok(AnyClient::new(groq::Client::from_env())),
270            Hyperbolic => || Ok(AnyClient::new(hyperbolic::Client::from_env())),
271            Moonshot => || Ok(AnyClient::new(moonshot::Client::from_env())),
272            Mira => || Ok(AnyClient::new(mira::Client::from_env())),
273            Mistral => || Ok(AnyClient::new(mistral::Client::from_env())),
274            Ollama => || Ok(AnyClient::new(ollama::Client::from_env())),
275            Perplexity => || Ok(AnyClient::new(perplexity::Client::from_env())),
276        }
277    }
278}
279
280#[allow(deprecated)]
281impl DynClientBuilder {
282    pub fn new() -> Self {
283        Self::default().register_all()
284    }
285
286    fn register_all(mut self) -> Self {
287        for provider in DefaultProviders::all() {
288            let from_env = provider.get_env_fn();
289            self.0
290                .insert(provider.to_string(), ProviderFactory { from_env });
291        }
292
293        self
294    }
295
296    fn to_key<Models>(provider_name: &'static str, model: &Models) -> String
297    where
298        Models: ToString,
299    {
300        format!("{provider_name}:{}", model.to_string())
301    }
302
303    pub fn register<Ext, H, Models>(mut self, provider_name: &'static str, model: Models) -> Self
304    where
305        Ext: Provider + Capabilities + WasmCompatSend + WasmCompatSync + 'static,
306        H: Default + WasmCompatSend + WasmCompatSync + 'static,
307        Client<Ext, H>: ProviderClient + WasmCompatSend + WasmCompatSync + 'static,
308        Models: ToString,
309    {
310        let key = Self::to_key(provider_name, &model);
311
312        let factory = ProviderFactory {
313            from_env: || Ok(AnyClient::new(Client::<Ext, H>::from_env())),
314        };
315
316        self.0.insert(key, factory);
317
318        self
319    }
320
321    pub fn from_env<T, Models>(
322        &self,
323        provider_name: &'static str,
324        model: Models,
325    ) -> Result<AnyClient, Error>
326    where
327        T: 'static,
328        Models: ToString,
329    {
330        let key = Self::to_key(provider_name, &model);
331
332        self.0
333            .get(&key)
334            .ok_or(Error::NotFound(key))
335            .and_then(|factory| (factory.from_env)())
336    }
337
338    pub fn factory<Models>(
339        &self,
340        provider_name: &'static str,
341        model: Models,
342    ) -> Option<&ProviderFactory>
343    where
344        Models: ToString,
345    {
346        let key = Self::to_key(provider_name, &model);
347
348        self.0.get(&key)
349    }
350
351    /// Get a boxed agent based on the provider and model, as well as an API key.
352    pub fn agent<Models>(
353        &self,
354        provider_name: impl Into<&'static str>,
355        model: Models,
356    ) -> Result<AgentBuilder<CompletionModelHandle<'_>>, Error>
357    where
358        Models: ToString,
359    {
360        let key = Self::to_key(provider_name.into(), &model);
361
362        let client = self
363            .0
364            .get(&key)
365            .ok_or_else(|| Error::NotFound(key.clone()))
366            .and_then(|factory| (factory.from_env)())?;
367
368        let completion = client.as_completion().ok_or(Error::NotCapable {
369            provider: key,
370            role: "Completion".into(),
371        })?;
372
373        Ok(completion.agent(&model.to_string()))
374    }
375
376    /// Get a boxed completion model based on the provider and model.
377    pub fn completion<Models>(
378        &self,
379        provider_name: &'static str,
380        model: Models,
381    ) -> Result<Box<dyn CompletionModelDyn>, Error>
382    where
383        Models: ToString,
384    {
385        let key = Self::to_key(provider_name, &model);
386
387        let client = self
388            .0
389            .get(&key)
390            .ok_or_else(|| Error::NotFound(key.clone()))
391            .and_then(|factory| (factory.from_env)())?;
392
393        let completion = client.as_completion().ok_or(Error::NotCapable {
394            provider: key,
395            role: "Embedding Model".into(),
396        })?;
397
398        Ok(completion.completion_model(&model.to_string()))
399    }
400
401    /// Get a boxed embedding model based on the provider and model.
402    pub fn embeddings<Models>(
403        &self,
404        provider_name: &'static str,
405        model: Models,
406    ) -> Result<Box<dyn EmbeddingModelDyn>, Error>
407    where
408        Models: ToString,
409    {
410        let key = Self::to_key(provider_name, &model);
411
412        let client = self
413            .0
414            .get(&key)
415            .ok_or_else(|| Error::NotFound(key.clone()))
416            .and_then(|factory| (factory.from_env)())?;
417
418        let embeddings = client.as_embedding().ok_or(Error::NotCapable {
419            provider: key,
420            role: "Embedding Model".into(),
421        })?;
422
423        Ok(embeddings.embedding_model(&model.to_string()))
424    }
425
426    /// Get a boxed transcription model based on the provider and model.
427    pub fn transcription<Models>(
428        &self,
429        provider_name: &'static str,
430        model: Models,
431    ) -> Result<Box<dyn TranscriptionModelDyn>, Error>
432    where
433        Models: ToString,
434    {
435        let key = Self::to_key(provider_name, &model);
436
437        let client = self
438            .0
439            .get(&key)
440            .ok_or_else(|| Error::NotFound(key.clone()))
441            .and_then(|factory| (factory.from_env)())?;
442
443        let transcription = client.as_transcription().ok_or(Error::NotCapable {
444            provider: key,
445            role: "transcription model".into(),
446        })?;
447
448        Ok(transcription.transcription_model(&model.to_string()))
449    }
450
451    #[cfg(feature = "image")]
452    pub fn image_generation<Models>(
453        &self,
454        provider_name: &'static str,
455        model: Models,
456    ) -> Result<Box<dyn ImageGenerationModelDyn>, Error>
457    where
458        Models: ToString,
459    {
460        let key = Self::to_key(provider_name, &model);
461
462        let client = self
463            .0
464            .get(&key)
465            .ok_or_else(|| Error::NotFound(key.clone()))
466            .and_then(|factory| (factory.from_env)())?;
467
468        let image_generation = client.as_image_generation().ok_or(Error::NotCapable {
469            provider: key,
470            role: "Image generation".into(),
471        })?;
472
473        Ok(image_generation.image_generation_model(&model.to_string()))
474    }
475
476    #[cfg(feature = "audio")]
477    pub fn audio_generation<Models>(
478        &self,
479        provider_name: &'static str,
480        model: Models,
481    ) -> Result<Box<dyn AudioGenerationModelDyn>, Error>
482    where
483        Models: ToString,
484    {
485        let key = Self::to_key(provider_name, &model);
486
487        let client = self
488            .0
489            .get(&key)
490            .ok_or_else(|| Error::NotFound(key.clone()))
491            .and_then(|factory| (factory.from_env)())?;
492
493        let audio_generation = client.as_audio_generation().ok_or(Error::NotCapable {
494            provider: key,
495            role: "Image generation".into(),
496        })?;
497
498        Ok(audio_generation.audio_generation_model(&model.to_string()))
499    }
500
501    /// Stream a completion request to the specified provider and model.
502    pub async fn stream_completion<Models>(
503        &self,
504        provider_name: &'static str,
505        model: Models,
506        request: CompletionRequest,
507    ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, Error>
508    where
509        Models: ToString,
510    {
511        let completion = self.completion(provider_name, model)?;
512
513        completion.stream(request).await.map_err(Error::Completion)
514    }
515
516    /// Stream a simple prompt to the specified provider and model.
517    pub async fn stream_prompt<Models, Prompt>(
518        &self,
519        provider_name: impl Into<&'static str>,
520        model: Models,
521        prompt: Prompt,
522    ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, Error>
523    where
524        Models: ToString,
525        Prompt: Into<Message> + WasmCompatSend,
526    {
527        let completion = self.completion(provider_name.into(), model)?;
528
529        let request = CompletionRequest {
530            preamble: None,
531            tools: vec![],
532            documents: vec![],
533            temperature: None,
534            max_tokens: None,
535            additional_params: None,
536            tool_choice: None,
537            chat_history: crate::OneOrMany::one(prompt.into()),
538        };
539
540        completion.stream(request).await.map_err(Error::Completion)
541    }
542
543    /// Stream a chat with history to the specified provider and model.
544    pub async fn stream_chat<Models, Prompt>(
545        &self,
546        provider_name: &'static str,
547        model: Models,
548        prompt: Prompt,
549        mut history: Vec<Message>,
550    ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, Error>
551    where
552        Models: ToString,
553        Prompt: Into<Message> + WasmCompatSend,
554    {
555        let completion = self.completion(provider_name, model)?;
556
557        history.push(prompt.into());
558        let request = CompletionRequest {
559            preamble: None,
560            tools: vec![],
561            documents: vec![],
562            temperature: None,
563            max_tokens: None,
564            additional_params: None,
565            tool_choice: None,
566            chat_history: OneOrMany::many(history)
567                .unwrap_or_else(|_| OneOrMany::one(Message::user(""))),
568        };
569
570        completion.stream(request).await.map_err(Error::Completion)
571    }
572}