Skip to main content

objectiveai_sdk/functions/check/example_inputs/
optional.rs

1use rand::Rng;
2use rand::rngs::StdRng;
3use rand::seq::SliceRandom;
4use rand::SeedableRng;
5
6use crate::functions::expression::{InputValue, InputSchema};
7
8pub fn permutations(schema: &InputSchema) -> usize {
9    inner_permutations(schema) * 2
10}
11
12pub fn inner_permutations(schema: &InputSchema) -> usize {
13    match schema {
14        InputSchema::Boolean(s) => super::boolean::permutations(s),
15        InputSchema::String(s) => super::string::permutations(s),
16        InputSchema::Integer(s) => super::integer::permutations(s),
17        InputSchema::Number(s) => super::number::permutations(s),
18        InputSchema::Image(s) => super::image::permutations(s),
19        InputSchema::Audio(s) => super::audio::permutations(s),
20        InputSchema::Video(s) => super::video::permutations(s),
21        InputSchema::File(s) => super::file::permutations(s),
22        InputSchema::Object(s) => super::object::permutations(s),
23        InputSchema::Array(s) => super::array::permutations(s),
24        InputSchema::AnyOf(s) => super::any_of::permutations(s),
25    }
26}
27
28pub fn generate(schema: &InputSchema, mut rng: StdRng) -> Generator {
29    let inner_count = inner_permutations(schema);
30    // 0..inner_count = present, inner_count..inner_count*2 = absent
31    let total = inner_count * 2;
32    let mut indices: Vec<usize> = (0..total).collect();
33    indices.shuffle(&mut rng);
34
35    let inner = super::multi::generate(schema, StdRng::seed_from_u64(rng.random::<u64>()));
36
37    Generator {
38        inner,
39        inner_count,
40        indices,
41        pos: 0,
42        rng,
43    }
44}
45
46pub struct Generator {
47    inner: super::multi::Generator,
48    inner_count: usize,
49    indices: Vec<usize>,
50    pos: usize,
51    rng: StdRng,
52}
53
54impl Iterator for Generator {
55    type Item = Option<InputValue>;
56    fn next(&mut self) -> Option<Option<InputValue>> {
57        if self.indices.is_empty() {
58            return Some(None);
59        }
60        if self.pos >= self.indices.len() {
61            self.indices.shuffle(&mut self.rng);
62            self.pos = 0;
63        }
64        let index = self.indices[self.pos];
65        self.pos += 1;
66        if index < self.inner_count {
67            Some(self.inner.next())
68        } else {
69            Some(None)
70        }
71    }
72}