objectiveai_sdk/functions/check/
check_vector_fields.rs1use rand::Rng;
7use rand::SeedableRng;
8use rand::rngs::StdRng;
9use serde::Deserialize;
10
11use super::check_input_schema::check_input_schema;
12use super::example_inputs;
13use crate::functions::expression::{Expression, InputSchema, InputValue};
14use crate::functions::{Function, RemoteFunction};
15use schemars::JsonSchema;
16
17#[derive(Debug, Clone, Deserialize, JsonSchema)]
19#[schemars(rename = "functions.check.VectorFieldsValidation")]
20pub struct VectorFieldsValidation {
21 pub input_schema: InputSchema,
22 pub output_length: Expression,
23 pub input_split: Expression,
24 pub input_merge: Expression,
25}
26
27impl VectorFieldsValidation {
28 fn to_function(&self) -> Function {
30 Function::Remote(RemoteFunction::Vector {
31 description: String::new(),
32 input_schema: self.input_schema.clone(),
33 tasks: vec![],
34 output_length: self.output_length.clone(),
35 input_split: self.input_split.clone(),
36 input_merge: self.input_merge.clone(),
37 })
38 }
39}
40
41pub fn check_vector_fields(
46 fields: VectorFieldsValidation,
47 seed: Option<i64>,
48) -> Result<(), String> {
49 check_input_schema(&fields.input_schema)?;
51
52 let mut rng = match seed {
53 Some(s) => StdRng::seed_from_u64(s as u64),
54 None => StdRng::from_os_rng(),
55 };
56
57 let mut count = 0usize;
58 for ref input in example_inputs::generate_seeded(
59 &fields.input_schema,
60 StdRng::seed_from_u64(rng.random::<u64>()),
61 ) {
62 count += 1;
63 let input_label = serde_json::to_string(input).unwrap_or_default();
64 check_vector_fields_for_input(&fields, &input_label, input, &mut rng)?;
65 }
66
67 if count == 0 {
68 return Err(
69 "VF22: Failed to generate any example inputs from input_schema"
70 .to_string(),
71 );
72 }
73
74 Ok(())
75}
76
77pub(crate) fn check_vector_fields_for_input(
84 fields: &VectorFieldsValidation,
85 input_label: &str,
86 input: &InputValue,
87 rng: &mut impl Rng,
88) -> Result<(), String> {
89 let output_length = fields
91 .to_function()
92 .compile_output_length(input)
93 .map_err(|e| {
94 format!("VF01: Input {}: output_length compilation failed: {}", input_label, e)
95 })?
96 .ok_or_else(|| {
97 format!(
98 "VF02: Input {}: output_length returned None (not a vector function?)",
99 input_label
100 )
101 })?;
102
103 if output_length < 2 {
104 return Err(format!(
105 "VF03: Input {}: output_length must be > 1 for vector functions, got {}. Try setting `minItems` to 2 in the `input_schema`.",
106 input_label, output_length,
107 ));
108 }
109
110 let splits = fields
112 .to_function()
113 .compile_input_split(input)
114 .map_err(|e| {
115 format!(
116 "VF04: Input {}: input_split compilation failed: {}",
117 input_label, e
118 )
119 })?
120 .ok_or_else(|| {
121 format!("VF05: Input {}: input_split returned None", input_label)
122 })?;
123
124 if splits.len() as u64 != output_length {
125 return Err(format!(
126 "VF06: Input {}: input_split produced {} elements but output_length is {}",
127 input_label,
128 splits.len(),
129 output_length,
130 ));
131 }
132
133 for (j, split) in splits.iter().enumerate() {
135 let split_len = fields
136 .to_function()
137 .compile_output_length(split)
138 .map_err(|e| {
139 format!(
140 "VF07: Input {}: output_length failed for split [{}]: {}",
141 input_label, j, e
142 )
143 })?
144 .ok_or_else(|| {
145 format!(
146 "VF08: Input {}: output_length returned None for split [{}]",
147 input_label, j
148 )
149 })?;
150
151 if split_len != 1 {
152 return Err(format!(
153 "VF09: Input {}: split [{}] output_length must be 1, got {}.\n\nSplit: {}",
154 input_label,
155 j,
156 split_len,
157 serde_json::to_string(split).unwrap_or_default()
158 ));
159 }
160 }
161
162 let merge_input = InputValue::Array(splits.clone());
164 let merged = fields
165 .to_function()
166 .compile_input_merge(&merge_input)
167 .map_err(|e| {
168 format!(
169 "VF10: Input {}: input_merge compilation failed: {}",
170 input_label, e
171 )
172 })?
173 .ok_or_else(|| {
174 format!("VF11: Input {}: input_merge returned None", input_label)
175 })?;
176
177 if !inputs_equal(input, &merged) {
178 return Err(format!(
179 "VF12: Input {}: merged input does not match original.\n\nOriginal: {}\n\nMerged: {}",
180 input_label,
181 serde_json::to_string(input).unwrap_or_default(),
182 serde_json::to_string(&merged).unwrap_or_default()
183 ));
184 }
185
186 let merged_len = fields
188 .to_function()
189 .compile_output_length(&merged)
190 .map_err(|e| {
191 format!(
192 "VF13: Input {}: output_length failed for merged input: {}",
193 input_label, e
194 )
195 })?
196 .ok_or_else(|| {
197 format!(
198 "VF14: Input {}: output_length returned None for merged input",
199 input_label
200 )
201 })?;
202
203 if merged_len != output_length {
204 return Err(format!(
205 "VF15: Input {}: merged output_length ({}) != original output_length ({})",
206 input_label, merged_len, output_length
207 ));
208 }
209
210 let mut subsets = random_subsets(splits.len(), 5, rng);
213 if splits.len() >= 3 {
216 subsets.insert(0, vec![0, 1]);
217 }
218 for subset in &subsets {
219 let sub_splits: Vec<InputValue> =
220 subset.iter().map(|&idx| splits[idx].clone()).collect();
221 let sub_merge_input = InputValue::Array(sub_splits);
222 let sub_merged = fields
223 .to_function()
224 .compile_input_merge(&sub_merge_input)
225 .map_err(|e| {
226 format!(
227 "VF16: Input {}: input_merge failed for subset {:?}: {}",
228 input_label, subset, e
229 )
230 })?
231 .ok_or_else(|| {
232 format!(
233 "VF17: Input {}: input_merge returned None for subset {:?}",
234 input_label, subset
235 )
236 })?;
237
238 let sub_merged_len = fields
239 .to_function()
240 .compile_output_length(&sub_merged)
241 .map_err(|e| {
242 format!(
243 "VF18: Input {}: output_length failed for merged subset {:?}: {}",
244 input_label, subset, e
245 )
246 })?
247 .ok_or_else(|| {
248 format!(
249 "VF19: Input {}: output_length returned None for merged subset {:?}",
250 input_label, subset
251 )
252 })?;
253
254 if sub_merged_len as usize != subset.len() {
255 return Err(format!(
256 "VF20: Input {}: merged subset {:?} output_length is {}, expected {}",
257 input_label,
258 subset,
259 sub_merged_len,
260 subset.len()
261 ));
262 }
263
264 validate_input_against_schema(
268 &sub_merged,
269 &fields.input_schema,
270 "root",
271 )
272 .map_err(|e| {
273 format!(
274 "VF21: Input {}: merged subset {:?} violates input_schema: {}",
275 input_label, subset, e
276 )
277 })?;
278 }
279
280 Ok(())
281}
282
283fn validate_input_against_schema(
286 input: &InputValue,
287 schema: &InputSchema,
288 path: &str,
289) -> Result<(), String> {
290 match (input, schema) {
291 (InputValue::Array(arr), InputSchema::Array(arr_schema)) => {
292 if let Some(min) = arr_schema.min_items {
293 if (arr.len() as u64) < min {
294 return Err(format!(
295 "VF23: {}: array has {} items but min_items is {}",
296 path,
297 arr.len(),
298 min
299 ));
300 }
301 }
302 if let Some(max) = arr_schema.max_items {
303 if (arr.len() as u64) > max {
304 return Err(format!(
305 "VF24: {}: array has {} items but max_items is {}",
306 path,
307 arr.len(),
308 max
309 ));
310 }
311 }
312 for (i, item) in arr.iter().enumerate() {
313 validate_input_against_schema(
314 item,
315 &arr_schema.items,
316 &format!("{}[{}]", path, i),
317 )?;
318 }
319 Ok(())
320 }
321 (InputValue::Object(obj), InputSchema::Object(obj_schema)) => {
322 for (key, prop_schema) in &obj_schema.properties {
323 if let Some(value) = obj.get(key) {
324 validate_input_against_schema(
325 value,
326 prop_schema,
327 &format!("{}.{}", path, key),
328 )?;
329 }
330 }
331 Ok(())
332 }
333 _ => Ok(()),
334 }
335}
336
337pub(crate) fn inputs_equal(a: &InputValue, b: &InputValue) -> bool {
339 match (a, b) {
340 (InputValue::String(a), InputValue::String(b)) => a == b,
341 (InputValue::Integer(a), InputValue::Integer(b)) => a == b,
342 (InputValue::Number(a), InputValue::Number(b)) => a == b,
343 (InputValue::Boolean(a), InputValue::Boolean(b)) => a == b,
344 (InputValue::Array(a), InputValue::Array(b)) => {
345 a.len() == b.len()
346 && a.iter().zip(b.iter()).all(|(x, y)| inputs_equal(x, y))
347 }
348 (InputValue::Object(a), InputValue::Object(b)) => {
349 a.len() == b.len()
350 && a.iter().all(|(ka, va)| {
351 b.get(ka).is_some_and(|vb| inputs_equal(va, vb))
352 })
353 }
354 (InputValue::RichContentPart(a), InputValue::RichContentPart(b)) => {
355 a == b
356 }
357 _ => false,
358 }
359}
360
361pub(crate) fn random_subsets(
363 length: usize,
364 count: usize,
365 rng: &mut impl Rng,
366) -> Vec<Vec<usize>> {
367 if length < 2 {
368 return vec![];
369 }
370
371 let mut result = Vec::new();
372
373 for _ in 0..count {
374 let size = rng.random_range(2..=length);
375 let mut all_indices: Vec<usize> = (0..length).collect();
376
377 for i in (1..all_indices.len()).rev() {
379 let j = rng.random_range(0..=i);
380 all_indices.swap(i, j);
381 }
382
383 let mut subset: Vec<usize> =
384 all_indices.into_iter().take(size).collect();
385 subset.sort();
386 subset.dedup();
387
388 if subset.len() >= 2 {
389 result.push(subset);
390 }
391 }
392
393 result
394}