Skip to main content

objectiveai_sdk/functions/check/
check_vector_fields.rs

1//! Validation of vector function fields (output_length, input_split, input_merge).
2//!
3//! Verifies that these expressions work correctly together via round-trip testing
4//! against randomized example inputs.
5
6use rand::Rng;
7use rand::SeedableRng;
8use rand::rngs::StdRng;
9use serde::Deserialize;
10
11use super::check_input_schema::check_input_schema;
12use super::example_inputs;
13use crate::functions::expression::{Expression, InputSchema, InputValue};
14use crate::functions::{Function, RemoteFunction};
15use schemars::JsonSchema;
16
17/// The 4 fields needed to validate a vector function's split/merge behavior.
18#[derive(Debug, Clone, Deserialize, JsonSchema)]
19#[schemars(rename = "functions.check.VectorFieldsValidation")]
20pub struct VectorFieldsValidation {
21    pub input_schema: InputSchema,
22    pub output_length: Expression,
23    pub input_split: Expression,
24    pub input_merge: Expression,
25}
26
27impl VectorFieldsValidation {
28    /// Construct a minimal `Function` for compilation, cloning our expressions.
29    fn to_function(&self) -> Function {
30        Function::Remote(RemoteFunction::Vector {
31            description: String::new(),
32            input_schema: self.input_schema.clone(),
33            tasks: vec![],
34            output_length: self.output_length.clone(),
35            input_split: self.input_split.clone(),
36            input_merge: self.input_merge.clone(),
37        })
38    }
39}
40
41/// Validate that the vector fields work together correctly.
42///
43/// Generates diverse, randomized example inputs from the `input_schema`, then
44/// validates each one via [`check_vector_fields_for_input`].
45pub fn check_vector_fields(
46    fields: VectorFieldsValidation,
47    seed: Option<i64>,
48) -> Result<(), String> {
49    // Input schema permutations
50    check_input_schema(&fields.input_schema)?;
51
52    let mut rng = match seed {
53        Some(s) => StdRng::seed_from_u64(s as u64),
54        None => StdRng::from_os_rng(),
55    };
56
57    let mut count = 0usize;
58    for ref input in example_inputs::generate_seeded(
59        &fields.input_schema,
60        StdRng::seed_from_u64(rng.random::<u64>()),
61    ) {
62        count += 1;
63        let input_label = serde_json::to_string(input).unwrap_or_default();
64        check_vector_fields_for_input(&fields, &input_label, input, &mut rng)?;
65    }
66
67    if count == 0 {
68        return Err(
69            "VF22: Failed to generate any example inputs from input_schema"
70                .to_string(),
71        );
72    }
73
74    Ok(())
75}
76
77/// Validates vector fields for a single input:
78/// 1. Compiles `output_length` — must be > 0
79/// 2. Compiles `input_split` — length must equal output_length
80/// 3. Each split element must produce output_length = 1
81/// 4. Merging all splits must reconstruct the original input
82/// 5. Merging random subsets must produce output_length = subset size
83pub(crate) fn check_vector_fields_for_input(
84    fields: &VectorFieldsValidation,
85    input_label: &str,
86    input: &InputValue,
87    rng: &mut impl Rng,
88) -> Result<(), String> {
89    // 1. Compile output_length
90    let output_length = fields
91        .to_function()
92        .compile_output_length(input)
93        .map_err(|e| {
94            format!("VF01: Input {}: output_length compilation failed: {}", input_label, e)
95        })?
96        .ok_or_else(|| {
97            format!(
98                "VF02: Input {}: output_length returned None (not a vector function?)",
99                input_label
100            )
101        })?;
102
103    if output_length < 2 {
104        return Err(format!(
105            "VF03: Input {}: output_length must be > 1 for vector functions, got {}. Try setting `minItems` to 2 in the `input_schema`.",
106            input_label, output_length,
107        ));
108    }
109
110    // 2. Compile input_split
111    let splits = fields
112        .to_function()
113        .compile_input_split(input)
114        .map_err(|e| {
115            format!(
116                "VF04: Input {}: input_split compilation failed: {}",
117                input_label, e
118            )
119        })?
120        .ok_or_else(|| {
121            format!("VF05: Input {}: input_split returned None", input_label)
122        })?;
123
124    if splits.len() as u64 != output_length {
125        return Err(format!(
126            "VF06: Input {}: input_split produced {} elements but output_length is {}",
127            input_label,
128            splits.len(),
129            output_length,
130        ));
131    }
132
133    // 3. Each split must produce output_length = 1
134    for (j, split) in splits.iter().enumerate() {
135        let split_len = fields
136            .to_function()
137            .compile_output_length(split)
138            .map_err(|e| {
139                format!(
140                    "VF07: Input {}: output_length failed for split [{}]: {}",
141                    input_label, j, e
142                )
143            })?
144            .ok_or_else(|| {
145                format!(
146                    "VF08: Input {}: output_length returned None for split [{}]",
147                    input_label, j
148                )
149            })?;
150
151        if split_len != 1 {
152            return Err(format!(
153                "VF09: Input {}: split [{}] output_length must be 1, got {}.\n\nSplit: {}",
154                input_label,
155                j,
156                split_len,
157                serde_json::to_string(split).unwrap_or_default()
158            ));
159        }
160    }
161
162    // 4. Merge all splits — must equal original input
163    let merge_input = InputValue::Array(splits.clone());
164    let merged = fields
165        .to_function()
166        .compile_input_merge(&merge_input)
167        .map_err(|e| {
168            format!(
169                "VF10: Input {}: input_merge compilation failed: {}",
170                input_label, e
171            )
172        })?
173        .ok_or_else(|| {
174            format!("VF11: Input {}: input_merge returned None", input_label)
175        })?;
176
177    if !inputs_equal(input, &merged) {
178        return Err(format!(
179            "VF12: Input {}: merged input does not match original.\n\nOriginal: {}\n\nMerged: {}",
180            input_label,
181            serde_json::to_string(input).unwrap_or_default(),
182            serde_json::to_string(&merged).unwrap_or_default()
183        ));
184    }
185
186    // 5. Merged output_length equals original output_length
187    let merged_len = fields
188        .to_function()
189        .compile_output_length(&merged)
190        .map_err(|e| {
191            format!(
192                "VF13: Input {}: output_length failed for merged input: {}",
193                input_label, e
194            )
195        })?
196        .ok_or_else(|| {
197            format!(
198                "VF14: Input {}: output_length returned None for merged input",
199                input_label
200            )
201        })?;
202
203    if merged_len != output_length {
204        return Err(format!(
205            "VF15: Input {}: merged output_length ({}) != original output_length ({})",
206            input_label, merged_len, output_length
207        ));
208    }
209
210    // 6. Random subsets — merge and verify output_length = subset size
211    //    and merged input satisfies input_schema constraints.
212    let mut subsets = random_subsets(splits.len(), 5, rng);
213    // Always test a 2-element subset deterministically so that
214    // min_items violations are caught reliably.
215    if splits.len() >= 3 {
216        subsets.insert(0, vec![0, 1]);
217    }
218    for subset in &subsets {
219        let sub_splits: Vec<InputValue> =
220            subset.iter().map(|&idx| splits[idx].clone()).collect();
221        let sub_merge_input = InputValue::Array(sub_splits);
222        let sub_merged = fields
223            .to_function()
224            .compile_input_merge(&sub_merge_input)
225            .map_err(|e| {
226                format!(
227                    "VF16: Input {}: input_merge failed for subset {:?}: {}",
228                    input_label, subset, e
229                )
230            })?
231            .ok_or_else(|| {
232                format!(
233                    "VF17: Input {}: input_merge returned None for subset {:?}",
234                    input_label, subset
235                )
236            })?;
237
238        let sub_merged_len = fields
239            .to_function()
240            .compile_output_length(&sub_merged)
241            .map_err(|e| {
242                format!(
243                    "VF18: Input {}: output_length failed for merged subset {:?}: {}",
244                    input_label, subset, e
245                )
246            })?
247            .ok_or_else(|| {
248                format!(
249                    "VF19: Input {}: output_length returned None for merged subset {:?}",
250                    input_label, subset
251                )
252            })?;
253
254        if sub_merged_len as usize != subset.len() {
255            return Err(format!(
256                "VF20: Input {}: merged subset {:?} output_length is {}, expected {}",
257                input_label,
258                subset,
259                sub_merged_len,
260                subset.len()
261            ));
262        }
263
264        // Merged subset must satisfy the input_schema constraints
265        // (e.g., min_items). This ensures the function can execute
266        // correctly with merged sub-inputs (used by swiss_system).
267        validate_input_against_schema(
268            &sub_merged,
269            &fields.input_schema,
270            "root",
271        )
272        .map_err(|e| {
273            format!(
274                "VF21: Input {}: merged subset {:?} violates input_schema: {}",
275                input_label, subset, e
276            )
277        })?;
278    }
279
280    Ok(())
281}
282
283/// Validate that an input satisfies the schema's structural constraints.
284/// Checks array min_items/max_items recursively through objects.
285fn validate_input_against_schema(
286    input: &InputValue,
287    schema: &InputSchema,
288    path: &str,
289) -> Result<(), String> {
290    match (input, schema) {
291        (InputValue::Array(arr), InputSchema::Array(arr_schema)) => {
292            if let Some(min) = arr_schema.min_items {
293                if (arr.len() as u64) < min {
294                    return Err(format!(
295                        "VF23: {}: array has {} items but min_items is {}",
296                        path,
297                        arr.len(),
298                        min
299                    ));
300                }
301            }
302            if let Some(max) = arr_schema.max_items {
303                if (arr.len() as u64) > max {
304                    return Err(format!(
305                        "VF24: {}: array has {} items but max_items is {}",
306                        path,
307                        arr.len(),
308                        max
309                    ));
310                }
311            }
312            for (i, item) in arr.iter().enumerate() {
313                validate_input_against_schema(
314                    item,
315                    &arr_schema.items,
316                    &format!("{}[{}]", path, i),
317                )?;
318            }
319            Ok(())
320        }
321        (InputValue::Object(obj), InputSchema::Object(obj_schema)) => {
322            for (key, prop_schema) in &obj_schema.properties {
323                if let Some(value) = obj.get(key) {
324                    validate_input_against_schema(
325                        value,
326                        prop_schema,
327                        &format!("{}.{}", path, key),
328                    )?;
329                }
330            }
331            Ok(())
332        }
333        _ => Ok(()),
334    }
335}
336
337/// Deep equality check for Input values.
338pub(crate) fn inputs_equal(a: &InputValue, b: &InputValue) -> bool {
339    match (a, b) {
340        (InputValue::String(a), InputValue::String(b)) => a == b,
341        (InputValue::Integer(a), InputValue::Integer(b)) => a == b,
342        (InputValue::Number(a), InputValue::Number(b)) => a == b,
343        (InputValue::Boolean(a), InputValue::Boolean(b)) => a == b,
344        (InputValue::Array(a), InputValue::Array(b)) => {
345            a.len() == b.len()
346                && a.iter().zip(b.iter()).all(|(x, y)| inputs_equal(x, y))
347        }
348        (InputValue::Object(a), InputValue::Object(b)) => {
349            a.len() == b.len()
350                && a.iter().all(|(ka, va)| {
351                    b.get(ka).is_some_and(|vb| inputs_equal(va, vb))
352                })
353        }
354        (InputValue::RichContentPart(a), InputValue::RichContentPart(b)) => {
355            a == b
356        }
357        _ => false,
358    }
359}
360
361/// Generate random subsets of indices for subset merge testing.
362pub(crate) fn random_subsets(
363    length: usize,
364    count: usize,
365    rng: &mut impl Rng,
366) -> Vec<Vec<usize>> {
367    if length < 2 {
368        return vec![];
369    }
370
371    let mut result = Vec::new();
372
373    for _ in 0..count {
374        let size = rng.random_range(2..=length);
375        let mut all_indices: Vec<usize> = (0..length).collect();
376
377        // Fisher-Yates shuffle
378        for i in (1..all_indices.len()).rev() {
379            let j = rng.random_range(0..=i);
380            all_indices.swap(i, j);
381        }
382
383        let mut subset: Vec<usize> =
384            all_indices.into_iter().take(size).collect();
385        subset.sort();
386        subset.dedup();
387
388        if subset.len() >= 2 {
389            result.push(subset);
390        }
391    }
392
393    result
394}