1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11pub mod verify;
12
13#[cfg(feature = "derive")]
14pub use rig_derive::ProviderClient;
15use std::fmt::Debug;
16use thiserror::Error;
17
18#[derive(Debug, Error)]
19#[non_exhaustive]
20pub enum ClientBuilderError {
21 #[error("reqwest error: {0}")]
22 HttpError(
23 #[from]
24 #[source]
25 reqwest::Error,
26 ),
27 #[error("invalid property: {0}")]
28 InvalidProperty(&'static str),
29}
30
31pub trait ProviderClient:
37 AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
38{
39 fn from_env() -> Self
42 where
43 Self: Sized;
44
45 fn boxed(self) -> Box<dyn ProviderClient>
47 where
48 Self: Sized + 'static,
49 {
50 Box::new(self)
51 }
52
53 fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
56 where
57 Self: Sized,
58 Self: 'a,
59 {
60 Box::new(Self::from_env())
61 }
62
63 fn from_val(input: ProviderValue) -> Self
64 where
65 Self: Sized;
66
67 fn from_val_boxed<'a>(input: ProviderValue) -> Box<dyn ProviderClient + 'a>
70 where
71 Self: Sized,
72 Self: 'a,
73 {
74 Box::new(Self::from_val(input))
75 }
76}
77
78#[derive(Clone)]
79pub enum ProviderValue {
80 Simple(String),
81 ApiKeyWithOptionalKey(String, Option<String>),
82 ApiKeyWithVersionAndHeader(String, String, String),
83}
84
85impl From<&str> for ProviderValue {
86 fn from(value: &str) -> Self {
87 Self::Simple(value.to_string())
88 }
89}
90
91impl From<String> for ProviderValue {
92 fn from(value: String) -> Self {
93 Self::Simple(value)
94 }
95}
96
97impl<P> From<(P, Option<P>)> for ProviderValue
98where
99 P: AsRef<str>,
100{
101 fn from((api_key, optional_key): (P, Option<P>)) -> Self {
102 Self::ApiKeyWithOptionalKey(
103 api_key.as_ref().to_string(),
104 optional_key.map(|x| x.as_ref().to_string()),
105 )
106 }
107}
108
109impl<P> From<(P, P, P)> for ProviderValue
110where
111 P: AsRef<str>,
112{
113 fn from((api_key, version, header): (P, P, P)) -> Self {
114 Self::ApiKeyWithVersionAndHeader(
115 api_key.as_ref().to_string(),
116 version.as_ref().to_string(),
117 header.as_ref().to_string(),
118 )
119 }
120}
121
122pub trait AsCompletion {
124 fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
125 None
126 }
127}
128
129pub trait AsTranscription {
131 fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
132 None
133 }
134}
135
136pub trait AsEmbeddings {
138 fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
139 None
140 }
141}
142
143pub trait AsAudioGeneration {
145 #[cfg(feature = "audio")]
146 fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
147 None
148 }
149}
150
151pub trait AsImageGeneration {
153 #[cfg(feature = "image")]
154 fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
155 None
156 }
157}
158
159pub trait AsVerify {
161 fn as_verify(&self) -> Option<Box<dyn VerifyClientDyn>> {
162 None
163 }
164}
165
166#[cfg(not(feature = "audio"))]
167impl<T: ProviderClient> AsAudioGeneration for T {}
168
169#[cfg(not(feature = "image"))]
170impl<T: ProviderClient> AsImageGeneration for T {}
171
172#[macro_export]
181macro_rules! impl_conversion_traits {
182 ($( $trait_:ident ),* for $struct_:ident ) => {
183 $(
184 impl_conversion_traits!(@impl $trait_ for $struct_);
185 )*
186 };
187
188 (@impl AsAudioGeneration for $struct_:ident ) => {
189 rig::client::impl_audio_generation!($struct_);
190 };
191
192 (@impl AsImageGeneration for $struct_:ident ) => {
193 rig::client::impl_image_generation!($struct_);
194 };
195
196 (@impl $trait_:ident for $struct_:ident) => {
197 impl rig::client::$trait_ for $struct_ {}
198 };
199}
200
201#[cfg(feature = "audio")]
202#[macro_export]
203macro_rules! impl_audio_generation {
204 ($struct_:ident) => {
205 impl rig::client::AsAudioGeneration for $struct_ {}
206 };
207}
208
209#[cfg(not(feature = "audio"))]
210#[macro_export]
211macro_rules! impl_audio_generation {
212 ($struct_:ident) => {};
213}
214
215#[cfg(feature = "image")]
216#[macro_export]
217macro_rules! impl_image_generation {
218 ($struct_:ident) => {
219 impl rig::client::AsImageGeneration for $struct_ {}
220 };
221}
222
223#[cfg(not(feature = "image"))]
224#[macro_export]
225macro_rules! impl_image_generation {
226 ($struct_:ident) => {};
227}
228
229pub use impl_audio_generation;
230pub use impl_conversion_traits;
231pub use impl_image_generation;
232
233#[cfg(feature = "audio")]
234use crate::client::audio_generation::AudioGenerationClientDyn;
235use crate::client::completion::CompletionClientDyn;
236use crate::client::embeddings::EmbeddingsClientDyn;
237#[cfg(feature = "image")]
238use crate::client::image_generation::ImageGenerationClientDyn;
239use crate::client::transcription::TranscriptionClientDyn;
240use crate::client::verify::VerifyClientDyn;
241
242#[cfg(feature = "audio")]
243pub use crate::client::audio_generation::AudioGenerationClient;
244pub use crate::client::completion::CompletionClient;
245pub use crate::client::embeddings::EmbeddingsClient;
246#[cfg(feature = "image")]
247pub use crate::client::image_generation::ImageGenerationClient;
248pub use crate::client::transcription::TranscriptionClient;
249pub use crate::client::verify::{VerifyClient, VerifyError};
250
251#[cfg(test)]
252mod tests {
253 use crate::OneOrMany;
254 use crate::client::ProviderClient;
255 use crate::completion::{Completion, CompletionRequest, ToolDefinition};
256 use crate::image_generation::ImageGenerationRequest;
257 use crate::message::AssistantContent;
258 use crate::providers::{
259 anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
260 moonshot, openai, openrouter, together, xai,
261 };
262 use crate::streaming::StreamingCompletion;
263 use crate::tool::Tool;
264 use crate::transcription::TranscriptionRequest;
265 use futures::StreamExt;
266 use rig::message::Message;
267 use rig::providers::{groq, ollama, perplexity};
268 use serde::{Deserialize, Serialize};
269 use serde_json::json;
270 use std::fs::File;
271 use std::io::Read;
272
273 use super::ProviderValue;
274
275 struct ClientConfig {
276 name: &'static str,
277 factory_env: Box<dyn Fn() -> Box<dyn ProviderClient>>,
278 #[allow(dead_code)]
280 factory_val: Box<dyn Fn(ProviderValue) -> Box<dyn ProviderClient>>,
281 env_variable: &'static str,
282 completion_model: Option<&'static str>,
283 embeddings_model: Option<&'static str>,
284 transcription_model: Option<&'static str>,
285 image_generation_model: Option<&'static str>,
286 audio_generation_model: Option<(&'static str, &'static str)>,
287 }
288
289 impl Default for ClientConfig {
290 fn default() -> Self {
291 Self {
292 name: "",
293 factory_env: Box::new(|| panic!("Not implemented")),
294 factory_val: Box::new(|_| panic!("Not implemented")),
295 env_variable: "",
296 completion_model: None,
297 embeddings_model: None,
298 transcription_model: None,
299 image_generation_model: None,
300 audio_generation_model: None,
301 }
302 }
303 }
304
305 impl ClientConfig {
306 fn is_env_var_set(&self) -> bool {
307 self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
308 }
309
310 fn factory_env(&self) -> Box<dyn ProviderClient + '_> {
311 self.factory_env.as_ref()()
312 }
313 }
314
315 fn providers() -> Vec<ClientConfig> {
316 vec![
317 ClientConfig {
318 name: "Anthropic",
319 factory_env: Box::new(anthropic::Client::from_env_boxed),
320 factory_val: Box::new(anthropic::Client::from_val_boxed),
321 env_variable: "ANTHROPIC_API_KEY",
322 completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
323 ..Default::default()
324 },
325 ClientConfig {
326 name: "Cohere",
327 factory_env: Box::new(cohere::Client::from_env_boxed),
328 factory_val: Box::new(cohere::Client::from_val_boxed),
329 env_variable: "COHERE_API_KEY",
330 completion_model: Some(cohere::COMMAND_R),
331 embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
332 ..Default::default()
333 },
334 ClientConfig {
335 name: "Gemini",
336 factory_env: Box::new(gemini::Client::from_env_boxed),
337 factory_val: Box::new(gemini::Client::from_val_boxed),
338 env_variable: "GEMINI_API_KEY",
339 completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
340 embeddings_model: Some(gemini::embedding::EMBEDDING_001),
341 transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
342 ..Default::default()
343 },
344 ClientConfig {
345 name: "Huggingface",
346 factory_env: Box::new(huggingface::Client::from_env_boxed),
347 factory_val: Box::new(huggingface::Client::from_val_boxed),
348 env_variable: "HUGGINGFACE_API_KEY",
349 completion_model: Some(huggingface::PHI_4),
350 transcription_model: Some(huggingface::WHISPER_SMALL),
351 image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
352 ..Default::default()
353 },
354 ClientConfig {
355 name: "OpenAI",
356 factory_env: Box::new(openai::Client::from_env_boxed),
357 factory_val: Box::new(openai::Client::from_val_boxed),
358 env_variable: "OPENAI_API_KEY",
359 completion_model: Some(openai::GPT_4O),
360 embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
361 transcription_model: Some(openai::WHISPER_1),
362 image_generation_model: Some(openai::DALL_E_2),
363 audio_generation_model: Some((openai::TTS_1, "onyx")),
364 },
365 ClientConfig {
366 name: "OpenRouter",
367 factory_env: Box::new(openrouter::Client::from_env_boxed),
368 factory_val: Box::new(openrouter::Client::from_val_boxed),
369 env_variable: "OPENROUTER_API_KEY",
370 completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
371 ..Default::default()
372 },
373 ClientConfig {
374 name: "Together",
375 factory_env: Box::new(together::Client::from_env_boxed),
376 factory_val: Box::new(together::Client::from_val_boxed),
377 env_variable: "TOGETHER_API_KEY",
378 completion_model: Some(together::ALPACA_7B),
379 embeddings_model: Some(together::BERT_BASE_UNCASED),
380 ..Default::default()
381 },
382 ClientConfig {
383 name: "XAI",
384 factory_env: Box::new(xai::Client::from_env_boxed),
385 factory_val: Box::new(xai::Client::from_val_boxed),
386 env_variable: "XAI_API_KEY",
387 completion_model: Some(xai::GROK_3_MINI),
388 embeddings_model: None,
389 ..Default::default()
390 },
391 ClientConfig {
392 name: "Azure",
393 factory_env: Box::new(azure::Client::from_env_boxed),
394 factory_val: Box::new(azure::Client::from_val_boxed),
395 env_variable: "AZURE_API_KEY",
396 completion_model: Some(azure::GPT_4O),
397 embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
398 transcription_model: Some("whisper-1"),
399 image_generation_model: Some("dalle-2"),
400 audio_generation_model: Some(("tts-1", "onyx")),
401 },
402 ClientConfig {
403 name: "Deepseek",
404 factory_env: Box::new(deepseek::Client::from_env_boxed),
405 factory_val: Box::new(deepseek::Client::from_val_boxed),
406 env_variable: "DEEPSEEK_API_KEY",
407 completion_model: Some(deepseek::DEEPSEEK_CHAT),
408 ..Default::default()
409 },
410 ClientConfig {
411 name: "Galadriel",
412 factory_env: Box::new(galadriel::Client::from_env_boxed),
413 factory_val: Box::new(galadriel::Client::from_val_boxed),
414 env_variable: "GALADRIEL_API_KEY",
415 completion_model: Some(galadriel::GPT_4O),
416 ..Default::default()
417 },
418 ClientConfig {
419 name: "Groq",
420 factory_env: Box::new(groq::Client::from_env_boxed),
421 factory_val: Box::new(groq::Client::from_val_boxed),
422 env_variable: "GROQ_API_KEY",
423 completion_model: Some(groq::MIXTRAL_8X7B_32768),
424 transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
425 ..Default::default()
426 },
427 ClientConfig {
428 name: "Hyperbolic",
429 factory_env: Box::new(hyperbolic::Client::from_env_boxed),
430 factory_val: Box::new(hyperbolic::Client::from_val_boxed),
431 env_variable: "HYPERBOLIC_API_KEY",
432 completion_model: Some(hyperbolic::LLAMA_3_1_8B),
433 image_generation_model: Some(hyperbolic::SD1_5),
434 audio_generation_model: Some(("EN", "EN-US")),
435 ..Default::default()
436 },
437 ClientConfig {
438 name: "Mira",
439 factory_env: Box::new(mira::Client::from_env_boxed),
440 factory_val: Box::new(mira::Client::from_val_boxed),
441 env_variable: "MIRA_API_KEY",
442 completion_model: Some("gpt-4o"),
443 ..Default::default()
444 },
445 ClientConfig {
446 name: "Moonshot",
447 factory_env: Box::new(moonshot::Client::from_env_boxed),
448 factory_val: Box::new(moonshot::Client::from_val_boxed),
449 env_variable: "MOONSHOT_API_KEY",
450 completion_model: Some(moonshot::MOONSHOT_CHAT),
451 ..Default::default()
452 },
453 ClientConfig {
454 name: "Ollama",
455 factory_env: Box::new(ollama::Client::from_env_boxed),
456 factory_val: Box::new(ollama::Client::from_val_boxed),
457 env_variable: "OLLAMA_ENABLED",
458 completion_model: Some("llama3.1:8b"),
459 embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
460 ..Default::default()
461 },
462 ClientConfig {
463 name: "Perplexity",
464 factory_env: Box::new(perplexity::Client::from_env_boxed),
465 factory_val: Box::new(perplexity::Client::from_val_boxed),
466 env_variable: "PERPLEXITY_API_KEY",
467 completion_model: Some(perplexity::SONAR),
468 ..Default::default()
469 },
470 ]
471 }
472
473 async fn test_completions_client(config: &ClientConfig) {
474 let client = config.factory_env();
475
476 let Some(client) = client.as_completion() else {
477 return;
478 };
479
480 let model = config
481 .completion_model
482 .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
483
484 let model = client.completion_model(model);
485
486 let resp = model
487 .completion_request(Message::user("Whats the capital of France?"))
488 .send()
489 .await;
490
491 assert!(
492 resp.is_ok(),
493 "[{}]: Error occurred when prompting, {}",
494 config.name,
495 resp.err().unwrap()
496 );
497
498 let resp = resp.unwrap();
499
500 match resp.choice.first() {
501 AssistantContent::Text(text) => {
502 assert!(text.text.to_lowercase().contains("paris"));
503 }
504 _ => {
505 unreachable!(
506 "[{}]: First choice wasn't a Text message, {:?}",
507 config.name,
508 resp.choice.first()
509 );
510 }
511 }
512 }
513
514 #[tokio::test]
515 #[ignore]
516 async fn test_completions() {
517 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
518 test_completions_client(&p).await;
519 }
520 }
521
522 async fn test_tools_client(config: &ClientConfig) {
523 let client = config.factory_env();
524 let model = config
525 .completion_model
526 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
527
528 let Some(client) = client.as_completion() else {
529 return;
530 };
531
532 let model = client.agent(model)
533 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
534 .max_tokens(1024)
535 .tool(Adder)
536 .tool(Subtract)
537 .build();
538
539 let request = model.completion("Calculate 2 - 5", vec![]).await;
540
541 assert!(
542 request.is_ok(),
543 "[{}]: Error occurred when building prompt, {}",
544 config.name,
545 request.err().unwrap()
546 );
547
548 let resp = request.unwrap().send().await;
549
550 assert!(
551 resp.is_ok(),
552 "[{}]: Error occurred when prompting, {}",
553 config.name,
554 resp.err().unwrap()
555 );
556
557 let resp = resp.unwrap();
558
559 assert!(
560 resp.choice.iter().any(|content| match content {
561 AssistantContent::ToolCall(tc) => {
562 if tc.function.name != Subtract::NAME {
563 return false;
564 }
565
566 let arguments =
567 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
568 .expect("Error parsing arguments");
569
570 arguments.x == 2.0 && arguments.y == 5.0
571 }
572 _ => false,
573 }),
574 "[{}]: Model did not use the Subtract tool.",
575 config.name
576 )
577 }
578
579 #[tokio::test]
580 #[ignore]
581 async fn test_tools() {
582 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
583 test_tools_client(&p).await;
584 }
585 }
586
587 async fn test_streaming_client(config: &ClientConfig) {
588 let client = config.factory_env();
589
590 let Some(client) = client.as_completion() else {
591 return;
592 };
593
594 let model = config
595 .completion_model
596 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
597
598 let model = client.completion_model(model);
599
600 let resp = model.stream(CompletionRequest {
601 preamble: None,
602 tools: vec![],
603 documents: vec![],
604 temperature: None,
605 max_tokens: None,
606 additional_params: None,
607 chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
608 });
609
610 let mut resp = resp.await.unwrap();
611
612 let mut received_chunk = false;
613
614 while let Some(chunk) = resp.next().await {
615 received_chunk = true;
616 assert!(chunk.is_ok());
617 }
618
619 assert!(
620 received_chunk,
621 "[{}]: Failed to receive a chunk from stream",
622 config.name
623 );
624
625 for choice in resp.choice {
626 match choice {
627 AssistantContent::Text(text) => {
628 assert!(
629 text.text.to_lowercase().contains("paris"),
630 "[{}]: Did not answer with Paris",
631 config.name
632 );
633 }
634 AssistantContent::ToolCall(_) => {}
635 AssistantContent::Reasoning(_) => {}
636 }
637 }
638 }
639
640 #[tokio::test]
641 #[ignore]
642 async fn test_streaming() {
643 for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
644 test_streaming_client(&provider).await;
645 }
646 }
647
648 async fn test_streaming_tools_client(config: &ClientConfig) {
649 let client = config.factory_env();
650 let model = config
651 .completion_model
652 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
653
654 let Some(client) = client.as_completion() else {
655 return;
656 };
657
658 let model = client.agent(model)
659 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
660 .max_tokens(1024)
661 .tool(Adder)
662 .tool(Subtract)
663 .build();
664
665 let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
666
667 assert!(
668 request.is_ok(),
669 "[{}]: Error occurred when building prompt, {}",
670 config.name,
671 request.err().unwrap()
672 );
673
674 let resp = request.unwrap().stream().await;
675
676 assert!(
677 resp.is_ok(),
678 "[{}]: Error occurred when prompting, {}",
679 config.name,
680 resp.err().unwrap()
681 );
682
683 let mut resp = resp.unwrap();
684
685 let mut received_chunk = false;
686
687 while let Some(chunk) = resp.next().await {
688 received_chunk = true;
689 assert!(chunk.is_ok());
690 }
691
692 assert!(
693 received_chunk,
694 "[{}]: Failed to receive a chunk from stream",
695 config.name
696 );
697
698 assert!(
699 resp.choice.iter().any(|content| match content {
700 AssistantContent::ToolCall(tc) => {
701 if tc.function.name != Subtract::NAME {
702 return false;
703 }
704
705 let arguments =
706 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
707 .expect("Error parsing arguments");
708
709 arguments.x == 2.0 && arguments.y == 5.0
710 }
711 _ => false,
712 }),
713 "[{}]: Model did not use the Subtract tool.",
714 config.name
715 )
716 }
717
718 #[tokio::test]
719 #[ignore]
720 async fn test_streaming_tools() {
721 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
722 test_streaming_tools_client(&p).await;
723 }
724 }
725
726 async fn test_audio_generation_client(config: &ClientConfig) {
727 let client = config.factory_env();
728
729 let Some(client) = client.as_audio_generation() else {
730 return;
731 };
732
733 let (model, voice) = config
734 .audio_generation_model
735 .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
736
737 let model = client.audio_generation_model(model);
738
739 let request = model
740 .audio_generation_request()
741 .text("Hello world!")
742 .voice(voice);
743
744 let resp = request.send().await;
745
746 assert!(
747 resp.is_ok(),
748 "[{}]: Error occurred when sending request, {}",
749 config.name,
750 resp.err().unwrap()
751 );
752
753 let resp = resp.unwrap();
754
755 assert!(
756 !resp.audio.is_empty(),
757 "[{}]: Returned audio was empty",
758 config.name
759 );
760 }
761
762 #[tokio::test]
763 #[ignore]
764 async fn test_audio_generation() {
765 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
766 test_audio_generation_client(&p).await;
767 }
768 }
769
770 fn assert_feature<F, M>(
771 name: &str,
772 feature_name: &str,
773 model_name: &str,
774 feature: Option<F>,
775 model: Option<M>,
776 ) {
777 assert_eq!(
778 feature.is_some(),
779 model.is_some(),
780 "{} has{} implemented {} but config.{} is {}.",
781 name,
782 if feature.is_some() { "" } else { "n't" },
783 feature_name,
784 model_name,
785 if model.is_some() { "some" } else { "none" }
786 );
787 }
788
789 #[test]
790 #[ignore]
791 pub fn test_polymorphism() {
792 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
793 let client = config.factory_env();
794 assert_feature(
795 config.name,
796 "AsCompletion",
797 "completion_model",
798 client.as_completion(),
799 config.completion_model,
800 );
801
802 assert_feature(
803 config.name,
804 "AsEmbeddings",
805 "embeddings_model",
806 client.as_embeddings(),
807 config.embeddings_model,
808 );
809
810 assert_feature(
811 config.name,
812 "AsTranscription",
813 "transcription_model",
814 client.as_transcription(),
815 config.transcription_model,
816 );
817
818 assert_feature(
819 config.name,
820 "AsImageGeneration",
821 "image_generation_model",
822 client.as_image_generation(),
823 config.image_generation_model,
824 );
825
826 assert_feature(
827 config.name,
828 "AsAudioGeneration",
829 "audio_generation_model",
830 client.as_audio_generation(),
831 config.audio_generation_model,
832 )
833 }
834 }
835
836 async fn test_embed_client(config: &ClientConfig) {
837 const TEST: &str = "Hello world.";
838
839 let client = config.factory_env();
840
841 let Some(client) = client.as_embeddings() else {
842 return;
843 };
844
845 let model = config.embeddings_model.unwrap();
846
847 let model = client.embedding_model(model);
848
849 let resp = model.embed_text(TEST).await;
850
851 assert!(
852 resp.is_ok(),
853 "[{}]: Error occurred when sending request, {}",
854 config.name,
855 resp.err().unwrap()
856 );
857
858 let resp = resp.unwrap();
859
860 assert_eq!(resp.document, TEST);
861
862 assert!(
863 !resp.vec.is_empty(),
864 "[{}]: Returned embed was empty",
865 config.name
866 );
867 }
868
869 #[tokio::test]
870 #[ignore]
871 async fn test_embed() {
872 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
873 test_embed_client(&config).await;
874 }
875 }
876
877 async fn test_image_generation_client(config: &ClientConfig) {
878 let client = config.factory_env();
879 let Some(client) = client.as_image_generation() else {
880 return;
881 };
882
883 let model = config.image_generation_model.unwrap();
884
885 let model = client.image_generation_model(model);
886
887 let resp = model
888 .image_generation(ImageGenerationRequest {
889 prompt: "A castle sitting on a large hill.".to_string(),
890 width: 256,
891 height: 256,
892 additional_params: None,
893 })
894 .await;
895
896 assert!(
897 resp.is_ok(),
898 "[{}]: Error occurred when sending request, {}",
899 config.name,
900 resp.err().unwrap()
901 );
902
903 let resp = resp.unwrap();
904
905 assert!(
906 !resp.image.is_empty(),
907 "[{}]: Generated image was empty",
908 config.name
909 );
910 }
911
912 #[tokio::test]
913 #[ignore]
914 async fn test_image_generation() {
915 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
916 test_image_generation_client(&config).await;
917 }
918 }
919
920 async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
921 let client = config.factory_env();
922 let Some(client) = client.as_transcription() else {
923 return;
924 };
925
926 let model = config.image_generation_model.unwrap();
927
928 let model = client.transcription_model(model);
929
930 let resp = model
931 .transcription(TranscriptionRequest {
932 data,
933 filename: "audio.mp3".to_string(),
934 language: "en".to_string(),
935 prompt: None,
936 temperature: None,
937 additional_params: None,
938 })
939 .await;
940
941 assert!(
942 resp.is_ok(),
943 "[{}]: Error occurred when sending request, {}",
944 config.name,
945 resp.err().unwrap()
946 );
947
948 let resp = resp.unwrap();
949
950 assert!(
951 !resp.text.is_empty(),
952 "[{}]: Returned transcription was empty",
953 config.name
954 );
955 }
956
957 #[tokio::test]
958 #[ignore]
959 async fn test_transcription() {
960 let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
961
962 let mut data = Vec::new();
963 let _ = file.read(&mut data);
964
965 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
966 test_transcription_client(&config, data.clone()).await;
967 }
968 }
969
970 #[derive(Deserialize)]
971 struct OperationArgs {
972 x: f32,
973 y: f32,
974 }
975
976 #[derive(Debug, thiserror::Error)]
977 #[error("Math error")]
978 struct MathError;
979
980 #[derive(Deserialize, Serialize)]
981 struct Adder;
982 impl Tool for Adder {
983 const NAME: &'static str = "add";
984
985 type Error = MathError;
986 type Args = OperationArgs;
987 type Output = f32;
988
989 async fn definition(&self, _prompt: String) -> ToolDefinition {
990 ToolDefinition {
991 name: "add".to_string(),
992 description: "Add x and y together".to_string(),
993 parameters: json!({
994 "type": "object",
995 "properties": {
996 "x": {
997 "type": "number",
998 "description": "The first number to add"
999 },
1000 "y": {
1001 "type": "number",
1002 "description": "The second number to add"
1003 }
1004 }
1005 }),
1006 }
1007 }
1008
1009 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1010 println!("[tool-call] Adding {} and {}", args.x, args.y);
1011 let result = args.x + args.y;
1012 Ok(result)
1013 }
1014 }
1015
1016 #[derive(Deserialize, Serialize)]
1017 struct Subtract;
1018 impl Tool for Subtract {
1019 const NAME: &'static str = "subtract";
1020
1021 type Error = MathError;
1022 type Args = OperationArgs;
1023 type Output = f32;
1024
1025 async fn definition(&self, _prompt: String) -> ToolDefinition {
1026 serde_json::from_value(json!({
1027 "name": "subtract",
1028 "description": "Subtract y from x (i.e.: x - y)",
1029 "parameters": {
1030 "type": "object",
1031 "properties": {
1032 "x": {
1033 "type": "number",
1034 "description": "The number to subtract from"
1035 },
1036 "y": {
1037 "type": "number",
1038 "description": "The number to subtract"
1039 }
1040 }
1041 }
1042 }))
1043 .expect("Tool Definition")
1044 }
1045
1046 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
1047 println!("[tool-call] Subtracting {} from {}", args.y, args.x);
1048 let result = args.x - args.y;
1049 Ok(result)
1050 }
1051 }
1052}