1use crate::agent::Agent;
2use crate::client::ProviderClient;
3use crate::embeddings::embedding::EmbeddingModelDyn;
4use crate::providers::{
5 anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
6 moonshot, ollama, openai, openrouter, perplexity, together, xai,
7};
8use crate::transcription::TranscriptionModelDyn;
9use rig::completion::CompletionModelDyn;
10use std::collections::HashMap;
11use std::panic::{RefUnwindSafe, UnwindSafe};
12use thiserror::Error;
13
14#[derive(Debug, Error)]
15pub enum ClientBuildError {
16 #[error("factory error: {}", .0)]
17 FactoryError(String),
18 #[error("invalid id string: {}", .0)]
19 InvalidIdString(String),
20 #[error("unsupported feature: {} for {}", .1, .0)]
21 UnsupportedFeature(String, String),
22 #[error("unknown provider")]
23 UnknownProvider,
24}
25
26pub type BoxCompletionModel<'a> = Box<dyn CompletionModelDyn + 'a>;
27pub type BoxAgentBuilder<'a> = AgentBuilder<CompletionModelHandle<'a>>;
28pub type BoxAgent<'a> = Agent<CompletionModelHandle<'a>>;
29pub type BoxEmbeddingModel<'a> = Box<dyn EmbeddingModelDyn + 'a>;
30pub type BoxTranscriptionModel<'a> = Box<dyn TranscriptionModelDyn + 'a>;
31
32pub struct DynClientBuilder {
61 registry: HashMap<String, ClientFactory>,
62}
63
64impl Default for DynClientBuilder {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl<'a> DynClientBuilder {
71 pub fn new() -> Self {
75 Self {
76 registry: HashMap::new(),
77 }
78 .register_all(vec![
79 ClientFactory::new(
80 DefaultProviders::ANTHROPIC,
81 anthropic::Client::from_env_boxed,
82 ),
83 ClientFactory::new(DefaultProviders::COHERE, cohere::Client::from_env_boxed),
84 ClientFactory::new(DefaultProviders::GEMINI, gemini::Client::from_env_boxed),
85 ClientFactory::new(
86 DefaultProviders::HUGGINGFACE,
87 huggingface::Client::from_env_boxed,
88 ),
89 ClientFactory::new(DefaultProviders::OPENAI, openai::Client::from_env_boxed),
90 ClientFactory::new(
91 DefaultProviders::OPENROUTER,
92 openrouter::Client::from_env_boxed,
93 ),
94 ClientFactory::new(DefaultProviders::TOGETHER, together::Client::from_env_boxed),
95 ClientFactory::new(DefaultProviders::XAI, xai::Client::from_env_boxed),
96 ClientFactory::new(DefaultProviders::AZURE, azure::Client::from_env_boxed),
97 ClientFactory::new(DefaultProviders::DEEPSEEK, deepseek::Client::from_env_boxed),
98 ClientFactory::new(
99 DefaultProviders::GALADRIEL,
100 galadriel::Client::from_env_boxed,
101 ),
102 ClientFactory::new(DefaultProviders::GROQ, groq::Client::from_env_boxed),
103 ClientFactory::new(
104 DefaultProviders::HYPERBOLIC,
105 hyperbolic::Client::from_env_boxed,
106 ),
107 ClientFactory::new(DefaultProviders::MOONSHOT, moonshot::Client::from_env_boxed),
108 ClientFactory::new(DefaultProviders::MIRA, mira::Client::from_env_boxed),
109 ClientFactory::new(DefaultProviders::MISTRAL, mistral::Client::from_env_boxed),
110 ClientFactory::new(DefaultProviders::OLLAMA, ollama::Client::from_env_boxed),
111 ClientFactory::new(
112 DefaultProviders::PERPLEXITY,
113 perplexity::Client::from_env_boxed,
114 ),
115 ])
116 }
117
118 pub fn empty() -> Self {
120 Self {
121 registry: HashMap::new(),
122 }
123 }
124
125 pub fn register(mut self, client_factory: ClientFactory) -> Self {
127 self.registry
128 .insert(client_factory.name.clone(), client_factory);
129 self
130 }
131
132 pub fn register_all(mut self, factories: impl IntoIterator<Item = ClientFactory>) -> Self {
134 for factory in factories {
135 self.registry.insert(factory.name.clone(), factory);
136 }
137
138 self
139 }
140
141 pub fn build(&self, provider: &str) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
143 let factory = self.get_factory(provider)?;
144 factory.build()
145 }
146
147 pub fn parse(&self, id: &'a str) -> Result<(&'a str, &'a str), ClientBuildError> {
150 let (provider, model) = id
151 .split_once(":")
152 .ok_or(ClientBuildError::InvalidIdString(id.to_string()))?;
153
154 Ok((provider, model))
155 }
156
157 fn get_factory(&self, provider: &str) -> Result<&ClientFactory, ClientBuildError> {
159 self.registry
160 .get(provider)
161 .ok_or(ClientBuildError::UnknownProvider)
162 }
163
164 pub fn completion(
166 &self,
167 provider: &str,
168 model: &str,
169 ) -> Result<BoxCompletionModel<'a>, ClientBuildError> {
170 let client = self.build(provider)?;
171
172 let completion = client
173 .as_completion()
174 .ok_or(ClientBuildError::UnsupportedFeature(
175 provider.to_string(),
176 "completion".to_owned(),
177 ))?;
178
179 Ok(completion.completion_model(model))
180 }
181
182 pub fn agent(
184 &self,
185 provider: &str,
186 model: &str,
187 ) -> Result<BoxAgentBuilder<'a>, ClientBuildError> {
188 let client = self.build(provider)?;
189
190 let client = client
191 .as_completion()
192 .ok_or(ClientBuildError::UnsupportedFeature(
193 provider.to_string(),
194 "completion".to_string(),
195 ))?;
196
197 Ok(client.agent(model))
198 }
199
200 pub fn embeddings(
202 &self,
203 provider: &str,
204 model: &str,
205 ) -> Result<Box<dyn EmbeddingModelDyn + 'a>, ClientBuildError> {
206 let client = self.build(provider)?;
207
208 let embeddings = client
209 .as_embeddings()
210 .ok_or(ClientBuildError::UnsupportedFeature(
211 provider.to_string(),
212 "embeddings".to_owned(),
213 ))?;
214
215 Ok(embeddings.embedding_model(model))
216 }
217
218 pub fn transcription(
220 &self,
221 provider: &str,
222 model: &str,
223 ) -> Result<Box<dyn TranscriptionModelDyn + 'a>, ClientBuildError> {
224 let client = self.build(provider)?;
225 let transcription =
226 client
227 .as_transcription()
228 .ok_or(ClientBuildError::UnsupportedFeature(
229 provider.to_string(),
230 "transcription".to_owned(),
231 ))?;
232
233 Ok(transcription.transcription_model(model))
234 }
235
236 pub fn id<'id>(&'a self, id: &'id str) -> Result<ProviderModelId<'a, 'id>, ClientBuildError> {
238 let (provider, model) = self.parse(id)?;
239
240 Ok(ProviderModelId {
241 builder: self,
242 provider,
243 model,
244 })
245 }
246}
247
248pub struct ProviderModelId<'builder, 'id> {
249 builder: &'builder DynClientBuilder,
250 provider: &'id str,
251 model: &'id str,
252}
253
254impl<'builder> ProviderModelId<'builder, '_> {
255 pub fn completion(self) -> Result<BoxCompletionModel<'builder>, ClientBuildError> {
256 self.builder.completion(self.provider, self.model)
257 }
258
259 pub fn agent(self) -> Result<BoxAgentBuilder<'builder>, ClientBuildError> {
260 self.builder.agent(self.provider, self.model)
261 }
262
263 pub fn embedding(self) -> Result<BoxEmbeddingModel<'builder>, ClientBuildError> {
264 self.builder.embeddings(self.provider, self.model)
265 }
266
267 pub fn transcription(self) -> Result<BoxTranscriptionModel<'builder>, ClientBuildError> {
268 self.builder.transcription(self.provider, self.model)
269 }
270}
271
272#[cfg(feature = "image")]
273mod image {
274 use crate::client::builder::ClientBuildError;
275 use crate::image_generation::ImageGenerationModelDyn;
276 use rig::client::builder::{DynClientBuilder, ProviderModelId};
277
278 pub type BoxImageGenerationModel<'a> = Box<dyn ImageGenerationModelDyn + 'a>;
279
280 impl DynClientBuilder {
281 pub fn image_generation<'a>(
282 &self,
283 provider: &str,
284 model: &str,
285 ) -> Result<BoxImageGenerationModel<'a>, ClientBuildError> {
286 let client = self.build(provider)?;
287 let image =
288 client
289 .as_image_generation()
290 .ok_or(ClientBuildError::UnsupportedFeature(
291 provider.to_string(),
292 "image_generation".to_string(),
293 ))?;
294
295 Ok(image.image_generation_model(model))
296 }
297 }
298
299 impl<'builder> ProviderModelId<'builder, '_> {
300 pub fn image_generation(
301 self,
302 ) -> Result<Box<dyn ImageGenerationModelDyn + 'builder>, ClientBuildError> {
303 self.builder.image_generation(self.provider, self.model)
304 }
305 }
306}
307#[cfg(feature = "image")]
308pub use image::*;
309
310#[cfg(feature = "audio")]
311mod audio {
312 use crate::audio_generation::AudioGenerationModelDyn;
313 use crate::client::builder::DynClientBuilder;
314 use crate::client::builder::{ClientBuildError, ProviderModelId};
315
316 pub type BoxAudioGenerationModel<'a> = Box<dyn AudioGenerationModelDyn + 'a>;
317
318 impl DynClientBuilder {
319 pub fn audio_generation<'a>(
320 &self,
321 provider: &str,
322 model: &str,
323 ) -> Result<BoxAudioGenerationModel<'a>, ClientBuildError> {
324 let client = self.build(provider)?;
325 let audio =
326 client
327 .as_audio_generation()
328 .ok_or(ClientBuildError::UnsupportedFeature(
329 provider.to_string(),
330 "audio_generation".to_owned(),
331 ))?;
332
333 Ok(audio.audio_generation_model(model))
334 }
335 }
336
337 impl<'builder> ProviderModelId<'builder, '_> {
338 pub fn audio_generation(
339 self,
340 ) -> Result<Box<dyn AudioGenerationModelDyn + 'builder>, ClientBuildError> {
341 self.builder.audio_generation(self.provider, self.model)
342 }
343 }
344}
345use crate::agent::AgentBuilder;
346use crate::client::completion::CompletionModelHandle;
347#[cfg(feature = "audio")]
348pub use audio::*;
349use rig::providers::mistral;
350
351pub struct ClientFactory {
352 pub name: String,
353 pub factory: Box<dyn Fn() -> Box<dyn ProviderClient>>,
354}
355
356impl UnwindSafe for ClientFactory {}
357impl RefUnwindSafe for ClientFactory {}
358
359impl ClientFactory {
360 pub fn new<F: 'static + Fn() -> Box<dyn ProviderClient>>(name: &str, func: F) -> Self {
361 Self {
362 name: name.to_string(),
363 factory: Box::new(func),
364 }
365 }
366
367 pub fn build(&self) -> Result<Box<dyn ProviderClient>, ClientBuildError> {
368 std::panic::catch_unwind(|| (self.factory)())
369 .map_err(|e| ClientBuildError::FactoryError(format!("{e:?}")))
370 }
371}
372
373pub struct DefaultProviders;
374impl DefaultProviders {
375 pub const ANTHROPIC: &'static str = "anthropic";
376 pub const COHERE: &'static str = "cohere";
377 pub const GEMINI: &'static str = "gemini";
378 pub const HUGGINGFACE: &'static str = "huggingface";
379 pub const OPENAI: &'static str = "openai";
380 pub const OPENROUTER: &'static str = "openrouter";
381 pub const TOGETHER: &'static str = "together";
382 pub const XAI: &'static str = "xai";
383 pub const AZURE: &'static str = "azure";
384 pub const DEEPSEEK: &'static str = "deepseek";
385 pub const GALADRIEL: &'static str = "galadriel";
386 pub const GROQ: &'static str = "groq";
387 pub const HYPERBOLIC: &'static str = "hyperbolic";
388 pub const MOONSHOT: &'static str = "moonshot";
389 pub const MIRA: &'static str = "mira";
390 pub const MISTRAL: &'static str = "mistral";
391 pub const OLLAMA: &'static str = "ollama";
392 pub const PERPLEXITY: &'static str = "perplexity";
393}