Skip to main content

cortexai_tools/
macro_tools.rs

1//! Tools defined using the derive macro
2//!
3//! This module demonstrates how to create tools using the `#[derive(Tool)]` macro
4//! for minimal boilerplate.
5
6use cortexai_macros::Tool;
7
8// --- Function-based tools using #[tool] attribute macro ---
9
10#[cortexai_macros::tool(description = "Say hello to someone")]
11async fn say_hello(
12    #[param(description = "Name to greet", required)] name: String,
13) -> Result<serde_json::Value, String> {
14    Ok(serde_json::json!({ "greeting": format!("Hello, {}!", name) }))
15}
16
17#[cortexai_macros::tool(description = "Get weather for a city")]
18async fn get_weather(
19    #[param(description = "City name", required)] city: String,
20    #[param(description = "Unit system")] unit: Option<String>,
21) -> Result<serde_json::Value, String> {
22    let u = unit.unwrap_or_else(|| "metric".to_string());
23    Ok(serde_json::json!({ "city": city, "unit": u, "temp": 72 }))
24}
25
26#[cortexai_macros::tool(description = "Compute sum of numbers")]
27async fn compute_sum(
28    #[param(description = "First number", required)] a: i64,
29    #[param(description = "Second number", required)] b: i64,
30    #[param(description = "Include absolute value")] absolute: Option<bool>,
31) -> Result<serde_json::Value, String> {
32    let sum = a + b;
33    let result = if absolute.unwrap_or(false) { sum.abs() } else { sum };
34    Ok(serde_json::json!({ "result": result }))
35}
36
37#[cortexai_macros::tool(description = "Look up a word")]
38async fn lookup_word(
39    #[param(name = "query", description = "The word to look up", required)] word: String,
40) -> Result<serde_json::Value, String> {
41    Ok(serde_json::json!({ "word": word, "found": true }))
42}
43
44/// A simple echo tool that returns the input message
45#[derive(Tool, Default)]
46#[tool(name = "echo", description = "Echo back the input message")]
47pub struct EchoTool {
48    #[tool(param, required, description = "The message to echo back")]
49    pub message: String,
50
51    #[tool(param, description = "Number of times to repeat the message")]
52    pub repeat: Option<u32>,
53}
54
55impl EchoTool {
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// The implementation method that the macro calls
61    pub async fn run(
62        &self,
63        message: String,
64        repeat: Option<u32>,
65    ) -> Result<serde_json::Value, String> {
66        let times = repeat.unwrap_or(1);
67        let result: Vec<String> = (0..times).map(|_| message.clone()).collect();
68
69        Ok(serde_json::json!({
70            "echoed": if times == 1 { serde_json::json!(message) } else { serde_json::json!(result) },
71            "repeat_count": times
72        }))
73    }
74}
75
76/// A greeting tool that generates personalized greetings
77#[derive(Tool, Default)]
78#[tool(
79    name = "greet",
80    description = "Generate a personalized greeting message"
81)]
82pub struct GreetTool {
83    #[tool(param, required, description = "Name of the person to greet")]
84    pub name: String,
85
86    #[tool(param, description = "Language for the greeting (en, pt, es)")]
87    pub language: Option<String>,
88
89    #[tool(param, description = "Whether to use formal greeting")]
90    pub formal: Option<bool>,
91}
92
93impl GreetTool {
94    pub fn new() -> Self {
95        Self::default()
96    }
97
98    pub async fn run(
99        &self,
100        name: String,
101        language: Option<String>,
102        formal: Option<bool>,
103    ) -> Result<serde_json::Value, String> {
104        let lang = language.unwrap_or_else(|| "en".to_string());
105        let is_formal = formal.unwrap_or(false);
106
107        let greeting = match (lang.as_str(), is_formal) {
108            ("pt", true) => format!("Prezado(a) {}, é um prazer cumprimentá-lo(a).", name),
109            ("pt", false) => format!("Olá, {}! Tudo bem?", name),
110            ("es", true) => format!("Estimado(a) {}, es un placer saludarle.", name),
111            ("es", false) => format!("¡Hola, {}! ¿Qué tal?", name),
112            (_, true) => format!("Dear {}, it is a pleasure to greet you.", name),
113            (_, false) => format!("Hello, {}! How are you?", name),
114        };
115
116        Ok(serde_json::json!({
117            "greeting": greeting,
118            "language": lang,
119            "formal": is_formal
120        }))
121    }
122}
123
124/// A text transformation tool
125#[derive(Tool, Default)]
126#[tool(
127    name = "transform_text",
128    description = "Transform text with various operations"
129)]
130pub struct TransformTextTool {
131    #[tool(param, required, description = "The text to transform")]
132    pub text: String,
133
134    #[tool(
135        param,
136        required,
137        description = "Operation: uppercase, lowercase, reverse, title"
138    )]
139    pub operation: String,
140}
141
142impl TransformTextTool {
143    pub fn new() -> Self {
144        Self::default()
145    }
146
147    pub async fn run(&self, text: String, operation: String) -> Result<serde_json::Value, String> {
148        let result = match operation.to_lowercase().as_str() {
149            "uppercase" => text.to_uppercase(),
150            "lowercase" => text.to_lowercase(),
151            "reverse" => text.chars().rev().collect(),
152            "title" => text
153                .split_whitespace()
154                .map(|word| {
155                    let mut chars = word.chars();
156                    match chars.next() {
157                        None => String::new(),
158                        Some(first) => first
159                            .to_uppercase()
160                            .chain(chars.flat_map(|c| c.to_lowercase()))
161                            .collect(),
162                    }
163                })
164                .collect::<Vec<_>>()
165                .join(" "),
166            _ => return Err(format!("Unknown operation: {}", operation)),
167        };
168
169        Ok(serde_json::json!({
170            "original": text,
171            "transformed": result,
172            "operation": operation
173        }))
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use cortexai_core::tool::{ExecutionContext, Tool};
181    use cortexai_core::types::AgentId;
182    use std::collections::HashMap;
183
184    /// Test tool with Vec<String> parameter
185    #[derive(Tool, Default)]
186    #[tool(name = "tag_tool", description = "A tool with vector parameter")]
187    struct TagTool {
188        #[tool(param, required, description = "List of tags")]
189        tags: Vec<String>,
190    }
191
192    impl TagTool {
193        async fn run(&self, tags: Vec<String>) -> Result<serde_json::Value, String> {
194            Ok(serde_json::json!({ "tags": tags }))
195        }
196    }
197
198    /// Test tool with Option<Vec<u32>> parameter
199    #[derive(Tool, Default)]
200    #[tool(name = "score_tool", description = "A tool with optional vec parameter")]
201    struct ScoreTool {
202        #[tool(param, required, description = "Name of scorer")]
203        name: String,
204
205        #[tool(param, description = "Optional list of scores")]
206        scores: Option<Vec<u32>>,
207    }
208
209    impl ScoreTool {
210        async fn run(&self, name: String, scores: Option<Vec<u32>>) -> Result<serde_json::Value, String> {
211            Ok(serde_json::json!({ "name": name, "scores": scores }))
212        }
213    }
214
215    /// Test tool with HashMap<String, String> parameter
216    #[derive(Tool, Default)]
217    #[tool(name = "metadata_tool", description = "A tool with hashmap parameter")]
218    struct MetadataTool {
219        #[tool(param, required, description = "Key-value metadata")]
220        metadata: HashMap<String, String>,
221    }
222
223    impl MetadataTool {
224        async fn run(&self, metadata: HashMap<String, String>) -> Result<serde_json::Value, String> {
225            Ok(serde_json::json!({ "metadata": metadata }))
226        }
227    }
228
229    fn ctx() -> ExecutionContext {
230        ExecutionContext::new(AgentId::new("test"))
231    }
232
233    #[test]
234    fn test_echo_tool_schema() {
235        let tool = EchoTool::new();
236        let schema = tool.schema();
237
238        assert_eq!(schema.name, "echo");
239        assert!(schema.description.contains("Echo"));
240        assert!(schema.parameters["properties"]["message"].is_object());
241        assert!(schema.parameters["required"]
242            .as_array()
243            .unwrap()
244            .contains(&serde_json::json!("message")));
245    }
246
247    #[tokio::test]
248    async fn test_echo_tool_execute() {
249        let tool = EchoTool::new();
250
251        let result = tool
252            .execute(
253                &ctx(),
254                serde_json::json!({
255                    "message": "Hello, World!"
256                }),
257            )
258            .await
259            .unwrap();
260
261        assert_eq!(result["echoed"], "Hello, World!");
262        assert_eq!(result["repeat_count"], 1);
263    }
264
265    #[tokio::test]
266    async fn test_echo_tool_repeat() {
267        let tool = EchoTool::new();
268
269        let result = tool
270            .execute(
271                &ctx(),
272                serde_json::json!({
273                    "message": "Hi",
274                    "repeat": 3
275                }),
276            )
277            .await
278            .unwrap();
279
280        assert_eq!(result["echoed"].as_array().unwrap().len(), 3);
281        assert_eq!(result["repeat_count"], 3);
282    }
283
284    #[test]
285    fn test_greet_tool_schema() {
286        let tool = GreetTool::new();
287        let schema = tool.schema();
288
289        assert_eq!(schema.name, "greet");
290        assert!(schema.parameters["properties"]["name"].is_object());
291        assert!(schema.parameters["properties"]["language"].is_object());
292        assert!(schema.parameters["properties"]["formal"].is_object());
293    }
294
295    #[tokio::test]
296    async fn test_greet_tool_execute() {
297        let tool = GreetTool::new();
298
299        let result = tool
300            .execute(
301                &ctx(),
302                serde_json::json!({
303                    "name": "Lucas",
304                    "language": "pt",
305                    "formal": false
306                }),
307            )
308            .await
309            .unwrap();
310
311        assert!(result["greeting"].as_str().unwrap().contains("Olá"));
312        assert!(result["greeting"].as_str().unwrap().contains("Lucas"));
313    }
314
315    #[tokio::test]
316    async fn test_transform_tool_uppercase() {
317        let tool = TransformTextTool::new();
318
319        let result = tool
320            .execute(
321                &ctx(),
322                serde_json::json!({
323                    "text": "hello world",
324                    "operation": "uppercase"
325                }),
326            )
327            .await
328            .unwrap();
329
330        assert_eq!(result["transformed"], "HELLO WORLD");
331    }
332
333    #[tokio::test]
334    async fn test_transform_tool_reverse() {
335        let tool = TransformTextTool::new();
336
337        let result = tool
338            .execute(
339                &ctx(),
340                serde_json::json!({
341                    "text": "abc",
342                    "operation": "reverse"
343                }),
344            )
345            .await
346            .unwrap();
347
348        assert_eq!(result["transformed"], "cba");
349    }
350
351    // ---- Function-based #[tool] macro tests ----
352
353    #[test]
354    fn test_fn_tool_struct_name_and_schema() {
355        let tool = SayHelloTool::default();
356        let schema = tool.schema();
357
358        assert_eq!(schema.name, "say_hello");
359        assert_eq!(schema.description, "Say hello to someone");
360        assert!(schema.parameters["properties"]["name"].is_object());
361        assert_eq!(
362            schema.parameters["properties"]["name"]["description"],
363            "Name to greet"
364        );
365        assert!(schema.parameters["required"]
366            .as_array()
367            .unwrap()
368            .contains(&serde_json::json!("name")));
369    }
370
371    #[tokio::test]
372    async fn test_fn_tool_execute_basic() {
373        let tool = SayHelloTool::default();
374
375        let result = tool
376            .execute(
377                &ctx(),
378                serde_json::json!({ "name": "World" }),
379            )
380            .await
381            .unwrap();
382
383        assert_eq!(result["greeting"], "Hello, World!");
384    }
385
386    #[test]
387    fn test_fn_tool_option_param_is_optional_in_schema() {
388        let tool = GetWeatherTool::default();
389        let schema = tool.schema();
390
391        assert_eq!(schema.name, "get_weather");
392        assert_eq!(schema.description, "Get weather for a city");
393
394        // city is required, unit is not
395        let required = schema.parameters["required"].as_array().unwrap();
396        assert!(required.contains(&serde_json::json!("city")));
397        assert!(!required.contains(&serde_json::json!("unit")));
398
399        // Both appear in properties
400        assert!(schema.parameters["properties"]["city"].is_object());
401        assert!(schema.parameters["properties"]["unit"].is_object());
402    }
403
404    #[tokio::test]
405    async fn test_fn_tool_option_param_execute_without_optional() {
406        let tool = GetWeatherTool::default();
407
408        let result = tool
409            .execute(&ctx(), serde_json::json!({ "city": "London" }))
410            .await
411            .unwrap();
412
413        assert_eq!(result["city"], "London");
414        assert_eq!(result["unit"], "metric"); // default
415    }
416
417    #[tokio::test]
418    async fn test_fn_tool_option_param_execute_with_optional() {
419        let tool = GetWeatherTool::default();
420
421        let result = tool
422            .execute(&ctx(), serde_json::json!({ "city": "NYC", "unit": "imperial" }))
423            .await
424            .unwrap();
425
426        assert_eq!(result["city"], "NYC");
427        assert_eq!(result["unit"], "imperial");
428    }
429
430    #[test]
431    fn test_fn_tool_multiple_params_schema() {
432        let tool = ComputeSumTool::default();
433        let schema = tool.schema();
434
435        assert_eq!(schema.name, "compute_sum");
436
437        let required = schema.parameters["required"].as_array().unwrap();
438        assert!(required.contains(&serde_json::json!("a")));
439        assert!(required.contains(&serde_json::json!("b")));
440        assert!(!required.contains(&serde_json::json!("absolute")));
441
442        assert_eq!(schema.parameters["properties"]["a"]["type"], "integer");
443        assert_eq!(schema.parameters["properties"]["b"]["type"], "integer");
444        assert_eq!(schema.parameters["properties"]["absolute"]["type"], "boolean");
445    }
446
447    #[tokio::test]
448    async fn test_fn_tool_multiple_params_execute() {
449        let tool = ComputeSumTool::default();
450
451        let result = tool
452            .execute(&ctx(), serde_json::json!({ "a": 3, "b": 5 }))
453            .await
454            .unwrap();
455
456        assert_eq!(result["result"], 8);
457    }
458
459    #[tokio::test]
460    async fn test_fn_tool_multiple_params_with_optional() {
461        let tool = ComputeSumTool::default();
462
463        let result = tool
464            .execute(&ctx(), serde_json::json!({ "a": -3, "b": 1, "absolute": true }))
465            .await
466            .unwrap();
467
468        assert_eq!(result["result"], 2);
469    }
470
471    #[test]
472    fn test_fn_tool_param_name_override() {
473        let tool = LookupWordTool::default();
474        let schema = tool.schema();
475
476        // The schema should use "query" not "word"
477        assert!(schema.parameters["properties"]["query"].is_object());
478        assert!(schema.parameters["properties"]["word"].is_null());
479        assert!(schema.parameters["required"]
480            .as_array()
481            .unwrap()
482            .contains(&serde_json::json!("query")));
483    }
484
485    #[tokio::test]
486    async fn test_fn_tool_param_name_override_execute() {
487        let tool = LookupWordTool::default();
488
489        let result = tool
490            .execute(&ctx(), serde_json::json!({ "query": "hello" }))
491            .await
492            .unwrap();
493
494        assert_eq!(result["word"], "hello");
495    }
496
497    #[tokio::test]
498    async fn test_fn_tool_missing_required_param() {
499        let tool = GetWeatherTool::default();
500
501        // city is required but not provided
502        let result = tool.execute(&ctx(), serde_json::json!({})).await;
503        assert!(result.is_err());
504    }
505
506    #[tokio::test]
507    async fn test_missing_required_param() {
508        let tool = EchoTool::new();
509
510        let result = tool.execute(&ctx(), serde_json::json!({})).await;
511
512        assert!(result.is_err());
513    }
514
515    #[test]
516    fn test_vec_string_param_generates_array_schema() {
517        let tool = TagTool::default();
518        let schema = tool.schema();
519        let tags_prop = &schema.parameters["properties"]["tags"];
520
521        assert_eq!(tags_prop["type"], "array");
522        assert_eq!(tags_prop["items"]["type"], "string");
523    }
524
525    #[test]
526    fn test_option_vec_u32_param_is_optional_array() {
527        let tool = ScoreTool::default();
528        let schema = tool.schema();
529        let scores_prop = &schema.parameters["properties"]["scores"];
530
531        // Should be array type with integer items
532        assert_eq!(scores_prop["type"], "array");
533        assert_eq!(scores_prop["items"]["type"], "integer");
534
535        // Should NOT be in the required list
536        let required = schema.parameters["required"].as_array().unwrap();
537        assert!(
538            !required.contains(&serde_json::json!("scores")),
539            "Option<Vec<u32>> should not be required"
540        );
541    }
542
543    #[test]
544    fn test_hashmap_param_generates_object_schema() {
545        let tool = MetadataTool::default();
546        let schema = tool.schema();
547        let meta_prop = &schema.parameters["properties"]["metadata"];
548
549        assert_eq!(meta_prop["type"], "object");
550        assert_eq!(meta_prop["additionalProperties"]["type"], "string");
551    }
552
553    #[test]
554    fn test_generated_schema_is_valid_json() {
555        let tool = ScoreTool::default();
556        let schema = tool.schema();
557
558        // The parameters should be a valid JSON object
559        assert!(schema.parameters.is_object());
560        assert!(schema.parameters["properties"].is_object());
561        assert!(schema.parameters["required"].is_array());
562    }
563}