1use async_trait::async_trait;
2use rand::{Rng, seq::SliceRandom};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use crate::generate_schema;
8use ai_agents_core::{Tool, ToolResult};
9
10pub struct RandomTool;
11
12impl RandomTool {
13 pub fn new() -> Self {
14 Self
15 }
16}
17
18impl Default for RandomTool {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24#[derive(Debug, Deserialize, JsonSchema)]
25struct RandomInput {
26 operation: String,
28 #[serde(default)]
30 min: Option<f64>,
31 #[serde(default)]
33 max: Option<f64>,
34 #[serde(default)]
36 items: Option<Vec<Value>>,
37 #[serde(default)]
39 count: Option<usize>,
40 #[serde(default)]
42 length: Option<usize>,
43 #[serde(default)]
45 charset: Option<String>,
46}
47
48#[derive(Debug, Serialize, Deserialize)]
49struct UuidOutput {
50 uuid: String,
51}
52
53#[derive(Debug, Serialize, Deserialize)]
54struct NumberOutput {
55 value: f64,
56 min: f64,
57 max: f64,
58}
59
60#[derive(Debug, Serialize, Deserialize)]
61struct IntegerOutput {
62 value: i64,
63 min: i64,
64 max: i64,
65}
66
67#[derive(Debug, Serialize, Deserialize)]
68struct ChoiceOutput {
69 selected: Vec<Value>,
70 count: usize,
71}
72
73#[derive(Debug, Serialize, Deserialize)]
74struct ShuffleOutput {
75 shuffled: Vec<Value>,
76 count: usize,
77}
78
79#[derive(Debug, Serialize, Deserialize)]
80struct BoolOutput {
81 value: bool,
82}
83
84#[derive(Debug, Serialize, Deserialize)]
85struct StringOutput {
86 value: String,
87 length: usize,
88}
89
90#[async_trait]
91impl Tool for RandomTool {
92 fn id(&self) -> &str {
93 "random"
94 }
95
96 fn name(&self) -> &str {
97 "Random Generator"
98 }
99
100 fn description(&self) -> &str {
101 "Generate random values. Operations: uuid (generate UUID v4), number (random float), integer (random int), choice (pick from list), shuffle (randomize list order), bool (random true/false), string (random string)."
102 }
103
104 fn input_schema(&self) -> Value {
105 generate_schema::<RandomInput>()
106 }
107
108 async fn execute(&self, args: Value) -> ToolResult {
109 let input: RandomInput = match serde_json::from_value(args) {
110 Ok(input) => input,
111 Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
112 };
113
114 match input.operation.to_lowercase().as_str() {
115 "uuid" => self.handle_uuid(),
116 "number" => self.handle_number(&input),
117 "integer" | "int" => self.handle_integer(&input),
118 "choice" | "choose" | "pick" => self.handle_choice(&input),
119 "shuffle" => self.handle_shuffle(&input),
120 "bool" | "boolean" => self.handle_bool(),
121 "string" | "str" => self.handle_string(&input),
122 _ => ToolResult::error(format!(
123 "Unknown operation: {}. Valid operations: uuid, number, integer, choice, shuffle, bool, string",
124 input.operation
125 )),
126 }
127 }
128}
129
130impl RandomTool {
131 fn handle_uuid(&self) -> ToolResult {
132 let output = UuidOutput {
133 uuid: uuid::Uuid::new_v4().to_string(),
134 };
135 self.to_result(&output)
136 }
137
138 fn handle_number(&self, input: &RandomInput) -> ToolResult {
139 let min = input.min.unwrap_or(0.0);
140 let max = input.max.unwrap_or(1.0);
141
142 if min >= max {
143 return ToolResult::error("'min' must be less than 'max'");
144 }
145
146 let mut rng = rand::thread_rng();
147 let value: f64 = rng.gen_range(min..max);
148
149 let output = NumberOutput { value, min, max };
150 self.to_result(&output)
151 }
152
153 fn handle_integer(&self, input: &RandomInput) -> ToolResult {
154 let min = input.min.unwrap_or(0.0) as i64;
155 let max = input.max.unwrap_or(100.0) as i64;
156
157 if min >= max {
158 return ToolResult::error("'min' must be less than 'max'");
159 }
160
161 let mut rng = rand::thread_rng();
162 let value: i64 = rng.gen_range(min..=max);
163
164 let output = IntegerOutput { value, min, max };
165 self.to_result(&output)
166 }
167
168 fn handle_choice(&self, input: &RandomInput) -> ToolResult {
169 let items = match &input.items {
170 Some(i) if !i.is_empty() => i,
171 Some(_) => return ToolResult::error("'items' cannot be empty"),
172 None => return ToolResult::error("'items' is required for choice operation"),
173 };
174
175 let count = input.count.unwrap_or(1).min(items.len());
176
177 let mut rng = rand::thread_rng();
178 let selected: Vec<Value> = items.choose_multiple(&mut rng, count).cloned().collect();
179
180 let output = ChoiceOutput {
181 count: selected.len(),
182 selected,
183 };
184 self.to_result(&output)
185 }
186
187 fn handle_shuffle(&self, input: &RandomInput) -> ToolResult {
188 let items = match &input.items {
189 Some(i) => i.clone(),
190 None => return ToolResult::error("'items' is required for shuffle operation"),
191 };
192
193 let mut shuffled = items;
194 let mut rng = rand::thread_rng();
195 shuffled.shuffle(&mut rng);
196
197 let output = ShuffleOutput {
198 count: shuffled.len(),
199 shuffled,
200 };
201 self.to_result(&output)
202 }
203
204 fn handle_bool(&self) -> ToolResult {
205 let mut rng = rand::thread_rng();
206 let output = BoolOutput { value: rng.r#gen() };
207 self.to_result(&output)
208 }
209
210 fn handle_string(&self, input: &RandomInput) -> ToolResult {
211 let length = input.length.unwrap_or(16);
212 let charset = input.charset.as_deref().unwrap_or("alphanumeric");
213
214 let chars: Vec<char> = match charset.to_lowercase().as_str() {
215 "alphanumeric" | "alnum" => {
216 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
217 .chars()
218 .collect()
219 }
220 "alpha" | "letters" => "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
221 .chars()
222 .collect(),
223 "numeric" | "digits" | "numbers" => "0123456789".chars().collect(),
224 "hex" | "hexadecimal" => "0123456789abcdef".chars().collect(),
225 "lower" | "lowercase" => "abcdefghijklmnopqrstuvwxyz".chars().collect(),
226 "upper" | "uppercase" => "ABCDEFGHIJKLMNOPQRSTUVWXYZ".chars().collect(),
227 _ => {
228 return ToolResult::error(format!(
229 "Unknown charset: {}. Valid: alphanumeric, alpha, numeric, hex, lower, upper",
230 charset
231 ));
232 }
233 };
234
235 let mut rng = rand::thread_rng();
236 let value: String = (0..length)
237 .map(|_| chars[rng.gen_range(0..chars.len())])
238 .collect();
239
240 let output = StringOutput { value, length };
241 self.to_result(&output)
242 }
243
244 fn to_result<T: Serialize>(&self, output: &T) -> ToolResult {
245 match serde_json::to_string(output) {
246 Ok(json) => ToolResult::ok(json),
247 Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[tokio::test]
257 async fn test_uuid() {
258 let tool = RandomTool::new();
259 let result = tool.execute(serde_json::json!({"operation": "uuid"})).await;
260 assert!(result.success);
261
262 let output: UuidOutput = serde_json::from_str(&result.output).unwrap();
263 assert_eq!(output.uuid.len(), 36);
264 assert!(output.uuid.contains('-'));
265 }
266
267 #[tokio::test]
268 async fn test_number() {
269 let tool = RandomTool::new();
270 let result = tool
271 .execute(serde_json::json!({
272 "operation": "number",
273 "min": 10.0,
274 "max": 20.0
275 }))
276 .await;
277 assert!(result.success);
278
279 let output: NumberOutput = serde_json::from_str(&result.output).unwrap();
280 assert!(output.value >= 10.0 && output.value < 20.0);
281 }
282
283 #[tokio::test]
284 async fn test_number_invalid_range() {
285 let tool = RandomTool::new();
286 let result = tool
287 .execute(serde_json::json!({
288 "operation": "number",
289 "min": 20.0,
290 "max": 10.0
291 }))
292 .await;
293 assert!(!result.success);
294 }
295
296 #[tokio::test]
297 async fn test_integer() {
298 let tool = RandomTool::new();
299 let result = tool
300 .execute(serde_json::json!({
301 "operation": "integer",
302 "min": 1,
303 "max": 10
304 }))
305 .await;
306 assert!(result.success);
307
308 let output: IntegerOutput = serde_json::from_str(&result.output).unwrap();
309 assert!(output.value >= 1 && output.value <= 10);
310 }
311
312 #[tokio::test]
313 async fn test_choice() {
314 let tool = RandomTool::new();
315 let result = tool
316 .execute(serde_json::json!({
317 "operation": "choice",
318 "items": ["a", "b", "c", "d"],
319 "count": 2
320 }))
321 .await;
322 assert!(result.success);
323
324 let output: ChoiceOutput = serde_json::from_str(&result.output).unwrap();
325 assert_eq!(output.count, 2);
326 assert_eq!(output.selected.len(), 2);
327 }
328
329 #[tokio::test]
330 async fn test_choice_single() {
331 let tool = RandomTool::new();
332 let result = tool
333 .execute(serde_json::json!({
334 "operation": "choice",
335 "items": [1, 2, 3]
336 }))
337 .await;
338 assert!(result.success);
339
340 let output: ChoiceOutput = serde_json::from_str(&result.output).unwrap();
341 assert_eq!(output.count, 1);
342 }
343
344 #[tokio::test]
345 async fn test_choice_empty() {
346 let tool = RandomTool::new();
347 let result = tool
348 .execute(serde_json::json!({
349 "operation": "choice",
350 "items": []
351 }))
352 .await;
353 assert!(!result.success);
354 }
355
356 #[tokio::test]
357 async fn test_shuffle() {
358 let tool = RandomTool::new();
359 let result = tool
360 .execute(serde_json::json!({
361 "operation": "shuffle",
362 "items": [1, 2, 3, 4, 5]
363 }))
364 .await;
365 assert!(result.success);
366
367 let output: ShuffleOutput = serde_json::from_str(&result.output).unwrap();
368 assert_eq!(output.count, 5);
369 }
370
371 #[tokio::test]
372 async fn test_bool() {
373 let tool = RandomTool::new();
374 let result = tool.execute(serde_json::json!({"operation": "bool"})).await;
375 assert!(result.success);
376
377 let output: BoolOutput = serde_json::from_str(&result.output).unwrap();
378 assert!(output.value == true || output.value == false);
379 }
380
381 #[tokio::test]
382 async fn test_string_default() {
383 let tool = RandomTool::new();
384 let result = tool
385 .execute(serde_json::json!({"operation": "string"}))
386 .await;
387 assert!(result.success);
388
389 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
390 assert_eq!(output.length, 16);
391 assert_eq!(output.value.len(), 16);
392 }
393
394 #[tokio::test]
395 async fn test_string_hex() {
396 let tool = RandomTool::new();
397 let result = tool
398 .execute(serde_json::json!({
399 "operation": "string",
400 "length": 8,
401 "charset": "hex"
402 }))
403 .await;
404 assert!(result.success);
405
406 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
407 assert_eq!(output.length, 8);
408 assert!(output.value.chars().all(|c| c.is_ascii_hexdigit()));
409 }
410
411 #[tokio::test]
412 async fn test_string_numeric() {
413 let tool = RandomTool::new();
414 let result = tool
415 .execute(serde_json::json!({
416 "operation": "string",
417 "length": 10,
418 "charset": "numeric"
419 }))
420 .await;
421 assert!(result.success);
422
423 let output: StringOutput = serde_json::from_str(&result.output).unwrap();
424 assert!(output.value.chars().all(|c| c.is_ascii_digit()));
425 }
426
427 #[tokio::test]
428 async fn test_invalid_operation() {
429 let tool = RandomTool::new();
430 let result = tool
431 .execute(serde_json::json!({"operation": "invalid"}))
432 .await;
433 assert!(!result.success);
434 }
435}