1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum ClientBuilderError {
20 #[error("reqwest error: {0}")]
21 HttpError(
22 #[from]
23 #[source]
24 reqwest::Error,
25 ),
26 #[error("invalid property: {0}")]
27 InvalidProperty(&'static str),
28}
29
30pub trait ProviderClient:
36 AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
37{
38 fn from_env() -> Self
41 where
42 Self: Sized;
43
44 fn boxed(self) -> Box<dyn ProviderClient>
46 where
47 Self: Sized + 'static,
48 {
49 Box::new(self)
50 }
51
52 fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
55 where
56 Self: Sized,
57 Self: 'a,
58 {
59 Box::new(Self::from_env())
60 }
61
62 fn from_val(input: ProviderValue) -> Self
63 where
64 Self: Sized;
65
66 fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
69 where
70 Self: Sized,
71 Self: 'a,
72 {
73 Box::new(Self::from_val(input))
74 }
75}
76
77#[derive(Clone)]
78pub enum ProviderValue {
79 Simple(String),
80 ApiKeyWithOptionalKey(String, Option<String>),
81 ApiKeyWithVersionAndHeader(String, String, String),
82}
83
84impl From<&str> for ProviderValue {
85 fn from(value: &str) -> Self {
86 Self::Simple(value.to_string())
87 }
88}
89
90impl From<String> for ProviderValue {
91 fn from(value: String) -> Self {
92 Self::Simple(value)
93 }
94}
95
96impl<P> From<(P, Option<P>)> for ProviderValue
97where
98 P: AsRef<str>,
99{
100 fn from((api_key, optional_key): (P, Option<P>)) -> Self {
101 Self::ApiKeyWithOptionalKey(
102 api_key.as_ref().to_string(),
103 optional_key.map(|x| x.as_ref().to_string()),
104 )
105 }
106}
107
108impl<P> From<(P, P, P)> for ProviderValue
109where
110 P: AsRef<str>,
111{
112 fn from((api_key, version, header): (P, P, P)) -> Self {
113 Self::ApiKeyWithVersionAndHeader(
114 api_key.as_ref().to_string(),
115 version.as_ref().to_string(),
116 header.as_ref().to_string(),
117 )
118 }
119}
120
121pub trait AsCompletion {
123 fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
124 None
125 }
126}
127
128pub trait AsTranscription {
130 fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
131 None
132 }
133}
134
135pub trait AsEmbeddings {
137 fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
138 None
139 }
140}
141
142pub trait AsAudioGeneration {
144 #[cfg(feature = "audio")]
145 fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
146 None
147 }
148}
149
150pub trait AsImageGeneration {
152 #[cfg(feature = "image")]
153 fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
154 None
155 }
156}
157
158#[cfg(not(feature = "audio"))]
159impl<T: ProviderClient> AsAudioGeneration for T {}
160
161#[cfg(not(feature = "image"))]
162impl<T: ProviderClient> AsImageGeneration for T {}
163
164#[macro_export]
173macro_rules! impl_conversion_traits {
174 ($( $trait_:ident ),* for $struct_:ident ) => {
175 $(
176 impl_conversion_traits!(@impl $trait_ for $struct_);
177 )*
178 };
179
180 (@impl AsAudioGeneration for $struct_:ident ) => {
181 rig::client::impl_audio_generation!($struct_);
182 };
183
184 (@impl AsImageGeneration for $struct_:ident ) => {
185 rig::client::impl_image_generation!($struct_);
186 };
187
188 (@impl $trait_:ident for $struct_:ident) => {
189 impl rig::client::$trait_ for $struct_ {}
190 };
191}
192
193#[cfg(feature = "audio")]
194#[macro_export]
195macro_rules! impl_audio_generation {
196 ($struct_:ident) => {
197 impl rig::client::AsAudioGeneration for $struct_ {}
198 };
199}
200
201#[cfg(not(feature = "audio"))]
202#[macro_export]
203macro_rules! impl_audio_generation {
204 ($struct_:ident) => {};
205}
206
207#[cfg(feature = "image")]
208#[macro_export]
209macro_rules! impl_image_generation {
210 ($struct_:ident) => {
211 impl rig::client::AsImageGeneration for $struct_ {}
212 };
213}
214
215#[cfg(not(feature = "image"))]
216#[macro_export]
217macro_rules! impl_image_generation {
218 ($struct_:ident) => {};
219}
220
221pub use impl_audio_generation;
222pub use impl_conversion_traits;
223pub use impl_image_generation;
224
225#[cfg(feature = "audio")]
226use crate::client::audio_generation::AudioGenerationClientDyn;
227use crate::client::completion::CompletionClientDyn;
228use crate::client::embeddings::EmbeddingsClientDyn;
229#[cfg(feature = "image")]
230use crate::client::image_generation::ImageGenerationClientDyn;
231use crate::client::transcription::TranscriptionClientDyn;
232
233#[cfg(feature = "audio")]
234pub use crate::client::audio_generation::AudioGenerationClient;
235pub use crate::client::completion::CompletionClient;
236pub use crate::client::embeddings::EmbeddingsClient;
237#[cfg(feature = "image")]
238pub use crate::client::image_generation::ImageGenerationClient;
239pub use crate::client::transcription::TranscriptionClient;
240
241#[cfg(test)]
242mod tests {
243 use crate::OneOrMany;
244 use crate::client::ProviderClient;
245 use crate::completion::{Completion, CompletionRequest, ToolDefinition};
246 use crate::image_generation::ImageGenerationRequest;
247 use crate::message::AssistantContent;
248 use crate::providers::{
249 anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
250 moonshot, openai, openrouter, together, xai,
251 };
252 use crate::streaming::StreamingCompletion;
253 use crate::tool::Tool;
254 use crate::transcription::TranscriptionRequest;
255 use futures::StreamExt;
256 use rig::message::Message;
257 use rig::providers::{groq, ollama, perplexity};
258 use serde::{Deserialize, Serialize};
259 use serde_json::json;
260 use std::fs::File;
261 use std::io::Read;
262
263 use super::ProviderValue;
264
265 struct ClientConfig {
266 name: &'static str,
267 factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
268 #[allow(dead_code)]
270 factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
271 env_variable: &'static str,
272 completion_model: Option<&'static str>,
273 embeddings_model: Option<&'static str>,
274 transcription_model: Option<&'static str>,
275 image_generation_model: Option<&'static str>,
276 audio_generation_model: Option<(&'static str, &'static str)>,
277 }
278
279 impl Default for ClientConfig {
280 fn default() -> Self {
281 Self {
282 name: "",
283 factory_env: Box::new(|| panic!("Not implemented")),
284 factory_val: Box::new(|_| panic!("Not implemented")),
285 env_variable: "",
286 completion_model: None,
287 embeddings_model: None,
288 transcription_model: None,
289 image_generation_model: None,
290 audio_generation_model: None,
291 }
292 }
293 }
294
295 impl ClientConfig {
296 fn is_env_var_set(&self) -> bool {
297 self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
298 }
299
300 fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
301 self.factory_env.as_ref()()
302 }
303 }
304
305 fn providers() -> Vec<ClientConfig> {
306 vec![
307 ClientConfig {
308 name: "Anthropic",
309 factory_env: Box::new(anthropic::Client::from_env_boxed),
310 factory_val: Box::new(anthropic::Client::from_val_boxed),
311 env_variable: "ANTHROPIC_API_KEY",
312 completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
313 ..Default::default()
314 },
315 ClientConfig {
316 name: "Cohere",
317 factory_env: Box::new(cohere::Client::from_env_boxed),
318 factory_val: Box::new(cohere::Client::from_val_boxed),
319 env_variable: "COHERE_API_KEY",
320 completion_model: Some(cohere::COMMAND_R),
321 embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
322 ..Default::default()
323 },
324 ClientConfig {
325 name: "Gemini",
326 factory_env: Box::new(gemini::Client::from_env_boxed),
327 factory_val: Box::new(gemini::Client::from_val_boxed),
328 env_variable: "GEMINI_API_KEY",
329 completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
330 embeddings_model: Some(gemini::embedding::EMBEDDING_001),
331 transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
332 ..Default::default()
333 },
334 ClientConfig {
335 name: "Huggingface",
336 factory_env: Box::new(huggingface::Client::from_env_boxed),
337 factory_val: Box::new(huggingface::Client::from_val_boxed),
338 env_variable: "HUGGINGFACE_API_KEY",
339 completion_model: Some(huggingface::PHI_4),
340 transcription_model: Some(huggingface::WHISPER_SMALL),
341 image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
342 ..Default::default()
343 },
344 ClientConfig {
345 name: "OpenAI",
346 factory_env: Box::new(openai::Client::from_env_boxed),
347 factory_val: Box::new(openai::Client::from_val_boxed),
348 env_variable: "OPENAI_API_KEY",
349 completion_model: Some(openai::GPT_4O),
350 embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
351 transcription_model: Some(openai::WHISPER_1),
352 image_generation_model: Some(openai::DALL_E_2),
353 audio_generation_model: Some((openai::TTS_1, "onyx")),
354 },
355 ClientConfig {
356 name: "OpenRouter",
357 factory_env: Box::new(openrouter::Client::from_env_boxed),
358 factory_val: Box::new(openrouter::Client::from_val_boxed),
359 env_variable: "OPENROUTER_API_KEY",
360 completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
361 ..Default::default()
362 },
363 ClientConfig {
364 name: "Together",
365 factory_env: Box::new(together::Client::from_env_boxed),
366 factory_val: Box::new(together::Client::from_val_boxed),
367 env_variable: "TOGETHER_API_KEY",
368 completion_model: Some(together::ALPACA_7B),
369 embeddings_model: Some(together::BERT_BASE_UNCASED),
370 ..Default::default()
371 },
372 ClientConfig {
373 name: "XAI",
374 factory_env: Box::new(xai::Client::from_env_boxed),
375 factory_val: Box::new(xai::Client::from_val_boxed),
376 env_variable: "XAI_API_KEY",
377 completion_model: Some(xai::GROK_3_MINI),
378 embeddings_model: None,
379 ..Default::default()
380 },
381 ClientConfig {
382 name: "Azure",
383 factory_env: Box::new(azure::Client::from_env_boxed),
384 factory_val: Box::new(azure::Client::from_val_boxed),
385 env_variable: "AZURE_API_KEY",
386 completion_model: Some(azure::GPT_4O),
387 embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
388 transcription_model: Some("whisper-1"),
389 image_generation_model: Some("dalle-2"),
390 audio_generation_model: Some(("tts-1", "onyx")),
391 },
392 ClientConfig {
393 name: "Deepseek",
394 factory_env: Box::new(deepseek::Client::from_env_boxed),
395 factory_val: Box::new(deepseek::Client::from_val_boxed),
396 env_variable: "DEEPSEEK_API_KEY",
397 completion_model: Some(deepseek::DEEPSEEK_CHAT),
398 ..Default::default()
399 },
400 ClientConfig {
401 name: "Galadriel",
402 factory_env: Box::new(galadriel::Client::from_env_boxed),
403 factory_val: Box::new(galadriel::Client::from_val_boxed),
404 env_variable: "GALADRIEL_API_KEY",
405 completion_model: Some(galadriel::GPT_4O),
406 ..Default::default()
407 },
408 ClientConfig {
409 name: "Groq",
410 factory_env: Box::new(groq::Client::from_env_boxed),
411 factory_val: Box::new(groq::Client::from_val_boxed),
412 env_variable: "GROQ_API_KEY",
413 completion_model: Some(groq::MIXTRAL_8X7B_32768),
414 transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
415 ..Default::default()
416 },
417 ClientConfig {
418 name: "Hyperbolic",
419 factory_env: Box::new(hyperbolic::Client::from_env_boxed),
420 factory_val: Box::new(hyperbolic::Client::from_val_boxed),
421 env_variable: "HYPERBOLIC_API_KEY",
422 completion_model: Some(hyperbolic::LLAMA_3_1_8B),
423 image_generation_model: Some(hyperbolic::SD1_5),
424 audio_generation_model: Some(("EN", "EN-US")),
425 ..Default::default()
426 },
427 ClientConfig {
428 name: "Mira",
429 factory_env: Box::new(mira::Client::from_env_boxed),
430 factory_val: Box::new(mira::Client::from_val_boxed),
431 env_variable: "MIRA_API_KEY",
432 completion_model: Some("gpt-4o"),
433 ..Default::default()
434 },
435 ClientConfig {
436 name: "Moonshot",
437 factory_env: Box::new(moonshot::Client::from_env_boxed),
438 factory_val: Box::new(moonshot::Client::from_val_boxed),
439 env_variable: "MOONSHOT_API_KEY",
440 completion_model: Some(moonshot::MOONSHOT_CHAT),
441 ..Default::default()
442 },
443 ClientConfig {
444 name: "Ollama",
445 factory_env: Box::new(ollama::Client::from_env_boxed),
446 factory_val: Box::new(ollama::Client::from_val_boxed),
447 env_variable: "OLLAMA_ENABLED",
448 completion_model: Some("llama3.1:8b"),
449 embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
450 ..Default::default()
451 },
452 ClientConfig {
453 name: "Perplexity",
454 factory_env: Box::new(perplexity::Client::from_env_boxed),
455 factory_val: Box::new(perplexity::Client::from_val_boxed),
456 env_variable: "PERPLEXITY_API_KEY",
457 completion_model: Some(perplexity::SONAR),
458 ..Default::default()
459 },
460 ]
461 }
462
463 async fn test_completions_client(config: &ClientConfig) {
464 let client = config.factory_env();
465
466 let Some(client) = client.as_completion() else {
467 return;
468 };
469
470 let model = config
471 .completion_model
472 .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
473
474 let model = client.completion_model(model);
475
476 let resp = model
477 .completion_request(Message::user("Whats the capital of France?"))
478 .send()
479 .await;
480
481 assert!(
482 resp.is_ok(),
483 "[{}]: Error occurred when prompting, {}",
484 config.name,
485 resp.err().unwrap()
486 );
487
488 let resp = resp.unwrap();
489
490 match resp.choice.first() {
491 AssistantContent::Text(text) => {
492 assert!(text.text.to_lowercase().contains("paris"));
493 }
494 _ => {
495 unreachable!(
496 "[{}]: First choice wasn't a Text message, {:?}",
497 config.name,
498 resp.choice.first()
499 );
500 }
501 }
502 }
503
504 #[tokio::test]
505 #[ignore]
506 async fn test_completions() {
507 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
508 test_completions_client(&p).await;
509 }
510 }
511
512 async fn test_tools_client(config: &ClientConfig) {
513 let client = config.factory_env();
514 let model = config
515 .completion_model
516 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
517
518 let Some(client) = client.as_completion() else {
519 return;
520 };
521
522 let model = client.agent(model)
523 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
524 .max_tokens(1024)
525 .tool(Adder)
526 .tool(Subtract)
527 .build();
528
529 let request = model.completion("Calculate 2 - 5", vec![]).await;
530
531 assert!(
532 request.is_ok(),
533 "[{}]: Error occurred when building prompt, {}",
534 config.name,
535 request.err().unwrap()
536 );
537
538 let resp = request.unwrap().send().await;
539
540 assert!(
541 resp.is_ok(),
542 "[{}]: Error occurred when prompting, {}",
543 config.name,
544 resp.err().unwrap()
545 );
546
547 let resp = resp.unwrap();
548
549 assert!(
550 resp.choice.iter().any(|content| match content {
551 AssistantContent::ToolCall(tc) => {
552 if tc.function.name != Subtract::NAME {
553 return false;
554 }
555
556 let arguments =
557 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
558 .expect("Error parsing arguments");
559
560 arguments.x == 2.0 && arguments.y == 5.0
561 }
562 _ => false,
563 }),
564 "[{}]: Model did not use the Subtract tool.",
565 config.name
566 )
567 }
568
569 #[tokio::test]
570 #[ignore]
571 async fn test_tools() {
572 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
573 test_tools_client(&p).await;
574 }
575 }
576
577 async fn test_streaming_client(config: &ClientConfig) {
578 let client = config.factory_env();
579
580 let Some(client) = client.as_completion() else {
581 return;
582 };
583
584 let model = config
585 .completion_model
586 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
587
588 let model = client.completion_model(model);
589
590 let resp = model.stream(CompletionRequest {
591 preamble: None,
592 tools: vec![],
593 documents: vec![],
594 temperature: None,
595 max_tokens: None,
596 additional_params: None,
597 chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
598 });
599
600 let mut resp = resp.await.unwrap();
601
602 let mut received_chunk = false;
603
604 while let Some(chunk) = resp.next().await {
605 received_chunk = true;
606 assert!(chunk.is_ok());
607 }
608
609 assert!(
610 received_chunk,
611 "[{}]: Failed to receive a chunk from stream",
612 config.name
613 );
614
615 for choice in resp.choice {
616 match choice {
617 AssistantContent::Text(text) => {
618 assert!(
619 text.text.to_lowercase().contains("paris"),
620 "[{}]: Did not answer with Paris",
621 config.name
622 );
623 }
624 AssistantContent::ToolCall(_) => {}
625 AssistantContent::Reasoning(_) => {}
626 }
627 }
628 }
629
630 #[tokio::test]
631 #[ignore]
632 async fn test_streaming() {
633 for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
634 test_streaming_client(&provider).await;
635 }
636 }
637
638 async fn test_streaming_tools_client(config: &ClientConfig) {
639 let client = config.factory_env();
640 let model = config
641 .completion_model
642 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
643
644 let Some(client) = client.as_completion() else {
645 return;
646 };
647
648 let model = client.agent(model)
649 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
650 .max_tokens(1024)
651 .tool(Adder)
652 .tool(Subtract)
653 .build();
654
655 let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
656
657 assert!(
658 request.is_ok(),
659 "[{}]: Error occurred when building prompt, {}",
660 config.name,
661 request.err().unwrap()
662 );
663
664 let resp = request.unwrap().stream().await;
665
666 assert!(
667 resp.is_ok(),
668 "[{}]: Error occurred when prompting, {}",
669 config.name,
670 resp.err().unwrap()
671 );
672
673 let mut resp = resp.unwrap();
674
675 let mut received_chunk = false;
676
677 while let Some(chunk) = resp.next().await {
678 received_chunk = true;
679 assert!(chunk.is_ok());
680 }
681
682 assert!(
683 received_chunk,
684 "[{}]: Failed to receive a chunk from stream",
685 config.name
686 );
687
688 assert!(
689 resp.choice.iter().any(|content| match content {
690 AssistantContent::ToolCall(tc) => {
691 if tc.function.name != Subtract::NAME {
692 return false;
693 }
694
695 let arguments =
696 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
697 .expect("Error parsing arguments");
698
699 arguments.x == 2.0 && arguments.y == 5.0
700 }
701 _ => false,
702 }),
703 "[{}]: Model did not use the Subtract tool.",
704 config.name
705 )
706 }
707
708 #[tokio::test]
709 #[ignore]
710 async fn test_streaming_tools() {
711 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
712 test_streaming_tools_client(&p).await;
713 }
714 }
715
716 async fn test_audio_generation_client(config: &ClientConfig) {
717 let client = config.factory_env();
718
719 let Some(client) = client.as_audio_generation() else {
720 return;
721 };
722
723 let (model, voice) = config
724 .audio_generation_model
725 .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
726
727 let model = client.audio_generation_model(model);
728
729 let request = model
730 .audio_generation_request()
731 .text("Hello world!")
732 .voice(voice);
733
734 let resp = request.send().await;
735
736 assert!(
737 resp.is_ok(),
738 "[{}]: Error occurred when sending request, {}",
739 config.name,
740 resp.err().unwrap()
741 );
742
743 let resp = resp.unwrap();
744
745 assert!(
746 !resp.audio.is_empty(),
747 "[{}]: Returned audio was empty",
748 config.name
749 );
750 }
751
752 #[tokio::test]
753 #[ignore]
754 async fn test_audio_generation() {
755 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
756 test_audio_generation_client(&p).await;
757 }
758 }
759
760 fn assert_feature<F, M>(
761 name: &str,
762 feature_name: &str,
763 model_name: &str,
764 feature: Option<F>,
765 model: Option<M>,
766 ) {
767 assert_eq!(
768 feature.is_some(),
769 model.is_some(),
770 "{} has{} implemented {} but config.{} is {}.",
771 name,
772 if feature.is_some() { "" } else { "n't" },
773 feature_name,
774 model_name,
775 if model.is_some() { "some" } else { "none" }
776 );
777 }
778
779 #[test]
780 #[ignore]
781 pub fn test_polymorphism() {
782 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
783 let client = config.factory_env();
784 assert_feature(
785 config.name,
786 "AsCompletion",
787 "completion_model",
788 client.as_completion(),
789 config.completion_model,
790 );
791
792 assert_feature(
793 config.name,
794 "AsEmbeddings",
795 "embeddings_model",
796 client.as_embeddings(),
797 config.embeddings_model,
798 );
799
800 assert_feature(
801 config.name,
802 "AsTranscription",
803 "transcription_model",
804 client.as_transcription(),
805 config.transcription_model,
806 );
807
808 assert_feature(
809 config.name,
810 "AsImageGeneration",
811 "image_generation_model",
812 client.as_image_generation(),
813 config.image_generation_model,
814 );
815
816 assert_feature(
817 config.name,
818 "AsAudioGeneration",
819 "audio_generation_model",
820 client.as_audio_generation(),
821 config.audio_generation_model,
822 )
823 }
824 }
825
826 async fn test_embed_client(config: &ClientConfig) {
827 const TEST: &str = "Hello world.";
828
829 let client = config.factory_env();
830
831 let Some(client) = client.as_embeddings() else {
832 return;
833 };
834
835 let model = config.embeddings_model.unwrap();
836
837 let model = client.embedding_model(model);
838
839 let resp = model.embed_text(TEST).await;
840
841 assert!(
842 resp.is_ok(),
843 "[{}]: Error occurred when sending request, {}",
844 config.name,
845 resp.err().unwrap()
846 );
847
848 let resp = resp.unwrap();
849
850 assert_eq!(resp.document, TEST);
851
852 assert!(
853 !resp.vec.is_empty(),
854 "[{}]: Returned embed was empty",
855 config.name
856 );
857 }
858
859 #[tokio::test]
860 #[ignore]
861 async fn test_embed() {
862 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
863 test_embed_client(&config).await;
864 }
865 }
866
867 async fn test_image_generation_client(config: &ClientConfig) {
868 let client = config.factory_env();
869 let Some(client) = client.as_image_generation() else {
870 return;
871 };
872
873 let model = config.image_generation_model.unwrap();
874
875 let model = client.image_generation_model(model);
876
877 let resp = model
878 .image_generation(ImageGenerationRequest {
879 prompt: "A castle sitting on a large hill.".to_string(),
880 width: 256,
881 height: 256,
882 additional_params: None,
883 })
884 .await;
885
886 assert!(
887 resp.is_ok(),
888 "[{}]: Error occurred when sending request, {}",
889 config.name,
890 resp.err().unwrap()
891 );
892
893 let resp = resp.unwrap();
894
895 assert!(
896 !resp.image.is_empty(),
897 "[{}]: Generated image was empty",
898 config.name
899 );
900 }
901
902 #[tokio::test]
903 #[ignore]
904 async fn test_image_generation() {
905 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
906 test_image_generation_client(&config).await;
907 }
908 }
909
910 async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
911 let client = config.factory_env();
912 let Some(client) = client.as_transcription() else {
913 return;
914 };
915
916 let model = config.image_generation_model.unwrap();
917
918 let model = client.transcription_model(model);
919
920 let resp = model
921 .transcription(TranscriptionRequest {
922 data,
923 filename: "audio.mp3".to_string(),
924 language: "en".to_string(),
925 prompt: None,
926 temperature: None,
927 additional_params: None,
928 })
929 .await;
930
931 assert!(
932 resp.is_ok(),
933 "[{}]: Error occurred when sending request, {}",
934 config.name,
935 resp.err().unwrap()
936 );
937
938 let resp = resp.unwrap();
939
940 assert!(
941 !resp.text.is_empty(),
942 "[{}]: Returned transcription was empty",
943 config.name
944 );
945 }
946
947 #[tokio::test]
948 #[ignore]
949 async fn test_transcription() {
950 let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
951
952 let mut data = Vec::new();
953 let _ = file.read(&mut data);
954
955 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
956 test_transcription_client(&config, data.clone()).await;
957 }
958 }
959
960 #[derive(Deserialize)]
961 struct OperationArgs {
962 x: f32,
963 y: f32,
964 }
965
966 #[derive(Debug, thiserror::Error)]
967 #[error("Math error")]
968 struct MathError;
969
970 #[derive(Deserialize, Serialize)]
971 struct Adder;
972 impl Tool for Adder {
973 const NAME: &'static str = "add";
974
975 type Error = MathError;
976 type Args = OperationArgs;
977 type Output = f32;
978
979 async fn definition(&self, _prompt: String) -> ToolDefinition {
980 ToolDefinition {
981 name: "add".to_string(),
982 description: "Add x and y together".to_string(),
983 parameters: json!({
984 "type": "object",
985 "properties": {
986 "x": {
987 "type": "number",
988 "description": "The first number to add"
989 },
990 "y": {
991 "type": "number",
992 "description": "The second number to add"
993 }
994 }
995 }),
996 }
997 }
998
999 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1000 println!("[tool-call] Adding {} and {}", args.x, args.y);
1001 let result = args.x + args.y;
1002 Ok(result)
1003 }
1004 }
1005
1006 #[derive(Deserialize, Serialize)]
1007 struct Subtract;
1008 impl Tool for Subtract {
1009 const NAME: &'static str = "subtract";
1010
1011 type Error = MathError;
1012 type Args = OperationArgs;
1013 type Output = f32;
1014
1015 async fn definition(&self, _prompt: String) -> ToolDefinition {
1016 serde_json::from_value(json!({
1017 "name": "subtract",
1018 "description": "Subtract y from x (i.e.: x - y)",
1019 "parameters": {
1020 "type": "object",
1021 "properties": {
1022 "x": {
1023 "type": "number",
1024 "description": "The number to subtract from"
1025 },
1026 "y": {
1027 "type": "number",
1028 "description": "The number to subtract"
1029 }
1030 }
1031 }
1032 }))
1033 .expect("Tool Definition")
1034 }
1035
1036 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1037 println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1038 let result = args.x - args.y;
1039 Ok(result)
1040 }
1041 }
1042}