use crate::spec::property::resolve;
use crate::spec::Property;
mod codegen;
pub(crate) mod util;
use std::collections::HashSet;
use std::path::Path;
use crate::adversarial::mutations::catalog::applied::applied_for_classes;
use crate::generate::archetypes::{Archetype, TestInput};
use crate::generate::emit::emit::{write_test, EmittedTest};
use crate::generate::emit::naming::{sanitize, test_name};
use crate::generate::emit::provenance::{header, SourceTuple};
use crate::generate::emit::seed::SeedStream;
use crate::generate::emit::{
GenError, GeneratedFile, GenerationPlan, GenerationReport, TemplateKind,
};
use crate::generate::templates::{render, TemplateContext};
use crate::spec::types::OpSpec;
use crate::spec::{OracleKind, SpecRow};
use vyre_spec::Category;
#[inline]
pub fn generate_cross_product(
plan: &GenerationPlan,
out_dir: &Path,
) -> Result<GenerationReport, GenError> {
let mut report = GenerationReport::default();
let mut names = HashSet::new();
let mut seeds = SeedStream::new(plan.seed);
for op in &plan.ops {
for archetype in &plan.archetypes {
if !archetype.applies_to(op) {
continue;
}
let Some(input) = archetype.materialize(op) else {
report.skipped_empty_inputs += 1;
continue;
};
let inputs = vec![input];
for route in generate_routes_union(op, &inputs)? {
if !plan.templates.contains(&route.template) {
continue;
}
let input = inputs
.get(route.input_index)
.ok_or_else(|| GenError::InvalidPlan {
reason: format!(
"input_index {} out of range (inputs.len()={})",
route.input_index,
inputs.len()
),
})?;
let file = emit_tuple(out_dir, &mut seeds, op, *archetype, input, &route)?;
if !names.insert(file.test_name.clone()) {
return Err(GenError::NameCollision(file.test_name));
}
report.files.push(file);
}
}
}
Ok(report)
}
#[derive(Clone)]
struct Route {
route_id: String,
oracle: OracleKind,
oracle_name: String,
property: Property,
property_label: String,
template: TemplateKind,
input_index: usize,
validation: Option<ValidationCase>,
mutation: Option<MutationRoute>,
}
#[derive(Clone, Copy)]
struct ValidationCase {
rule_id: &'static str,
description: &'static str,
expectation: ValidationExpectation,
}
#[derive(Clone, Copy)]
enum ValidationExpectation {
MustAccept,
MustReject,
}
#[derive(Clone)]
struct MutationRoute {
id: String,
description: String,
details: String,
}
fn generate_routes_union(op: &OpSpec, inputs: &[TestInput]) -> Result<Vec<Route>, GenError> {
let mut routes = Vec::new();
for law in &op.declared_laws {
push_route(
op,
&mut routes,
Property::DeclaredLawHolds { law: law.clone() },
format!("declared law `{}` holds", law.name()),
TemplateKind::Law,
0,
);
}
for (input_index, row_index) in matching_spec_table_inputs(op, inputs)? {
push_route(
op,
&mut routes,
Property::SpecTableRowMatches { row_index },
format!("spec table row {row_index} matches"),
TemplateKind::OpCorrectness,
input_index,
);
}
if is_composed(op) {
if let Some(law) = op.declared_laws.first().cloned() {
push_route(
op,
&mut routes,
Property::CompositionPreservesLaw { law },
"composition preserves declared law".to_string(),
TemplateKind::Archetype,
0,
);
}
}
for (input_index, input) in inputs.iter().enumerate() {
push_route(
op,
&mut routes,
Property::PointParity { input_index },
format!("point parity for `{}`", input.label),
TemplateKind::Archetype,
input_index,
);
push_route(
op,
&mut routes,
Property::PointParity { input_index },
format!("backend equivalence for `{}`", input.label),
TemplateKind::BackendEquiv,
input_index,
);
}
for validation in validation_cases() {
push_route_with_metadata(
op,
&mut routes,
Property::PropertyCheck {
description: validation.description.to_string(),
},
format!(
"{} {}",
validation.rule_id,
validation_expectation_label(validation.expectation)
),
TemplateKind::Validation,
0,
format!(
"validation_{}_{}",
sanitize(validation.rule_id),
validation_expectation_label(validation.expectation)
),
Some(*validation),
None,
);
}
for mutation in mutation_routes(op) {
push_route_with_metadata(
op,
&mut routes,
Property::PropertyCheck {
description: "mutation kill".to_string(),
},
format!("kill mutation {}", mutation.id),
TemplateKind::MutationKill,
0,
format!("mutation_{}", sanitize(&mutation.id)),
None,
Some(mutation),
);
}
Ok(routes)
}
fn push_route(
op: &OpSpec,
routes: &mut Vec<Route>,
property: Property,
property_label: String,
template: TemplateKind,
input_index: usize,
) {
let route_id = property_key(&property);
push_route_with_metadata(
op,
routes,
property,
property_label,
template,
input_index,
route_id,
None,
None,
);
}
fn push_route_with_metadata(
op: &OpSpec,
routes: &mut Vec<Route>,
property: Property,
property_label: String,
template: TemplateKind,
input_index: usize,
route_id: String,
validation: Option<ValidationCase>,
mutation: Option<MutationRoute>,
) {
let oracle = resolved_oracle_kind(op, &property);
if oracle == OracleKind::Property
&& stronger_property_exists(op)
&& !matches!(
template,
TemplateKind::Validation | TemplateKind::MutationKill
)
{
return;
}
if routes.iter().any(|route| {
route.oracle == oracle && route.template == template && route.route_id == route_id
}) {
return;
}
routes.push(Route {
route_id,
oracle,
oracle_name: oracle_name(oracle).to_string(),
property,
property_label,
template,
input_index,
validation,
mutation,
});
}
fn validation_cases() -> &'static [ValidationCase] {
&[
ValidationCase {
rule_id: "V001",
description: "valid op program is accepted",
expectation: ValidationExpectation::MustAccept,
},
ValidationCase {
rule_id: "V001",
description: "duplicate buffer names are rejected",
expectation: ValidationExpectation::MustReject,
},
]
}
fn validation_expectation_label(expectation: ValidationExpectation) -> &'static str {
match expectation {
ValidationExpectation::MustAccept => "must_accept",
ValidationExpectation::MustReject => "must_reject",
}
}
fn mutation_routes(op: &OpSpec) -> Vec<MutationRoute> {
applied_for_classes(op.mutation_sensitivity)
.into_iter()
.map(|mutation| MutationRoute {
id: mutation.id().to_string(),
description: mutation.description().to_string(),
details: mutation.hint(),
})
.collect()
}
fn resolved_oracle_kind(op: &OpSpec, property: &Property) -> OracleKind {
let resolved = resolve(op, property).kind();
if matches!(property, Property::SpecTableRowMatches { row_index } if *row_index < spec_rows(op).len())
{
return OracleKind::SpecTable;
}
resolved
}
fn matching_spec_table_inputs(
op: &OpSpec,
inputs: &[TestInput],
) -> Result<Vec<(usize, usize)>, GenError> {
let mut matches = Vec::new();
for (input_index, input) in inputs.iter().enumerate() {
if input.values.len() != op.signature.inputs.len() {
return Err(GenError::ArityMismatch {
expected: op.signature.inputs.len(),
actual: input.values.len(),
op: op.id,
});
}
let input_bytes = input_bytes_for_arity(input, op.signature.inputs.len());
for (row_index, row) in spec_rows(op).iter().enumerate() {
let row_input = row
.inputs
.iter()
.flat_map(|bytes| bytes.iter().copied())
.collect::<Vec<_>>();
if row_input == input_bytes {
matches.push((input_index, row_index));
}
}
}
Ok(matches)
}
fn spec_rows(op: &OpSpec) -> &'static [SpecRow] {
if !op.spec_table.is_empty() {
return op.spec_table;
}
crate::spec::ops::all(crate::spec::spec_op_sources())
.iter()
.find(|candidate| candidate.id == op.id)
.map(|candidate| candidate.spec_table)
.unwrap_or(&[])
}
fn stronger_property_exists(op: &OpSpec) -> bool {
!op.declared_laws.is_empty()
|| !spec_rows(op).is_empty()
|| op.oracle_override.is_some()
|| op.signature.min_input_bytes() > 0
}
fn is_composed(op: &OpSpec) -> bool {
matches!(&op.category, Category::A { composition_of } if composition_of.len() > 1)
}
fn emit_tuple(
out_dir: &Path,
seeds: &mut SeedStream,
op: &'static OpSpec,
archetype: &'static dyn Archetype,
input: &TestInput,
route: &Route,
) -> Result<GeneratedFile, GenError> {
let route_key = route_key(route);
let seed = seeds.derive(op.id, archetype.id(), &route_key);
let name = test_name(op.id, archetype.id(), &route_key, seed);
let rendered = render(
template_src(route.template),
&context(op, archetype, input, route),
)?;
let tuple = SourceTuple {
op: op.id,
archetype: archetype.id(),
oracle: &route.oracle_name,
};
let provenance = header(&tuple, seeds.master(), op.version);
let rust = rust_test(&name, op, archetype, input, route, seed);
let path = out_dir.join(sanitize(op.id)).join(format!("{name}.rs"));
let written = write_test(
&path,
&EmittedTest {
test_name: &name,
provenance: &provenance,
rendered_template: &rendered,
rust: &rust,
op_under_test: op.id,
},
)?;
Ok(GeneratedFile {
path: written,
test_name: name,
op: op.id,
archetype: archetype.id(),
oracle: route.oracle_name.clone(),
template: route.template,
seed,
})
}
fn context(
op: &OpSpec,
archetype: &dyn Archetype,
input: &TestInput,
route: &Route,
) -> TemplateContext {
TemplateContext::new()
.with_scalar("op_name", op.id)
.with_scalar("op", op.id)
.with_scalar("file", last_segment(op.id))
.with_scalar("expected", expected_bytes(op, route))
.with_scalar("inputs", format!("{:?}", input.values))
.with_scalar("spec_path", "op_registry::all_specs")
.with_scalar("line", provenance_line(route))
.with_scalar("category", format!("{:?}", op.category))
.with_scalar("input_desc", sanitize(input.label))
.with_scalar("law_name", law_name(&route.property))
.with_scalar("law_description", route.property_label.clone())
.with_scalar("input_class", input.label)
.with_scalar("law_formula", route.property_label.clone())
.with_scalar("specific_inputs", format!("{:?}", input.values))
.with_scalar("law", law_name(&route.property))
.with_scalar("archetype_id", archetype.id())
.with_scalar("archetype_name", archetype.name())
.with_scalar("archetype_description", archetype.description())
.with_scalar("property", route.property_label.clone())
.with_scalar("oracle_chosen_from_hierarchy", route.oracle_name.clone())
.with_scalar(
"mutation_description",
route
.mutation
.as_ref()
.map(|mutation| mutation.description.clone())
.unwrap_or_else(|| "not requested".to_string()),
)
.with_scalar(
"mutation_details",
route
.mutation
.as_ref()
.map(|mutation| mutation.details.clone())
.unwrap_or_else(|| "not requested".to_string()),
)
.with_scalar(
"V_NNN",
route
.validation
.as_ref()
.map(|validation| validation.rule_id)
.unwrap_or("V000"),
)
.with_scalar("description", route.property_label.clone())
.with_scalar(
"NNN",
route
.validation
.as_ref()
.map(|validation| validation.rule_id.trim_start_matches('V'))
.unwrap_or("000"),
)
}
fn provenance_line(route: &Route) -> String {
spec_row_index(&route.property)
.map(|index| (index + 1).to_string())
.unwrap_or_else(|| "not-spec-table".to_string())
}
use codegen::{expected_bytes, law_name, route_key, rust_test, spec_row_index};
use util::{input_bytes_for_arity, last_segment, oracle_name, property_key, template_src};