use std::collections::{HashMap, HashSet};
use rand::Rng;
use rand::rngs::StdRng;
use rand::SeedableRng;
use crate::functions::alpha_vector::RemoteFunction;
use crate::functions::expression::InputValue;
use crate::functions::{CompiledTask, Function, Task};
use crate::functions::check::check_description;
use crate::functions::check::check_input_schema;
use crate::functions::check::{
ModalityFlags, check_modality_coverage, collect_schema_modalities,
collect_task_modalities,
};
use crate::functions::check::{
VectorOutputShape, check_vector_distribution,
};
use crate::functions::check::{
VectorFieldsValidation, check_vector_fields_for_input, random_subsets,
};
use crate::functions::check::compile_and_validate_one_input;
use crate::functions::check::example_inputs;
pub fn check_alpha_leaf_vector_function(
function: &RemoteFunction,
seed: Option<i64>,
) -> Result<(), String> {
let (description, input_schema, _tasks) = match function {
RemoteFunction::Leaf {
description,
input_schema,
tasks,
} => (description, input_schema, tasks),
RemoteFunction::Branch { .. } => {
return Err(
"AV01: Expected alpha.vector.leaf.function, got alpha.vector.branch.function"
.to_string(),
);
}
};
check_description(description)?;
let transpiled_input_schema = input_schema.clone().transpile();
check_input_schema(&transpiled_input_schema)?;
if _tasks.is_empty() {
return Err(
"AV03: Functions must have at least one task".to_string(),
);
}
let transpiled = function.clone().transpile();
let (
transpiled_input_schema_ref,
transpiled_output_length,
transpiled_input_split,
transpiled_input_merge,
) = match &transpiled {
crate::functions::RemoteFunction::Vector {
input_schema,
output_length,
input_split,
input_merge,
..
} => (input_schema, output_length, input_split, input_merge),
_ => unreachable!(),
};
let vector_fields = VectorFieldsValidation {
input_schema: transpiled_input_schema_ref.clone(),
output_length: transpiled_output_length.clone(),
input_split: transpiled_input_split.clone(),
input_merge: transpiled_input_merge.clone(),
};
let func_template = Function::Remote(transpiled.clone());
let task_count = _tasks.len();
let mut per_task_indexed: Vec<HashMap<usize, (usize, HashSet<String>)>> =
vec![HashMap::new(); task_count];
let mut per_task_has_varying = vec![false; task_count];
let mut per_task_skipped = vec![false; task_count];
let mut seen_dist_tasks: HashSet<(usize, usize)> = HashSet::new();
let mut count = 0usize;
let mut schema_modalities: ModalityFlags = [false; 4];
collect_schema_modalities(transpiled_input_schema_ref, &mut schema_modalities);
let mut task_modalities: ModalityFlags = [false; 4];
let mut rng = match seed {
Some(s) => StdRng::seed_from_u64(s as u64),
None => StdRng::from_os_rng(),
};
for ref input in example_inputs::generate_seeded(transpiled_input_schema_ref, StdRng::seed_from_u64(rng.random::<u64>())) {
count += 1;
let input_label = serde_json::to_string(input).unwrap_or_default();
check_vector_fields_for_input(&vector_fields, &input_label, input, &mut rng)?;
let compiled_tasks = compile_and_validate_one_input(
&input_label,
&transpiled,
input,
None,
)?;
for (j, compiled_task) in compiled_tasks.iter().enumerate() {
if let Some(CompiledTask::One(Task::VectorCompletion(vc))) =
compiled_task
{
let key = (j, vc.responses.len());
if seen_dist_tasks.insert(key) {
let ol = func_template
.clone()
.compile_output_length(input)
.ok()
.flatten()
.unwrap_or(0) as usize;
check_vector_distribution(
j,
input,
&Task::VectorCompletion(vc.clone()),
&VectorOutputShape::VectorCompletion(
vc.responses.len(),
),
ol,
)?;
}
}
}
for (j, compiled_task) in compiled_tasks.iter().enumerate() {
let Some(compiled_task) = compiled_task else {
per_task_skipped[j] = true;
continue;
};
if let CompiledTask::One(Task::VectorCompletion(vc)) =
compiled_task
{
collect_task_modalities(vc, &mut task_modalities);
for (ri, response) in vc.responses.iter().enumerate() {
let key =
serde_json::to_string(response).unwrap_or_default();
let entry = per_task_indexed[j]
.entry(ri)
.or_insert_with(|| (0, HashSet::new()));
entry.0 += 1;
entry.1.insert(key);
}
if !per_task_has_varying[j] && vc.responses.len() >= 2 {
let first = serde_json::to_string(&vc.responses[0])
.unwrap_or_default();
let has_different = vc.responses[1..].iter().any(|r| {
serde_json::to_string(r).unwrap_or_default() != first
});
if has_different {
per_task_has_varying[j] = true;
}
}
}
}
let splits = func_template
.clone()
.compile_input_split(input)
.map_err(|e| {
format!(
"AV09: Merged input validation, input {}: input_split failed: {}",
input_label, e
)
})?
.ok_or_else(|| {
format!(
"AV10: Merged input validation, input {}: input_split returned None",
input_label
)
})?;
if splits.len() >= 2 {
let subsets = random_subsets(splits.len(), 3, &mut rng);
for subset in &subsets {
let sub_splits: Vec<InputValue> =
subset.iter().map(|&idx| splits[idx].clone()).collect();
let merge_input = InputValue::Array(sub_splits);
let merged = func_template
.clone()
.compile_input_merge(&merge_input)
.map_err(|e| {
format!(
"AV11: Merged input validation, input {}, subset {:?}: \
input_merge failed: {}",
input_label, subset, e
)
})?
.ok_or_else(|| {
format!(
"AV12: Merged input validation, input {}, subset {:?}: \
input_merge returned None",
input_label, subset
)
})?;
let merged_label =
serde_json::to_string(&merged).unwrap_or_default();
compile_and_validate_one_input(
&merged_label,
&transpiled,
&merged,
None,
)?;
}
}
}
if count == 0 {
return Err(
"AV15: Failed to generate any example inputs from input_schema"
.to_string(),
);
}
if count >= 2 {
for (j, indexed) in per_task_indexed.iter().enumerate() {
for (&ri, (occurrences, unique_values)) in indexed {
let total = *occurrences
+ if per_task_skipped[j] { 1 } else { 0 };
if total <= 1 {
continue;
}
let effective = unique_values.len()
+ if per_task_skipped[j] { 1 } else { 0 };
if effective < 2 {
return Err(format!(
"AV16: Task [{}]: response at index {} is a fixed value — \
responses must be derived from an array in the input",
j, ri,
));
}
}
}
for (j, has_varying) in per_task_has_varying.iter().enumerate() {
if !has_varying && !per_task_skipped[j] {
return Err(format!(
"AV17: Task [{}]: all responses are equal to each other for every \
example input — rankings are useless if every item is the same",
j,
));
}
}
}
check_modality_coverage(&schema_modalities, &task_modalities, "AV18")?;
Ok(())
}