rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15
16/// The base ProviderClient trait, facilitates conversion between client types
17/// and creating a client from the environment.
18///
19/// All conversion traits must be implemented, they are automatically
20/// implemented if the respective client trait is implemented.
21pub trait ProviderClient:
22    AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
23{
24    /// Create a client from the process's environment.
25    /// Panics if an environment is improperly configured.
26    fn from_env() -> Self
27    where
28        Self: Sized;
29
30    /// A helper method to box the client.
31    fn boxed(self) -> Box<dyn ProviderClient>
32    where
33        Self: Sized + 'static,
34    {
35        Box::new(self)
36    }
37
38    /// Create a boxed client from the process's environment.
39    /// Panics if an environment is improperly configured.
40    fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
41    where
42        Self: Sized,
43        Self: 'a,
44    {
45        Box::new(Self::from_env())
46    }
47
48    fn from_val(input: ProviderValue) -> Self
49    where
50        Self: Sized;
51
52    /// Create a boxed client from the process's environment.
53    /// Panics if an environment is improperly configured.
54    fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
55    where
56        Self: Sized,
57        Self: 'a,
58    {
59        Box::new(Self::from_val(input))
60    }
61}
62
63#[derive(Clone)]
64pub enum ProviderValue {
65    Simple(String),
66    ApiKeyWithOptionalKey(String, Option<String>),
67    ApiKeyWithVersionAndHeader(String, String, String),
68}
69
70impl From<&str> for ProviderValue {
71    fn from(value: &str) -> Self {
72        Self::Simple(value.to_string())
73    }
74}
75
76impl From<String> for ProviderValue {
77    fn from(value: String) -> Self {
78        Self::Simple(value)
79    }
80}
81
82impl<P> From<(P, Option<P>)> for ProviderValue
83where
84    P: AsRef<str>,
85{
86    fn from((api_key, optional_key): (P, Option<P>)) -> Self {
87        Self::ApiKeyWithOptionalKey(
88            api_key.as_ref().to_string(),
89            optional_key.map(|x| x.as_ref().to_string()),
90        )
91    }
92}
93
94impl<P> From<(P, P, P)> for ProviderValue
95where
96    P: AsRef<str>,
97{
98    fn from((api_key, version, header): (P, P, P)) -> Self {
99        Self::ApiKeyWithVersionAndHeader(
100            api_key.as_ref().to_string(),
101            version.as_ref().to_string(),
102            header.as_ref().to_string(),
103        )
104    }
105}
106
107/// Attempt to convert a ProviderClient to a CompletionClient
108pub trait AsCompletion {
109    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
110        None
111    }
112}
113
114/// Attempt to convert a ProviderClient to a TranscriptionClient
115pub trait AsTranscription {
116    fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
117        None
118    }
119}
120
121/// Attempt to convert a ProviderClient to a EmbeddingsClient
122pub trait AsEmbeddings {
123    fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
124        None
125    }
126}
127
128/// Attempt to convert a ProviderClient to a AudioGenerationClient
129pub trait AsAudioGeneration {
130    #[cfg(feature = "audio")]
131    fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
132        None
133    }
134}
135
136/// Attempt to convert a ProviderClient to a ImageGenerationClient
137pub trait AsImageGeneration {
138    #[cfg(feature = "image")]
139    fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
140        None
141    }
142}
143
144#[cfg(not(feature = "audio"))]
145impl<T: ProviderClient> AsAudioGeneration for T {}
146
147#[cfg(not(feature = "image"))]
148impl<T: ProviderClient> AsImageGeneration for T {}
149
150/// Implements the conversion traits for a given struct
151/// ```rust
152/// pub struct Client;
153/// impl ProviderClient for Client {
154///     ...
155/// }
156/// impl_conversion_traits!(AsCompletion, AsEmbeddings for Client);
157/// ```
158#[macro_export]
159macro_rules! impl_conversion_traits {
160    ($( $trait_:ident ),* for $struct_:ident ) => {
161        $(
162            impl_conversion_traits!(@impl $trait_ for $struct_);
163        )*
164    };
165
166    (@impl AsAudioGeneration for $struct_:ident ) => {
167        rig::client::impl_audio_generation!($struct_);
168    };
169
170    (@impl AsImageGeneration for $struct_:ident ) => {
171        rig::client::impl_image_generation!($struct_);
172    };
173
174    (@impl $trait_:ident for $struct_:ident) => {
175        impl rig::client::$trait_ for $struct_ {}
176    };
177}
178
179#[cfg(feature = "audio")]
180#[macro_export]
181macro_rules! impl_audio_generation {
182    ($struct_:ident) => {
183        impl rig::client::AsAudioGeneration for $struct_ {}
184    };
185}
186
187#[cfg(not(feature = "audio"))]
188#[macro_export]
189macro_rules! impl_audio_generation {
190    ($struct_:ident) => {};
191}
192
193#[cfg(feature = "image")]
194#[macro_export]
195macro_rules! impl_image_generation {
196    ($struct_:ident) => {
197        impl rig::client::AsImageGeneration for $struct_ {}
198    };
199}
200
201#[cfg(not(feature = "image"))]
202#[macro_export]
203macro_rules! impl_image_generation {
204    ($struct_:ident) => {};
205}
206
207pub use impl_audio_generation;
208pub use impl_conversion_traits;
209pub use impl_image_generation;
210
211#[cfg(feature = "audio")]
212use crate::client::audio_generation::AudioGenerationClientDyn;
213use crate::client::completion::CompletionClientDyn;
214use crate::client::embeddings::EmbeddingsClientDyn;
215#[cfg(feature = "image")]
216use crate::client::image_generation::ImageGenerationClientDyn;
217use crate::client::transcription::TranscriptionClientDyn;
218
219#[cfg(feature = "audio")]
220pub use crate::client::audio_generation::AudioGenerationClient;
221pub use crate::client::completion::CompletionClient;
222pub use crate::client::embeddings::EmbeddingsClient;
223#[cfg(feature = "image")]
224pub use crate::client::image_generation::ImageGenerationClient;
225pub use crate::client::transcription::TranscriptionClient;
226
227#[cfg(test)]
228mod tests {
229    use crate::OneOrMany;
230    use crate::client::ProviderClient;
231    use crate::completion::{Completion, CompletionRequest, ToolDefinition};
232    use crate::image_generation::ImageGenerationRequest;
233    use crate::message::AssistantContent;
234    use crate::providers::{
235        anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
236        moonshot, openai, openrouter, together, xai,
237    };
238    use crate::streaming::StreamingCompletion;
239    use crate::tool::Tool;
240    use crate::transcription::TranscriptionRequest;
241    use futures::StreamExt;
242    use rig::message::Message;
243    use rig::providers::{groq, ollama, perplexity};
244    use serde::{Deserialize, Serialize};
245    use serde_json::json;
246    use std::fs::File;
247    use std::io::Read;
248
249    use super::ProviderValue;
250
251    struct ClientConfig {
252        name: &'static str,
253        factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
254        // Not sure where we're going to be using this but I've added it for completeness
255        #[allow(dead_code)]
256        factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
257        env_variable: &'static str,
258        completion_model: Option<&'static str>,
259        embeddings_model: Option<&'static str>,
260        transcription_model: Option<&'static str>,
261        image_generation_model: Option<&'static str>,
262        audio_generation_model: Option<(&'static str, &'static str)>,
263    }
264
265    impl Default for ClientConfig {
266        fn default() -> Self {
267            Self {
268                name: "",
269                factory_env: Box::new(|| panic!("Not implemented")),
270                factory_val: Box::new(|_| panic!("Not implemented")),
271                env_variable: "",
272                completion_model: None,
273                embeddings_model: None,
274                transcription_model: None,
275                image_generation_model: None,
276                audio_generation_model: None,
277            }
278        }
279    }
280
281    impl ClientConfig {
282        fn is_env_var_set(&self) -> bool {
283            self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
284        }
285
286        fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
287            self.factory_env.as_ref()()
288        }
289    }
290
291    fn providers() -> Vec<ClientConfig> {
292        vec![
293            ClientConfig {
294                name: "Anthropic",
295                factory_env: Box::new(anthropic::Client::from_env_boxed),
296                factory_val: Box::new(anthropic::Client::from_val_boxed),
297                env_variable: "ANTHROPIC_API_KEY",
298                completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
299                ..Default::default()
300            },
301            ClientConfig {
302                name: "Cohere",
303                factory_env: Box::new(cohere::Client::from_env_boxed),
304                factory_val: Box::new(cohere::Client::from_val_boxed),
305                env_variable: "COHERE_API_KEY",
306                completion_model: Some(cohere::COMMAND_R),
307                embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
308                ..Default::default()
309            },
310            ClientConfig {
311                name: "Gemini",
312                factory_env: Box::new(gemini::Client::from_env_boxed),
313                factory_val: Box::new(gemini::Client::from_val_boxed),
314                env_variable: "GEMINI_API_KEY",
315                completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
316                embeddings_model: Some(gemini::embedding::EMBEDDING_001),
317                transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
318                ..Default::default()
319            },
320            ClientConfig {
321                name: "Huggingface",
322                factory_env: Box::new(huggingface::Client::from_env_boxed),
323                factory_val: Box::new(huggingface::Client::from_val_boxed),
324                env_variable: "HUGGINGFACE_API_KEY",
325                completion_model: Some(huggingface::PHI_4),
326                transcription_model: Some(huggingface::WHISPER_SMALL),
327                image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
328                ..Default::default()
329            },
330            ClientConfig {
331                name: "OpenAI",
332                factory_env: Box::new(openai::Client::from_env_boxed),
333                factory_val: Box::new(openai::Client::from_val_boxed),
334                env_variable: "OPENAI_API_KEY",
335                completion_model: Some(openai::GPT_4O),
336                embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
337                transcription_model: Some(openai::WHISPER_1),
338                image_generation_model: Some(openai::DALL_E_2),
339                audio_generation_model: Some((openai::TTS_1, "onyx")),
340            },
341            ClientConfig {
342                name: "OpenRouter",
343                factory_env: Box::new(openrouter::Client::from_env_boxed),
344                factory_val: Box::new(openrouter::Client::from_val_boxed),
345                env_variable: "OPENROUTER_API_KEY",
346                completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
347                ..Default::default()
348            },
349            ClientConfig {
350                name: "Together",
351                factory_env: Box::new(together::Client::from_env_boxed),
352                factory_val: Box::new(together::Client::from_val_boxed),
353                env_variable: "TOGETHER_API_KEY",
354                completion_model: Some(together::ALPACA_7B),
355                embeddings_model: Some(together::BERT_BASE_UNCASED),
356                ..Default::default()
357            },
358            ClientConfig {
359                name: "XAI",
360                factory_env: Box::new(xai::Client::from_env_boxed),
361                factory_val: Box::new(xai::Client::from_val_boxed),
362                env_variable: "XAI_API_KEY",
363                completion_model: Some(xai::GROK_3_MINI),
364                embeddings_model: None,
365                ..Default::default()
366            },
367            ClientConfig {
368                name: "Azure",
369                factory_env: Box::new(azure::Client::from_env_boxed),
370                factory_val: Box::new(azure::Client::from_val_boxed),
371                env_variable: "AZURE_API_KEY",
372                completion_model: Some(azure::GPT_4O),
373                embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
374                transcription_model: Some("whisper-1"),
375                image_generation_model: Some("dalle-2"),
376                audio_generation_model: Some(("tts-1", "onyx")),
377            },
378            ClientConfig {
379                name: "Deepseek",
380                factory_env: Box::new(deepseek::Client::from_env_boxed),
381                factory_val: Box::new(deepseek::Client::from_val_boxed),
382                env_variable: "DEEPSEEK_API_KEY",
383                completion_model: Some(deepseek::DEEPSEEK_CHAT),
384                ..Default::default()
385            },
386            ClientConfig {
387                name: "Galadriel",
388                factory_env: Box::new(galadriel::Client::from_env_boxed),
389                factory_val: Box::new(galadriel::Client::from_val_boxed),
390                env_variable: "GALADRIEL_API_KEY",
391                completion_model: Some(galadriel::GPT_4O),
392                ..Default::default()
393            },
394            ClientConfig {
395                name: "Groq",
396                factory_env: Box::new(groq::Client::from_env_boxed),
397                factory_val: Box::new(groq::Client::from_val_boxed),
398                env_variable: "GROQ_API_KEY",
399                completion_model: Some(groq::MIXTRAL_8X7B_32768),
400                transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
401                ..Default::default()
402            },
403            ClientConfig {
404                name: "Hyperbolic",
405                factory_env: Box::new(hyperbolic::Client::from_env_boxed),
406                factory_val: Box::new(hyperbolic::Client::from_val_boxed),
407                env_variable: "HYPERBOLIC_API_KEY",
408                completion_model: Some(hyperbolic::LLAMA_3_1_8B),
409                image_generation_model: Some(hyperbolic::SD1_5),
410                audio_generation_model: Some(("EN", "EN-US")),
411                ..Default::default()
412            },
413            ClientConfig {
414                name: "Mira",
415                factory_env: Box::new(mira::Client::from_env_boxed),
416                factory_val: Box::new(mira::Client::from_val_boxed),
417                env_variable: "MIRA_API_KEY",
418                completion_model: Some("gpt-4o"),
419                ..Default::default()
420            },
421            ClientConfig {
422                name: "Moonshot",
423                factory_env: Box::new(moonshot::Client::from_env_boxed),
424                factory_val: Box::new(moonshot::Client::from_val_boxed),
425                env_variable: "MOONSHOT_API_KEY",
426                completion_model: Some(moonshot::MOONSHOT_CHAT),
427                ..Default::default()
428            },
429            ClientConfig {
430                name: "Ollama",
431                factory_env: Box::new(ollama::Client::from_env_boxed),
432                factory_val: Box::new(ollama::Client::from_val_boxed),
433                env_variable: "OLLAMA_ENABLED",
434                completion_model: Some("llama3.1:8b"),
435                embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
436                ..Default::default()
437            },
438            ClientConfig {
439                name: "Perplexity",
440                factory_env: Box::new(perplexity::Client::from_env_boxed),
441                factory_val: Box::new(perplexity::Client::from_val_boxed),
442                env_variable: "PERPLEXITY_API_KEY",
443                completion_model: Some(perplexity::SONAR),
444                ..Default::default()
445            },
446        ]
447    }
448
449    async fn test_completions_client(config: &ClientConfig) {
450        let client = config.factory_env();
451
452        let Some(client) = client.as_completion() else {
453            return;
454        };
455
456        let model = config
457            .completion_model
458            .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
459
460        let model = client.completion_model(model);
461
462        let resp = model
463            .completion_request(Message::user("Whats the capital of France?"))
464            .send()
465            .await;
466
467        assert!(
468            resp.is_ok(),
469            "[{}]: Error occurred when prompting, {}",
470            config.name,
471            resp.err().unwrap()
472        );
473
474        let resp = resp.unwrap();
475
476        match resp.choice.first() {
477            AssistantContent::Text(text) => {
478                assert!(text.text.to_lowercase().contains("paris"));
479            }
480            _ => {
481                unreachable!(
482                    "[{}]: First choice wasn't a Text message, {:?}",
483                    config.name,
484                    resp.choice.first()
485                );
486            }
487        }
488    }
489
490    #[tokio::test]
491    #[ignore]
492    async fn test_completions() {
493        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
494            test_completions_client(&p).await;
495        }
496    }
497
498    async fn test_tools_client(config: &ClientConfig) {
499        let client = config.factory_env();
500        let model = config
501            .completion_model
502            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
503
504        let Some(client) = client.as_completion() else {
505            return;
506        };
507
508        let model = client.agent(model)
509            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
510            .max_tokens(1024)
511            .tool(Adder)
512            .tool(Subtract)
513            .build();
514
515        let request = model.completion("Calculate 2 - 5", vec![]).await;
516
517        assert!(
518            request.is_ok(),
519            "[{}]: Error occurred when building prompt, {}",
520            config.name,
521            request.err().unwrap()
522        );
523
524        let resp = request.unwrap().send().await;
525
526        assert!(
527            resp.is_ok(),
528            "[{}]: Error occurred when prompting, {}",
529            config.name,
530            resp.err().unwrap()
531        );
532
533        let resp = resp.unwrap();
534
535        assert!(
536            resp.choice.iter().any(|content| match content {
537                AssistantContent::ToolCall(tc) => {
538                    if tc.function.name != Subtract::NAME {
539                        return false;
540                    }
541
542                    let arguments =
543                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
544                            .expect("Error parsing arguments");
545
546                    arguments.x == 2.0 && arguments.y == 5.0
547                }
548                _ => false,
549            }),
550            "[{}]: Model did not use the Subtract tool.",
551            config.name
552        )
553    }
554
555    #[tokio::test]
556    #[ignore]
557    async fn test_tools() {
558        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
559            test_tools_client(&p).await;
560        }
561    }
562
563    async fn test_streaming_client(config: &ClientConfig) {
564        let client = config.factory_env();
565
566        let Some(client) = client.as_completion() else {
567            return;
568        };
569
570        let model = config
571            .completion_model
572            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
573
574        let model = client.completion_model(model);
575
576        let resp = model.stream(CompletionRequest {
577            preamble: None,
578            tools: vec![],
579            documents: vec![],
580            temperature: None,
581            max_tokens: None,
582            additional_params: None,
583            chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
584        });
585
586        let mut resp = resp.await.unwrap();
587
588        let mut received_chunk = false;
589
590        while let Some(chunk) = resp.next().await {
591            received_chunk = true;
592            assert!(chunk.is_ok());
593        }
594
595        assert!(
596            received_chunk,
597            "[{}]: Failed to receive a chunk from stream",
598            config.name
599        );
600
601        for choice in resp.choice {
602            match choice {
603                AssistantContent::Text(text) => {
604                    assert!(
605                        text.text.to_lowercase().contains("paris"),
606                        "[{}]: Did not answer with Paris",
607                        config.name
608                    );
609                }
610                AssistantContent::ToolCall(_) => {}
611            }
612        }
613    }
614
615    #[tokio::test]
616    #[ignore]
617    async fn test_streaming() {
618        for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
619            test_streaming_client(&provider).await;
620        }
621    }
622
623    async fn test_streaming_tools_client(config: &ClientConfig) {
624        let client = config.factory_env();
625        let model = config
626            .completion_model
627            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
628
629        let Some(client) = client.as_completion() else {
630            return;
631        };
632
633        let model = client.agent(model)
634            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
635            .max_tokens(1024)
636            .tool(Adder)
637            .tool(Subtract)
638            .build();
639
640        let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
641
642        assert!(
643            request.is_ok(),
644            "[{}]: Error occurred when building prompt, {}",
645            config.name,
646            request.err().unwrap()
647        );
648
649        let resp = request.unwrap().stream().await;
650
651        assert!(
652            resp.is_ok(),
653            "[{}]: Error occurred when prompting, {}",
654            config.name,
655            resp.err().unwrap()
656        );
657
658        let mut resp = resp.unwrap();
659
660        let mut received_chunk = false;
661
662        while let Some(chunk) = resp.next().await {
663            received_chunk = true;
664            assert!(chunk.is_ok());
665        }
666
667        assert!(
668            received_chunk,
669            "[{}]: Failed to receive a chunk from stream",
670            config.name
671        );
672
673        assert!(
674            resp.choice.iter().any(|content| match content {
675                AssistantContent::ToolCall(tc) => {
676                    if tc.function.name != Subtract::NAME {
677                        return false;
678                    }
679
680                    let arguments =
681                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
682                            .expect("Error parsing arguments");
683
684                    arguments.x == 2.0 && arguments.y == 5.0
685                }
686                _ => false,
687            }),
688            "[{}]: Model did not use the Subtract tool.",
689            config.name
690        )
691    }
692
693    #[tokio::test]
694    #[ignore]
695    async fn test_streaming_tools() {
696        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
697            test_streaming_tools_client(&p).await;
698        }
699    }
700
701    async fn test_audio_generation_client(config: &ClientConfig) {
702        let client = config.factory_env();
703
704        let Some(client) = client.as_audio_generation() else {
705            return;
706        };
707
708        let (model, voice) = config
709            .audio_generation_model
710            .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
711
712        let model = client.audio_generation_model(model);
713
714        let request = model
715            .audio_generation_request()
716            .text("Hello world!")
717            .voice(voice);
718
719        let resp = request.send().await;
720
721        assert!(
722            resp.is_ok(),
723            "[{}]: Error occurred when sending request, {}",
724            config.name,
725            resp.err().unwrap()
726        );
727
728        let resp = resp.unwrap();
729
730        assert!(
731            !resp.audio.is_empty(),
732            "[{}]: Returned audio was empty",
733            config.name
734        );
735    }
736
737    #[tokio::test]
738    #[ignore]
739    async fn test_audio_generation() {
740        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
741            test_audio_generation_client(&p).await;
742        }
743    }
744
745    fn assert_feature<F, M>(
746        name: &str,
747        feature_name: &str,
748        model_name: &str,
749        feature: Option<F>,
750        model: Option<M>,
751    ) {
752        assert_eq!(
753            feature.is_some(),
754            model.is_some(),
755            "{} has{} implemented {} but config.{} is {}.",
756            name,
757            if feature.is_some() { "" } else { "n't" },
758            feature_name,
759            model_name,
760            if model.is_some() { "some" } else { "none" }
761        );
762    }
763
764    #[test]
765    #[ignore]
766    pub fn test_polymorphism() {
767        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
768            let client = config.factory_env();
769            assert_feature(
770                config.name,
771                "AsCompletion",
772                "completion_model",
773                client.as_completion(),
774                config.completion_model,
775            );
776
777            assert_feature(
778                config.name,
779                "AsEmbeddings",
780                "embeddings_model",
781                client.as_embeddings(),
782                config.embeddings_model,
783            );
784
785            assert_feature(
786                config.name,
787                "AsTranscription",
788                "transcription_model",
789                client.as_transcription(),
790                config.transcription_model,
791            );
792
793            assert_feature(
794                config.name,
795                "AsImageGeneration",
796                "image_generation_model",
797                client.as_image_generation(),
798                config.image_generation_model,
799            );
800
801            assert_feature(
802                config.name,
803                "AsAudioGeneration",
804                "audio_generation_model",
805                client.as_audio_generation(),
806                config.audio_generation_model,
807            )
808        }
809    }
810
811    async fn test_embed_client(config: &ClientConfig) {
812        const TEST: &str = "Hello world.";
813
814        let client = config.factory_env();
815
816        let Some(client) = client.as_embeddings() else {
817            return;
818        };
819
820        let model = config.embeddings_model.unwrap();
821
822        let model = client.embedding_model(model);
823
824        let resp = model.embed_text(TEST).await;
825
826        assert!(
827            resp.is_ok(),
828            "[{}]: Error occurred when sending request, {}",
829            config.name,
830            resp.err().unwrap()
831        );
832
833        let resp = resp.unwrap();
834
835        assert_eq!(resp.document, TEST);
836
837        assert!(
838            !resp.vec.is_empty(),
839            "[{}]: Returned embed was empty",
840            config.name
841        );
842    }
843
844    #[tokio::test]
845    #[ignore]
846    async fn test_embed() {
847        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
848            test_embed_client(&config).await;
849        }
850    }
851
852    async fn test_image_generation_client(config: &ClientConfig) {
853        let client = config.factory_env();
854        let Some(client) = client.as_image_generation() else {
855            return;
856        };
857
858        let model = config.image_generation_model.unwrap();
859
860        let model = client.image_generation_model(model);
861
862        let resp = model
863            .image_generation(ImageGenerationRequest {
864                prompt: "A castle sitting on a large hill.".to_string(),
865                width: 256,
866                height: 256,
867                additional_params: None,
868            })
869            .await;
870
871        assert!(
872            resp.is_ok(),
873            "[{}]: Error occurred when sending request, {}",
874            config.name,
875            resp.err().unwrap()
876        );
877
878        let resp = resp.unwrap();
879
880        assert!(
881            !resp.image.is_empty(),
882            "[{}]: Generated image was empty",
883            config.name
884        );
885    }
886
887    #[tokio::test]
888    #[ignore]
889    async fn test_image_generation() {
890        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
891            test_image_generation_client(&config).await;
892        }
893    }
894
895    async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
896        let client = config.factory_env();
897        let Some(client) = client.as_transcription() else {
898            return;
899        };
900
901        let model = config.image_generation_model.unwrap();
902
903        let model = client.transcription_model(model);
904
905        let resp = model
906            .transcription(TranscriptionRequest {
907                data,
908                filename: "audio.mp3".to_string(),
909                language: "en".to_string(),
910                prompt: None,
911                temperature: None,
912                additional_params: None,
913            })
914            .await;
915
916        assert!(
917            resp.is_ok(),
918            "[{}]: Error occurred when sending request, {}",
919            config.name,
920            resp.err().unwrap()
921        );
922
923        let resp = resp.unwrap();
924
925        assert!(
926            !resp.text.is_empty(),
927            "[{}]: Returned transcription was empty",
928            config.name
929        );
930    }
931
932    #[tokio::test]
933    #[ignore]
934    async fn test_transcription() {
935        let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
936
937        let mut data = Vec::new();
938        let _ = file.read(&mut data);
939
940        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
941            test_transcription_client(&config, data.clone()).await;
942        }
943    }
944
945    #[derive(Deserialize)]
946    struct OperationArgs {
947        x: f32,
948        y: f32,
949    }
950
951    #[derive(Debug, thiserror::Error)]
952    #[error("Math error")]
953    struct MathError;
954
955    #[derive(Deserialize, Serialize)]
956    struct Adder;
957    impl Tool for Adder {
958        const NAME: &'static str = "add";
959
960        type Error = MathError;
961        type Args = OperationArgs;
962        type Output = f32;
963
964        async fn definition(&self, _prompt: String) -> ToolDefinition {
965            ToolDefinition {
966                name: "add".to_string(),
967                description: "Add x and y together".to_string(),
968                parameters: json!({
969                    "type": "object",
970                    "properties": {
971                        "x": {
972                            "type": "number",
973                            "description": "The first number to add"
974                        },
975                        "y": {
976                            "type": "number",
977                            "description": "The second number to add"
978                        }
979                    }
980                }),
981            }
982        }
983
984        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
985            println!("[tool-call] Adding {} and {}", args.x, args.y);
986            let result = args.x + args.y;
987            Ok(result)
988        }
989    }
990
991    #[derive(Deserialize, Serialize)]
992    struct Subtract;
993    impl Tool for Subtract {
994        const NAME: &'static str = "subtract";
995
996        type Error = MathError;
997        type Args = OperationArgs;
998        type Output = f32;
999
1000        async fn definition(&self, _prompt: String) -> ToolDefinition {
1001            serde_json::from_value(json!({
1002                "name": "subtract",
1003                "description": "Subtract y from x (i.e.: x - y)",
1004                "parameters": {
1005                    "type": "object",
1006                    "properties": {
1007                        "x": {
1008                            "type": "number",
1009                            "description": "The number to subtract from"
1010                        },
1011                        "y": {
1012                            "type": "number",
1013                            "description": "The number to subtract"
1014                        }
1015                    }
1016                }
1017            }))
1018            .expect("Tool Definition")
1019        }
1020
1021        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1022            println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1023            let result = args.x - args.y;
1024            Ok(result)
1025        }
1026    }
1027}