rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15use thiserror::Error;
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19pub enum ClientBuilderError {
20    #[error("reqwest error: {0}")]
21    HttpError(
22        #[from]
23        #[source]
24        reqwest::Error,
25    ),
26    #[error("invalid property: {0}")]
27    InvalidProperty(&'static str),
28}
29
30/// The base ProviderClient trait, facilitates conversion between client types
31/// and creating a client from the environment.
32///
33/// All conversion traits must be implemented, they are automatically
34/// implemented if the respective client trait is implemented.
35pub trait ProviderClient:
36    AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
37{
38    /// Create a client from the process's environment.
39    /// Panics if an environment is improperly configured.
40    fn from_env() -> Self
41    where
42        Self: Sized;
43
44    /// A helper method to box the client.
45    fn boxed(self) -> Box<dyn ProviderClient>
46    where
47        Self: Sized + 'static,
48    {
49        Box::new(self)
50    }
51
52    /// Create a boxed client from the process's environment.
53    /// Panics if an environment is improperly configured.
54    fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
55    where
56        Self: Sized,
57        Self: 'a,
58    {
59        Box::new(Self::from_env())
60    }
61
62    fn from_val(input: ProviderValue) -> Self
63    where
64        Self: Sized;
65
66    /// Create a boxed client from the process's environment.
67    /// Panics if an environment is improperly configured.
68    fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
69    where
70        Self: Sized,
71        Self: 'a,
72    {
73        Box::new(Self::from_val(input))
74    }
75}
76
77#[derive(Clone)]
78pub enum ProviderValue {
79    Simple(String),
80    ApiKeyWithOptionalKey(String, Option<String>),
81    ApiKeyWithVersionAndHeader(String, String, String),
82}
83
84impl From<&str> for ProviderValue {
85    fn from(value: &str) -> Self {
86        Self::Simple(value.to_string())
87    }
88}
89
90impl From<String> for ProviderValue {
91    fn from(value: String) -> Self {
92        Self::Simple(value)
93    }
94}
95
96impl<P> From<(P, Option<P>)> for ProviderValue
97where
98    P: AsRef<str>,
99{
100    fn from((api_key, optional_key): (P, Option<P>)) -> Self {
101        Self::ApiKeyWithOptionalKey(
102            api_key.as_ref().to_string(),
103            optional_key.map(|x| x.as_ref().to_string()),
104        )
105    }
106}
107
108impl<P> From<(P, P, P)> for ProviderValue
109where
110    P: AsRef<str>,
111{
112    fn from((api_key, version, header): (P, P, P)) -> Self {
113        Self::ApiKeyWithVersionAndHeader(
114            api_key.as_ref().to_string(),
115            version.as_ref().to_string(),
116            header.as_ref().to_string(),
117        )
118    }
119}
120
121/// Attempt to convert a ProviderClient to a CompletionClient
122pub trait AsCompletion {
123    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
124        None
125    }
126}
127
128/// Attempt to convert a ProviderClient to a TranscriptionClient
129pub trait AsTranscription {
130    fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
131        None
132    }
133}
134
135/// Attempt to convert a ProviderClient to a EmbeddingsClient
136pub trait AsEmbeddings {
137    fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
138        None
139    }
140}
141
142/// Attempt to convert a ProviderClient to a AudioGenerationClient
143pub trait AsAudioGeneration {
144    #[cfg(feature = "audio")]
145    fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
146        None
147    }
148}
149
150/// Attempt to convert a ProviderClient to a ImageGenerationClient
151pub trait AsImageGeneration {
152    #[cfg(feature = "image")]
153    fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
154        None
155    }
156}
157
158#[cfg(not(feature = "audio"))]
159impl<T: ProviderClient> AsAudioGeneration for T {}
160
161#[cfg(not(feature = "image"))]
162impl<T: ProviderClient> AsImageGeneration for T {}
163
164/// Implements the conversion traits for a given struct
165/// ```rust
166/// pub struct Client;
167/// impl ProviderClient for Client {
168///     ...
169/// }
170/// impl_conversion_traits!(AsCompletion, AsEmbeddings for Client);
171/// ```
172#[macro_export]
173macro_rules! impl_conversion_traits {
174    ($( $trait_:ident ),* for $struct_:ident ) => {
175        $(
176            impl_conversion_traits!(@impl $trait_ for $struct_);
177        )*
178    };
179
180    (@impl AsAudioGeneration for $struct_:ident ) => {
181        rig::client::impl_audio_generation!($struct_);
182    };
183
184    (@impl AsImageGeneration for $struct_:ident ) => {
185        rig::client::impl_image_generation!($struct_);
186    };
187
188    (@impl $trait_:ident for $struct_:ident) => {
189        impl rig::client::$trait_ for $struct_ {}
190    };
191}
192
193#[cfg(feature = "audio")]
194#[macro_export]
195macro_rules! impl_audio_generation {
196    ($struct_:ident) => {
197        impl rig::client::AsAudioGeneration for $struct_ {}
198    };
199}
200
201#[cfg(not(feature = "audio"))]
202#[macro_export]
203macro_rules! impl_audio_generation {
204    ($struct_:ident) => {};
205}
206
207#[cfg(feature = "image")]
208#[macro_export]
209macro_rules! impl_image_generation {
210    ($struct_:ident) => {
211        impl rig::client::AsImageGeneration for $struct_ {}
212    };
213}
214
215#[cfg(not(feature = "image"))]
216#[macro_export]
217macro_rules! impl_image_generation {
218    ($struct_:ident) => {};
219}
220
221pub use impl_audio_generation;
222pub use impl_conversion_traits;
223pub use impl_image_generation;
224
225#[cfg(feature = "audio")]
226use crate::client::audio_generation::AudioGenerationClientDyn;
227use crate::client::completion::CompletionClientDyn;
228use crate::client::embeddings::EmbeddingsClientDyn;
229#[cfg(feature = "image")]
230use crate::client::image_generation::ImageGenerationClientDyn;
231use crate::client::transcription::TranscriptionClientDyn;
232
233#[cfg(feature = "audio")]
234pub use crate::client::audio_generation::AudioGenerationClient;
235pub use crate::client::completion::CompletionClient;
236pub use crate::client::embeddings::EmbeddingsClient;
237#[cfg(feature = "image")]
238pub use crate::client::image_generation::ImageGenerationClient;
239pub use crate::client::transcription::TranscriptionClient;
240
241#[cfg(test)]
242mod tests {
243    use crate::OneOrMany;
244    use crate::client::ProviderClient;
245    use crate::completion::{Completion, CompletionRequest, ToolDefinition};
246    use crate::image_generation::ImageGenerationRequest;
247    use crate::message::AssistantContent;
248    use crate::providers::{
249        anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
250        moonshot, openai, openrouter, together, xai,
251    };
252    use crate::streaming::StreamingCompletion;
253    use crate::tool::Tool;
254    use crate::transcription::TranscriptionRequest;
255    use futures::StreamExt;
256    use rig::message::Message;
257    use rig::providers::{groq, ollama, perplexity};
258    use serde::{Deserialize, Serialize};
259    use serde_json::json;
260    use std::fs::File;
261    use std::io::Read;
262
263    use super::ProviderValue;
264
265    struct ClientConfig {
266        name: &'static str,
267        factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
268        // Not sure where we're going to be using this but I've added it for completeness
269        #[allow(dead_code)]
270        factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
271        env_variable: &'static str,
272        completion_model: Option<&'static str>,
273        embeddings_model: Option<&'static str>,
274        transcription_model: Option<&'static str>,
275        image_generation_model: Option<&'static str>,
276        audio_generation_model: Option<(&'static str, &'static str)>,
277    }
278
279    impl Default for ClientConfig {
280        fn default() -> Self {
281            Self {
282                name: "",
283                factory_env: Box::new(|| panic!("Not implemented")),
284                factory_val: Box::new(|_| panic!("Not implemented")),
285                env_variable: "",
286                completion_model: None,
287                embeddings_model: None,
288                transcription_model: None,
289                image_generation_model: None,
290                audio_generation_model: None,
291            }
292        }
293    }
294
295    impl ClientConfig {
296        fn is_env_var_set(&self) -> bool {
297            self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
298        }
299
300        fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
301            self.factory_env.as_ref()()
302        }
303    }
304
305    fn providers() -> Vec<ClientConfig> {
306        vec![
307            ClientConfig {
308                name: "Anthropic",
309                factory_env: Box::new(anthropic::Client::from_env_boxed),
310                factory_val: Box::new(anthropic::Client::from_val_boxed),
311                env_variable: "ANTHROPIC_API_KEY",
312                completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
313                ..Default::default()
314            },
315            ClientConfig {
316                name: "Cohere",
317                factory_env: Box::new(cohere::Client::from_env_boxed),
318                factory_val: Box::new(cohere::Client::from_val_boxed),
319                env_variable: "COHERE_API_KEY",
320                completion_model: Some(cohere::COMMAND_R),
321                embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
322                ..Default::default()
323            },
324            ClientConfig {
325                name: "Gemini",
326                factory_env: Box::new(gemini::Client::from_env_boxed),
327                factory_val: Box::new(gemini::Client::from_val_boxed),
328                env_variable: "GEMINI_API_KEY",
329                completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
330                embeddings_model: Some(gemini::embedding::EMBEDDING_001),
331                transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
332                ..Default::default()
333            },
334            ClientConfig {
335                name: "Huggingface",
336                factory_env: Box::new(huggingface::Client::from_env_boxed),
337                factory_val: Box::new(huggingface::Client::from_val_boxed),
338                env_variable: "HUGGINGFACE_API_KEY",
339                completion_model: Some(huggingface::PHI_4),
340                transcription_model: Some(huggingface::WHISPER_SMALL),
341                image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
342                ..Default::default()
343            },
344            ClientConfig {
345                name: "OpenAI",
346                factory_env: Box::new(openai::Client::from_env_boxed),
347                factory_val: Box::new(openai::Client::from_val_boxed),
348                env_variable: "OPENAI_API_KEY",
349                completion_model: Some(openai::GPT_4O),
350                embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
351                transcription_model: Some(openai::WHISPER_1),
352                image_generation_model: Some(openai::DALL_E_2),
353                audio_generation_model: Some((openai::TTS_1, "onyx")),
354            },
355            ClientConfig {
356                name: "OpenRouter",
357                factory_env: Box::new(openrouter::Client::from_env_boxed),
358                factory_val: Box::new(openrouter::Client::from_val_boxed),
359                env_variable: "OPENROUTER_API_KEY",
360                completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
361                ..Default::default()
362            },
363            ClientConfig {
364                name: "Together",
365                factory_env: Box::new(together::Client::from_env_boxed),
366                factory_val: Box::new(together::Client::from_val_boxed),
367                env_variable: "TOGETHER_API_KEY",
368                completion_model: Some(together::ALPACA_7B),
369                embeddings_model: Some(together::BERT_BASE_UNCASED),
370                ..Default::default()
371            },
372            ClientConfig {
373                name: "XAI",
374                factory_env: Box::new(xai::Client::from_env_boxed),
375                factory_val: Box::new(xai::Client::from_val_boxed),
376                env_variable: "XAI_API_KEY",
377                completion_model: Some(xai::GROK_3_MINI),
378                embeddings_model: None,
379                ..Default::default()
380            },
381            ClientConfig {
382                name: "Azure",
383                factory_env: Box::new(azure::Client::from_env_boxed),
384                factory_val: Box::new(azure::Client::from_val_boxed),
385                env_variable: "AZURE_API_KEY",
386                completion_model: Some(azure::GPT_4O),
387                embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
388                transcription_model: Some("whisper-1"),
389                image_generation_model: Some("dalle-2"),
390                audio_generation_model: Some(("tts-1", "onyx")),
391            },
392            ClientConfig {
393                name: "Deepseek",
394                factory_env: Box::new(deepseek::Client::from_env_boxed),
395                factory_val: Box::new(deepseek::Client::from_val_boxed),
396                env_variable: "DEEPSEEK_API_KEY",
397                completion_model: Some(deepseek::DEEPSEEK_CHAT),
398                ..Default::default()
399            },
400            ClientConfig {
401                name: "Galadriel",
402                factory_env: Box::new(galadriel::Client::from_env_boxed),
403                factory_val: Box::new(galadriel::Client::from_val_boxed),
404                env_variable: "GALADRIEL_API_KEY",
405                completion_model: Some(galadriel::GPT_4O),
406                ..Default::default()
407            },
408            ClientConfig {
409                name: "Groq",
410                factory_env: Box::new(groq::Client::from_env_boxed),
411                factory_val: Box::new(groq::Client::from_val_boxed),
412                env_variable: "GROQ_API_KEY",
413                completion_model: Some(groq::MIXTRAL_8X7B_32768),
414                transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
415                ..Default::default()
416            },
417            ClientConfig {
418                name: "Hyperbolic",
419                factory_env: Box::new(hyperbolic::Client::from_env_boxed),
420                factory_val: Box::new(hyperbolic::Client::from_val_boxed),
421                env_variable: "HYPERBOLIC_API_KEY",
422                completion_model: Some(hyperbolic::LLAMA_3_1_8B),
423                image_generation_model: Some(hyperbolic::SD1_5),
424                audio_generation_model: Some(("EN", "EN-US")),
425                ..Default::default()
426            },
427            ClientConfig {
428                name: "Mira",
429                factory_env: Box::new(mira::Client::from_env_boxed),
430                factory_val: Box::new(mira::Client::from_val_boxed),
431                env_variable: "MIRA_API_KEY",
432                completion_model: Some("gpt-4o"),
433                ..Default::default()
434            },
435            ClientConfig {
436                name: "Moonshot",
437                factory_env: Box::new(moonshot::Client::from_env_boxed),
438                factory_val: Box::new(moonshot::Client::from_val_boxed),
439                env_variable: "MOONSHOT_API_KEY",
440                completion_model: Some(moonshot::MOONSHOT_CHAT),
441                ..Default::default()
442            },
443            ClientConfig {
444                name: "Ollama",
445                factory_env: Box::new(ollama::Client::from_env_boxed),
446                factory_val: Box::new(ollama::Client::from_val_boxed),
447                env_variable: "OLLAMA_ENABLED",
448                completion_model: Some("llama3.1:8b"),
449                embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
450                ..Default::default()
451            },
452            ClientConfig {
453                name: "Perplexity",
454                factory_env: Box::new(perplexity::Client::from_env_boxed),
455                factory_val: Box::new(perplexity::Client::from_val_boxed),
456                env_variable: "PERPLEXITY_API_KEY",
457                completion_model: Some(perplexity::SONAR),
458                ..Default::default()
459            },
460        ]
461    }
462
463    async fn test_completions_client(config: &ClientConfig) {
464        let client = config.factory_env();
465
466        let Some(client) = client.as_completion() else {
467            return;
468        };
469
470        let model = config
471            .completion_model
472            .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
473
474        let model = client.completion_model(model);
475
476        let resp = model
477            .completion_request(Message::user("Whats the capital of France?"))
478            .send()
479            .await;
480
481        assert!(
482            resp.is_ok(),
483            "[{}]: Error occurred when prompting, {}",
484            config.name,
485            resp.err().unwrap()
486        );
487
488        let resp = resp.unwrap();
489
490        match resp.choice.first() {
491            AssistantContent::Text(text) => {
492                assert!(text.text.to_lowercase().contains("paris"));
493            }
494            _ => {
495                unreachable!(
496                    "[{}]: First choice wasn't a Text message, {:?}",
497                    config.name,
498                    resp.choice.first()
499                );
500            }
501        }
502    }
503
504    #[tokio::test]
505    #[ignore]
506    async fn test_completions() {
507        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
508            test_completions_client(&p).await;
509        }
510    }
511
512    async fn test_tools_client(config: &ClientConfig) {
513        let client = config.factory_env();
514        let model = config
515            .completion_model
516            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
517
518        let Some(client) = client.as_completion() else {
519            return;
520        };
521
522        let model = client.agent(model)
523            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
524            .max_tokens(1024)
525            .tool(Adder)
526            .tool(Subtract)
527            .build();
528
529        let request = model.completion("Calculate 2 - 5", vec![]).await;
530
531        assert!(
532            request.is_ok(),
533            "[{}]: Error occurred when building prompt, {}",
534            config.name,
535            request.err().unwrap()
536        );
537
538        let resp = request.unwrap().send().await;
539
540        assert!(
541            resp.is_ok(),
542            "[{}]: Error occurred when prompting, {}",
543            config.name,
544            resp.err().unwrap()
545        );
546
547        let resp = resp.unwrap();
548
549        assert!(
550            resp.choice.iter().any(|content| match content {
551                AssistantContent::ToolCall(tc) => {
552                    if tc.function.name != Subtract::NAME {
553                        return false;
554                    }
555
556                    let arguments =
557                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
558                            .expect("Error parsing arguments");
559
560                    arguments.x == 2.0 && arguments.y == 5.0
561                }
562                _ => false,
563            }),
564            "[{}]: Model did not use the Subtract tool.",
565            config.name
566        )
567    }
568
569    #[tokio::test]
570    #[ignore]
571    async fn test_tools() {
572        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
573            test_tools_client(&p).await;
574        }
575    }
576
577    async fn test_streaming_client(config: &ClientConfig) {
578        let client = config.factory_env();
579
580        let Some(client) = client.as_completion() else {
581            return;
582        };
583
584        let model = config
585            .completion_model
586            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
587
588        let model = client.completion_model(model);
589
590        let resp = model.stream(CompletionRequest {
591            preamble: None,
592            tools: vec![],
593            documents: vec![],
594            temperature: None,
595            max_tokens: None,
596            additional_params: None,
597            chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
598        });
599
600        let mut resp = resp.await.unwrap();
601
602        let mut received_chunk = false;
603
604        while let Some(chunk) = resp.next().await {
605            received_chunk = true;
606            assert!(chunk.is_ok());
607        }
608
609        assert!(
610            received_chunk,
611            "[{}]: Failed to receive a chunk from stream",
612            config.name
613        );
614
615        for choice in resp.choice {
616            match choice {
617                AssistantContent::Text(text) => {
618                    assert!(
619                        text.text.to_lowercase().contains("paris"),
620                        "[{}]: Did not answer with Paris",
621                        config.name
622                    );
623                }
624                AssistantContent::ToolCall(_) => {}
625                AssistantContent::Reasoning(_) => {}
626            }
627        }
628    }
629
630    #[tokio::test]
631    #[ignore]
632    async fn test_streaming() {
633        for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
634            test_streaming_client(&provider).await;
635        }
636    }
637
638    async fn test_streaming_tools_client(config: &ClientConfig) {
639        let client = config.factory_env();
640        let model = config
641            .completion_model
642            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
643
644        let Some(client) = client.as_completion() else {
645            return;
646        };
647
648        let model = client.agent(model)
649            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
650            .max_tokens(1024)
651            .tool(Adder)
652            .tool(Subtract)
653            .build();
654
655        let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
656
657        assert!(
658            request.is_ok(),
659            "[{}]: Error occurred when building prompt, {}",
660            config.name,
661            request.err().unwrap()
662        );
663
664        let resp = request.unwrap().stream().await;
665
666        assert!(
667            resp.is_ok(),
668            "[{}]: Error occurred when prompting, {}",
669            config.name,
670            resp.err().unwrap()
671        );
672
673        let mut resp = resp.unwrap();
674
675        let mut received_chunk = false;
676
677        while let Some(chunk) = resp.next().await {
678            received_chunk = true;
679            assert!(chunk.is_ok());
680        }
681
682        assert!(
683            received_chunk,
684            "[{}]: Failed to receive a chunk from stream",
685            config.name
686        );
687
688        assert!(
689            resp.choice.iter().any(|content| match content {
690                AssistantContent::ToolCall(tc) => {
691                    if tc.function.name != Subtract::NAME {
692                        return false;
693                    }
694
695                    let arguments =
696                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
697                            .expect("Error parsing arguments");
698
699                    arguments.x == 2.0 && arguments.y == 5.0
700                }
701                _ => false,
702            }),
703            "[{}]: Model did not use the Subtract tool.",
704            config.name
705        )
706    }
707
708    #[tokio::test]
709    #[ignore]
710    async fn test_streaming_tools() {
711        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
712            test_streaming_tools_client(&p).await;
713        }
714    }
715
716    async fn test_audio_generation_client(config: &ClientConfig) {
717        let client = config.factory_env();
718
719        let Some(client) = client.as_audio_generation() else {
720            return;
721        };
722
723        let (model, voice) = config
724            .audio_generation_model
725            .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
726
727        let model = client.audio_generation_model(model);
728
729        let request = model
730            .audio_generation_request()
731            .text("Hello world!")
732            .voice(voice);
733
734        let resp = request.send().await;
735
736        assert!(
737            resp.is_ok(),
738            "[{}]: Error occurred when sending request, {}",
739            config.name,
740            resp.err().unwrap()
741        );
742
743        let resp = resp.unwrap();
744
745        assert!(
746            !resp.audio.is_empty(),
747            "[{}]: Returned audio was empty",
748            config.name
749        );
750    }
751
752    #[tokio::test]
753    #[ignore]
754    async fn test_audio_generation() {
755        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
756            test_audio_generation_client(&p).await;
757        }
758    }
759
760    fn assert_feature<F, M>(
761        name: &str,
762        feature_name: &str,
763        model_name: &str,
764        feature: Option<F>,
765        model: Option<M>,
766    ) {
767        assert_eq!(
768            feature.is_some(),
769            model.is_some(),
770            "{} has{} implemented {} but config.{} is {}.",
771            name,
772            if feature.is_some() { "" } else { "n't" },
773            feature_name,
774            model_name,
775            if model.is_some() { "some" } else { "none" }
776        );
777    }
778
779    #[test]
780    #[ignore]
781    pub fn test_polymorphism() {
782        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
783            let client = config.factory_env();
784            assert_feature(
785                config.name,
786                "AsCompletion",
787                "completion_model",
788                client.as_completion(),
789                config.completion_model,
790            );
791
792            assert_feature(
793                config.name,
794                "AsEmbeddings",
795                "embeddings_model",
796                client.as_embeddings(),
797                config.embeddings_model,
798            );
799
800            assert_feature(
801                config.name,
802                "AsTranscription",
803                "transcription_model",
804                client.as_transcription(),
805                config.transcription_model,
806            );
807
808            assert_feature(
809                config.name,
810                "AsImageGeneration",
811                "image_generation_model",
812                client.as_image_generation(),
813                config.image_generation_model,
814            );
815
816            assert_feature(
817                config.name,
818                "AsAudioGeneration",
819                "audio_generation_model",
820                client.as_audio_generation(),
821                config.audio_generation_model,
822            )
823        }
824    }
825
826    async fn test_embed_client(config: &ClientConfig) {
827        const TEST: &str = "Hello world.";
828
829        let client = config.factory_env();
830
831        let Some(client) = client.as_embeddings() else {
832            return;
833        };
834
835        let model = config.embeddings_model.unwrap();
836
837        let model = client.embedding_model(model);
838
839        let resp = model.embed_text(TEST).await;
840
841        assert!(
842            resp.is_ok(),
843            "[{}]: Error occurred when sending request, {}",
844            config.name,
845            resp.err().unwrap()
846        );
847
848        let resp = resp.unwrap();
849
850        assert_eq!(resp.document, TEST);
851
852        assert!(
853            !resp.vec.is_empty(),
854            "[{}]: Returned embed was empty",
855            config.name
856        );
857    }
858
859    #[tokio::test]
860    #[ignore]
861    async fn test_embed() {
862        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
863            test_embed_client(&config).await;
864        }
865    }
866
867    async fn test_image_generation_client(config: &ClientConfig) {
868        let client = config.factory_env();
869        let Some(client) = client.as_image_generation() else {
870            return;
871        };
872
873        let model = config.image_generation_model.unwrap();
874
875        let model = client.image_generation_model(model);
876
877        let resp = model
878            .image_generation(ImageGenerationRequest {
879                prompt: "A castle sitting on a large hill.".to_string(),
880                width: 256,
881                height: 256,
882                additional_params: None,
883            })
884            .await;
885
886        assert!(
887            resp.is_ok(),
888            "[{}]: Error occurred when sending request, {}",
889            config.name,
890            resp.err().unwrap()
891        );
892
893        let resp = resp.unwrap();
894
895        assert!(
896            !resp.image.is_empty(),
897            "[{}]: Generated image was empty",
898            config.name
899        );
900    }
901
902    #[tokio::test]
903    #[ignore]
904    async fn test_image_generation() {
905        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
906            test_image_generation_client(&config).await;
907        }
908    }
909
910    async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
911        let client = config.factory_env();
912        let Some(client) = client.as_transcription() else {
913            return;
914        };
915
916        let model = config.image_generation_model.unwrap();
917
918        let model = client.transcription_model(model);
919
920        let resp = model
921            .transcription(TranscriptionRequest {
922                data,
923                filename: "audio.mp3".to_string(),
924                language: "en".to_string(),
925                prompt: None,
926                temperature: None,
927                additional_params: None,
928            })
929            .await;
930
931        assert!(
932            resp.is_ok(),
933            "[{}]: Error occurred when sending request, {}",
934            config.name,
935            resp.err().unwrap()
936        );
937
938        let resp = resp.unwrap();
939
940        assert!(
941            !resp.text.is_empty(),
942            "[{}]: Returned transcription was empty",
943            config.name
944        );
945    }
946
947    #[tokio::test]
948    #[ignore]
949    async fn test_transcription() {
950        let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
951
952        let mut data = Vec::new();
953        let _ = file.read(&mut data);
954
955        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
956            test_transcription_client(&config, data.clone()).await;
957        }
958    }
959
960    #[derive(Deserialize)]
961    struct OperationArgs {
962        x: f32,
963        y: f32,
964    }
965
966    #[derive(Debug, thiserror::Error)]
967    #[error("Math error")]
968    struct MathError;
969
970    #[derive(Deserialize, Serialize)]
971    struct Adder;
972    impl Tool for Adder {
973        const NAME: &'static str = "add";
974
975        type Error = MathError;
976        type Args = OperationArgs;
977        type Output = f32;
978
979        async fn definition(&self, _prompt: String) -> ToolDefinition {
980            ToolDefinition {
981                name: "add".to_string(),
982                description: "Add x and y together".to_string(),
983                parameters: json!({
984                    "type": "object",
985                    "properties": {
986                        "x": {
987                            "type": "number",
988                            "description": "The first number to add"
989                        },
990                        "y": {
991                            "type": "number",
992                            "description": "The second number to add"
993                        }
994                    }
995                }),
996            }
997        }
998
999        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1000            println!("[tool-call] Adding {} and {}", args.x, args.y);
1001            let result = args.x + args.y;
1002            Ok(result)
1003        }
1004    }
1005
1006    #[derive(Deserialize, Serialize)]
1007    struct Subtract;
1008    impl Tool for Subtract {
1009        const NAME: &'static str = "subtract";
1010
1011        type Error = MathError;
1012        type Args = OperationArgs;
1013        type Output = f32;
1014
1015        async fn definition(&self, _prompt: String) -> ToolDefinition {
1016            serde_json::from_value(json!({
1017                "name": "subtract",
1018                "description": "Subtract y from x (i.e.: x - y)",
1019                "parameters": {
1020                    "type": "object",
1021                    "properties": {
1022                        "x": {
1023                            "type": "number",
1024                            "description": "The number to subtract from"
1025                        },
1026                        "y": {
1027                            "type": "number",
1028                            "description": "The number to subtract"
1029                        }
1030                    }
1031                }
1032            }))
1033            .expect("Tool Definition")
1034        }
1035
1036        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1037            println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1038            let result = args.x - args.y;
1039            Ok(result)
1040        }
1041    }
1042}