Skip to main content

objectiveai_sdk/functions/check/
check_vector_fields.rs

1//! Validation of vector function fields (output_length, input_split, input_merge).
2//!
3//! Verifies that these expressions work correctly together via round-trip testing
4//! against randomized example inputs.
5
6use rand::Rng;
7use rand::rngs::StdRng;
8use rand::SeedableRng;
9use serde::Deserialize;
10
11use super::check_input_schema::check_input_schema;
12use super::example_inputs;
13use crate::functions::expression::{Expression, InputValue, InputSchema};
14use crate::functions::{Function, RemoteFunction};
15use schemars::JsonSchema;
16
17/// The 4 fields needed to validate a vector function's split/merge behavior.
18#[derive(Debug, Clone, Deserialize, JsonSchema)]
19#[schemars(rename = "functions.check.VectorFieldsValidation")]
20pub struct VectorFieldsValidation {
21    pub input_schema: InputSchema,
22    pub output_length: Expression,
23    pub input_split: Expression,
24    pub input_merge: Expression,
25}
26
27impl VectorFieldsValidation {
28    /// Construct a minimal `Function` for compilation, cloning our expressions.
29    fn to_function(&self) -> Function {
30        Function::Remote(RemoteFunction::Vector {
31            description: String::new(),
32            input_schema: self.input_schema.clone(),
33            tasks: vec![],
34            output_length: self.output_length.clone(),
35            input_split: self.input_split.clone(),
36            input_merge: self.input_merge.clone(),
37        })
38    }
39}
40
41/// Validate that the vector fields work together correctly.
42///
43/// Generates diverse, randomized example inputs from the `input_schema`, then
44/// validates each one via [`check_vector_fields_for_input`].
45pub fn check_vector_fields(
46    fields: VectorFieldsValidation,
47    seed: Option<i64>,
48) -> Result<(), String> {
49    // Input schema permutations
50    check_input_schema(&fields.input_schema)?;
51
52    let mut rng = match seed {
53        Some(s) => StdRng::seed_from_u64(s as u64),
54        None => StdRng::from_os_rng(),
55    };
56
57    let mut count = 0usize;
58    for ref input in example_inputs::generate_seeded(&fields.input_schema, StdRng::seed_from_u64(rng.random::<u64>())) {
59        count += 1;
60        let input_label = serde_json::to_string(input).unwrap_or_default();
61        check_vector_fields_for_input(&fields, &input_label, input, &mut rng)?;
62    }
63
64    if count == 0 {
65        return Err(
66            "VF22: Failed to generate any example inputs from input_schema"
67                .to_string(),
68        );
69    }
70
71    Ok(())
72}
73
74/// Validates vector fields for a single input:
75/// 1. Compiles `output_length` — must be > 0
76/// 2. Compiles `input_split` — length must equal output_length
77/// 3. Each split element must produce output_length = 1
78/// 4. Merging all splits must reconstruct the original input
79/// 5. Merging random subsets must produce output_length = subset size
80pub(crate) fn check_vector_fields_for_input(
81    fields: &VectorFieldsValidation,
82    input_label: &str,
83    input: &InputValue,
84    rng: &mut impl Rng,
85) -> Result<(), String> {
86    // 1. Compile output_length
87    let output_length = fields
88        .to_function()
89        .compile_output_length(input)
90        .map_err(|e| {
91            format!("VF01: Input {}: output_length compilation failed: {}", input_label, e)
92        })?
93        .ok_or_else(|| {
94            format!(
95                "VF02: Input {}: output_length returned None (not a vector function?)",
96                input_label
97            )
98        })?;
99
100    if output_length < 2 {
101        return Err(format!(
102            "VF03: Input {}: output_length must be > 1 for vector functions, got {}. Try setting `minItems` to 2 in the `input_schema`.",
103            input_label, output_length,
104        ));
105    }
106
107    // 2. Compile input_split
108    let splits = fields
109        .to_function()
110        .compile_input_split(input)
111        .map_err(|e| {
112            format!(
113                "VF04: Input {}: input_split compilation failed: {}",
114                input_label, e
115            )
116        })?
117        .ok_or_else(|| {
118            format!("VF05: Input {}: input_split returned None", input_label)
119        })?;
120
121    if splits.len() as u64 != output_length {
122        return Err(format!(
123            "VF06: Input {}: input_split produced {} elements but output_length is {}",
124            input_label,
125            splits.len(),
126            output_length,
127        ));
128    }
129
130    // 3. Each split must produce output_length = 1
131    for (j, split) in splits.iter().enumerate() {
132        let split_len = fields
133            .to_function()
134            .compile_output_length(split)
135            .map_err(|e| {
136                format!(
137                    "VF07: Input {}: output_length failed for split [{}]: {}",
138                    input_label, j, e
139                )
140            })?
141            .ok_or_else(|| {
142                format!(
143                    "VF08: Input {}: output_length returned None for split [{}]",
144                    input_label, j
145                )
146            })?;
147
148        if split_len != 1 {
149            return Err(format!(
150                "VF09: Input {}: split [{}] output_length must be 1, got {}.\n\nSplit: {}",
151                input_label,
152                j,
153                split_len,
154                serde_json::to_string(split).unwrap_or_default()
155            ));
156        }
157    }
158
159    // 4. Merge all splits — must equal original input
160    let merge_input = InputValue::Array(splits.clone());
161    let merged = fields
162        .to_function()
163        .compile_input_merge(&merge_input)
164        .map_err(|e| {
165            format!(
166                "VF10: Input {}: input_merge compilation failed: {}",
167                input_label, e
168            )
169        })?
170        .ok_or_else(|| {
171            format!("VF11: Input {}: input_merge returned None", input_label)
172        })?;
173
174    if !inputs_equal(input, &merged) {
175        return Err(format!(
176            "VF12: Input {}: merged input does not match original.\n\nOriginal: {}\n\nMerged: {}",
177            input_label,
178            serde_json::to_string(input).unwrap_or_default(),
179            serde_json::to_string(&merged).unwrap_or_default()
180        ));
181    }
182
183    // 5. Merged output_length equals original output_length
184    let merged_len = fields
185        .to_function()
186        .compile_output_length(&merged)
187        .map_err(|e| {
188            format!(
189                "VF13: Input {}: output_length failed for merged input: {}",
190                input_label, e
191            )
192        })?
193        .ok_or_else(|| {
194            format!(
195                "VF14: Input {}: output_length returned None for merged input",
196                input_label
197            )
198        })?;
199
200    if merged_len != output_length {
201        return Err(format!(
202            "VF15: Input {}: merged output_length ({}) != original output_length ({})",
203            input_label, merged_len, output_length
204        ));
205    }
206
207    // 6. Random subsets — merge and verify output_length = subset size
208    //    and merged input satisfies input_schema constraints.
209    let mut subsets = random_subsets(splits.len(), 5, rng);
210    // Always test a 2-element subset deterministically so that
211    // min_items violations are caught reliably.
212    if splits.len() >= 3 {
213        subsets.insert(0, vec![0, 1]);
214    }
215    for subset in &subsets {
216        let sub_splits: Vec<InputValue> =
217            subset.iter().map(|&idx| splits[idx].clone()).collect();
218        let sub_merge_input = InputValue::Array(sub_splits);
219        let sub_merged = fields
220            .to_function()
221            .compile_input_merge(&sub_merge_input)
222            .map_err(|e| {
223                format!(
224                    "VF16: Input {}: input_merge failed for subset {:?}: {}",
225                    input_label, subset, e
226                )
227            })?
228            .ok_or_else(|| {
229                format!(
230                    "VF17: Input {}: input_merge returned None for subset {:?}",
231                    input_label, subset
232                )
233            })?;
234
235        let sub_merged_len = fields
236            .to_function()
237            .compile_output_length(&sub_merged)
238            .map_err(|e| {
239                format!(
240                    "VF18: Input {}: output_length failed for merged subset {:?}: {}",
241                    input_label, subset, e
242                )
243            })?
244            .ok_or_else(|| {
245                format!(
246                    "VF19: Input {}: output_length returned None for merged subset {:?}",
247                    input_label, subset
248                )
249            })?;
250
251        if sub_merged_len as usize != subset.len() {
252            return Err(format!(
253                "VF20: Input {}: merged subset {:?} output_length is {}, expected {}",
254                input_label,
255                subset,
256                sub_merged_len,
257                subset.len()
258            ));
259        }
260
261        // Merged subset must satisfy the input_schema constraints
262        // (e.g., min_items). This ensures the function can execute
263        // correctly with merged sub-inputs (used by swiss_system).
264        validate_input_against_schema(
265            &sub_merged,
266            &fields.input_schema,
267            "root",
268        )
269        .map_err(|e| {
270            format!(
271                "VF21: Input {}: merged subset {:?} violates input_schema: {}",
272                input_label, subset, e
273            )
274        })?;
275    }
276
277    Ok(())
278}
279
280/// Validate that an input satisfies the schema's structural constraints.
281/// Checks array min_items/max_items recursively through objects.
282fn validate_input_against_schema(
283    input: &InputValue,
284    schema: &InputSchema,
285    path: &str,
286) -> Result<(), String> {
287    match (input, schema) {
288        (InputValue::Array(arr), InputSchema::Array(arr_schema)) => {
289            if let Some(min) = arr_schema.min_items {
290                if (arr.len() as u64) < min {
291                    return Err(format!(
292                        "VF23: {}: array has {} items but min_items is {}",
293                        path,
294                        arr.len(),
295                        min
296                    ));
297                }
298            }
299            if let Some(max) = arr_schema.max_items {
300                if (arr.len() as u64) > max {
301                    return Err(format!(
302                        "VF24: {}: array has {} items but max_items is {}",
303                        path,
304                        arr.len(),
305                        max
306                    ));
307                }
308            }
309            for (i, item) in arr.iter().enumerate() {
310                validate_input_against_schema(
311                    item,
312                    &arr_schema.items,
313                    &format!("{}[{}]", path, i),
314                )?;
315            }
316            Ok(())
317        }
318        (InputValue::Object(obj), InputSchema::Object(obj_schema)) => {
319            for (key, prop_schema) in &obj_schema.properties {
320                if let Some(value) = obj.get(key) {
321                    validate_input_against_schema(
322                        value,
323                        prop_schema,
324                        &format!("{}.{}", path, key),
325                    )?;
326                }
327            }
328            Ok(())
329        }
330        _ => Ok(()),
331    }
332}
333
334/// Deep equality check for Input values.
335pub(crate) fn inputs_equal(a: &InputValue, b: &InputValue) -> bool {
336    match (a, b) {
337        (InputValue::String(a), InputValue::String(b)) => a == b,
338        (InputValue::Integer(a), InputValue::Integer(b)) => a == b,
339        (InputValue::Number(a), InputValue::Number(b)) => a == b,
340        (InputValue::Boolean(a), InputValue::Boolean(b)) => a == b,
341        (InputValue::Array(a), InputValue::Array(b)) => {
342            a.len() == b.len()
343                && a.iter().zip(b.iter()).all(|(x, y)| inputs_equal(x, y))
344        }
345        (InputValue::Object(a), InputValue::Object(b)) => {
346            a.len() == b.len()
347                && a.iter().all(|(ka, va)| {
348                    b.get(ka).is_some_and(|vb| inputs_equal(va, vb))
349                })
350        }
351        (InputValue::RichContentPart(a), InputValue::RichContentPart(b)) => a == b,
352        _ => false,
353    }
354}
355
356/// Generate random subsets of indices for subset merge testing.
357pub(crate) fn random_subsets(length: usize, count: usize, rng: &mut impl Rng) -> Vec<Vec<usize>> {
358    if length < 2 {
359        return vec![];
360    }
361
362    let mut result = Vec::new();
363
364    for _ in 0..count {
365        let size = rng.random_range(2..=length);
366        let mut all_indices: Vec<usize> = (0..length).collect();
367
368        // Fisher-Yates shuffle
369        for i in (1..all_indices.len()).rev() {
370            let j = rng.random_range(0..=i);
371            all_indices.swap(i, j);
372        }
373
374        let mut subset: Vec<usize> =
375            all_indices.into_iter().take(size).collect();
376        subset.sort();
377        subset.dedup();
378
379        if subset.len() >= 2 {
380            result.push(subset);
381        }
382    }
383
384    result
385}