rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15
16/// The base ProviderClient trait, facilitates conversion between client types
17/// and creating a client from the environment.
18///
19/// All conversion traits must be implemented, they are automatically
20/// implemented if the respective client trait is implemented.
21pub trait ProviderClient:
22    AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
23{
24    /// Create a client from the process's environment.
25    /// Panics if an environment is improperly configured.
26    fn from_env() -> Self
27    where
28        Self: Sized;
29
30    /// A helper method to box the client.
31    fn boxed(self) -> Box<dyn ProviderClient>
32    where
33        Self: Sized + 'static,
34    {
35        Box::new(self)
36    }
37
38    /// Create a boxed client from the process's environment.
39    /// Panics if an environment is improperly configured.
40    fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
41    where
42        Self: Sized,
43        Self: 'a,
44    {
45        Box::new(Self::from_env())
46    }
47}
48
49/// Attempt to convert a ProviderClient to a CompletionClient
50pub trait AsCompletion {
51    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
52        None
53    }
54}
55
56/// Attempt to convert a ProviderClient to a TranscriptionClient
57pub trait AsTranscription {
58    fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
59        None
60    }
61}
62
63/// Attempt to convert a ProviderClient to a EmbeddingsClient
64pub trait AsEmbeddings {
65    fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
66        None
67    }
68}
69
70/// Attempt to convert a ProviderClient to a AudioGenerationClient
71pub trait AsAudioGeneration {
72    #[cfg(feature = "audio")]
73    fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
74        None
75    }
76}
77
78/// Attempt to convert a ProviderClient to a ImageGenerationClient
79pub trait AsImageGeneration {
80    #[cfg(feature = "image")]
81    fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
82        None
83    }
84}
85
86#[cfg(not(feature = "audio"))]
87impl<T: ProviderClient> AsAudioGeneration for T {}
88
89#[cfg(not(feature = "image"))]
90impl<T: ProviderClient> AsImageGeneration for T {}
91
92/// Implements the conversion traits for a given struct
93/// ```rust
94/// pub struct Client;
95/// impl ProviderClient for Client {
96///     ...
97/// }
98/// impl_conversion_traits!(AsCompletion, AsEmbeddings for Client);
99/// ```
100#[macro_export]
101macro_rules! impl_conversion_traits {
102    ($( $trait_:ident ),* for $struct_:ident ) => {
103        $(
104            impl_conversion_traits!(@impl $trait_ for $struct_);
105        )*
106    };
107
108    (@impl AsAudioGeneration for $struct_:ident ) => {
109        rig::client::impl_audio_generation!($struct_);
110    };
111
112    (@impl AsImageGeneration for $struct_:ident ) => {
113        rig::client::impl_image_generation!($struct_);
114    };
115
116    (@impl $trait_:ident for $struct_:ident) => {
117        impl rig::client::$trait_ for $struct_ {}
118    };
119}
120
121#[cfg(feature = "audio")]
122#[macro_export]
123macro_rules! impl_audio_generation {
124    ($struct_:ident) => {
125        impl rig::client::AsAudioGeneration for $struct_ {}
126    };
127}
128
129#[cfg(not(feature = "audio"))]
130#[macro_export]
131macro_rules! impl_audio_generation {
132    ($struct_:ident) => {};
133}
134
135#[cfg(feature = "image")]
136#[macro_export]
137macro_rules! impl_image_generation {
138    ($struct_:ident) => {
139        impl rig::client::AsImageGeneration for $struct_ {}
140    };
141}
142
143#[cfg(not(feature = "image"))]
144#[macro_export]
145macro_rules! impl_image_generation {
146    ($struct_:ident) => {};
147}
148
149pub use impl_audio_generation;
150pub use impl_conversion_traits;
151pub use impl_image_generation;
152
153#[cfg(feature = "audio")]
154use crate::client::audio_generation::AudioGenerationClientDyn;
155use crate::client::completion::CompletionClientDyn;
156use crate::client::embeddings::EmbeddingsClientDyn;
157#[cfg(feature = "image")]
158use crate::client::image_generation::ImageGenerationClientDyn;
159use crate::client::transcription::TranscriptionClientDyn;
160
161#[cfg(feature = "audio")]
162pub use crate::client::audio_generation::AudioGenerationClient;
163pub use crate::client::completion::CompletionClient;
164pub use crate::client::embeddings::EmbeddingsClient;
165#[cfg(feature = "image")]
166pub use crate::client::image_generation::ImageGenerationClient;
167pub use crate::client::transcription::TranscriptionClient;
168
169#[cfg(test)]
170mod tests {
171    use crate::OneOrMany;
172    use crate::client::ProviderClient;
173    use crate::completion::{Completion, CompletionRequest, ToolDefinition};
174    use crate::image_generation::ImageGenerationRequest;
175    use crate::message::AssistantContent;
176    use crate::providers::{
177        anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
178        moonshot, openai, openrouter, together, xai,
179    };
180    use crate::streaming::StreamingCompletion;
181    use crate::tool::Tool;
182    use crate::transcription::TranscriptionRequest;
183    use futures::StreamExt;
184    use rig::message::Message;
185    use rig::providers::{groq, ollama, perplexity};
186    use serde::{Deserialize, Serialize};
187    use serde_json::json;
188    use std::fs::File;
189    use std::io::Read;
190
191    struct ClientConfig {
192        name: &'static str,
193        factory: Box<dyn Fn() -> Box<dyn ProviderClient>>,
194        env_variable: &'static str,
195        completion_model: Option<&'static str>,
196        embeddings_model: Option<&'static str>,
197        transcription_model: Option<&'static str>,
198        image_generation_model: Option<&'static str>,
199        audio_generation_model: Option<(&'static str, &'static str)>,
200    }
201
202    impl Default for ClientConfig {
203        fn default() -> Self {
204            Self {
205                name: "",
206                factory: Box::new(|| panic!("Not implemented")),
207                env_variable: "",
208                completion_model: None,
209                embeddings_model: None,
210                transcription_model: None,
211                image_generation_model: None,
212                audio_generation_model: None,
213            }
214        }
215    }
216
217    impl ClientConfig {
218        fn is_env_var_set(&self) -> bool {
219            self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
220        }
221
222        fn factory(&self) -> Box<dyn ProviderClient + '_> {
223            self.factory.as_ref()()
224        }
225    }
226
227    fn providers() -> Vec<ClientConfig> {
228        vec![
229            ClientConfig {
230                name: "Anthropic",
231                factory: Box::new(anthropic::Client::from_env_boxed),
232                env_variable: "ANTHROPIC_API_KEY",
233                completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
234                ..Default::default()
235            },
236            ClientConfig {
237                name: "Cohere",
238                factory: Box::new(cohere::Client::from_env_boxed),
239                env_variable: "COHERE_API_KEY",
240                completion_model: Some(cohere::COMMAND_R),
241                embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
242                ..Default::default()
243            },
244            ClientConfig {
245                name: "Gemini",
246                factory: Box::new(gemini::Client::from_env_boxed),
247                env_variable: "GEMINI_API_KEY",
248                completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
249                embeddings_model: Some(gemini::embedding::EMBEDDING_001),
250                transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
251                ..Default::default()
252            },
253            ClientConfig {
254                name: "Huggingface",
255                factory: Box::new(huggingface::Client::from_env_boxed),
256                env_variable: "HUGGINGFACE_API_KEY",
257                completion_model: Some(huggingface::PHI_4),
258                transcription_model: Some(huggingface::WHISPER_SMALL),
259                image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
260                ..Default::default()
261            },
262            ClientConfig {
263                name: "OpenAI",
264                factory: Box::new(openai::Client::from_env_boxed),
265                env_variable: "OPENAI_API_KEY",
266                completion_model: Some(openai::GPT_4O),
267                embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
268                transcription_model: Some(openai::WHISPER_1),
269                image_generation_model: Some(openai::DALL_E_2),
270                audio_generation_model: Some((openai::TTS_1, "onyx")),
271            },
272            ClientConfig {
273                name: "OpenRouter",
274                factory: Box::new(openrouter::Client::from_env_boxed),
275                env_variable: "OPENROUTER_API_KEY",
276                completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
277                ..Default::default()
278            },
279            ClientConfig {
280                name: "Together",
281                factory: Box::new(together::Client::from_env_boxed),
282                env_variable: "TOGETHER_API_KEY",
283                completion_model: Some(together::ALPACA_7B),
284                embeddings_model: Some(together::BERT_BASE_UNCASED),
285                ..Default::default()
286            },
287            ClientConfig {
288                name: "XAI",
289                factory: Box::new(xai::Client::from_env_boxed),
290                env_variable: "XAI_API_KEY",
291                completion_model: Some(xai::GROK_3_MINI),
292                embeddings_model: None,
293                ..Default::default()
294            },
295            ClientConfig {
296                name: "Azure",
297                factory: Box::new(azure::Client::from_env_boxed),
298                env_variable: "AZURE_API_KEY",
299                completion_model: Some(azure::GPT_4O),
300                embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
301                transcription_model: Some("whisper-1"),
302                image_generation_model: Some("dalle-2"),
303                audio_generation_model: Some(("tts-1", "onyx")),
304            },
305            ClientConfig {
306                name: "Deepseek",
307                factory: Box::new(deepseek::Client::from_env_boxed),
308                env_variable: "DEEPSEEK_API_KEY",
309                completion_model: Some(deepseek::DEEPSEEK_CHAT),
310                ..Default::default()
311            },
312            ClientConfig {
313                name: "Galadriel",
314                factory: Box::new(galadriel::Client::from_env_boxed),
315                env_variable: "GALADRIEL_API_KEY",
316                completion_model: Some(galadriel::GPT_4O),
317                ..Default::default()
318            },
319            ClientConfig {
320                name: "Groq",
321                factory: Box::new(groq::Client::from_env_boxed),
322                env_variable: "GROQ_API_KEY",
323                completion_model: Some(groq::MIXTRAL_8X7B_32768),
324                transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
325                ..Default::default()
326            },
327            ClientConfig {
328                name: "Hyperbolic",
329                factory: Box::new(hyperbolic::Client::from_env_boxed),
330                env_variable: "HYPERBOLIC_API_KEY",
331                completion_model: Some(hyperbolic::LLAMA_3_1_8B),
332                image_generation_model: Some(hyperbolic::SD1_5),
333                audio_generation_model: Some(("EN", "EN-US")),
334                ..Default::default()
335            },
336            ClientConfig {
337                name: "Mira",
338                factory: Box::new(mira::Client::from_env_boxed),
339                env_variable: "MIRA_API_KEY",
340                completion_model: Some("gpt-4o"),
341                ..Default::default()
342            },
343            ClientConfig {
344                name: "Moonshot",
345                factory: Box::new(moonshot::Client::from_env_boxed),
346                env_variable: "MOONSHOT_API_KEY",
347                completion_model: Some(moonshot::MOONSHOT_CHAT),
348                ..Default::default()
349            },
350            ClientConfig {
351                name: "Ollama",
352                factory: Box::new(ollama::Client::from_env_boxed),
353                env_variable: "OLLAMA_ENABLED",
354                completion_model: Some("llama3.1:8b"),
355                embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
356                ..Default::default()
357            },
358            ClientConfig {
359                name: "Perplexity",
360                factory: Box::new(perplexity::Client::from_env_boxed),
361                env_variable: "PERPLEXITY_API_KEY",
362                completion_model: Some(perplexity::SONAR),
363                ..Default::default()
364            },
365        ]
366    }
367
368    async fn test_completions_client(config: &ClientConfig) {
369        let client = config.factory();
370
371        let Some(client) = client.as_completion() else {
372            return;
373        };
374
375        let model = config
376            .completion_model
377            .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
378
379        let model = client.completion_model(model);
380
381        let resp = model
382            .completion_request(Message::user("Whats the capital of France?"))
383            .send()
384            .await;
385
386        assert!(
387            resp.is_ok(),
388            "[{}]: Error occurred when prompting, {}",
389            config.name,
390            resp.err().unwrap()
391        );
392
393        let resp = resp.unwrap();
394
395        match resp.choice.first() {
396            AssistantContent::Text(text) => {
397                assert!(text.text.to_lowercase().contains("paris"));
398            }
399            _ => {
400                unreachable!(
401                    "[{}]: First choice wasn't a Text message, {:?}",
402                    config.name,
403                    resp.choice.first()
404                );
405            }
406        }
407    }
408
409    #[tokio::test]
410    #[ignore]
411    async fn test_completions() {
412        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
413            test_completions_client(&p).await;
414        }
415    }
416
417    async fn test_tools_client(config: &ClientConfig) {
418        let client = config.factory();
419        let model = config
420            .completion_model
421            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
422
423        let Some(client) = client.as_completion() else {
424            return;
425        };
426
427        let model = client.agent(model)
428            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
429            .max_tokens(1024)
430            .tool(Adder)
431            .tool(Subtract)
432            .build();
433
434        let request = model.completion("Calculate 2 - 5", vec![]).await;
435
436        assert!(
437            request.is_ok(),
438            "[{}]: Error occurred when building prompt, {}",
439            config.name,
440            request.err().unwrap()
441        );
442
443        let resp = request.unwrap().send().await;
444
445        assert!(
446            resp.is_ok(),
447            "[{}]: Error occurred when prompting, {}",
448            config.name,
449            resp.err().unwrap()
450        );
451
452        let resp = resp.unwrap();
453
454        assert!(
455            resp.choice.iter().any(|content| match content {
456                AssistantContent::ToolCall(tc) => {
457                    if tc.function.name != Subtract::NAME {
458                        return false;
459                    }
460
461                    let arguments =
462                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
463                            .expect("Error parsing arguments");
464
465                    arguments.x == 2.0 && arguments.y == 5.0
466                }
467                _ => false,
468            }),
469            "[{}]: Model did not use the Subtract tool.",
470            config.name
471        )
472    }
473
474    #[tokio::test]
475    #[ignore]
476    async fn test_tools() {
477        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
478            test_tools_client(&p).await;
479        }
480    }
481
482    async fn test_streaming_client(config: &ClientConfig) {
483        let client = config.factory();
484
485        let Some(client) = client.as_completion() else {
486            return;
487        };
488
489        let model = config
490            .completion_model
491            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
492
493        let model = client.completion_model(model);
494
495        let resp = model.stream(CompletionRequest {
496            preamble: None,
497            tools: vec![],
498            documents: vec![],
499            temperature: None,
500            max_tokens: None,
501            additional_params: None,
502            chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
503        });
504
505        let mut resp = resp.await.unwrap();
506
507        let mut received_chunk = false;
508
509        while let Some(chunk) = resp.next().await {
510            received_chunk = true;
511            assert!(chunk.is_ok());
512        }
513
514        assert!(
515            received_chunk,
516            "[{}]: Failed to receive a chunk from stream",
517            config.name
518        );
519
520        for choice in resp.choice {
521            match choice {
522                AssistantContent::Text(text) => {
523                    assert!(
524                        text.text.to_lowercase().contains("paris"),
525                        "[{}]: Did not answer with Paris",
526                        config.name
527                    );
528                }
529                AssistantContent::ToolCall(_) => {}
530            }
531        }
532    }
533
534    #[tokio::test]
535    #[ignore]
536    async fn test_streaming() {
537        for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
538            test_streaming_client(&provider).await;
539        }
540    }
541
542    async fn test_streaming_tools_client(config: &ClientConfig) {
543        let client = config.factory();
544        let model = config
545            .completion_model
546            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
547
548        let Some(client) = client.as_completion() else {
549            return;
550        };
551
552        let model = client.agent(model)
553            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
554            .max_tokens(1024)
555            .tool(Adder)
556            .tool(Subtract)
557            .build();
558
559        let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
560
561        assert!(
562            request.is_ok(),
563            "[{}]: Error occurred when building prompt, {}",
564            config.name,
565            request.err().unwrap()
566        );
567
568        let resp = request.unwrap().stream().await;
569
570        assert!(
571            resp.is_ok(),
572            "[{}]: Error occurred when prompting, {}",
573            config.name,
574            resp.err().unwrap()
575        );
576
577        let mut resp = resp.unwrap();
578
579        let mut received_chunk = false;
580
581        while let Some(chunk) = resp.next().await {
582            received_chunk = true;
583            assert!(chunk.is_ok());
584        }
585
586        assert!(
587            received_chunk,
588            "[{}]: Failed to receive a chunk from stream",
589            config.name
590        );
591
592        assert!(
593            resp.choice.iter().any(|content| match content {
594                AssistantContent::ToolCall(tc) => {
595                    if tc.function.name != Subtract::NAME {
596                        return false;
597                    }
598
599                    let arguments =
600                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
601                            .expect("Error parsing arguments");
602
603                    arguments.x == 2.0 && arguments.y == 5.0
604                }
605                _ => false,
606            }),
607            "[{}]: Model did not use the Subtract tool.",
608            config.name
609        )
610    }
611
612    #[tokio::test]
613    #[ignore]
614    async fn test_streaming_tools() {
615        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
616            test_streaming_tools_client(&p).await;
617        }
618    }
619
620    async fn test_audio_generation_client(config: &ClientConfig) {
621        let client = config.factory();
622
623        let Some(client) = client.as_audio_generation() else {
624            return;
625        };
626
627        let (model, voice) = config
628            .audio_generation_model
629            .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
630
631        let model = client.audio_generation_model(model);
632
633        let request = model
634            .audio_generation_request()
635            .text("Hello world!")
636            .voice(voice);
637
638        let resp = request.send().await;
639
640        assert!(
641            resp.is_ok(),
642            "[{}]: Error occurred when sending request, {}",
643            config.name,
644            resp.err().unwrap()
645        );
646
647        let resp = resp.unwrap();
648
649        assert!(
650            !resp.audio.is_empty(),
651            "[{}]: Returned audio was empty",
652            config.name
653        );
654    }
655
656    #[tokio::test]
657    #[ignore]
658    async fn test_audio_generation() {
659        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
660            test_audio_generation_client(&p).await;
661        }
662    }
663
664    fn assert_feature<F, M>(
665        name: &str,
666        feature_name: &str,
667        model_name: &str,
668        feature: Option<F>,
669        model: Option<M>,
670    ) {
671        assert_eq!(
672            feature.is_some(),
673            model.is_some(),
674            "{} has{} implemented {} but config.{} is {}.",
675            name,
676            if feature.is_some() { "" } else { "n't" },
677            feature_name,
678            model_name,
679            if model.is_some() { "some" } else { "none" }
680        );
681    }
682
683    #[test]
684    #[ignore]
685    pub fn test_polymorphism() {
686        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
687            let client = config.factory();
688            assert_feature(
689                config.name,
690                "AsCompletion",
691                "completion_model",
692                client.as_completion(),
693                config.completion_model,
694            );
695
696            assert_feature(
697                config.name,
698                "AsEmbeddings",
699                "embeddings_model",
700                client.as_embeddings(),
701                config.embeddings_model,
702            );
703
704            assert_feature(
705                config.name,
706                "AsTranscription",
707                "transcription_model",
708                client.as_transcription(),
709                config.transcription_model,
710            );
711
712            assert_feature(
713                config.name,
714                "AsImageGeneration",
715                "image_generation_model",
716                client.as_image_generation(),
717                config.image_generation_model,
718            );
719
720            assert_feature(
721                config.name,
722                "AsAudioGeneration",
723                "audio_generation_model",
724                client.as_audio_generation(),
725                config.audio_generation_model,
726            )
727        }
728    }
729
730    async fn test_embed_client(config: &ClientConfig) {
731        const TEST: &str = "Hello world.";
732
733        let client = config.factory();
734
735        let Some(client) = client.as_embeddings() else {
736            return;
737        };
738
739        let model = config.embeddings_model.unwrap();
740
741        let model = client.embedding_model(model);
742
743        let resp = model.embed_text(TEST).await;
744
745        assert!(
746            resp.is_ok(),
747            "[{}]: Error occurred when sending request, {}",
748            config.name,
749            resp.err().unwrap()
750        );
751
752        let resp = resp.unwrap();
753
754        assert_eq!(resp.document, TEST);
755
756        assert!(
757            !resp.vec.is_empty(),
758            "[{}]: Returned embed was empty",
759            config.name
760        );
761    }
762
763    #[tokio::test]
764    #[ignore]
765    async fn test_embed() {
766        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
767            test_embed_client(&config).await;
768        }
769    }
770
771    async fn test_image_generation_client(config: &ClientConfig) {
772        let client = config.factory();
773        let Some(client) = client.as_image_generation() else {
774            return;
775        };
776
777        let model = config.image_generation_model.unwrap();
778
779        let model = client.image_generation_model(model);
780
781        let resp = model
782            .image_generation(ImageGenerationRequest {
783                prompt: "A castle sitting on a large hill.".to_string(),
784                width: 256,
785                height: 256,
786                additional_params: None,
787            })
788            .await;
789
790        assert!(
791            resp.is_ok(),
792            "[{}]: Error occurred when sending request, {}",
793            config.name,
794            resp.err().unwrap()
795        );
796
797        let resp = resp.unwrap();
798
799        assert!(
800            !resp.image.is_empty(),
801            "[{}]: Generated image was empty",
802            config.name
803        );
804    }
805
806    #[tokio::test]
807    #[ignore]
808    async fn test_image_generation() {
809        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
810            test_image_generation_client(&config).await;
811        }
812    }
813
814    async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
815        let client = config.factory();
816        let Some(client) = client.as_transcription() else {
817            return;
818        };
819
820        let model = config.image_generation_model.unwrap();
821
822        let model = client.transcription_model(model);
823
824        let resp = model
825            .transcription(TranscriptionRequest {
826                data,
827                filename: "audio.mp3".to_string(),
828                language: "en".to_string(),
829                prompt: None,
830                temperature: None,
831                additional_params: None,
832            })
833            .await;
834
835        assert!(
836            resp.is_ok(),
837            "[{}]: Error occurred when sending request, {}",
838            config.name,
839            resp.err().unwrap()
840        );
841
842        let resp = resp.unwrap();
843
844        assert!(
845            !resp.text.is_empty(),
846            "[{}]: Returned transcription was empty",
847            config.name
848        );
849    }
850
851    #[tokio::test]
852    #[ignore]
853    async fn test_transcription() {
854        let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
855
856        let mut data = Vec::new();
857        let _ = file.read(&mut data);
858
859        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
860            test_transcription_client(&config, data.clone()).await;
861        }
862    }
863
864    #[derive(Deserialize)]
865    struct OperationArgs {
866        x: f32,
867        y: f32,
868    }
869
870    #[derive(Debug, thiserror::Error)]
871    #[error("Math error")]
872    struct MathError;
873
874    #[derive(Deserialize, Serialize)]
875    struct Adder;
876    impl Tool for Adder {
877        const NAME: &'static str = "add";
878
879        type Error = MathError;
880        type Args = OperationArgs;
881        type Output = f32;
882
883        async fn definition(&self, _prompt: String) -> ToolDefinition {
884            ToolDefinition {
885                name: "add".to_string(),
886                description: "Add x and y together".to_string(),
887                parameters: json!({
888                    "type": "object",
889                    "properties": {
890                        "x": {
891                            "type": "number",
892                            "description": "The first number to add"
893                        },
894                        "y": {
895                            "type": "number",
896                            "description": "The second number to add"
897                        }
898                    }
899                }),
900            }
901        }
902
903        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
904            println!("[tool-call] Adding {} and {}", args.x, args.y);
905            let result = args.x + args.y;
906            Ok(result)
907        }
908    }
909
910    #[derive(Deserialize, Serialize)]
911    struct Subtract;
912    impl Tool for Subtract {
913        const NAME: &'static str = "subtract";
914
915        type Error = MathError;
916        type Args = OperationArgs;
917        type Output = f32;
918
919        async fn definition(&self, _prompt: String) -> ToolDefinition {
920            serde_json::from_value(json!({
921                "name": "subtract",
922                "description": "Subtract y from x (i.e.: x - y)",
923                "parameters": {
924                    "type": "object",
925                    "properties": {
926                        "x": {
927                            "type": "number",
928                            "description": "The number to subtract from"
929                        },
930                        "y": {
931                            "type": "number",
932                            "description": "The number to subtract"
933                        }
934                    }
935                }
936            }))
937            .expect("Tool Definition")
938        }
939
940        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
941            println!("[tool-call] Subtracting {} from {}", args.y, args.x);
942            let result = args.x - args.y;
943            Ok(result)
944        }
945    }
946}