jpx_core/extensions/
random.rs1use std::collections::HashSet;
4
5use serde_json::Value;
6
7use crate::functions::{Function, custom_error, number_value};
8use crate::interpreter::SearchResult;
9use crate::registry::register_if_enabled;
10use crate::{Context, Runtime, defn};
11
12pub fn register_filtered(runtime: &mut Runtime, enabled: &HashSet<&str>) {
14 register_if_enabled(runtime, "random", enabled, Box::new(RandomFn::new()));
15 register_if_enabled(
16 runtime,
17 "random_choice",
18 enabled,
19 Box::new(RandomChoiceFn::new()),
20 );
21 register_if_enabled(runtime, "random_int", enabled, Box::new(RandomIntFn::new()));
22 register_if_enabled(runtime, "sample", enabled, Box::new(SampleFn::new()));
23 register_if_enabled(runtime, "shuffle", enabled, Box::new(ShuffleFn::new()));
24 register_if_enabled(runtime, "uuid", enabled, Box::new(UuidFn::new()));
25}
26
27pub struct RandomFn;
33
34impl Default for RandomFn {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl RandomFn {
41 pub fn new() -> RandomFn {
42 RandomFn
43 }
44}
45
46impl Function for RandomFn {
47 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
48 use rand::Rng;
49
50 if !args.is_empty() && args.len() != 2 {
52 return Err(custom_error(ctx, "random() takes 0 or 2 arguments"));
53 }
54
55 let mut rng = rand::thread_rng();
56
57 let value: f64 = if args.is_empty() {
58 rng.gen_range(0.0..1.0)
60 } else {
61 let min = args[0]
63 .as_f64()
64 .ok_or_else(|| custom_error(ctx, "Expected number for min"))?;
65 let max = args[1]
66 .as_f64()
67 .ok_or_else(|| custom_error(ctx, "Expected number for max"))?;
68 rng.gen_range(min..max)
69 };
70
71 Ok(number_value(value))
72 }
73}
74
75pub struct RandomChoiceFn;
80
81impl Default for RandomChoiceFn {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87impl RandomChoiceFn {
88 pub fn new() -> RandomChoiceFn {
89 RandomChoiceFn
90 }
91}
92
93impl Function for RandomChoiceFn {
94 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
95 use rand::seq::SliceRandom;
96
97 if args.len() != 1 {
98 return Err(custom_error(ctx, "random_choice() takes 1 argument"));
99 }
100
101 let arr = args[0]
102 .as_array()
103 .ok_or_else(|| custom_error(ctx, "Expected array argument"))?;
104
105 if arr.is_empty() {
106 return Ok(Value::Null);
107 }
108
109 let chosen = arr
110 .choose(&mut rand::thread_rng())
111 .cloned()
112 .unwrap_or(Value::Null);
113
114 Ok(chosen)
115 }
116}
117
118pub struct RandomIntFn;
123
124impl Default for RandomIntFn {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl RandomIntFn {
131 pub fn new() -> RandomIntFn {
132 RandomIntFn
133 }
134}
135
136impl Function for RandomIntFn {
137 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
138 use rand::Rng;
139
140 if args.len() != 2 {
141 return Err(custom_error(ctx, "random_int() takes 2 arguments"));
142 }
143
144 let min = args[0]
145 .as_f64()
146 .ok_or_else(|| custom_error(ctx, "Expected number for min"))? as i64;
147 let max = args[1]
148 .as_f64()
149 .ok_or_else(|| custom_error(ctx, "Expected number for max"))? as i64;
150
151 if min > max {
152 return Err(custom_error(ctx, "min must be less than or equal to max"));
153 }
154
155 let value = rand::thread_rng().gen_range(min..=max);
156 Ok(Value::Number(serde_json::Number::from(value)))
157 }
158}
159
160pub struct ShuffleFn;
166
167impl Default for ShuffleFn {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl ShuffleFn {
174 pub fn new() -> ShuffleFn {
175 ShuffleFn
176 }
177}
178
179impl Function for ShuffleFn {
180 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
181 if args.is_empty() || args.len() > 2 {
183 return Err(custom_error(ctx, "shuffle() takes 1 or 2 arguments"));
184 }
185
186 let arr = args[0]
187 .as_array()
188 .ok_or_else(|| custom_error(ctx, "Expected array argument"))?;
189
190 use rand::SeedableRng;
191 use rand::seq::SliceRandom;
192
193 let mut result: Vec<Value> = arr.clone();
194
195 if args.len() == 2 {
196 let seed = args[1]
198 .as_f64()
199 .ok_or_else(|| custom_error(ctx, "Expected number for seed"))?
200 as u64;
201 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
202 result.shuffle(&mut rng);
203 } else {
204 result.shuffle(&mut rand::thread_rng());
206 }
207
208 Ok(Value::Array(result))
209 }
210}
211
212pub struct SampleFn;
218
219impl Default for SampleFn {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225impl SampleFn {
226 pub fn new() -> SampleFn {
227 SampleFn
228 }
229}
230
231impl Function for SampleFn {
232 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
233 if args.len() < 2 || args.len() > 3 {
235 return Err(custom_error(ctx, "sample() takes 2 or 3 arguments"));
236 }
237
238 let arr = args[0]
239 .as_array()
240 .ok_or_else(|| custom_error(ctx, "Expected array argument"))?;
241
242 let n = args[1]
243 .as_f64()
244 .ok_or_else(|| custom_error(ctx, "Expected number argument"))? as usize;
245
246 use rand::SeedableRng;
247 use rand::seq::SliceRandom;
248
249 let sample: Vec<Value> = if args.len() == 3 {
250 let seed = args[2]
252 .as_f64()
253 .ok_or_else(|| custom_error(ctx, "Expected number for seed"))?
254 as u64;
255 let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
256 arr.choose_multiple(&mut rng, n.min(arr.len()))
257 .cloned()
258 .collect()
259 } else {
260 arr.choose_multiple(&mut rand::thread_rng(), n.min(arr.len()))
262 .cloned()
263 .collect()
264 };
265
266 Ok(Value::Array(sample))
267 }
268}
269
270defn!(UuidFn, vec![], None);
275
276impl Function for UuidFn {
277 fn evaluate(&self, args: &[Value], ctx: &mut Context<'_>) -> SearchResult {
278 self.signature.validate(args, ctx)?;
279
280 let id = uuid::Uuid::new_v4();
281 Ok(Value::String(id.to_string()))
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use crate::Runtime;
288 use serde_json::{Value, json};
289
290 fn setup_runtime() -> Runtime {
291 Runtime::builder()
292 .with_standard()
293 .with_all_extensions()
294 .build()
295 }
296
297 #[test]
298 fn test_random() {
299 let runtime = setup_runtime();
300 let expr = runtime.compile("random()").unwrap();
301 let result = expr.search(&json!(null)).unwrap();
302 let value = result.as_f64().unwrap();
303 assert!((0.0..1.0).contains(&value));
304 }
305
306 #[test]
307 fn test_random_choice() {
308 let runtime = setup_runtime();
309 let data = json!(["a", "b", "c"]);
310 let expr = runtime.compile("random_choice(@)").unwrap();
311 let result = expr.search(&data).unwrap();
312 assert!(result.is_string());
313 let s = result.as_str().unwrap();
314 assert!(["a", "b", "c"].contains(&s));
315 }
316
317 #[test]
318 fn test_random_choice_single_element() {
319 let runtime = setup_runtime();
320 let data = json!([42]);
321 let expr = runtime.compile("random_choice(@)").unwrap();
322 let result = expr.search(&data).unwrap();
323 assert_eq!(result, json!(42));
324 }
325
326 #[test]
327 fn test_random_choice_empty() {
328 let runtime = setup_runtime();
329 let data = json!([]);
330 let expr = runtime.compile("random_choice(@)").unwrap();
331 let result = expr.search(&data).unwrap();
332 assert_eq!(result, Value::Null);
333 }
334
335 #[test]
336 fn test_random_int() {
337 let runtime = setup_runtime();
338 let expr = runtime.compile("random_int(`1`, `10`)").unwrap();
339 let result = expr.search(&json!(null)).unwrap();
340 let value = result.as_i64().unwrap();
341 assert!((1..=10).contains(&value));
342 }
343
344 #[test]
345 fn test_random_int_min_equals_max() {
346 let runtime = setup_runtime();
347 let expr = runtime.compile("random_int(`5`, `5`)").unwrap();
348 let result = expr.search(&json!(null)).unwrap();
349 assert_eq!(result, json!(5));
350 }
351
352 #[test]
353 fn test_random_int_min_greater_than_max() {
354 let runtime = setup_runtime();
355 let expr = runtime.compile("random_int(`10`, `1`)").unwrap();
356 let result = expr.search(&json!(null));
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_shuffle() {
362 let runtime = setup_runtime();
363 let data = json!([1, 2, 3]);
364 let expr = runtime.compile("shuffle(@)").unwrap();
365 let result = expr.search(&data).unwrap();
366 let arr = result.as_array().unwrap();
367 assert_eq!(arr.len(), 3);
368 }
369
370 #[test]
371 fn test_uuid() {
372 let runtime = setup_runtime();
373 let expr = runtime.compile("uuid()").unwrap();
374 let result = expr.search(&json!(null)).unwrap();
375 let uuid_str = result.as_str().unwrap();
376 assert_eq!(uuid_str.len(), 36);
378 }
379}