1use std::sync::Arc;
8
9use rustc_hash::FxHashMap;
10
11use crate::eq::{Equation, Term};
12use crate::error::GatError;
13use crate::model::{Model, ModelValue};
14use crate::theory::Theory;
15use crate::typecheck::infer_var_sorts;
16
17#[derive(Debug, Clone)]
19pub struct EquationViolation {
20 pub equation: Arc<str>,
22 pub assignment: FxHashMap<Arc<str>, ModelValue>,
24 pub lhs_value: ModelValue,
26 pub rhs_value: ModelValue,
28}
29
30#[derive(Debug, Clone)]
32pub struct CheckModelOptions {
33 pub max_assignments: usize,
36}
37
38impl Default for CheckModelOptions {
39 fn default() -> Self {
40 Self {
41 max_assignments: 10_000,
42 }
43 }
44}
45
46pub fn check_model(model: &Model, theory: &Theory) -> Result<Vec<EquationViolation>, GatError> {
55 check_model_with_options(model, theory, &CheckModelOptions::default())
56}
57
58pub fn check_model_with_options(
65 model: &Model,
66 theory: &Theory,
67 options: &CheckModelOptions,
68) -> Result<Vec<EquationViolation>, GatError> {
69 let mut violations = Vec::new();
70
71 for eq in &theory.eqs {
72 let eq_violations = check_equation(model, eq, theory, options)?;
73 violations.extend(eq_violations);
74 }
75
76 Ok(violations)
77}
78
79fn check_equation(
81 model: &Model,
82 eq: &Equation,
83 theory: &Theory,
84 options: &CheckModelOptions,
85) -> Result<Vec<EquationViolation>, GatError> {
86 let var_sorts = infer_var_sorts(eq, theory)?;
87
88 let var_carriers: Vec<(Arc<str>, &[ModelValue])> = var_sorts
90 .iter()
91 .map(|(var, sort)| {
92 let head = sort.head();
93 let carrier = model
94 .sort_interp
95 .get(head.as_ref())
96 .ok_or_else(|| GatError::ModelError(format!("no carrier set for sort '{sort}'")))?;
97 Ok((Arc::clone(var), carrier.as_slice()))
98 })
99 .collect::<Result<Vec<_>, GatError>>()?;
100
101 if var_carriers.iter().any(|(_, carrier)| carrier.is_empty()) {
103 return Ok(vec![]);
104 }
105
106 if var_carriers.is_empty() {
108 let assignment = FxHashMap::default();
109 let lhs_val = eval_term(&eq.lhs, &assignment, model)?;
110 let rhs_val = eval_term(&eq.rhs, &assignment, model)?;
111 if lhs_val != rhs_val {
112 return Ok(vec![EquationViolation {
113 equation: Arc::clone(&eq.name),
114 assignment,
115 lhs_value: lhs_val,
116 rhs_value: rhs_val,
117 }]);
118 }
119 return Ok(vec![]);
120 }
121
122 let total: usize = var_carriers
124 .iter()
125 .map(|(_, carrier)| carrier.len())
126 .try_fold(1usize, usize::checked_mul)
127 .unwrap_or(usize::MAX);
128
129 if options.max_assignments > 0 && total > options.max_assignments {
130 return Err(GatError::ModelError(format!(
131 "equation '{}' requires {total} assignments, exceeding limit {}",
132 eq.name, options.max_assignments
133 )));
134 }
135
136 let mut violations = Vec::new();
137 let mut indices = vec![0usize; var_carriers.len()];
138
139 loop {
140 let assignment: FxHashMap<Arc<str>, ModelValue> = var_carriers
142 .iter()
143 .zip(indices.iter())
144 .map(|((var, carrier), &idx)| (Arc::clone(var), carrier[idx].clone()))
145 .collect();
146
147 let lhs_val = eval_term(&eq.lhs, &assignment, model)?;
149 let rhs_val = eval_term(&eq.rhs, &assignment, model)?;
150
151 if lhs_val != rhs_val {
152 violations.push(EquationViolation {
153 equation: Arc::clone(&eq.name),
154 assignment,
155 lhs_value: lhs_val,
156 rhs_value: rhs_val,
157 });
158 }
159
160 if !increment_indices(&mut indices, &var_carriers) {
162 break;
163 }
164 }
165
166 Ok(violations)
167}
168
169fn eval_term(
171 term: &Term,
172 assignment: &FxHashMap<Arc<str>, ModelValue>,
173 model: &Model,
174) -> Result<ModelValue, GatError> {
175 match term {
176 Term::Var(name) => assignment
177 .get(name)
178 .cloned()
179 .ok_or_else(|| GatError::ModelError(format!("variable '{name}' not in assignment"))),
180
181 Term::App { op, args } => {
182 let arg_vals: Vec<ModelValue> = args
183 .iter()
184 .map(|a| eval_term(a, assignment, model))
185 .collect::<Result<Vec<_>, _>>()?;
186 model.eval(op, &arg_vals)
187 }
188
189 Term::Case {
190 scrutinee,
191 branches,
192 } => {
193 let _ = (scrutinee, branches);
202 Err(GatError::ModelError(
203 "case terms are not yet supported in set-theoretic model evaluation".to_string(),
204 ))
205 }
206
207 Term::Hole { .. } => Err(GatError::ModelError(
208 "typed holes cannot be evaluated in a set-theoretic model".to_string(),
209 )),
210 Term::Let { name, bound, body } => {
211 let v = eval_term(bound, assignment, model)?;
212 let mut extended = assignment.clone();
213 extended.insert(Arc::clone(name), v);
214 eval_term(body, &extended, model)
215 }
216 }
217}
218
219fn increment_indices(indices: &mut [usize], var_carriers: &[(Arc<str>, &[ModelValue])]) -> bool {
221 for i in (0..indices.len()).rev() {
222 indices[i] += 1;
223 if indices[i] < var_carriers[i].1.len() {
224 return true;
225 }
226 indices[i] = 0;
227 }
228 false
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::eq::Equation;
235 use crate::model::Model;
236 use crate::op::Operation;
237 use crate::sort::Sort;
238 use crate::theory::Theory;
239
240 fn monoid_theory() -> Theory {
241 Theory::new(
242 "Monoid",
243 vec![Sort::simple("Carrier")],
244 vec![
245 Operation::new(
246 "mul",
247 vec![
248 ("a".into(), "Carrier".into()),
249 ("b".into(), "Carrier".into()),
250 ],
251 "Carrier",
252 ),
253 Operation::nullary("unit", "Carrier"),
254 ],
255 vec![
256 Equation::new(
257 "assoc",
258 Term::app(
259 "mul",
260 vec![
261 Term::var("a"),
262 Term::app("mul", vec![Term::var("b"), Term::var("c")]),
263 ],
264 ),
265 Term::app(
266 "mul",
267 vec![
268 Term::app("mul", vec![Term::var("a"), Term::var("b")]),
269 Term::var("c"),
270 ],
271 ),
272 ),
273 Equation::new(
274 "left_id",
275 Term::app("mul", vec![Term::constant("unit"), Term::var("a")]),
276 Term::var("a"),
277 ),
278 Equation::new(
279 "right_id",
280 Term::app("mul", vec![Term::var("a"), Term::constant("unit")]),
281 Term::var("a"),
282 ),
283 ],
284 )
285 }
286
287 fn valid_z5_model() -> Model {
288 let mut model = Model::new("Monoid");
289 model.add_sort("Carrier", (0..5).map(ModelValue::Int).collect());
290 model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
291 (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int((a + b) % 5)),
292 _ => Err(GatError::ModelError("expected Int".into())),
293 });
294 model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
295 model
296 }
297
298 #[test]
299 fn valid_model_passes() -> Result<(), Box<dyn std::error::Error>> {
300 let theory = monoid_theory();
301 let model = valid_z5_model();
302 let violations = check_model(&model, &theory)?;
303 assert!(
304 violations.is_empty(),
305 "expected no violations, got {violations:?}"
306 );
307 Ok(())
308 }
309
310 #[test]
311 fn broken_identity_detected() -> Result<(), Box<dyn std::error::Error>> {
312 let theory = monoid_theory();
313 let mut model = valid_z5_model();
314 model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
316
317 let violations = check_model(&model, &theory)?;
318 assert!(!violations.is_empty(), "expected violations");
319
320 let has_identity_violation = violations
322 .iter()
323 .any(|v| v.equation.as_ref() == "left_id" || v.equation.as_ref() == "right_id");
324 assert!(has_identity_violation);
325 Ok(())
326 }
327
328 #[test]
329 fn broken_associativity_detected() -> Result<(), Box<dyn std::error::Error>> {
330 let theory = monoid_theory();
331 let mut model = Model::new("Monoid");
332 model.add_sort(
333 "Carrier",
334 vec![ModelValue::Int(0), ModelValue::Int(1), ModelValue::Int(2)],
335 );
336 model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
338 (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int((*a - *b).max(0))),
339 _ => Err(GatError::ModelError("expected Int".into())),
340 });
341 model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
342
343 let violations = check_model(&model, &theory)?;
344 let has_assoc = violations.iter().any(|v| v.equation.as_ref() == "assoc");
345 assert!(has_assoc, "expected associativity violation");
346 Ok(())
347 }
348
349 #[test]
350 fn empty_carrier_passes() -> Result<(), Box<dyn std::error::Error>> {
351 let theory = monoid_theory();
352 let mut model = Model::new("Monoid");
353 model.add_sort("Carrier", vec![]);
354 model.add_op("mul", |_: &[ModelValue]| {
355 Err(GatError::ModelError("unreachable".into()))
356 });
357 model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
358
359 let violations = check_model(&model, &theory)?;
364 assert!(violations.is_empty());
365 Ok(())
366 }
367
368 #[test]
369 fn constants_only_equation() -> Result<(), Box<dyn std::error::Error>> {
370 let theory = Theory::new(
371 "T",
372 vec![Sort::simple("S")],
373 vec![Operation::nullary("a", "S"), Operation::nullary("b", "S")],
374 vec![Equation::new(
375 "a_eq_b",
376 Term::constant("a"),
377 Term::constant("b"),
378 )],
379 );
380
381 let mut model = Model::new("T");
383 model.add_sort("S", vec![ModelValue::Int(0)]);
384 model.add_op("a", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
385 model.add_op("b", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
386 let violations = check_model(&model, &theory)?;
387 assert!(violations.is_empty());
388
389 model.add_op("b", |_: &[ModelValue]| Ok(ModelValue::Int(1)));
391 let violations = check_model(&model, &theory)?;
392 assert_eq!(violations.len(), 1);
393 assert_eq!(violations[0].equation.as_ref(), "a_eq_b");
394 Ok(())
395 }
396
397 #[test]
398 fn assignment_limit_exceeded() {
399 let theory = monoid_theory();
400 let mut model = Model::new("Monoid");
401 model.add_sort("Carrier", (0..100).map(ModelValue::Int).collect());
403 model.add_op("mul", |args: &[ModelValue]| match (&args[0], &args[1]) {
404 (ModelValue::Int(a), ModelValue::Int(b)) => Ok(ModelValue::Int(a + b)),
405 _ => Err(GatError::ModelError("expected Int".into())),
406 });
407 model.add_op("unit", |_: &[ModelValue]| Ok(ModelValue::Int(0)));
408
409 let options = CheckModelOptions {
410 max_assignments: 100,
411 };
412 let result = check_model_with_options(&model, &theory, &options);
413 assert!(matches!(result, Err(GatError::ModelError(_))));
414 }
415
416 #[test]
417 fn missing_carrier_set_errors() {
418 let theory = monoid_theory();
419 let model = Model::new("Monoid");
420 let result = check_model(&model, &theory);
422 assert!(matches!(result, Err(GatError::ModelError(_))));
423 }
424}