1pub mod audio_generation;
6pub mod builder;
7pub mod completion;
8pub mod embeddings;
9pub mod image_generation;
10pub mod transcription;
11
12#[cfg(feature = "derive")]
13pub use rig_derive::ProviderClient;
14use std::fmt::Debug;
15
16pub trait ProviderClient:
22 AsCompletion + AsTranscription + AsEmbeddings + AsImageGeneration + AsAudioGeneration + Debug
23{
24 fn from_env() -> Self
27 where
28 Self: Sized;
29
30 fn boxed(self) -> Box<dyn ProviderClient>
32 where
33 Self: Sized + 'static,
34 {
35 Box::new(self)
36 }
37
38 fn from_env_boxed<'a>() -> Box<dyn ProviderClient + 'a>
41 where
42 Self: Sized,
43 Self: 'a,
44 {
45 Box::new(Self::from_env())
46 }
47}
48
49pub trait AsCompletion {
51 fn as_completion(&self) -> Option<Box<dyn CompletionClientDyn>> {
52 None
53 }
54}
55
56pub trait AsTranscription {
58 fn as_transcription(&self) -> Option<Box<dyn TranscriptionClientDyn>> {
59 None
60 }
61}
62
63pub trait AsEmbeddings {
65 fn as_embeddings(&self) -> Option<Box<dyn EmbeddingsClientDyn>> {
66 None
67 }
68}
69
70pub trait AsAudioGeneration {
72 #[cfg(feature = "audio")]
73 fn as_audio_generation(&self) -> Option<Box<dyn AudioGenerationClientDyn>> {
74 None
75 }
76}
77
78pub trait AsImageGeneration {
80 #[cfg(feature = "image")]
81 fn as_image_generation(&self) -> Option<Box<dyn ImageGenerationClientDyn>> {
82 None
83 }
84}
85
86#[cfg(not(feature = "audio"))]
87impl<T: ProviderClient> AsAudioGeneration for T {}
88
89#[cfg(not(feature = "image"))]
90impl<T: ProviderClient> AsImageGeneration for T {}
91
92#[macro_export]
101macro_rules! impl_conversion_traits {
102 ($( $trait_:ident ),* for $struct_:ident ) => {
103 $(
104 impl_conversion_traits!(@impl $trait_ for $struct_);
105 )*
106 };
107
108 (@impl AsAudioGeneration for $struct_:ident ) => {
109 rig::client::impl_audio_generation!($struct_);
110 };
111
112 (@impl AsImageGeneration for $struct_:ident ) => {
113 rig::client::impl_image_generation!($struct_);
114 };
115
116 (@impl $trait_:ident for $struct_:ident) => {
117 impl rig::client::$trait_ for $struct_ {}
118 };
119}
120
121#[cfg(feature = "audio")]
122#[macro_export]
123macro_rules! impl_audio_generation {
124 ($struct_:ident) => {
125 impl rig::client::AsAudioGeneration for $struct_ {}
126 };
127}
128
129#[cfg(not(feature = "audio"))]
130#[macro_export]
131macro_rules! impl_audio_generation {
132 ($struct_:ident) => {};
133}
134
135#[cfg(feature = "image")]
136#[macro_export]
137macro_rules! impl_image_generation {
138 ($struct_:ident) => {
139 impl rig::client::AsImageGeneration for $struct_ {}
140 };
141}
142
143#[cfg(not(feature = "image"))]
144#[macro_export]
145macro_rules! impl_image_generation {
146 ($struct_:ident) => {};
147}
148
149pub use impl_audio_generation;
150pub use impl_conversion_traits;
151pub use impl_image_generation;
152
153#[cfg(feature = "audio")]
154use crate::client::audio_generation::AudioGenerationClientDyn;
155use crate::client::completion::CompletionClientDyn;
156use crate::client::embeddings::EmbeddingsClientDyn;
157#[cfg(feature = "image")]
158use crate::client::image_generation::ImageGenerationClientDyn;
159use crate::client::transcription::TranscriptionClientDyn;
160
161#[cfg(feature = "audio")]
162pub use crate::client::audio_generation::AudioGenerationClient;
163pub use crate::client::completion::CompletionClient;
164pub use crate::client::embeddings::EmbeddingsClient;
165#[cfg(feature = "image")]
166pub use crate::client::image_generation::ImageGenerationClient;
167pub use crate::client::transcription::TranscriptionClient;
168
169#[cfg(test)]
170mod tests {
171 use crate::client::ProviderClient;
172 use crate::completion::{Completion, CompletionRequest, ToolDefinition};
173 use crate::image_generation::ImageGenerationRequest;
174 use crate::message::AssistantContent;
175 use crate::providers::{
176 anthropic, azure, cohere, deepseek, galadriel, gemini, huggingface, hyperbolic, mira,
177 moonshot, openai, openrouter, together, xai,
178 };
179 use crate::streaming::StreamingCompletion;
180 use crate::tool::Tool;
181 use crate::transcription::TranscriptionRequest;
182 use crate::OneOrMany;
183 use futures::StreamExt;
184 use rig::message::Message;
185 use rig::providers::{groq, ollama, perplexity};
186 use serde::{Deserialize, Serialize};
187 use serde_json::json;
188 use std::fs::File;
189 use std::io::Read;
190
191 struct ClientConfig {
192 name: &'static str,
193 factory: Box<dyn Fn() -> Box<dyn ProviderClient>>,
194 env_variable: &'static str,
195 completion_model: Option<&'static str>,
196 embeddings_model: Option<&'static str>,
197 transcription_model: Option<&'static str>,
198 image_generation_model: Option<&'static str>,
199 audio_generation_model: Option<(&'static str, &'static str)>,
200 }
201
202 impl Default for ClientConfig {
203 fn default() -> Self {
204 Self {
205 name: "",
206 factory: Box::new(|| panic!("Not implemented")),
207 env_variable: "",
208 completion_model: None,
209 embeddings_model: None,
210 transcription_model: None,
211 image_generation_model: None,
212 audio_generation_model: None,
213 }
214 }
215 }
216
217 impl ClientConfig {
218 fn is_env_var_set(&self) -> bool {
219 self.env_variable.is_empty() || std::env::var(self.env_variable).is_ok()
220 }
221
222 fn factory(&self) -> Box<dyn ProviderClient + '_> {
223 let client = self.factory.as_ref()();
224
225 client
226 }
227 }
228
229 fn providers() -> Vec<ClientConfig> {
230 vec![
231 ClientConfig {
232 name: "Anthropic",
233 factory: Box::new(anthropic::Client::from_env_boxed),
234 env_variable: "ANTHROPIC_API_KEY",
235 completion_model: Some(anthropic::CLAUDE_3_5_SONNET),
236 ..Default::default()
237 },
238 ClientConfig {
239 name: "Cohere",
240 factory: Box::new(cohere::Client::from_env_boxed),
241 env_variable: "COHERE_API_KEY",
242 completion_model: Some(cohere::COMMAND_R),
243 embeddings_model: Some(cohere::EMBED_ENGLISH_LIGHT_V2),
244 ..Default::default()
245 },
246 ClientConfig {
247 name: "Gemini",
248 factory: Box::new(gemini::Client::from_env_boxed),
249 env_variable: "GEMINI_API_KEY",
250 completion_model: Some(gemini::completion::GEMINI_2_0_FLASH),
251 embeddings_model: Some(gemini::embedding::EMBEDDING_001),
252 transcription_model: Some(gemini::transcription::GEMINI_2_0_FLASH),
253 ..Default::default()
254 },
255 ClientConfig {
256 name: "Huggingface",
257 factory: Box::new(huggingface::Client::from_env_boxed),
258 env_variable: "HUGGINGFACE_API_KEY",
259 completion_model: Some(huggingface::PHI_4),
260 transcription_model: Some(huggingface::WHISPER_SMALL),
261 image_generation_model: Some(huggingface::STABLE_DIFFUSION_3),
262 ..Default::default()
263 },
264 ClientConfig {
265 name: "OpenAI",
266 factory: Box::new(openai::Client::from_env_boxed),
267 env_variable: "OPENAI_API_KEY",
268 completion_model: Some(openai::GPT_4O),
269 embeddings_model: Some(openai::TEXT_EMBEDDING_ADA_002),
270 transcription_model: Some(openai::WHISPER_1),
271 image_generation_model: Some(openai::DALL_E_2),
272 audio_generation_model: Some((openai::TTS_1, "onyx")),
273 },
274 ClientConfig {
275 name: "OpenRouter",
276 factory: Box::new(openrouter::Client::from_env_boxed),
277 env_variable: "OPENROUTER_API_KEY",
278 completion_model: Some(openrouter::CLAUDE_3_7_SONNET),
279 ..Default::default()
280 },
281 ClientConfig {
282 name: "Together",
283 factory: Box::new(together::Client::from_env_boxed),
284 env_variable: "TOGETHER_API_KEY",
285 completion_model: Some(together::ALPACA_7B),
286 embeddings_model: Some(together::BERT_BASE_UNCASED),
287 ..Default::default()
288 },
289 ClientConfig {
290 name: "XAI",
291 factory: Box::new(xai::Client::from_env_boxed),
292 env_variable: "XAI_API_KEY",
293 completion_model: Some(xai::GROK_3_MINI),
294 embeddings_model: Some(xai::EMBEDDING_V1),
295 ..Default::default()
296 },
297 ClientConfig {
298 name: "Azure",
299 factory: Box::new(azure::Client::from_env_boxed),
300 env_variable: "AZURE_API_KEY",
301 completion_model: Some(azure::GPT_4O),
302 embeddings_model: Some(azure::TEXT_EMBEDDING_ADA_002),
303 transcription_model: Some("whisper-1"),
304 image_generation_model: Some("dalle-2"),
305 audio_generation_model: Some(("tts-1", "onyx")),
306 },
307 ClientConfig {
308 name: "Deepseek",
309 factory: Box::new(deepseek::Client::from_env_boxed),
310 env_variable: "DEEPSEEK_API_KEY",
311 completion_model: Some(deepseek::DEEPSEEK_CHAT),
312 ..Default::default()
313 },
314 ClientConfig {
315 name: "Galadriel",
316 factory: Box::new(galadriel::Client::from_env_boxed),
317 env_variable: "GALADRIEL_API_KEY",
318 completion_model: Some(galadriel::GPT_4O),
319 ..Default::default()
320 },
321 ClientConfig {
322 name: "Groq",
323 factory: Box::new(groq::Client::from_env_boxed),
324 env_variable: "GROQ_API_KEY",
325 completion_model: Some(groq::MIXTRAL_8X7B_32768),
326 transcription_model: Some(groq::DISTIL_WHISPER_LARGE_V3),
327 ..Default::default()
328 },
329 ClientConfig {
330 name: "Hyperbolic",
331 factory: Box::new(hyperbolic::Client::from_env_boxed),
332 env_variable: "HYPERBOLIC_API_KEY",
333 completion_model: Some(hyperbolic::LLAMA_3_1_8B),
334 image_generation_model: Some(hyperbolic::SD1_5),
335 audio_generation_model: Some(("EN", "EN-US")),
336 ..Default::default()
337 },
338 ClientConfig {
339 name: "Mira",
340 factory: Box::new(mira::Client::from_env_boxed),
341 env_variable: "MIRA_API_KEY",
342 completion_model: Some("gpt-4o"),
343 ..Default::default()
344 },
345 ClientConfig {
346 name: "Moonshot",
347 factory: Box::new(moonshot::Client::from_env_boxed),
348 env_variable: "MOONSHOT_API_KEY",
349 completion_model: Some(moonshot::MOONSHOT_CHAT),
350 ..Default::default()
351 },
352 ClientConfig {
353 name: "Ollama",
354 factory: Box::new(ollama::Client::from_env_boxed),
355 env_variable: "OLLAMA_ENABLED",
356 completion_model: Some("llama3.1:8b"),
357 embeddings_model: Some(ollama::NOMIC_EMBED_TEXT),
358 ..Default::default()
359 },
360 ClientConfig {
361 name: "Perplexity",
362 factory: Box::new(perplexity::Client::from_env_boxed),
363 env_variable: "PERPLEXITY_API_KEY",
364 completion_model: Some(perplexity::SONAR),
365 ..Default::default()
366 },
367 ]
368 }
369
370 async fn test_completions_client(config: &ClientConfig) {
371 let client = config.factory();
372
373 let Some(client) = client.as_completion() else {
374 return;
375 };
376
377 let model = config.completion_model.expect(&format!(
378 "{} does not have completion_model set",
379 config.name
380 ));
381
382 let model = client.completion_model(model);
383
384 let resp = model
385 .completion_request(Message::user("Whats the capital of France?"))
386 .send()
387 .await;
388
389 assert!(
390 resp.is_ok(),
391 "[{}]: Error occurred when prompting, {}",
392 config.name,
393 resp.err().unwrap()
394 );
395
396 let resp = resp.unwrap();
397
398 match resp.choice.first() {
399 AssistantContent::Text(text) => {
400 assert!(text.text.to_lowercase().contains("paris"));
401 }
402 _ => {
403 assert!(
404 false,
405 "[{}]: First choice wasn't a Text message, {:?}",
406 config.name,
407 resp.choice.first()
408 );
409 }
410 }
411 }
412
413 #[tokio::test]
414 #[ignore]
415 async fn test_completions() {
416 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
417 test_completions_client(&p).await;
418 }
419 }
420
421 async fn test_tools_client(config: &ClientConfig) {
422 let client = config.factory();
423 let model = config
424 .completion_model
425 .expect(&format!("{} does not have the model set.", config.name));
426
427 let Some(client) = client.as_completion() else {
428 return;
429 };
430
431 let model = client.agent(model)
432 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
433 .max_tokens(1024)
434 .tool(Adder)
435 .tool(Subtract)
436 .build();
437
438 let request = model.completion("Calculate 2 - 5", vec![]).await;
439
440 assert!(
441 request.is_ok(),
442 "[{}]: Error occurred when building prompt, {}",
443 config.name,
444 request.err().unwrap()
445 );
446
447 let resp = request.unwrap().send().await;
448
449 assert!(
450 resp.is_ok(),
451 "[{}]: Error occurred when prompting, {}",
452 config.name,
453 resp.err().unwrap()
454 );
455
456 let resp = resp.unwrap();
457
458 assert!(
459 resp.choice.iter().any(|content| match content {
460 AssistantContent::ToolCall(tc) => {
461 if tc.function.name != Subtract::NAME {
462 return false;
463 }
464
465 let arguments =
466 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
467 .expect("Error parsing arguments");
468
469 arguments.x == 2.0 && arguments.y == 5.0
470 }
471 _ => false,
472 }),
473 "[{}]: Model did not use the Subtract tool.",
474 config.name
475 )
476 }
477
478 #[tokio::test]
479 #[ignore]
480 async fn test_tools() {
481 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
482 test_tools_client(&p).await;
483 }
484 }
485
486 async fn test_streaming_client(config: &ClientConfig) {
487 let client = config.factory();
488
489 let Some(client) = client.as_completion() else {
490 return;
491 };
492
493 let model = config
494 .completion_model
495 .expect(&format!("{} does not have the model set.", config.name));
496
497 let model = client.completion_model(model);
498
499 let resp = model.stream(CompletionRequest {
500 preamble: None,
501 tools: vec![],
502 documents: vec![],
503 temperature: None,
504 max_tokens: None,
505 additional_params: None,
506 chat_history: OneOrMany::one(Message::user("What is the capital of France?")),
507 });
508
509 let mut resp = resp.await.unwrap();
510
511 let mut received_chunk = false;
512
513 while let Some(chunk) = resp.next().await {
514 received_chunk = true;
515 assert!(chunk.is_ok());
516 }
517
518 assert!(
519 received_chunk,
520 "[{}]: Failed to receive a chunk from stream",
521 config.name
522 );
523
524 for choice in resp.choice {
525 match choice {
526 AssistantContent::Text(text) => {
527 assert!(
528 text.text.to_lowercase().contains("paris"),
529 "[{}]: Did not answer with Paris",
530 config.name
531 );
532 }
533 AssistantContent::ToolCall(_) => {}
534 }
535 }
536 }
537
538 #[tokio::test]
539 #[ignore]
540 async fn test_streaming() {
541 for provider in providers().into_iter().filter(ClientConfig::is_env_var_set) {
542 test_streaming_client(&provider).await;
543 }
544 }
545
546 async fn test_streaming_tools_client(config: &ClientConfig) {
547 let client = config.factory();
548 let model = config
549 .completion_model
550 .expect(&format!("{} does not have the model set.", config.name));
551
552 let Some(client) = client.as_completion() else {
553 return;
554 };
555
556 let model = client.agent(model)
557 .preamble("You are a calculator here to help the user perform arithmetic operations. Use the tools provided to answer the user's question.")
558 .max_tokens(1024)
559 .tool(Adder)
560 .tool(Subtract)
561 .build();
562
563 let request = model.stream_completion("Calculate 2 - 5", vec![]).await;
564
565 assert!(
566 request.is_ok(),
567 "[{}]: Error occurred when building prompt, {}",
568 config.name,
569 request.err().unwrap()
570 );
571
572 let resp = request.unwrap().stream().await;
573
574 assert!(
575 resp.is_ok(),
576 "[{}]: Error occurred when prompting, {}",
577 config.name,
578 resp.err().unwrap()
579 );
580
581 let mut resp = resp.unwrap();
582
583 let mut received_chunk = false;
584
585 while let Some(chunk) = resp.next().await {
586 received_chunk = true;
587 assert!(chunk.is_ok());
588 }
589
590 assert!(
591 received_chunk,
592 "[{}]: Failed to receive a chunk from stream",
593 config.name
594 );
595
596 assert!(
597 resp.choice.iter().any(|content| match content {
598 AssistantContent::ToolCall(tc) => {
599 if tc.function.name != Subtract::NAME {
600 return false;
601 }
602
603 let arguments =
604 serde_json::from_value::<OperationArgs>((tc.function.arguments).clone())
605 .expect("Error parsing arguments");
606
607 arguments.x == 2.0 && arguments.y == 5.0
608 }
609 _ => false,
610 }),
611 "[{}]: Model did not use the Subtract tool.",
612 config.name
613 )
614 }
615
616 #[tokio::test]
617 #[ignore]
618 async fn test_streaming_tools() {
619 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
620 test_streaming_tools_client(&p).await;
621 }
622 }
623
624 async fn test_audio_generation_client(config: &ClientConfig) {
625 let client = config.factory();
626
627 let Some(client) = client.as_audio_generation() else {
628 return;
629 };
630
631 let (model, voice) = config
632 .audio_generation_model
633 .expect(&format!("{} doesn't have the model set", config.name));
634
635 let model = client.audio_generation_model(model);
636
637 let request = model
638 .audio_generation_request()
639 .text("Hello world!")
640 .voice(voice);
641
642 let resp = request.send().await;
643
644 assert!(
645 resp.is_ok(),
646 "[{}]: Error occurred when sending request, {}",
647 config.name,
648 resp.err().unwrap()
649 );
650
651 let resp = resp.unwrap();
652
653 assert!(
654 resp.audio.len() > 0,
655 "[{}]: Returned audio was empty",
656 config.name
657 );
658 }
659
660 #[tokio::test]
661 #[ignore]
662 async fn test_audio_generation() {
663 for p in providers().into_iter().filter(ClientConfig::is_env_var_set) {
664 test_audio_generation_client(&p).await;
665 }
666 }
667
668 fn assert_feature<F, M>(
669 name: &str,
670 feature_name: &str,
671 model_name: &str,
672 feature: Option<F>,
673 model: Option<M>,
674 ) {
675 assert_eq!(
676 feature.is_some(),
677 model.is_some(),
678 "{} has{} implemented {} but config.{} is {}.",
679 name,
680 if feature.is_some() { "" } else { "n't" },
681 feature_name,
682 model_name,
683 if model.is_some() { "some" } else { "none" }
684 );
685 }
686
687 #[test]
688 #[ignore]
689 pub fn test_polymorphism() {
690 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
691 let client = config.factory();
692 assert_feature(
693 config.name,
694 "AsCompletion",
695 "completion_model",
696 client.as_completion(),
697 config.completion_model,
698 );
699
700 assert_feature(
701 config.name,
702 "AsEmbeddings",
703 "embeddings_model",
704 client.as_embeddings(),
705 config.embeddings_model,
706 );
707
708 assert_feature(
709 config.name,
710 "AsTranscription",
711 "transcription_model",
712 client.as_transcription(),
713 config.transcription_model,
714 );
715
716 assert_feature(
717 config.name,
718 "AsImageGeneration",
719 "image_generation_model",
720 client.as_image_generation(),
721 config.image_generation_model,
722 );
723
724 assert_feature(
725 config.name,
726 "AsAudioGeneration",
727 "audio_generation_model",
728 client.as_audio_generation(),
729 config.audio_generation_model,
730 )
731 }
732 }
733
734 async fn test_embed_client(config: &ClientConfig) {
735 const TEST: &'static str = "Hello world.";
736
737 let client = config.factory();
738
739 let Some(client) = client.as_embeddings() else {
740 return;
741 };
742
743 let model = config.embeddings_model.unwrap();
744
745 let model = client.embedding_model(model);
746
747 let resp = model.embed_text(TEST).await;
748
749 assert!(
750 resp.is_ok(),
751 "[{}]: Error occurred when sending request, {}",
752 config.name,
753 resp.err().unwrap()
754 );
755
756 let resp = resp.unwrap();
757
758 assert_eq!(resp.document, TEST);
759
760 assert!(
761 !resp.vec.is_empty(),
762 "[{}]: Returned embed was empty",
763 config.name
764 );
765 }
766
767 #[tokio::test]
768 #[ignore]
769 async fn test_embed() {
770 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
771 test_embed_client(&config).await;
772 }
773 }
774
775 async fn test_image_generation_client(config: &ClientConfig) {
776 let client = config.factory();
777 let Some(client) = client.as_image_generation() else {
778 return;
779 };
780
781 let model = config.image_generation_model.unwrap();
782
783 let model = client.image_generation_model(model);
784
785 let resp = model
786 .image_generation(ImageGenerationRequest {
787 prompt: "A castle sitting on a large hill.".to_string(),
788 width: 256,
789 height: 256,
790 additional_params: None,
791 })
792 .await;
793
794 assert!(
795 resp.is_ok(),
796 "[{}]: Error occurred when sending request, {}",
797 config.name,
798 resp.err().unwrap()
799 );
800
801 let resp = resp.unwrap();
802
803 assert!(
804 !resp.image.is_empty(),
805 "[{}]: Generated image was empty",
806 config.name
807 );
808 }
809
810 #[tokio::test]
811 #[ignore]
812 async fn test_image_generation() {
813 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
814 test_image_generation_client(&config).await;
815 }
816 }
817
818 async fn test_transcription_client(config: &ClientConfig, data: Vec<u8>) {
819 let client = config.factory();
820 let Some(client) = client.as_transcription() else {
821 return;
822 };
823
824 let model = config.image_generation_model.unwrap();
825
826 let model = client.transcription_model(model);
827
828 let resp = model
829 .transcription(TranscriptionRequest {
830 data,
831 filename: "audio.mp3".to_string(),
832 language: "en".to_string(),
833 prompt: None,
834 temperature: None,
835 additional_params: None,
836 })
837 .await;
838
839 assert!(
840 resp.is_ok(),
841 "[{}]: Error occurred when sending request, {}",
842 config.name,
843 resp.err().unwrap()
844 );
845
846 let resp = resp.unwrap();
847
848 assert!(
849 !resp.text.is_empty(),
850 "[{}]: Returned transcription was empty",
851 config.name
852 );
853 }
854
855 #[tokio::test]
856 #[ignore]
857 async fn test_transcription() {
858 let mut file = File::open("examples/audio/en-us-natural-speech.mp3").unwrap();
859
860 let mut data = Vec::new();
861 let _ = file.read(&mut data);
862
863 for config in providers().into_iter().filter(ClientConfig::is_env_var_set) {
864 test_transcription_client(&config, data.clone()).await;
865 }
866 }
867
868 #[derive(Deserialize)]
869 struct OperationArgs {
870 x: f32,
871 y: f32,
872 }
873
874 #[derive(Debug, thiserror::Error)]
875 #[error("Math error")]
876 struct MathError;
877
878 #[derive(Deserialize, Serialize)]
879 struct Adder;
880 impl Tool for Adder {
881 const NAME: &'static str = "add";
882
883 type Error = MathError;
884 type Args = OperationArgs;
885 type Output = f32;
886
887 async fn definition(&self, _prompt: String) -> ToolDefinition {
888 ToolDefinition {
889 name: "add".to_string(),
890 description: "Add x and y together".to_string(),
891 parameters: json!({
892 "type": "object",
893 "properties": {
894 "x": {
895 "type": "number",
896 "description": "The first number to add"
897 },
898 "y": {
899 "type": "number",
900 "description": "The second number to add"
901 }
902 }
903 }),
904 }
905 }
906
907 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
908 println!("[tool-call] Adding {} and {}", args.x, args.y);
909 let result = args.x + args.y;
910 Ok(result)
911 }
912 }
913
914 #[derive(Deserialize, Serialize)]
915 struct Subtract;
916 impl Tool for Subtract {
917 const NAME: &'static str = "subtract";
918
919 type Error = MathError;
920 type Args = OperationArgs;
921 type Output = f32;
922
923 async fn definition(&self, _prompt: String) -> ToolDefinition {
924 serde_json::from_value(json!({
925 "name": "subtract",
926 "description": "Subtract y from x (i.e.: x - y)",
927 "parameters": {
928 "type": "object",
929 "properties": {
930 "x": {
931 "type": "number",
932 "description": "The number to subtract from"
933 },
934 "y": {
935 "type": "number",
936 "description": "The number to subtract"
937 }
938 }
939 }
940 }))
941 .expect("Tool Definition")
942 }
943
944 async fn call(&self, args: Self::Args) -> anyhow::Result<Self::Output, Self::Error> {
945 println!("[tool-call] Subtracting {} from {}", args.y, args.x);
946 let result = args.x - args.y;
947 Ok(result)
948 }
949 }
950}