rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15
16/// The base ProviderClient trait, facilitates conversion between client types
17/// and creating a client from the environment.
18///
19/// All conversion traits must be implemented, they are automatically
20/// implemented if the respective client trait is implemented.
21pub trait ProviderClient:
22    AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
23{
24    /// Create a client from the process's environment.
25    /// Panics if an environment is improperly configured.
26    fn from_env() -> Self
27    where
28        Self: Sized;
29
30    /// A helper method to box the client.
31    fn boxed(self) -> Box<dyn ProviderClient>
32    where
33        Self: Sized + 'static,
34    {
35        Box::new(self)
36    }
37
38    /// Create a boxed client from the process's environment.
39    /// Panics if an environment is improperly configured.
40    fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
41    where
42        Self: Sized,
43        Self: 'a,
44    {
45        Box::new(Self::from_env())
46    }
47}
48
49/// Attempt to convert a ProviderClient to a CompletionClient
50pub trait AsCompletion {
51    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
52        None
53    }
54}
55
56/// Attempt to convert a ProviderClient to a TranscriptionClient
57pub trait AsTranscription {
58    fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
59        None
60    }
61}
62
63/// Attempt to convert a ProviderClient to a EmbeddingsClient
64pub trait AsEmbeddings {
65    fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
66        None
67    }
68}
69
70/// Attempt to convert a ProviderClient to a AudioGenerationClient
71pub trait AsAudioGeneration {
72    #[cfg(feature = "audio")]
73    fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
74        None
75    }
76}
77
78/// Attempt to convert a ProviderClient to a ImageGenerationClient
79pub trait AsImageGeneration {
80    #[cfg(feature = "image")]
81    fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
82        None
83    }
84}
85
86#[cfg(not(feature = "audio"))]
87impl<T: ProviderClient> AsAudioGeneration for T {}
88
89#[cfg(not(feature = "image"))]
90impl<T: ProviderClient> AsImageGeneration for T {}
91
92/// Implements the conversion traits for a given struct
93/// ```rust
94/// pub struct Client;
95/// impl ProviderClient for Client {
96///     ...
97/// }
98/// impl_conversion_traits!(AsCompletion, AsEmbeddings for Client);
99/// ```
100#[macro_export]
101macro_rules! impl_conversion_traits {
102    ($( $trait_:ident ),* for $struct_:ident ) => {
103        $(
104            impl_conversion_traits!(@impl $trait_ for $struct_);
105        )*
106    };
107
108    (@impl AsAudioGeneration for $struct_:ident ) => {
109        rig::client::impl_audio_generation!($struct_);
110    };
111
112    (@impl AsImageGeneration for $struct_:ident ) => {
113        rig::client::impl_image_generation!($struct_);
114    };
115
116    (@impl $trait_:ident for $struct_:ident) => {
117        impl rig::client::$trait_ for $struct_ {}
118    };
119}
120
121#[cfg(feature = "audio")]
122#[macro_export]
123macro_rules! impl_audio_generation {
124    ($struct_:ident) => {
125        impl rig::client::AsAudioGeneration for $struct_ {}
126    };
127}
128
129#[cfg(not(feature = "audio"))]
130#[macro_export]
131macro_rules! impl_audio_generation {
132    ($struct_:ident) => {};
133}
134
135#[cfg(feature = "image")]
136#[macro_export]
137macro_rules! impl_image_generation {
138    ($struct_:ident) => {
139        impl rig::client::AsImageGeneration for $struct_ {}
140    };
141}
142
143#[cfg(not(feature = "image"))]
144#[macro_export]
145macro_rules! impl_image_generation {
146    ($struct_:ident) => {};
147}
148
149pub use impl_audio_generation;
150pub use impl_conversion_traits;
151pub use impl_image_generation;
152
153#[cfg(feature = "audio")]
154use crate::client::audio_generation::AudioGenerationClientDyn;
155use crate::client::completion::CompletionClientDyn;
156use crate::client::embeddings::EmbeddingsClientDyn;
157#[cfg(feature = "image")]
158use crate::client::image_generation::ImageGenerationClientDyn;
159use crate::client::transcription::TranscriptionClientDyn;
160
161#[cfg(feature = "audio")]
162pub use crate::client::audio_generation::AudioGenerationClient;
163pub use crate::client::completion::CompletionClient;
164pub use crate::client::embeddings::EmbeddingsClient;
165#[cfg(feature = "image")]
166pub use crate::client::image_generation::ImageGenerationClient;
167pub use crate::client::transcription::TranscriptionClient;
168
169#[cfg(test)]
170mod tests {
171    use crate::client::ProviderClient;
172    use crate::completion::{Completion, CompletionRequest, ToolDefinition};
173    use crate::image_generation::ImageGenerationRequest;
174    use crate::message::AssistantContent;
175    use crate::providers::{
176        anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
177        moonshot, openai, openrouter, together, xai,
178    };
179    use crate::streaming::StreamingCompletion;
180    use crate::tool::Tool;
181    use crate::transcription::TranscriptionRequest;
182    use crate::OneOrMany;
183    use futures::StreamExt;
184    use rig::message::Message;
185    use rig::providers::{groq, ollama, perplexity};
186    use serde::{Deserialize, Serialize};
187    use serde_json::json;
188    use std::fs::File;
189    use std::io::Read;
190
191    struct ClientConfig {
192        name: &'static str,
193        factory: Box<dyn Fn() -> Box<dyn ProviderClient>>,
194        env_variable: &'static str,
195        completion_model: Option<&'static str>,
196        embeddings_model: Option<&'static str>,
197        transcription_model: Option<&'static str>,
198        image_generation_model: Option<&'static str>,
199        audio_generation_model: Option<(&'static str, &'static str)>,
200    }
201
202    impl Default for ClientConfig {
203        fn default() -> Self {
204            Self {
205                name: "",
206                factory: Box::new(|| panic!("Not implemented")),
207                env_variable: "",
208                completion_model: None,
209                embeddings_model: None,
210                transcription_model: None,
211                image_generation_model: None,
212                audio_generation_model: None,
213            }
214        }
215    }
216
217    impl ClientConfig {
218        fn is_env_var_set(&self) -> bool {
219            self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
220        }
221
222        fn factory(&self) -> Box<dyn ProviderClient + '_> {
223            let client = self.factory.as_ref()();
224
225            client
226        }
227    }
228
229    fn providers() -> Vec<ClientConfig> {
230        vec![
231            ClientConfig {
232                name: "Anthropic",
233                factory: Box::new(anthropic::Client::from_env_boxed),
234                env_variable: "ANTHROPIC_API_KEY",
235                completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
236                ..Default::default()
237            },
238            ClientConfig {
239                name: "Cohere",
240                factory: Box::new(cohere::Client::from_env_boxed),
241                env_variable: "COHERE_API_KEY",
242                completion_model: Some(cohere::COMMAND_R),
243                embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
244                ..Default::default()
245            },
246            ClientConfig {
247                name: "Gemini",
248                factory: Box::new(gemini::Client::from_env_boxed),
249                env_variable: "GEMINI_API_KEY",
250                completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
251                embeddings_model: Some(gemini::embedding::EMBEDDING_001),
252                transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
253                ..Default::default()
254            },
255            ClientConfig {
256                name: "Huggingface",
257                factory: Box::new(huggingface::Client::from_env_boxed),
258                env_variable: "HUGGINGFACE_API_KEY",
259                completion_model: Some(huggingface::PHI_4),
260                transcription_model: Some(huggingface::WHISPER_SMALL),
261                image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
262                ..Default::default()
263            },
264            ClientConfig {
265                name: "OpenAI",
266                factory: Box::new(openai::Client::from_env_boxed),
267                env_variable: "OPENAI_API_KEY",
268                completion_model: Some(openai::GPT_4O),
269                embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
270                transcription_model: Some(openai::WHISPER_1),
271                image_generation_model: Some(openai::DALL_E_2),
272                audio_generation_model: Some((openai::TTS_1, "onyx")),
273            },
274            ClientConfig {
275                name: "OpenRouter",
276                factory: Box::new(openrouter::Client::from_env_boxed),
277                env_variable: "OPENROUTER_API_KEY",
278                completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
279                ..Default::default()
280            },
281            ClientConfig {
282                name: "Together",
283                factory: Box::new(together::Client::from_env_boxed),
284                env_variable: "TOGETHER_API_KEY",
285                completion_model: Some(together::ALPACA_7B),
286                embeddings_model: Some(together::BERT_BASE_UNCASED),
287                ..Default::default()
288            },
289            ClientConfig {
290                name: "XAI",
291                factory: Box::new(xai::Client::from_env_boxed),
292                env_variable: "XAI_API_KEY",
293                completion_model: Some(xai::GROK_3_MINI),
294                embeddings_model: Some(xai::EMBEDDING_V1),
295                ..Default::default()
296            },
297            ClientConfig {
298                name: "Azure",
299                factory: Box::new(azure::Client::from_env_boxed),
300                env_variable: "AZURE_API_KEY",
301                completion_model: Some(azure::GPT_4O),
302                embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
303                transcription_model: Some("whisper-1"),
304                image_generation_model: Some("dalle-2"),
305                audio_generation_model: Some(("tts-1", "onyx")),
306            },
307            ClientConfig {
308                name: "Deepseek",
309                factory: Box::new(deepseek::Client::from_env_boxed),
310                env_variable: "DEEPSEEK_API_KEY",
311                completion_model: Some(deepseek::DEEPSEEK_CHAT),
312                ..Default::default()
313            },
314            ClientConfig {
315                name: "Galadriel",
316                factory: Box::new(galadriel::Client::from_env_boxed),
317                env_variable: "GALADRIEL_API_KEY",
318                completion_model: Some(galadriel::GPT_4O),
319                ..Default::default()
320            },
321            ClientConfig {
322                name: "Groq",
323                factory: Box::new(groq::Client::from_env_boxed),
324                env_variable: "GROQ_API_KEY",
325                completion_model: Some(groq::MIXTRAL_8X7B_32768),
326                transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
327                ..Default::default()
328            },
329            ClientConfig {
330                name: "Hyperbolic",
331                factory: Box::new(hyperbolic::Client::from_env_boxed),
332                env_variable: "HYPERBOLIC_API_KEY",
333                completion_model: Some(hyperbolic::LLAMA_3_1_8B),
334                image_generation_model: Some(hyperbolic::SD1_5),
335                audio_generation_model: Some(("EN", "EN-US")),
336                ..Default::default()
337            },
338            ClientConfig {
339                name: "Mira",
340                factory: Box::new(mira::Client::from_env_boxed),
341                env_variable: "MIRA_API_KEY",
342                completion_model: Some("gpt-4o"),
343                ..Default::default()
344            },
345            ClientConfig {
346                name: "Moonshot",
347                factory: Box::new(moonshot::Client::from_env_boxed),
348                env_variable: "MOONSHOT_API_KEY",
349                completion_model: Some(moonshot::MOONSHOT_CHAT),
350                ..Default::default()
351            },
352            ClientConfig {
353                name: "Ollama",
354                factory: Box::new(ollama::Client::from_env_boxed),
355                env_variable: "OLLAMA_ENABLED",
356                completion_model: Some("llama3.1:8b"),
357                embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
358                ..Default::default()
359            },
360            ClientConfig {
361                name: "Perplexity",
362                factory: Box::new(perplexity::Client::from_env_boxed),
363                env_variable: "PERPLEXITY_API_KEY",
364                completion_model: Some(perplexity::SONAR),
365                ..Default::default()
366            },
367        ]
368    }
369
370    async fn test_completions_client(config: &ClientConfig) {
371        let client = config.factory();
372
373        let Some(client) = client.as_completion() else {
374            return;
375        };
376
377        let model = config.completion_model.expect(&format!(
378            "{} does not have completion_model set",
379            config.name
380        ));
381
382        let model = client.completion_model(model);
383
384        let resp = model
385            .completion_request(Message::user("Whats the capital of France?"))
386            .send()
387            .await;
388
389        assert!(
390            resp.is_ok(),
391            "[{}]: Error occurred when prompting, {}",
392            config.name,
393            resp.err().unwrap()
394        );
395
396        let resp = resp.unwrap();
397
398        match resp.choice.first() {
399            AssistantContent::Text(text) => {
400                assert!(text.text.to_lowercase().contains("paris"));
401            }
402            _ => {
403                assert!(
404                    false,
405                    "[{}]: First choice wasn't a Text message, {:?}",
406                    config.name,
407                    resp.choice.first()
408                );
409            }
410        }
411    }
412
413    #[tokio::test]
414    #[ignore]
415    async fn test_completions() {
416        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
417            test_completions_client(&p).await;
418        }
419    }
420
421    async fn test_tools_client(config: &ClientConfig) {
422        let client = config.factory();
423        let model = config
424            .completion_model
425            .expect(&format!("{} does not have the model set.", config.name));
426
427        let Some(client) = client.as_completion() else {
428            return;
429        };
430
431        let model = client.agent(model)
432            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
433            .max_tokens(1024)
434            .tool(Adder)
435            .tool(Subtract)
436            .build();
437
438        let request = model.completion("Calculate 2 - 5", vec![]).await;
439
440        assert!(
441            request.is_ok(),
442            "[{}]: Error occurred when building prompt, {}",
443            config.name,
444            request.err().unwrap()
445        );
446
447        let resp = request.unwrap().send().await;
448
449        assert!(
450            resp.is_ok(),
451            "[{}]: Error occurred when prompting, {}",
452            config.name,
453            resp.err().unwrap()
454        );
455
456        let resp = resp.unwrap();
457
458        assert!(
459            resp.choice.iter().any(|content| match content {
460                AssistantContent::ToolCall(tc) => {
461                    if tc.function.name != Subtract::NAME {
462                        return false;
463                    }
464
465                    let arguments =
466                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
467                            .expect("Error parsing arguments");
468
469                    arguments.x == 2.0 && arguments.y == 5.0
470                }
471                _ => false,
472            }),
473            "[{}]: Model did not use the Subtract tool.",
474            config.name
475        )
476    }
477
478    #[tokio::test]
479    #[ignore]
480    async fn test_tools() {
481        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
482            test_tools_client(&p).await;
483        }
484    }
485
486    async fn test_streaming_client(config: &ClientConfig) {
487        let client = config.factory();
488
489        let Some(client) = client.as_completion() else {
490            return;
491        };
492
493        let model = config
494            .completion_model
495            .expect(&format!("{} does not have the model set.", config.name));
496
497        let model = client.completion_model(model);
498
499        let resp = model.stream(CompletionRequest {
500            preamble: None,
501            tools: vec![],
502            documents: vec![],
503            temperature: None,
504            max_tokens: None,
505            additional_params: None,
506            chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
507        });
508
509        let mut resp = resp.await.unwrap();
510
511        let mut received_chunk = false;
512
513        while let Some(chunk) = resp.next().await {
514            received_chunk = true;
515            assert!(chunk.is_ok());
516        }
517
518        assert!(
519            received_chunk,
520            "[{}]: Failed to receive a chunk from stream",
521            config.name
522        );
523
524        for choice in resp.choice {
525            match choice {
526                AssistantContent::Text(text) => {
527                    assert!(
528                        text.text.to_lowercase().contains("paris"),
529                        "[{}]: Did not answer with Paris",
530                        config.name
531                    );
532                }
533                AssistantContent::ToolCall(_) => {}
534            }
535        }
536    }
537
538    #[tokio::test]
539    #[ignore]
540    async fn test_streaming() {
541        for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
542            test_streaming_client(&provider).await;
543        }
544    }
545
546    async fn test_streaming_tools_client(config: &ClientConfig) {
547        let client = config.factory();
548        let model = config
549            .completion_model
550            .expect(&format!("{} does not have the model set.", config.name));
551
552        let Some(client) = client.as_completion() else {
553            return;
554        };
555
556        let model = client.agent(model)
557            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
558            .max_tokens(1024)
559            .tool(Adder)
560            .tool(Subtract)
561            .build();
562
563        let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
564
565        assert!(
566            request.is_ok(),
567            "[{}]: Error occurred when building prompt, {}",
568            config.name,
569            request.err().unwrap()
570        );
571
572        let resp = request.unwrap().stream().await;
573
574        assert!(
575            resp.is_ok(),
576            "[{}]: Error occurred when prompting, {}",
577            config.name,
578            resp.err().unwrap()
579        );
580
581        let mut resp = resp.unwrap();
582
583        let mut received_chunk = false;
584
585        while let Some(chunk) = resp.next().await {
586            received_chunk = true;
587            assert!(chunk.is_ok());
588        }
589
590        assert!(
591            received_chunk,
592            "[{}]: Failed to receive a chunk from stream",
593            config.name
594        );
595
596        assert!(
597            resp.choice.iter().any(|content| match content {
598                AssistantContent::ToolCall(tc) => {
599                    if tc.function.name != Subtract::NAME {
600                        return false;
601                    }
602
603                    let arguments =
604                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
605                            .expect("Error parsing arguments");
606
607                    arguments.x == 2.0 && arguments.y == 5.0
608                }
609                _ => false,
610            }),
611            "[{}]: Model did not use the Subtract tool.",
612            config.name
613        )
614    }
615
616    #[tokio::test]
617    #[ignore]
618    async fn test_streaming_tools() {
619        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
620            test_streaming_tools_client(&p).await;
621        }
622    }
623
624    async fn test_audio_generation_client(config: &ClientConfig) {
625        let client = config.factory();
626
627        let Some(client) = client.as_audio_generation() else {
628            return;
629        };
630
631        let (model, voice) = config
632            .audio_generation_model
633            .expect(&format!("{} doesn't have the model set", config.name));
634
635        let model = client.audio_generation_model(model);
636
637        let request = model
638            .audio_generation_request()
639            .text("Hello world!")
640            .voice(voice);
641
642        let resp = request.send().await;
643
644        assert!(
645            resp.is_ok(),
646            "[{}]: Error occurred when sending request, {}",
647            config.name,
648            resp.err().unwrap()
649        );
650
651        let resp = resp.unwrap();
652
653        assert!(
654            resp.audio.len() > 0,
655            "[{}]: Returned audio was empty",
656            config.name
657        );
658    }
659
660    #[tokio::test]
661    #[ignore]
662    async fn test_audio_generation() {
663        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
664            test_audio_generation_client(&p).await;
665        }
666    }
667
668    fn assert_feature<F, M>(
669        name: &str,
670        feature_name: &str,
671        model_name: &str,
672        feature: Option<F>,
673        model: Option<M>,
674    ) {
675        assert_eq!(
676            feature.is_some(),
677            model.is_some(),
678            "{} has{} implemented {} but config.{} is {}.",
679            name,
680            if feature.is_some() { "" } else { "n't" },
681            feature_name,
682            model_name,
683            if model.is_some() { "some" } else { "none" }
684        );
685    }
686
687    #[test]
688    #[ignore]
689    pub fn test_polymorphism() {
690        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
691            let client = config.factory();
692            assert_feature(
693                config.name,
694                "AsCompletion",
695                "completion_model",
696                client.as_completion(),
697                config.completion_model,
698            );
699
700            assert_feature(
701                config.name,
702                "AsEmbeddings",
703                "embeddings_model",
704                client.as_embeddings(),
705                config.embeddings_model,
706            );
707
708            assert_feature(
709                config.name,
710                "AsTranscription",
711                "transcription_model",
712                client.as_transcription(),
713                config.transcription_model,
714            );
715
716            assert_feature(
717                config.name,
718                "AsImageGeneration",
719                "image_generation_model",
720                client.as_image_generation(),
721                config.image_generation_model,
722            );
723
724            assert_feature(
725                config.name,
726                "AsAudioGeneration",
727                "audio_generation_model",
728                client.as_audio_generation(),
729                config.audio_generation_model,
730            )
731        }
732    }
733
734    async fn test_embed_client(config: &ClientConfig) {
735        const TEST: &'static str = "Hello world.";
736
737        let client = config.factory();
738
739        let Some(client) = client.as_embeddings() else {
740            return;
741        };
742
743        let model = config.embeddings_model.unwrap();
744
745        let model = client.embedding_model(model);
746
747        let resp = model.embed_text(TEST).await;
748
749        assert!(
750            resp.is_ok(),
751            "[{}]: Error occurred when sending request, {}",
752            config.name,
753            resp.err().unwrap()
754        );
755
756        let resp = resp.unwrap();
757
758        assert_eq!(resp.document, TEST);
759
760        assert!(
761            !resp.vec.is_empty(),
762            "[{}]: Returned embed was empty",
763            config.name
764        );
765    }
766
767    #[tokio::test]
768    #[ignore]
769    async fn test_embed() {
770        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
771            test_embed_client(&config).await;
772        }
773    }
774
775    async fn test_image_generation_client(config: &ClientConfig) {
776        let client = config.factory();
777        let Some(client) = client.as_image_generation() else {
778            return;
779        };
780
781        let model = config.image_generation_model.unwrap();
782
783        let model = client.image_generation_model(model);
784
785        let resp = model
786            .image_generation(ImageGenerationRequest {
787                prompt: "A castle sitting on a large hill.".to_string(),
788                width: 256,
789                height: 256,
790                additional_params: None,
791            })
792            .await;
793
794        assert!(
795            resp.is_ok(),
796            "[{}]: Error occurred when sending request, {}",
797            config.name,
798            resp.err().unwrap()
799        );
800
801        let resp = resp.unwrap();
802
803        assert!(
804            !resp.image.is_empty(),
805            "[{}]: Generated image was empty",
806            config.name
807        );
808    }
809
810    #[tokio::test]
811    #[ignore]
812    async fn test_image_generation() {
813        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
814            test_image_generation_client(&config).await;
815        }
816    }
817
818    async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
819        let client = config.factory();
820        let Some(client) = client.as_transcription() else {
821            return;
822        };
823
824        let model = config.image_generation_model.unwrap();
825
826        let model = client.transcription_model(model);
827
828        let resp = model
829            .transcription(TranscriptionRequest {
830                data,
831                filename: "audio.mp3".to_string(),
832                language: "en".to_string(),
833                prompt: None,
834                temperature: None,
835                additional_params: None,
836            })
837            .await;
838
839        assert!(
840            resp.is_ok(),
841            "[{}]: Error occurred when sending request, {}",
842            config.name,
843            resp.err().unwrap()
844        );
845
846        let resp = resp.unwrap();
847
848        assert!(
849            !resp.text.is_empty(),
850            "[{}]: Returned transcription was empty",
851            config.name
852        );
853    }
854
855    #[tokio::test]
856    #[ignore]
857    async fn test_transcription() {
858        let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
859
860        let mut data = Vec::new();
861        let _ = file.read(&mut data);
862
863        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
864            test_transcription_client(&config, data.clone()).await;
865        }
866    }
867
868    #[derive(Deserialize)]
869    struct OperationArgs {
870        x: f32,
871        y: f32,
872    }
873
874    #[derive(Debug, thiserror::Error)]
875    #[error("Math error")]
876    struct MathError;
877
878    #[derive(Deserialize, Serialize)]
879    struct Adder;
880    impl Tool for Adder {
881        const NAME: &'static str = "add";
882
883        type Error = MathError;
884        type Args = OperationArgs;
885        type Output = f32;
886
887        async fn definition(&self, _prompt: String) -> ToolDefinition {
888            ToolDefinition {
889                name: "add".to_string(),
890                description: "Add x and y together".to_string(),
891                parameters: json!({
892                    "type": "object",
893                    "properties": {
894                        "x": {
895                            "type": "number",
896                            "description": "The first number to add"
897                        },
898                        "y": {
899                            "type": "number",
900                            "description": "The second number to add"
901                        }
902                    }
903                }),
904            }
905        }
906
907        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
908            println!("[tool-call] Adding {} and {}", args.x, args.y);
909            let result = args.x + args.y;
910            Ok(result)
911        }
912    }
913
914    #[derive(Deserialize, Serialize)]
915    struct Subtract;
916    impl Tool for Subtract {
917        const NAME: &'static str = "subtract";
918
919        type Error = MathError;
920        type Args = OperationArgs;
921        type Output = f32;
922
923        async fn definition(&self, _prompt: String) -> ToolDefinition {
924            serde_json::from_value(json!({
925                "name": "subtract",
926                "description": "Subtract y from x (i.e.: x - y)",
927                "parameters": {
928                    "type": "object",
929                    "properties": {
930                        "x": {
931                            "type": "number",
932                            "description": "The number to subtract from"
933                        },
934                        "y": {
935                            "type": "number",
936                            "description": "The number to subtract"
937                        }
938                    }
939                }
940            }))
941            .expect("Tool Definition")
942        }
943
944        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
945            println!("[tool-call] Subtracting {} from {}", args.y, args.x);
946            let result = args.x - args.y;
947            Ok(result)
948        }
949    }
950}