Skip to main content

panproto_gat/
check_model.rs

1//! Model equation satisfaction checking.
2//!
3//! Verifies that a [`Model`] satisfies all equations of its [`Theory`]
4//! by enumerating variable assignments from carrier sets and evaluating
5//! both sides.
6
7use std::sync::Arc;
8
9use rustc_hash::FxHashMap;
10
11use crate::eq::{Equation, Term};
12use crate::error::GatError;
13use crate::model::{Model, ModelValue};
14use crate::theory::Theory;
15use crate::typecheck::infer_var_sorts;
16
17/// A single violation of an equation in a model.
18#[derive(Debug, Clone)]
19pub struct EquationViolation {
20    /// The name of the violated equation.
21    pub equation: Arc<str>,
22    /// The variable assignment that produced the violation.
23    pub assignment: FxHashMap<Arc<str>, ModelValue>,
24    /// The value the LHS evaluated to.
25    pub lhs_value: ModelValue,
26    /// The value the RHS evaluated to.
27    pub rhs_value: ModelValue,
28}
29
30/// Options for model checking.
31#[derive(Debug, Clone)]
32pub struct CheckModelOptions {
33    /// Maximum number of assignments to enumerate per equation.
34    /// Set to 0 for unlimited. Default: 10,000.
35    pub max_assignments: usize,
36}
37
38impl Default for CheckModelOptions {
39    fn default() -> Self {
40        Self {
41            max_assignments: 10_000,
42        }
43    }
44}
45
46/// Check whether a model satisfies all equations of its theory.
47///
48/// Returns a list of violations (empty means the model is valid).
49///
50/// # Errors
51///
52/// Returns [`GatError`] if variable sorts cannot be inferred or a carrier
53/// set is missing from the model.
54pub fn check_model(model: &Model, theory: &Theory) -> Result<Vec<EquationViolation>, GatError> {
55    check_model_with_options(model, theory, &CheckModelOptions::default())
56}
57
58/// Check with configurable options.
59///
60/// # Errors
61///
62/// Returns [`GatError::ModelError`] if the assignment count exceeds
63/// `options.max_assignments`, or other errors from type inference.
64pub fn check_model_with_options(
65    model: &Model,
66    theory: &Theory,
67    options: &CheckModelOptions,
68) -> Result<Vec<EquationViolation>, GatError> {
69    let mut violations = Vec::new();
70
71    for eq in &theory.eqs {
72        let eq_violations = check_equation(model, eq, theory, options)?;
73        violations.extend(eq_violations);
74    }
75
76    Ok(violations)
77}
78
79/// Check a single equation against all valid variable assignments.
80fn check_equation(
81    model: &Model,
82    eq: &Equation,
83    theory: &Theory,
84    options: &CheckModelOptions,
85) -> Result<Vec<EquationViolation>, GatError> {
86    let var_sorts = infer_var_sorts(eq, theory)?;
87
88    // Build ordered list of (var_name, carrier_set) pairs.
89    let var_carriers: Vec<(Arc<str>, &[ModelValue])> = var_sorts
90        .iter()
91        .map(|(var, sort)| {
92            let head = sort.head();
93            let carrier = model
94                .sort_interp
95                .get(head.as_ref())
96                .ok_or_else(|| GatError::ModelError(format!("no carrier set for sort '{sort}'")))?;
97            Ok((Arc::clone(var), carrier.as_slice()))
98        })
99        .collect::<Result<Vec<_>, GatError>>()?;
100
101    // If any carrier is empty, there are zero valid assignments.
102    if var_carriers.iter().any(|(_, carrier)| carrier.is_empty()) {
103        return Ok(vec![]);
104    }
105
106    // Handle the zero-variable case: one assignment (the empty one).
107    if var_carriers.is_empty() {
108        let assignment = FxHashMap::default();
109        let lhs_val = eval_term(&eq.lhs, &assignment, model)?;
110        let rhs_val = eval_term(&eq.rhs, &assignment, model)?;
111        if lhs_val != rhs_val {
112            return Ok(vec![EquationViolation {
113                equation: Arc::clone(&eq.name),
114                assignment,
115                lhs_value: lhs_val,
116                rhs_value: rhs_val,
117            }]);
118        }
119        return Ok(vec![]);
120    }
121
122    // Compute total assignment count for limit check.
123    let total: usize = var_carriers
124        .iter()
125        .map(|(_, carrier)| carrier.len())
126        .try_fold(1usize, usize::checked_mul)
127        .unwrap_or(usize::MAX);
128
129    if options.max_assignments > 0 && total > options.max_assignments {
130        return Err(GatError::ModelError(format!(
131            "equation '{}' requires {total} assignments, exceeding limit {}",
132            eq.name, options.max_assignments
133        )));
134    }
135
136    let mut violations = Vec::new();
137    let mut indices = vec![0usize; var_carriers.len()];
138
139    loop {
140        // Build current assignment.
141        let assignment: FxHashMap<Arc<str>, ModelValue> = var_carriers
142            .iter()
143            .zip(indices.iter())
144            .map(|((var, carrier), &idx)| (Arc::clone(var), carrier[idx].clone()))
145            .collect();
146
147        // Evaluate both sides.
148        let lhs_val = eval_term(&eq.lhs, &assignment, model)?;
149        let rhs_val = eval_term(&eq.rhs, &assignment, model)?;
150
151        if lhs_val != rhs_val {
152            violations.push(EquationViolation {
153                equation: Arc::clone(&eq.name),
154                assignment,
155                lhs_value: lhs_val,
156                rhs_value: rhs_val,
157            });
158        }
159
160        // Increment indices (odometer-style).
161        if !increment_indices(&mut indices, &var_carriers) {
162            break;
163        }
164    }
165
166    Ok(violations)
167}
168
169/// Evaluate a term under a variable-to-ModelValue assignment.
170fn eval_term(
171    term: &Term,
172    assignment: &FxHashMap<Arc<str>, ModelValue>,
173    model: &Model,
174) -> Result<ModelValue, GatError> {
175    match term {
176        Term::Var(name) => assignment
177            .get(name)
178            .cloned()
179            .ok_or_else(|| GatError::ModelError(format!("variable '{name}' not in assignment"))),
180
181        Term::App { op, args } => {
182            let arg_vals: Vec<ModelValue> = args
183                .iter()
184                .map(|a| eval_term(a, assignment, model))
185                .collect::<Result<Vec<_>, _>>()?;
186            model.eval(op, &arg_vals)
187        }
188
189        Term::Case {
190            scrutinee,
191            branches,
192        } => {
193            // Model evaluation of a case term: evaluate the scrutinee
194            // and match against branches by constructor-tagged
195            // model values. Set-theoretic models return a
196            // ModelValue::Constructor variant when appropriate. Since
197            // the current Model runtime does not carry constructor
198            // tags, we surface this as an unsupported-in-model error;
199            // the typechecker still verifies well-formedness
200            // independently.
201            let _ = (scrutinee, branches);
202            Err(GatError::ModelError(
203                "case terms are not yet supported in set-theoretic model evaluation".to_string(),
204            ))
205        }
206
207        Term::Hole { .. } => Err(GatError::ModelError(
208            "typed holes cannot be evaluated in a set-theoretic model".to_string(),
209        )),
210        Term::Let { name, bound, body } => {
211            let v = eval_term(bound, assignment, model)?;
212            let mut extended = assignment.clone();
213            extended.insert(Arc::clone(name), v);
214            eval_term(body, &extended, model)
215        }
216    }
217}
218
219/// Odometer-style increment. Returns `false` when all combinations are exhausted.
220fn increment_indices(indices: &mut [usize], var_carriers: &[(Arc<str>, &[ModelValue])]) -> bool {
221    for i in (0..indices.len()).rev() {
222        indices[i] += 1;
223        if indices[i] < var_carriers[i].1.len() {
224            return true;
225        }
226        indices[i] = 0;
227    }
228    false
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::eq::Equation;
235    use crate::model::Model;
236    use crate::op::Operation;
237    use crate::sort::Sort;
238    use crate::theory::Theory;
239
240    fn monoid_theory() -> Theory {
241        Theory::new(
242            "Monoid",
243            vec![Sort::simple("Carrier")],
244            vec![
245                Operation::new(
246                    "mul",
247                    vec![
248                        ("a".into(), "Carrier".into()),
249                        ("b".into(), "Carrier".into()),
250                    ],
251                    "Carrier",
252                ),
253                Operation::nullary("unit", "Carrier"),
254            ],
255            vec![
256                Equation::new(
257                    "assoc",
258                    Term::app(
259                        "mul",
260                        vec![
261                            Term::var("a"),
262                            Term::app("mul", vec![Term::var("b"), Term::var("c")]),
263                        ],
264                    ),
265                    Term::app(
266                        "mul",
267                        vec![
268                            Term::app("mul", vec![Term::var("a"), Term::var("b")]),
269                            Term::var("c"),
270                        ],
271                    ),
272                ),
273                Equation::new(
274                    "left_id",
275                    Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
276                    Term::var("a"),
277                ),
278                Equation::new(
279                    "right_id",
280                    Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
281                    Term::var("a"),
282                ),
283            ],
284        )
285    }
286
287    fn valid_z5_model() -> Model {
288        let mut model = Model::new("Monoid");
289        model.add_sort("Carrier", (0..5).map(ModelValue::Int).collect());
290        model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
291            (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int((a + b) % 5)),
292            _ => Err(GatError::ModelError("expected Int".into())),
293        });
294        model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
295        model
296    }
297
298    #[test]
299    fn valid_model_passes() -> Result<(), Box<dyn std::error::Error>> {
300        let theory = monoid_theory();
301        let model = valid_z5_model();
302        let violations = check_model(&model, &theory)?;
303        assert!(
304            violations.is_empty(),
305            "expected no violations, got {violations:?}"
306        );
307        Ok(())
308    }
309
310    #[test]
311    fn broken_identity_detected() -> Result<(), Box<dyn std::error::Error>> {
312        let theory = monoid_theory();
313        let mut model = valid_z5_model();
314        // Break right identity: unit() returns 1 instead of 0.
315        model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
316
317        let violations = check_model(&model, &theory)?;
318        assert!(!violations.is_empty(), "expected violations");
319
320        // At least one violation should be for right_id or left_id.
321        let has_identity_violation = violations
322            .iter()
323            .any(|v| v.equation.as_ref() == "left_id" || v.equation.as_ref() == "right_id");
324        assert!(has_identity_violation);
325        Ok(())
326    }
327
328    #[test]
329    fn broken_associativity_detected() -> Result<(), Box<dyn std::error::Error>> {
330        let theory = monoid_theory();
331        let mut model = Model::new("Monoid");
332        model.add_sort(
333            "Carrier",
334            vec![ModelValue::Int(0), ModelValue::Int(1), ModelValue::Int(2)],
335        );
336        // Non-associative: saturating subtraction (a - b, clamped to 0).
337        model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
338            (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int((*a - *b).max(0))),
339            _ => Err(GatError::ModelError("expected Int".into())),
340        });
341        model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
342
343        let violations = check_model(&model, &theory)?;
344        let has_assoc = violations.iter().any(|v| v.equation.as_ref() == "assoc");
345        assert!(has_assoc, "expected associativity violation");
346        Ok(())
347    }
348
349    #[test]
350    fn empty_carrier_passes() -> Result<(), Box<dyn std::error::Error>> {
351        let theory = monoid_theory();
352        let mut model = Model::new("Monoid");
353        model.add_sort("Carrier", vec![]);
354        model.add_op("mul", |_: &[ModelValue]| {
355            Err(GatError::ModelError("unreachable".into()))
356        });
357        model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
358
359        // With empty carrier, only constants-only equations are checked.
360        // left_id and right_id have variables, so 0 assignments for those.
361        // But unit() = unit() would be checked if it existed.
362        // assoc also has variables so 0 assignments.
363        let violations = check_model(&model, &theory)?;
364        assert!(violations.is_empty());
365        Ok(())
366    }
367
368    #[test]
369    fn constants_only_equation() -> Result<(), Box<dyn std::error::Error>> {
370        let theory = Theory::new(
371            "T",
372            vec![Sort::simple("S")],
373            vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
374            vec![Equation::new(
375                "a_eq_b",
376                Term::constant("a"),
377                Term::constant("b"),
378            )],
379        );
380
381        // Model where a() = b() = 0: passes.
382        let mut model = Model::new("T");
383        model.add_sort("S", vec![ModelValue::Int(0)]);
384        model.add_op("a", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
385        model.add_op("b", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
386        let violations = check_model(&model, &theory)?;
387        assert!(violations.is_empty());
388
389        // Model where a() = 0, b() = 1: fails.
390        model.add_op("b", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
391        let violations = check_model(&model, &theory)?;
392        assert_eq!(violations.len(), 1);
393        assert_eq!(violations[0].equation.as_ref(), "a_eq_b");
394        Ok(())
395    }
396
397    #[test]
398    fn assignment_limit_exceeded() {
399        let theory = monoid_theory();
400        let mut model = Model::new("Monoid");
401        // Large carrier set: 100 elements, assoc has 3 variables -> 1M assignments.
402        model.add_sort("Carrier", (0..100).map(ModelValue::Int).collect());
403        model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
404            (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a + b)),
405            _ => Err(GatError::ModelError("expected Int".into())),
406        });
407        model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
408
409        let options = CheckModelOptions {
410            max_assignments: 100,
411        };
412        let result = check_model_with_options(&model, &theory, &options);
413        assert!(matches!(result, Err(GatError::ModelError(_))));
414    }
415
416    #[test]
417    fn missing_carrier_set_errors() {
418        let theory = monoid_theory();
419        let model = Model::new("Monoid");
420        // No carrier set added; should error.
421        let result = check_model(&model, &theory);
422        assert!(matches!(result, Err(GatError::ModelError(_))));
423    }
424}