objectiveai_sdk/functions/check/
check_vector_fields.rs1use rand::Rng;
7use rand::rngs::StdRng;
8use rand::SeedableRng;
9use serde::Deserialize;
10
11use super::check_input_schema::check_input_schema;
12use super::example_inputs;
13use crate::functions::expression::{Expression, InputValue, InputSchema};
14use crate::functions::{Function, RemoteFunction};
15use schemars::JsonSchema;
16
17#[derive(Debug, Clone, Deserialize, JsonSchema)]
19#[schemars(rename = "functions.check.VectorFieldsValidation")]
20pub struct VectorFieldsValidation {
21 pub input_schema: InputSchema,
22 pub output_length: Expression,
23 pub input_split: Expression,
24 pub input_merge: Expression,
25}
26
27impl VectorFieldsValidation {
28 fn to_function(&self) -> Function {
30 Function::Remote(RemoteFunction::Vector {
31 description: String::new(),
32 input_schema: self.input_schema.clone(),
33 tasks: vec![],
34 output_length: self.output_length.clone(),
35 input_split: self.input_split.clone(),
36 input_merge: self.input_merge.clone(),
37 })
38 }
39}
40
41pub fn check_vector_fields(
46 fields: VectorFieldsValidation,
47 seed: Option<i64>,
48) -> Result<(), String> {
49 check_input_schema(&fields.input_schema)?;
51
52 let mut rng = match seed {
53 Some(s) => StdRng::seed_from_u64(s as u64),
54 None => StdRng::from_os_rng(),
55 };
56
57 let mut count = 0usize;
58 for ref input in example_inputs::generate_seeded(&fields.input_schema, StdRng::seed_from_u64(rng.random::<u64>())) {
59 count += 1;
60 let input_label = serde_json::to_string(input).unwrap_or_default();
61 check_vector_fields_for_input(&fields, &input_label, input, &mut rng)?;
62 }
63
64 if count == 0 {
65 return Err(
66 "VF22: Failed to generate any example inputs from input_schema"
67 .to_string(),
68 );
69 }
70
71 Ok(())
72}
73
74pub(crate) fn check_vector_fields_for_input(
81 fields: &VectorFieldsValidation,
82 input_label: &str,
83 input: &InputValue,
84 rng: &mut impl Rng,
85) -> Result<(), String> {
86 let output_length = fields
88 .to_function()
89 .compile_output_length(input)
90 .map_err(|e| {
91 format!("VF01: Input {}: output_length compilation failed: {}", input_label, e)
92 })?
93 .ok_or_else(|| {
94 format!(
95 "VF02: Input {}: output_length returned None (not a vector function?)",
96 input_label
97 )
98 })?;
99
100 if output_length < 2 {
101 return Err(format!(
102 "VF03: Input {}: output_length must be > 1 for vector functions, got {}. Try setting `minItems` to 2 in the `input_schema`.",
103 input_label, output_length,
104 ));
105 }
106
107 let splits = fields
109 .to_function()
110 .compile_input_split(input)
111 .map_err(|e| {
112 format!(
113 "VF04: Input {}: input_split compilation failed: {}",
114 input_label, e
115 )
116 })?
117 .ok_or_else(|| {
118 format!("VF05: Input {}: input_split returned None", input_label)
119 })?;
120
121 if splits.len() as u64 != output_length {
122 return Err(format!(
123 "VF06: Input {}: input_split produced {} elements but output_length is {}",
124 input_label,
125 splits.len(),
126 output_length,
127 ));
128 }
129
130 for (j, split) in splits.iter().enumerate() {
132 let split_len = fields
133 .to_function()
134 .compile_output_length(split)
135 .map_err(|e| {
136 format!(
137 "VF07: Input {}: output_length failed for split [{}]: {}",
138 input_label, j, e
139 )
140 })?
141 .ok_or_else(|| {
142 format!(
143 "VF08: Input {}: output_length returned None for split [{}]",
144 input_label, j
145 )
146 })?;
147
148 if split_len != 1 {
149 return Err(format!(
150 "VF09: Input {}: split [{}] output_length must be 1, got {}.\n\nSplit: {}",
151 input_label,
152 j,
153 split_len,
154 serde_json::to_string(split).unwrap_or_default()
155 ));
156 }
157 }
158
159 let merge_input = InputValue::Array(splits.clone());
161 let merged = fields
162 .to_function()
163 .compile_input_merge(&merge_input)
164 .map_err(|e| {
165 format!(
166 "VF10: Input {}: input_merge compilation failed: {}",
167 input_label, e
168 )
169 })?
170 .ok_or_else(|| {
171 format!("VF11: Input {}: input_merge returned None", input_label)
172 })?;
173
174 if !inputs_equal(input, &merged) {
175 return Err(format!(
176 "VF12: Input {}: merged input does not match original.\n\nOriginal: {}\n\nMerged: {}",
177 input_label,
178 serde_json::to_string(input).unwrap_or_default(),
179 serde_json::to_string(&merged).unwrap_or_default()
180 ));
181 }
182
183 let merged_len = fields
185 .to_function()
186 .compile_output_length(&merged)
187 .map_err(|e| {
188 format!(
189 "VF13: Input {}: output_length failed for merged input: {}",
190 input_label, e
191 )
192 })?
193 .ok_or_else(|| {
194 format!(
195 "VF14: Input {}: output_length returned None for merged input",
196 input_label
197 )
198 })?;
199
200 if merged_len != output_length {
201 return Err(format!(
202 "VF15: Input {}: merged output_length ({}) != original output_length ({})",
203 input_label, merged_len, output_length
204 ));
205 }
206
207 let mut subsets = random_subsets(splits.len(), 5, rng);
210 if splits.len() >= 3 {
213 subsets.insert(0, vec![0, 1]);
214 }
215 for subset in &subsets {
216 let sub_splits: Vec<InputValue> =
217 subset.iter().map(|&idx| splits[idx].clone()).collect();
218 let sub_merge_input = InputValue::Array(sub_splits);
219 let sub_merged = fields
220 .to_function()
221 .compile_input_merge(&sub_merge_input)
222 .map_err(|e| {
223 format!(
224 "VF16: Input {}: input_merge failed for subset {:?}: {}",
225 input_label, subset, e
226 )
227 })?
228 .ok_or_else(|| {
229 format!(
230 "VF17: Input {}: input_merge returned None for subset {:?}",
231 input_label, subset
232 )
233 })?;
234
235 let sub_merged_len = fields
236 .to_function()
237 .compile_output_length(&sub_merged)
238 .map_err(|e| {
239 format!(
240 "VF18: Input {}: output_length failed for merged subset {:?}: {}",
241 input_label, subset, e
242 )
243 })?
244 .ok_or_else(|| {
245 format!(
246 "VF19: Input {}: output_length returned None for merged subset {:?}",
247 input_label, subset
248 )
249 })?;
250
251 if sub_merged_len as usize != subset.len() {
252 return Err(format!(
253 "VF20: Input {}: merged subset {:?} output_length is {}, expected {}",
254 input_label,
255 subset,
256 sub_merged_len,
257 subset.len()
258 ));
259 }
260
261 validate_input_against_schema(
265 &sub_merged,
266 &fields.input_schema,
267 "root",
268 )
269 .map_err(|e| {
270 format!(
271 "VF21: Input {}: merged subset {:?} violates input_schema: {}",
272 input_label, subset, e
273 )
274 })?;
275 }
276
277 Ok(())
278}
279
280fn validate_input_against_schema(
283 input: &InputValue,
284 schema: &InputSchema,
285 path: &str,
286) -> Result<(), String> {
287 match (input, schema) {
288 (InputValue::Array(arr), InputSchema::Array(arr_schema)) => {
289 if let Some(min) = arr_schema.min_items {
290 if (arr.len() as u64) < min {
291 return Err(format!(
292 "VF23: {}: array has {} items but min_items is {}",
293 path,
294 arr.len(),
295 min
296 ));
297 }
298 }
299 if let Some(max) = arr_schema.max_items {
300 if (arr.len() as u64) > max {
301 return Err(format!(
302 "VF24: {}: array has {} items but max_items is {}",
303 path,
304 arr.len(),
305 max
306 ));
307 }
308 }
309 for (i, item) in arr.iter().enumerate() {
310 validate_input_against_schema(
311 item,
312 &arr_schema.items,
313 &format!("{}[{}]", path, i),
314 )?;
315 }
316 Ok(())
317 }
318 (InputValue::Object(obj), InputSchema::Object(obj_schema)) => {
319 for (key, prop_schema) in &obj_schema.properties {
320 if let Some(value) = obj.get(key) {
321 validate_input_against_schema(
322 value,
323 prop_schema,
324 &format!("{}.{}", path, key),
325 )?;
326 }
327 }
328 Ok(())
329 }
330 _ => Ok(()),
331 }
332}
333
334pub(crate) fn inputs_equal(a: &InputValue, b: &InputValue) -> bool {
336 match (a, b) {
337 (InputValue::String(a), InputValue::String(b)) => a == b,
338 (InputValue::Integer(a), InputValue::Integer(b)) => a == b,
339 (InputValue::Number(a), InputValue::Number(b)) => a == b,
340 (InputValue::Boolean(a), InputValue::Boolean(b)) => a == b,
341 (InputValue::Array(a), InputValue::Array(b)) => {
342 a.len() == b.len()
343 && a.iter().zip(b.iter()).all(|(x, y)| inputs_equal(x, y))
344 }
345 (InputValue::Object(a), InputValue::Object(b)) => {
346 a.len() == b.len()
347 && a.iter().all(|(ka, va)| {
348 b.get(ka).is_some_and(|vb| inputs_equal(va, vb))
349 })
350 }
351 (InputValue::RichContentPart(a), InputValue::RichContentPart(b)) => a == b,
352 _ => false,
353 }
354}
355
356pub(crate) fn random_subsets(length: usize, count: usize, rng: &mut impl Rng) -> Vec<Vec<usize>> {
358 if length < 2 {
359 return vec![];
360 }
361
362 let mut result = Vec::new();
363
364 for _ in 0..count {
365 let size = rng.random_range(2..=length);
366 let mut all_indices: Vec<usize> = (0..length).collect();
367
368 for i in (1..all_indices.len()).rev() {
370 let j = rng.random_range(0..=i);
371 all_indices.swap(i, j);
372 }
373
374 let mut subset: Vec<usize> =
375 all_indices.into_iter().take(size).collect();
376 subset.sort();
377 subset.dedup();
378
379 if subset.len() >= 2 {
380 result.push(subset);
381 }
382 }
383
384 result
385}