objectiveai_sdk/functions/check/example_inputs/
number.rs1use rand::Rng;
2use rand::seq::SliceRandom;
3
4use crate::functions::expression::{InputValue, NumberInputSchema};
5
6pub fn permutations(schema: &NumberInputSchema) -> usize {
7 let min = schema.minimum;
8 let max = schema.maximum;
9 let mut count = 0;
10
11 if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m >= 0.0) {
12 count += 1;
13 }
14 if let Some(m) = min {
15 if m != 0.0 {
16 count += 1;
17 }
18 }
19 if let Some(m) = max {
20 if m != 0.0 {
21 count += 1;
22 }
23 }
24 if min.is_none() {
25 count += 1;
26 }
27 if max.is_none() {
28 count += 1;
29 }
30 if min.map_or(true, |m| m < 0.0) && max.map_or(true, |m| m >= 0.0) {
31 count += 1;
32 }
33 if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m > 0.0) {
34 count += 1;
35 }
36
37 count.max(1)
38}
39
40#[derive(Clone, Copy)]
41enum Variant {
42 Zero,
43 Min(f64),
44 Max(f64),
45 RandomNegative,
46 RandomPositive,
47 DecimalNeg,
48 DecimalPos,
49}
50
51fn variants(schema: &NumberInputSchema) -> Vec<Variant> {
52 let min = schema.minimum;
53 let max = schema.maximum;
54 let mut v = Vec::with_capacity(7);
55
56 if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m >= 0.0) {
57 v.push(Variant::Zero);
58 }
59 if let Some(m) = min {
60 if m != 0.0 {
61 v.push(Variant::Min(m));
62 }
63 }
64 if let Some(m) = max {
65 if m != 0.0 {
66 v.push(Variant::Max(m));
67 }
68 }
69 if min.is_none() {
70 v.push(Variant::RandomNegative);
71 }
72 if max.is_none() {
73 v.push(Variant::RandomPositive);
74 }
75 if min.map_or(true, |m| m < 0.0) && max.map_or(true, |m| m >= 0.0) {
76 v.push(Variant::DecimalNeg);
77 }
78 if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m > 0.0) {
79 v.push(Variant::DecimalPos);
80 }
81 if v.is_empty() {
82 v.push(Variant::Zero);
83 }
84 v
85}
86
87pub fn generate<R: Rng>(
88 schema: &NumberInputSchema,
89 mut rng: R,
90) -> Generator<R> {
91 let vars = variants(schema);
92 let mut indices: Vec<usize> = (0..vars.len()).collect();
93 indices.shuffle(&mut rng);
94 Generator {
95 variants: vars,
96 indices,
97 pos: 0,
98 rng,
99 min: schema.minimum,
100 max: schema.maximum,
101 }
102}
103
104pub struct Generator<R: Rng> {
105 variants: Vec<Variant>,
106 indices: Vec<usize>,
107 pos: usize,
108 rng: R,
109 min: Option<f64>,
110 max: Option<f64>,
111}
112
113impl<R: Rng> Iterator for Generator<R> {
114 type Item = InputValue;
115 fn next(&mut self) -> Option<InputValue> {
116 if self.pos >= self.indices.len() {
117 self.indices.shuffle(&mut self.rng);
118 self.pos = 0;
119 }
120 let index = self.indices[self.pos];
121 self.pos += 1;
122 let value = match self.variants[index] {
123 Variant::Zero => 0.0,
124 Variant::Min(m) => m,
125 Variant::Max(m) => m,
126 Variant::RandomNegative => {
127 let upper = self.max.unwrap_or(-1.0).min(-1.0);
128 self.rng.random_range(-100.0..=upper)
129 }
130 Variant::RandomPositive => {
131 let lower = self.min.unwrap_or(1.0).max(1.0);
132 self.rng.random_range(lower..=100.0)
133 }
134 Variant::DecimalNeg => self.rng.random_range(-1.0..0.0),
135 Variant::DecimalPos => self.rng.random_range(0.0..1.0),
136 };
137 Some(InputValue::Number(value))
138 }
139}