1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15
16pub trait ProviderClient:
22 AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
23{
24 fn from_env() -> Self
27 where
28 Self: Sized;
29
30 fn boxed(self) -> Box<dyn ProviderClient>
32 where
33 Self: Sized + 'static,
34 {
35 Box::new(self)
36 }
37
38 fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
41 where
42 Self: Sized,
43 Self: 'a,
44 {
45 Box::new(Self::from_env())
46 }
47
48 fn from_val(input: ProviderValue) -> Self
49 where
50 Self: Sized;
51
52 fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
55 where
56 Self: Sized,
57 Self: 'a,
58 {
59 Box::new(Self::from_val(input))
60 }
61}
62
63#[derive(Clone)]
64pub enum ProviderValue {
65 Simple(String),
66 ApiKeyWithOptionalKey(String, Option<String>),
67 ApiKeyWithVersionAndHeader(String, String, String),
68}
69
70impl From<&str> for ProviderValue {
71 fn from(value: &str) -> Self {
72 Self::Simple(value.to_string())
73 }
74}
75
76impl From<String> for ProviderValue {
77 fn from(value: String) -> Self {
78 Self::Simple(value)
79 }
80}
81
82impl<P> From<(P, Option<P>)> for ProviderValue
83where
84 P: AsRef<str>,
85{
86 fn from((api_key, optional_key): (P, Option<P>)) -> Self {
87 Self::ApiKeyWithOptionalKey(
88 api_key.as_ref().to_string(),
89 optional_key.map(|x| x.as_ref().to_string()),
90 )
91 }
92}
93
94impl<P> From<(P, P, P)> for ProviderValue
95where
96 P: AsRef<str>,
97{
98 fn from((api_key, version, header): (P, P, P)) -> Self {
99 Self::ApiKeyWithVersionAndHeader(
100 api_key.as_ref().to_string(),
101 version.as_ref().to_string(),
102 header.as_ref().to_string(),
103 )
104 }
105}
106
107pub trait AsCompletion {
109 fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
110 None
111 }
112}
113
114pub trait AsTranscription {
116 fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
117 None
118 }
119}
120
121pub trait AsEmbeddings {
123 fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
124 None
125 }
126}
127
128pub trait AsAudioGeneration {
130 #[cfg(feature = "audio")]
131 fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
132 None
133 }
134}
135
136pub trait AsImageGeneration {
138 #[cfg(feature = "image")]
139 fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
140 None
141 }
142}
143
144#[cfg(not(feature = "audio"))]
145impl<T: ProviderClient> AsAudioGeneration for T {}
146
147#[cfg(not(feature = "image"))]
148impl<T: ProviderClient> AsImageGeneration for T {}
149
150#[macro_export]
159macro_rules! impl_conversion_traits {
160 ($( $trait_:ident ),* for $struct_:ident ) => {
161 $(
162 impl_conversion_traits!(@impl $trait_ for $struct_);
163 )*
164 };
165
166 (@impl AsAudioGeneration for $struct_:ident ) => {
167 rig::client::impl_audio_generation!($struct_);
168 };
169
170 (@impl AsImageGeneration for $struct_:ident ) => {
171 rig::client::impl_image_generation!($struct_);
172 };
173
174 (@impl $trait_:ident for $struct_:ident) => {
175 impl rig::client::$trait_ for $struct_ {}
176 };
177}
178
179#[cfg(feature = "audio")]
180#[macro_export]
181macro_rules! impl_audio_generation {
182 ($struct_:ident) => {
183 impl rig::client::AsAudioGeneration for $struct_ {}
184 };
185}
186
187#[cfg(not(feature = "audio"))]
188#[macro_export]
189macro_rules! impl_audio_generation {
190 ($struct_:ident) => {};
191}
192
193#[cfg(feature = "image")]
194#[macro_export]
195macro_rules! impl_image_generation {
196 ($struct_:ident) => {
197 impl rig::client::AsImageGeneration for $struct_ {}
198 };
199}
200
201#[cfg(not(feature = "image"))]
202#[macro_export]
203macro_rules! impl_image_generation {
204 ($struct_:ident) => {};
205}
206
207pub use impl_audio_generation;
208pub use impl_conversion_traits;
209pub use impl_image_generation;
210
211#[cfg(feature = "audio")]
212use crate::client::audio_generation::AudioGenerationClientDyn;
213use crate::client::completion::CompletionClientDyn;
214use crate::client::embeddings::EmbeddingsClientDyn;
215#[cfg(feature = "image")]
216use crate::client::image_generation::ImageGenerationClientDyn;
217use crate::client::transcription::TranscriptionClientDyn;
218
219#[cfg(feature = "audio")]
220pub use crate::client::audio_generation::AudioGenerationClient;
221pub use crate::client::completion::CompletionClient;
222pub use crate::client::embeddings::EmbeddingsClient;
223#[cfg(feature = "image")]
224pub use crate::client::image_generation::ImageGenerationClient;
225pub use crate::client::transcription::TranscriptionClient;
226
227#[cfg(test)]
228mod tests {
229 use crate::OneOrMany;
230 use crate::client::ProviderClient;
231 use crate::completion::{Completion, CompletionRequest, ToolDefinition};
232 use crate::image_generation::ImageGenerationRequest;
233 use crate::message::AssistantContent;
234 use crate::providers::{
235 anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
236 moonshot, openai, openrouter, together, xai,
237 };
238 use crate::streaming::StreamingCompletion;
239 use crate::tool::Tool;
240 use crate::transcription::TranscriptionRequest;
241 use futures::StreamExt;
242 use rig::message::Message;
243 use rig::providers::{groq, ollama, perplexity};
244 use serde::{Deserialize, Serialize};
245 use serde_json::json;
246 use std::fs::File;
247 use std::io::Read;
248
249 use super::ProviderValue;
250
251 struct ClientConfig {
252 name: &'static str,
253 factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
254 #[allow(dead_code)]
256 factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
257 env_variable: &'static str,
258 completion_model: Option<&'static str>,
259 embeddings_model: Option<&'static str>,
260 transcription_model: Option<&'static str>,
261 image_generation_model: Option<&'static str>,
262 audio_generation_model: Option<(&'static str, &'static str)>,
263 }
264
265 impl Default for ClientConfig {
266 fn default() -> Self {
267 Self {
268 name: "",
269 factory_env: Box::new(|| panic!("Not implemented")),
270 factory_val: Box::new(|_| panic!("Not implemented")),
271 env_variable: "",
272 completion_model: None,
273 embeddings_model: None,
274 transcription_model: None,
275 image_generation_model: None,
276 audio_generation_model: None,
277 }
278 }
279 }
280
281 impl ClientConfig {
282 fn is_env_var_set(&self) -> bool {
283 self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
284 }
285
286 fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
287 self.factory_env.as_ref()()
288 }
289 }
290
291 fn providers() -> Vec<ClientConfig> {
292 vec![
293 ClientConfig {
294 name: "Anthropic",
295 factory_env: Box::new(anthropic::Client::from_env_boxed),
296 factory_val: Box::new(anthropic::Client::from_val_boxed),
297 env_variable: "ANTHROPIC_API_KEY",
298 completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
299 ..Default::default()
300 },
301 ClientConfig {
302 name: "Cohere",
303 factory_env: Box::new(cohere::Client::from_env_boxed),
304 factory_val: Box::new(cohere::Client::from_val_boxed),
305 env_variable: "COHERE_API_KEY",
306 completion_model: Some(cohere::COMMAND_R),
307 embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
308 ..Default::default()
309 },
310 ClientConfig {
311 name: "Gemini",
312 factory_env: Box::new(gemini::Client::from_env_boxed),
313 factory_val: Box::new(gemini::Client::from_val_boxed),
314 env_variable: "GEMINI_API_KEY",
315 completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
316 embeddings_model: Some(gemini::embedding::EMBEDDING_001),
317 transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
318 ..Default::default()
319 },
320 ClientConfig {
321 name: "Huggingface",
322 factory_env: Box::new(huggingface::Client::from_env_boxed),
323 factory_val: Box::new(huggingface::Client::from_val_boxed),
324 env_variable: "HUGGINGFACE_API_KEY",
325 completion_model: Some(huggingface::PHI_4),
326 transcription_model: Some(huggingface::WHISPER_SMALL),
327 image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
328 ..Default::default()
329 },
330 ClientConfig {
331 name: "OpenAI",
332 factory_env: Box::new(openai::Client::from_env_boxed),
333 factory_val: Box::new(openai::Client::from_val_boxed),
334 env_variable: "OPENAI_API_KEY",
335 completion_model: Some(openai::GPT_4O),
336 embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
337 transcription_model: Some(openai::WHISPER_1),
338 image_generation_model: Some(openai::DALL_E_2),
339 audio_generation_model: Some((openai::TTS_1, "onyx")),
340 },
341 ClientConfig {
342 name: "OpenRouter",
343 factory_env: Box::new(openrouter::Client::from_env_boxed),
344 factory_val: Box::new(openrouter::Client::from_val_boxed),
345 env_variable: "OPENROUTER_API_KEY",
346 completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
347 ..Default::default()
348 },
349 ClientConfig {
350 name: "Together",
351 factory_env: Box::new(together::Client::from_env_boxed),
352 factory_val: Box::new(together::Client::from_val_boxed),
353 env_variable: "TOGETHER_API_KEY",
354 completion_model: Some(together::ALPACA_7B),
355 embeddings_model: Some(together::BERT_BASE_UNCASED),
356 ..Default::default()
357 },
358 ClientConfig {
359 name: "XAI",
360 factory_env: Box::new(xai::Client::from_env_boxed),
361 factory_val: Box::new(xai::Client::from_val_boxed),
362 env_variable: "XAI_API_KEY",
363 completion_model: Some(xai::GROK_3_MINI),
364 embeddings_model: None,
365 ..Default::default()
366 },
367 ClientConfig {
368 name: "Azure",
369 factory_env: Box::new(azure::Client::from_env_boxed),
370 factory_val: Box::new(azure::Client::from_val_boxed),
371 env_variable: "AZURE_API_KEY",
372 completion_model: Some(azure::GPT_4O),
373 embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
374 transcription_model: Some("whisper-1"),
375 image_generation_model: Some("dalle-2"),
376 audio_generation_model: Some(("tts-1", "onyx")),
377 },
378 ClientConfig {
379 name: "Deepseek",
380 factory_env: Box::new(deepseek::Client::from_env_boxed),
381 factory_val: Box::new(deepseek::Client::from_val_boxed),
382 env_variable: "DEEPSEEK_API_KEY",
383 completion_model: Some(deepseek::DEEPSEEK_CHAT),
384 ..Default::default()
385 },
386 ClientConfig {
387 name: "Galadriel",
388 factory_env: Box::new(galadriel::Client::from_env_boxed),
389 factory_val: Box::new(galadriel::Client::from_val_boxed),
390 env_variable: "GALADRIEL_API_KEY",
391 completion_model: Some(galadriel::GPT_4O),
392 ..Default::default()
393 },
394 ClientConfig {
395 name: "Groq",
396 factory_env: Box::new(groq::Client::from_env_boxed),
397 factory_val: Box::new(groq::Client::from_val_boxed),
398 env_variable: "GROQ_API_KEY",
399 completion_model: Some(groq::MIXTRAL_8X7B_32768),
400 transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
401 ..Default::default()
402 },
403 ClientConfig {
404 name: "Hyperbolic",
405 factory_env: Box::new(hyperbolic::Client::from_env_boxed),
406 factory_val: Box::new(hyperbolic::Client::from_val_boxed),
407 env_variable: "HYPERBOLIC_API_KEY",
408 completion_model: Some(hyperbolic::LLAMA_3_1_8B),
409 image_generation_model: Some(hyperbolic::SD1_5),
410 audio_generation_model: Some(("EN", "EN-US")),
411 ..Default::default()
412 },
413 ClientConfig {
414 name: "Mira",
415 factory_env: Box::new(mira::Client::from_env_boxed),
416 factory_val: Box::new(mira::Client::from_val_boxed),
417 env_variable: "MIRA_API_KEY",
418 completion_model: Some("gpt-4o"),
419 ..Default::default()
420 },
421 ClientConfig {
422 name: "Moonshot",
423 factory_env: Box::new(moonshot::Client::from_env_boxed),
424 factory_val: Box::new(moonshot::Client::from_val_boxed),
425 env_variable: "MOONSHOT_API_KEY",
426 completion_model: Some(moonshot::MOONSHOT_CHAT),
427 ..Default::default()
428 },
429 ClientConfig {
430 name: "Ollama",
431 factory_env: Box::new(ollama::Client::from_env_boxed),
432 factory_val: Box::new(ollama::Client::from_val_boxed),
433 env_variable: "OLLAMA_ENABLED",
434 completion_model: Some("llama3.1:8b"),
435 embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
436 ..Default::default()
437 },
438 ClientConfig {
439 name: "Perplexity",
440 factory_env: Box::new(perplexity::Client::from_env_boxed),
441 factory_val: Box::new(perplexity::Client::from_val_boxed),
442 env_variable: "PERPLEXITY_API_KEY",
443 completion_model: Some(perplexity::SONAR),
444 ..Default::default()
445 },
446 ]
447 }
448
449 async fn test_completions_client(config: &ClientConfig) {
450 let client = config.factory_env();
451
452 let Some(client) = client.as_completion() else {
453 return;
454 };
455
456 let model = config
457 .completion_model
458 .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
459
460 let model = client.completion_model(model);
461
462 let resp = model
463 .completion_request(Message::user("Whats the capital of France?"))
464 .send()
465 .await;
466
467 assert!(
468 resp.is_ok(),
469 "[{}]: Error occurred when prompting, {}",
470 config.name,
471 resp.err().unwrap()
472 );
473
474 let resp = resp.unwrap();
475
476 match resp.choice.first() {
477 AssistantContent::Text(text) => {
478 assert!(text.text.to_lowercase().contains("paris"));
479 }
480 _ => {
481 unreachable!(
482 "[{}]: First choice wasn't a Text message, {:?}",
483 config.name,
484 resp.choice.first()
485 );
486 }
487 }
488 }
489
490 #[tokio::test]
491 #[ignore]
492 async fn test_completions() {
493 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
494 test_completions_client(&p).await;
495 }
496 }
497
498 async fn test_tools_client(config: &ClientConfig) {
499 let client = config.factory_env();
500 let model = config
501 .completion_model
502 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
503
504 let Some(client) = client.as_completion() else {
505 return;
506 };
507
508 let model = client.agent(model)
509 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
510 .max_tokens(1024)
511 .tool(Adder)
512 .tool(Subtract)
513 .build();
514
515 let request = model.completion("Calculate 2 - 5", vec![]).await;
516
517 assert!(
518 request.is_ok(),
519 "[{}]: Error occurred when building prompt, {}",
520 config.name,
521 request.err().unwrap()
522 );
523
524 let resp = request.unwrap().send().await;
525
526 assert!(
527 resp.is_ok(),
528 "[{}]: Error occurred when prompting, {}",
529 config.name,
530 resp.err().unwrap()
531 );
532
533 let resp = resp.unwrap();
534
535 assert!(
536 resp.choice.iter().any(|content| match content {
537 AssistantContent::ToolCall(tc) => {
538 if tc.function.name != Subtract::NAME {
539 return false;
540 }
541
542 let arguments =
543 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
544 .expect("Error parsing arguments");
545
546 arguments.x == 2.0 && arguments.y == 5.0
547 }
548 _ => false,
549 }),
550 "[{}]: Model did not use the Subtract tool.",
551 config.name
552 )
553 }
554
555 #[tokio::test]
556 #[ignore]
557 async fn test_tools() {
558 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
559 test_tools_client(&p).await;
560 }
561 }
562
563 async fn test_streaming_client(config: &ClientConfig) {
564 let client = config.factory_env();
565
566 let Some(client) = client.as_completion() else {
567 return;
568 };
569
570 let model = config
571 .completion_model
572 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
573
574 let model = client.completion_model(model);
575
576 let resp = model.stream(CompletionRequest {
577 preamble: None,
578 tools: vec![],
579 documents: vec![],
580 temperature: None,
581 max_tokens: None,
582 additional_params: None,
583 chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
584 });
585
586 let mut resp = resp.await.unwrap();
587
588 let mut received_chunk = false;
589
590 while let Some(chunk) = resp.next().await {
591 received_chunk = true;
592 assert!(chunk.is_ok());
593 }
594
595 assert!(
596 received_chunk,
597 "[{}]: Failed to receive a chunk from stream",
598 config.name
599 );
600
601 for choice in resp.choice {
602 match choice {
603 AssistantContent::Text(text) => {
604 assert!(
605 text.text.to_lowercase().contains("paris"),
606 "[{}]: Did not answer with Paris",
607 config.name
608 );
609 }
610 AssistantContent::ToolCall(_) => {}
611 }
612 }
613 }
614
615 #[tokio::test]
616 #[ignore]
617 async fn test_streaming() {
618 for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
619 test_streaming_client(&provider).await;
620 }
621 }
622
623 async fn test_streaming_tools_client(config: &ClientConfig) {
624 let client = config.factory_env();
625 let model = config
626 .completion_model
627 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
628
629 let Some(client) = client.as_completion() else {
630 return;
631 };
632
633 let model = client.agent(model)
634 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
635 .max_tokens(1024)
636 .tool(Adder)
637 .tool(Subtract)
638 .build();
639
640 let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
641
642 assert!(
643 request.is_ok(),
644 "[{}]: Error occurred when building prompt, {}",
645 config.name,
646 request.err().unwrap()
647 );
648
649 let resp = request.unwrap().stream().await;
650
651 assert!(
652 resp.is_ok(),
653 "[{}]: Error occurred when prompting, {}",
654 config.name,
655 resp.err().unwrap()
656 );
657
658 let mut resp = resp.unwrap();
659
660 let mut received_chunk = false;
661
662 while let Some(chunk) = resp.next().await {
663 received_chunk = true;
664 assert!(chunk.is_ok());
665 }
666
667 assert!(
668 received_chunk,
669 "[{}]: Failed to receive a chunk from stream",
670 config.name
671 );
672
673 assert!(
674 resp.choice.iter().any(|content| match content {
675 AssistantContent::ToolCall(tc) => {
676 if tc.function.name != Subtract::NAME {
677 return false;
678 }
679
680 let arguments =
681 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
682 .expect("Error parsing arguments");
683
684 arguments.x == 2.0 && arguments.y == 5.0
685 }
686 _ => false,
687 }),
688 "[{}]: Model did not use the Subtract tool.",
689 config.name
690 )
691 }
692
693 #[tokio::test]
694 #[ignore]
695 async fn test_streaming_tools() {
696 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
697 test_streaming_tools_client(&p).await;
698 }
699 }
700
701 async fn test_audio_generation_client(config: &ClientConfig) {
702 let client = config.factory_env();
703
704 let Some(client) = client.as_audio_generation() else {
705 return;
706 };
707
708 let (model, voice) = config
709 .audio_generation_model
710 .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
711
712 let model = client.audio_generation_model(model);
713
714 let request = model
715 .audio_generation_request()
716 .text("Hello world!")
717 .voice(voice);
718
719 let resp = request.send().await;
720
721 assert!(
722 resp.is_ok(),
723 "[{}]: Error occurred when sending request, {}",
724 config.name,
725 resp.err().unwrap()
726 );
727
728 let resp = resp.unwrap();
729
730 assert!(
731 !resp.audio.is_empty(),
732 "[{}]: Returned audio was empty",
733 config.name
734 );
735 }
736
737 #[tokio::test]
738 #[ignore]
739 async fn test_audio_generation() {
740 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
741 test_audio_generation_client(&p).await;
742 }
743 }
744
745 fn assert_feature<F, M>(
746 name: &str,
747 feature_name: &str,
748 model_name: &str,
749 feature: Option<F>,
750 model: Option<M>,
751 ) {
752 assert_eq!(
753 feature.is_some(),
754 model.is_some(),
755 "{} has{} implemented {} but config.{} is {}.",
756 name,
757 if feature.is_some() { "" } else { "n't" },
758 feature_name,
759 model_name,
760 if model.is_some() { "some" } else { "none" }
761 );
762 }
763
764 #[test]
765 #[ignore]
766 pub fn test_polymorphism() {
767 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
768 let client = config.factory_env();
769 assert_feature(
770 config.name,
771 "AsCompletion",
772 "completion_model",
773 client.as_completion(),
774 config.completion_model,
775 );
776
777 assert_feature(
778 config.name,
779 "AsEmbeddings",
780 "embeddings_model",
781 client.as_embeddings(),
782 config.embeddings_model,
783 );
784
785 assert_feature(
786 config.name,
787 "AsTranscription",
788 "transcription_model",
789 client.as_transcription(),
790 config.transcription_model,
791 );
792
793 assert_feature(
794 config.name,
795 "AsImageGeneration",
796 "image_generation_model",
797 client.as_image_generation(),
798 config.image_generation_model,
799 );
800
801 assert_feature(
802 config.name,
803 "AsAudioGeneration",
804 "audio_generation_model",
805 client.as_audio_generation(),
806 config.audio_generation_model,
807 )
808 }
809 }
810
811 async fn test_embed_client(config: &ClientConfig) {
812 const TEST: &str = "Hello world.";
813
814 let client = config.factory_env();
815
816 let Some(client) = client.as_embeddings() else {
817 return;
818 };
819
820 let model = config.embeddings_model.unwrap();
821
822 let model = client.embedding_model(model);
823
824 let resp = model.embed_text(TEST).await;
825
826 assert!(
827 resp.is_ok(),
828 "[{}]: Error occurred when sending request, {}",
829 config.name,
830 resp.err().unwrap()
831 );
832
833 let resp = resp.unwrap();
834
835 assert_eq!(resp.document, TEST);
836
837 assert!(
838 !resp.vec.is_empty(),
839 "[{}]: Returned embed was empty",
840 config.name
841 );
842 }
843
844 #[tokio::test]
845 #[ignore]
846 async fn test_embed() {
847 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
848 test_embed_client(&config).await;
849 }
850 }
851
852 async fn test_image_generation_client(config: &ClientConfig) {
853 let client = config.factory_env();
854 let Some(client) = client.as_image_generation() else {
855 return;
856 };
857
858 let model = config.image_generation_model.unwrap();
859
860 let model = client.image_generation_model(model);
861
862 let resp = model
863 .image_generation(ImageGenerationRequest {
864 prompt: "A castle sitting on a large hill.".to_string(),
865 width: 256,
866 height: 256,
867 additional_params: None,
868 })
869 .await;
870
871 assert!(
872 resp.is_ok(),
873 "[{}]: Error occurred when sending request, {}",
874 config.name,
875 resp.err().unwrap()
876 );
877
878 let resp = resp.unwrap();
879
880 assert!(
881 !resp.image.is_empty(),
882 "[{}]: Generated image was empty",
883 config.name
884 );
885 }
886
887 #[tokio::test]
888 #[ignore]
889 async fn test_image_generation() {
890 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
891 test_image_generation_client(&config).await;
892 }
893 }
894
895 async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
896 let client = config.factory_env();
897 let Some(client) = client.as_transcription() else {
898 return;
899 };
900
901 let model = config.image_generation_model.unwrap();
902
903 let model = client.transcription_model(model);
904
905 let resp = model
906 .transcription(TranscriptionRequest {
907 data,
908 filename: "audio.mp3".to_string(),
909 language: "en".to_string(),
910 prompt: None,
911 temperature: None,
912 additional_params: None,
913 })
914 .await;
915
916 assert!(
917 resp.is_ok(),
918 "[{}]: Error occurred when sending request, {}",
919 config.name,
920 resp.err().unwrap()
921 );
922
923 let resp = resp.unwrap();
924
925 assert!(
926 !resp.text.is_empty(),
927 "[{}]: Returned transcription was empty",
928 config.name
929 );
930 }
931
932 #[tokio::test]
933 #[ignore]
934 async fn test_transcription() {
935 let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
936
937 let mut data = Vec::new();
938 let _ = file.read(&mut data);
939
940 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
941 test_transcription_client(&config, data.clone()).await;
942 }
943 }
944
945 #[derive(Deserialize)]
946 struct OperationArgs {
947 x: f32,
948 y: f32,
949 }
950
951 #[derive(Debug, thiserror::Error)]
952 #[error("Math error")]
953 struct MathError;
954
955 #[derive(Deserialize, Serialize)]
956 struct Adder;
957 impl Tool for Adder {
958 const NAME: &'static str = "add";
959
960 type Error = MathError;
961 type Args = OperationArgs;
962 type Output = f32;
963
964 async fn definition(&self, _prompt: String) -> ToolDefinition {
965 ToolDefinition {
966 name: "add".to_string(),
967 description: "Add x and y together".to_string(),
968 parameters: json!({
969 "type": "object",
970 "properties": {
971 "x": {
972 "type": "number",
973 "description": "The first number to add"
974 },
975 "y": {
976 "type": "number",
977 "description": "The second number to add"
978 }
979 }
980 }),
981 }
982 }
983
984 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
985 println!("[tool-call] Adding {} and {}", args.x, args.y);
986 let result = args.x + args.y;
987 Ok(result)
988 }
989 }
990
991 #[derive(Deserialize, Serialize)]
992 struct Subtract;
993 impl Tool for Subtract {
994 const NAME: &'static str = "subtract";
995
996 type Error = MathError;
997 type Args = OperationArgs;
998 type Output = f32;
999
1000 async fn definition(&self, _prompt: String) -> ToolDefinition {
1001 serde_json::from_value(json!({
1002 "name": "subtract",
1003 "description": "Subtract y from x (i.e.: x - y)",
1004 "parameters": {
1005 "type": "object",
1006 "properties": {
1007 "x": {
1008 "type": "number",
1009 "description": "The number to subtract from"
1010 },
1011 "y": {
1012 "type": "number",
1013 "description": "The number to subtract"
1014 }
1015 }
1016 }
1017 }))
1018 .expect("Tool Definition")
1019 }
1020
1021 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1022 println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1023 let result = args.x - args.y;
1024 Ok(result)
1025 }
1026 }
1027}