Skip to main content

ai_agents_tools/builtin/
random.rs

1use async_trait::async_trait;
2use rand::{Rng, seq::SliceRandom};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::generate_schema;
8use ai_agents_core::{Tool, ToolResult};
9
10pub struct RandomTool;
11
12impl RandomTool {
13    pub fn new() -> Self {
14        Self
15    }
16}
17
18impl Default for RandomTool {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24#[derive(Debug, Deserialize, JsonSchema)]
25struct RandomInput {
26    /// Operation to perform: uuid, number, integer, choice, shuffle, bool, string
27    operation: String,
28    /// Minimum value (for number/integer operations)
29    #[serde(default)]
30    min: Option<f64>,
31    /// Maximum value (for number/integer operations)
32    #[serde(default)]
33    max: Option<f64>,
34    /// Items to choose from or shuffle (for choice/shuffle operations)
35    #[serde(default)]
36    items: Option<Vec<Value>>,
37    /// Number of items to choose (for choice operation, default: 1)
38    #[serde(default)]
39    count: Option<usize>,
40    /// Length of random string (for string operation, default: 16)
41    #[serde(default)]
42    length: Option<usize>,
43    /// Character set for string: alphanumeric, alpha, numeric, hex (default: alphanumeric)
44    #[serde(default)]
45    charset: Option<String>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49struct UuidOutput {
50    uuid: String,
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54struct NumberOutput {
55    value: f64,
56    min: f64,
57    max: f64,
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61struct IntegerOutput {
62    value: i64,
63    min: i64,
64    max: i64,
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68struct ChoiceOutput {
69    selected: Vec<Value>,
70    count: usize,
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74struct ShuffleOutput {
75    shuffled: Vec<Value>,
76    count: usize,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80struct BoolOutput {
81    value: bool,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85struct StringOutput {
86    value: String,
87    length: usize,
88}
89
90#[async_trait]
91impl Tool for RandomTool {
92    fn id(&self) -> &str {
93        "random"
94    }
95
96    fn name(&self) -> &str {
97        "Random Generator"
98    }
99
100    fn description(&self) -> &str {
101        "Generate random values. Operations: uuid (generate UUID v4), number (random float), integer (random int), choice (pick from list), shuffle (randomize list order), bool (random true/false), string (random string)."
102    }
103
104    fn input_schema(&self) -> Value {
105        generate_schema::<RandomInput>()
106    }
107
108    async fn execute(&self, args: Value) -> ToolResult {
109        let input: RandomInput = match serde_json::from_value(args) {
110            Ok(input) => input,
111            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
112        };
113
114        match input.operation.to_lowercase().as_str() {
115            "uuid" => self.handle_uuid(),
116            "number" => self.handle_number(&input),
117            "integer" | "int" => self.handle_integer(&input),
118            "choice" | "choose" | "pick" => self.handle_choice(&input),
119            "shuffle" => self.handle_shuffle(&input),
120            "bool" | "boolean" => self.handle_bool(),
121            "string" | "str" => self.handle_string(&input),
122            _ => ToolResult::error(format!(
123                "Unknown operation: {}. Valid operations: uuid, number, integer, choice, shuffle, bool, string",
124                input.operation
125            )),
126        }
127    }
128}
129
130impl RandomTool {
131    fn handle_uuid(&self) -> ToolResult {
132        let output = UuidOutput {
133            uuid: uuid::Uuid::new_v4().to_string(),
134        };
135        self.to_result(&output)
136    }
137
138    fn handle_number(&self, input: &RandomInput) -> ToolResult {
139        let min = input.min.unwrap_or(0.0);
140        let max = input.max.unwrap_or(1.0);
141
142        if min >= max {
143            return ToolResult::error("'min' must be less than 'max'");
144        }
145
146        let mut rng = rand::thread_rng();
147        let value: f64 = rng.gen_range(min..max);
148
149        let output = NumberOutput { value, min, max };
150        self.to_result(&output)
151    }
152
153    fn handle_integer(&self, input: &RandomInput) -> ToolResult {
154        let min = input.min.unwrap_or(0.0) as i64;
155        let max = input.max.unwrap_or(100.0) as i64;
156
157        if min >= max {
158            return ToolResult::error("'min' must be less than 'max'");
159        }
160
161        let mut rng = rand::thread_rng();
162        let value: i64 = rng.gen_range(min..=max);
163
164        let output = IntegerOutput { value, min, max };
165        self.to_result(&output)
166    }
167
168    fn handle_choice(&self, input: &RandomInput) -> ToolResult {
169        let items = match &input.items {
170            Some(i) if !i.is_empty() => i,
171            Some(_) => return ToolResult::error("'items' cannot be empty"),
172            None => return ToolResult::error("'items' is required for choice operation"),
173        };
174
175        let count = input.count.unwrap_or(1).min(items.len());
176
177        let mut rng = rand::thread_rng();
178        let selected: Vec<Value> = items.choose_multiple(&mut rng, count).cloned().collect();
179
180        let output = ChoiceOutput {
181            count: selected.len(),
182            selected,
183        };
184        self.to_result(&output)
185    }
186
187    fn handle_shuffle(&self, input: &RandomInput) -> ToolResult {
188        let items = match &input.items {
189            Some(i) => i.clone(),
190            None => return ToolResult::error("'items' is required for shuffle operation"),
191        };
192
193        let mut shuffled = items;
194        let mut rng = rand::thread_rng();
195        shuffled.shuffle(&mut rng);
196
197        let output = ShuffleOutput {
198            count: shuffled.len(),
199            shuffled,
200        };
201        self.to_result(&output)
202    }
203
204    fn handle_bool(&self) -> ToolResult {
205        let mut rng = rand::thread_rng();
206        let output = BoolOutput { value: rng.r#gen() };
207        self.to_result(&output)
208    }
209
210    fn handle_string(&self, input: &RandomInput) -> ToolResult {
211        let length = input.length.unwrap_or(16);
212        let charset = input.charset.as_deref().unwrap_or("alphanumeric");
213
214        let chars: Vec<char> = match charset.to_lowercase().as_str() {
215            "alphanumeric" | "alnum" => {
216                "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
217                    .chars()
218                    .collect()
219            }
220            "alpha" | "letters" => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
221                .chars()
222                .collect(),
223            "numeric" | "digits" | "numbers" => "0123456789".chars().collect(),
224            "hex" | "hexadecimal" => "0123456789abcdef".chars().collect(),
225            "lower" | "lowercase" => "abcdefghijklmnopqrstuvwxyz".chars().collect(),
226            "upper" | "uppercase" => "ABCDEFGHIJKLMNOPQRSTUVWXYZ".chars().collect(),
227            _ => {
228                return ToolResult::error(format!(
229                    "Unknown charset: {}. Valid: alphanumeric, alpha, numeric, hex, lower, upper",
230                    charset
231                ));
232            }
233        };
234
235        let mut rng = rand::thread_rng();
236        let value: String = (0..length)
237            .map(|_| chars[rng.gen_range(0..chars.len())])
238            .collect();
239
240        let output = StringOutput { value, length };
241        self.to_result(&output)
242    }
243
244    fn to_result<T: Serialize>(&self, output: &T) -> ToolResult {
245        match serde_json::to_string(output) {
246            Ok(json) => ToolResult::ok(json),
247            Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[tokio::test]
257    async fn test_uuid() {
258        let tool = RandomTool::new();
259        let result = tool.execute(serde_json::json!({"operation": "uuid"})).await;
260        assert!(result.success);
261
262        let output: UuidOutput = serde_json::from_str(&result.output).unwrap();
263        assert_eq!(output.uuid.len(), 36);
264        assert!(output.uuid.contains('-'));
265    }
266
267    #[tokio::test]
268    async fn test_number() {
269        let tool = RandomTool::new();
270        let result = tool
271            .execute(serde_json::json!({
272                "operation": "number",
273                "min": 10.0,
274                "max": 20.0
275            }))
276            .await;
277        assert!(result.success);
278
279        let output: NumberOutput = serde_json::from_str(&result.output).unwrap();
280        assert!(output.value >= 10.0 && output.value < 20.0);
281    }
282
283    #[tokio::test]
284    async fn test_number_invalid_range() {
285        let tool = RandomTool::new();
286        let result = tool
287            .execute(serde_json::json!({
288                "operation": "number",
289                "min": 20.0,
290                "max": 10.0
291            }))
292            .await;
293        assert!(!result.success);
294    }
295
296    #[tokio::test]
297    async fn test_integer() {
298        let tool = RandomTool::new();
299        let result = tool
300            .execute(serde_json::json!({
301                "operation": "integer",
302                "min": 1,
303                "max": 10
304            }))
305            .await;
306        assert!(result.success);
307
308        let output: IntegerOutput = serde_json::from_str(&result.output).unwrap();
309        assert!(output.value >= 1 && output.value <= 10);
310    }
311
312    #[tokio::test]
313    async fn test_choice() {
314        let tool = RandomTool::new();
315        let result = tool
316            .execute(serde_json::json!({
317                "operation": "choice",
318                "items": ["a", "b", "c", "d"],
319                "count": 2
320            }))
321            .await;
322        assert!(result.success);
323
324        let output: ChoiceOutput = serde_json::from_str(&result.output).unwrap();
325        assert_eq!(output.count, 2);
326        assert_eq!(output.selected.len(), 2);
327    }
328
329    #[tokio::test]
330    async fn test_choice_single() {
331        let tool = RandomTool::new();
332        let result = tool
333            .execute(serde_json::json!({
334                "operation": "choice",
335                "items": [1, 2, 3]
336            }))
337            .await;
338        assert!(result.success);
339
340        let output: ChoiceOutput = serde_json::from_str(&result.output).unwrap();
341        assert_eq!(output.count, 1);
342    }
343
344    #[tokio::test]
345    async fn test_choice_empty() {
346        let tool = RandomTool::new();
347        let result = tool
348            .execute(serde_json::json!({
349                "operation": "choice",
350                "items": []
351            }))
352            .await;
353        assert!(!result.success);
354    }
355
356    #[tokio::test]
357    async fn test_shuffle() {
358        let tool = RandomTool::new();
359        let result = tool
360            .execute(serde_json::json!({
361                "operation": "shuffle",
362                "items": [1, 2, 3, 4, 5]
363            }))
364            .await;
365        assert!(result.success);
366
367        let output: ShuffleOutput = serde_json::from_str(&result.output).unwrap();
368        assert_eq!(output.count, 5);
369    }
370
371    #[tokio::test]
372    async fn test_bool() {
373        let tool = RandomTool::new();
374        let result = tool.execute(serde_json::json!({"operation": "bool"})).await;
375        assert!(result.success);
376
377        let output: BoolOutput = serde_json::from_str(&result.output).unwrap();
378        assert!(output.value == true || output.value == false);
379    }
380
381    #[tokio::test]
382    async fn test_string_default() {
383        let tool = RandomTool::new();
384        let result = tool
385            .execute(serde_json::json!({"operation": "string"}))
386            .await;
387        assert!(result.success);
388
389        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
390        assert_eq!(output.length, 16);
391        assert_eq!(output.value.len(), 16);
392    }
393
394    #[tokio::test]
395    async fn test_string_hex() {
396        let tool = RandomTool::new();
397        let result = tool
398            .execute(serde_json::json!({
399                "operation": "string",
400                "length": 8,
401                "charset": "hex"
402            }))
403            .await;
404        assert!(result.success);
405
406        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
407        assert_eq!(output.length, 8);
408        assert!(output.value.chars().all(|c| c.is_ascii_hexdigit()));
409    }
410
411    #[tokio::test]
412    async fn test_string_numeric() {
413        let tool = RandomTool::new();
414        let result = tool
415            .execute(serde_json::json!({
416                "operation": "string",
417                "length": 10,
418                "charset": "numeric"
419            }))
420            .await;
421        assert!(result.success);
422
423        let output: StringOutput = serde_json::from_str(&result.output).unwrap();
424        assert!(output.value.chars().all(|c| c.is_ascii_digit()));
425    }
426
427    #[tokio::test]
428    async fn test_invalid_operation() {
429        let tool = RandomTool::new();
430        let result = tool
431            .execute(serde_json::json!({"operation": "invalid"}))
432            .await;
433        assert!(!result.success);
434    }
435}