use crate::proof::comparator::ComparatorKind;
use crate::spec::types::{ChainSpec, OpSpec, Strictness};
mod proof;
mod suite;
pub mod backend;
pub mod certify;
pub mod execution;
pub mod loader;
pub mod notify;
pub mod reporter;
#[allow(missing_docs)]
pub mod streaming;
pub use suite::ConformanceSuite;
use crate::spec::minimums::{MIN_BOUNDARY_VALUES, MIN_EQUIVALENCE_CLASSES};
#[inline]
pub fn validate_minimum_coverage(op: &crate::OpSpec) -> Result<(), String> {
if op.boundary_values.len() < MIN_BOUNDARY_VALUES {
return Err(format!(
"Fix: op '{}' has {} boundary values, minimum is {}.\n\
Add boundary values covering zero, one, max, and at least one\n\
domain-specific edge case.",
op.id,
op.boundary_values.len(),
MIN_BOUNDARY_VALUES
));
}
if op.equivalence_classes.len() < MIN_EQUIVALENCE_CLASSES {
return Err(format!(
"Fix: op '{}' has {} equivalence classes, minimum is {}.\n\
Add at least one equivalence class describing the input domain.",
op.id,
op.equivalence_classes.len(),
MIN_EQUIVALENCE_CLASSES
));
}
Ok(())
}
#[inline]
pub(crate) fn workgroup_sizes(preferred: Option<u32>) -> Result<Vec<u32>, String> {
let mut sizes = vec![1u32, 64];
if let Some(size) = preferred {
if size == 0 {
return Err(
"Fix: workgroup_size cannot be 0. Use Some(n) with n in 1..=1024 \
or None to accept the default schedule."
.to_string(),
);
}
if size > 1024 {
return Err(format!(
"Fix: workgroup_size {size} exceeds 1024. Cap to the device max."
));
}
sizes.push(size);
}
sizes.sort_unstable();
sizes.dedup();
Ok(sizes)
}
fn chain_workgroup_sizes(chain: &ChainSpec) -> Result<Vec<u32>, String> {
let mut sizes = workgroup_sizes(None)?;
for spec in &chain.specs {
if let Some(size) = spec.workgroup_size {
for added in workgroup_sizes(Some(size))? {
sizes.push(added);
}
}
}
sizes.sort_unstable();
sizes.dedup();
Ok(sizes)
}
#[inline]
pub(crate) fn chain_version(chain: &ChainSpec) -> u32 {
debug_assert!(
!chain.specs.is_empty(),
"chain_version called on empty chain '{}' — caller should reject \
empty chains explicitly (P1.20-F18)",
chain.id
);
chain
.specs
.iter()
.map(|spec| spec.version)
.max()
.unwrap_or(0)
}
#[inline]
pub(crate) fn chain_comparator(chain: &ChainSpec) -> ComparatorKind {
chain
.specs
.last()
.map_or(ComparatorKind::ExactMatch, |spec| spec.comparator)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) enum CertificateTrack {
Integer,
Float,
Approximate,
}
#[inline]
pub(crate) fn certificate_track_for_op(op: &OpSpec) -> CertificateTrack {
match op.strictness {
Strictness::Approximate { .. } => CertificateTrack::Approximate,
Strictness::Strict if op.signature.output.is_float_family() => CertificateTrack::Float,
Strictness::Strict => CertificateTrack::Integer,
}
}
fn vyre_minimum_capabilities() -> naga::valid::Capabilities {
use naga::valid::Capabilities;
Capabilities::empty()
}
#[inline]
pub fn validate_wgsl_syntax(op: &crate::OpSpec) -> Result<(), String> {
let wgsl_source = (op.wgsl_fn)();
let config = crate::pipeline::backend::ConformDispatchConfig::default();
let wrapped = crate::pipeline::backend::wrap_shader(&wgsl_source, &config);
let result = naga::front::wgsl::parse_str(&wrapped);
match result {
Ok(module) => {
let info = naga::valid::Validator::new(
naga::valid::ValidationFlags::all(),
vyre_minimum_capabilities(),
)
.validate(&module);
match info {
Ok(_) => Ok(()),
Err(e) => Err(format!(
"Fix: op '{}' WGSL fails naga validation: {e}\n\
The shader parses but has semantic errors, OR uses a \
naga capability outside vyre's minimum allowlist (see \
vyre_minimum_capabilities). If a new capability is \
required, opt it in there + update \
coordination/wgsl-capability-allowlist.md.",
op.id
)),
}
}
Err(e) => Err(format!(
"Fix: op '{}' WGSL fails naga parsing: {e}\n\
The shader source is syntactically invalid.",
op.id
)),
}
}
#[cfg(test)]
mod tests;