1#[allow(deprecated)]
2#[cfg(feature = "audio")]
3use super::audio_generation::AudioGenerationClientDyn;
4#[cfg(feature = "image")]
5#[allow(deprecated)]
6use super::image_generation::ImageGenerationClientDyn;
7#[allow(deprecated)]
8#[cfg(feature = "audio")]
9use crate::audio_generation::AudioGenerationModelDyn;
10#[cfg(feature = "image")]
11#[allow(deprecated)]
12use crate::image_generation::ImageGenerationModelDyn;
13#[allow(deprecated)]
14use crate::{
15 OneOrMany,
16 agent::AgentBuilder,
17 client::{
18 Capabilities, Capability, Client, FinalCompletionResponse, Provider, ProviderClient,
19 completion::{CompletionClientDyn, CompletionModelHandle},
20 embeddings::EmbeddingsClientDyn,
21 transcription::TranscriptionClientDyn,
22 },
23 completion::{CompletionError, CompletionModelDyn, CompletionRequest},
24 embeddings::EmbeddingModelDyn,
25 message::Message,
26 providers::{
27 anthropic, azure, cohere, deepseek, galadriel, gemini, groq, huggingface, hyperbolic, mira,
28 mistral, moonshot, ollama, openai, openrouter, perplexity, together, xai,
29 },
30 streaming::StreamingCompletionResponse,
31 transcription::TranscriptionModelDyn,
32 wasm_compat::{WasmCompatSend, WasmCompatSync},
33};
34use std::{any::Any, collections::HashMap};
35
36#[derive(Debug, thiserror::Error)]
37pub enum Error {
38 #[error("Provider '{0}' not found")]
39 NotFound(String),
40 #[error("Provider '{provider}' cannot be coerced to a '{role}'")]
41 NotCapable { provider: String, role: String },
42 #[error("Error generating response\n{0}")]
43 Completion(#[from] CompletionError),
44}
45
46#[deprecated(
47 since = "0.25.0",
48 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release."
49)]
50pub struct AnyClient {
51 client: Box<dyn Any + 'static>,
52 vtable: AnyClientVTable,
53}
54
55struct AnyClientVTable {
56 #[allow(deprecated)]
57 as_completion: fn(&dyn Any) -> Option<&&dyn CompletionClientDyn>,
58 #[allow(deprecated)]
59 as_embedding: fn(&dyn Any) -> Option<&&dyn EmbeddingsClientDyn>,
60 #[allow(deprecated)]
61 as_transcription: fn(&dyn Any) -> Option<&&dyn TranscriptionClientDyn>,
62 #[allow(deprecated)]
63 #[cfg(feature = "image")]
64 as_image_generation: fn(&dyn Any) -> Option<&&dyn ImageGenerationClientDyn>,
65 #[allow(deprecated)]
66 #[cfg(feature = "audio")]
67 as_audio_generation: fn(&dyn Any) -> Option<&&dyn AudioGenerationClientDyn>,
68}
69
70#[allow(deprecated)]
71impl AnyClient {
72 pub fn new<Ext, H>(client: Client<Ext, H>) -> Self
73 where
74 Ext: Provider + Capabilities + WasmCompatSend + WasmCompatSync + 'static,
75 H: WasmCompatSend + WasmCompatSync + 'static,
76 Client<Ext, H>: WasmCompatSend + WasmCompatSync + 'static,
77 {
78 Self {
79 client: Box::new(client),
80 vtable: AnyClientVTable {
81 as_completion: if <<Ext as Capabilities>::Completion as Capability>::CAPABLE {
82 |any| any.downcast_ref()
83 } else {
84 |_| None
85 },
86
87 as_embedding: if <<Ext as Capabilities>::Embeddings as Capability>::CAPABLE {
88 |any| any.downcast_ref()
89 } else {
90 |_| None
91 },
92
93 as_transcription: if <<Ext as Capabilities>::Transcription as Capability>::CAPABLE {
94 |any| any.downcast_ref()
95 } else {
96 |_| None
97 },
98
99 #[cfg(feature = "image")]
100 as_image_generation:
101 if <<Ext as Capabilities>::ImageGeneration as Capability>::CAPABLE {
102 |any| any.downcast_ref()
103 } else {
104 |_| None
105 },
106
107 #[cfg(feature = "audio")]
108 as_audio_generation:
109 if <<Ext as Capabilities>::AudioGeneration as Capability>::CAPABLE {
110 |any| any.downcast_ref()
111 } else {
112 |_| None
113 },
114 },
115 }
116 }
117
118 pub fn as_completion(&self) -> Option<&dyn CompletionClientDyn> {
119 (self.vtable.as_completion)(self.client.as_ref()).copied()
120 }
121
122 pub fn as_embedding(&self) -> Option<&dyn EmbeddingsClientDyn> {
123 (self.vtable.as_embedding)(self.client.as_ref()).copied()
124 }
125
126 pub fn as_transcription(&self) -> Option<&dyn TranscriptionClientDyn> {
127 (self.vtable.as_transcription)(self.client.as_ref()).copied()
128 }
129
130 #[cfg(feature = "image")]
131 pub fn as_image_generation(&self) -> Option<&dyn ImageGenerationClientDyn> {
132 (self.vtable.as_image_generation)(self.client.as_ref()).copied()
133 }
134
135 #[cfg(feature = "audio")]
136 pub fn as_audio_generation(&self) -> Option<&dyn AudioGenerationClientDyn> {
137 (self.vtable.as_audio_generation)(self.client.as_ref()).copied()
138 }
139}
140
141#[deprecated(
142 since = "0.25.0",
143 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release."
144)]
145#[derive(Debug, Clone)]
146pub struct ProviderFactory {
147 #[allow(deprecated)]
149 from_env: fn() -> Result<AnyClient, Error>,
150}
151
152#[allow(deprecated)]
153#[deprecated(
154 since = "0.25.0",
155 note = "`DynClientBuilder` and related features have been deprecated and will be removed in a future release."
156)]
157#[derive(Debug, Clone)]
158pub struct DynClientBuilder(HashMap<String, ProviderFactory>);
159
160#[allow(deprecated)]
161impl Default for DynClientBuilder {
162 fn default() -> Self {
163 Self(HashMap::with_capacity(32))
165 }
166}
167
168#[repr(u8)]
169#[derive(Debug, Clone, Copy)]
170pub enum DefaultProviders {
171 Anthropic,
172 Cohere,
173 Gemini,
174 HuggingFace,
175 OpenAI,
176 OpenRouter,
177 Together,
178 XAI,
179 Azure,
180 DeepSeek,
181 Galadriel,
182 Groq,
183 Hyperbolic,
184 Moonshot,
185 Mira,
186 Mistral,
187 Ollama,
188 Perplexity,
189}
190
191impl From<DefaultProviders> for &'static str {
192 fn from(value: DefaultProviders) -> Self {
193 use DefaultProviders::*;
194
195 match value {
196 Anthropic => "anthropic",
197 Cohere => "cohere",
198 Gemini => "gemini",
199 HuggingFace => "huggingface",
200 OpenAI => "openai",
201 OpenRouter => "openrouter",
202 Together => "together",
203 XAI => "xai",
204 Azure => "azure",
205 DeepSeek => "deepseek",
206 Galadriel => "galadriel",
207 Groq => "groq",
208 Hyperbolic => "hyperbolic",
209 Moonshot => "moonshot",
210 Mira => "mira",
211 Mistral => "mistral",
212 Ollama => "ollama",
213 Perplexity => "perplexity",
214 }
215 }
216}
217pub use DefaultProviders::*;
218
219impl std::fmt::Display for DefaultProviders {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 let s: &str = (*self).into();
222 f.write_str(s)
223 }
224}
225
226impl DefaultProviders {
227 fn all() -> impl Iterator<Item = Self> {
228 use DefaultProviders::*;
229
230 [
231 Anthropic,
232 Cohere,
233 Gemini,
234 HuggingFace,
235 OpenAI,
236 OpenRouter,
237 Together,
238 XAI,
239 Azure,
240 DeepSeek,
241 Galadriel,
242 Groq,
243 Hyperbolic,
244 Moonshot,
245 Mira,
246 Mistral,
247 Ollama,
248 Perplexity,
249 ]
250 .into_iter()
251 }
252
253 #[allow(deprecated)]
254 fn get_env_fn(self) -> fn() -> Result<AnyClient, Error> {
255 use DefaultProviders::*;
256
257 match self {
258 Anthropic => || Ok(AnyClient::new(anthropic::Client::from_env())),
259 Cohere => || Ok(AnyClient::new(cohere::Client::from_env())),
260 Gemini => || Ok(AnyClient::new(gemini::Client::from_env())),
261 HuggingFace => || Ok(AnyClient::new(huggingface::Client::from_env())),
262 OpenAI => || Ok(AnyClient::new(openai::Client::from_env())),
263 OpenRouter => || Ok(AnyClient::new(openrouter::Client::from_env())),
264 Together => || Ok(AnyClient::new(together::Client::from_env())),
265 XAI => || Ok(AnyClient::new(xai::Client::from_env())),
266 Azure => || Ok(AnyClient::new(azure::Client::from_env())),
267 DeepSeek => || Ok(AnyClient::new(deepseek::Client::from_env())),
268 Galadriel => || Ok(AnyClient::new(galadriel::Client::from_env())),
269 Groq => || Ok(AnyClient::new(groq::Client::from_env())),
270 Hyperbolic => || Ok(AnyClient::new(hyperbolic::Client::from_env())),
271 Moonshot => || Ok(AnyClient::new(moonshot::Client::from_env())),
272 Mira => || Ok(AnyClient::new(mira::Client::from_env())),
273 Mistral => || Ok(AnyClient::new(mistral::Client::from_env())),
274 Ollama => || Ok(AnyClient::new(ollama::Client::from_env())),
275 Perplexity => || Ok(AnyClient::new(perplexity::Client::from_env())),
276 }
277 }
278}
279
280#[allow(deprecated)]
281impl DynClientBuilder {
282 pub fn new() -> Self {
283 Self::default().register_all()
284 }
285
286 fn register_all(mut self) -> Self {
287 for provider in DefaultProviders::all() {
288 let from_env = provider.get_env_fn();
289 self.0
290 .insert(provider.to_string(), ProviderFactory { from_env });
291 }
292
293 self
294 }
295
296 fn to_key<Models>(provider_name: &'static str, model: &Models) -> String
297 where
298 Models: ToString,
299 {
300 format!("{provider_name}:{}", model.to_string())
301 }
302
303 pub fn register<Ext, H, Models>(mut self, provider_name: &'static str, model: Models) -> Self
304 where
305 Ext: Provider + Capabilities + WasmCompatSend + WasmCompatSync + 'static,
306 H: Default + WasmCompatSend + WasmCompatSync + 'static,
307 Client<Ext, H>: ProviderClient + WasmCompatSend + WasmCompatSync + 'static,
308 Models: ToString,
309 {
310 let key = Self::to_key(provider_name, &model);
311
312 let factory = ProviderFactory {
313 from_env: || Ok(AnyClient::new(Client::<Ext, H>::from_env())),
314 };
315
316 self.0.insert(key, factory);
317
318 self
319 }
320
321 pub fn from_env<T, Models>(
322 &self,
323 provider_name: &'static str,
324 model: Models,
325 ) -> Result<AnyClient, Error>
326 where
327 T: 'static,
328 Models: ToString,
329 {
330 let key = Self::to_key(provider_name, &model);
331
332 self.0
333 .get(&key)
334 .ok_or(Error::NotFound(key))
335 .and_then(|factory| (factory.from_env)())
336 }
337
338 pub fn factory<Models>(
339 &self,
340 provider_name: &'static str,
341 model: Models,
342 ) -> Option<&ProviderFactory>
343 where
344 Models: ToString,
345 {
346 let key = Self::to_key(provider_name, &model);
347
348 self.0.get(&key)
349 }
350
351 pub fn agent<Models>(
353 &self,
354 provider_name: impl Into<&'static str>,
355 model: Models,
356 ) -> Result<AgentBuilder<CompletionModelHandle<'_>>, Error>
357 where
358 Models: ToString,
359 {
360 let key = Self::to_key(provider_name.into(), &model);
361
362 let client = self
363 .0
364 .get(&key)
365 .ok_or_else(|| Error::NotFound(key.clone()))
366 .and_then(|factory| (factory.from_env)())?;
367
368 let completion = client.as_completion().ok_or(Error::NotCapable {
369 provider: key,
370 role: "Completion".into(),
371 })?;
372
373 Ok(completion.agent(&model.to_string()))
374 }
375
376 pub fn completion<Models>(
378 &self,
379 provider_name: &'static str,
380 model: Models,
381 ) -> Result<Box<dyn CompletionModelDyn>, Error>
382 where
383 Models: ToString,
384 {
385 let key = Self::to_key(provider_name, &model);
386
387 let client = self
388 .0
389 .get(&key)
390 .ok_or_else(|| Error::NotFound(key.clone()))
391 .and_then(|factory| (factory.from_env)())?;
392
393 let completion = client.as_completion().ok_or(Error::NotCapable {
394 provider: key,
395 role: "Embedding Model".into(),
396 })?;
397
398 Ok(completion.completion_model(&model.to_string()))
399 }
400
401 pub fn embeddings<Models>(
403 &self,
404 provider_name: &'static str,
405 model: Models,
406 ) -> Result<Box<dyn EmbeddingModelDyn>, Error>
407 where
408 Models: ToString,
409 {
410 let key = Self::to_key(provider_name, &model);
411
412 let client = self
413 .0
414 .get(&key)
415 .ok_or_else(|| Error::NotFound(key.clone()))
416 .and_then(|factory| (factory.from_env)())?;
417
418 let embeddings = client.as_embedding().ok_or(Error::NotCapable {
419 provider: key,
420 role: "Embedding Model".into(),
421 })?;
422
423 Ok(embeddings.embedding_model(&model.to_string()))
424 }
425
426 pub fn transcription<Models>(
428 &self,
429 provider_name: &'static str,
430 model: Models,
431 ) -> Result<Box<dyn TranscriptionModelDyn>, Error>
432 where
433 Models: ToString,
434 {
435 let key = Self::to_key(provider_name, &model);
436
437 let client = self
438 .0
439 .get(&key)
440 .ok_or_else(|| Error::NotFound(key.clone()))
441 .and_then(|factory| (factory.from_env)())?;
442
443 let transcription = client.as_transcription().ok_or(Error::NotCapable {
444 provider: key,
445 role: "transcription model".into(),
446 })?;
447
448 Ok(transcription.transcription_model(&model.to_string()))
449 }
450
451 #[cfg(feature = "image")]
452 pub fn image_generation<Models>(
453 &self,
454 provider_name: &'static str,
455 model: Models,
456 ) -> Result<Box<dyn ImageGenerationModelDyn>, Error>
457 where
458 Models: ToString,
459 {
460 let key = Self::to_key(provider_name, &model);
461
462 let client = self
463 .0
464 .get(&key)
465 .ok_or_else(|| Error::NotFound(key.clone()))
466 .and_then(|factory| (factory.from_env)())?;
467
468 let image_generation = client.as_image_generation().ok_or(Error::NotCapable {
469 provider: key,
470 role: "Image generation".into(),
471 })?;
472
473 Ok(image_generation.image_generation_model(&model.to_string()))
474 }
475
476 #[cfg(feature = "audio")]
477 pub fn audio_generation<Models>(
478 &self,
479 provider_name: &'static str,
480 model: Models,
481 ) -> Result<Box<dyn AudioGenerationModelDyn>, Error>
482 where
483 Models: ToString,
484 {
485 let key = Self::to_key(provider_name, &model);
486
487 let client = self
488 .0
489 .get(&key)
490 .ok_or_else(|| Error::NotFound(key.clone()))
491 .and_then(|factory| (factory.from_env)())?;
492
493 let audio_generation = client.as_audio_generation().ok_or(Error::NotCapable {
494 provider: key,
495 role: "Image generation".into(),
496 })?;
497
498 Ok(audio_generation.audio_generation_model(&model.to_string()))
499 }
500
501 pub async fn stream_completion<Models>(
503 &self,
504 provider_name: &'static str,
505 model: Models,
506 request: CompletionRequest,
507 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, Error>
508 where
509 Models: ToString,
510 {
511 let completion = self.completion(provider_name, model)?;
512
513 completion.stream(request).await.map_err(Error::Completion)
514 }
515
516 pub async fn stream_prompt<Models, Prompt>(
518 &self,
519 provider_name: impl Into<&'static str>,
520 model: Models,
521 prompt: Prompt,
522 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, Error>
523 where
524 Models: ToString,
525 Prompt: Into<Message> + WasmCompatSend,
526 {
527 let completion = self.completion(provider_name.into(), model)?;
528
529 let request = CompletionRequest {
530 preamble: None,
531 tools: vec![],
532 documents: vec![],
533 temperature: None,
534 max_tokens: None,
535 additional_params: None,
536 tool_choice: None,
537 chat_history: crate::OneOrMany::one(prompt.into()),
538 };
539
540 completion.stream(request).await.map_err(Error::Completion)
541 }
542
543 pub async fn stream_chat<Models, Prompt>(
545 &self,
546 provider_name: &'static str,
547 model: Models,
548 prompt: Prompt,
549 mut history: Vec<Message>,
550 ) -> Result<StreamingCompletionResponse<FinalCompletionResponse>, Error>
551 where
552 Models: ToString,
553 Prompt: Into<Message> + WasmCompatSend,
554 {
555 let completion = self.completion(provider_name, model)?;
556
557 history.push(prompt.into());
558 let request = CompletionRequest {
559 preamble: None,
560 tools: vec![],
561 documents: vec![],
562 temperature: None,
563 max_tokens: None,
564 additional_params: None,
565 tool_choice: None,
566 chat_history: OneOrMany::many(history)
567 .unwrap_or_else(|_| OneOrMany::one(Message::user(""))),
568 };
569
570 completion.stream(request).await.map_err(Error::Completion)
571 }
572}