use std::sync::Arc;
use morok_ir::UOp;
use crate::z3::convert::Z3Context;
pub type VerificationResult = Result<(), CounterExample>;
#[derive(Debug, Clone)]
pub enum CounterExample {
Found { message: String, model: String },
Timeout,
ConversionFailed(String),
}
impl std::fmt::Display for CounterExample {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Found { message, model } => {
write!(f, "Counterexample found: {}\nModel: {}", message, model)
}
Self::Timeout => write!(f, "Z3 timeout or unknown result"),
Self::ConversionFailed(s) => write!(f, "Conversion failed: {}", s),
}
}
}
impl std::error::Error for CounterExample {}
pub fn verify_equivalence(original: &Arc<UOp>, simplified: &Arc<UOp>) -> VerificationResult {
let mut z3ctx = Z3Context::new();
let z3_original = match z3ctx.convert_uop(original) {
Ok(expr) => expr,
Err(e) => return Err(CounterExample::ConversionFailed(format!("Failed to convert original: {}", e))),
};
let z3_simplified = match z3ctx.convert_uop(simplified) {
Ok(expr) => expr,
Err(e) => return Err(CounterExample::ConversionFailed(format!("Failed to convert simplified: {}", e))),
};
let (z3_original, z3_simplified) = match (z3_original.as_int(), z3_simplified.as_int()) {
(Some(o), Some(s)) => (o, s),
_ => {
match (z3_original.as_bool(), z3_simplified.as_bool()) {
(Some(o), Some(s)) => {
let solver = z3ctx.solver();
solver.assert(o.eq(s).not());
match solver.check() {
z3::SatResult::Unsat => return Ok(()),
z3::SatResult::Sat => {
let model = solver
.get_model()
.map(|m| m.to_string())
.unwrap_or_else(|| "No model available".to_string());
return Err(CounterExample::Found {
message: "Boolean expressions not equivalent".to_string(),
model,
});
}
z3::SatResult::Unknown => return Err(CounterExample::Timeout),
}
}
_ => {
return Err(CounterExample::ConversionFailed(
"Type mismatch: cannot compare expressions".to_string(),
));
}
}
}
};
let solver = z3ctx.solver();
solver.assert(z3_original.eq(z3_simplified).not());
match solver.check() {
z3::SatResult::Unsat => {
Ok(())
}
z3::SatResult::Sat => {
let model = solver.get_model().map(|m| m.to_string()).unwrap_or_else(|| "No model available".to_string());
Err(CounterExample::Found {
message: format!(
"Expressions not equivalent:\nOriginal: {:?}\nSimplified: {:?}",
original.op(),
simplified.op()
),
model,
})
}
z3::SatResult::Unknown => Err(CounterExample::Timeout),
}
}
#[cfg(test)]
mod tests {
use super::*;
use morok_dtype::DType;
use morok_ir::types::ConstValue;
#[test]
fn test_verify_identity_add_zero() {
let x = UOp::var("x", DType::Int32, 0, 100);
let zero = UOp::const_(DType::Int32, ConstValue::Int(0));
let x_plus_zero = x.try_add(&zero).unwrap();
verify_equivalence(&x_plus_zero, &x).expect("x + 0 should equal x");
}
#[test]
fn test_verify_commutativity() {
let x = UOp::var("x", DType::Int32, 0, 100);
let y = UOp::var("y", DType::Int32, 0, 100);
let x_plus_y = x.try_add(&y).unwrap();
let y_plus_x = y.try_add(&x).unwrap();
verify_equivalence(&x_plus_y, &y_plus_x).expect("x + y should equal y + x");
}
#[test]
fn test_verify_detect_inequality() {
let x = UOp::var("x", DType::Int32, 0, 100);
let one = UOp::const_(DType::Int32, ConstValue::Int(1));
let x_plus_one = x.try_add(&one).unwrap();
let result = verify_equivalence(&x_plus_one, &x);
assert!(result.is_err(), "x + 1 should not equal x");
if let Err(CounterExample::Found { message, model }) = result {
tracing::debug!(message = %message, model = %model, "z3 counterexample found");
}
}
#[test]
fn test_verify_self_folding() {
let x = UOp::var("x", DType::Int32, 0, 100);
let x_minus_x = x.try_sub(&x).unwrap();
let zero = UOp::const_(DType::Int32, ConstValue::Int(0));
verify_equivalence(&x_minus_x, &zero).expect("x - x should equal 0");
}
}