rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13#[cfg(feature = "derive")]
14pub use rig_derive::ProviderClient;
15use std::fmt::Debug;
16use thiserror::Error;
17
18#[derive(Debug, Error)]
19#[non_exhaustive]
20pub enum ClientBuilderError {
21    #[error("reqwest error: {0}")]
22    HttpError(
23        #[from]
24        #[source]
25        reqwest::Error,
26    ),
27    #[error("invalid property: {0}")]
28    InvalidProperty(&'static str),
29}
30
31/// The base ProviderClient trait, facilitates conversion between client types
32/// and creating a client from the environment.
33///
34/// All conversion traits must be implemented, they are automatically
35/// implemented if the respective client trait is implemented.
36pub trait ProviderClient:
37    AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
38{
39    /// Create a client from the process's environment.
40    /// Panics if an environment is improperly configured.
41    fn from_env() -> Self
42    where
43        Self: Sized;
44
45    /// A helper method to box the client.
46    fn boxed(self) -> Box<dyn ProviderClient>
47    where
48        Self: Sized + 'static,
49    {
50        Box::new(self)
51    }
52
53    /// Create a boxed client from the process's environment.
54    /// Panics if an environment is improperly configured.
55    fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
56    where
57        Self: Sized,
58        Self: 'a,
59    {
60        Box::new(Self::from_env())
61    }
62
63    fn from_val(input: ProviderValue) -> Self
64    where
65        Self: Sized;
66
67    /// Create a boxed client from the process's environment.
68    /// Panics if an environment is improperly configured.
69    fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
70    where
71        Self: Sized,
72        Self: 'a,
73    {
74        Box::new(Self::from_val(input))
75    }
76}
77
78#[derive(Clone)]
79pub enum ProviderValue {
80    Simple(String),
81    ApiKeyWithOptionalKey(String, Option<String>),
82    ApiKeyWithVersionAndHeader(String, String, String),
83}
84
85impl From<&str> for ProviderValue {
86    fn from(value: &str) -> Self {
87        Self::Simple(value.to_string())
88    }
89}
90
91impl From<String> for ProviderValue {
92    fn from(value: String) -> Self {
93        Self::Simple(value)
94    }
95}
96
97impl<P> From<(P, Option<P>)> for ProviderValue
98where
99    P: AsRef<str>,
100{
101    fn from((api_key, optional_key): (P, Option<P>)) -> Self {
102        Self::ApiKeyWithOptionalKey(
103            api_key.as_ref().to_string(),
104            optional_key.map(|x| x.as_ref().to_string()),
105        )
106    }
107}
108
109impl<P> From<(P, P, P)> for ProviderValue
110where
111    P: AsRef<str>,
112{
113    fn from((api_key, version, header): (P, P, P)) -> Self {
114        Self::ApiKeyWithVersionAndHeader(
115            api_key.as_ref().to_string(),
116            version.as_ref().to_string(),
117            header.as_ref().to_string(),
118        )
119    }
120}
121
122/// Attempt to convert a ProviderClient to a CompletionClient
123pub trait AsCompletion {
124    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
125        None
126    }
127}
128
129/// Attempt to convert a ProviderClient to a TranscriptionClient
130pub trait AsTranscription {
131    fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
132        None
133    }
134}
135
136/// Attempt to convert a ProviderClient to a EmbeddingsClient
137pub trait AsEmbeddings {
138    fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
139        None
140    }
141}
142
143/// Attempt to convert a ProviderClient to a AudioGenerationClient
144pub trait AsAudioGeneration {
145    #[cfg(feature = "audio")]
146    fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
147        None
148    }
149}
150
151/// Attempt to convert a ProviderClient to a ImageGenerationClient
152pub trait AsImageGeneration {
153    #[cfg(feature = "image")]
154    fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
155        None
156    }
157}
158
159/// Attempt to convert a ProviderClient to a VerifyClient
160pub trait AsVerify {
161    fn as_verify(&self) -> Option<Box<dyn VerifyClientDyn>> {
162        None
163    }
164}
165
166#[cfg(not(feature = "audio"))]
167impl<T: ProviderClient> AsAudioGeneration for T {}
168
169#[cfg(not(feature = "image"))]
170impl<T: ProviderClient> AsImageGeneration for T {}
171
172/// Implements the conversion traits for a given struct
173/// ```rust
174/// pub struct Client;
175/// impl ProviderClient for Client {
176///     ...
177/// }
178/// impl_conversion_traits!(AsCompletion, AsEmbeddings for Client);
179/// ```
180#[macro_export]
181macro_rules! impl_conversion_traits {
182    ($( $trait_:ident ),* for $struct_:ident ) => {
183        $(
184            impl_conversion_traits!(@impl $trait_ for $struct_);
185        )*
186    };
187
188    (@impl AsAudioGeneration for $struct_:ident ) => {
189        rig::client::impl_audio_generation!($struct_);
190    };
191
192    (@impl AsImageGeneration for $struct_:ident ) => {
193        rig::client::impl_image_generation!($struct_);
194    };
195
196    (@impl $trait_:ident for $struct_:ident) => {
197        impl rig::client::$trait_ for $struct_ {}
198    };
199}
200
201#[cfg(feature = "audio")]
202#[macro_export]
203macro_rules! impl_audio_generation {
204    ($struct_:ident) => {
205        impl rig::client::AsAudioGeneration for $struct_ {}
206    };
207}
208
209#[cfg(not(feature = "audio"))]
210#[macro_export]
211macro_rules! impl_audio_generation {
212    ($struct_:ident) => {};
213}
214
215#[cfg(feature = "image")]
216#[macro_export]
217macro_rules! impl_image_generation {
218    ($struct_:ident) => {
219        impl rig::client::AsImageGeneration for $struct_ {}
220    };
221}
222
223#[cfg(not(feature = "image"))]
224#[macro_export]
225macro_rules! impl_image_generation {
226    ($struct_:ident) => {};
227}
228
229pub use impl_audio_generation;
230pub use impl_conversion_traits;
231pub use impl_image_generation;
232
233#[cfg(feature = "audio")]
234use crate::client::audio_generation::AudioGenerationClientDyn;
235use crate::client::completion::CompletionClientDyn;
236use crate::client::embeddings::EmbeddingsClientDyn;
237#[cfg(feature = "image")]
238use crate::client::image_generation::ImageGenerationClientDyn;
239use crate::client::transcription::TranscriptionClientDyn;
240use crate::client::verify::VerifyClientDyn;
241
242#[cfg(feature = "audio")]
243pub use crate::client::audio_generation::AudioGenerationClient;
244pub use crate::client::completion::CompletionClient;
245pub use crate::client::embeddings::EmbeddingsClient;
246#[cfg(feature = "image")]
247pub use crate::client::image_generation::ImageGenerationClient;
248pub use crate::client::transcription::TranscriptionClient;
249pub use crate::client::verify::{VerifyClient, VerifyError};
250
251#[cfg(test)]
252mod tests {
253    use crate::OneOrMany;
254    use crate::client::ProviderClient;
255    use crate::completion::{Completion, CompletionRequest, ToolDefinition};
256    use crate::image_generation::ImageGenerationRequest;
257    use crate::message::AssistantContent;
258    use crate::providers::{
259        anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
260        moonshot, openai, openrouter, together, xai,
261    };
262    use crate::streaming::StreamingCompletion;
263    use crate::tool::Tool;
264    use crate::transcription::TranscriptionRequest;
265    use futures::StreamExt;
266    use rig::message::Message;
267    use rig::providers::{groq, ollama, perplexity};
268    use serde::{Deserialize, Serialize};
269    use serde_json::json;
270    use std::fs::File;
271    use std::io::Read;
272
273    use super::ProviderValue;
274
275    struct ClientConfig {
276        name: &'static str,
277        factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
278        // Not sure where we're going to be using this but I've added it for completeness
279        #[allow(dead_code)]
280        factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
281        env_variable: &'static str,
282        completion_model: Option<&'static str>,
283        embeddings_model: Option<&'static str>,
284        transcription_model: Option<&'static str>,
285        image_generation_model: Option<&'static str>,
286        audio_generation_model: Option<(&'static str, &'static str)>,
287    }
288
289    impl Default for ClientConfig {
290        fn default() -> Self {
291            Self {
292                name: "",
293                factory_env: Box::new(|| panic!("Not implemented")),
294                factory_val: Box::new(|_| panic!("Not implemented")),
295                env_variable: "",
296                completion_model: None,
297                embeddings_model: None,
298                transcription_model: None,
299                image_generation_model: None,
300                audio_generation_model: None,
301            }
302        }
303    }
304
305    impl ClientConfig {
306        fn is_env_var_set(&self) -> bool {
307            self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
308        }
309
310        fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
311            self.factory_env.as_ref()()
312        }
313    }
314
315    fn providers() -> Vec<ClientConfig> {
316        vec![
317            ClientConfig {
318                name: "Anthropic",
319                factory_env: Box::new(anthropic::Client::from_env_boxed),
320                factory_val: Box::new(anthropic::Client::from_val_boxed),
321                env_variable: "ANTHROPIC_API_KEY",
322                completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
323                ..Default::default()
324            },
325            ClientConfig {
326                name: "Cohere",
327                factory_env: Box::new(cohere::Client::from_env_boxed),
328                factory_val: Box::new(cohere::Client::from_val_boxed),
329                env_variable: "COHERE_API_KEY",
330                completion_model: Some(cohere::COMMAND_R),
331                embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
332                ..Default::default()
333            },
334            ClientConfig {
335                name: "Gemini",
336                factory_env: Box::new(gemini::Client::from_env_boxed),
337                factory_val: Box::new(gemini::Client::from_val_boxed),
338                env_variable: "GEMINI_API_KEY",
339                completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
340                embeddings_model: Some(gemini::embedding::EMBEDDING_001),
341                transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
342                ..Default::default()
343            },
344            ClientConfig {
345                name: "Huggingface",
346                factory_env: Box::new(huggingface::Client::from_env_boxed),
347                factory_val: Box::new(huggingface::Client::from_val_boxed),
348                env_variable: "HUGGINGFACE_API_KEY",
349                completion_model: Some(huggingface::PHI_4),
350                transcription_model: Some(huggingface::WHISPER_SMALL),
351                image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
352                ..Default::default()
353            },
354            ClientConfig {
355                name: "OpenAI",
356                factory_env: Box::new(openai::Client::from_env_boxed),
357                factory_val: Box::new(openai::Client::from_val_boxed),
358                env_variable: "OPENAI_API_KEY",
359                completion_model: Some(openai::GPT_4O),
360                embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
361                transcription_model: Some(openai::WHISPER_1),
362                image_generation_model: Some(openai::DALL_E_2),
363                audio_generation_model: Some((openai::TTS_1, "onyx")),
364            },
365            ClientConfig {
366                name: "OpenRouter",
367                factory_env: Box::new(openrouter::Client::from_env_boxed),
368                factory_val: Box::new(openrouter::Client::from_val_boxed),
369                env_variable: "OPENROUTER_API_KEY",
370                completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
371                ..Default::default()
372            },
373            ClientConfig {
374                name: "Together",
375                factory_env: Box::new(together::Client::from_env_boxed),
376                factory_val: Box::new(together::Client::from_val_boxed),
377                env_variable: "TOGETHER_API_KEY",
378                completion_model: Some(together::ALPACA_7B),
379                embeddings_model: Some(together::BERT_BASE_UNCASED),
380                ..Default::default()
381            },
382            ClientConfig {
383                name: "XAI",
384                factory_env: Box::new(xai::Client::from_env_boxed),
385                factory_val: Box::new(xai::Client::from_val_boxed),
386                env_variable: "XAI_API_KEY",
387                completion_model: Some(xai::GROK_3_MINI),
388                embeddings_model: None,
389                ..Default::default()
390            },
391            ClientConfig {
392                name: "Azure",
393                factory_env: Box::new(azure::Client::from_env_boxed),
394                factory_val: Box::new(azure::Client::from_val_boxed),
395                env_variable: "AZURE_API_KEY",
396                completion_model: Some(azure::GPT_4O),
397                embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
398                transcription_model: Some("whisper-1"),
399                image_generation_model: Some("dalle-2"),
400                audio_generation_model: Some(("tts-1", "onyx")),
401            },
402            ClientConfig {
403                name: "Deepseek",
404                factory_env: Box::new(deepseek::Client::from_env_boxed),
405                factory_val: Box::new(deepseek::Client::from_val_boxed),
406                env_variable: "DEEPSEEK_API_KEY",
407                completion_model: Some(deepseek::DEEPSEEK_CHAT),
408                ..Default::default()
409            },
410            ClientConfig {
411                name: "Galadriel",
412                factory_env: Box::new(galadriel::Client::from_env_boxed),
413                factory_val: Box::new(galadriel::Client::from_val_boxed),
414                env_variable: "GALADRIEL_API_KEY",
415                completion_model: Some(galadriel::GPT_4O),
416                ..Default::default()
417            },
418            ClientConfig {
419                name: "Groq",
420                factory_env: Box::new(groq::Client::from_env_boxed),
421                factory_val: Box::new(groq::Client::from_val_boxed),
422                env_variable: "GROQ_API_KEY",
423                completion_model: Some(groq::MIXTRAL_8X7B_32768),
424                transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
425                ..Default::default()
426            },
427            ClientConfig {
428                name: "Hyperbolic",
429                factory_env: Box::new(hyperbolic::Client::from_env_boxed),
430                factory_val: Box::new(hyperbolic::Client::from_val_boxed),
431                env_variable: "HYPERBOLIC_API_KEY",
432                completion_model: Some(hyperbolic::LLAMA_3_1_8B),
433                image_generation_model: Some(hyperbolic::SD1_5),
434                audio_generation_model: Some(("EN", "EN-US")),
435                ..Default::default()
436            },
437            ClientConfig {
438                name: "Mira",
439                factory_env: Box::new(mira::Client::from_env_boxed),
440                factory_val: Box::new(mira::Client::from_val_boxed),
441                env_variable: "MIRA_API_KEY",
442                completion_model: Some("gpt-4o"),
443                ..Default::default()
444            },
445            ClientConfig {
446                name: "Moonshot",
447                factory_env: Box::new(moonshot::Client::from_env_boxed),
448                factory_val: Box::new(moonshot::Client::from_val_boxed),
449                env_variable: "MOONSHOT_API_KEY",
450                completion_model: Some(moonshot::MOONSHOT_CHAT),
451                ..Default::default()
452            },
453            ClientConfig {
454                name: "Ollama",
455                factory_env: Box::new(ollama::Client::from_env_boxed),
456                factory_val: Box::new(ollama::Client::from_val_boxed),
457                env_variable: "OLLAMA_ENABLED",
458                completion_model: Some("llama3.1:8b"),
459                embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
460                ..Default::default()
461            },
462            ClientConfig {
463                name: "Perplexity",
464                factory_env: Box::new(perplexity::Client::from_env_boxed),
465                factory_val: Box::new(perplexity::Client::from_val_boxed),
466                env_variable: "PERPLEXITY_API_KEY",
467                completion_model: Some(perplexity::SONAR),
468                ..Default::default()
469            },
470        ]
471    }
472
473    async fn test_completions_client(config: &ClientConfig) {
474        let client = config.factory_env();
475
476        let Some(client) = client.as_completion() else {
477            return;
478        };
479
480        let model = config
481            .completion_model
482            .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
483
484        let model = client.completion_model(model);
485
486        let resp = model
487            .completion_request(Message::user("Whats the capital of France?"))
488            .send()
489            .await;
490
491        assert!(
492            resp.is_ok(),
493            "[{}]: Error occurred when prompting, {}",
494            config.name,
495            resp.err().unwrap()
496        );
497
498        let resp = resp.unwrap();
499
500        match resp.choice.first() {
501            AssistantContent::Text(text) => {
502                assert!(text.text.to_lowercase().contains("paris"));
503            }
504            _ => {
505                unreachable!(
506                    "[{}]: First choice wasn't a Text message, {:?}",
507                    config.name,
508                    resp.choice.first()
509                );
510            }
511        }
512    }
513
514    #[tokio::test]
515    #[ignore]
516    async fn test_completions() {
517        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
518            test_completions_client(&p).await;
519        }
520    }
521
522    async fn test_tools_client(config: &ClientConfig) {
523        let client = config.factory_env();
524        let model = config
525            .completion_model
526            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
527
528        let Some(client) = client.as_completion() else {
529            return;
530        };
531
532        let model = client.agent(model)
533            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
534            .max_tokens(1024)
535            .tool(Adder)
536            .tool(Subtract)
537            .build();
538
539        let request = model.completion("Calculate 2 - 5", vec![]).await;
540
541        assert!(
542            request.is_ok(),
543            "[{}]: Error occurred when building prompt, {}",
544            config.name,
545            request.err().unwrap()
546        );
547
548        let resp = request.unwrap().send().await;
549
550        assert!(
551            resp.is_ok(),
552            "[{}]: Error occurred when prompting, {}",
553            config.name,
554            resp.err().unwrap()
555        );
556
557        let resp = resp.unwrap();
558
559        assert!(
560            resp.choice.iter().any(|content| match content {
561                AssistantContent::ToolCall(tc) => {
562                    if tc.function.name != Subtract::NAME {
563                        return false;
564                    }
565
566                    let arguments =
567                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
568                            .expect("Error parsing arguments");
569
570                    arguments.x == 2.0 && arguments.y == 5.0
571                }
572                _ => false,
573            }),
574            "[{}]: Model did not use the Subtract tool.",
575            config.name
576        )
577    }
578
579    #[tokio::test]
580    #[ignore]
581    async fn test_tools() {
582        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
583            test_tools_client(&p).await;
584        }
585    }
586
587    async fn test_streaming_client(config: &ClientConfig) {
588        let client = config.factory_env();
589
590        let Some(client) = client.as_completion() else {
591            return;
592        };
593
594        let model = config
595            .completion_model
596            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
597
598        let model = client.completion_model(model);
599
600        let resp = model.stream(CompletionRequest {
601            preamble: None,
602            tools: vec![],
603            documents: vec![],
604            temperature: None,
605            max_tokens: None,
606            additional_params: None,
607            chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
608        });
609
610        let mut resp = resp.await.unwrap();
611
612        let mut received_chunk = false;
613
614        while let Some(chunk) = resp.next().await {
615            received_chunk = true;
616            assert!(chunk.is_ok());
617        }
618
619        assert!(
620            received_chunk,
621            "[{}]: Failed to receive a chunk from stream",
622            config.name
623        );
624
625        for choice in resp.choice {
626            match choice {
627                AssistantContent::Text(text) => {
628                    assert!(
629                        text.text.to_lowercase().contains("paris"),
630                        "[{}]: Did not answer with Paris",
631                        config.name
632                    );
633                }
634                AssistantContent::ToolCall(_) => {}
635                AssistantContent::Reasoning(_) => {}
636            }
637        }
638    }
639
640    #[tokio::test]
641    #[ignore]
642    async fn test_streaming() {
643        for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
644            test_streaming_client(&provider).await;
645        }
646    }
647
648    async fn test_streaming_tools_client(config: &ClientConfig) {
649        let client = config.factory_env();
650        let model = config
651            .completion_model
652            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
653
654        let Some(client) = client.as_completion() else {
655            return;
656        };
657
658        let model = client.agent(model)
659            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
660            .max_tokens(1024)
661            .tool(Adder)
662            .tool(Subtract)
663            .build();
664
665        let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
666
667        assert!(
668            request.is_ok(),
669            "[{}]: Error occurred when building prompt, {}",
670            config.name,
671            request.err().unwrap()
672        );
673
674        let resp = request.unwrap().stream().await;
675
676        assert!(
677            resp.is_ok(),
678            "[{}]: Error occurred when prompting, {}",
679            config.name,
680            resp.err().unwrap()
681        );
682
683        let mut resp = resp.unwrap();
684
685        let mut received_chunk = false;
686
687        while let Some(chunk) = resp.next().await {
688            received_chunk = true;
689            assert!(chunk.is_ok());
690        }
691
692        assert!(
693            received_chunk,
694            "[{}]: Failed to receive a chunk from stream",
695            config.name
696        );
697
698        assert!(
699            resp.choice.iter().any(|content| match content {
700                AssistantContent::ToolCall(tc) => {
701                    if tc.function.name != Subtract::NAME {
702                        return false;
703                    }
704
705                    let arguments =
706                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
707                            .expect("Error parsing arguments");
708
709                    arguments.x == 2.0 && arguments.y == 5.0
710                }
711                _ => false,
712            }),
713            "[{}]: Model did not use the Subtract tool.",
714            config.name
715        )
716    }
717
718    #[tokio::test]
719    #[ignore]
720    async fn test_streaming_tools() {
721        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
722            test_streaming_tools_client(&p).await;
723        }
724    }
725
726    async fn test_audio_generation_client(config: &ClientConfig) {
727        let client = config.factory_env();
728
729        let Some(client) = client.as_audio_generation() else {
730            return;
731        };
732
733        let (model, voice) = config
734            .audio_generation_model
735            .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
736
737        let model = client.audio_generation_model(model);
738
739        let request = model
740            .audio_generation_request()
741            .text("Hello world!")
742            .voice(voice);
743
744        let resp = request.send().await;
745
746        assert!(
747            resp.is_ok(),
748            "[{}]: Error occurred when sending request, {}",
749            config.name,
750            resp.err().unwrap()
751        );
752
753        let resp = resp.unwrap();
754
755        assert!(
756            !resp.audio.is_empty(),
757            "[{}]: Returned audio was empty",
758            config.name
759        );
760    }
761
762    #[tokio::test]
763    #[ignore]
764    async fn test_audio_generation() {
765        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
766            test_audio_generation_client(&p).await;
767        }
768    }
769
770    fn assert_feature<F, M>(
771        name: &str,
772        feature_name: &str,
773        model_name: &str,
774        feature: Option<F>,
775        model: Option<M>,
776    ) {
777        assert_eq!(
778            feature.is_some(),
779            model.is_some(),
780            "{} has{} implemented {} but config.{} is {}.",
781            name,
782            if feature.is_some() { "" } else { "n't" },
783            feature_name,
784            model_name,
785            if model.is_some() { "some" } else { "none" }
786        );
787    }
788
789    #[test]
790    #[ignore]
791    pub fn test_polymorphism() {
792        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
793            let client = config.factory_env();
794            assert_feature(
795                config.name,
796                "AsCompletion",
797                "completion_model",
798                client.as_completion(),
799                config.completion_model,
800            );
801
802            assert_feature(
803                config.name,
804                "AsEmbeddings",
805                "embeddings_model",
806                client.as_embeddings(),
807                config.embeddings_model,
808            );
809
810            assert_feature(
811                config.name,
812                "AsTranscription",
813                "transcription_model",
814                client.as_transcription(),
815                config.transcription_model,
816            );
817
818            assert_feature(
819                config.name,
820                "AsImageGeneration",
821                "image_generation_model",
822                client.as_image_generation(),
823                config.image_generation_model,
824            );
825
826            assert_feature(
827                config.name,
828                "AsAudioGeneration",
829                "audio_generation_model",
830                client.as_audio_generation(),
831                config.audio_generation_model,
832            )
833        }
834    }
835
836    async fn test_embed_client(config: &ClientConfig) {
837        const TEST: &str = "Hello world.";
838
839        let client = config.factory_env();
840
841        let Some(client) = client.as_embeddings() else {
842            return;
843        };
844
845        let model = config.embeddings_model.unwrap();
846
847        let model = client.embedding_model(model);
848
849        let resp = model.embed_text(TEST).await;
850
851        assert!(
852            resp.is_ok(),
853            "[{}]: Error occurred when sending request, {}",
854            config.name,
855            resp.err().unwrap()
856        );
857
858        let resp = resp.unwrap();
859
860        assert_eq!(resp.document, TEST);
861
862        assert!(
863            !resp.vec.is_empty(),
864            "[{}]: Returned embed was empty",
865            config.name
866        );
867    }
868
869    #[tokio::test]
870    #[ignore]
871    async fn test_embed() {
872        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
873            test_embed_client(&config).await;
874        }
875    }
876
877    async fn test_image_generation_client(config: &ClientConfig) {
878        let client = config.factory_env();
879        let Some(client) = client.as_image_generation() else {
880            return;
881        };
882
883        let model = config.image_generation_model.unwrap();
884
885        let model = client.image_generation_model(model);
886
887        let resp = model
888            .image_generation(ImageGenerationRequest {
889                prompt: "A castle sitting on a large hill.".to_string(),
890                width: 256,
891                height: 256,
892                additional_params: None,
893            })
894            .await;
895
896        assert!(
897            resp.is_ok(),
898            "[{}]: Error occurred when sending request, {}",
899            config.name,
900            resp.err().unwrap()
901        );
902
903        let resp = resp.unwrap();
904
905        assert!(
906            !resp.image.is_empty(),
907            "[{}]: Generated image was empty",
908            config.name
909        );
910    }
911
912    #[tokio::test]
913    #[ignore]
914    async fn test_image_generation() {
915        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
916            test_image_generation_client(&config).await;
917        }
918    }
919
920    async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
921        let client = config.factory_env();
922        let Some(client) = client.as_transcription() else {
923            return;
924        };
925
926        let model = config.image_generation_model.unwrap();
927
928        let model = client.transcription_model(model);
929
930        let resp = model
931            .transcription(TranscriptionRequest {
932                data,
933                filename: "audio.mp3".to_string(),
934                language: "en".to_string(),
935                prompt: None,
936                temperature: None,
937                additional_params: None,
938            })
939            .await;
940
941        assert!(
942            resp.is_ok(),
943            "[{}]: Error occurred when sending request, {}",
944            config.name,
945            resp.err().unwrap()
946        );
947
948        let resp = resp.unwrap();
949
950        assert!(
951            !resp.text.is_empty(),
952            "[{}]: Returned transcription was empty",
953            config.name
954        );
955    }
956
957    #[tokio::test]
958    #[ignore]
959    async fn test_transcription() {
960        let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
961
962        let mut data = Vec::new();
963        let _ = file.read(&mut data);
964
965        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
966            test_transcription_client(&config, data.clone()).await;
967        }
968    }
969
970    #[derive(Deserialize)]
971    struct OperationArgs {
972        x: f32,
973        y: f32,
974    }
975
976    #[derive(Debug, thiserror::Error)]
977    #[error("Math error")]
978    struct MathError;
979
980    #[derive(Deserialize, Serialize)]
981    struct Adder;
982    impl Tool for Adder {
983        const NAME: &'static str = "add";
984
985        type Error = MathError;
986        type Args = OperationArgs;
987        type Output = f32;
988
989        async fn definition(&self, _prompt: String) -> ToolDefinition {
990            ToolDefinition {
991                name: "add".to_string(),
992                description: "Add x and y together".to_string(),
993                parameters: json!({
994                    "type": "object",
995                    "properties": {
996                        "x": {
997                            "type": "number",
998                            "description": "The first number to add"
999                        },
1000                        "y": {
1001                            "type": "number",
1002                            "description": "The second number to add"
1003                        }
1004                    }
1005                }),
1006            }
1007        }
1008
1009        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1010            println!("[tool-call] Adding {} and {}", args.x, args.y);
1011            let result = args.x + args.y;
1012            Ok(result)
1013        }
1014    }
1015
1016    #[derive(Deserialize, Serialize)]
1017    struct Subtract;
1018    impl Tool for Subtract {
1019        const NAME: &'static str = "subtract";
1020
1021        type Error = MathError;
1022        type Args = OperationArgs;
1023        type Output = f32;
1024
1025        async fn definition(&self, _prompt: String) -> ToolDefinition {
1026            serde_json::from_value(json!({
1027                "name": "subtract",
1028                "description": "Subtract y from x (i.e.: x - y)",
1029                "parameters": {
1030                    "type": "object",
1031                    "properties": {
1032                        "x": {
1033                            "type": "number",
1034                            "description": "The number to subtract from"
1035                        },
1036                        "y": {
1037                            "type": "number",
1038                            "description": "The number to subtract"
1039                        }
1040                    }
1041                }
1042            }))
1043            .expect("Tool Definition")
1044        }
1045
1046        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1047            println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1048            let result = args.x - args.y;
1049            Ok(result)
1050        }
1051    }
1052}