1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13#[cfg(feature = "derive")]
14pub use rig_derive::ProviderClient;
15use std::fmt::Debug;
16use thiserror::Error;
17
18#[derive(Debug, Error)]
19#[non_exhaustive]
20pub enum ClientBuilderError {
21 #[error("reqwest error: {0}")]
22 HttpError(
23 #[from]
24 #[source]
25 reqwest::Error,
26 ),
27 #[error("invalid property: {0}")]
28 InvalidProperty(&'static str),
29}
30
31pub trait ProviderClient:
37 AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
38{
39 fn from_env() -> Self
42 where
43 Self: Sized;
44
45 fn boxed(self) -> Box<dyn ProviderClient>
47 where
48 Self: Sized + 'static,
49 {
50 Box::new(self)
51 }
52
53 fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
56 where
57 Self: Sized,
58 Self: 'a,
59 {
60 Box::new(Self::from_env())
61 }
62
63 fn from_val(input: ProviderValue) -> Self
64 where
65 Self: Sized;
66
67 fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
70 where
71 Self: Sized,
72 Self: 'a,
73 {
74 Box::new(Self::from_val(input))
75 }
76}
77
78#[derive(Clone)]
79pub enum ProviderValue {
80 Simple(String),
81 ApiKeyWithOptionalKey(String, Option<String>),
82 ApiKeyWithVersionAndHeader(String, String, String),
83}
84
85impl From<&str> for ProviderValue {
86 fn from(value: &str) -> Self {
87 Self::Simple(value.to_string())
88 }
89}
90
91impl From<String> for ProviderValue {
92 fn from(value: String) -> Self {
93 Self::Simple(value)
94 }
95}
96
97impl<P> From<(P, Option<P>)> for ProviderValue
98where
99 P: AsRef<str>,
100{
101 fn from((api_key, optional_key): (P, Option<P>)) -> Self {
102 Self::ApiKeyWithOptionalKey(
103 api_key.as_ref().to_string(),
104 optional_key.map(|x| x.as_ref().to_string()),
105 )
106 }
107}
108
109impl<P> From<(P, P, P)> for ProviderValue
110where
111 P: AsRef<str>,
112{
113 fn from((api_key, version, header): (P, P, P)) -> Self {
114 Self::ApiKeyWithVersionAndHeader(
115 api_key.as_ref().to_string(),
116 version.as_ref().to_string(),
117 header.as_ref().to_string(),
118 )
119 }
120}
121
122pub trait AsCompletion {
124 fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
125 None
126 }
127}
128
129pub trait AsTranscription {
131 fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
132 None
133 }
134}
135
136pub trait AsEmbeddings {
138 fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
139 None
140 }
141}
142
143pub trait AsAudioGeneration {
145 #[cfg(feature = "audio")]
146 fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
147 None
148 }
149}
150
151pub trait AsImageGeneration {
153 #[cfg(feature = "image")]
154 fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
155 None
156 }
157}
158
159pub trait AsVerify {
161 fn as_verify(&self) -> Option<Box<dyn VerifyClientDyn>> {
162 None
163 }
164}
165
166#[cfg(not(feature = "audio"))]
167impl<T: ProviderClient> AsAudioGeneration for T {}
168
169#[cfg(not(feature = "image"))]
170impl<T: ProviderClient> AsImageGeneration for T {}
171
172#[macro_export]
173macro_rules! impl_conversion_traits {
174 ($( $trait_:ident ),* for $($type_spec:tt)+) => {
175 impl_conversion_traits!(@expand_traits [$($trait_)+] $($type_spec)+);
176 };
177
178 (@expand_traits [$trait_:ident $($rest_traits:ident)*] $($type_spec:tt)+) => {
179 impl_conversion_traits!(@impl $trait_ for $($type_spec)+);
180 impl_conversion_traits!(@expand_traits [$($rest_traits)*] $($type_spec)+);
181 };
182
183 (@expand_traits [] $($type_spec:tt)+) => {};
184
185 (@impl AsAudioGeneration for $($type_spec:tt)+) => {
186 rig::client::impl_audio_generation!($($type_spec)+);
187 };
188
189 (@impl AsImageGeneration for $($type_spec:tt)+) => {
190 rig::client::impl_image_generation!($($type_spec)+);
191 };
192
193 (@impl $trait_:ident for $($type_spec:tt)+) => {
194 impl_conversion_traits!(@impl_trait $trait_ for $($type_spec)+);
195 };
196
197 (@impl_trait $trait_:ident for $struct_:ident) => {
198 impl rig::client::$trait_ for $struct_ {}
199 };
200
201 (@impl_trait $trait_:ident for $struct_:ident<$($generics:tt),*>) => {
202 impl<$($generics),*> rig::client::$trait_ for $struct_<$($generics),*> {}
203 };
204}
205
206#[cfg(feature = "audio")]
207#[macro_export]
208macro_rules! impl_audio_generation {
209 ($struct_:ident) => {
210 impl rig::client::AsAudioGeneration for $struct_ {}
211 };
212 ($struct_:ident<$($generics:tt),*>) => {
213 impl<$($generics),*> rig::client::AsAudioGeneration for $struct_<$($generics),*> {}
214 };
215}
216
217#[cfg(not(feature = "audio"))]
218#[macro_export]
219macro_rules! impl_audio_generation {
220 ($($tokens:tt)*) => {};
221}
222
223#[cfg(feature = "image")]
224#[macro_export]
225macro_rules! impl_image_generation {
226 ($struct_:ident) => {
227 impl rig::client::AsImageGeneration for $struct_ {}
228 };
229 ($struct_:ident<$($generics:tt),*>) => {
230 impl<$($generics),*> rig::client::AsImageGeneration for $struct_<$($generics),*> {}
231 };
232}
233
234#[cfg(not(feature = "image"))]
235#[macro_export]
236macro_rules! impl_image_generation {
237 ($($tokens:tt)*) => {};
238}
239
240pub use impl_audio_generation;
241pub use impl_conversion_traits;
242pub use impl_image_generation;
243
244#[cfg(feature = "audio")]
245use crate::client::audio_generation::AudioGenerationClientDyn;
246use crate::client::completion::CompletionClientDyn;
247use crate::client::embeddings::EmbeddingsClientDyn;
248#[cfg(feature = "image")]
249use crate::client::image_generation::ImageGenerationClientDyn;
250use crate::client::transcription::TranscriptionClientDyn;
251use crate::client::verify::VerifyClientDyn;
252
253#[cfg(feature = "audio")]
254pub use crate::client::audio_generation::AudioGenerationClient;
255pub use crate::client::completion::CompletionClient;
256pub use crate::client::embeddings::EmbeddingsClient;
257#[cfg(feature = "image")]
258pub use crate::client::image_generation::ImageGenerationClient;
259pub use crate::client::transcription::TranscriptionClient;
260pub use crate::client::verify::{VerifyClient, VerifyError};
261
262#[cfg(test)]
263mod tests {
264 use crate::OneOrMany;
265 use crate::client::ProviderClient;
266 use crate::completion::{Completion, CompletionRequest, ToolDefinition};
267 use crate::image_generation::ImageGenerationRequest;
268 use crate::message::AssistantContent;
269 use crate::providers::{
270 anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
271 moonshot, openai, openrouter, together, xai,
272 };
273 use crate::streaming::StreamingCompletion;
274 use crate::tool::Tool;
275 use crate::transcription::TranscriptionRequest;
276 use futures::StreamExt;
277 use rig::message::Message;
278 use rig::providers::{groq, ollama, perplexity};
279 use serde::{Deserialize, Serialize};
280 use serde_json::json;
281 use std::fs::File;
282 use std::io::Read;
283
284 use super::ProviderValue;
285
286 struct ClientConfig {
287 name: &'static str,
288 factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
289 #[allow(dead_code)]
291 factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
292 env_variable: &'static str,
293 completion_model: Option<&'static str>,
294 embeddings_model: Option<&'static str>,
295 transcription_model: Option<&'static str>,
296 image_generation_model: Option<&'static str>,
297 audio_generation_model: Option<(&'static str, &'static str)>,
298 }
299
300 impl Default for ClientConfig {
301 fn default() -> Self {
302 Self {
303 name: "",
304 factory_env: Box::new(|| panic!("Not implemented")),
305 factory_val: Box::new(|_| panic!("Not implemented")),
306 env_variable: "",
307 completion_model: None,
308 embeddings_model: None,
309 transcription_model: None,
310 image_generation_model: None,
311 audio_generation_model: None,
312 }
313 }
314 }
315
316 impl ClientConfig {
317 fn is_env_var_set(&self) -> bool {
318 self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
319 }
320
321 fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
322 self.factory_env.as_ref()()
323 }
324 }
325
326 fn providers() -> Vec<ClientConfig> {
327 vec![
328 ClientConfig {
329 name: "Anthropic",
330 factory_env: Box::new(anthropic::Client::<reqwest::Client>::from_env_boxed),
331 factory_val: Box::new(anthropic::Client::<reqwest::Client>::from_val_boxed),
332 env_variable: "ANTHROPIC_API_KEY",
333 completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
334 ..Default::default()
335 },
336 ClientConfig {
337 name: "Cohere",
338 factory_env: Box::new(cohere::Client::<reqwest::Client>::from_env_boxed),
339 factory_val: Box::new(cohere::Client::<reqwest::Client>::from_val_boxed),
340 env_variable: "COHERE_API_KEY",
341 completion_model: Some(cohere::COMMAND_R),
342 embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
343 ..Default::default()
344 },
345 ClientConfig {
346 name: "Gemini",
347 factory_env: Box::new(gemini::Client::<reqwest::Client>::from_env_boxed),
348 factory_val: Box::new(gemini::Client::<reqwest::Client>::from_val_boxed),
349 env_variable: "GEMINI_API_KEY",
350 completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
351 embeddings_model: Some(gemini::embedding::EMBEDDING_001),
352 transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
353 ..Default::default()
354 },
355 ClientConfig {
356 name: "Huggingface",
357 factory_env: Box::new(huggingface::Client::<reqwest::Client>::from_env_boxed),
358 factory_val: Box::new(huggingface::Client::<reqwest::Client>::from_val_boxed),
359 env_variable: "HUGGINGFACE_API_KEY",
360 completion_model: Some(huggingface::PHI_4),
361 transcription_model: Some(huggingface::WHISPER_SMALL),
362 image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
363 ..Default::default()
364 },
365 ClientConfig {
366 name: "OpenAI",
367 factory_env: Box::new(openai::Client::<reqwest::Client>::from_env_boxed),
368 factory_val: Box::new(openai::Client::<reqwest::Client>::from_val_boxed),
369 env_variable: "OPENAI_API_KEY",
370 completion_model: Some(openai::GPT_4O),
371 embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
372 transcription_model: Some(openai::WHISPER_1),
373 image_generation_model: Some(openai::DALL_E_2),
374 audio_generation_model: Some((openai::TTS_1, "onyx")),
375 },
376 ClientConfig {
377 name: "OpenRouter",
378 factory_env: Box::new(openrouter::Client::<reqwest::Client>::from_env_boxed),
379 factory_val: Box::new(openrouter::Client::<reqwest::Client>::from_val_boxed),
380 env_variable: "OPENROUTER_API_KEY",
381 completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
382 ..Default::default()
383 },
384 ClientConfig {
385 name: "Together",
386 factory_env: Box::new(together::Client::<reqwest::Client>::from_env_boxed),
387 factory_val: Box::new(together::Client::<reqwest::Client>::from_val_boxed),
388 env_variable: "TOGETHER_API_KEY",
389 completion_model: Some(together::ALPACA_7B),
390 embeddings_model: Some(together::BERT_BASE_UNCASED),
391 ..Default::default()
392 },
393 ClientConfig {
394 name: "XAI",
395 factory_env: Box::new(xai::Client::<reqwest::Client>::from_env_boxed),
396 factory_val: Box::new(xai::Client::<reqwest::Client>::from_val_boxed),
397 env_variable: "XAI_API_KEY",
398 completion_model: Some(xai::GROK_3_MINI),
399 embeddings_model: None,
400 ..Default::default()
401 },
402 ClientConfig {
403 name: "Azure",
404 factory_env: Box::new(azure::Client::<reqwest::Client>::from_env_boxed),
405 factory_val: Box::new(azure::Client::<reqwest::Client>::from_val_boxed),
406 env_variable: "AZURE_API_KEY",
407 completion_model: Some(azure::GPT_4O),
408 embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
409 transcription_model: Some("whisper-1"),
410 image_generation_model: Some("dalle-2"),
411 audio_generation_model: Some(("tts-1", "onyx")),
412 },
413 ClientConfig {
414 name: "Deepseek",
415 factory_env: Box::new(deepseek::Client::<reqwest::Client>::from_env_boxed),
416 factory_val: Box::new(deepseek::Client::<reqwest::Client>::from_val_boxed),
417 env_variable: "DEEPSEEK_API_KEY",
418 completion_model: Some(deepseek::DEEPSEEK_CHAT),
419 ..Default::default()
420 },
421 ClientConfig {
422 name: "Galadriel",
423 factory_env: Box::new(galadriel::Client::<reqwest::Client>::from_env_boxed),
424 factory_val: Box::new(galadriel::Client::<reqwest::Client>::from_val_boxed),
425 env_variable: "GALADRIEL_API_KEY",
426 completion_model: Some(galadriel::GPT_4O),
427 ..Default::default()
428 },
429 ClientConfig {
430 name: "Groq",
431 factory_env: Box::new(groq::Client::<reqwest::Client>::from_env_boxed),
432 factory_val: Box::new(groq::Client::<reqwest::Client>::from_val_boxed),
433 env_variable: "GROQ_API_KEY",
434 completion_model: Some(groq::MIXTRAL_8X7B_32768),
435 transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
436 ..Default::default()
437 },
438 ClientConfig {
439 name: "Hyperbolic",
440 factory_env: Box::new(hyperbolic::Client::<reqwest::Client>::from_env_boxed),
441 factory_val: Box::new(hyperbolic::Client::<reqwest::Client>::from_val_boxed),
442 env_variable: "HYPERBOLIC_API_KEY",
443 completion_model: Some(hyperbolic::LLAMA_3_1_8B),
444 image_generation_model: Some(hyperbolic::SD1_5),
445 audio_generation_model: Some(("EN", "EN-US")),
446 ..Default::default()
447 },
448 ClientConfig {
449 name: "Mira",
450 factory_env: Box::new(mira::Client::<reqwest::Client>::from_env_boxed),
451 factory_val: Box::new(mira::Client::<reqwest::Client>::from_val_boxed),
452 env_variable: "MIRA_API_KEY",
453 completion_model: Some("gpt-4o"),
454 ..Default::default()
455 },
456 ClientConfig {
457 name: "Moonshot",
458 factory_env: Box::new(moonshot::Client::<reqwest::Client>::from_env_boxed),
459 factory_val: Box::new(moonshot::Client::<reqwest::Client>::from_val_boxed),
460 env_variable: "MOONSHOT_API_KEY",
461 completion_model: Some(moonshot::MOONSHOT_CHAT),
462 ..Default::default()
463 },
464 ClientConfig {
465 name: "Ollama",
466 factory_env: Box::new(ollama::Client::<reqwest::Client>::from_env_boxed),
467 factory_val: Box::new(ollama::Client::<reqwest::Client>::from_val_boxed),
468 env_variable: "OLLAMA_ENABLED",
469 completion_model: Some("llama3.1:8b"),
470 embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
471 ..Default::default()
472 },
473 ClientConfig {
474 name: "Perplexity",
475 factory_env: Box::new(perplexity::Client::<reqwest::Client>::from_env_boxed),
476 factory_val: Box::new(perplexity::Client::<reqwest::Client>::from_val_boxed),
477 env_variable: "PERPLEXITY_API_KEY",
478 completion_model: Some(perplexity::SONAR),
479 ..Default::default()
480 },
481 ]
482 }
483
484 async fn test_completions_client(config: &ClientConfig) {
485 let client = config.factory_env();
486
487 let Some(client) = client.as_completion() else {
488 return;
489 };
490
491 let model = config
492 .completion_model
493 .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
494
495 let model = client.completion_model(model);
496
497 let resp = model
498 .completion_request(Message::user("Whats the capital of France?"))
499 .send()
500 .await;
501
502 assert!(
503 resp.is_ok(),
504 "[{}]: Error occurred when prompting, {}",
505 config.name,
506 resp.err().unwrap()
507 );
508
509 let resp = resp.unwrap();
510
511 match resp.choice.first() {
512 AssistantContent::Text(text) => {
513 assert!(text.text.to_lowercase().contains("paris"));
514 }
515 _ => {
516 unreachable!(
517 "[{}]: First choice wasn't a Text message, {:?}",
518 config.name,
519 resp.choice.first()
520 );
521 }
522 }
523 }
524
525 #[tokio::test]
526 #[ignore]
527 async fn test_completions() {
528 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
529 test_completions_client(&p).await;
530 }
531 }
532
533 async fn test_tools_client(config: &ClientConfig) {
534 let client = config.factory_env();
535 let model = config
536 .completion_model
537 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
538
539 let Some(client) = client.as_completion() else {
540 return;
541 };
542
543 let model = client.agent(model)
544 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
545 .max_tokens(1024)
546 .tool(Adder)
547 .tool(Subtract)
548 .build();
549
550 let request = model.completion("Calculate 2 - 5", vec![]).await;
551
552 assert!(
553 request.is_ok(),
554 "[{}]: Error occurred when building prompt, {}",
555 config.name,
556 request.err().unwrap()
557 );
558
559 let resp = request.unwrap().send().await;
560
561 assert!(
562 resp.is_ok(),
563 "[{}]: Error occurred when prompting, {}",
564 config.name,
565 resp.err().unwrap()
566 );
567
568 let resp = resp.unwrap();
569
570 assert!(
571 resp.choice.iter().any(|content| match content {
572 AssistantContent::ToolCall(tc) => {
573 if tc.function.name != Subtract::NAME {
574 return false;
575 }
576
577 let arguments =
578 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
579 .expect("Error parsing arguments");
580
581 arguments.x == 2.0 && arguments.y == 5.0
582 }
583 _ => false,
584 }),
585 "[{}]: Model did not use the Subtract tool.",
586 config.name
587 )
588 }
589
590 #[tokio::test]
591 #[ignore]
592 async fn test_tools() {
593 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
594 test_tools_client(&p).await;
595 }
596 }
597
598 async fn test_streaming_client(config: &ClientConfig) {
599 let client = config.factory_env();
600
601 let Some(client) = client.as_completion() else {
602 return;
603 };
604
605 let model = config
606 .completion_model
607 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
608
609 let model = client.completion_model(model);
610
611 let resp = model.stream(CompletionRequest {
612 preamble: None,
613 tools: vec![],
614 documents: vec![],
615 temperature: None,
616 max_tokens: None,
617 additional_params: None,
618 tool_choice: None,
619 chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
620 });
621
622 let mut resp = resp.await.unwrap();
623
624 let mut received_chunk = false;
625
626 while let Some(chunk) = resp.next().await {
627 received_chunk = true;
628 assert!(chunk.is_ok());
629 }
630
631 assert!(
632 received_chunk,
633 "[{}]: Failed to receive a chunk from stream",
634 config.name
635 );
636
637 for choice in resp.choice {
638 match choice {
639 AssistantContent::Text(text) => {
640 assert!(
641 text.text.to_lowercase().contains("paris"),
642 "[{}]: Did not answer with Paris",
643 config.name
644 );
645 }
646 AssistantContent::ToolCall(_) => {}
647 AssistantContent::Reasoning(_) => {}
648 }
649 }
650 }
651
652 #[tokio::test]
653 #[ignore]
654 async fn test_streaming() {
655 for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
656 test_streaming_client(&provider).await;
657 }
658 }
659
660 async fn test_streaming_tools_client(config: &ClientConfig) {
661 let client = config.factory_env();
662 let model = config
663 .completion_model
664 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
665
666 let Some(client) = client.as_completion() else {
667 return;
668 };
669
670 let model = client.agent(model)
671 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
672 .max_tokens(1024)
673 .tool(Adder)
674 .tool(Subtract)
675 .build();
676
677 let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
678
679 assert!(
680 request.is_ok(),
681 "[{}]: Error occurred when building prompt, {}",
682 config.name,
683 request.err().unwrap()
684 );
685
686 let resp = request.unwrap().stream().await;
687
688 assert!(
689 resp.is_ok(),
690 "[{}]: Error occurred when prompting, {}",
691 config.name,
692 resp.err().unwrap()
693 );
694
695 let mut resp = resp.unwrap();
696
697 let mut received_chunk = false;
698
699 while let Some(chunk) = resp.next().await {
700 received_chunk = true;
701 assert!(chunk.is_ok());
702 }
703
704 assert!(
705 received_chunk,
706 "[{}]: Failed to receive a chunk from stream",
707 config.name
708 );
709
710 assert!(
711 resp.choice.iter().any(|content| match content {
712 AssistantContent::ToolCall(tc) => {
713 if tc.function.name != Subtract::NAME {
714 return false;
715 }
716
717 let arguments =
718 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
719 .expect("Error parsing arguments");
720
721 arguments.x == 2.0 && arguments.y == 5.0
722 }
723 _ => false,
724 }),
725 "[{}]: Model did not use the Subtract tool.",
726 config.name
727 )
728 }
729
730 #[tokio::test]
731 #[ignore]
732 async fn test_streaming_tools() {
733 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
734 test_streaming_tools_client(&p).await;
735 }
736 }
737
738 async fn test_audio_generation_client(config: &ClientConfig) {
739 let client = config.factory_env();
740
741 let Some(client) = client.as_audio_generation() else {
742 return;
743 };
744
745 let (model, voice) = config
746 .audio_generation_model
747 .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
748
749 let model = client.audio_generation_model(model);
750
751 let request = model
752 .audio_generation_request()
753 .text("Hello world!")
754 .voice(voice);
755
756 let resp = request.send().await;
757
758 assert!(
759 resp.is_ok(),
760 "[{}]: Error occurred when sending request, {}",
761 config.name,
762 resp.err().unwrap()
763 );
764
765 let resp = resp.unwrap();
766
767 assert!(
768 !resp.audio.is_empty(),
769 "[{}]: Returned audio was empty",
770 config.name
771 );
772 }
773
774 #[tokio::test]
775 #[ignore]
776 async fn test_audio_generation() {
777 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
778 test_audio_generation_client(&p).await;
779 }
780 }
781
782 fn assert_feature<F, M>(
783 name: &str,
784 feature_name: &str,
785 model_name: &str,
786 feature: Option<F>,
787 model: Option<M>,
788 ) {
789 assert_eq!(
790 feature.is_some(),
791 model.is_some(),
792 "{} has{} implemented {} but config.{} is {}.",
793 name,
794 if feature.is_some() { "" } else { "n't" },
795 feature_name,
796 model_name,
797 if model.is_some() { "some" } else { "none" }
798 );
799 }
800
801 #[test]
802 #[ignore]
803 pub fn test_polymorphism() {
804 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
805 let client = config.factory_env();
806 assert_feature(
807 config.name,
808 "AsCompletion",
809 "completion_model",
810 client.as_completion(),
811 config.completion_model,
812 );
813
814 assert_feature(
815 config.name,
816 "AsEmbeddings",
817 "embeddings_model",
818 client.as_embeddings(),
819 config.embeddings_model,
820 );
821
822 assert_feature(
823 config.name,
824 "AsTranscription",
825 "transcription_model",
826 client.as_transcription(),
827 config.transcription_model,
828 );
829
830 assert_feature(
831 config.name,
832 "AsImageGeneration",
833 "image_generation_model",
834 client.as_image_generation(),
835 config.image_generation_model,
836 );
837
838 assert_feature(
839 config.name,
840 "AsAudioGeneration",
841 "audio_generation_model",
842 client.as_audio_generation(),
843 config.audio_generation_model,
844 )
845 }
846 }
847
848 async fn test_embed_client(config: &ClientConfig) {
849 const TEST: &str = "Hello world.";
850
851 let client = config.factory_env();
852
853 let Some(client) = client.as_embeddings() else {
854 return;
855 };
856
857 let model = config.embeddings_model.unwrap();
858
859 let model = client.embedding_model(model);
860
861 let resp = model.embed_text(TEST).await;
862
863 assert!(
864 resp.is_ok(),
865 "[{}]: Error occurred when sending request, {}",
866 config.name,
867 resp.err().unwrap()
868 );
869
870 let resp = resp.unwrap();
871
872 assert_eq!(resp.document, TEST);
873
874 assert!(
875 !resp.vec.is_empty(),
876 "[{}]: Returned embed was empty",
877 config.name
878 );
879 }
880
881 #[tokio::test]
882 #[ignore]
883 async fn test_embed() {
884 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
885 test_embed_client(&config).await;
886 }
887 }
888
889 async fn test_image_generation_client(config: &ClientConfig) {
890 let client = config.factory_env();
891 let Some(client) = client.as_image_generation() else {
892 return;
893 };
894
895 let model = config.image_generation_model.unwrap();
896
897 let model = client.image_generation_model(model);
898
899 let resp = model
900 .image_generation(ImageGenerationRequest {
901 prompt: "A castle sitting on a large hill.".to_string(),
902 width: 256,
903 height: 256,
904 additional_params: None,
905 })
906 .await;
907
908 assert!(
909 resp.is_ok(),
910 "[{}]: Error occurred when sending request, {}",
911 config.name,
912 resp.err().unwrap()
913 );
914
915 let resp = resp.unwrap();
916
917 assert!(
918 !resp.image.is_empty(),
919 "[{}]: Generated image was empty",
920 config.name
921 );
922 }
923
924 #[tokio::test]
925 #[ignore]
926 async fn test_image_generation() {
927 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
928 test_image_generation_client(&config).await;
929 }
930 }
931
932 async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
933 let client = config.factory_env();
934 let Some(client) = client.as_transcription() else {
935 return;
936 };
937
938 let model = config.image_generation_model.unwrap();
939
940 let model = client.transcription_model(model);
941
942 let resp = model
943 .transcription(TranscriptionRequest {
944 data,
945 filename: "audio.mp3".to_string(),
946 language: "en".to_string(),
947 prompt: None,
948 temperature: None,
949 additional_params: None,
950 })
951 .await;
952
953 assert!(
954 resp.is_ok(),
955 "[{}]: Error occurred when sending request, {}",
956 config.name,
957 resp.err().unwrap()
958 );
959
960 let resp = resp.unwrap();
961
962 assert!(
963 !resp.text.is_empty(),
964 "[{}]: Returned transcription was empty",
965 config.name
966 );
967 }
968
969 #[tokio::test]
970 #[ignore]
971 async fn test_transcription() {
972 let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
973
974 let mut data = Vec::new();
975 let _ = file.read(&mut data);
976
977 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
978 test_transcription_client(&config, data.clone()).await;
979 }
980 }
981
982 #[derive(Deserialize)]
983 struct OperationArgs {
984 x: f32,
985 y: f32,
986 }
987
988 #[derive(Debug, thiserror::Error)]
989 #[error("Math error")]
990 struct MathError;
991
992 #[derive(Deserialize, Serialize)]
993 struct Adder;
994 impl Tool for Adder {
995 const NAME: &'static str = "add";
996
997 type Error = MathError;
998 type Args = OperationArgs;
999 type Output = f32;
1000
1001 async fn definition(&self, _prompt: String) -> ToolDefinition {
1002 ToolDefinition {
1003 name: "add".to_string(),
1004 description: "Add x and y together".to_string(),
1005 parameters: json!({
1006 "type": "object",
1007 "properties": {
1008 "x": {
1009 "type": "number",
1010 "description": "The first number to add"
1011 },
1012 "y": {
1013 "type": "number",
1014 "description": "The second number to add"
1015 }
1016 }
1017 }),
1018 }
1019 }
1020
1021 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1022 println!("[tool-call] Adding {} and {}", args.x, args.y);
1023 let result = args.x + args.y;
1024 Ok(result)
1025 }
1026 }
1027
1028 #[derive(Deserialize, Serialize)]
1029 struct Subtract;
1030 impl Tool for Subtract {
1031 const NAME: &'static str = "subtract";
1032
1033 type Error = MathError;
1034 type Args = OperationArgs;
1035 type Output = f32;
1036
1037 async fn definition(&self, _prompt: String) -> ToolDefinition {
1038 serde_json::from_value(json!({
1039 "name": "subtract",
1040 "description": "Subtract y from x (i.e.: x - y)",
1041 "parameters": {
1042 "type": "object",
1043 "properties": {
1044 "x": {
1045 "type": "number",
1046 "description": "The number to subtract from"
1047 },
1048 "y": {
1049 "type": "number",
1050 "description": "The number to subtract"
1051 }
1052 }
1053 }
1054 }))
1055 .expect("Tool Definition")
1056 }
1057
1058 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1059 println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1060 let result = args.x - args.y;
1061 Ok(result)
1062 }
1063 }
1064}