use crate::op::Op;
use crate::program::Function;
#[derive(Debug, Clone, PartialEq)]
pub struct StackError {
pub fn_name: String,
pub pc: usize,
pub depth_a: i32,
pub depth_b: i32,
}
impl std::fmt::Display for StackError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"stack depth mismatch in `{}` at pc {}: path A leaves depth {}, path B leaves depth {}",
self.fn_name, self.pc, self.depth_a, self.depth_b
)
}
}
pub fn verify_program(functions: &[Function]) -> Vec<StackError> {
let mut errors = Vec::new();
for func in functions {
verify_function(func, &mut errors);
}
errors
}
pub fn verify_function(func: &Function, errors: &mut Vec<StackError>) {
let n = func.code.len();
if n == 0 {
return;
}
let mut depths: Vec<Option<i32>> = vec![None; n];
let mut worklist: Vec<(usize, i32)> = vec![(0, 0)];
while let Some((pc, depth)) = worklist.pop() {
if pc >= n {
continue;
}
if let Some(prev) = depths[pc] {
if prev != depth {
errors.push(StackError {
fn_name: func.name.clone(),
pc,
depth_a: prev,
depth_b: depth,
});
}
continue;
}
depths[pc] = Some(depth);
let op = &func.code[pc];
let delta = stack_delta(op);
let next_depth = depth + delta;
match op {
Op::Jump(off) => {
let target = (pc as i32 + 1 + off) as usize;
worklist.push((target, next_depth));
}
Op::JumpIf(off) | Op::JumpIfNot(off) => {
let target = (pc as i32 + 1 + off) as usize;
worklist.push((pc + 1, next_depth));
worklist.push((target, next_depth));
}
Op::Return | Op::TailCall { .. } | Op::Panic(_) => {}
_ => {
worklist.push((pc + 1, next_depth));
}
}
}
}
fn stack_delta(op: &Op) -> i32 {
match op {
Op::PushConst(_) => 1,
Op::Pop => -1,
Op::Dup => 1,
Op::LoadLocal(_) => 1,
Op::StoreLocal(_) => -1,
Op::MakeRecord { field_name_indices } => -(field_name_indices.len() as i32) + 1,
Op::MakeTuple(n) => -(*n as i32) + 1,
Op::MakeList(n) => -(*n as i32) + 1,
Op::MakeVariant { arity, .. } => -(*arity as i32) + 1,
Op::GetField(_) => 0,
Op::GetElem(_) => 0,
Op::GetListElem(_) => 0,
Op::GetListLen => 0,
Op::TestVariant(_) => 0,
Op::GetVariant(_) => 0,
Op::GetVariantArg(_)=> 0,
Op::ListAppend => -1,
Op::GetListElemDyn => -1,
Op::Jump(_) | Op::JumpIf(_) | Op::JumpIfNot(_) => {
match op {
Op::JumpIf(_) | Op::JumpIfNot(_) => -1,
_ => 0,
}
}
Op::Call { arity, .. } => -(*arity as i32) + 1,
Op::TailCall { arity, .. } => -(*arity as i32) + 1,
Op::CallClosure { arity, .. }=> -(*arity as i32 + 1) + 1, Op::EffectCall { arity, .. } => -(*arity as i32) + 1,
Op::MakeClosure { capture_count, .. } => -(*capture_count as i32) + 1,
Op::SortByKey { .. } => -1,
Op::ParallelMap { .. }=> -1,
Op::Return => -1, Op::Panic(_)=> 0,
Op::IntAdd | Op::IntSub | Op::IntMul | Op::IntDiv | Op::IntMod => -1,
Op::IntEq | Op::IntLt | Op::IntLe => -1,
Op::IntNeg => 0,
Op::FloatAdd | Op::FloatSub | Op::FloatMul | Op::FloatDiv => -1,
Op::FloatEq | Op::FloatLt | Op::FloatLe => -1,
Op::FloatNeg => 0,
Op::NumAdd | Op::NumSub | Op::NumMul | Op::NumDiv | Op::NumMod => -1,
Op::NumEq | Op::NumLt | Op::NumLe => -1,
Op::NumNeg => 0,
Op::BoolAnd | Op::BoolOr => -1,
Op::BoolNot => 0,
Op::StrConcat => -1,
Op::StrLen => 0,
Op::StrEq => -1,
Op::BytesLen => 0,
Op::BytesEq => -1,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::op::Op;
use crate::program::Function;
fn make_fn(name: &str, code: Vec<Op>) -> Function {
Function {
name: name.to_string(),
arity: 0,
locals_count: 4,
code,
effects: vec![],
body_hash: crate::program::ZERO_BODY_HASH,
refinements: vec![],
}
}
#[test]
fn clean_match_no_errors() {
let code = vec![
Op::LoadLocal(0), Op::Dup, Op::TestVariant(0), Op::JumpIfNot(3), Op::Pop, Op::PushConst(0), Op::Jump(2), Op::Pop, Op::PushConst(1), Op::Return, ];
let f = make_fn("clean", code);
let mut errs = Vec::new();
verify_function(&f, &mut errs);
assert!(errs.is_empty(), "expected no errors, got: {errs:?}");
}
#[test]
fn leaked_scrutinee_detected() {
let mismatch2 = vec![
Op::PushConst(0), Op::JumpIfNot(2), Op::PushConst(0), Op::Jump(2), Op::PushConst(0), Op::PushConst(0), Op::Return, ];
let f2 = make_fn("mismatch", mismatch2);
let mut errs2 = Vec::new();
verify_function(&f2, &mut errs2);
assert!(!errs2.is_empty(), "expected stack mismatch error");
assert_eq!(errs2[0].fn_name, "mismatch");
}
}