1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15
16pub trait ProviderClient:
22 AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
23{
24 fn from_env() -> Self
27 where
28 Self: Sized;
29
30 fn boxed(self) -> Box<dyn ProviderClient>
32 where
33 Self: Sized + 'static,
34 {
35 Box::new(self)
36 }
37
38 fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
41 where
42 Self: Sized,
43 Self: 'a,
44 {
45 Box::new(Self::from_env())
46 }
47}
48
49pub trait AsCompletion {
51 fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
52 None
53 }
54}
55
56pub trait AsTranscription {
58 fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
59 None
60 }
61}
62
63pub trait AsEmbeddings {
65 fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
66 None
67 }
68}
69
70pub trait AsAudioGeneration {
72 #[cfg(feature = "audio")]
73 fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
74 None
75 }
76}
77
78pub trait AsImageGeneration {
80 #[cfg(feature = "image")]
81 fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
82 None
83 }
84}
85
86#[cfg(not(feature = "audio"))]
87impl<T: ProviderClient> AsAudioGeneration for T {}
88
89#[cfg(not(feature = "image"))]
90impl<T: ProviderClient> AsImageGeneration for T {}
91
92#[macro_export]
101macro_rules! impl_conversion_traits {
102 ($( $trait_:ident ),* for $struct_:ident ) => {
103 $(
104 impl_conversion_traits!(@impl $trait_ for $struct_);
105 )*
106 };
107
108 (@impl AsAudioGeneration for $struct_:ident ) => {
109 rig::client::impl_audio_generation!($struct_);
110 };
111
112 (@impl AsImageGeneration for $struct_:ident ) => {
113 rig::client::impl_image_generation!($struct_);
114 };
115
116 (@impl $trait_:ident for $struct_:ident) => {
117 impl rig::client::$trait_ for $struct_ {}
118 };
119}
120
121#[cfg(feature = "audio")]
122#[macro_export]
123macro_rules! impl_audio_generation {
124 ($struct_:ident) => {
125 impl rig::client::AsAudioGeneration for $struct_ {}
126 };
127}
128
129#[cfg(not(feature = "audio"))]
130#[macro_export]
131macro_rules! impl_audio_generation {
132 ($struct_:ident) => {};
133}
134
135#[cfg(feature = "image")]
136#[macro_export]
137macro_rules! impl_image_generation {
138 ($struct_:ident) => {
139 impl rig::client::AsImageGeneration for $struct_ {}
140 };
141}
142
143#[cfg(not(feature = "image"))]
144#[macro_export]
145macro_rules! impl_image_generation {
146 ($struct_:ident) => {};
147}
148
149pub use impl_audio_generation;
150pub use impl_conversion_traits;
151pub use impl_image_generation;
152
153#[cfg(feature = "audio")]
154use crate::client::audio_generation::AudioGenerationClientDyn;
155use crate::client::completion::CompletionClientDyn;
156use crate::client::embeddings::EmbeddingsClientDyn;
157#[cfg(feature = "image")]
158use crate::client::image_generation::ImageGenerationClientDyn;
159use crate::client::transcription::TranscriptionClientDyn;
160
161#[cfg(feature = "audio")]
162pub use crate::client::audio_generation::AudioGenerationClient;
163pub use crate::client::completion::CompletionClient;
164pub use crate::client::embeddings::EmbeddingsClient;
165#[cfg(feature = "image")]
166pub use crate::client::image_generation::ImageGenerationClient;
167pub use crate::client::transcription::TranscriptionClient;
168
169#[cfg(test)]
170mod tests {
171 use crate::OneOrMany;
172 use crate::client::ProviderClient;
173 use crate::completion::{Completion, CompletionRequest, ToolDefinition};
174 use crate::image_generation::ImageGenerationRequest;
175 use crate::message::AssistantContent;
176 use crate::providers::{
177 anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
178 moonshot, openai, openrouter, together, xai,
179 };
180 use crate::streaming::StreamingCompletion;
181 use crate::tool::Tool;
182 use crate::transcription::TranscriptionRequest;
183 use futures::StreamExt;
184 use rig::message::Message;
185 use rig::providers::{groq, ollama, perplexity};
186 use serde::{Deserialize, Serialize};
187 use serde_json::json;
188 use std::fs::File;
189 use std::io::Read;
190
191 struct ClientConfig {
192 name: &'static str,
193 factory: Box<dyn Fn() -> Box<dyn ProviderClient>>,
194 env_variable: &'static str,
195 completion_model: Option<&'static str>,
196 embeddings_model: Option<&'static str>,
197 transcription_model: Option<&'static str>,
198 image_generation_model: Option<&'static str>,
199 audio_generation_model: Option<(&'static str, &'static str)>,
200 }
201
202 impl Default for ClientConfig {
203 fn default() -> Self {
204 Self {
205 name: "",
206 factory: Box::new(|| panic!("Not implemented")),
207 env_variable: "",
208 completion_model: None,
209 embeddings_model: None,
210 transcription_model: None,
211 image_generation_model: None,
212 audio_generation_model: None,
213 }
214 }
215 }
216
217 impl ClientConfig {
218 fn is_env_var_set(&self) -> bool {
219 self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
220 }
221
222 fn factory(&self) -> Box<dyn ProviderClient + '_> {
223 self.factory.as_ref()()
224 }
225 }
226
227 fn providers() -> Vec<ClientConfig> {
228 vec![
229 ClientConfig {
230 name: "Anthropic",
231 factory: Box::new(anthropic::Client::from_env_boxed),
232 env_variable: "ANTHROPIC_API_KEY",
233 completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
234 ..Default::default()
235 },
236 ClientConfig {
237 name: "Cohere",
238 factory: Box::new(cohere::Client::from_env_boxed),
239 env_variable: "COHERE_API_KEY",
240 completion_model: Some(cohere::COMMAND_R),
241 embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
242 ..Default::default()
243 },
244 ClientConfig {
245 name: "Gemini",
246 factory: Box::new(gemini::Client::from_env_boxed),
247 env_variable: "GEMINI_API_KEY",
248 completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
249 embeddings_model: Some(gemini::embedding::EMBEDDING_001),
250 transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
251 ..Default::default()
252 },
253 ClientConfig {
254 name: "Huggingface",
255 factory: Box::new(huggingface::Client::from_env_boxed),
256 env_variable: "HUGGINGFACE_API_KEY",
257 completion_model: Some(huggingface::PHI_4),
258 transcription_model: Some(huggingface::WHISPER_SMALL),
259 image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
260 ..Default::default()
261 },
262 ClientConfig {
263 name: "OpenAI",
264 factory: Box::new(openai::Client::from_env_boxed),
265 env_variable: "OPENAI_API_KEY",
266 completion_model: Some(openai::GPT_4O),
267 embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
268 transcription_model: Some(openai::WHISPER_1),
269 image_generation_model: Some(openai::DALL_E_2),
270 audio_generation_model: Some((openai::TTS_1, "onyx")),
271 },
272 ClientConfig {
273 name: "OpenRouter",
274 factory: Box::new(openrouter::Client::from_env_boxed),
275 env_variable: "OPENROUTER_API_KEY",
276 completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
277 ..Default::default()
278 },
279 ClientConfig {
280 name: "Together",
281 factory: Box::new(together::Client::from_env_boxed),
282 env_variable: "TOGETHER_API_KEY",
283 completion_model: Some(together::ALPACA_7B),
284 embeddings_model: Some(together::BERT_BASE_UNCASED),
285 ..Default::default()
286 },
287 ClientConfig {
288 name: "XAI",
289 factory: Box::new(xai::Client::from_env_boxed),
290 env_variable: "XAI_API_KEY",
291 completion_model: Some(xai::GROK_3_MINI),
292 embeddings_model: None,
293 ..Default::default()
294 },
295 ClientConfig {
296 name: "Azure",
297 factory: Box::new(azure::Client::from_env_boxed),
298 env_variable: "AZURE_API_KEY",
299 completion_model: Some(azure::GPT_4O),
300 embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
301 transcription_model: Some("whisper-1"),
302 image_generation_model: Some("dalle-2"),
303 audio_generation_model: Some(("tts-1", "onyx")),
304 },
305 ClientConfig {
306 name: "Deepseek",
307 factory: Box::new(deepseek::Client::from_env_boxed),
308 env_variable: "DEEPSEEK_API_KEY",
309 completion_model: Some(deepseek::DEEPSEEK_CHAT),
310 ..Default::default()
311 },
312 ClientConfig {
313 name: "Galadriel",
314 factory: Box::new(galadriel::Client::from_env_boxed),
315 env_variable: "GALADRIEL_API_KEY",
316 completion_model: Some(galadriel::GPT_4O),
317 ..Default::default()
318 },
319 ClientConfig {
320 name: "Groq",
321 factory: Box::new(groq::Client::from_env_boxed),
322 env_variable: "GROQ_API_KEY",
323 completion_model: Some(groq::MIXTRAL_8X7B_32768),
324 transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
325 ..Default::default()
326 },
327 ClientConfig {
328 name: "Hyperbolic",
329 factory: Box::new(hyperbolic::Client::from_env_boxed),
330 env_variable: "HYPERBOLIC_API_KEY",
331 completion_model: Some(hyperbolic::LLAMA_3_1_8B),
332 image_generation_model: Some(hyperbolic::SD1_5),
333 audio_generation_model: Some(("EN", "EN-US")),
334 ..Default::default()
335 },
336 ClientConfig {
337 name: "Mira",
338 factory: Box::new(mira::Client::from_env_boxed),
339 env_variable: "MIRA_API_KEY",
340 completion_model: Some("gpt-4o"),
341 ..Default::default()
342 },
343 ClientConfig {
344 name: "Moonshot",
345 factory: Box::new(moonshot::Client::from_env_boxed),
346 env_variable: "MOONSHOT_API_KEY",
347 completion_model: Some(moonshot::MOONSHOT_CHAT),
348 ..Default::default()
349 },
350 ClientConfig {
351 name: "Ollama",
352 factory: Box::new(ollama::Client::from_env_boxed),
353 env_variable: "OLLAMA_ENABLED",
354 completion_model: Some("llama3.1:8b"),
355 embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
356 ..Default::default()
357 },
358 ClientConfig {
359 name: "Perplexity",
360 factory: Box::new(perplexity::Client::from_env_boxed),
361 env_variable: "PERPLEXITY_API_KEY",
362 completion_model: Some(perplexity::SONAR),
363 ..Default::default()
364 },
365 ]
366 }
367
368 async fn test_completions_client(config: &ClientConfig) {
369 let client = config.factory();
370
371 let Some(client) = client.as_completion() else {
372 return;
373 };
374
375 let model = config
376 .completion_model
377 .unwrap_or_else(|| panic!("{} does not have completion_model set", config.name));
378
379 let model = client.completion_model(model);
380
381 let resp = model
382 .completion_request(Message::user("Whats the capital of France?"))
383 .send()
384 .await;
385
386 assert!(
387 resp.is_ok(),
388 "[{}]: Error occurred when prompting, {}",
389 config.name,
390 resp.err().unwrap()
391 );
392
393 let resp = resp.unwrap();
394
395 match resp.choice.first() {
396 AssistantContent::Text(text) => {
397 assert!(text.text.to_lowercase().contains("paris"));
398 }
399 _ => {
400 unreachable!(
401 "[{}]: First choice wasn't a Text message, {:?}",
402 config.name,
403 resp.choice.first()
404 );
405 }
406 }
407 }
408
409 #[tokio::test]
410 #[ignore]
411 async fn test_completions() {
412 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
413 test_completions_client(&p).await;
414 }
415 }
416
417 async fn test_tools_client(config: &ClientConfig) {
418 let client = config.factory();
419 let model = config
420 .completion_model
421 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
422
423 let Some(client) = client.as_completion() else {
424 return;
425 };
426
427 let model = client.agent(model)
428 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
429 .max_tokens(1024)
430 .tool(Adder)
431 .tool(Subtract)
432 .build();
433
434 let request = model.completion("Calculate 2 - 5", vec![]).await;
435
436 assert!(
437 request.is_ok(),
438 "[{}]: Error occurred when building prompt, {}",
439 config.name,
440 request.err().unwrap()
441 );
442
443 let resp = request.unwrap().send().await;
444
445 assert!(
446 resp.is_ok(),
447 "[{}]: Error occurred when prompting, {}",
448 config.name,
449 resp.err().unwrap()
450 );
451
452 let resp = resp.unwrap();
453
454 assert!(
455 resp.choice.iter().any(|content| match content {
456 AssistantContent::ToolCall(tc) => {
457 if tc.function.name != Subtract::NAME {
458 return false;
459 }
460
461 let arguments =
462 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
463 .expect("Error parsing arguments");
464
465 arguments.x == 2.0 && arguments.y == 5.0
466 }
467 _ => false,
468 }),
469 "[{}]: Model did not use the Subtract tool.",
470 config.name
471 )
472 }
473
474 #[tokio::test]
475 #[ignore]
476 async fn test_tools() {
477 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
478 test_tools_client(&p).await;
479 }
480 }
481
482 async fn test_streaming_client(config: &ClientConfig) {
483 let client = config.factory();
484
485 let Some(client) = client.as_completion() else {
486 return;
487 };
488
489 let model = config
490 .completion_model
491 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
492
493 let model = client.completion_model(model);
494
495 let resp = model.stream(CompletionRequest {
496 preamble: None,
497 tools: vec![],
498 documents: vec![],
499 temperature: None,
500 max_tokens: None,
501 additional_params: None,
502 chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
503 });
504
505 let mut resp = resp.await.unwrap();
506
507 let mut received_chunk = false;
508
509 while let Some(chunk) = resp.next().await {
510 received_chunk = true;
511 assert!(chunk.is_ok());
512 }
513
514 assert!(
515 received_chunk,
516 "[{}]: Failed to receive a chunk from stream",
517 config.name
518 );
519
520 for choice in resp.choice {
521 match choice {
522 AssistantContent::Text(text) => {
523 assert!(
524 text.text.to_lowercase().contains("paris"),
525 "[{}]: Did not answer with Paris",
526 config.name
527 );
528 }
529 AssistantContent::ToolCall(_) => {}
530 }
531 }
532 }
533
534 #[tokio::test]
535 #[ignore]
536 async fn test_streaming() {
537 for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
538 test_streaming_client(&provider).await;
539 }
540 }
541
542 async fn test_streaming_tools_client(config: &ClientConfig) {
543 let client = config.factory();
544 let model = config
545 .completion_model
546 .unwrap_or_else(|| panic!("{} does not have the model set.", config.name));
547
548 let Some(client) = client.as_completion() else {
549 return;
550 };
551
552 let model = client.agent(model)
553 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
554 .max_tokens(1024)
555 .tool(Adder)
556 .tool(Subtract)
557 .build();
558
559 let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
560
561 assert!(
562 request.is_ok(),
563 "[{}]: Error occurred when building prompt, {}",
564 config.name,
565 request.err().unwrap()
566 );
567
568 let resp = request.unwrap().stream().await;
569
570 assert!(
571 resp.is_ok(),
572 "[{}]: Error occurred when prompting, {}",
573 config.name,
574 resp.err().unwrap()
575 );
576
577 let mut resp = resp.unwrap();
578
579 let mut received_chunk = false;
580
581 while let Some(chunk) = resp.next().await {
582 received_chunk = true;
583 assert!(chunk.is_ok());
584 }
585
586 assert!(
587 received_chunk,
588 "[{}]: Failed to receive a chunk from stream",
589 config.name
590 );
591
592 assert!(
593 resp.choice.iter().any(|content| match content {
594 AssistantContent::ToolCall(tc) => {
595 if tc.function.name != Subtract::NAME {
596 return false;
597 }
598
599 let arguments =
600 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
601 .expect("Error parsing arguments");
602
603 arguments.x == 2.0 && arguments.y == 5.0
604 }
605 _ => false,
606 }),
607 "[{}]: Model did not use the Subtract tool.",
608 config.name
609 )
610 }
611
612 #[tokio::test]
613 #[ignore]
614 async fn test_streaming_tools() {
615 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
616 test_streaming_tools_client(&p).await;
617 }
618 }
619
620 async fn test_audio_generation_client(config: &ClientConfig) {
621 let client = config.factory();
622
623 let Some(client) = client.as_audio_generation() else {
624 return;
625 };
626
627 let (model, voice) = config
628 .audio_generation_model
629 .unwrap_or_else(|| panic!("{} doesn't have the model set", config.name));
630
631 let model = client.audio_generation_model(model);
632
633 let request = model
634 .audio_generation_request()
635 .text("Hello world!")
636 .voice(voice);
637
638 let resp = request.send().await;
639
640 assert!(
641 resp.is_ok(),
642 "[{}]: Error occurred when sending request, {}",
643 config.name,
644 resp.err().unwrap()
645 );
646
647 let resp = resp.unwrap();
648
649 assert!(
650 !resp.audio.is_empty(),
651 "[{}]: Returned audio was empty",
652 config.name
653 );
654 }
655
656 #[tokio::test]
657 #[ignore]
658 async fn test_audio_generation() {
659 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
660 test_audio_generation_client(&p).await;
661 }
662 }
663
664 fn assert_feature<F, M>(
665 name: &str,
666 feature_name: &str,
667 model_name: &str,
668 feature: Option<F>,
669 model: Option<M>,
670 ) {
671 assert_eq!(
672 feature.is_some(),
673 model.is_some(),
674 "{} has{} implemented {} but config.{} is {}.",
675 name,
676 if feature.is_some() { "" } else { "n't" },
677 feature_name,
678 model_name,
679 if model.is_some() { "some" } else { "none" }
680 );
681 }
682
683 #[test]
684 #[ignore]
685 pub fn test_polymorphism() {
686 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
687 let client = config.factory();
688 assert_feature(
689 config.name,
690 "AsCompletion",
691 "completion_model",
692 client.as_completion(),
693 config.completion_model,
694 );
695
696 assert_feature(
697 config.name,
698 "AsEmbeddings",
699 "embeddings_model",
700 client.as_embeddings(),
701 config.embeddings_model,
702 );
703
704 assert_feature(
705 config.name,
706 "AsTranscription",
707 "transcription_model",
708 client.as_transcription(),
709 config.transcription_model,
710 );
711
712 assert_feature(
713 config.name,
714 "AsImageGeneration",
715 "image_generation_model",
716 client.as_image_generation(),
717 config.image_generation_model,
718 );
719
720 assert_feature(
721 config.name,
722 "AsAudioGeneration",
723 "audio_generation_model",
724 client.as_audio_generation(),
725 config.audio_generation_model,
726 )
727 }
728 }
729
730 async fn test_embed_client(config: &ClientConfig) {
731 const TEST: &str = "Hello world.";
732
733 let client = config.factory();
734
735 let Some(client) = client.as_embeddings() else {
736 return;
737 };
738
739 let model = config.embeddings_model.unwrap();
740
741 let model = client.embedding_model(model);
742
743 let resp = model.embed_text(TEST).await;
744
745 assert!(
746 resp.is_ok(),
747 "[{}]: Error occurred when sending request, {}",
748 config.name,
749 resp.err().unwrap()
750 );
751
752 let resp = resp.unwrap();
753
754 assert_eq!(resp.document, TEST);
755
756 assert!(
757 !resp.vec.is_empty(),
758 "[{}]: Returned embed was empty",
759 config.name
760 );
761 }
762
763 #[tokio::test]
764 #[ignore]
765 async fn test_embed() {
766 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
767 test_embed_client(&config).await;
768 }
769 }
770
771 async fn test_image_generation_client(config: &ClientConfig) {
772 let client = config.factory();
773 let Some(client) = client.as_image_generation() else {
774 return;
775 };
776
777 let model = config.image_generation_model.unwrap();
778
779 let model = client.image_generation_model(model);
780
781 let resp = model
782 .image_generation(ImageGenerationRequest {
783 prompt: "A castle sitting on a large hill.".to_string(),
784 width: 256,
785 height: 256,
786 additional_params: None,
787 })
788 .await;
789
790 assert!(
791 resp.is_ok(),
792 "[{}]: Error occurred when sending request, {}",
793 config.name,
794 resp.err().unwrap()
795 );
796
797 let resp = resp.unwrap();
798
799 assert!(
800 !resp.image.is_empty(),
801 "[{}]: Generated image was empty",
802 config.name
803 );
804 }
805
806 #[tokio::test]
807 #[ignore]
808 async fn test_image_generation() {
809 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
810 test_image_generation_client(&config).await;
811 }
812 }
813
814 async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
815 let client = config.factory();
816 let Some(client) = client.as_transcription() else {
817 return;
818 };
819
820 let model = config.image_generation_model.unwrap();
821
822 let model = client.transcription_model(model);
823
824 let resp = model
825 .transcription(TranscriptionRequest {
826 data,
827 filename: "audio.mp3".to_string(),
828 language: "en".to_string(),
829 prompt: None,
830 temperature: None,
831 additional_params: None,
832 })
833 .await;
834
835 assert!(
836 resp.is_ok(),
837 "[{}]: Error occurred when sending request, {}",
838 config.name,
839 resp.err().unwrap()
840 );
841
842 let resp = resp.unwrap();
843
844 assert!(
845 !resp.text.is_empty(),
846 "[{}]: Returned transcription was empty",
847 config.name
848 );
849 }
850
851 #[tokio::test]
852 #[ignore]
853 async fn test_transcription() {
854 let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
855
856 let mut data = Vec::new();
857 let _ = file.read(&mut data);
858
859 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
860 test_transcription_client(&config, data.clone()).await;
861 }
862 }
863
864 #[derive(Deserialize)]
865 struct OperationArgs {
866 x: f32,
867 y: f32,
868 }
869
870 #[derive(Debug, thiserror::Error)]
871 #[error("Math error")]
872 struct MathError;
873
874 #[derive(Deserialize, Serialize)]
875 struct Adder;
876 impl Tool for Adder {
877 const NAME: &'static str = "add";
878
879 type Error = MathError;
880 type Args = OperationArgs;
881 type Output = f32;
882
883 async fn definition(&self, _prompt: String) -> ToolDefinition {
884 ToolDefinition {
885 name: "add".to_string(),
886 description: "Add x and y together".to_string(),
887 parameters: json!({
888 "type": "object",
889 "properties": {
890 "x": {
891 "type": "number",
892 "description": "The first number to add"
893 },
894 "y": {
895 "type": "number",
896 "description": "The second number to add"
897 }
898 }
899 }),
900 }
901 }
902
903 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
904 println!("[tool-call] Adding {} and {}", args.x, args.y);
905 let result = args.x + args.y;
906 Ok(result)
907 }
908 }
909
910 #[derive(Deserialize, Serialize)]
911 struct Subtract;
912 impl Tool for Subtract {
913 const NAME: &'static str = "subtract";
914
915 type Error = MathError;
916 type Args = OperationArgs;
917 type Output = f32;
918
919 async fn definition(&self, _prompt: String) -> ToolDefinition {
920 serde_json::from_value(json!({
921 "name": "subtract",
922 "description": "Subtract y from x (i.e.: x - y)",
923 "parameters": {
924 "type": "object",
925 "properties": {
926 "x": {
927 "type": "number",
928 "description": "The number to subtract from"
929 },
930 "y": {
931 "type": "number",
932 "description": "The number to subtract"
933 }
934 }
935 }
936 }))
937 .expect("Tool Definition")
938 }
939
940 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
941 println!("[tool-call] Subtracting {} from {}", args.y, args.x);
942 let result = args.x - args.y;
943 Ok(result)
944 }
945 }
946}