Skip to main content

oxiz_solver/mbqi/
model_completion.rs

1//! Model Completion Algorithms
2//!
3//! This module implements model completion for MBQI. Model completion is the process
4//! of taking a partial model (which may only define values for some terms) and
5//! completing it to a total model that assigns values to all terms.
6//!
7//! The key challenge is handling function symbols and uninterpreted sorts, which may
8//! have infinitely many possible interpretations. We use several strategies:
9//!
10//! 1. **Macro Solving**: Identify quantifiers that can be solved as macros
11//! 2. **Projection Functions**: Map infinite domains to finite representatives
12//! 3. **Default Values**: Assign sensible defaults for undefined terms
13//! 4. **Finite Universes**: Restrict uninterpreted sorts to finite sets
14//!
15//! # References
16//!
17//! - Z3's model_fixer.cpp and q_model_fixer.cpp
18//! - "Complete Quantifier Instantiation" (Ge & de Moura, 2009)
19
20#![allow(missing_docs)]
21#![allow(dead_code)]
22
23use lasso::Spur;
24use num_bigint::BigInt;
25use num_rational::Rational64;
26use oxiz_core::ast::{TermId, TermKind, TermManager};
27use oxiz_core::sort::SortId;
28use rustc_hash::{FxHashMap, FxHashSet};
29use smallvec::SmallVec;
30use std::cmp::Ordering;
31use std::fmt;
32
33use super::QuantifiedFormula;
34
35/// A completed model that assigns values to all relevant terms
36#[derive(Debug, Clone)]
37pub struct CompletedModel {
38    /// Term assignments (term -> value)
39    pub assignments: FxHashMap<TermId, TermId>,
40    /// Function interpretations
41    pub function_interps: FxHashMap<Spur, FunctionInterpretation>,
42    /// Universes for uninterpreted sorts (sort -> finite set of values)
43    pub universes: FxHashMap<SortId, Vec<TermId>>,
44    /// Default values for each sort
45    pub defaults: FxHashMap<SortId, TermId>,
46    /// Generation number
47    pub generation: u32,
48}
49
50impl CompletedModel {
51    /// Create a new empty completed model
52    pub fn new() -> Self {
53        Self {
54            assignments: FxHashMap::default(),
55            function_interps: FxHashMap::default(),
56            universes: FxHashMap::default(),
57            defaults: FxHashMap::default(),
58            generation: 0,
59        }
60    }
61
62    /// Get the value of a term in this model
63    pub fn eval(&self, term: TermId) -> Option<TermId> {
64        self.assignments.get(&term).copied()
65    }
66
67    /// Set the value of a term
68    pub fn set(&mut self, term: TermId, value: TermId) {
69        self.assignments.insert(term, value);
70    }
71
72    /// Get the universe for a sort
73    pub fn universe(&self, sort: SortId) -> Option<&[TermId]> {
74        self.universes.get(&sort).map(|v| v.as_slice())
75    }
76
77    /// Add a value to a sort's universe
78    pub fn add_to_universe(&mut self, sort: SortId, value: TermId) {
79        self.universes.entry(sort).or_default().push(value);
80    }
81
82    /// Get the default value for a sort
83    pub fn default_value(&self, sort: SortId) -> Option<TermId> {
84        self.defaults.get(&sort).copied()
85    }
86
87    /// Set the default value for a sort
88    pub fn set_default(&mut self, sort: SortId, value: TermId) {
89        self.defaults.insert(sort, value);
90    }
91
92    /// Check if a sort has an uninterpreted universe
93    pub fn has_uninterpreted_sort(&self, sort: SortId) -> bool {
94        self.universes.contains_key(&sort)
95    }
96}
97
98impl Default for CompletedModel {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104/// A function interpretation (finite representation of function mapping)
105#[derive(Debug, Clone)]
106pub struct FunctionInterpretation {
107    /// Function name
108    pub name: Spur,
109    /// Arity
110    pub arity: usize,
111    /// Domain sorts
112    pub domain: SmallVec<[SortId; 4]>,
113    /// Range sort
114    pub range: SortId,
115    /// Explicit entries (args -> result)
116    pub entries: Vec<FunctionEntry>,
117    /// Default/else value (for arguments not in entries)
118    pub else_value: Option<TermId>,
119    /// Projection functions for arguments (if any)
120    pub projections: Vec<Option<ProjectionFunctionDef>>,
121}
122
123impl FunctionInterpretation {
124    /// Create a new function interpretation
125    pub fn new(name: Spur, domain: SmallVec<[SortId; 4]>, range: SortId) -> Self {
126        let arity = domain.len();
127        Self {
128            name,
129            arity,
130            domain,
131            range,
132            entries: Vec::new(),
133            else_value: None,
134            projections: vec![None; arity],
135        }
136    }
137
138    /// Add an entry to the function table
139    pub fn add_entry(&mut self, args: Vec<TermId>, result: TermId) {
140        if args.len() == self.arity {
141            self.entries.push(FunctionEntry { args, result });
142        }
143    }
144
145    /// Lookup a value in the function table
146    pub fn lookup(&self, args: &[TermId]) -> Option<TermId> {
147        for entry in &self.entries {
148            if entry.args == args {
149                return Some(entry.result);
150            }
151        }
152        self.else_value
153    }
154
155    /// Check if this is a constant function
156    pub fn is_constant(&self) -> bool {
157        self.arity == 0
158    }
159
160    /// Check if the interpretation is partial (missing else value or entries)
161    pub fn is_partial(&self) -> bool {
162        self.else_value.is_none() && !self.entries.is_empty()
163    }
164
165    /// Get the most common result value
166    pub fn max_occurrence_result(&self) -> Option<TermId> {
167        if self.entries.is_empty() {
168            return None;
169        }
170
171        let mut counts: FxHashMap<TermId, usize> = FxHashMap::default();
172        for entry in &self.entries {
173            *counts.entry(entry.result).or_insert(0) += 1;
174        }
175
176        counts
177            .into_iter()
178            .max_by_key(|(_, count)| *count)
179            .map(|(term, _)| term)
180    }
181}
182
183/// A single entry in a function interpretation
184#[derive(Debug, Clone)]
185pub struct FunctionEntry {
186    /// Arguments
187    pub args: Vec<TermId>,
188    /// Result value
189    pub result: TermId,
190}
191
192/// Definition of a projection function for argument position
193#[derive(Debug, Clone)]
194pub struct ProjectionFunctionDef {
195    /// Argument index this projection is for
196    pub arg_index: usize,
197    /// Sort being projected
198    pub sort: SortId,
199    /// Sorted values that appear in function applications
200    pub values: Vec<TermId>,
201    /// Mapping from value to representative term
202    pub value_to_term: FxHashMap<TermId, TermId>,
203    /// Mapping from term to value
204    pub term_to_value: FxHashMap<TermId, TermId>,
205}
206
207impl ProjectionFunctionDef {
208    /// Create a new projection function definition
209    pub fn new(arg_index: usize, sort: SortId) -> Self {
210        Self {
211            arg_index,
212            sort,
213            values: Vec::new(),
214            value_to_term: FxHashMap::default(),
215            term_to_value: FxHashMap::default(),
216        }
217    }
218
219    /// Add a value to the projection
220    pub fn add_value(&mut self, value: TermId, term: TermId) {
221        if !self.values.contains(&value) {
222            self.values.push(value);
223        }
224        self.value_to_term.insert(value, term);
225        self.term_to_value.insert(term, value);
226    }
227
228    /// Project a value to its representative
229    pub fn project(&self, value: TermId) -> Option<TermId> {
230        self.value_to_term.get(&value).copied()
231    }
232}
233
234/// Model completer that takes partial models and makes them complete
235#[derive(Debug)]
236pub struct ModelCompleter {
237    /// Macro solver
238    macro_solver: MacroSolver,
239    /// Model fixer for function interpretations
240    model_fixer: ModelFixer,
241    /// Handler for uninterpreted sorts
242    uninterp_handler: UninterpretedSortHandler,
243    /// Cache of completed models
244    cache: FxHashMap<u64, CompletedModel>,
245    /// Statistics
246    stats: CompletionStats,
247}
248
249impl ModelCompleter {
250    /// Create a new model completer
251    pub fn new() -> Self {
252        Self {
253            macro_solver: MacroSolver::new(),
254            model_fixer: ModelFixer::new(),
255            uninterp_handler: UninterpretedSortHandler::new(),
256            cache: FxHashMap::default(),
257            stats: CompletionStats::default(),
258        }
259    }
260
261    /// Complete a partial model
262    pub fn complete(
263        &mut self,
264        partial_model: &FxHashMap<TermId, TermId>,
265        quantifiers: &[QuantifiedFormula],
266        manager: &mut TermManager,
267    ) -> Result<CompletedModel, CompletionError> {
268        self.stats.num_completions += 1;
269
270        // Start with the partial model
271        let mut completed = CompletedModel::new();
272        completed.assignments = partial_model.clone();
273
274        // Try to solve some quantifiers as macros
275        let macro_results = self.macro_solver.solve_macros(quantifiers, manager)?;
276        for (func_name, interp) in macro_results {
277            completed.function_interps.insert(func_name, interp);
278        }
279
280        // Complete function interpretations
281        self.model_fixer
282            .fix_model(&mut completed, quantifiers, manager)?;
283
284        // Handle uninterpreted sorts
285        self.uninterp_handler
286            .complete_universes(&mut completed, manager)?;
287
288        // Set default values for all sorts
289        self.set_default_values(&mut completed, manager)?;
290
291        Ok(completed)
292    }
293
294    /// Set default values for all sorts in the model
295    fn set_default_values(
296        &mut self,
297        model: &mut CompletedModel,
298        manager: &mut TermManager,
299    ) -> Result<(), CompletionError> {
300        // Boolean
301        if !model.defaults.contains_key(&manager.sorts.bool_sort) {
302            model.set_default(manager.sorts.bool_sort, manager.mk_false());
303        }
304
305        // Integer
306        if !model.defaults.contains_key(&manager.sorts.int_sort) {
307            model.set_default(manager.sorts.int_sort, manager.mk_int(BigInt::from(0)));
308        }
309
310        // Real
311        if !model.defaults.contains_key(&manager.sorts.real_sort) {
312            model.set_default(
313                manager.sorts.real_sort,
314                manager.mk_real(Rational64::from_integer(0)),
315            );
316        }
317
318        // Uninterpreted sorts - use first element from universe
319        // Collect defaults first to avoid borrow conflict
320        let defaults_to_set: Vec<(SortId, TermId)> = model
321            .universes
322            .iter()
323            .filter_map(|(sort, universe)| {
324                if !model.defaults.contains_key(sort) {
325                    universe.first().map(|&first| (*sort, first))
326                } else {
327                    None
328                }
329            })
330            .collect();
331
332        for (sort, value) in defaults_to_set {
333            model.set_default(sort, value);
334        }
335
336        Ok(())
337    }
338
339    /// Get completion statistics
340    pub fn stats(&self) -> &CompletionStats {
341        &self.stats
342    }
343}
344
345impl Default for ModelCompleter {
346    fn default() -> Self {
347        Self::new()
348    }
349}
350
351/// Macro solver that identifies quantifiers that can be solved as macros
352///
353/// A quantifier can be solved as a macro if it has the form:
354/// ∀x. f(x) = body(x)
355/// where f is an uninterpreted function and body doesn't contain f
356#[derive(Debug)]
357pub struct MacroSolver {
358    /// Detected macros
359    macros: FxHashMap<Spur, MacroDefinition>,
360    /// Statistics
361    stats: MacroStats,
362}
363
364impl MacroSolver {
365    /// Create a new macro solver
366    pub fn new() -> Self {
367        Self {
368            macros: FxHashMap::default(),
369            stats: MacroStats::default(),
370        }
371    }
372
373    /// Try to solve quantifiers as macros
374    pub fn solve_macros(
375        &mut self,
376        quantifiers: &[QuantifiedFormula],
377        manager: &mut TermManager,
378    ) -> Result<FxHashMap<Spur, FunctionInterpretation>, CompletionError> {
379        let mut results = FxHashMap::default();
380
381        for quant in quantifiers {
382            if let Some(macro_def) = self.try_extract_macro(quant, manager)? {
383                self.stats.num_macros_found += 1;
384                let interp = self.macro_to_interpretation(&macro_def, manager)?;
385                results.insert(macro_def.func_name, interp);
386                self.macros.insert(macro_def.func_name, macro_def);
387            }
388        }
389
390        Ok(results)
391    }
392
393    /// Try to extract a macro from a quantified formula
394    fn try_extract_macro(
395        &self,
396        quant: &QuantifiedFormula,
397        manager: &TermManager,
398    ) -> Result<Option<MacroDefinition>, CompletionError> {
399        // Look for pattern: ∀x. f(x) = body(x)
400        let Some(body_term) = manager.get(quant.body) else {
401            return Ok(None);
402        };
403
404        // Check if body is an equality
405        if let TermKind::Eq(lhs, rhs) = &body_term.kind {
406            // Try both directions
407            if let Some(macro_def) = self.try_extract_macro_from_eq(*lhs, *rhs, quant, manager)? {
408                return Ok(Some(macro_def));
409            }
410            if let Some(macro_def) = self.try_extract_macro_from_eq(*rhs, *lhs, quant, manager)? {
411                return Ok(Some(macro_def));
412            }
413        }
414
415        Ok(None)
416    }
417
418    /// Try to extract macro from equality lhs = rhs
419    fn try_extract_macro_from_eq(
420        &self,
421        lhs: TermId,
422        rhs: TermId,
423        quant: &QuantifiedFormula,
424        manager: &TermManager,
425    ) -> Result<Option<MacroDefinition>, CompletionError> {
426        let Some(lhs_term) = manager.get(lhs) else {
427            return Ok(None);
428        };
429
430        // Check if lhs is f(x1, ..., xn) where f is uninterpreted
431        if let TermKind::Apply { func, args } = &lhs_term.kind {
432            // Check if all args are bound variables
433            let mut is_macro = true;
434            for &arg in args.iter() {
435                if let Some(arg_term) = manager.get(arg)
436                    && !matches!(arg_term.kind, TermKind::Var(_))
437                {
438                    is_macro = false;
439                    break;
440                }
441            }
442
443            if is_macro {
444                // Check if rhs doesn't contain f
445                if !self.contains_function(rhs, *func, manager) {
446                    return Ok(Some(MacroDefinition {
447                        quantifier: quant.term,
448                        func_name: *func,
449                        bound_vars: quant.bound_vars.clone(),
450                        body: rhs,
451                    }));
452                }
453            }
454        }
455
456        Ok(None)
457    }
458
459    /// Check if term contains a function application
460    fn contains_function(&self, term: TermId, func: Spur, manager: &TermManager) -> bool {
461        let mut visited = FxHashSet::default();
462        self.contains_function_rec(term, func, manager, &mut visited)
463    }
464
465    fn contains_function_rec(
466        &self,
467        term: TermId,
468        func: Spur,
469        manager: &TermManager,
470        visited: &mut FxHashSet<TermId>,
471    ) -> bool {
472        if visited.contains(&term) {
473            return false;
474        }
475        visited.insert(term);
476
477        let Some(t) = manager.get(term) else {
478            return false;
479        };
480
481        match &t.kind {
482            TermKind::Apply { func: f, args } => {
483                if *f == func {
484                    return true;
485                }
486                for &arg in args.iter() {
487                    if self.contains_function_rec(arg, func, manager, visited) {
488                        return true;
489                    }
490                }
491                false
492            }
493            _ => {
494                // Recursively check children
495                let children = self.get_children(term, manager);
496                for child in children {
497                    if self.contains_function_rec(child, func, manager, visited) {
498                        return true;
499                    }
500                }
501                false
502            }
503        }
504    }
505
506    /// Get children of a term
507    fn get_children(&self, term: TermId, manager: &TermManager) -> Vec<TermId> {
508        let Some(t) = manager.get(term) else {
509            return vec![];
510        };
511
512        match &t.kind {
513            TermKind::Not(arg) | TermKind::Neg(arg) => vec![*arg],
514            TermKind::And(args)
515            | TermKind::Or(args)
516            | TermKind::Add(args)
517            | TermKind::Mul(args) => args.to_vec(),
518            TermKind::Sub(lhs, rhs)
519            | TermKind::Div(lhs, rhs)
520            | TermKind::Mod(lhs, rhs)
521            | TermKind::Eq(lhs, rhs)
522            | TermKind::Lt(lhs, rhs)
523            | TermKind::Le(lhs, rhs)
524            | TermKind::Gt(lhs, rhs)
525            | TermKind::Ge(lhs, rhs)
526            | TermKind::Implies(lhs, rhs) => vec![*lhs, *rhs],
527            TermKind::Ite(cond, then_br, else_br) => vec![*cond, *then_br, *else_br],
528            TermKind::Apply { args, .. } => args.to_vec(),
529            _ => vec![],
530        }
531    }
532
533    /// Convert a macro definition to a function interpretation
534    fn macro_to_interpretation(
535        &self,
536        macro_def: &MacroDefinition,
537        manager: &mut TermManager,
538    ) -> Result<FunctionInterpretation, CompletionError> {
539        // For now, create an empty interpretation
540        // In a full implementation, we would evaluate the body for various inputs
541        let func_name = macro_def.func_name;
542
543        // Get function signature (this is simplified - real implementation would look it up)
544        let domain = SmallVec::new();
545        let range = manager.sorts.bool_sort; // Placeholder
546
547        let interp = FunctionInterpretation::new(func_name, domain, range);
548        Ok(interp)
549    }
550
551    /// Get statistics
552    pub fn stats(&self) -> &MacroStats {
553        &self.stats
554    }
555}
556
557impl Default for MacroSolver {
558    fn default() -> Self {
559        Self::new()
560    }
561}
562
563/// A macro definition extracted from a quantifier
564#[derive(Debug, Clone)]
565pub struct MacroDefinition {
566    /// Original quantifier
567    pub quantifier: TermId,
568    /// Function being defined
569    pub func_name: Spur,
570    /// Bound variables
571    pub bound_vars: SmallVec<[(Spur, SortId); 4]>,
572    /// Definition body
573    pub body: TermId,
574}
575
576/// Model fixer that completes function interpretations
577#[derive(Debug)]
578pub struct ModelFixer {
579    /// Projection functions by sort
580    projections: FxHashMap<SortId, Box<dyn ProjectionFunction>>,
581    /// Statistics
582    stats: FixerStats,
583}
584
585impl ModelFixer {
586    /// Create a new model fixer
587    pub fn new() -> Self {
588        Self {
589            projections: FxHashMap::default(),
590            stats: FixerStats::default(),
591        }
592    }
593
594    /// Fix a model by completing function interpretations
595    pub fn fix_model(
596        &mut self,
597        model: &mut CompletedModel,
598        quantifiers: &[QuantifiedFormula],
599        manager: &mut TermManager,
600    ) -> Result<(), CompletionError> {
601        self.stats.num_fixes += 1;
602
603        // Collect all partial functions from quantifiers
604        let partial_functions = self.collect_partial_functions(quantifiers, manager);
605
606        // For each partial function, add projection functions
607        // Process one at a time to avoid borrow conflicts
608        for func_name in partial_functions.iter() {
609            // Check if function exists first (immutable borrow)
610            let has_interp = model.function_interps.contains_key(func_name);
611            if has_interp {
612                // Get mutable reference in separate scope
613                if let Some(interp) = model.function_interps.get_mut(func_name) {
614                    // Create a minimal projection without full model access
615                    // This is a simplified version - full implementation would cache model data
616                    for arg_idx in 0..interp.arity {
617                        let sort = interp.domain[arg_idx];
618                        if self.needs_projection(sort, manager) {
619                            // Placeholder: would need model data extracted first
620                            interp.projections[arg_idx] = None;
621                        }
622                    }
623                }
624            }
625        }
626
627        // Complete partial interpretations
628        for interp in model.function_interps.values_mut() {
629            if interp.is_partial() {
630                // Use most common value as default
631                if let Some(default) = interp.max_occurrence_result() {
632                    interp.else_value = Some(default);
633                }
634            }
635        }
636
637        Ok(())
638    }
639
640    /// Collect partial function symbols from quantifiers
641    fn collect_partial_functions(
642        &self,
643        quantifiers: &[QuantifiedFormula],
644        manager: &TermManager,
645    ) -> FxHashSet<Spur> {
646        let mut functions = FxHashSet::default();
647
648        for quant in quantifiers {
649            self.collect_partial_functions_rec(quant.body, &mut functions, manager);
650        }
651
652        functions
653    }
654
655    fn collect_partial_functions_rec(
656        &self,
657        term: TermId,
658        functions: &mut FxHashSet<Spur>,
659        manager: &TermManager,
660    ) {
661        let Some(t) = manager.get(term) else {
662            return;
663        };
664
665        if let TermKind::Apply { func, args } = &t.kind {
666            // Check if any arg contains variables (not ground)
667            let has_vars = args.iter().any(|&arg| {
668                if let Some(arg_t) = manager.get(arg) {
669                    matches!(arg_t.kind, TermKind::Var(_))
670                } else {
671                    false
672                }
673            });
674
675            if has_vars {
676                functions.insert(*func);
677            }
678
679            // Recurse into args
680            for &arg in args.iter() {
681                self.collect_partial_functions_rec(arg, functions, manager);
682            }
683        }
684
685        // Recurse into other children
686        match &t.kind {
687            TermKind::Not(arg) | TermKind::Neg(arg) => {
688                self.collect_partial_functions_rec(*arg, functions, manager);
689            }
690            TermKind::And(args) | TermKind::Or(args) => {
691                for &arg in args.iter() {
692                    self.collect_partial_functions_rec(arg, functions, manager);
693                }
694            }
695            TermKind::Eq(lhs, rhs) | TermKind::Lt(lhs, rhs) | TermKind::Le(lhs, rhs) => {
696                self.collect_partial_functions_rec(*lhs, functions, manager);
697                self.collect_partial_functions_rec(*rhs, functions, manager);
698            }
699            _ => {}
700        }
701    }
702
703    /// Add projection functions for a function interpretation
704    fn add_projection_functions(
705        &mut self,
706        interp: &mut FunctionInterpretation,
707        model: &CompletedModel,
708        manager: &mut TermManager,
709    ) -> Result<(), CompletionError> {
710        // For each argument position, create a projection if needed
711        for arg_idx in 0..interp.arity {
712            let sort = interp.domain[arg_idx];
713
714            // Check if we need a projection for this sort
715            if self.needs_projection(sort, manager) {
716                let proj_def = self.create_projection(interp, arg_idx, model, manager)?;
717                interp.projections[arg_idx] = Some(proj_def);
718            }
719        }
720
721        Ok(())
722    }
723
724    /// Check if a sort needs projection
725    fn needs_projection(&self, sort: SortId, manager: &TermManager) -> bool {
726        // Arithmetic sorts benefit from projection
727        sort == manager.sorts.int_sort || sort == manager.sorts.real_sort
728    }
729
730    /// Create a projection function for an argument position
731    fn create_projection(
732        &mut self,
733        interp: &FunctionInterpretation,
734        arg_idx: usize,
735        model: &CompletedModel,
736        manager: &mut TermManager,
737    ) -> Result<ProjectionFunctionDef, CompletionError> {
738        let sort = interp.domain[arg_idx];
739        let mut proj_def = ProjectionFunctionDef::new(arg_idx, sort);
740
741        // Collect all values that appear at this argument position
742        for entry in &interp.entries {
743            if let Some(&arg_term) = entry.args.get(arg_idx) {
744                // Evaluate the argument in the model
745                let value = model.eval(arg_term).unwrap_or(arg_term);
746                proj_def.add_value(value, arg_term);
747            }
748        }
749
750        // Sort the values
751        proj_def
752            .values
753            .sort_by(|a, b| self.compare_values(*a, *b, sort, manager));
754
755        Ok(proj_def)
756    }
757
758    /// Compare two values for a given sort
759    fn compare_values(
760        &self,
761        a: TermId,
762        b: TermId,
763        _sort: SortId,
764        manager: &TermManager,
765    ) -> Ordering {
766        let a_term = manager.get(a);
767        let b_term = manager.get(b);
768
769        if let (Some(at), Some(bt)) = (a_term, b_term) {
770            // Integer comparison
771            if let (TermKind::IntConst(av), TermKind::IntConst(bv)) = (&at.kind, &bt.kind) {
772                return av.cmp(bv);
773            }
774
775            // Real comparison
776            if let (TermKind::RealConst(av), TermKind::RealConst(bv)) = (&at.kind, &bt.kind) {
777                return av.cmp(bv);
778            }
779
780            // Boolean comparison (false < true)
781            match (&at.kind, &bt.kind) {
782                (TermKind::False, TermKind::True) => return Ordering::Less,
783                (TermKind::True, TermKind::False) => return Ordering::Greater,
784                (TermKind::False, TermKind::False) | (TermKind::True, TermKind::True) => {
785                    return Ordering::Equal;
786                }
787                _ => {}
788            }
789        }
790
791        // Fall back to ID comparison
792        a.0.cmp(&b.0)
793    }
794
795    /// Get statistics
796    pub fn stats(&self) -> &FixerStats {
797        &self.stats
798    }
799}
800
801impl Default for ModelFixer {
802    fn default() -> Self {
803        Self::new()
804    }
805}
806
807/// Trait for projection functions (maps infinite domain to finite representatives)
808pub trait ProjectionFunction: fmt::Debug + Send + Sync {
809    /// Compare two values (for sorting)
810    fn compare(&self, a: TermId, b: TermId, manager: &TermManager) -> bool;
811
812    /// Create a less-than term
813    fn mk_lt(&self, x: TermId, y: TermId, manager: &mut TermManager) -> TermId;
814}
815
816/// Arithmetic projection function
817#[derive(Debug)]
818pub struct ArithmeticProjection {
819    /// Whether this is for integers (vs reals)
820    is_int: bool,
821}
822
823impl ArithmeticProjection {
824    pub fn new(is_int: bool) -> Self {
825        Self { is_int }
826    }
827}
828
829impl ProjectionFunction for ArithmeticProjection {
830    fn compare(&self, a: TermId, b: TermId, manager: &TermManager) -> bool {
831        let a_term = manager.get(a);
832        let b_term = manager.get(b);
833
834        if let (Some(at), Some(bt)) = (a_term, b_term) {
835            if let (TermKind::IntConst(av), TermKind::IntConst(bv)) = (&at.kind, &bt.kind) {
836                return av < bv;
837            }
838            if let (TermKind::RealConst(av), TermKind::RealConst(bv)) = (&at.kind, &bt.kind) {
839                return av < bv;
840            }
841        }
842
843        a.0 < b.0
844    }
845
846    fn mk_lt(&self, x: TermId, y: TermId, manager: &mut TermManager) -> TermId {
847        manager.mk_lt(x, y)
848    }
849}
850
851/// Handler for uninterpreted sorts
852#[derive(Debug)]
853pub struct UninterpretedSortHandler {
854    /// Maximum universe size for each sort
855    max_universe_size: usize,
856    /// Statistics
857    stats: UninterpStats,
858}
859
860impl UninterpretedSortHandler {
861    /// Create a new handler
862    pub fn new() -> Self {
863        Self {
864            max_universe_size: 8,
865            stats: UninterpStats::default(),
866        }
867    }
868
869    /// Create with custom universe size limit
870    pub fn with_max_size(max_size: usize) -> Self {
871        let mut handler = Self::new();
872        handler.max_universe_size = max_size;
873        handler
874    }
875
876    /// Complete universes for uninterpreted sorts
877    pub fn complete_universes(
878        &mut self,
879        model: &mut CompletedModel,
880        manager: &mut TermManager,
881    ) -> Result<(), CompletionError> {
882        // Identify uninterpreted sorts
883        let uninterp_sorts = self.identify_uninterpreted_sorts(model, manager);
884
885        for sort in uninterp_sorts {
886            if let std::collections::hash_map::Entry::Vacant(e) = model.universes.entry(sort) {
887                // Create a finite universe for this sort
888                let universe = self.create_finite_universe(sort, manager)?;
889                e.insert(universe);
890                self.stats.num_universes_created += 1;
891            }
892        }
893
894        Ok(())
895    }
896
897    /// Identify uninterpreted sorts in the model
898    fn identify_uninterpreted_sorts(
899        &self,
900        model: &CompletedModel,
901        manager: &TermManager,
902    ) -> Vec<SortId> {
903        let mut sorts = Vec::new();
904
905        // Collect sorts from function interpretations
906        for interp in model.function_interps.values() {
907            for &sort in &interp.domain {
908                if self.is_uninterpreted(sort, manager) && !sorts.contains(&sort) {
909                    sorts.push(sort);
910                }
911            }
912            if self.is_uninterpreted(interp.range, manager) && !sorts.contains(&interp.range) {
913                sorts.push(interp.range);
914            }
915        }
916
917        sorts
918    }
919
920    /// Check if a sort is uninterpreted
921    fn is_uninterpreted(&self, sort: SortId, manager: &TermManager) -> bool {
922        // A sort is uninterpreted if it's not a built-in sort
923        sort != manager.sorts.bool_sort
924            && sort != manager.sorts.int_sort
925            && sort != manager.sorts.real_sort
926    }
927
928    /// Create a finite universe for a sort
929    fn create_finite_universe(
930        &self,
931        sort: SortId,
932        manager: &mut TermManager,
933    ) -> Result<Vec<TermId>, CompletionError> {
934        let mut universe = Vec::new();
935
936        // Create fresh constants for the universe
937        for i in 0..self.max_universe_size {
938            let name = format!("u!{}", i);
939            let const_id = manager.mk_var(&name, sort);
940            universe.push(const_id);
941        }
942
943        Ok(universe)
944    }
945
946    /// Get statistics
947    pub fn stats(&self) -> &UninterpStats {
948        &self.stats
949    }
950}
951
952impl Default for UninterpretedSortHandler {
953    fn default() -> Self {
954        Self::new()
955    }
956}
957
958/// Error during model completion
959#[derive(Debug, Clone)]
960pub enum CompletionError {
961    /// Could not complete the model
962    CompletionFailed(String),
963    /// Resource limit exceeded
964    ResourceLimit,
965    /// Invalid model
966    InvalidModel(String),
967}
968
969impl fmt::Display for CompletionError {
970    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
971        match self {
972            Self::CompletionFailed(msg) => write!(f, "Model completion failed: {}", msg),
973            Self::ResourceLimit => write!(f, "Resource limit exceeded during completion"),
974            Self::InvalidModel(msg) => write!(f, "Invalid model: {}", msg),
975        }
976    }
977}
978
979impl std::error::Error for CompletionError {}
980
981/// Statistics for model completion
982#[derive(Debug, Clone, Default)]
983pub struct CompletionStats {
984    pub num_completions: usize,
985    pub num_failures: usize,
986}
987
988/// Statistics for macro solving
989#[derive(Debug, Clone, Default)]
990pub struct MacroStats {
991    pub num_macros_found: usize,
992    pub num_macros_applied: usize,
993}
994
995/// Statistics for model fixing
996#[derive(Debug, Clone, Default)]
997pub struct FixerStats {
998    pub num_fixes: usize,
999    pub num_projections_created: usize,
1000}
1001
1002/// Statistics for uninterpreted sort handling
1003#[derive(Debug, Clone, Default)]
1004pub struct UninterpStats {
1005    pub num_universes_created: usize,
1006    pub total_universe_size: usize,
1007}
1008
1009#[cfg(test)]
1010mod tests {
1011    use super::*;
1012    use lasso::Key;
1013
1014    #[test]
1015    fn test_completed_model_creation() {
1016        let model = CompletedModel::new();
1017        assert_eq!(model.assignments.len(), 0);
1018        assert_eq!(model.function_interps.len(), 0);
1019    }
1020
1021    #[test]
1022    fn test_completed_model_eval() {
1023        let mut model = CompletedModel::new();
1024        let term = TermId::new(1);
1025        let value = TermId::new(2);
1026
1027        model.set(term, value);
1028        assert_eq!(model.eval(term), Some(value));
1029        assert_eq!(model.eval(TermId::new(99)), None);
1030    }
1031
1032    #[test]
1033    fn test_function_interpretation_lookup() {
1034        // Create a function with arity 2 (domain has 2 sorts)
1035        let mut domain = SmallVec::new();
1036        domain.push(SortId::new(1));
1037        domain.push(SortId::new(1));
1038
1039        let mut interp = FunctionInterpretation::new(
1040            Spur::try_from_usize(1).expect("valid spur"),
1041            domain,
1042            SortId::new(1),
1043        );
1044
1045        let args = vec![TermId::new(1), TermId::new(2)];
1046        let result = TermId::new(10);
1047        interp.add_entry(args.clone(), result);
1048
1049        assert_eq!(interp.lookup(&args), Some(result));
1050        assert_eq!(interp.lookup(&[TermId::new(99)]), None);
1051    }
1052
1053    #[test]
1054    fn test_function_interpretation_else_value() {
1055        let mut interp = FunctionInterpretation::new(
1056            Spur::try_from_usize(1).expect("valid spur"),
1057            SmallVec::new(),
1058            SortId::new(1),
1059        );
1060
1061        let else_val = TermId::new(42);
1062        interp.else_value = Some(else_val);
1063
1064        assert_eq!(interp.lookup(&[TermId::new(99)]), Some(else_val));
1065    }
1066
1067    #[test]
1068    fn test_function_interpretation_max_occurrence() {
1069        // Create a function with arity 1 (domain has 1 sort)
1070        let mut domain = SmallVec::new();
1071        domain.push(SortId::new(1));
1072
1073        let mut interp = FunctionInterpretation::new(
1074            Spur::try_from_usize(1).expect("valid spur"),
1075            domain,
1076            SortId::new(1),
1077        );
1078
1079        let result1 = TermId::new(10);
1080        let result2 = TermId::new(20);
1081
1082        interp.add_entry(vec![TermId::new(1)], result1);
1083        interp.add_entry(vec![TermId::new(2)], result1);
1084        interp.add_entry(vec![TermId::new(3)], result2);
1085
1086        assert_eq!(interp.max_occurrence_result(), Some(result1));
1087    }
1088
1089    #[test]
1090    fn test_projection_function_def() {
1091        let mut proj = ProjectionFunctionDef::new(0, SortId::new(1));
1092
1093        let value1 = TermId::new(1);
1094        let term1 = TermId::new(10);
1095        proj.add_value(value1, term1);
1096
1097        assert_eq!(proj.project(value1), Some(term1));
1098        assert_eq!(proj.values.len(), 1);
1099    }
1100
1101    #[test]
1102    fn test_model_completer_creation() {
1103        let completer = ModelCompleter::new();
1104        assert_eq!(completer.stats.num_completions, 0);
1105    }
1106
1107    #[test]
1108    fn test_macro_solver_creation() {
1109        let solver = MacroSolver::new();
1110        assert_eq!(solver.stats.num_macros_found, 0);
1111    }
1112
1113    #[test]
1114    fn test_model_fixer_creation() {
1115        let fixer = ModelFixer::new();
1116        assert_eq!(fixer.stats.num_fixes, 0);
1117    }
1118
1119    #[test]
1120    fn test_uninterpreted_sort_handler_creation() {
1121        let handler = UninterpretedSortHandler::new();
1122        assert_eq!(handler.max_universe_size, 8);
1123    }
1124
1125    #[test]
1126    fn test_uninterpreted_sort_handler_custom_size() {
1127        let handler = UninterpretedSortHandler::with_max_size(16);
1128        assert_eq!(handler.max_universe_size, 16);
1129    }
1130
1131    #[test]
1132    fn test_arithmetic_projection() {
1133        let proj = ArithmeticProjection::new(true);
1134        assert!(proj.is_int);
1135    }
1136
1137    #[test]
1138    fn test_completion_error_display() {
1139        let err = CompletionError::CompletionFailed("test".to_string());
1140        assert!(format!("{}", err).contains("test"));
1141    }
1142}