1#[allow(unused_imports)]
10use crate::prelude::*;
11use crate::solver::{Solver, SolverResult};
12use num_bigint::BigInt;
13use num_rational::Rational64;
14use num_traits::Zero;
15use oxiz_core::ast::{TermId, TermKind, TermManager};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ObjectiveKind {
20 Minimize,
22 Maximize,
24}
25
26#[derive(Debug, Clone)]
28pub struct Objective {
29 pub term: TermId,
31 pub kind: ObjectiveKind,
33 pub priority: usize,
35}
36
37#[derive(Debug, Clone)]
39pub enum OptimizationResult {
40 Optimal {
42 value: TermId,
44 model: crate::solver::Model,
46 },
47 Unbounded,
49 Unsat,
51 Unknown,
53}
54
55#[derive(Debug)]
118pub struct Optimizer {
119 solver: Solver,
121 objectives: Vec<Objective>,
123 assertions: Vec<TermId>,
125}
126
127impl Optimizer {
128 #[must_use]
130 pub fn new() -> Self {
131 Self {
132 solver: Solver::new(),
133 objectives: Vec::new(),
134 assertions: Vec::new(),
135 }
136 }
137
138 pub fn assert(&mut self, term: TermId) {
140 self.assertions.push(term);
141 }
142
143 pub fn minimize(&mut self, term: TermId) {
145 self.objectives.push(Objective {
146 term,
147 kind: ObjectiveKind::Minimize,
148 priority: self.objectives.len(),
149 });
150 }
151
152 pub fn maximize(&mut self, term: TermId) {
154 self.objectives.push(Objective {
155 term,
156 kind: ObjectiveKind::Maximize,
157 priority: self.objectives.len(),
158 });
159 }
160
161 pub fn set_logic(&mut self, logic: &str) {
163 self.solver.set_logic(logic);
164 }
165
166 pub fn push(&mut self) {
168 self.solver.push();
169 }
170
171 pub fn pop(&mut self) {
173 self.solver.pop();
174 }
175
176 pub fn optimize(&mut self, term_manager: &mut TermManager) -> OptimizationResult {
181 for &assertion in &self.assertions.clone() {
183 self.solver.assert(assertion, term_manager);
184 }
185 self.assertions.clear();
187
188 if self.objectives.is_empty() {
189 match self.solver.check(term_manager) {
191 SolverResult::Sat => {
192 if let Some(model) = self.solver.model() {
193 let zero = term_manager.mk_int(BigInt::zero());
195 return OptimizationResult::Optimal {
196 value: zero,
197 model: model.clone(),
198 };
199 }
200 OptimizationResult::Unknown
201 }
202 SolverResult::Unsat => OptimizationResult::Unsat,
203 SolverResult::Unknown => OptimizationResult::Unknown,
204 }
205 } else {
206 let mut sorted_objectives = self.objectives.clone();
208 sorted_objectives.sort_by_key(|obj| obj.priority);
209
210 for (idx, objective) in sorted_objectives.iter().enumerate() {
212 let result = self.optimize_single(objective, term_manager);
213
214 match result {
215 OptimizationResult::Optimal { value, model } => {
216 if idx < sorted_objectives.len() - 1 {
218 self.solver.push();
220 let eq = term_manager.mk_eq(objective.term, value);
221 self.solver.assert(eq, term_manager);
222 } else {
223 return OptimizationResult::Optimal { value, model };
225 }
226 }
227 other => return other,
228 }
229 }
230
231 OptimizationResult::Unknown
232 }
233 }
234
235 fn optimize_single(
237 &mut self,
238 objective: &Objective,
239 term_manager: &mut TermManager,
240 ) -> OptimizationResult {
241 let result = self.solver.check(term_manager);
243 if result != SolverResult::Sat {
244 return match result {
245 SolverResult::Unsat => OptimizationResult::Unsat,
246 _ => OptimizationResult::Unknown,
247 };
248 }
249
250 let term_info = term_manager.get(objective.term);
252 let is_int = term_info.is_some_and(|t| t.sort == term_manager.sorts.int_sort);
253
254 if is_int {
255 self.optimize_int(objective, term_manager)
256 } else {
257 self.optimize_real(objective, term_manager)
258 }
259 }
260
261 fn optimize_int(
269 &mut self,
270 objective: &Objective,
271 term_manager: &mut TermManager,
272 ) -> OptimizationResult {
273 let result = self.solver.check(term_manager);
275 if result != SolverResult::Sat {
276 return if result == SolverResult::Unsat {
277 OptimizationResult::Unsat
278 } else {
279 OptimizationResult::Unknown
280 };
281 }
282
283 let mut best_model = match self.solver.model() {
285 Some(m) => m.clone(),
286 None => return OptimizationResult::Unknown,
287 };
288
289 let value_term = best_model.eval(objective.term, term_manager);
291
292 let mut current_value = if let Some(t) = term_manager.get(value_term) {
294 if let TermKind::IntConst(n) = &t.kind {
295 n.clone()
296 } else {
297 return OptimizationResult::Optimal {
299 value: value_term,
300 model: best_model,
301 };
302 }
303 } else {
304 return OptimizationResult::Unknown;
305 };
306
307 let mut best_value_term = value_term;
308
309 let max_iterations = 1000; for _ in 0..max_iterations {
312 self.solver.push();
314
315 let bound_term = term_manager.mk_int(current_value.clone());
317 let improvement_constraint = match objective.kind {
318 ObjectiveKind::Minimize => {
319 term_manager.mk_lt(objective.term, bound_term)
321 }
322 ObjectiveKind::Maximize => {
323 term_manager.mk_gt(objective.term, bound_term)
325 }
326 };
327 self.solver.assert(improvement_constraint, term_manager);
328
329 let result = self.solver.check(term_manager);
331 if result == SolverResult::Sat {
332 let mut value_updated = false;
334 if let Some(model) = self.solver.model() {
335 let new_value_term = model.eval(objective.term, term_manager);
336
337 if let Some(t) = term_manager.get(new_value_term)
338 && let TermKind::IntConst(n) = &t.kind
339 {
340 current_value = n.clone();
341 best_value_term = new_value_term;
342 best_model = model.clone();
343 value_updated = true;
344 }
345 }
346 self.solver.pop();
348 if !value_updated {
349 break;
352 }
353 } else {
354 self.solver.pop();
356 break;
357 }
358 }
359
360 OptimizationResult::Optimal {
361 value: best_value_term,
362 model: best_model,
363 }
364 }
365
366 fn optimize_real(
370 &mut self,
371 objective: &Objective,
372 term_manager: &mut TermManager,
373 ) -> OptimizationResult {
374 let result = self.solver.check(term_manager);
376 if result != SolverResult::Sat {
377 return if result == SolverResult::Unsat {
378 OptimizationResult::Unsat
379 } else {
380 OptimizationResult::Unknown
381 };
382 }
383
384 let mut best_model = match self.solver.model() {
386 Some(m) => m.clone(),
387 None => return OptimizationResult::Unknown,
388 };
389
390 let value_term = best_model.eval(objective.term, term_manager);
392
393 let mut current_value: Option<Rational64> = None;
395 if let Some(term) = term_manager.get(value_term) {
396 match &term.kind {
397 TermKind::RealConst(val) => {
398 current_value = Some(*val);
399 }
400 TermKind::IntConst(val) => {
401 let int_val = if val.sign() == num_bigint::Sign::Minus {
403 -val.to_string()
404 .trim_start_matches('-')
405 .parse::<i64>()
406 .unwrap_or(0)
407 } else {
408 val.to_string().parse::<i64>().unwrap_or(0)
409 };
410 current_value = Some(Rational64::from_integer(int_val));
411 }
412 _ => {}
413 }
414 }
415
416 let Some(mut current_val) = current_value else {
417 return OptimizationResult::Optimal {
419 value: value_term,
420 model: best_model,
421 };
422 };
423
424 let mut best_value = current_val;
425
426 let max_iterations = 1000;
428 for _ in 0..max_iterations {
429 self.solver.push();
430
431 let bound_term = term_manager.mk_real(current_val);
433 let improvement_constraint = match objective.kind {
434 ObjectiveKind::Minimize => term_manager.mk_lt(objective.term, bound_term),
435 ObjectiveKind::Maximize => term_manager.mk_gt(objective.term, bound_term),
436 };
437 self.solver.assert(improvement_constraint, term_manager);
438
439 let result = self.solver.check(term_manager);
440 if result == SolverResult::Sat {
441 let mut value_updated = false;
442 if let Some(model) = self.solver.model() {
443 let new_value_term = model.eval(objective.term, term_manager);
444
445 if let Some(t) = term_manager.get(new_value_term) {
446 let new_val = match &t.kind {
447 TermKind::RealConst(v) => Some(*v),
448 TermKind::IntConst(v) => {
449 let int_val = if v.sign() == num_bigint::Sign::Minus {
450 -v.to_string()
451 .trim_start_matches('-')
452 .parse::<i64>()
453 .unwrap_or(0)
454 } else {
455 v.to_string().parse::<i64>().unwrap_or(0)
456 };
457 Some(Rational64::from_integer(int_val))
458 }
459 _ => None,
460 };
461
462 if let Some(v) = new_val {
463 current_val = v;
464 best_value = v;
465 best_model = model.clone();
466 value_updated = true;
467 }
468 }
469 }
470 self.solver.pop();
471 if !value_updated {
472 break;
475 }
476 } else {
477 self.solver.pop();
478 break;
479 }
480 }
481
482 let final_value_term = term_manager.mk_real(best_value);
484 OptimizationResult::Optimal {
485 value: final_value_term,
486 model: best_model,
487 }
488 }
489}
490
491impl Default for Optimizer {
492 fn default() -> Self {
493 Self::new()
494 }
495}
496
497#[derive(Debug, Clone)]
499pub struct ParetoPoint {
500 pub values: Vec<TermId>,
502 pub model: crate::solver::Model,
504}
505
506impl Optimizer {
507 pub fn pareto_optimize(&mut self, term_manager: &mut TermManager) -> Vec<ParetoPoint> {
516 let mut pareto_front = Vec::new();
517
518 for &assertion in &self.assertions.clone() {
520 self.solver.assert(assertion, term_manager);
521 }
522 self.assertions.clear();
523
524 if self.objectives.is_empty() {
525 return pareto_front;
526 }
527
528 let max_points = 100; for _ in 0..max_points {
531 match self.solver.check(term_manager) {
533 SolverResult::Sat => {
534 if let Some(model) = self.solver.model() {
536 let mut values = Vec::new();
537 for objective in &self.objectives {
538 let value = model.eval(objective.term, term_manager);
539 values.push(value);
540 }
541
542 pareto_front.push(ParetoPoint {
544 values: values.clone(),
545 model: model.clone(),
546 });
547
548 self.solver.push();
552 let mut improvement_disjuncts = Vec::new();
553
554 for (idx, objective) in self.objectives.iter().enumerate() {
555 let current_value = values[idx];
556 let improvement = match objective.kind {
557 ObjectiveKind::Minimize => {
558 term_manager.mk_lt(objective.term, current_value)
559 }
560 ObjectiveKind::Maximize => {
561 term_manager.mk_gt(objective.term, current_value)
562 }
563 };
564 improvement_disjuncts.push(improvement);
565 }
566
567 if !improvement_disjuncts.is_empty() {
569 let constraint = term_manager.mk_or(improvement_disjuncts);
570 self.solver.assert(constraint, term_manager);
571 } else {
572 self.solver.pop();
574 break;
575 }
576 } else {
577 break;
578 }
579 }
580 SolverResult::Unsat => {
581 break;
583 }
584 SolverResult::Unknown => {
585 break;
587 }
588 }
589 }
590
591 pareto_front
592 }
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598 use num_bigint::BigInt;
599
600 #[test]
601 fn test_solver_direct() {
602 let mut solver = Solver::new();
604 let mut tm = TermManager::new();
605
606 solver.set_logic("QF_LIA");
607
608 let x = tm.mk_var("x", tm.sorts.int_sort);
609 let zero = tm.mk_int(BigInt::zero());
610 let ten = tm.mk_int(BigInt::from(10));
611
612 let c1 = tm.mk_ge(x, zero);
613 let c2 = tm.mk_le(x, ten);
614
615 solver.assert(c1, &mut tm);
616 solver.assert(c2, &mut tm);
617
618 let result = solver.check(&mut tm);
619 assert_eq!(result, SolverResult::Sat, "Solver should return SAT");
620 }
621
622 #[test]
623 fn test_optimizer_encoding() {
624 let mut optimizer = Optimizer::new();
626 let mut tm = TermManager::new();
627
628 optimizer.set_logic("QF_LIA");
629
630 let x = tm.mk_var("x", tm.sorts.int_sort);
631 let zero = tm.mk_int(BigInt::zero());
632 let ten = tm.mk_int(BigInt::from(10));
633
634 let c1 = tm.mk_ge(x, zero);
635 let c2 = tm.mk_le(x, ten);
636
637 optimizer.assert(c1);
638 optimizer.assert(c2);
639
640 for &assertion in &optimizer.assertions.clone() {
642 optimizer.solver.assert(assertion, &mut tm);
643 }
644 optimizer.assertions.clear();
645
646 let result = optimizer.solver.check(&mut tm);
647 assert_eq!(result, SolverResult::Sat, "Should be SAT after encoding");
648 }
649
650 #[test]
651 fn test_optimizer_basic() {
652 let mut optimizer = Optimizer::new();
653 let mut tm = TermManager::new();
654
655 optimizer.set_logic("QF_LIA");
656
657 let x = tm.mk_var("x", tm.sorts.int_sort);
659
660 let zero = tm.mk_int(BigInt::zero());
662 let c1 = tm.mk_ge(x, zero);
663 optimizer.assert(c1);
664
665 let ten = tm.mk_int(BigInt::from(10));
667 let c2 = tm.mk_le(x, ten);
668 optimizer.assert(c2);
669
670 optimizer.minimize(x);
672
673 let result = optimizer.optimize(&mut tm);
674 match result {
675 OptimizationResult::Optimal { value, .. } => {
676 if let Some(t) = tm.get(value) {
678 if let TermKind::IntConst(n) = &t.kind {
679 assert_eq!(*n, BigInt::zero());
680 } else {
681 panic!("Expected integer constant");
682 }
683 }
684 }
685 OptimizationResult::Unsat => panic!("Unexpected unsat result"),
686 OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
687 OptimizationResult::Unknown => panic!("Got unknown result"),
688 }
689 }
690
691 #[test]
692 fn test_optimizer_maximize() {
693 let mut optimizer = Optimizer::new();
694 let mut tm = TermManager::new();
695
696 optimizer.set_logic("QF_LIA");
697
698 let x = tm.mk_var("x", tm.sorts.int_sort);
699
700 let zero = tm.mk_int(BigInt::zero());
702 let c1 = tm.mk_ge(x, zero);
703 optimizer.assert(c1);
704
705 let ten = tm.mk_int(BigInt::from(10));
707 let c2 = tm.mk_le(x, ten);
708 optimizer.assert(c2);
709
710 optimizer.maximize(x);
712
713 let result = optimizer.optimize(&mut tm);
714 match result {
715 OptimizationResult::Optimal { value, .. } => {
716 if let Some(t) = tm.get(value) {
718 if let TermKind::IntConst(n) = &t.kind {
719 assert_eq!(*n, BigInt::from(10));
720 } else {
721 panic!("Expected integer constant");
722 }
723 }
724 }
725 _ => panic!("Expected optimal result"),
726 }
727 }
728
729 #[test]
730 fn test_optimizer_unsat() {
731 let mut optimizer = Optimizer::new();
732 let mut tm = TermManager::new();
733
734 optimizer.set_logic("QF_LIA");
735
736 let x = tm.mk_var("x", tm.sorts.int_sort);
738 let y = tm.mk_var("y", tm.sorts.int_sort);
739
740 let eq = tm.mk_eq(x, y);
742 let neq = tm.mk_not(eq);
743 optimizer.assert(eq);
744 optimizer.assert(neq);
745
746 optimizer.minimize(x);
747
748 let result = optimizer.optimize(&mut tm);
749 match result {
752 OptimizationResult::Unsat
753 | OptimizationResult::Unknown
754 | OptimizationResult::Optimal { .. } => {}
755 OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
756 }
757 }
758}