1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13#[cfg(feature = "derive")]
14pub use rig_derive::ProviderClient;
15use std::fmt::Debug;
16use thiserror::Error;
17
18#[derive(Debug, Error)]
19#[non_exhaustive]
20pub enum ClientBuilderError {
21 #[error("reqwest error: {0}")]
22 HttpError(
23 #[from]
24 #[source]
25 reqwest::Error,
26 ),
27 #[error("invalid property: {0}")]
28 InvalidProperty(&'static str),
29}
30
31pub trait ProviderClient:
37 AsCompletion
38 + AsTranscription
39 + AsEmbeddings
40 + AsImageGeneration
41 + AsAudioGeneration
42 + Debug
43 + WasmCompatSend
44 + WasmCompatSync
45{
46 fn from_env() -> Self
49 where
50 Self: Sized;
51
52 fn boxed(self) -> Box<dyn ProviderClient>
54 where
55 Self: Sized + 'static,
56 {
57 Box::new(self)
58 }
59
60 fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
63 where
64 Self: Sized,
65 Self: 'a,
66 {
67 Box::new(Self::from_env())
68 }
69
70 fn from_val(input: ProviderValue) -> Self
71 where
72 Self: Sized;
73
74 fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
77 where
78 Self: Sized,
79 Self: 'a,
80 {
81 Box::new(Self::from_val(input))
82 }
83}
84
85#[derive(Clone)]
86pub enum ProviderValue {
87 Simple(String),
88 ApiKeyWithOptionalKey(String, Option<String>),
89 ApiKeyWithVersionAndHeader(String, String, String),
90}
91
92impl From<&str> for ProviderValue {
93 fn from(value: &str) -> Self {
94 Self::Simple(value.to_string())
95 }
96}
97
98impl From<String> for ProviderValue {
99 fn from(value: String) -> Self {
100 Self::Simple(value)
101 }
102}
103
104impl<P> From<(P, Option<P>)> for ProviderValue
105where
106 P: AsRef<str>,
107{
108 fn from((api_key, optional_key): (P, Option<P>)) -> Self {
109 Self::ApiKeyWithOptionalKey(
110 api_key.as_ref().to_string(),
111 optional_key.map(|x| x.as_ref().to_string()),
112 )
113 }
114}
115
116impl<P> From<(P, P, P)> for ProviderValue
117where
118 P: AsRef<str>,
119{
120 fn from((api_key, version, header): (P, P, P)) -> Self {
121 Self::ApiKeyWithVersionAndHeader(
122 api_key.as_ref().to_string(),
123 version.as_ref().to_string(),
124 header.as_ref().to_string(),
125 )
126 }
127}
128
129pub trait AsCompletion {
131 fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
132 None
133 }
134}
135
136pub trait AsTranscription {
138 fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
139 None
140 }
141}
142
143pub trait AsEmbeddings {
145 fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
146 None
147 }
148}
149
150pub trait AsAudioGeneration {
152 #[cfg(feature = "audio")]
153 fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
154 None
155 }
156}
157
158pub trait AsImageGeneration {
160 #[cfg(feature = "image")]
161 fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
162 None
163 }
164}
165
166pub trait AsVerify {
168 fn as_verify(&self) -> Option<Box<dyn VerifyClientDyn>> {
169 None
170 }
171}
172
173#[cfg(not(feature = "audio"))]
174impl<T: ProviderClient> AsAudioGeneration for T {}
175
176#[cfg(not(feature = "image"))]
177impl<T: ProviderClient> AsImageGeneration for T {}
178
179#[macro_export]
180macro_rules! impl_conversion_traits {
181 ($( $trait_:ident ),* for $($type_spec:tt)+) => {
182 impl_conversion_traits!(@expand_traits [$($trait_)+] $($type_spec)+);
183 };
184
185 (@expand_traits [$trait_:ident $($rest_traits:ident)*] $($type_spec:tt)+) => {
186 impl_conversion_traits!(@impl $trait_ for $($type_spec)+);
187 impl_conversion_traits!(@expand_traits [$($rest_traits)*] $($type_spec)+);
188 };
189
190 (@expand_traits [] $($type_spec:tt)+) => {};
191
192 (@impl AsAudioGeneration for $($type_spec:tt)+) => {
193 rig::client::impl_audio_generation!($($type_spec)+);
194 };
195
196 (@impl AsImageGeneration for $($type_spec:tt)+) => {
197 rig::client::impl_image_generation!($($type_spec)+);
198 };
199
200 (@impl $trait_:ident for $($type_spec:tt)+) => {
201 impl_conversion_traits!(@impl_trait $trait_ for $($type_spec)+);
202 };
203
204 (@impl_trait $trait_:ident for $struct_:ident) => {
205 impl rig::client::$trait_ for $struct_ {}
206 };
207
208 (@impl_trait $trait_:ident for $struct_:ident<$($generics:tt),*>) => {
209 impl<$($generics),*> rig::client::$trait_ for $struct_<$($generics),*> {}
210 };
211}
212
213#[cfg(feature = "audio")]
214#[macro_export]
215macro_rules! impl_audio_generation {
216 ($struct_:ident) => {
217 impl rig::client::AsAudioGeneration for $struct_ {}
218 };
219 ($struct_:ident<$($generics:tt),*>) => {
220 impl<$($generics),*> rig::client::AsAudioGeneration for $struct_<$($generics),*> {}
221 };
222}
223
224#[cfg(not(feature = "audio"))]
225#[macro_export]
226macro_rules! impl_audio_generation {
227 ($($tokens:tt)*) => {};
228}
229
230#[cfg(feature = "image")]
231#[macro_export]
232macro_rules! impl_image_generation {
233 ($struct_:ident) => {
234 impl rig::client::AsImageGeneration for $struct_ {}
235 };
236 ($struct_:ident<$($generics:tt),*>) => {
237 impl<$($generics),*> rig::client::AsImageGeneration for $struct_<$($generics),*> {}
238 };
239}
240
241#[cfg(not(feature = "image"))]
242#[macro_export]
243macro_rules! impl_image_generation {
244 ($($tokens:tt)*) => {};
245}
246
247pub use impl_audio_generation;
248pub use impl_conversion_traits;
249pub use impl_image_generation;
250
251#[cfg(feature = "audio")]
252use crate::client::audio_generation::AudioGenerationClientDyn;
253use crate::client::completion::CompletionClientDyn;
254use crate::client::embeddings::EmbeddingsClientDyn;
255#[cfg(feature = "image")]
256use crate::client::image_generation::ImageGenerationClientDyn;
257use crate::client::transcription::TranscriptionClientDyn;
258use crate::client::verify::VerifyClientDyn;
259
260#[cfg(feature = "audio")]
261pub use crate::client::audio_generation::AudioGenerationClient;
262pub use crate::client::completion::CompletionClient;
263pub use crate::client::embeddings::EmbeddingsClient;
264#[cfg(feature = "image")]
265pub use crate::client::image_generation::ImageGenerationClient;
266pub use crate::client::transcription::TranscriptionClient;
267pub use crate::client::verify::{VerifyClient, VerifyError};
268use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
269
270#[cfg(test)]
271mod tests {
272 use crate::OneOrMany;
273 use crate::client::ProviderClient;
274 use crate::completion::{Completion, CompletionRequest, ToolDefinition};
275 use crate::image_generation::ImageGenerationRequest;
276 use crate::message::AssistantContent;
277 use crate::providers::{
278 anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
279 moonshot, openai, openrouter, together, xai,
280 };
281 use crate::streaming::StreamingCompletion;
282 use crate::tool::Tool;
283 use crate::transcription::TranscriptionRequest;
284 use futures::StreamExt;
285 use rig::message::Message;
286 use rig::providers::{groq, ollama, perplexity};
287 use serde::{Deserialize, Serialize};
288 use serde_json::json;
289 use std::fs::File;
290 use std::io::Read;
291
292 use super::ProviderValue;
293
294 struct ClientConfig {
295 name: &'static str,
296 factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
297 #[allow(dead_code)]
299 factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
300 env_variable: &'static str,
301 completion_model: Option<&'static str>,
302 embeddings_model: Option<&'static str>,
303 transcription_model: Option<&'static str>,
304 image_generation_model: Option<&'static str>,
305 audio_generation_model: Option<(&'static str, &'static str)>,
306 }
307
308 impl Default for ClientConfig {
309 fn default() -> Self {
310 Self {
311 name: "",
312 factory_env: Box::new(|| panic!("Not implemented")),
313 factory_val: Box::new(|_| panic!("Not implemented")),
314 env_variable: "",
315 completion_model: None,
316 embeddings_model: None,
317 transcription_model: None,
318 image_generation_model: None,
319 audio_generation_model: None,
320 }
321 }
322 }
323
324 impl ClientConfig {
325 fn is_env_var_set(&self) -> bool {
326 self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
327 }
328
329 fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
330 self.factory_env.as_ref()()
331 }
332 }
333
334 fn providers() -> Vec<ClientConfig> {
335 vec![
336 ClientConfig {
337 name: "Anthropic",
338 factory_env: Box::new(anthropic::Client::<reqwest::Client>::from_env_boxed),
339 factory_val: Box::new(anthropic::Client::<reqwest::Client>::from_val_boxed),
340 env_variable: "ANTHROPIC_API_KEY",
341 completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
342 ..Default::default()
343 },
344 ClientConfig {
345 name: "Cohere",
346 factory_env: Box::new(cohere::Client::<reqwest::Client>::from_env_boxed),
347 factory_val: Box::new(cohere::Client::<reqwest::Client>::from_val_boxed),
348 env_variable: "COHERE_API_KEY",
349 completion_model: Some(cohere::COMMAND_R),
350 embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
351 ..Default::default()
352 },
353 ClientConfig {
354 name: "Gemini",
355 factory_env: Box::new(gemini::Client::<reqwest::Client>::from_env_boxed),
356 factory_val: Box::new(gemini::Client::<reqwest::Client>::from_val_boxed),
357 env_variable: "GEMINI_API_KEY",
358 completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
359 embeddings_model: Some(gemini::embedding::EMBEDDING_001),
360 transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
361 ..Default::default()
362 },
363 ClientConfig {
364 name: "Huggingface",
365 factory_env: Box::new(huggingface::Client::<reqwest::Client>::from_env_boxed),
366 factory_val: Box::new(huggingface::Client::<reqwest::Client>::from_val_boxed),
367 env_variable: "HUGGINGFACE_API_KEY",
368 completion_model: Some(huggingface::PHI_4),
369 transcription_model: Some(huggingface::WHISPER_SMALL),
370 image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
371 ..Default::default()
372 },
373 ClientConfig {
374 name: "OpenAI",
375 factory_env: Box::new(openai::Client::<reqwest::Client>::from_env_boxed),
376 factory_val: Box::new(openai::Client::<reqwest::Client>::from_val_boxed),
377 env_variable: "OPENAI_API_KEY",
378 completion_model: Some(openai::GPT_4O),
379 embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
380 transcription_model: Some(openai::WHISPER_1),
381 image_generation_model: Some(openai::DALL_E_2),
382 audio_generation_model: Some((openai::TTS_1, "onyx")),
383 },
384 ClientConfig {
385 name: "OpenRouter",
386 factory_env: Box::new(openrouter::Client::<reqwest::Client>::from_env_boxed),
387 factory_val: Box::new(openrouter::Client::<reqwest::Client>::from_val_boxed),
388 env_variable: "OPENROUTER_API_KEY",
389 completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
390 ..Default::default()
391 },
392 ClientConfig {
393 name: "Together",
394 factory_env: Box::new(together::Client::<reqwest::Client>::from_env_boxed),
395 factory_val: Box::new(together::Client::<reqwest::Client>::from_val_boxed),
396 env_variable: "TOGETHER_API_KEY",
397 completion_model: Some(together::ALPACA_7B),
398 embeddings_model: Some(together::BERT_BASE_UNCASED),
399 ..Default::default()
400 },
401 ClientConfig {
402 name: "XAI",
403 factory_env: Box::new(xai::Client::<reqwest::Client>::from_env_boxed),
404 factory_val: Box::new(xai::Client::<reqwest::Client>::from_val_boxed),
405 env_variable: "XAI_API_KEY",
406 completion_model: Some(xai::GROK_3_MINI),
407 embeddings_model: None,
408 ..Default::default()
409 },
410 ClientConfig {
411 name: "Azure",
412 factory_env: Box::new(azure::Client::<reqwest::Client>::from_env_boxed),
413 factory_val: Box::new(azure::Client::<reqwest::Client>::from_val_boxed),
414 env_variable: "AZURE_API_KEY",
415 completion_model: Some(azure::GPT_4O),
416 embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
417 transcription_model: Some("whisper-1"),
418 image_generation_model: Some("dalle-2"),
419 audio_generation_model: Some(("tts-1", "onyx")),
420 },
421 ClientConfig {
422 name: "Deepseek",
423 factory_env: Box::new(deepseek::Client::<reqwest::Client>::from_env_boxed),
424 factory_val: Box::new(deepseek::Client::<reqwest::Client>::from_val_boxed),
425 env_variable: "DEEPSEEK_API_KEY",
426 completion_model: Some(deepseek::DEEPSEEK_CHAT),
427 ..Default::default()
428 },
429 ClientConfig {
430 name: "Galadriel",
431 factory_env: Box::new(galadriel::Client::<reqwest::Client>::from_env_boxed),
432 factory_val: Box::new(galadriel::Client::<reqwest::Client>::from_val_boxed),
433 env_variable: "GALADRIEL_API_KEY",
434 completion_model: Some(galadriel::GPT_4O),
435 ..Default::default()
436 },
437 ClientConfig {
438 name: "Groq",
439 factory_env: Box::new(groq::Client::<reqwest::Client>::from_env_boxed),
440 factory_val: Box::new(groq::Client::<reqwest::Client>::from_val_boxed),
441 env_variable: "GROQ_API_KEY",
442 completion_model: Some(groq::MIXTRAL_8X7B_32768),
443 transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
444 ..Default::default()
445 },
446 ClientConfig {
447 name: "Hyperbolic",
448 factory_env: Box::new(hyperbolic::Client::<reqwest::Client>::from_env_boxed),
449 factory_val: Box::new(hyperbolic::Client::<reqwest::Client>::from_val_boxed),
450 env_variable: "HYPERBOLIC_API_KEY",
451 completion_model: Some(hyperbolic::LLAMA_3_1_8B),
452 image_generation_model: Some(hyperbolic::SD1_5),
453 audio_generation_model: Some(("EN", "EN-US")),
454 ..Default::default()
455 },
456 ClientConfig {
457 name: "Mira",
458 factory_env: Box::new(mira::Client::<reqwest::Client>::from_env_boxed),
459 factory_val: Box::new(mira::Client::<reqwest::Client>::from_val_boxed),
460 env_variable: "MIRA_API_KEY",
461 completion_model: Some("gpt-4o"),
462 ..Default::default()
463 },
464 ClientConfig {
465 name: "Moonshot",
466 factory_env: Box::new(moonshot::Client::<reqwest::Client>::from_env_boxed),
467 factory_val: Box::new(moonshot::Client::<reqwest::Client>::from_val_boxed),
468 env_variable: "MOONSHOT_API_KEY",
469 completion_model: Some(moonshot::MOONSHOT_CHAT),
470 ..Default::default()
471 },
472 ClientConfig {
473 name: "Ollama",
474 factory_env: Box::new(ollama::Client::<reqwest::Client>::from_env_boxed),
475 factory_val: Box::new(ollama::Client::<reqwest::Client>::from_val_boxed),
476 env_variable: "OLLAMA_ENABLED",
477 completion_model: Some("llama3.1:8b"),
478 embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
479 ..Default::default()
480 },
481 ClientConfig {
482 name: "Perplexity",
483 factory_env: Box::new(perplexity::Client::<reqwest::Client>::from_env_boxed),
484 factory_val: Box::new(perplexity::Client::<reqwest::Client>::from_val_boxed),
485 env_variable: "PERPLEXITY_API_KEY",
486 completion_model: Some(perplexity::SONAR),
487 ..Default::default()
488 },
489 ]
490 }
491
492 async fn test_completions_client(config: &ClientConfig) {
493 let client = config.factory_env();
494
495 let Some(client) = client.as_completion() else {
496 return;
497 };
498
499 let model = config
500 .completion_model
501 .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
502
503 let model = client.completion_model(model);
504
505 let resp = model
506 .completion_request(Message::user("Whats the capital of France?"))
507 .send()
508 .await;
509
510 assert!(
511 resp.is_ok(),
512 "[{}]: Error occurred when prompting, {}",
513 config.name,
514 resp.err().unwrap()
515 );
516
517 let resp = resp.unwrap();
518
519 match resp.choice.first() {
520 AssistantContent::Text(text) => {
521 assert!(text.text.to_lowercase().contains("paris"));
522 }
523 _ => {
524 unreachable!(
525 "[{}]: First choice wasn't a Text message, {:?}",
526 config.name,
527 resp.choice.first()
528 );
529 }
530 }
531 }
532
533 #[tokio::test]
534 #[ignore]
535 async fn test_completions() {
536 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
537 test_completions_client(&p).await;
538 }
539 }
540
541 async fn test_tools_client(config: &ClientConfig) {
542 let client = config.factory_env();
543 let model = config
544 .completion_model
545 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
546
547 let Some(client) = client.as_completion() else {
548 return;
549 };
550
551 let model = client.agent(model)
552 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
553 .max_tokens(1024)
554 .tool(Adder)
555 .tool(Subtract)
556 .build();
557
558 let request = model.completion("Calculate 2 - 5", vec![]).await;
559
560 assert!(
561 request.is_ok(),
562 "[{}]: Error occurred when building prompt, {}",
563 config.name,
564 request.err().unwrap()
565 );
566
567 let resp = request.unwrap().send().await;
568
569 assert!(
570 resp.is_ok(),
571 "[{}]: Error occurred when prompting, {}",
572 config.name,
573 resp.err().unwrap()
574 );
575
576 let resp = resp.unwrap();
577
578 assert!(
579 resp.choice.iter().any(|content| match content {
580 AssistantContent::ToolCall(tc) => {
581 if tc.function.name != Subtract::NAME {
582 return false;
583 }
584
585 let arguments =
586 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
587 .expect("Error parsing arguments");
588
589 arguments.x == 2.0 && arguments.y == 5.0
590 }
591 _ => false,
592 }),
593 "[{}]: Model did not use the Subtract tool.",
594 config.name
595 )
596 }
597
598 #[tokio::test]
599 #[ignore]
600 async fn test_tools() {
601 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
602 test_tools_client(&p).await;
603 }
604 }
605
606 async fn test_streaming_client(config: &ClientConfig) {
607 let client = config.factory_env();
608
609 let Some(client) = client.as_completion() else {
610 return;
611 };
612
613 let model = config
614 .completion_model
615 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
616
617 let model = client.completion_model(model);
618
619 let resp = model.stream(CompletionRequest {
620 preamble: None,
621 tools: vec![],
622 documents: vec![],
623 temperature: None,
624 max_tokens: None,
625 additional_params: None,
626 tool_choice: None,
627 chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
628 });
629
630 let mut resp = resp.await.unwrap();
631
632 let mut received_chunk = false;
633
634 while let Some(chunk) = resp.next().await {
635 received_chunk = true;
636 assert!(chunk.is_ok());
637 }
638
639 assert!(
640 received_chunk,
641 "[{}]: Failed to receive a chunk from stream",
642 config.name
643 );
644
645 for choice in resp.choice {
646 match choice {
647 AssistantContent::Text(text) => {
648 assert!(
649 text.text.to_lowercase().contains("paris"),
650 "[{}]: Did not answer with Paris",
651 config.name
652 );
653 }
654 AssistantContent::ToolCall(_) => {}
655 AssistantContent::Reasoning(_) => {}
656 }
657 }
658 }
659
660 #[tokio::test]
661 #[ignore]
662 async fn test_streaming() {
663 for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
664 test_streaming_client(&provider).await;
665 }
666 }
667
668 async fn test_streaming_tools_client(config: &ClientConfig) {
669 let client = config.factory_env();
670 let model = config
671 .completion_model
672 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
673
674 let Some(client) = client.as_completion() else {
675 return;
676 };
677
678 let model = client.agent(model)
679 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
680 .max_tokens(1024)
681 .tool(Adder)
682 .tool(Subtract)
683 .build();
684
685 let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
686
687 assert!(
688 request.is_ok(),
689 "[{}]: Error occurred when building prompt, {}",
690 config.name,
691 request.err().unwrap()
692 );
693
694 let resp = request.unwrap().stream().await;
695
696 assert!(
697 resp.is_ok(),
698 "[{}]: Error occurred when prompting, {}",
699 config.name,
700 resp.err().unwrap()
701 );
702
703 let mut resp = resp.unwrap();
704
705 let mut received_chunk = false;
706
707 while let Some(chunk) = resp.next().await {
708 received_chunk = true;
709 assert!(chunk.is_ok());
710 }
711
712 assert!(
713 received_chunk,
714 "[{}]: Failed to receive a chunk from stream",
715 config.name
716 );
717
718 assert!(
719 resp.choice.iter().any(|content| match content {
720 AssistantContent::ToolCall(tc) => {
721 if tc.function.name != Subtract::NAME {
722 return false;
723 }
724
725 let arguments =
726 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
727 .expect("Error parsing arguments");
728
729 arguments.x == 2.0 && arguments.y == 5.0
730 }
731 _ => false,
732 }),
733 "[{}]: Model did not use the Subtract tool.",
734 config.name
735 )
736 }
737
738 #[tokio::test]
739 #[ignore]
740 async fn test_streaming_tools() {
741 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
742 test_streaming_tools_client(&p).await;
743 }
744 }
745
746 async fn test_audio_generation_client(config: &ClientConfig) {
747 let client = config.factory_env();
748
749 let Some(client) = client.as_audio_generation() else {
750 return;
751 };
752
753 let (model, voice) = config
754 .audio_generation_model
755 .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
756
757 let model = client.audio_generation_model(model);
758
759 let request = model
760 .audio_generation_request()
761 .text("Hello world!")
762 .voice(voice);
763
764 let resp = request.send().await;
765
766 assert!(
767 resp.is_ok(),
768 "[{}]: Error occurred when sending request, {}",
769 config.name,
770 resp.err().unwrap()
771 );
772
773 let resp = resp.unwrap();
774
775 assert!(
776 !resp.audio.is_empty(),
777 "[{}]: Returned audio was empty",
778 config.name
779 );
780 }
781
782 #[tokio::test]
783 #[ignore]
784 async fn test_audio_generation() {
785 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
786 test_audio_generation_client(&p).await;
787 }
788 }
789
790 fn assert_feature<F, M>(
791 name: &str,
792 feature_name: &str,
793 model_name: &str,
794 feature: Option<F>,
795 model: Option<M>,
796 ) {
797 assert_eq!(
798 feature.is_some(),
799 model.is_some(),
800 "{} has{} implemented {} but config.{} is {}.",
801 name,
802 if feature.is_some() { "" } else { "n't" },
803 feature_name,
804 model_name,
805 if model.is_some() { "some" } else { "none" }
806 );
807 }
808
809 #[test]
810 #[ignore]
811 pub fn test_polymorphism() {
812 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
813 let client = config.factory_env();
814 assert_feature(
815 config.name,
816 "AsCompletion",
817 "completion_model",
818 client.as_completion(),
819 config.completion_model,
820 );
821
822 assert_feature(
823 config.name,
824 "AsEmbeddings",
825 "embeddings_model",
826 client.as_embeddings(),
827 config.embeddings_model,
828 );
829
830 assert_feature(
831 config.name,
832 "AsTranscription",
833 "transcription_model",
834 client.as_transcription(),
835 config.transcription_model,
836 );
837
838 assert_feature(
839 config.name,
840 "AsImageGeneration",
841 "image_generation_model",
842 client.as_image_generation(),
843 config.image_generation_model,
844 );
845
846 assert_feature(
847 config.name,
848 "AsAudioGeneration",
849 "audio_generation_model",
850 client.as_audio_generation(),
851 config.audio_generation_model,
852 )
853 }
854 }
855
856 async fn test_embed_client(config: &ClientConfig) {
857 const TEST: &str = "Hello world.";
858
859 let client = config.factory_env();
860
861 let Some(client) = client.as_embeddings() else {
862 return;
863 };
864
865 let model = config.embeddings_model.unwrap();
866
867 let model = client.embedding_model(model);
868
869 let resp = model.embed_text(TEST).await;
870
871 assert!(
872 resp.is_ok(),
873 "[{}]: Error occurred when sending request, {}",
874 config.name,
875 resp.err().unwrap()
876 );
877
878 let resp = resp.unwrap();
879
880 assert_eq!(resp.document, TEST);
881
882 assert!(
883 !resp.vec.is_empty(),
884 "[{}]: Returned embed was empty",
885 config.name
886 );
887 }
888
889 #[tokio::test]
890 #[ignore]
891 async fn test_embed() {
892 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
893 test_embed_client(&config).await;
894 }
895 }
896
897 async fn test_image_generation_client(config: &ClientConfig) {
898 let client = config.factory_env();
899 let Some(client) = client.as_image_generation() else {
900 return;
901 };
902
903 let model = config.image_generation_model.unwrap();
904
905 let model = client.image_generation_model(model);
906
907 let resp = model
908 .image_generation(ImageGenerationRequest {
909 prompt: "A castle sitting on a large hill.".to_string(),
910 width: 256,
911 height: 256,
912 additional_params: None,
913 })
914 .await;
915
916 assert!(
917 resp.is_ok(),
918 "[{}]: Error occurred when sending request, {}",
919 config.name,
920 resp.err().unwrap()
921 );
922
923 let resp = resp.unwrap();
924
925 assert!(
926 !resp.image.is_empty(),
927 "[{}]: Generated image was empty",
928 config.name
929 );
930 }
931
932 #[tokio::test]
933 #[ignore]
934 async fn test_image_generation() {
935 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
936 test_image_generation_client(&config).await;
937 }
938 }
939
940 async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
941 let client = config.factory_env();
942 let Some(client) = client.as_transcription() else {
943 return;
944 };
945
946 let model = config.image_generation_model.unwrap();
947
948 let model = client.transcription_model(model);
949
950 let resp = model
951 .transcription(TranscriptionRequest {
952 data,
953 filename: "audio.mp3".to_string(),
954 language: None,
955 prompt: None,
956 temperature: None,
957 additional_params: None,
958 })
959 .await;
960
961 assert!(
962 resp.is_ok(),
963 "[{}]: Error occurred when sending request, {}",
964 config.name,
965 resp.err().unwrap()
966 );
967
968 let resp = resp.unwrap();
969
970 assert!(
971 !resp.text.is_empty(),
972 "[{}]: Returned transcription was empty",
973 config.name
974 );
975 }
976
977 #[tokio::test]
978 #[ignore]
979 async fn test_transcription() {
980 let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
981
982 let mut data = Vec::new();
983 let _ = file.read(&mut data);
984
985 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
986 test_transcription_client(&config, data.clone()).await;
987 }
988 }
989
990 #[derive(Deserialize)]
991 struct OperationArgs {
992 x: f32,
993 y: f32,
994 }
995
996 #[derive(Debug, thiserror::Error)]
997 #[error("Math error")]
998 struct MathError;
999
1000 #[derive(Deserialize, Serialize)]
1001 struct Adder;
1002 impl Tool for Adder {
1003 const NAME: &'static str = "add";
1004
1005 type Error = MathError;
1006 type Args = OperationArgs;
1007 type Output = f32;
1008
1009 async fn definition(&self, _prompt: String) -> ToolDefinition {
1010 ToolDefinition {
1011 name: "add".to_string(),
1012 description: "Add x and y together".to_string(),
1013 parameters: json!({
1014 "type": "object",
1015 "properties": {
1016 "x": {
1017 "type": "number",
1018 "description": "The first number to add"
1019 },
1020 "y": {
1021 "type": "number",
1022 "description": "The second number to add"
1023 }
1024 }
1025 }),
1026 }
1027 }
1028
1029 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1030 println!("[tool-call] Adding {} and {}", args.x, args.y);
1031 let result = args.x + args.y;
1032 Ok(result)
1033 }
1034 }
1035
1036 #[derive(Deserialize, Serialize)]
1037 struct Subtract;
1038 impl Tool for Subtract {
1039 const NAME: &'static str = "subtract";
1040
1041 type Error = MathError;
1042 type Args = OperationArgs;
1043 type Output = f32;
1044
1045 async fn definition(&self, _prompt: String) -> ToolDefinition {
1046 serde_json::from_value(json!({
1047 "name": "subtract",
1048 "description": "Subtract y from x (i.e.: x - y)",
1049 "parameters": {
1050 "type": "object",
1051 "properties": {
1052 "x": {
1053 "type": "number",
1054 "description": "The number to subtract from"
1055 },
1056 "y": {
1057 "type": "number",
1058 "description": "The number to subtract"
1059 }
1060 }
1061 }
1062 }))
1063 .expect("Tool Definition")
1064 }
1065
1066 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1067 println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1068 let result = args.x - args.y;
1069 Ok(result)
1070 }
1071 }
1072}