pub mod binding;
pub mod depth;
pub mod err;
pub mod options;
pub mod limits;
pub mod validate;
pub mod report;
pub mod validation_error;
mod atomic_rules;
mod barrier;
mod bytes_rejection;
mod cast;
mod expr_rules;
mod fusion_safety;
pub mod linear_type;
mod nodes;
mod self_composition;
mod shadowing;
pub mod shape_predicate;
mod typecheck;
mod uniformity;
pub(crate) use binding::Binding;
pub use depth::{
LimitState, DEFAULT_MAX_CALL_DEPTH, DEFAULT_MAX_EXPR_DEPTH, DEFAULT_MAX_NESTING_DEPTH,
DEFAULT_MAX_NODE_COUNT,
};
pub(crate) use err::err;
pub use options::{BackendCapabilities, BackendValidationCapabilities, ValidationOptions};
pub use report::{ValidationReport, ValidationWarning};
pub use validate::validate;
pub use validate::validate_with_options;
pub use validation_error::ValidationError;
#[cfg(test)]
mod tests {
use super::validate;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};
#[test]
fn array_output_buffer_rejected() {
let program = Program::wrapped(
vec![BufferDecl::output(
"out",
0,
DataType::Array { element_size: 4 },
)],
[1, 1, 1],
Vec::new(),
);
let errors = validate(&program);
assert!(errors.iter().any(|error| {
error
.message
.contains("output buffer `out` uses unsupported element type `array<4B>`")
}));
}
#[test]
fn tensor_output_buffer_rejected() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::Tensor)],
[1, 1, 1],
Vec::new(),
);
let errors = validate(&program);
assert!(errors.iter().any(|error| {
error
.message
.contains("output buffer `out` uses unsupported element type `tensor`")
}));
}
#[test]
fn wrapped_constructor_inserts_root_region_for_raw_entry() {
let program = Program::wrapped(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(9)), Node::Return],
);
assert!(
matches!(program.entry(), [Node::Region { generator, .. }] if generator.as_ref() == Program::ROOT_REGION_GENERATOR)
);
}
#[test]
fn wrapped_constructor_preserves_existing_top_level_regions() {
let body = vec![Node::Return];
let region = Node::Region {
generator: "already.region".into(),
source_region: None,
body: std::sync::Arc::new(body),
};
let program = Program::wrapped(Vec::new(), [1, 1, 1], vec![region.clone()]);
assert_eq!(program.entry(), &[region]);
}
#[test]
#[allow(deprecated)]
fn raw_top_level_statement_is_rejected() {
let program = Program::new(
vec![BufferDecl::output("out", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(7)), Node::Return],
);
let errors = validate(&program);
assert!(errors.iter().any(|error| {
error.message.contains("top-level Region") || error.message.contains("Node::Region")
}));
}
#[test]
fn linear_type_violation_surfaces_through_validate() {
use crate::ir::LinearType;
let program = Program::wrapped(
vec![BufferDecl::output("ghost", 0, DataType::U32)
.with_count(1)
.with_linear_type(LinearType::Linear)],
[1, 1, 1],
vec![Node::Return],
);
let errors = validate(&program);
assert!(
errors
.iter()
.any(|e| e.message.contains("`ghost` declared `LinearType::Linear`")),
"linear-type checker not wired into validate(): got {errors:?}"
);
}
#[test]
fn unrestricted_buffer_is_not_flagged_by_linear_type_checker() {
let program = Program::wrapped(
vec![BufferDecl::output("ok", 0, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("ok", Expr::u32(0), Expr::u32(0))],
);
let errors = validate(&program);
assert!(
!errors.iter().any(|e| e.message.contains("LinearType::")),
"unrestricted buffer flagged: {errors:?}"
);
}
#[test]
fn shape_predicate_violation_surfaces_through_validate() {
use crate::ir::ShapePredicate;
let program = Program::wrapped(
vec![BufferDecl::output("misaligned", 0, DataType::U32)
.with_count(3)
.with_shape_predicate(ShapePredicate::MultipleOf(64))],
[1, 1, 1],
vec![Node::store("misaligned", Expr::u32(0), Expr::u32(0))],
);
let errors = validate(&program);
assert!(
errors.iter().any(
|e| e.message.contains("`misaligned`") && e.message.contains("count % 64 == 0")
),
"shape-predicate checker not wired into validate(): got {errors:?}"
);
}
#[test]
fn satisfied_shape_predicate_is_not_flagged() {
use crate::ir::ShapePredicate;
let program = Program::wrapped(
vec![BufferDecl::output("aligned", 0, DataType::U32)
.with_count(128)
.with_shape_predicate(ShapePredicate::MultipleOf(64))],
[1, 1, 1],
vec![Node::store("aligned", Expr::u32(0), Expr::u32(0))],
);
let errors = validate(&program);
assert!(
!errors.iter().any(|e| e.message.contains("count % ")),
"satisfied predicate flagged: {errors:?}"
);
}
}