use std::collections::{HashMap, HashSet};
use syn::spanned::Spanned;
use syn::visit::{self, Visit};
use syn::{Expr, ExprCall, ExprMethodCall, ExprPath, Local, Pat, PatIdent};
const FORBIDDEN_CALL_SEGMENTS: &[&str] = &["cpu_fn", "wgsl_fn"];
const ALLOWED_CALL_SEGMENTS: &[&str] = &["run", "resolve"];
const ALLOWED_MACRO_SEGMENTS: &[&str] = &[
"assert",
"assert_eq",
"assert_ne",
"debug_assert",
"debug_assert_eq",
"debug_assert_ne",
"format",
"panic",
"vec",
];
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct IndependenceCertificate {
pub op_name: String,
pub verdict: CertificateVerdict,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum CertificateVerdict {
Independent,
SuspectedTautology {
line: usize,
call_path: String,
hint: String,
},
}
#[inline]
pub fn verify_test_independence(
rust_source: &str,
op_under_test: &str,
) -> Result<IndependenceCertificate, IndependenceError> {
let file = syn::parse_file(rust_source).map_err(|err| IndependenceError::InvalidRust {
reason: err.to_string(),
})?;
for item in &file.items {
if let syn::Item::Fn(f) = item {
if has_allow_tautological_test(&f.attrs) {
tracing::warn!(
op = op_under_test,
function = %f.sig.ident,
"independence check bypassed via #[allow(tautological_test)]"
);
return Ok(IndependenceCertificate {
op_name: op_under_test.to_string(),
verdict: CertificateVerdict::Independent,
});
}
}
}
let forbidden_segments = build_forbidden_set(op_under_test);
let mut local_fns = HashMap::new();
for item in &file.items {
if let syn::Item::Fn(f) = item {
local_fns.insert(f.sig.ident.to_string(), f);
}
}
let mut visitor = IndependenceVisitor {
op_under_test,
forbidden: &forbidden_segments,
verdict: CertificateVerdict::Independent,
local_fns,
local_closures: HashMap::new(),
scanning_stack: HashSet::new(),
};
visitor.visit_file(&file);
Ok(IndependenceCertificate {
op_name: op_under_test.to_string(),
verdict: visitor.verdict,
})
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum IndependenceError {
InvalidRust {
reason: String,
},
}
impl std::fmt::Display for IndependenceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidRust { reason } => write!(
f,
"independence checker received invalid Rust. Fix: {reason}"
),
}
}
}
impl std::error::Error for IndependenceError {}
fn has_allow_tautological_test(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|attr| {
let syn::Meta::List(meta_list) = &attr.meta else {
return false;
};
meta_list.path.is_ident("allow")
&& meta_list.tokens.to_string().contains("tautological_test")
})
}
fn build_forbidden_set(op_under_test: &str) -> HashSet<String> {
assert!(!op_under_test.is_empty(), "op_under_test must be non-empty");
assert!(
!op_under_test.ends_with('.') && !op_under_test.ends_with("::"),
"op_under_test `{op_under_test}` must not end with a path delimiter. \
Fix: strip trailing `.` or `::`.",
);
let mut set: HashSet<String> = FORBIDDEN_CALL_SEGMENTS
.iter()
.copied()
.map(String::from)
.collect();
let dotted_last: &str = op_under_test.rsplit('.').next().unwrap_or("");
if !dotted_last.is_empty() {
set.insert(dotted_last.to_string());
}
let module_last: &str = op_under_test.rsplit("::").next().unwrap_or("");
if !module_last.is_empty() {
set.insert(module_last.to_string());
}
set
}
struct IndependenceVisitor<'a> {
op_under_test: &'a str,
forbidden: &'a HashSet<String>,
verdict: CertificateVerdict,
local_fns: HashMap<String, &'a syn::ItemFn>,
local_closures: HashMap<String, &'a syn::ExprClosure>,
scanning_stack: HashSet<String>,
}
impl<'ast> Visit<'ast> for IndependenceVisitor<'ast> {
fn visit_local(&mut self, local: &'ast Local) {
if let Pat::Ident(PatIdent { ident, .. }) = &local.pat {
if let Some(init) = &local.init {
if let Expr::Closure(closure) = &*init.expr {
self.local_closures.insert(ident.to_string(), closure);
}
}
}
visit::visit_local(self, local);
}
fn visit_expr(&mut self, i: &'ast Expr) {
if matches!(self.verdict, CertificateVerdict::Independent) {
match i {
Expr::Call(call) => self.scan_call(call, expr_line(i)),
Expr::MethodCall(method) => self.scan_method_call(method, expr_line(i)),
Expr::Macro(m) => self.scan_macro_direct(&m.mac, expr_line(i)),
_ => {}
}
}
visit::visit_expr(self, i);
}
fn visit_macro(&mut self, i: &'ast syn::Macro) {
if matches!(self.verdict, CertificateVerdict::Independent) {
self.scan_macro_direct(i, span_line(i.span()));
}
visit::visit_macro(self, i);
}
}
impl IndependenceVisitor<'_> {
fn scan_call(&mut self, call: &ExprCall, line: usize) {
let Expr::Path(ExprPath { path, .. }) = call.func.as_ref() else {
return;
};
let last = match path.segments.last() {
Some(segment) => segment.ident.to_string(),
None => return,
};
if ALLOWED_CALL_SEGMENTS.contains(&last.as_str()) {
return;
}
if self.forbidden.contains(&last) {
self.verdict = CertificateVerdict::SuspectedTautology {
line,
call_path: path_to_string(path),
hint: format!(
"this test derives its oracle from a call to `{last}`, which \
resolves into the op under test ({op}). Fix: derive expected \
from (a) a hand-written literal in the spec table, (b) \
vyre_reference::run on an independent IR program, or (c) \
a law formula evaluated on the output.",
op = self.op_under_test,
),
};
return;
}
if !self.scanning_stack.contains(&last) {
if let Some(f) = self.local_fns.get(&last).copied() {
self.scanning_stack.insert(last.clone());
visit::visit_item_fn(self, f);
self.scanning_stack.remove(&last);
} else if let Some(closure) = self.local_closures.get(&last).copied() {
self.scanning_stack.insert(last.clone());
visit::visit_expr(self, &closure.body);
self.scanning_stack.remove(&last);
}
}
}
fn scan_method_call(&mut self, method: &ExprMethodCall, line: usize) {
let ident = method.method.to_string();
if ALLOWED_CALL_SEGMENTS.contains(&ident.as_str()) {
return;
}
if self.forbidden.contains(&ident) {
self.verdict = CertificateVerdict::SuspectedTautology {
line,
call_path: format!(".{ident}()"),
hint: format!(
"this test derives its oracle from `.{ident}()`, which on \
a spec object resolves into the op under test ({op}). Derive \
expected from a literal, the reference interpreter, or a law \
formula — never from a call on the spec being tested.",
op = self.op_under_test,
),
};
}
}
fn scan_macro_direct(&mut self, mac: &syn::Macro, line: usize) {
let Some(last) = mac.path.segments.last() else {
self.reject_unverified_macro("<empty>", line);
return;
};
let macro_name = last.ident.to_string();
if self.forbidden.contains(¯o_name) {
self.verdict = CertificateVerdict::SuspectedTautology {
line,
call_path: format!("{}!", path_to_string(&mac.path)),
hint: format!(
"this test invokes macro `{macro_name}!` which resolves into \
the op under test ({op}). Fix: derive expected from a literal \
spec row, vyre_reference::run, or a law formula.",
op = self.op_under_test,
),
};
return;
}
let token_text = mac.tokens.to_string();
if let Some(segment) = self
.forbidden
.iter()
.find(|segment| token_contains_segment(&token_text, segment))
{
self.verdict = CertificateVerdict::SuspectedTautology {
line,
call_path: format!("{}!(... {segment} ...)", path_to_string(&mac.path)),
hint: format!(
"this macro invocation contains `{segment}` while computing \
an oracle for {op}. Macros are opaque before expansion, so \
the independence gate rejects the value. Fix: move the oracle \
to a literal spec row, vyre_reference::run, or an explicit law formula.",
op = self.op_under_test,
),
};
return;
}
if !ALLOWED_MACRO_SEGMENTS.contains(¯o_name.as_str()) {
self.reject_unverified_macro(¯o_name, line);
}
}
fn reject_unverified_macro(&mut self, macro_name: &str, line: usize) {
self.verdict = CertificateVerdict::SuspectedTautology {
line,
call_path: format!("{macro_name}!"),
hint: format!(
"macro `{macro_name}!` is not in the independence gate's safe \
macro list, so its expansion cannot be verified before emit. \
Fix: replace it with explicit Rust expressions or add a narrow \
audited scanner rule before using it in generated tests for {op}.",
op = self.op_under_test,
),
};
}
}
fn expr_line(expr: &syn::Expr) -> usize {
use syn::spanned::Spanned;
expr.span().start().line
}
fn span_line(span: proc_macro2::Span) -> usize {
span.start().line
}
fn path_to_string(path: &syn::Path) -> String {
path.segments
.iter()
.map(|segment| segment.ident.to_string())
.collect::<Vec<_>>()
.join("::")
}
fn token_contains_segment(tokens: &str, segment: &str) -> bool {
tokens
.split(|ch: char| !(ch.is_ascii_alphanumeric() || ch == '_'))
.any(|token| token == segment)
}
#[cfg(test)]
mod tests;