Skip to main content

objectiveai_sdk/functions/check/example_inputs/
any_of.rs

1use rand::Rng;
2use rand::SeedableRng;
3use rand::rngs::StdRng;
4use rand::seq::SliceRandom;
5
6use crate::functions::expression::{AnyOfInputSchema, InputValue};
7
8fn max_inner_permutations(schema: &AnyOfInputSchema) -> usize {
9    schema
10        .any_of
11        .iter()
12        .map(|v| super::optional::inner_permutations(v))
13        .max()
14        .unwrap_or(0)
15}
16
17pub fn permutations(schema: &AnyOfInputSchema) -> usize {
18    schema
19        .any_of
20        .len()
21        .saturating_mul(max_inner_permutations(schema))
22}
23
24pub fn generate<R: Rng>(schema: &AnyOfInputSchema, mut rng: R) -> Generator<R> {
25    let variant_count = schema.any_of.len();
26    let max_inner = max_inner_permutations(schema);
27
28    let generators: Vec<super::multi::Generator> = schema
29        .any_of
30        .iter()
31        .map(|v| {
32            super::multi::generate(
33                v,
34                StdRng::seed_from_u64(rng.random::<u64>()),
35            )
36        })
37        .collect();
38
39    // Each variant appears max_inner times per cycle
40    let mut order: Vec<usize> = (0..variant_count)
41        .flat_map(|i| std::iter::repeat(i).take(max_inner))
42        .collect();
43    order.shuffle(&mut rng);
44
45    Generator {
46        generators,
47        order,
48        pos: 0,
49        rng,
50    }
51}
52
53pub struct Generator<R: Rng> {
54    generators: Vec<super::multi::Generator>,
55    order: Vec<usize>,
56    pos: usize,
57    rng: R,
58}
59
60impl<R: Rng> Iterator for Generator<R> {
61    type Item = InputValue;
62    fn next(&mut self) -> Option<InputValue> {
63        if self.order.is_empty() {
64            return None;
65        }
66        if self.pos >= self.order.len() {
67            self.order.shuffle(&mut self.rng);
68            self.pos = 0;
69        }
70        let variant_idx = self.order[self.pos];
71        self.pos += 1;
72        self.generators[variant_idx].next()
73    }
74}