1#[allow(unused_imports)]
10use crate::prelude::*;
11use crate::solver::{Solver, SolverResult};
12use num_bigint::BigInt;
13use num_rational::Rational64;
14use num_traits::Zero;
15use oxiz_core::ast::{TermId, TermKind, TermManager};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum ObjectiveKind {
20 Minimize,
22 Maximize,
24}
25
26#[derive(Debug, Clone)]
28pub struct Objective {
29 pub term: TermId,
31 pub kind: ObjectiveKind,
33 pub priority: usize,
35}
36
37#[derive(Debug, Clone)]
39pub enum OptimizationResult {
40 Optimal {
42 value: TermId,
44 model: crate::solver::Model,
46 },
47 Unbounded,
49 Unsat,
51 Unknown,
53}
54
55#[derive(Debug)]
118pub struct Optimizer {
119 solver: Solver,
121 objectives: Vec<Objective>,
123 assertions: Vec<TermId>,
125}
126
127impl Optimizer {
128 #[must_use]
130 pub fn new() -> Self {
131 Self {
132 solver: Solver::new(),
133 objectives: Vec::new(),
134 assertions: Vec::new(),
135 }
136 }
137
138 pub fn assert(&mut self, term: TermId) {
140 self.assertions.push(term);
141 }
142
143 pub fn minimize(&mut self, term: TermId) {
145 self.objectives.push(Objective {
146 term,
147 kind: ObjectiveKind::Minimize,
148 priority: self.objectives.len(),
149 });
150 }
151
152 pub fn maximize(&mut self, term: TermId) {
154 self.objectives.push(Objective {
155 term,
156 kind: ObjectiveKind::Maximize,
157 priority: self.objectives.len(),
158 });
159 }
160
161 pub fn set_logic(&mut self, logic: &str) {
163 self.solver.set_logic(logic);
164 }
165
166 pub fn push(&mut self) {
168 self.solver.push();
169 }
170
171 pub fn pop(&mut self) {
173 self.solver.pop();
174 }
175
176 pub fn optimize(&mut self, term_manager: &mut TermManager) -> OptimizationResult {
181 for &assertion in &self.assertions.clone() {
183 self.solver.assert(assertion, term_manager);
184 }
185 self.assertions.clear();
187
188 if self.objectives.is_empty() {
189 match self.solver.check(term_manager) {
191 SolverResult::Sat => {
192 if let Some(model) = self.solver.model() {
193 let zero = term_manager.mk_int(BigInt::zero());
195 return OptimizationResult::Optimal {
196 value: zero,
197 model: model.clone(),
198 };
199 }
200 OptimizationResult::Unknown
201 }
202 SolverResult::Unsat => OptimizationResult::Unsat,
203 SolverResult::Unknown => OptimizationResult::Unknown,
204 }
205 } else {
206 let mut sorted_objectives = self.objectives.clone();
208 sorted_objectives.sort_by_key(|obj| obj.priority);
209
210 for (idx, objective) in sorted_objectives.iter().enumerate() {
212 let result = self.optimize_single(objective, term_manager);
213
214 match result {
215 OptimizationResult::Optimal { value, model } => {
216 if idx < sorted_objectives.len() - 1 {
218 self.solver.push();
220 let eq = term_manager.mk_eq(objective.term, value);
221 self.solver.assert(eq, term_manager);
222 } else {
223 return OptimizationResult::Optimal { value, model };
225 }
226 }
227 other => return other,
228 }
229 }
230
231 OptimizationResult::Unknown
232 }
233 }
234
235 fn optimize_single(
237 &mut self,
238 objective: &Objective,
239 term_manager: &mut TermManager,
240 ) -> OptimizationResult {
241 let result = self.solver.check(term_manager);
243 if result != SolverResult::Sat {
244 return match result {
245 SolverResult::Unsat => OptimizationResult::Unsat,
246 _ => OptimizationResult::Unknown,
247 };
248 }
249
250 let term_info = term_manager.get(objective.term);
252 let is_int = term_info.is_some_and(|t| t.sort == term_manager.sorts.int_sort);
253
254 if is_int {
255 self.optimize_int(objective, term_manager)
256 } else {
257 self.optimize_real(objective, term_manager)
258 }
259 }
260
261 fn optimize_int(
269 &mut self,
270 objective: &Objective,
271 term_manager: &mut TermManager,
272 ) -> OptimizationResult {
273 let result = self.solver.check(term_manager);
275 if result != SolverResult::Sat {
276 return if result == SolverResult::Unsat {
277 OptimizationResult::Unsat
278 } else {
279 OptimizationResult::Unknown
280 };
281 }
282
283 let mut best_model = match self.solver.model() {
285 Some(m) => m.clone(),
286 None => return OptimizationResult::Unknown,
287 };
288
289 let value_term = best_model.eval(objective.term, term_manager);
291
292 let mut current_value = if let Some(t) = term_manager.get(value_term) {
294 if let TermKind::IntConst(n) = &t.kind {
295 n.clone()
296 } else {
297 return OptimizationResult::Optimal {
299 value: value_term,
300 model: best_model,
301 };
302 }
303 } else {
304 return OptimizationResult::Unknown;
305 };
306
307 let mut best_value_term = value_term;
308
309 let max_iterations = 1000; for _ in 0..max_iterations {
312 self.solver.push();
314
315 let bound_term = term_manager.mk_int(current_value.clone());
317 let improvement_constraint = match objective.kind {
318 ObjectiveKind::Minimize => {
319 term_manager.mk_lt(objective.term, bound_term)
321 }
322 ObjectiveKind::Maximize => {
323 term_manager.mk_gt(objective.term, bound_term)
325 }
326 };
327 self.solver.assert(improvement_constraint, term_manager);
328
329 let result = self.solver.check(term_manager);
331 if result == SolverResult::Sat {
332 if let Some(model) = self.solver.model() {
334 let new_value_term = model.eval(objective.term, term_manager);
335
336 if let Some(t) = term_manager.get(new_value_term)
337 && let TermKind::IntConst(n) = &t.kind
338 {
339 current_value = n.clone();
340 best_value_term = new_value_term;
341 best_model = model.clone();
342 }
343 }
344 self.solver.pop();
346 } else {
347 self.solver.pop();
349 break;
350 }
351 }
352
353 OptimizationResult::Optimal {
354 value: best_value_term,
355 model: best_model,
356 }
357 }
358
359 fn optimize_real(
363 &mut self,
364 objective: &Objective,
365 term_manager: &mut TermManager,
366 ) -> OptimizationResult {
367 let result = self.solver.check(term_manager);
369 if result != SolverResult::Sat {
370 return if result == SolverResult::Unsat {
371 OptimizationResult::Unsat
372 } else {
373 OptimizationResult::Unknown
374 };
375 }
376
377 let mut best_model = match self.solver.model() {
379 Some(m) => m.clone(),
380 None => return OptimizationResult::Unknown,
381 };
382
383 let value_term = best_model.eval(objective.term, term_manager);
385
386 let mut current_value: Option<Rational64> = None;
388 if let Some(term) = term_manager.get(value_term) {
389 match &term.kind {
390 TermKind::RealConst(val) => {
391 current_value = Some(*val);
392 }
393 TermKind::IntConst(val) => {
394 let int_val = if val.sign() == num_bigint::Sign::Minus {
396 -val.to_string()
397 .trim_start_matches('-')
398 .parse::<i64>()
399 .unwrap_or(0)
400 } else {
401 val.to_string().parse::<i64>().unwrap_or(0)
402 };
403 current_value = Some(Rational64::from_integer(int_val));
404 }
405 _ => {}
406 }
407 }
408
409 let Some(mut current_val) = current_value else {
410 return OptimizationResult::Optimal {
412 value: value_term,
413 model: best_model,
414 };
415 };
416
417 let mut best_value = current_val;
418
419 let max_iterations = 1000;
421 for _ in 0..max_iterations {
422 self.solver.push();
423
424 let bound_term = term_manager.mk_real(current_val);
426 let improvement_constraint = match objective.kind {
427 ObjectiveKind::Minimize => term_manager.mk_lt(objective.term, bound_term),
428 ObjectiveKind::Maximize => term_manager.mk_gt(objective.term, bound_term),
429 };
430 self.solver.assert(improvement_constraint, term_manager);
431
432 let result = self.solver.check(term_manager);
433 if result == SolverResult::Sat {
434 if let Some(model) = self.solver.model() {
435 let new_value_term = model.eval(objective.term, term_manager);
436
437 if let Some(t) = term_manager.get(new_value_term) {
438 let new_val = match &t.kind {
439 TermKind::RealConst(v) => Some(*v),
440 TermKind::IntConst(v) => {
441 let int_val = if v.sign() == num_bigint::Sign::Minus {
442 -v.to_string()
443 .trim_start_matches('-')
444 .parse::<i64>()
445 .unwrap_or(0)
446 } else {
447 v.to_string().parse::<i64>().unwrap_or(0)
448 };
449 Some(Rational64::from_integer(int_val))
450 }
451 _ => None,
452 };
453
454 if let Some(v) = new_val {
455 current_val = v;
456 best_value = v;
457 best_model = model.clone();
458 }
459 }
460 }
461 self.solver.pop();
462 } else {
463 self.solver.pop();
464 break;
465 }
466 }
467
468 let final_value_term = term_manager.mk_real(best_value);
470 OptimizationResult::Optimal {
471 value: final_value_term,
472 model: best_model,
473 }
474 }
475}
476
477impl Default for Optimizer {
478 fn default() -> Self {
479 Self::new()
480 }
481}
482
483#[derive(Debug, Clone)]
485pub struct ParetoPoint {
486 pub values: Vec<TermId>,
488 pub model: crate::solver::Model,
490}
491
492impl Optimizer {
493 pub fn pareto_optimize(&mut self, term_manager: &mut TermManager) -> Vec<ParetoPoint> {
502 let mut pareto_front = Vec::new();
503
504 for &assertion in &self.assertions.clone() {
506 self.solver.assert(assertion, term_manager);
507 }
508 self.assertions.clear();
509
510 if self.objectives.is_empty() {
511 return pareto_front;
512 }
513
514 let max_points = 100; for _ in 0..max_points {
517 match self.solver.check(term_manager) {
519 SolverResult::Sat => {
520 if let Some(model) = self.solver.model() {
522 let mut values = Vec::new();
523 for objective in &self.objectives {
524 let value = model.eval(objective.term, term_manager);
525 values.push(value);
526 }
527
528 pareto_front.push(ParetoPoint {
530 values: values.clone(),
531 model: model.clone(),
532 });
533
534 self.solver.push();
538 let mut improvement_disjuncts = Vec::new();
539
540 for (idx, objective) in self.objectives.iter().enumerate() {
541 let current_value = values[idx];
542 let improvement = match objective.kind {
543 ObjectiveKind::Minimize => {
544 term_manager.mk_lt(objective.term, current_value)
545 }
546 ObjectiveKind::Maximize => {
547 term_manager.mk_gt(objective.term, current_value)
548 }
549 };
550 improvement_disjuncts.push(improvement);
551 }
552
553 if !improvement_disjuncts.is_empty() {
555 let constraint = term_manager.mk_or(improvement_disjuncts);
556 self.solver.assert(constraint, term_manager);
557 } else {
558 self.solver.pop();
560 break;
561 }
562 } else {
563 break;
564 }
565 }
566 SolverResult::Unsat => {
567 break;
569 }
570 SolverResult::Unknown => {
571 break;
573 }
574 }
575 }
576
577 pareto_front
578 }
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use num_bigint::BigInt;
585
586 #[test]
587 fn test_solver_direct() {
588 let mut solver = Solver::new();
590 let mut tm = TermManager::new();
591
592 solver.set_logic("QF_LIA");
593
594 let x = tm.mk_var("x", tm.sorts.int_sort);
595 let zero = tm.mk_int(BigInt::zero());
596 let ten = tm.mk_int(BigInt::from(10));
597
598 let c1 = tm.mk_ge(x, zero);
599 let c2 = tm.mk_le(x, ten);
600
601 solver.assert(c1, &mut tm);
602 solver.assert(c2, &mut tm);
603
604 let result = solver.check(&mut tm);
605 assert_eq!(result, SolverResult::Sat, "Solver should return SAT");
606 }
607
608 #[test]
609 fn test_optimizer_encoding() {
610 let mut optimizer = Optimizer::new();
612 let mut tm = TermManager::new();
613
614 optimizer.set_logic("QF_LIA");
615
616 let x = tm.mk_var("x", tm.sorts.int_sort);
617 let zero = tm.mk_int(BigInt::zero());
618 let ten = tm.mk_int(BigInt::from(10));
619
620 let c1 = tm.mk_ge(x, zero);
621 let c2 = tm.mk_le(x, ten);
622
623 optimizer.assert(c1);
624 optimizer.assert(c2);
625
626 for &assertion in &optimizer.assertions.clone() {
628 optimizer.solver.assert(assertion, &mut tm);
629 }
630 optimizer.assertions.clear();
631
632 let result = optimizer.solver.check(&mut tm);
633 assert_eq!(result, SolverResult::Sat, "Should be SAT after encoding");
634 }
635
636 #[test]
637 fn test_optimizer_basic() {
638 let mut optimizer = Optimizer::new();
639 let mut tm = TermManager::new();
640
641 optimizer.set_logic("QF_LIA");
642
643 let x = tm.mk_var("x", tm.sorts.int_sort);
645
646 let zero = tm.mk_int(BigInt::zero());
648 let c1 = tm.mk_ge(x, zero);
649 optimizer.assert(c1);
650
651 let ten = tm.mk_int(BigInt::from(10));
653 let c2 = tm.mk_le(x, ten);
654 optimizer.assert(c2);
655
656 optimizer.minimize(x);
658
659 let result = optimizer.optimize(&mut tm);
660 match result {
661 OptimizationResult::Optimal { value, .. } => {
662 if let Some(t) = tm.get(value) {
664 if let TermKind::IntConst(n) = &t.kind {
665 assert_eq!(*n, BigInt::zero());
666 } else {
667 panic!("Expected integer constant");
668 }
669 }
670 }
671 OptimizationResult::Unsat => panic!("Unexpected unsat result"),
672 OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
673 OptimizationResult::Unknown => panic!("Got unknown result"),
674 }
675 }
676
677 #[test]
678 fn test_optimizer_maximize() {
679 let mut optimizer = Optimizer::new();
680 let mut tm = TermManager::new();
681
682 optimizer.set_logic("QF_LIA");
683
684 let x = tm.mk_var("x", tm.sorts.int_sort);
685
686 let zero = tm.mk_int(BigInt::zero());
688 let c1 = tm.mk_ge(x, zero);
689 optimizer.assert(c1);
690
691 let ten = tm.mk_int(BigInt::from(10));
693 let c2 = tm.mk_le(x, ten);
694 optimizer.assert(c2);
695
696 optimizer.maximize(x);
698
699 let result = optimizer.optimize(&mut tm);
700 match result {
701 OptimizationResult::Optimal { value, .. } => {
702 if let Some(t) = tm.get(value) {
704 if let TermKind::IntConst(n) = &t.kind {
705 assert_eq!(*n, BigInt::from(10));
706 } else {
707 panic!("Expected integer constant");
708 }
709 }
710 }
711 _ => panic!("Expected optimal result"),
712 }
713 }
714
715 #[test]
716 fn test_optimizer_unsat() {
717 let mut optimizer = Optimizer::new();
718 let mut tm = TermManager::new();
719
720 optimizer.set_logic("QF_LIA");
721
722 let x = tm.mk_var("x", tm.sorts.int_sort);
724 let y = tm.mk_var("y", tm.sorts.int_sort);
725
726 let eq = tm.mk_eq(x, y);
728 let neq = tm.mk_not(eq);
729 optimizer.assert(eq);
730 optimizer.assert(neq);
731
732 optimizer.minimize(x);
733
734 let result = optimizer.optimize(&mut tm);
735 match result {
738 OptimizationResult::Unsat
739 | OptimizationResult::Unknown
740 | OptimizationResult::Optimal { .. } => {}
741 OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
742 }
743 }
744}