use std::sync::Arc;
use rustc_hash::FxHashMap;
use crate::eq::{Equation, Term};
use crate::error::GatError;
use crate::model::{Model, ModelValue};
use crate::theory::Theory;
use crate::typecheck::infer_var_sorts;
#[derive(Debug, Clone)]
pub struct EquationViolation {
pub equation: Arc<str>,
pub assignment: FxHashMap<Arc<str>, ModelValue>,
pub lhs_value: ModelValue,
pub rhs_value: ModelValue,
}
#[derive(Debug, Clone)]
pub struct CheckModelOptions {
pub max_assignments: usize,
}
impl Default for CheckModelOptions {
fn default() -> Self {
Self {
max_assignments: 10_000,
}
}
}
pub fn check_model(model: &Model, theory: &Theory) -> Result<Vec<EquationViolation>, GatError> {
check_model_with_options(model, theory, &CheckModelOptions::default())
}
pub fn check_model_with_options(
model: &Model,
theory: &Theory,
options: &CheckModelOptions,
) -> Result<Vec<EquationViolation>, GatError> {
let mut violations = Vec::new();
for eq in &theory.eqs {
let eq_violations = check_equation(model, eq, theory, options)?;
violations.extend(eq_violations);
}
Ok(violations)
}
fn check_equation(
model: &Model,
eq: &Equation,
theory: &Theory,
options: &CheckModelOptions,
) -> Result<Vec<EquationViolation>, GatError> {
let var_sorts = infer_var_sorts(eq, theory)?;
let var_carriers: Vec<(Arc<str>, &[ModelValue])> = var_sorts
.iter()
.map(|(var, sort)| {
let head = sort.head();
let carrier = model
.sort_interp
.get(head.as_ref())
.ok_or_else(|| GatError::ModelError(format!("no carrier set for sort '{sort}'")))?;
Ok((Arc::clone(var), carrier.as_slice()))
})
.collect::<Result<Vec<_>, GatError>>()?;
if var_carriers.iter().any(|(_, carrier)| carrier.is_empty()) {
return Ok(vec![]);
}
if var_carriers.is_empty() {
let assignment = FxHashMap::default();
let lhs_val = eval_term(&eq.lhs, &assignment, model)?;
let rhs_val = eval_term(&eq.rhs, &assignment, model)?;
if lhs_val != rhs_val {
return Ok(vec![EquationViolation {
equation: Arc::clone(&eq.name),
assignment,
lhs_value: lhs_val,
rhs_value: rhs_val,
}]);
}
return Ok(vec![]);
}
let total: usize = var_carriers
.iter()
.map(|(_, carrier)| carrier.len())
.try_fold(1usize, usize::checked_mul)
.unwrap_or(usize::MAX);
if options.max_assignments > 0 && total > options.max_assignments {
return Err(GatError::ModelError(format!(
"equation '{}' requires {total} assignments, exceeding limit {}",
eq.name, options.max_assignments
)));
}
let mut violations = Vec::new();
let mut indices = vec![0usize; var_carriers.len()];
loop {
let assignment: FxHashMap<Arc<str>, ModelValue> = var_carriers
.iter()
.zip(indices.iter())
.map(|((var, carrier), &idx)| (Arc::clone(var), carrier[idx].clone()))
.collect();
let lhs_val = eval_term(&eq.lhs, &assignment, model)?;
let rhs_val = eval_term(&eq.rhs, &assignment, model)?;
if lhs_val != rhs_val {
violations.push(EquationViolation {
equation: Arc::clone(&eq.name),
assignment,
lhs_value: lhs_val,
rhs_value: rhs_val,
});
}
if !increment_indices(&mut indices, &var_carriers) {
break;
}
}
Ok(violations)
}
fn eval_term(
term: &Term,
assignment: &FxHashMap<Arc<str>, ModelValue>,
model: &Model,
) -> Result<ModelValue, GatError> {
match term {
Term::Var(name) => assignment
.get(name)
.cloned()
.ok_or_else(|| GatError::ModelError(format!("variable '{name}' not in assignment"))),
Term::App { op, args } => {
let arg_vals: Vec<ModelValue> = args
.iter()
.map(|a| eval_term(a, assignment, model))
.collect::<Result<Vec<_>, _>>()?;
model.eval(op, &arg_vals)
}
Term::Case {
scrutinee,
branches,
} => {
let _ = (scrutinee, branches);
Err(GatError::ModelError(
"case terms are not yet supported in set-theoretic model evaluation".to_string(),
))
}
Term::Hole { .. } => Err(GatError::ModelError(
"typed holes cannot be evaluated in a set-theoretic model".to_string(),
)),
Term::Let { name, bound, body } => {
let v = eval_term(bound, assignment, model)?;
let mut extended = assignment.clone();
extended.insert(Arc::clone(name), v);
eval_term(body, &extended, model)
}
}
}
fn increment_indices(indices: &mut [usize], var_carriers: &[(Arc<str>, &[ModelValue])]) -> bool {
for i in (0..indices.len()).rev() {
indices[i] += 1;
if indices[i] < var_carriers[i].1.len() {
return true;
}
indices[i] = 0;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::eq::Equation;
use crate::model::Model;
use crate::op::Operation;
use crate::sort::Sort;
use crate::theory::Theory;
fn monoid_theory() -> Theory {
Theory::new(
"Monoid",
vec![Sort::simple("Carrier")],
vec![
Operation::new(
"mul",
vec![
("a".into(), "Carrier".into()),
("b".into(), "Carrier".into()),
],
"Carrier",
),
Operation::nullary("unit", "Carrier"),
],
vec![
Equation::new(
"assoc",
Term::app(
"mul",
vec![
Term::var("a"),
Term::app("mul", vec![Term::var("b"), Term::var("c")]),
],
),
Term::app(
"mul",
vec![
Term::app("mul", vec![Term::var("a"), Term::var("b")]),
Term::var("c"),
],
),
),
Equation::new(
"left_id",
Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
Term::var("a"),
),
Equation::new(
"right_id",
Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
Term::var("a"),
),
],
)
}
fn valid_z5_model() -> Model {
let mut model = Model::new("Monoid");
model.add_sort("Carrier", (0..5).map(ModelValue::Int).collect());
model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
(ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int((a + b) % 5)),
_ => Err(GatError::ModelError("expected Int".into())),
});
model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
model
}
#[test]
fn valid_model_passes() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let model = valid_z5_model();
let violations = check_model(&model, &theory)?;
assert!(
violations.is_empty(),
"expected no violations, got {violations:?}"
);
Ok(())
}
#[test]
fn broken_identity_detected() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let mut model = valid_z5_model();
model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
let violations = check_model(&model, &theory)?;
assert!(!violations.is_empty(), "expected violations");
let has_identity_violation = violations
.iter()
.any(|v| v.equation.as_ref() == "left_id" || v.equation.as_ref() == "right_id");
assert!(has_identity_violation);
Ok(())
}
#[test]
fn broken_associativity_detected() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let mut model = Model::new("Monoid");
model.add_sort(
"Carrier",
vec![ModelValue::Int(0), ModelValue::Int(1), ModelValue::Int(2)],
);
model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
(ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int((*a - *b).max(0))),
_ => Err(GatError::ModelError("expected Int".into())),
});
model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
let violations = check_model(&model, &theory)?;
let has_assoc = violations.iter().any(|v| v.equation.as_ref() == "assoc");
assert!(has_assoc, "expected associativity violation");
Ok(())
}
#[test]
fn empty_carrier_passes() -> Result<(), Box<dyn std::error::Error>> {
let theory = monoid_theory();
let mut model = Model::new("Monoid");
model.add_sort("Carrier", vec![]);
model.add_op("mul", |_: &[ModelValue]| {
Err(GatError::ModelError("unreachable".into()))
});
model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
let violations = check_model(&model, &theory)?;
assert!(violations.is_empty());
Ok(())
}
#[test]
fn constants_only_equation() -> Result<(), Box<dyn std::error::Error>> {
let theory = Theory::new(
"T",
vec![Sort::simple("S")],
vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
vec![Equation::new(
"a_eq_b",
Term::constant("a"),
Term::constant("b"),
)],
);
let mut model = Model::new("T");
model.add_sort("S", vec![ModelValue::Int(0)]);
model.add_op("a", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
model.add_op("b", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
let violations = check_model(&model, &theory)?;
assert!(violations.is_empty());
model.add_op("b", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
let violations = check_model(&model, &theory)?;
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].equation.as_ref(), "a_eq_b");
Ok(())
}
#[test]
fn assignment_limit_exceeded() {
let theory = monoid_theory();
let mut model = Model::new("Monoid");
model.add_sort("Carrier", (0..100).map(ModelValue::Int).collect());
model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
(ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a + b)),
_ => Err(GatError::ModelError("expected Int".into())),
});
model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
let options = CheckModelOptions {
max_assignments: 100,
};
let result = check_model_with_options(&model, &theory, &options);
assert!(matches!(result, Err(GatError::ModelError(_))));
}
#[test]
fn missing_carrier_set_errors() {
let theory = monoid_theory();
let model = Model::new("Monoid");
let result = check_model(&model, &theory);
assert!(matches!(result, Err(GatError::ModelError(_))));
}
}