rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13#[cfg(feature = "derive")]
14pub use rig_derive::ProviderClient;
15use std::fmt::Debug;
16use thiserror::Error;
17
18#[derive(Debug, Error)]
19#[non_exhaustive]
20pub enum ClientBuilderError {
21    #[error("reqwest error: {0}")]
22    HttpError(
23        #[from]
24        #[source]
25        reqwest::Error,
26    ),
27    #[error("invalid property: {0}")]
28    InvalidProperty(&'static str),
29}
30
31/// The base ProviderClient trait, facilitates conversion between client types
32/// and creating a client from the environment.
33///
34/// All conversion traits must be implemented, they are automatically
35/// implemented if the respective client trait is implemented.
36pub trait ProviderClient:
37    AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
38{
39    /// Create a client from the process's environment.
40    /// Panics if an environment is improperly configured.
41    fn from_env() -> Self
42    where
43        Self: Sized;
44
45    /// A helper method to box the client.
46    fn boxed(self) -> Box<dyn ProviderClient>
47    where
48        Self: Sized + 'static,
49    {
50        Box::new(self)
51    }
52
53    /// Create a boxed client from the process's environment.
54    /// Panics if an environment is improperly configured.
55    fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
56    where
57        Self: Sized,
58        Self: 'a,
59    {
60        Box::new(Self::from_env())
61    }
62
63    fn from_val(input: ProviderValue) -> Self
64    where
65        Self: Sized;
66
67    /// Create a boxed client from the process's environment.
68    /// Panics if an environment is improperly configured.
69    fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
70    where
71        Self: Sized,
72        Self: 'a,
73    {
74        Box::new(Self::from_val(input))
75    }
76}
77
78#[derive(Clone)]
79pub enum ProviderValue {
80    Simple(String),
81    ApiKeyWithOptionalKey(String, Option<String>),
82    ApiKeyWithVersionAndHeader(String, String, String),
83}
84
85impl From<&str> for ProviderValue {
86    fn from(value: &str) -> Self {
87        Self::Simple(value.to_string())
88    }
89}
90
91impl From<String> for ProviderValue {
92    fn from(value: String) -> Self {
93        Self::Simple(value)
94    }
95}
96
97impl<P> From<(P, Option<P>)> for ProviderValue
98where
99    P: AsRef<str>,
100{
101    fn from((api_key, optional_key): (P, Option<P>)) -> Self {
102        Self::ApiKeyWithOptionalKey(
103            api_key.as_ref().to_string(),
104            optional_key.map(|x| x.as_ref().to_string()),
105        )
106    }
107}
108
109impl<P> From<(P, P, P)> for ProviderValue
110where
111    P: AsRef<str>,
112{
113    fn from((api_key, version, header): (P, P, P)) -> Self {
114        Self::ApiKeyWithVersionAndHeader(
115            api_key.as_ref().to_string(),
116            version.as_ref().to_string(),
117            header.as_ref().to_string(),
118        )
119    }
120}
121
122/// Attempt to convert a ProviderClient to a CompletionClient
123pub trait AsCompletion {
124    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
125        None
126    }
127}
128
129/// Attempt to convert a ProviderClient to a TranscriptionClient
130pub trait AsTranscription {
131    fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
132        None
133    }
134}
135
136/// Attempt to convert a ProviderClient to a EmbeddingsClient
137pub trait AsEmbeddings {
138    fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
139        None
140    }
141}
142
143/// Attempt to convert a ProviderClient to a AudioGenerationClient
144pub trait AsAudioGeneration {
145    #[cfg(feature = "audio")]
146    fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
147        None
148    }
149}
150
151/// Attempt to convert a ProviderClient to a ImageGenerationClient
152pub trait AsImageGeneration {
153    #[cfg(feature = "image")]
154    fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
155        None
156    }
157}
158
159/// Attempt to convert a ProviderClient to a VerifyClient
160pub trait AsVerify {
161    fn as_verify(&self) -> Option<Box<dyn VerifyClientDyn>> {
162        None
163    }
164}
165
166#[cfg(not(feature = "audio"))]
167impl<T: ProviderClient> AsAudioGeneration for T {}
168
169#[cfg(not(feature = "image"))]
170impl<T: ProviderClient> AsImageGeneration for T {}
171
172#[macro_export]
173macro_rules! impl_conversion_traits {
174    ($( $trait_:ident ),* for $($type_spec:tt)+) => {
175        impl_conversion_traits!(@expand_traits [$($trait_)+] $($type_spec)+);
176    };
177
178    (@expand_traits [$trait_:ident $($rest_traits:ident)*] $($type_spec:tt)+) => {
179        impl_conversion_traits!(@impl $trait_ for $($type_spec)+);
180        impl_conversion_traits!(@expand_traits [$($rest_traits)*] $($type_spec)+);
181    };
182
183    (@expand_traits [] $($type_spec:tt)+) => {};
184
185    (@impl AsAudioGeneration for $($type_spec:tt)+) => {
186        rig::client::impl_audio_generation!($($type_spec)+);
187    };
188
189    (@impl AsImageGeneration for $($type_spec:tt)+) => {
190        rig::client::impl_image_generation!($($type_spec)+);
191    };
192
193    (@impl $trait_:ident for $($type_spec:tt)+) => {
194        impl_conversion_traits!(@impl_trait $trait_ for $($type_spec)+);
195    };
196
197    (@impl_trait $trait_:ident for $struct_:ident) => {
198        impl rig::client::$trait_ for $struct_ {}
199    };
200
201    (@impl_trait $trait_:ident for $struct_:ident<$($generics:tt),*>) => {
202        impl<$($generics),*> rig::client::$trait_ for $struct_<$($generics),*> {}
203    };
204}
205
206#[cfg(feature = "audio")]
207#[macro_export]
208macro_rules! impl_audio_generation {
209    ($struct_:ident) => {
210        impl rig::client::AsAudioGeneration for $struct_ {}
211    };
212    ($struct_:ident<$($generics:tt),*>) => {
213        impl<$($generics),*> rig::client::AsAudioGeneration for $struct_<$($generics),*> {}
214    };
215}
216
217#[cfg(not(feature = "audio"))]
218#[macro_export]
219macro_rules! impl_audio_generation {
220    ($($tokens:tt)*) => {};
221}
222
223#[cfg(feature = "image")]
224#[macro_export]
225macro_rules! impl_image_generation {
226    ($struct_:ident) => {
227        impl rig::client::AsImageGeneration for $struct_ {}
228    };
229    ($struct_:ident<$($generics:tt),*>) => {
230        impl<$($generics),*> rig::client::AsImageGeneration for $struct_<$($generics),*> {}
231    };
232}
233
234#[cfg(not(feature = "image"))]
235#[macro_export]
236macro_rules! impl_image_generation {
237    ($($tokens:tt)*) => {};
238}
239
240pub use impl_audio_generation;
241pub use impl_conversion_traits;
242pub use impl_image_generation;
243
244#[cfg(feature = "audio")]
245use crate::client::audio_generation::AudioGenerationClientDyn;
246use crate::client::completion::CompletionClientDyn;
247use crate::client::embeddings::EmbeddingsClientDyn;
248#[cfg(feature = "image")]
249use crate::client::image_generation::ImageGenerationClientDyn;
250use crate::client::transcription::TranscriptionClientDyn;
251use crate::client::verify::VerifyClientDyn;
252
253#[cfg(feature = "audio")]
254pub use crate::client::audio_generation::AudioGenerationClient;
255pub use crate::client::completion::CompletionClient;
256pub use crate::client::embeddings::EmbeddingsClient;
257#[cfg(feature = "image")]
258pub use crate::client::image_generation::ImageGenerationClient;
259pub use crate::client::transcription::TranscriptionClient;
260pub use crate::client::verify::{VerifyClient, VerifyError};
261
262#[cfg(test)]
263mod tests {
264    use crate::OneOrMany;
265    use crate::client::ProviderClient;
266    use crate::completion::{Completion, CompletionRequest, ToolDefinition};
267    use crate::image_generation::ImageGenerationRequest;
268    use crate::message::AssistantContent;
269    use crate::providers::{
270        anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
271        moonshot, openai, openrouter, together, xai,
272    };
273    use crate::streaming::StreamingCompletion;
274    use crate::tool::Tool;
275    use crate::transcription::TranscriptionRequest;
276    use futures::StreamExt;
277    use rig::message::Message;
278    use rig::providers::{groq, ollama, perplexity};
279    use serde::{Deserialize, Serialize};
280    use serde_json::json;
281    use std::fs::File;
282    use std::io::Read;
283
284    use super::ProviderValue;
285
286    struct ClientConfig {
287        name: &'static str,
288        factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
289        // Not sure where we're going to be using this but I've added it for completeness
290        #[allow(dead_code)]
291        factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
292        env_variable: &'static str,
293        completion_model: Option<&'static str>,
294        embeddings_model: Option<&'static str>,
295        transcription_model: Option<&'static str>,
296        image_generation_model: Option<&'static str>,
297        audio_generation_model: Option<(&'static str, &'static str)>,
298    }
299
300    impl Default for ClientConfig {
301        fn default() -> Self {
302            Self {
303                name: "",
304                factory_env: Box::new(|| panic!("Not implemented")),
305                factory_val: Box::new(|_| panic!("Not implemented")),
306                env_variable: "",
307                completion_model: None,
308                embeddings_model: None,
309                transcription_model: None,
310                image_generation_model: None,
311                audio_generation_model: None,
312            }
313        }
314    }
315
316    impl ClientConfig {
317        fn is_env_var_set(&self) -> bool {
318            self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
319        }
320
321        fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
322            self.factory_env.as_ref()()
323        }
324    }
325
326    fn providers() -> Vec<ClientConfig> {
327        vec![
328            ClientConfig {
329                name: "Anthropic",
330                factory_env: Box::new(anthropic::Client::<reqwest::Client>::from_env_boxed),
331                factory_val: Box::new(anthropic::Client::<reqwest::Client>::from_val_boxed),
332                env_variable: "ANTHROPIC_API_KEY",
333                completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
334                ..Default::default()
335            },
336            ClientConfig {
337                name: "Cohere",
338                factory_env: Box::new(cohere::Client::<reqwest::Client>::from_env_boxed),
339                factory_val: Box::new(cohere::Client::<reqwest::Client>::from_val_boxed),
340                env_variable: "COHERE_API_KEY",
341                completion_model: Some(cohere::COMMAND_R),
342                embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
343                ..Default::default()
344            },
345            ClientConfig {
346                name: "Gemini",
347                factory_env: Box::new(gemini::Client::<reqwest::Client>::from_env_boxed),
348                factory_val: Box::new(gemini::Client::<reqwest::Client>::from_val_boxed),
349                env_variable: "GEMINI_API_KEY",
350                completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
351                embeddings_model: Some(gemini::embedding::EMBEDDING_001),
352                transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
353                ..Default::default()
354            },
355            ClientConfig {
356                name: "Huggingface",
357                factory_env: Box::new(huggingface::Client::<reqwest::Client>::from_env_boxed),
358                factory_val: Box::new(huggingface::Client::<reqwest::Client>::from_val_boxed),
359                env_variable: "HUGGINGFACE_API_KEY",
360                completion_model: Some(huggingface::PHI_4),
361                transcription_model: Some(huggingface::WHISPER_SMALL),
362                image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
363                ..Default::default()
364            },
365            ClientConfig {
366                name: "OpenAI",
367                factory_env: Box::new(openai::Client::<reqwest::Client>::from_env_boxed),
368                factory_val: Box::new(openai::Client::<reqwest::Client>::from_val_boxed),
369                env_variable: "OPENAI_API_KEY",
370                completion_model: Some(openai::GPT_4O),
371                embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
372                transcription_model: Some(openai::WHISPER_1),
373                image_generation_model: Some(openai::DALL_E_2),
374                audio_generation_model: Some((openai::TTS_1, "onyx")),
375            },
376            ClientConfig {
377                name: "OpenRouter",
378                factory_env: Box::new(openrouter::Client::<reqwest::Client>::from_env_boxed),
379                factory_val: Box::new(openrouter::Client::<reqwest::Client>::from_val_boxed),
380                env_variable: "OPENROUTER_API_KEY",
381                completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
382                ..Default::default()
383            },
384            ClientConfig {
385                name: "Together",
386                factory_env: Box::new(together::Client::<reqwest::Client>::from_env_boxed),
387                factory_val: Box::new(together::Client::<reqwest::Client>::from_val_boxed),
388                env_variable: "TOGETHER_API_KEY",
389                completion_model: Some(together::ALPACA_7B),
390                embeddings_model: Some(together::BERT_BASE_UNCASED),
391                ..Default::default()
392            },
393            ClientConfig {
394                name: "XAI",
395                factory_env: Box::new(xai::Client::<reqwest::Client>::from_env_boxed),
396                factory_val: Box::new(xai::Client::<reqwest::Client>::from_val_boxed),
397                env_variable: "XAI_API_KEY",
398                completion_model: Some(xai::GROK_3_MINI),
399                embeddings_model: None,
400                ..Default::default()
401            },
402            ClientConfig {
403                name: "Azure",
404                factory_env: Box::new(azure::Client::<reqwest::Client>::from_env_boxed),
405                factory_val: Box::new(azure::Client::<reqwest::Client>::from_val_boxed),
406                env_variable: "AZURE_API_KEY",
407                completion_model: Some(azure::GPT_4O),
408                embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
409                transcription_model: Some("whisper-1"),
410                image_generation_model: Some("dalle-2"),
411                audio_generation_model: Some(("tts-1", "onyx")),
412            },
413            ClientConfig {
414                name: "Deepseek",
415                factory_env: Box::new(deepseek::Client::<reqwest::Client>::from_env_boxed),
416                factory_val: Box::new(deepseek::Client::<reqwest::Client>::from_val_boxed),
417                env_variable: "DEEPSEEK_API_KEY",
418                completion_model: Some(deepseek::DEEPSEEK_CHAT),
419                ..Default::default()
420            },
421            ClientConfig {
422                name: "Galadriel",
423                factory_env: Box::new(galadriel::Client::<reqwest::Client>::from_env_boxed),
424                factory_val: Box::new(galadriel::Client::<reqwest::Client>::from_val_boxed),
425                env_variable: "GALADRIEL_API_KEY",
426                completion_model: Some(galadriel::GPT_4O),
427                ..Default::default()
428            },
429            ClientConfig {
430                name: "Groq",
431                factory_env: Box::new(groq::Client::<reqwest::Client>::from_env_boxed),
432                factory_val: Box::new(groq::Client::<reqwest::Client>::from_val_boxed),
433                env_variable: "GROQ_API_KEY",
434                completion_model: Some(groq::MIXTRAL_8X7B_32768),
435                transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
436                ..Default::default()
437            },
438            ClientConfig {
439                name: "Hyperbolic",
440                factory_env: Box::new(hyperbolic::Client::<reqwest::Client>::from_env_boxed),
441                factory_val: Box::new(hyperbolic::Client::<reqwest::Client>::from_val_boxed),
442                env_variable: "HYPERBOLIC_API_KEY",
443                completion_model: Some(hyperbolic::LLAMA_3_1_8B),
444                image_generation_model: Some(hyperbolic::SD1_5),
445                audio_generation_model: Some(("EN", "EN-US")),
446                ..Default::default()
447            },
448            ClientConfig {
449                name: "Mira",
450                factory_env: Box::new(mira::Client::<reqwest::Client>::from_env_boxed),
451                factory_val: Box::new(mira::Client::<reqwest::Client>::from_val_boxed),
452                env_variable: "MIRA_API_KEY",
453                completion_model: Some("gpt-4o"),
454                ..Default::default()
455            },
456            ClientConfig {
457                name: "Moonshot",
458                factory_env: Box::new(moonshot::Client::<reqwest::Client>::from_env_boxed),
459                factory_val: Box::new(moonshot::Client::<reqwest::Client>::from_val_boxed),
460                env_variable: "MOONSHOT_API_KEY",
461                completion_model: Some(moonshot::MOONSHOT_CHAT),
462                ..Default::default()
463            },
464            ClientConfig {
465                name: "Ollama",
466                factory_env: Box::new(ollama::Client::<reqwest::Client>::from_env_boxed),
467                factory_val: Box::new(ollama::Client::<reqwest::Client>::from_val_boxed),
468                env_variable: "OLLAMA_ENABLED",
469                completion_model: Some("llama3.1:8b"),
470                embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
471                ..Default::default()
472            },
473            ClientConfig {
474                name: "Perplexity",
475                factory_env: Box::new(perplexity::Client::<reqwest::Client>::from_env_boxed),
476                factory_val: Box::new(perplexity::Client::<reqwest::Client>::from_val_boxed),
477                env_variable: "PERPLEXITY_API_KEY",
478                completion_model: Some(perplexity::SONAR),
479                ..Default::default()
480            },
481        ]
482    }
483
484    async fn test_completions_client(config: &ClientConfig) {
485        let client = config.factory_env();
486
487        let Some(client) = client.as_completion() else {
488            return;
489        };
490
491        let model = config
492            .completion_model
493            .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
494
495        let model = client.completion_model(model);
496
497        let resp = model
498            .completion_request(Message::user("Whats the capital of France?"))
499            .send()
500            .await;
501
502        assert!(
503            resp.is_ok(),
504            "[{}]: Error occurred when prompting, {}",
505            config.name,
506            resp.err().unwrap()
507        );
508
509        let resp = resp.unwrap();
510
511        match resp.choice.first() {
512            AssistantContent::Text(text) => {
513                assert!(text.text.to_lowercase().contains("paris"));
514            }
515            _ => {
516                unreachable!(
517                    "[{}]: First choice wasn't a Text message, {:?}",
518                    config.name,
519                    resp.choice.first()
520                );
521            }
522        }
523    }
524
525    #[tokio::test]
526    #[ignore]
527    async fn test_completions() {
528        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
529            test_completions_client(&p).await;
530        }
531    }
532
533    async fn test_tools_client(config: &ClientConfig) {
534        let client = config.factory_env();
535        let model = config
536            .completion_model
537            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
538
539        let Some(client) = client.as_completion() else {
540            return;
541        };
542
543        let model = client.agent(model)
544            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
545            .max_tokens(1024)
546            .tool(Adder)
547            .tool(Subtract)
548            .build();
549
550        let request = model.completion("Calculate 2 - 5", vec![]).await;
551
552        assert!(
553            request.is_ok(),
554            "[{}]: Error occurred when building prompt, {}",
555            config.name,
556            request.err().unwrap()
557        );
558
559        let resp = request.unwrap().send().await;
560
561        assert!(
562            resp.is_ok(),
563            "[{}]: Error occurred when prompting, {}",
564            config.name,
565            resp.err().unwrap()
566        );
567
568        let resp = resp.unwrap();
569
570        assert!(
571            resp.choice.iter().any(|content| match content {
572                AssistantContent::ToolCall(tc) => {
573                    if tc.function.name != Subtract::NAME {
574                        return false;
575                    }
576
577                    let arguments =
578                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
579                            .expect("Error parsing arguments");
580
581                    arguments.x == 2.0 && arguments.y == 5.0
582                }
583                _ => false,
584            }),
585            "[{}]: Model did not use the Subtract tool.",
586            config.name
587        )
588    }
589
590    #[tokio::test]
591    #[ignore]
592    async fn test_tools() {
593        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
594            test_tools_client(&p).await;
595        }
596    }
597
598    async fn test_streaming_client(config: &ClientConfig) {
599        let client = config.factory_env();
600
601        let Some(client) = client.as_completion() else {
602            return;
603        };
604
605        let model = config
606            .completion_model
607            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
608
609        let model = client.completion_model(model);
610
611        let resp = model.stream(CompletionRequest {
612            preamble: None,
613            tools: vec![],
614            documents: vec![],
615            temperature: None,
616            max_tokens: None,
617            additional_params: None,
618            tool_choice: None,
619            chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
620        });
621
622        let mut resp = resp.await.unwrap();
623
624        let mut received_chunk = false;
625
626        while let Some(chunk) = resp.next().await {
627            received_chunk = true;
628            assert!(chunk.is_ok());
629        }
630
631        assert!(
632            received_chunk,
633            "[{}]: Failed to receive a chunk from stream",
634            config.name
635        );
636
637        for choice in resp.choice {
638            match choice {
639                AssistantContent::Text(text) => {
640                    assert!(
641                        text.text.to_lowercase().contains("paris"),
642                        "[{}]: Did not answer with Paris",
643                        config.name
644                    );
645                }
646                AssistantContent::ToolCall(_) => {}
647                AssistantContent::Reasoning(_) => {}
648            }
649        }
650    }
651
652    #[tokio::test]
653    #[ignore]
654    async fn test_streaming() {
655        for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
656            test_streaming_client(&provider).await;
657        }
658    }
659
660    async fn test_streaming_tools_client(config: &ClientConfig) {
661        let client = config.factory_env();
662        let model = config
663            .completion_model
664            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
665
666        let Some(client) = client.as_completion() else {
667            return;
668        };
669
670        let model = client.agent(model)
671            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
672            .max_tokens(1024)
673            .tool(Adder)
674            .tool(Subtract)
675            .build();
676
677        let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
678
679        assert!(
680            request.is_ok(),
681            "[{}]: Error occurred when building prompt, {}",
682            config.name,
683            request.err().unwrap()
684        );
685
686        let resp = request.unwrap().stream().await;
687
688        assert!(
689            resp.is_ok(),
690            "[{}]: Error occurred when prompting, {}",
691            config.name,
692            resp.err().unwrap()
693        );
694
695        let mut resp = resp.unwrap();
696
697        let mut received_chunk = false;
698
699        while let Some(chunk) = resp.next().await {
700            received_chunk = true;
701            assert!(chunk.is_ok());
702        }
703
704        assert!(
705            received_chunk,
706            "[{}]: Failed to receive a chunk from stream",
707            config.name
708        );
709
710        assert!(
711            resp.choice.iter().any(|content| match content {
712                AssistantContent::ToolCall(tc) => {
713                    if tc.function.name != Subtract::NAME {
714                        return false;
715                    }
716
717                    let arguments =
718                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
719                            .expect("Error parsing arguments");
720
721                    arguments.x == 2.0 && arguments.y == 5.0
722                }
723                _ => false,
724            }),
725            "[{}]: Model did not use the Subtract tool.",
726            config.name
727        )
728    }
729
730    #[tokio::test]
731    #[ignore]
732    async fn test_streaming_tools() {
733        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
734            test_streaming_tools_client(&p).await;
735        }
736    }
737
738    async fn test_audio_generation_client(config: &ClientConfig) {
739        let client = config.factory_env();
740
741        let Some(client) = client.as_audio_generation() else {
742            return;
743        };
744
745        let (model, voice) = config
746            .audio_generation_model
747            .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
748
749        let model = client.audio_generation_model(model);
750
751        let request = model
752            .audio_generation_request()
753            .text("Hello world!")
754            .voice(voice);
755
756        let resp = request.send().await;
757
758        assert!(
759            resp.is_ok(),
760            "[{}]: Error occurred when sending request, {}",
761            config.name,
762            resp.err().unwrap()
763        );
764
765        let resp = resp.unwrap();
766
767        assert!(
768            !resp.audio.is_empty(),
769            "[{}]: Returned audio was empty",
770            config.name
771        );
772    }
773
774    #[tokio::test]
775    #[ignore]
776    async fn test_audio_generation() {
777        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
778            test_audio_generation_client(&p).await;
779        }
780    }
781
782    fn assert_feature<F, M>(
783        name: &str,
784        feature_name: &str,
785        model_name: &str,
786        feature: Option<F>,
787        model: Option<M>,
788    ) {
789        assert_eq!(
790            feature.is_some(),
791            model.is_some(),
792            "{} has{} implemented {} but config.{} is {}.",
793            name,
794            if feature.is_some() { "" } else { "n't" },
795            feature_name,
796            model_name,
797            if model.is_some() { "some" } else { "none" }
798        );
799    }
800
801    #[test]
802    #[ignore]
803    pub fn test_polymorphism() {
804        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
805            let client = config.factory_env();
806            assert_feature(
807                config.name,
808                "AsCompletion",
809                "completion_model",
810                client.as_completion(),
811                config.completion_model,
812            );
813
814            assert_feature(
815                config.name,
816                "AsEmbeddings",
817                "embeddings_model",
818                client.as_embeddings(),
819                config.embeddings_model,
820            );
821
822            assert_feature(
823                config.name,
824                "AsTranscription",
825                "transcription_model",
826                client.as_transcription(),
827                config.transcription_model,
828            );
829
830            assert_feature(
831                config.name,
832                "AsImageGeneration",
833                "image_generation_model",
834                client.as_image_generation(),
835                config.image_generation_model,
836            );
837
838            assert_feature(
839                config.name,
840                "AsAudioGeneration",
841                "audio_generation_model",
842                client.as_audio_generation(),
843                config.audio_generation_model,
844            )
845        }
846    }
847
848    async fn test_embed_client(config: &ClientConfig) {
849        const TEST: &str = "Hello world.";
850
851        let client = config.factory_env();
852
853        let Some(client) = client.as_embeddings() else {
854            return;
855        };
856
857        let model = config.embeddings_model.unwrap();
858
859        let model = client.embedding_model(model);
860
861        let resp = model.embed_text(TEST).await;
862
863        assert!(
864            resp.is_ok(),
865            "[{}]: Error occurred when sending request, {}",
866            config.name,
867            resp.err().unwrap()
868        );
869
870        let resp = resp.unwrap();
871
872        assert_eq!(resp.document, TEST);
873
874        assert!(
875            !resp.vec.is_empty(),
876            "[{}]: Returned embed was empty",
877            config.name
878        );
879    }
880
881    #[tokio::test]
882    #[ignore]
883    async fn test_embed() {
884        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
885            test_embed_client(&config).await;
886        }
887    }
888
889    async fn test_image_generation_client(config: &ClientConfig) {
890        let client = config.factory_env();
891        let Some(client) = client.as_image_generation() else {
892            return;
893        };
894
895        let model = config.image_generation_model.unwrap();
896
897        let model = client.image_generation_model(model);
898
899        let resp = model
900            .image_generation(ImageGenerationRequest {
901                prompt: "A castle sitting on a large hill.".to_string(),
902                width: 256,
903                height: 256,
904                additional_params: None,
905            })
906            .await;
907
908        assert!(
909            resp.is_ok(),
910            "[{}]: Error occurred when sending request, {}",
911            config.name,
912            resp.err().unwrap()
913        );
914
915        let resp = resp.unwrap();
916
917        assert!(
918            !resp.image.is_empty(),
919            "[{}]: Generated image was empty",
920            config.name
921        );
922    }
923
924    #[tokio::test]
925    #[ignore]
926    async fn test_image_generation() {
927        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
928            test_image_generation_client(&config).await;
929        }
930    }
931
932    async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
933        let client = config.factory_env();
934        let Some(client) = client.as_transcription() else {
935            return;
936        };
937
938        let model = config.image_generation_model.unwrap();
939
940        let model = client.transcription_model(model);
941
942        let resp = model
943            .transcription(TranscriptionRequest {
944                data,
945                filename: "audio.mp3".to_string(),
946                language: "en".to_string(),
947                prompt: None,
948                temperature: None,
949                additional_params: None,
950            })
951            .await;
952
953        assert!(
954            resp.is_ok(),
955            "[{}]: Error occurred when sending request, {}",
956            config.name,
957            resp.err().unwrap()
958        );
959
960        let resp = resp.unwrap();
961
962        assert!(
963            !resp.text.is_empty(),
964            "[{}]: Returned transcription was empty",
965            config.name
966        );
967    }
968
969    #[tokio::test]
970    #[ignore]
971    async fn test_transcription() {
972        let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
973
974        let mut data = Vec::new();
975        let _ = file.read(&mut data);
976
977        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
978            test_transcription_client(&config, data.clone()).await;
979        }
980    }
981
982    #[derive(Deserialize)]
983    struct OperationArgs {
984        x: f32,
985        y: f32,
986    }
987
988    #[derive(Debug, thiserror::Error)]
989    #[error("Math error")]
990    struct MathError;
991
992    #[derive(Deserialize, Serialize)]
993    struct Adder;
994    impl Tool for Adder {
995        const NAME: &'static str = "add";
996
997        type Error = MathError;
998        type Args = OperationArgs;
999        type Output = f32;
1000
1001        async fn definition(&self, _prompt: String) -> ToolDefinition {
1002            ToolDefinition {
1003                name: "add".to_string(),
1004                description: "Add x and y together".to_string(),
1005                parameters: json!({
1006                    "type": "object",
1007                    "properties": {
1008                        "x": {
1009                            "type": "number",
1010                            "description": "The first number to add"
1011                        },
1012                        "y": {
1013                            "type": "number",
1014                            "description": "The second number to add"
1015                        }
1016                    }
1017                }),
1018            }
1019        }
1020
1021        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1022            println!("[tool-call] Adding {} and {}", args.x, args.y);
1023            let result = args.x + args.y;
1024            Ok(result)
1025        }
1026    }
1027
1028    #[derive(Deserialize, Serialize)]
1029    struct Subtract;
1030    impl Tool for Subtract {
1031        const NAME: &'static str = "subtract";
1032
1033        type Error = MathError;
1034        type Args = OperationArgs;
1035        type Output = f32;
1036
1037        async fn definition(&self, _prompt: String) -> ToolDefinition {
1038            serde_json::from_value(json!({
1039                "name": "subtract",
1040                "description": "Subtract y from x (i.e.: x - y)",
1041                "parameters": {
1042                    "type": "object",
1043                    "properties": {
1044                        "x": {
1045                            "type": "number",
1046                            "description": "The number to subtract from"
1047                        },
1048                        "y": {
1049                            "type": "number",
1050                            "description": "The number to subtract"
1051                        }
1052                    }
1053                }
1054            }))
1055            .expect("Tool Definition")
1056        }
1057
1058        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1059            println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1060            let result = args.x - args.y;
1061            Ok(result)
1062        }
1063    }
1064}