Skip to main content

objectiveai_sdk/functions/check/example_inputs/
number.rs

1use rand::Rng;
2use rand::seq::SliceRandom;
3
4use crate::functions::expression::{InputValue, NumberInputSchema};
5
6pub fn permutations(schema: &NumberInputSchema) -> usize {
7    let min = schema.minimum;
8    let max = schema.maximum;
9    let mut count = 0;
10
11    if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m >= 0.0) {
12        count += 1;
13    }
14    if let Some(m) = min {
15        if m != 0.0 {
16            count += 1;
17        }
18    }
19    if let Some(m) = max {
20        if m != 0.0 {
21            count += 1;
22        }
23    }
24    if min.is_none() {
25        count += 1;
26    }
27    if max.is_none() {
28        count += 1;
29    }
30    if min.map_or(true, |m| m < 0.0) && max.map_or(true, |m| m >= 0.0) {
31        count += 1;
32    }
33    if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m > 0.0) {
34        count += 1;
35    }
36
37    count.max(1)
38}
39
40#[derive(Clone, Copy)]
41enum Variant {
42    Zero,
43    Min(f64),
44    Max(f64),
45    RandomNegative,
46    RandomPositive,
47    DecimalNeg,
48    DecimalPos,
49}
50
51fn variants(schema: &NumberInputSchema) -> Vec<Variant> {
52    let min = schema.minimum;
53    let max = schema.maximum;
54    let mut v = Vec::with_capacity(7);
55
56    if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m >= 0.0) {
57        v.push(Variant::Zero);
58    }
59    if let Some(m) = min {
60        if m != 0.0 {
61            v.push(Variant::Min(m));
62        }
63    }
64    if let Some(m) = max {
65        if m != 0.0 {
66            v.push(Variant::Max(m));
67        }
68    }
69    if min.is_none() {
70        v.push(Variant::RandomNegative);
71    }
72    if max.is_none() {
73        v.push(Variant::RandomPositive);
74    }
75    if min.map_or(true, |m| m < 0.0) && max.map_or(true, |m| m >= 0.0) {
76        v.push(Variant::DecimalNeg);
77    }
78    if min.map_or(true, |m| m <= 0.0) && max.map_or(true, |m| m > 0.0) {
79        v.push(Variant::DecimalPos);
80    }
81    if v.is_empty() {
82        v.push(Variant::Zero);
83    }
84    v
85}
86
87pub fn generate<R: Rng>(
88    schema: &NumberInputSchema,
89    mut rng: R,
90) -> Generator<R> {
91    let vars = variants(schema);
92    let mut indices: Vec<usize> = (0..vars.len()).collect();
93    indices.shuffle(&mut rng);
94    Generator {
95        variants: vars,
96        indices,
97        pos: 0,
98        rng,
99        min: schema.minimum,
100        max: schema.maximum,
101    }
102}
103
104pub struct Generator<R: Rng> {
105    variants: Vec<Variant>,
106    indices: Vec<usize>,
107    pos: usize,
108    rng: R,
109    min: Option<f64>,
110    max: Option<f64>,
111}
112
113impl<R: Rng> Iterator for Generator<R> {
114    type Item = InputValue;
115    fn next(&mut self) -> Option<InputValue> {
116        if self.pos >= self.indices.len() {
117            self.indices.shuffle(&mut self.rng);
118            self.pos = 0;
119        }
120        let index = self.indices[self.pos];
121        self.pos += 1;
122        let value = match self.variants[index] {
123            Variant::Zero => 0.0,
124            Variant::Min(m) => m,
125            Variant::Max(m) => m,
126            Variant::RandomNegative => {
127                let upper = self.max.unwrap_or(-1.0).min(-1.0);
128                self.rng.random_range(-100.0..=upper)
129            }
130            Variant::RandomPositive => {
131                let lower = self.min.unwrap_or(1.0).max(1.0);
132                self.rng.random_range(lower..=100.0)
133            }
134            Variant::DecimalNeg => self.rng.random_range(-1.0..0.0),
135            Variant::DecimalPos => self.rng.random_range(0.0..1.0),
136        };
137        Some(InputValue::Number(value))
138    }
139}