use crate::cfg::MirCfg;
use crate::dominators::DominatorTree;
use crate::loop_analysis::{self, LoopTree};
use crate::reduction::{self, ReductionKind, ReductionReport};
use crate::{BlockId, MirBody, MirExpr, MirExprKind, MirFunction, MirProgram, MirStmt};
#[derive(Debug, Clone)]
pub struct LegalityError {
pub check: LegalityCheck,
pub message: String,
pub function: String,
}
impl std::fmt::Display for LegalityError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"[{:?}] in `{}`: {}",
self.check, self.function, self.message
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LegalityCheck {
CfgStructure,
LoopIntegrity,
ReductionContract,
SsaIntegrity,
NoGcContract,
StructuralBound,
ScheduleMetadataConsistency,
ScheduleMetadataNonSemantic,
MetadataConsistency,
}
#[derive(Debug, Clone)]
pub struct LegalityReport {
pub errors: Vec<LegalityError>,
pub checks_passed: u32,
pub checks_total: u32,
}
impl LegalityReport {
pub fn is_ok(&self) -> bool {
self.errors.is_empty()
}
pub fn errors(&self) -> &[LegalityError] {
&self.errors
}
}
pub fn verify_mir_legality(program: &MirProgram) -> LegalityReport {
let mut errors = Vec::new();
let mut checks_passed = 0u32;
let mut checks_total = 0u32;
for func in &program.functions {
if let Some(ref cfg) = func.cfg_body {
checks_total += 1;
let cfg_errors = check_cfg_structure(cfg, &func.name);
if cfg_errors.is_empty() {
checks_passed += 1;
} else {
errors.extend(cfg_errors);
}
checks_total += 1;
let domtree = DominatorTree::compute(cfg);
let loop_tree = loop_analysis::compute_loop_tree(cfg, &domtree);
let loop_errors = check_loop_integrity(&loop_tree, cfg, &func.name);
if loop_errors.is_empty() {
checks_passed += 1;
} else {
errors.extend(loop_errors);
}
}
checks_total += 1;
let bound_errors = check_structural_bounds(func);
if bound_errors.is_empty() {
checks_passed += 1;
} else {
errors.extend(bound_errors);
}
}
checks_total += 1;
let reduction_report = reduction::detect_reductions(program, &[]);
let reduction_errors = check_reduction_contracts(&reduction_report);
if reduction_errors.is_empty() {
checks_passed += 1;
} else {
errors.extend(reduction_errors);
}
checks_total += 1;
let meta_errors = check_reduction_metadata_consistency(&reduction_report);
if meta_errors.is_empty() {
checks_passed += 1;
} else {
errors.extend(meta_errors);
}
for func in &program.functions {
if let Some(ref cfg) = func.cfg_body {
checks_total += 1;
let domtree = DominatorTree::compute(cfg);
let loop_tree = loop_analysis::compute_loop_tree(cfg, &domtree);
let sched_errors =
check_schedule_metadata(&loop_tree, &reduction_report, &func.name);
if sched_errors.is_empty() {
checks_passed += 1;
} else {
errors.extend(sched_errors);
}
}
}
LegalityReport {
errors,
checks_passed,
checks_total,
}
}
fn check_cfg_structure(cfg: &MirCfg, fn_name: &str) -> Vec<LegalityError> {
let mut errors = Vec::new();
let num_blocks = cfg.basic_blocks.len();
if cfg.entry != BlockId(0) {
errors.push(LegalityError {
check: LegalityCheck::CfgStructure,
message: format!("entry block is {:?}, expected BlockId(0)", cfg.entry),
function: fn_name.to_string(),
});
}
for (i, bb) in cfg.basic_blocks.iter().enumerate() {
if bb.id.0 as usize != i {
errors.push(LegalityError {
check: LegalityCheck::CfgStructure,
message: format!(
"block at index {} has id {:?} (should be {})",
i, bb.id, i
),
function: fn_name.to_string(),
});
}
}
for bb in &cfg.basic_blocks {
for succ in bb.terminator.successors() {
if succ.0 as usize >= num_blocks {
errors.push(LegalityError {
check: LegalityCheck::CfgStructure,
message: format!(
"block {:?} has out-of-bounds successor {:?} (num_blocks={})",
bb.id, succ, num_blocks
),
function: fn_name.to_string(),
});
}
}
}
let mut visited = vec![false; num_blocks];
let mut queue = vec![cfg.entry];
while let Some(block) = queue.pop() {
let idx = block.0 as usize;
if idx >= num_blocks || visited[idx] {
continue;
}
visited[idx] = true;
for succ in cfg.basic_blocks[idx].terminator.successors() {
if !visited[succ.0 as usize] {
queue.push(succ);
}
}
}
if num_blocks > 0 && !visited[0] {
errors.push(LegalityError {
check: LegalityCheck::CfgStructure,
message: "entry block is not reachable (impossible)".to_string(),
function: fn_name.to_string(),
});
}
errors
}
fn check_loop_integrity(
loop_tree: &LoopTree,
cfg: &MirCfg,
fn_name: &str,
) -> Vec<LegalityError> {
let mut errors = Vec::new();
let num_blocks = cfg.basic_blocks.len();
for info in &loop_tree.loops {
if info.body_blocks.binary_search(&info.header).is_err() {
errors.push(LegalityError {
check: LegalityCheck::LoopIntegrity,
message: format!(
"loop {:?} header {:?} not in body_blocks",
info.id, info.header
),
function: fn_name.to_string(),
});
}
for &src in &info.back_edge_sources {
if info.body_blocks.binary_search(&src).is_err() {
errors.push(LegalityError {
check: LegalityCheck::LoopIntegrity,
message: format!(
"loop {:?} back-edge source {:?} not in body_blocks",
info.id, src
),
function: fn_name.to_string(),
});
}
}
for &block in &info.body_blocks {
if block.0 as usize >= num_blocks {
errors.push(LegalityError {
check: LegalityCheck::LoopIntegrity,
message: format!(
"loop {:?} body contains out-of-bounds block {:?}",
info.id, block
),
function: fn_name.to_string(),
});
}
}
let is_sorted = info
.body_blocks
.windows(2)
.all(|w| w[0] <= w[1]);
if !is_sorted {
errors.push(LegalityError {
check: LegalityCheck::LoopIntegrity,
message: format!(
"loop {:?} body_blocks not sorted (determinism violation)",
info.id
),
function: fn_name.to_string(),
});
}
if let Some(parent) = info.parent {
let parent_info = &loop_tree.loops[parent.0 as usize];
if !parent_info.children.contains(&info.id) {
errors.push(LegalityError {
check: LegalityCheck::LoopIntegrity,
message: format!(
"loop {:?} claims parent {:?} but parent doesn't list it as child",
info.id, parent
),
function: fn_name.to_string(),
});
}
}
}
for info in &loop_tree.loops {
let mut visited = vec![false; loop_tree.loops.len()];
let mut current = Some(info.id);
while let Some(lid) = current {
let idx = lid.0 as usize;
if visited[idx] {
errors.push(LegalityError {
check: LegalityCheck::LoopIntegrity,
message: format!("loop nesting cycle detected involving {:?}", lid),
function: fn_name.to_string(),
});
break;
}
visited[idx] = true;
current = loop_tree.loops[idx].parent;
}
}
errors
}
fn check_reduction_contracts(report: &ReductionReport) -> Vec<LegalityError> {
let mut errors = Vec::new();
for r in &report.reductions {
if r.kind == ReductionKind::StrictFold && r.kind.is_reorderable() {
errors.push(LegalityError {
check: LegalityCheck::ReductionContract,
message: format!(
"StrictFold reduction on `{}` is marked reorderable",
r.accumulator_var
),
function: r.function_name.clone(),
});
}
if r.kind == ReductionKind::Unknown && r.kind.is_parallelizable() {
errors.push(LegalityError {
check: LegalityCheck::ReductionContract,
message: format!(
"Unknown reduction on `{}` is marked parallelizable",
r.accumulator_var
),
function: r.function_name.clone(),
});
}
}
errors
}
const MAX_NESTING_DEPTH: u32 = 256;
fn check_structural_bounds(func: &MirFunction) -> Vec<LegalityError> {
let mut errors = Vec::new();
check_body_depth(&func.body, 0, &func.name, &mut errors);
errors
}
fn check_body_depth(
body: &MirBody,
depth: u32,
fn_name: &str,
errors: &mut Vec<LegalityError>,
) {
if depth > MAX_NESTING_DEPTH {
errors.push(LegalityError {
check: LegalityCheck::StructuralBound,
message: format!(
"nesting depth {} exceeds maximum {} — possible infinite recursion in MIR",
depth, MAX_NESTING_DEPTH
),
function: fn_name.to_string(),
});
return;
}
for stmt in &body.stmts {
match stmt {
MirStmt::If {
then_body,
else_body,
..
} => {
check_body_depth(then_body, depth + 1, fn_name, errors);
if let Some(eb) = else_body {
check_body_depth(eb, depth + 1, fn_name, errors);
}
}
MirStmt::While { body: wb, .. } => {
check_body_depth(wb, depth + 1, fn_name, errors);
}
MirStmt::NoGcBlock(inner) => {
check_body_depth(inner, depth + 1, fn_name, errors);
}
_ => {}
}
}
}
fn check_schedule_metadata(
loop_tree: &LoopTree,
reduction_report: &ReductionReport,
fn_name: &str,
) -> Vec<LegalityError> {
use crate::loop_analysis::SchedulePlan;
let mut errors = Vec::new();
for info in &loop_tree.loops {
let strict_reds: Vec<_> = reduction_report
.reductions
.iter()
.filter(|r| r.loop_id == Some(info.id) && r.strict_order_required)
.collect();
if !strict_reds.is_empty() && info.schedule != SchedulePlan::SequentialStrict {
errors.push(LegalityError {
check: LegalityCheck::ScheduleMetadataConsistency,
message: format!(
"loop {:?} has {} strict reduction(s) but schedule is {} (must be sequential_strict)",
info.id,
strict_reds.len(),
info.schedule,
),
function: fn_name.to_string(),
});
}
if let SchedulePlan::DescriptiveVectorized { width } = info.schedule {
if width == 0 || (width & (width - 1)) != 0 {
errors.push(LegalityError {
check: LegalityCheck::ScheduleMetadataConsistency,
message: format!(
"loop {:?} has DescriptiveVectorized width {} which is not a power of 2",
info.id, width,
),
function: fn_name.to_string(),
});
}
}
if let SchedulePlan::DescriptiveStaticPartition { chunk_size } = info.schedule {
if chunk_size == 0 {
errors.push(LegalityError {
check: LegalityCheck::ScheduleMetadataConsistency,
message: format!(
"loop {:?} has DescriptiveStaticPartition with chunk_size 0",
info.id,
),
function: fn_name.to_string(),
});
}
}
if let SchedulePlan::DescriptiveTiled { tile_size } = info.schedule {
if tile_size == 0 {
errors.push(LegalityError {
check: LegalityCheck::ScheduleMetadataConsistency,
message: format!(
"loop {:?} has DescriptiveTiled with tile_size 0",
info.id,
),
function: fn_name.to_string(),
});
}
}
}
errors
}
fn check_reduction_metadata_consistency(report: &ReductionReport) -> Vec<LegalityError> {
let mut errors = Vec::new();
for r in &report.reductions {
match r.kind {
ReductionKind::StrictFold => {
if !r.reassociation_forbidden {
errors.push(LegalityError {
check: LegalityCheck::MetadataConsistency,
message: format!(
"StrictFold reduction on `{}` has reassociation_forbidden=false",
r.accumulator_var,
),
function: r.function_name.clone(),
});
}
if !r.strict_order_required {
errors.push(LegalityError {
check: LegalityCheck::MetadataConsistency,
message: format!(
"StrictFold reduction on `{}` has strict_order_required=false",
r.accumulator_var,
),
function: r.function_name.clone(),
});
}
}
ReductionKind::Unknown => {
if !r.reassociation_forbidden {
errors.push(LegalityError {
check: LegalityCheck::MetadataConsistency,
message: format!(
"Unknown reduction on `{}` has reassociation_forbidden=false",
r.accumulator_var,
),
function: r.function_name.clone(),
});
}
}
_ => {}
}
}
errors
}
pub fn verify_function(func: &MirFunction) -> LegalityReport {
let program = MirProgram {
functions: vec![func.clone()],
struct_defs: vec![],
enum_defs: vec![],
entry: func.id,
};
verify_mir_legality(&program)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{MirBody, MirExpr, MirExprKind, MirFnId, MirParam, MirStmt};
use cjc_ast::Visibility;
fn int_expr(v: i64) -> MirExpr {
MirExpr {
kind: MirExprKind::IntLit(v),
}
}
fn bool_expr(b: bool) -> MirExpr {
MirExpr {
kind: MirExprKind::BoolLit(b),
}
}
fn make_fn(name: &str, body: MirBody) -> MirFunction {
MirFunction {
id: MirFnId(0),
name: name.to_string(),
type_params: vec![],
params: vec![],
return_type: None,
body,
is_nogc: false,
cfg_body: None,
decorators: vec![],
vis: Visibility::Private,
local_count: 0,
}
}
fn make_program(functions: Vec<MirFunction>) -> MirProgram {
let entry = functions.last().map(|f| f.id).unwrap_or(MirFnId(0));
MirProgram {
functions,
struct_defs: vec![],
enum_defs: vec![],
entry,
}
}
#[test]
fn test_empty_program_passes() {
let program = make_program(vec![make_fn(
"__main",
MirBody {
stmts: vec![],
result: None,
},
)]);
let report = verify_mir_legality(&program);
assert!(report.is_ok(), "empty program should pass: {:?}", report.errors);
}
#[test]
fn test_simple_while_passes() {
let mut func = make_fn(
"test",
MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
}],
result: None,
},
);
func.build_cfg();
let program = make_program(vec![func]);
let report = verify_mir_legality(&program);
assert!(report.is_ok(), "simple while should pass: {:?}", report.errors);
}
#[test]
fn test_cfg_structure_valid() {
let mut func = make_fn(
"test",
MirBody {
stmts: vec![
MirStmt::Let {
name: "x".into(),
mutable: false,
init: int_expr(42),
alloc_hint: None,
slot: None,
},
MirStmt::If {
cond: bool_expr(true),
then_body: MirBody {
stmts: vec![],
result: Some(Box::new(int_expr(1))),
},
else_body: Some(MirBody {
stmts: vec![],
result: Some(Box::new(int_expr(2))),
}),
},
],
result: None,
},
);
func.build_cfg();
let cfg = func.cfg_body.as_ref().unwrap();
let errors = check_cfg_structure(cfg, "test");
assert!(errors.is_empty(), "valid CFG should have no errors: {:?}", errors);
}
#[test]
fn test_reasonable_nesting_passes() {
let mut body = MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
};
for _ in 0..10 {
body = MirBody {
stmts: vec![MirStmt::If {
cond: bool_expr(true),
then_body: body,
else_body: None,
}],
result: None,
};
}
let func = make_fn("test", body);
let errors = check_structural_bounds(&func);
assert!(errors.is_empty(), "10-level nesting should pass");
}
#[test]
fn test_strict_fold_is_consistent() {
let report = ReductionReport {
reductions: vec![reduction::ReductionInfo {
id: reduction::ReductionId(0),
accumulator_var: "acc".to_string(),
op: reduction::ReductionOp::Add,
kind: ReductionKind::StrictFold,
loop_id: None,
function_name: "test".to_string(),
builtin_name: None,
reassociation_forbidden: true,
strict_order_required: true,
accumulator_semantics: reduction::AccumulatorSemantics::Plain,
}],
};
let errors = check_reduction_contracts(&report);
assert!(errors.is_empty(), "consistent StrictFold should pass");
}
#[test]
fn test_loop_integrity_nested() {
let mut func = make_fn(
"test",
MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
}],
result: None,
},
}],
result: None,
},
);
func.build_cfg();
let cfg = func.cfg_body.as_ref().unwrap();
let domtree = DominatorTree::compute(cfg);
let loop_tree = loop_analysis::compute_loop_tree(cfg, &domtree);
let errors = check_loop_integrity(&loop_tree, cfg, "test");
assert!(
errors.is_empty(),
"well-formed nested loops should pass: {:?}",
errors
);
}
#[test]
fn test_verify_function_convenience() {
let func = make_fn(
"simple",
MirBody {
stmts: vec![MirStmt::Let {
name: "x".into(),
mutable: false,
init: int_expr(42),
alloc_hint: None,
slot: None,
}],
result: Some(Box::new(int_expr(42))),
},
);
let report = verify_function(&func);
assert!(report.is_ok());
assert!(report.checks_passed > 0);
}
#[test]
fn test_report_structure() {
let program = make_program(vec![make_fn(
"__main",
MirBody {
stmts: vec![],
result: None,
},
)]);
let report = verify_mir_legality(&program);
assert!(report.is_ok());
assert!(report.checks_total > 0);
assert_eq!(report.checks_passed, report.checks_total);
assert!(report.errors().is_empty());
}
}