rig/client/
mod.rs

1//! This module provides traits for defining and creating provider clients.
2//! Clients are used to create models for completion, embeddings, etc.
3//! Dyn-compatible traits have been provided to allow for more provider-agnostic code.
4
5pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13#[cfg(feature = "derive")]
14pub use rig_derive::ProviderClient;
15use std::fmt::Debug;
16use thiserror::Error;
17
18#[derive(Debug, Error)]
19#[non_exhaustive]
20pub enum ClientBuilderError {
21    #[error("reqwest error: {0}")]
22    HttpError(
23        #[from]
24        #[source]
25        reqwest::Error,
26    ),
27    #[error("invalid property: {0}")]
28    InvalidProperty(&'static str),
29}
30
31/// The base ProviderClient trait, facilitates conversion between client types
32/// and creating a client from the environment.
33///
34/// All conversion traits must be implemented, they are automatically
35/// implemented if the respective client trait is implemented.
36pub trait ProviderClient:
37    AsCompletion
38    + AsTranscription
39    + AsEmbeddings
40    + AsImageGeneration
41    + AsAudioGeneration
42    + Debug
43    + WasmCompatSend
44    + WasmCompatSync
45{
46    /// Create a client from the process's environment.
47    /// Panics if an environment is improperly configured.
48    fn from_env() -> Self
49    where
50        Self: Sized;
51
52    /// A helper method to box the client.
53    fn boxed(self) -> Box<dyn ProviderClient>
54    where
55        Self: Sized + 'static,
56    {
57        Box::new(self)
58    }
59
60    /// Create a boxed client from the process's environment.
61    /// Panics if an environment is improperly configured.
62    fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
63    where
64        Self: Sized,
65        Self: 'a,
66    {
67        Box::new(Self::from_env())
68    }
69
70    fn from_val(input: ProviderValue) -> Self
71    where
72        Self: Sized;
73
74    /// Create a boxed client from the process's environment.
75    /// Panics if an environment is improperly configured.
76    fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
77    where
78        Self: Sized,
79        Self: 'a,
80    {
81        Box::new(Self::from_val(input))
82    }
83}
84
85#[derive(Clone)]
86pub enum ProviderValue {
87    Simple(String),
88    ApiKeyWithOptionalKey(String, Option<String>),
89    ApiKeyWithVersionAndHeader(String, String, String),
90}
91
92impl From<&str> for ProviderValue {
93    fn from(value: &str) -> Self {
94        Self::Simple(value.to_string())
95    }
96}
97
98impl From<String> for ProviderValue {
99    fn from(value: String) -> Self {
100        Self::Simple(value)
101    }
102}
103
104impl<P> From<(P, Option<P>)> for ProviderValue
105where
106    P: AsRef<str>,
107{
108    fn from((api_key, optional_key): (P, Option<P>)) -> Self {
109        Self::ApiKeyWithOptionalKey(
110            api_key.as_ref().to_string(),
111            optional_key.map(|x| x.as_ref().to_string()),
112        )
113    }
114}
115
116impl<P> From<(P, P, P)> for ProviderValue
117where
118    P: AsRef<str>,
119{
120    fn from((api_key, version, header): (P, P, P)) -> Self {
121        Self::ApiKeyWithVersionAndHeader(
122            api_key.as_ref().to_string(),
123            version.as_ref().to_string(),
124            header.as_ref().to_string(),
125        )
126    }
127}
128
129/// Attempt to convert a ProviderClient to a CompletionClient
130pub trait AsCompletion {
131    fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
132        None
133    }
134}
135
136/// Attempt to convert a ProviderClient to a TranscriptionClient
137pub trait AsTranscription {
138    fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
139        None
140    }
141}
142
143/// Attempt to convert a ProviderClient to a EmbeddingsClient
144pub trait AsEmbeddings {
145    fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
146        None
147    }
148}
149
150/// Attempt to convert a ProviderClient to a AudioGenerationClient
151pub trait AsAudioGeneration {
152    #[cfg(feature = "audio")]
153    fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
154        None
155    }
156}
157
158/// Attempt to convert a ProviderClient to a ImageGenerationClient
159pub trait AsImageGeneration {
160    #[cfg(feature = "image")]
161    fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
162        None
163    }
164}
165
166/// Attempt to convert a ProviderClient to a VerifyClient
167pub trait AsVerify {
168    fn as_verify(&self) -> Option<Box<dyn VerifyClientDyn>> {
169        None
170    }
171}
172
173#[cfg(not(feature = "audio"))]
174impl<T: ProviderClient> AsAudioGeneration for T {}
175
176#[cfg(not(feature = "image"))]
177impl<T: ProviderClient> AsImageGeneration for T {}
178
179#[macro_export]
180macro_rules! impl_conversion_traits {
181    ($( $trait_:ident ),* for $($type_spec:tt)+) => {
182        impl_conversion_traits!(@expand_traits [$($trait_)+] $($type_spec)+);
183    };
184
185    (@expand_traits [$trait_:ident $($rest_traits:ident)*] $($type_spec:tt)+) => {
186        impl_conversion_traits!(@impl $trait_ for $($type_spec)+);
187        impl_conversion_traits!(@expand_traits [$($rest_traits)*] $($type_spec)+);
188    };
189
190    (@expand_traits [] $($type_spec:tt)+) => {};
191
192    (@impl AsAudioGeneration for $($type_spec:tt)+) => {
193        rig::client::impl_audio_generation!($($type_spec)+);
194    };
195
196    (@impl AsImageGeneration for $($type_spec:tt)+) => {
197        rig::client::impl_image_generation!($($type_spec)+);
198    };
199
200    (@impl $trait_:ident for $($type_spec:tt)+) => {
201        impl_conversion_traits!(@impl_trait $trait_ for $($type_spec)+);
202    };
203
204    (@impl_trait $trait_:ident for $struct_:ident) => {
205        impl rig::client::$trait_ for $struct_ {}
206    };
207
208    (@impl_trait $trait_:ident for $struct_:ident<$($generics:tt),*>) => {
209        impl<$($generics),*> rig::client::$trait_ for $struct_<$($generics),*> {}
210    };
211}
212
213#[cfg(feature = "audio")]
214#[macro_export]
215macro_rules! impl_audio_generation {
216    ($struct_:ident) => {
217        impl rig::client::AsAudioGeneration for $struct_ {}
218    };
219    ($struct_:ident<$($generics:tt),*>) => {
220        impl<$($generics),*> rig::client::AsAudioGeneration for $struct_<$($generics),*> {}
221    };
222}
223
224#[cfg(not(feature = "audio"))]
225#[macro_export]
226macro_rules! impl_audio_generation {
227    ($($tokens:tt)*) => {};
228}
229
230#[cfg(feature = "image")]
231#[macro_export]
232macro_rules! impl_image_generation {
233    ($struct_:ident) => {
234        impl rig::client::AsImageGeneration for $struct_ {}
235    };
236    ($struct_:ident<$($generics:tt),*>) => {
237        impl<$($generics),*> rig::client::AsImageGeneration for $struct_<$($generics),*> {}
238    };
239}
240
241#[cfg(not(feature = "image"))]
242#[macro_export]
243macro_rules! impl_image_generation {
244    ($($tokens:tt)*) => {};
245}
246
247pub use impl_audio_generation;
248pub use impl_conversion_traits;
249pub use impl_image_generation;
250
251#[cfg(feature = "audio")]
252use crate::client::audio_generation::AudioGenerationClientDyn;
253use crate::client::completion::CompletionClientDyn;
254use crate::client::embeddings::EmbeddingsClientDyn;
255#[cfg(feature = "image")]
256use crate::client::image_generation::ImageGenerationClientDyn;
257use crate::client::transcription::TranscriptionClientDyn;
258use crate::client::verify::VerifyClientDyn;
259
260#[cfg(feature = "audio")]
261pub use crate::client::audio_generation::AudioGenerationClient;
262pub use crate::client::completion::CompletionClient;
263pub use crate::client::embeddings::EmbeddingsClient;
264#[cfg(feature = "image")]
265pub use crate::client::image_generation::ImageGenerationClient;
266pub use crate::client::transcription::TranscriptionClient;
267pub use crate::client::verify::{VerifyClient, VerifyError};
268use crate::wasm_compat::{WasmCompatSend, WasmCompatSync};
269
270#[cfg(test)]
271mod tests {
272    use crate::OneOrMany;
273    use crate::client::ProviderClient;
274    use crate::completion::{Completion, CompletionRequest, ToolDefinition};
275    use crate::image_generation::ImageGenerationRequest;
276    use crate::message::AssistantContent;
277    use crate::providers::{
278        anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
279        moonshot, openai, openrouter, together, xai,
280    };
281    use crate::streaming::StreamingCompletion;
282    use crate::tool::Tool;
283    use crate::transcription::TranscriptionRequest;
284    use futures::StreamExt;
285    use rig::message::Message;
286    use rig::providers::{groq, ollama, perplexity};
287    use serde::{Deserialize, Serialize};
288    use serde_json::json;
289    use std::fs::File;
290    use std::io::Read;
291
292    use super::ProviderValue;
293
294    struct ClientConfig {
295        name: &'static str,
296        factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
297        // Not sure where we're going to be using this but I've added it for completeness
298        #[allow(dead_code)]
299        factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
300        env_variable: &'static str,
301        completion_model: Option<&'static str>,
302        embeddings_model: Option<&'static str>,
303        transcription_model: Option<&'static str>,
304        image_generation_model: Option<&'static str>,
305        audio_generation_model: Option<(&'static str, &'static str)>,
306    }
307
308    impl Default for ClientConfig {
309        fn default() -> Self {
310            Self {
311                name: "",
312                factory_env: Box::new(|| panic!("Not implemented")),
313                factory_val: Box::new(|_| panic!("Not implemented")),
314                env_variable: "",
315                completion_model: None,
316                embeddings_model: None,
317                transcription_model: None,
318                image_generation_model: None,
319                audio_generation_model: None,
320            }
321        }
322    }
323
324    impl ClientConfig {
325        fn is_env_var_set(&self) -> bool {
326            self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
327        }
328
329        fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
330            self.factory_env.as_ref()()
331        }
332    }
333
334    fn providers() -> Vec<ClientConfig> {
335        vec![
336            ClientConfig {
337                name: "Anthropic",
338                factory_env: Box::new(anthropic::Client::<reqwest::Client>::from_env_boxed),
339                factory_val: Box::new(anthropic::Client::<reqwest::Client>::from_val_boxed),
340                env_variable: "ANTHROPIC_API_KEY",
341                completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
342                ..Default::default()
343            },
344            ClientConfig {
345                name: "Cohere",
346                factory_env: Box::new(cohere::Client::<reqwest::Client>::from_env_boxed),
347                factory_val: Box::new(cohere::Client::<reqwest::Client>::from_val_boxed),
348                env_variable: "COHERE_API_KEY",
349                completion_model: Some(cohere::COMMAND_R),
350                embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
351                ..Default::default()
352            },
353            ClientConfig {
354                name: "Gemini",
355                factory_env: Box::new(gemini::Client::<reqwest::Client>::from_env_boxed),
356                factory_val: Box::new(gemini::Client::<reqwest::Client>::from_val_boxed),
357                env_variable: "GEMINI_API_KEY",
358                completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
359                embeddings_model: Some(gemini::embedding::EMBEDDING_001),
360                transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
361                ..Default::default()
362            },
363            ClientConfig {
364                name: "Huggingface",
365                factory_env: Box::new(huggingface::Client::<reqwest::Client>::from_env_boxed),
366                factory_val: Box::new(huggingface::Client::<reqwest::Client>::from_val_boxed),
367                env_variable: "HUGGINGFACE_API_KEY",
368                completion_model: Some(huggingface::PHI_4),
369                transcription_model: Some(huggingface::WHISPER_SMALL),
370                image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
371                ..Default::default()
372            },
373            ClientConfig {
374                name: "OpenAI",
375                factory_env: Box::new(openai::Client::<reqwest::Client>::from_env_boxed),
376                factory_val: Box::new(openai::Client::<reqwest::Client>::from_val_boxed),
377                env_variable: "OPENAI_API_KEY",
378                completion_model: Some(openai::GPT_4O),
379                embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
380                transcription_model: Some(openai::WHISPER_1),
381                image_generation_model: Some(openai::DALL_E_2),
382                audio_generation_model: Some((openai::TTS_1, "onyx")),
383            },
384            ClientConfig {
385                name: "OpenRouter",
386                factory_env: Box::new(openrouter::Client::<reqwest::Client>::from_env_boxed),
387                factory_val: Box::new(openrouter::Client::<reqwest::Client>::from_val_boxed),
388                env_variable: "OPENROUTER_API_KEY",
389                completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
390                ..Default::default()
391            },
392            ClientConfig {
393                name: "Together",
394                factory_env: Box::new(together::Client::<reqwest::Client>::from_env_boxed),
395                factory_val: Box::new(together::Client::<reqwest::Client>::from_val_boxed),
396                env_variable: "TOGETHER_API_KEY",
397                completion_model: Some(together::ALPACA_7B),
398                embeddings_model: Some(together::BERT_BASE_UNCASED),
399                ..Default::default()
400            },
401            ClientConfig {
402                name: "XAI",
403                factory_env: Box::new(xai::Client::<reqwest::Client>::from_env_boxed),
404                factory_val: Box::new(xai::Client::<reqwest::Client>::from_val_boxed),
405                env_variable: "XAI_API_KEY",
406                completion_model: Some(xai::GROK_3_MINI),
407                embeddings_model: None,
408                ..Default::default()
409            },
410            ClientConfig {
411                name: "Azure",
412                factory_env: Box::new(azure::Client::<reqwest::Client>::from_env_boxed),
413                factory_val: Box::new(azure::Client::<reqwest::Client>::from_val_boxed),
414                env_variable: "AZURE_API_KEY",
415                completion_model: Some(azure::GPT_4O),
416                embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
417                transcription_model: Some("whisper-1"),
418                image_generation_model: Some("dalle-2"),
419                audio_generation_model: Some(("tts-1", "onyx")),
420            },
421            ClientConfig {
422                name: "Deepseek",
423                factory_env: Box::new(deepseek::Client::<reqwest::Client>::from_env_boxed),
424                factory_val: Box::new(deepseek::Client::<reqwest::Client>::from_val_boxed),
425                env_variable: "DEEPSEEK_API_KEY",
426                completion_model: Some(deepseek::DEEPSEEK_CHAT),
427                ..Default::default()
428            },
429            ClientConfig {
430                name: "Galadriel",
431                factory_env: Box::new(galadriel::Client::<reqwest::Client>::from_env_boxed),
432                factory_val: Box::new(galadriel::Client::<reqwest::Client>::from_val_boxed),
433                env_variable: "GALADRIEL_API_KEY",
434                completion_model: Some(galadriel::GPT_4O),
435                ..Default::default()
436            },
437            ClientConfig {
438                name: "Groq",
439                factory_env: Box::new(groq::Client::<reqwest::Client>::from_env_boxed),
440                factory_val: Box::new(groq::Client::<reqwest::Client>::from_val_boxed),
441                env_variable: "GROQ_API_KEY",
442                completion_model: Some(groq::MIXTRAL_8X7B_32768),
443                transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
444                ..Default::default()
445            },
446            ClientConfig {
447                name: "Hyperbolic",
448                factory_env: Box::new(hyperbolic::Client::<reqwest::Client>::from_env_boxed),
449                factory_val: Box::new(hyperbolic::Client::<reqwest::Client>::from_val_boxed),
450                env_variable: "HYPERBOLIC_API_KEY",
451                completion_model: Some(hyperbolic::LLAMA_3_1_8B),
452                image_generation_model: Some(hyperbolic::SD1_5),
453                audio_generation_model: Some(("EN", "EN-US")),
454                ..Default::default()
455            },
456            ClientConfig {
457                name: "Mira",
458                factory_env: Box::new(mira::Client::<reqwest::Client>::from_env_boxed),
459                factory_val: Box::new(mira::Client::<reqwest::Client>::from_val_boxed),
460                env_variable: "MIRA_API_KEY",
461                completion_model: Some("gpt-4o"),
462                ..Default::default()
463            },
464            ClientConfig {
465                name: "Moonshot",
466                factory_env: Box::new(moonshot::Client::<reqwest::Client>::from_env_boxed),
467                factory_val: Box::new(moonshot::Client::<reqwest::Client>::from_val_boxed),
468                env_variable: "MOONSHOT_API_KEY",
469                completion_model: Some(moonshot::MOONSHOT_CHAT),
470                ..Default::default()
471            },
472            ClientConfig {
473                name: "Ollama",
474                factory_env: Box::new(ollama::Client::<reqwest::Client>::from_env_boxed),
475                factory_val: Box::new(ollama::Client::<reqwest::Client>::from_val_boxed),
476                env_variable: "OLLAMA_ENABLED",
477                completion_model: Some("llama3.1:8b"),
478                embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
479                ..Default::default()
480            },
481            ClientConfig {
482                name: "Perplexity",
483                factory_env: Box::new(perplexity::Client::<reqwest::Client>::from_env_boxed),
484                factory_val: Box::new(perplexity::Client::<reqwest::Client>::from_val_boxed),
485                env_variable: "PERPLEXITY_API_KEY",
486                completion_model: Some(perplexity::SONAR),
487                ..Default::default()
488            },
489        ]
490    }
491
492    async fn test_completions_client(config: &ClientConfig) {
493        let client = config.factory_env();
494
495        let Some(client) = client.as_completion() else {
496            return;
497        };
498
499        let model = config
500            .completion_model
501            .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
502
503        let model = client.completion_model(model);
504
505        let resp = model
506            .completion_request(Message::user("Whats the capital of France?"))
507            .send()
508            .await;
509
510        assert!(
511            resp.is_ok(),
512            "[{}]: Error occurred when prompting, {}",
513            config.name,
514            resp.err().unwrap()
515        );
516
517        let resp = resp.unwrap();
518
519        match resp.choice.first() {
520            AssistantContent::Text(text) => {
521                assert!(text.text.to_lowercase().contains("paris"));
522            }
523            _ => {
524                unreachable!(
525                    "[{}]: First choice wasn't a Text message, {:?}",
526                    config.name,
527                    resp.choice.first()
528                );
529            }
530        }
531    }
532
533    #[tokio::test]
534    #[ignore]
535    async fn test_completions() {
536        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
537            test_completions_client(&p).await;
538        }
539    }
540
541    async fn test_tools_client(config: &ClientConfig) {
542        let client = config.factory_env();
543        let model = config
544            .completion_model
545            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
546
547        let Some(client) = client.as_completion() else {
548            return;
549        };
550
551        let model = client.agent(model)
552            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
553            .max_tokens(1024)
554            .tool(Adder)
555            .tool(Subtract)
556            .build();
557
558        let request = model.completion("Calculate 2 - 5", vec![]).await;
559
560        assert!(
561            request.is_ok(),
562            "[{}]: Error occurred when building prompt, {}",
563            config.name,
564            request.err().unwrap()
565        );
566
567        let resp = request.unwrap().send().await;
568
569        assert!(
570            resp.is_ok(),
571            "[{}]: Error occurred when prompting, {}",
572            config.name,
573            resp.err().unwrap()
574        );
575
576        let resp = resp.unwrap();
577
578        assert!(
579            resp.choice.iter().any(|content| match content {
580                AssistantContent::ToolCall(tc) => {
581                    if tc.function.name != Subtract::NAME {
582                        return false;
583                    }
584
585                    let arguments =
586                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
587                            .expect("Error parsing arguments");
588
589                    arguments.x == 2.0 && arguments.y == 5.0
590                }
591                _ => false,
592            }),
593            "[{}]: Model did not use the Subtract tool.",
594            config.name
595        )
596    }
597
598    #[tokio::test]
599    #[ignore]
600    async fn test_tools() {
601        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
602            test_tools_client(&p).await;
603        }
604    }
605
606    async fn test_streaming_client(config: &ClientConfig) {
607        let client = config.factory_env();
608
609        let Some(client) = client.as_completion() else {
610            return;
611        };
612
613        let model = config
614            .completion_model
615            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
616
617        let model = client.completion_model(model);
618
619        let resp = model.stream(CompletionRequest {
620            preamble: None,
621            tools: vec![],
622            documents: vec![],
623            temperature: None,
624            max_tokens: None,
625            additional_params: None,
626            tool_choice: None,
627            chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
628        });
629
630        let mut resp = resp.await.unwrap();
631
632        let mut received_chunk = false;
633
634        while let Some(chunk) = resp.next().await {
635            received_chunk = true;
636            assert!(chunk.is_ok());
637        }
638
639        assert!(
640            received_chunk,
641            "[{}]: Failed to receive a chunk from stream",
642            config.name
643        );
644
645        for choice in resp.choice {
646            match choice {
647                AssistantContent::Text(text) => {
648                    assert!(
649                        text.text.to_lowercase().contains("paris"),
650                        "[{}]: Did not answer with Paris",
651                        config.name
652                    );
653                }
654                AssistantContent::ToolCall(_) => {}
655                AssistantContent::Reasoning(_) => {}
656            }
657        }
658    }
659
660    #[tokio::test]
661    #[ignore]
662    async fn test_streaming() {
663        for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
664            test_streaming_client(&provider).await;
665        }
666    }
667
668    async fn test_streaming_tools_client(config: &ClientConfig) {
669        let client = config.factory_env();
670        let model = config
671            .completion_model
672            .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
673
674        let Some(client) = client.as_completion() else {
675            return;
676        };
677
678        let model = client.agent(model)
679            .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
680            .max_tokens(1024)
681            .tool(Adder)
682            .tool(Subtract)
683            .build();
684
685        let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
686
687        assert!(
688            request.is_ok(),
689            "[{}]: Error occurred when building prompt, {}",
690            config.name,
691            request.err().unwrap()
692        );
693
694        let resp = request.unwrap().stream().await;
695
696        assert!(
697            resp.is_ok(),
698            "[{}]: Error occurred when prompting, {}",
699            config.name,
700            resp.err().unwrap()
701        );
702
703        let mut resp = resp.unwrap();
704
705        let mut received_chunk = false;
706
707        while let Some(chunk) = resp.next().await {
708            received_chunk = true;
709            assert!(chunk.is_ok());
710        }
711
712        assert!(
713            received_chunk,
714            "[{}]: Failed to receive a chunk from stream",
715            config.name
716        );
717
718        assert!(
719            resp.choice.iter().any(|content| match content {
720                AssistantContent::ToolCall(tc) => {
721                    if tc.function.name != Subtract::NAME {
722                        return false;
723                    }
724
725                    let arguments =
726                        serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
727                            .expect("Error parsing arguments");
728
729                    arguments.x == 2.0 && arguments.y == 5.0
730                }
731                _ => false,
732            }),
733            "[{}]: Model did not use the Subtract tool.",
734            config.name
735        )
736    }
737
738    #[tokio::test]
739    #[ignore]
740    async fn test_streaming_tools() {
741        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
742            test_streaming_tools_client(&p).await;
743        }
744    }
745
746    async fn test_audio_generation_client(config: &ClientConfig) {
747        let client = config.factory_env();
748
749        let Some(client) = client.as_audio_generation() else {
750            return;
751        };
752
753        let (model, voice) = config
754            .audio_generation_model
755            .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
756
757        let model = client.audio_generation_model(model);
758
759        let request = model
760            .audio_generation_request()
761            .text("Hello world!")
762            .voice(voice);
763
764        let resp = request.send().await;
765
766        assert!(
767            resp.is_ok(),
768            "[{}]: Error occurred when sending request, {}",
769            config.name,
770            resp.err().unwrap()
771        );
772
773        let resp = resp.unwrap();
774
775        assert!(
776            !resp.audio.is_empty(),
777            "[{}]: Returned audio was empty",
778            config.name
779        );
780    }
781
782    #[tokio::test]
783    #[ignore]
784    async fn test_audio_generation() {
785        for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
786            test_audio_generation_client(&p).await;
787        }
788    }
789
790    fn assert_feature<F, M>(
791        name: &str,
792        feature_name: &str,
793        model_name: &str,
794        feature: Option<F>,
795        model: Option<M>,
796    ) {
797        assert_eq!(
798            feature.is_some(),
799            model.is_some(),
800            "{} has{} implemented {} but config.{} is {}.",
801            name,
802            if feature.is_some() { "" } else { "n't" },
803            feature_name,
804            model_name,
805            if model.is_some() { "some" } else { "none" }
806        );
807    }
808
809    #[test]
810    #[ignore]
811    pub fn test_polymorphism() {
812        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
813            let client = config.factory_env();
814            assert_feature(
815                config.name,
816                "AsCompletion",
817                "completion_model",
818                client.as_completion(),
819                config.completion_model,
820            );
821
822            assert_feature(
823                config.name,
824                "AsEmbeddings",
825                "embeddings_model",
826                client.as_embeddings(),
827                config.embeddings_model,
828            );
829
830            assert_feature(
831                config.name,
832                "AsTranscription",
833                "transcription_model",
834                client.as_transcription(),
835                config.transcription_model,
836            );
837
838            assert_feature(
839                config.name,
840                "AsImageGeneration",
841                "image_generation_model",
842                client.as_image_generation(),
843                config.image_generation_model,
844            );
845
846            assert_feature(
847                config.name,
848                "AsAudioGeneration",
849                "audio_generation_model",
850                client.as_audio_generation(),
851                config.audio_generation_model,
852            )
853        }
854    }
855
856    async fn test_embed_client(config: &ClientConfig) {
857        const TEST: &str = "Hello world.";
858
859        let client = config.factory_env();
860
861        let Some(client) = client.as_embeddings() else {
862            return;
863        };
864
865        let model = config.embeddings_model.unwrap();
866
867        let model = client.embedding_model(model);
868
869        let resp = model.embed_text(TEST).await;
870
871        assert!(
872            resp.is_ok(),
873            "[{}]: Error occurred when sending request, {}",
874            config.name,
875            resp.err().unwrap()
876        );
877
878        let resp = resp.unwrap();
879
880        assert_eq!(resp.document, TEST);
881
882        assert!(
883            !resp.vec.is_empty(),
884            "[{}]: Returned embed was empty",
885            config.name
886        );
887    }
888
889    #[tokio::test]
890    #[ignore]
891    async fn test_embed() {
892        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
893            test_embed_client(&config).await;
894        }
895    }
896
897    async fn test_image_generation_client(config: &ClientConfig) {
898        let client = config.factory_env();
899        let Some(client) = client.as_image_generation() else {
900            return;
901        };
902
903        let model = config.image_generation_model.unwrap();
904
905        let model = client.image_generation_model(model);
906
907        let resp = model
908            .image_generation(ImageGenerationRequest {
909                prompt: "A castle sitting on a large hill.".to_string(),
910                width: 256,
911                height: 256,
912                additional_params: None,
913            })
914            .await;
915
916        assert!(
917            resp.is_ok(),
918            "[{}]: Error occurred when sending request, {}",
919            config.name,
920            resp.err().unwrap()
921        );
922
923        let resp = resp.unwrap();
924
925        assert!(
926            !resp.image.is_empty(),
927            "[{}]: Generated image was empty",
928            config.name
929        );
930    }
931
932    #[tokio::test]
933    #[ignore]
934    async fn test_image_generation() {
935        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
936            test_image_generation_client(&config).await;
937        }
938    }
939
940    async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
941        let client = config.factory_env();
942        let Some(client) = client.as_transcription() else {
943            return;
944        };
945
946        let model = config.image_generation_model.unwrap();
947
948        let model = client.transcription_model(model);
949
950        let resp = model
951            .transcription(TranscriptionRequest {
952                data,
953                filename: "audio.mp3".to_string(),
954                language: None,
955                prompt: None,
956                temperature: None,
957                additional_params: None,
958            })
959            .await;
960
961        assert!(
962            resp.is_ok(),
963            "[{}]: Error occurred when sending request, {}",
964            config.name,
965            resp.err().unwrap()
966        );
967
968        let resp = resp.unwrap();
969
970        assert!(
971            !resp.text.is_empty(),
972            "[{}]: Returned transcription was empty",
973            config.name
974        );
975    }
976
977    #[tokio::test]
978    #[ignore]
979    async fn test_transcription() {
980        let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
981
982        let mut data = Vec::new();
983        let _ = file.read(&mut data);
984
985        for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
986            test_transcription_client(&config, data.clone()).await;
987        }
988    }
989
990    #[derive(Deserialize)]
991    struct OperationArgs {
992        x: f32,
993        y: f32,
994    }
995
996    #[derive(Debug, thiserror::Error)]
997    #[error("Math error")]
998    struct MathError;
999
1000    #[derive(Deserialize, Serialize)]
1001    struct Adder;
1002    impl Tool for Adder {
1003        const NAME: &'static str = "add";
1004
1005        type Error = MathError;
1006        type Args = OperationArgs;
1007        type Output = f32;
1008
1009        async fn definition(&self, _prompt: String) -> ToolDefinition {
1010            ToolDefinition {
1011                name: "add".to_string(),
1012                description: "Add x and y together".to_string(),
1013                parameters: json!({
1014                    "type": "object",
1015                    "properties": {
1016                        "x": {
1017                            "type": "number",
1018                            "description": "The first number to add"
1019                        },
1020                        "y": {
1021                            "type": "number",
1022                            "description": "The second number to add"
1023                        }
1024                    }
1025                }),
1026            }
1027        }
1028
1029        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1030            println!("[tool-call] Adding {} and {}", args.x, args.y);
1031            let result = args.x + args.y;
1032            Ok(result)
1033        }
1034    }
1035
1036    #[derive(Deserialize, Serialize)]
1037    struct Subtract;
1038    impl Tool for Subtract {
1039        const NAME: &'static str = "subtract";
1040
1041        type Error = MathError;
1042        type Args = OperationArgs;
1043        type Output = f32;
1044
1045        async fn definition(&self, _prompt: String) -> ToolDefinition {
1046            serde_json::from_value(json!({
1047                "name": "subtract",
1048                "description": "Subtract y from x (i.e.: x - y)",
1049                "parameters": {
1050                    "type": "object",
1051                    "properties": {
1052                        "x": {
1053                            "type": "number",
1054                            "description": "The number to subtract from"
1055                        },
1056                        "y": {
1057                            "type": "number",
1058                            "description": "The number to subtract"
1059                        }
1060                    }
1061                }
1062            }))
1063            .expect("Tool Definition")
1064        }
1065
1066        async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1067            println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1068            let result = args.x - args.y;
1069            Ok(result)
1070        }
1071    }
1072}