oxiz_solver/model/
completion.rs1use super::builder::{Model, Value, VarId};
17#[allow(unused_imports)]
18use crate::prelude::*;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum CompletionStrategy {
23 Default,
25 Witness,
27 Minimal,
29}
30
31#[derive(Debug, Clone)]
33pub struct CompletionConfig {
34 pub strategy: CompletionStrategy,
36 pub default_int: i64,
38 pub default_bool: bool,
40}
41
42impl Default for CompletionConfig {
43 fn default() -> Self {
44 Self {
45 strategy: CompletionStrategy::Default,
46 default_int: 0,
47 default_bool: false,
48 }
49 }
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct CompletionStats {
55 pub vars_completed: u64,
57 pub defaults_used: u64,
59 pub witnesses_used: u64,
61}
62
63#[derive(Debug)]
65pub struct ModelCompleter {
66 config: CompletionConfig,
68 witnesses: FxHashMap<VarId, Value>,
70 stats: CompletionStats,
72}
73
74impl ModelCompleter {
75 pub fn new(config: CompletionConfig) -> Self {
77 Self {
78 config,
79 witnesses: FxHashMap::default(),
80 stats: CompletionStats::default(),
81 }
82 }
83
84 pub fn default_config() -> Self {
86 Self::new(CompletionConfig::default())
87 }
88
89 pub fn add_witness(&mut self, var: VarId, value: Value) {
91 self.witnesses.insert(var, value);
92 }
93
94 pub fn complete(&mut self, partial_model: &Model) -> Model {
98 let mut completed = partial_model.clone();
99
100 let missing_vars = self.find_missing_vars(&completed);
102
103 for var in missing_vars {
104 let value = self.complete_variable(var);
105 completed.assign_theory(var, value);
106 self.stats.vars_completed += 1;
107 }
108
109 completed
110 }
111
112 fn find_missing_vars(&self, model: &Model) -> Vec<VarId> {
114 let mut missing = Vec::new();
117
118 for &var in self.witnesses.keys() {
120 if model.get_theory(var).is_none() {
121 missing.push(var);
122 }
123 }
124
125 missing
126 }
127
128 fn complete_variable(&mut self, var: VarId) -> Value {
130 match self.config.strategy {
131 CompletionStrategy::Witness => {
132 if let Some(witness) = self.witnesses.get(&var) {
133 self.stats.witnesses_used += 1;
134 witness.clone()
135 } else {
136 self.stats.defaults_used += 1;
137 self.default_value()
138 }
139 }
140 CompletionStrategy::Default => {
141 self.stats.defaults_used += 1;
142 self.default_value()
143 }
144 CompletionStrategy::Minimal => {
145 if let Some(witness) = self.witnesses.get(&var) {
147 self.stats.witnesses_used += 1;
148 witness.clone()
149 } else {
150 self.stats.defaults_used += 1;
151 self.minimal_value()
152 }
153 }
154 }
155 }
156
157 fn default_value(&self) -> Value {
159 Value::Int(self.config.default_int)
160 }
161
162 fn minimal_value(&self) -> Value {
164 Value::Int(0)
165 }
166
167 pub fn stats(&self) -> &CompletionStats {
169 &self.stats
170 }
171
172 pub fn reset(&mut self) {
174 self.witnesses.clear();
175 self.stats = CompletionStats::default();
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_completer_creation() {
185 let completer = ModelCompleter::default_config();
186 assert_eq!(completer.stats().vars_completed, 0);
187 }
188
189 #[test]
190 fn test_add_witness() {
191 let mut completer = ModelCompleter::default_config();
192 completer.add_witness(0, Value::Int(42));
193
194 assert!(completer.witnesses.contains_key(&0));
195 }
196
197 #[test]
198 fn test_complete_with_witnesses() {
199 let config = CompletionConfig {
200 strategy: CompletionStrategy::Witness,
201 ..Default::default()
202 };
203 let mut completer = ModelCompleter::new(config);
204 completer.add_witness(0, Value::Int(10));
205 completer.add_witness(1, Value::Bool(true));
206
207 let partial = Model::new();
208 let _completed = completer.complete(&partial);
209
210 assert!(completer.stats().witnesses_used > 0);
212 }
213
214 #[test]
215 fn test_default_completion() {
216 let config = CompletionConfig {
217 strategy: CompletionStrategy::Default,
218 default_int: 100,
219 ..Default::default()
220 };
221
222 let completer = ModelCompleter::new(config);
223 let default_val = completer.default_value();
224
225 assert_eq!(default_val, Value::Int(100));
226 }
227}