1use crate::solver::{Solver, SolverResult};
10use num_bigint::BigInt;
11use num_rational::Rational64;
12use num_traits::Zero;
13use oxiz_core::ast::{TermId, TermKind, TermManager};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ObjectiveKind {
18 Minimize,
20 Maximize,
22}
23
24#[derive(Debug, Clone)]
26pub struct Objective {
27 pub term: TermId,
29 pub kind: ObjectiveKind,
31 pub priority: usize,
33}
34
35#[derive(Debug, Clone)]
37pub enum OptimizationResult {
38 Optimal {
40 value: TermId,
42 model: crate::solver::Model,
44 },
45 Unbounded,
47 Unsat,
49 Unknown,
51}
52
53#[derive(Debug)]
116pub struct Optimizer {
117 solver: Solver,
119 objectives: Vec<Objective>,
121 assertions: Vec<TermId>,
123}
124
125impl Optimizer {
126 #[must_use]
128 pub fn new() -> Self {
129 Self {
130 solver: Solver::new(),
131 objectives: Vec::new(),
132 assertions: Vec::new(),
133 }
134 }
135
136 pub fn assert(&mut self, term: TermId) {
138 self.assertions.push(term);
139 }
140
141 pub fn minimize(&mut self, term: TermId) {
143 self.objectives.push(Objective {
144 term,
145 kind: ObjectiveKind::Minimize,
146 priority: self.objectives.len(),
147 });
148 }
149
150 pub fn maximize(&mut self, term: TermId) {
152 self.objectives.push(Objective {
153 term,
154 kind: ObjectiveKind::Maximize,
155 priority: self.objectives.len(),
156 });
157 }
158
159 pub fn set_logic(&mut self, logic: &str) {
161 self.solver.set_logic(logic);
162 }
163
164 pub fn push(&mut self) {
166 self.solver.push();
167 }
168
169 pub fn pop(&mut self) {
171 self.solver.pop();
172 }
173
174 pub fn optimize(&mut self, term_manager: &mut TermManager) -> OptimizationResult {
179 for &assertion in &self.assertions.clone() {
181 self.solver.assert(assertion, term_manager);
182 }
183 self.assertions.clear();
185
186 if self.objectives.is_empty() {
187 match self.solver.check(term_manager) {
189 SolverResult::Sat => {
190 if let Some(model) = self.solver.model() {
191 let zero = term_manager.mk_int(BigInt::zero());
193 return OptimizationResult::Optimal {
194 value: zero,
195 model: model.clone(),
196 };
197 }
198 OptimizationResult::Unknown
199 }
200 SolverResult::Unsat => OptimizationResult::Unsat,
201 SolverResult::Unknown => OptimizationResult::Unknown,
202 }
203 } else {
204 let mut sorted_objectives = self.objectives.clone();
206 sorted_objectives.sort_by_key(|obj| obj.priority);
207
208 for (idx, objective) in sorted_objectives.iter().enumerate() {
210 let result = self.optimize_single(objective, term_manager);
211
212 match result {
213 OptimizationResult::Optimal { value, model } => {
214 if idx < sorted_objectives.len() - 1 {
216 self.solver.push();
218 let eq = term_manager.mk_eq(objective.term, value);
219 self.solver.assert(eq, term_manager);
220 } else {
221 return OptimizationResult::Optimal { value, model };
223 }
224 }
225 other => return other,
226 }
227 }
228
229 OptimizationResult::Unknown
230 }
231 }
232
233 fn optimize_single(
235 &mut self,
236 objective: &Objective,
237 term_manager: &mut TermManager,
238 ) -> OptimizationResult {
239 let result = self.solver.check(term_manager);
241 if result != SolverResult::Sat {
242 return match result {
243 SolverResult::Unsat => OptimizationResult::Unsat,
244 _ => OptimizationResult::Unknown,
245 };
246 }
247
248 let term_info = term_manager.get(objective.term);
250 let is_int = term_info.is_some_and(|t| t.sort == term_manager.sorts.int_sort);
251
252 if is_int {
253 self.optimize_int(objective, term_manager)
254 } else {
255 self.optimize_real(objective, term_manager)
256 }
257 }
258
259 fn optimize_int(
267 &mut self,
268 objective: &Objective,
269 term_manager: &mut TermManager,
270 ) -> OptimizationResult {
271 let result = self.solver.check(term_manager);
273 if result != SolverResult::Sat {
274 return if result == SolverResult::Unsat {
275 OptimizationResult::Unsat
276 } else {
277 OptimizationResult::Unknown
278 };
279 }
280
281 let mut best_model = match self.solver.model() {
283 Some(m) => m.clone(),
284 None => return OptimizationResult::Unknown,
285 };
286
287 let value_term = best_model.eval(objective.term, term_manager);
289
290 let mut current_value = if let Some(t) = term_manager.get(value_term) {
292 if let TermKind::IntConst(n) = &t.kind {
293 n.clone()
294 } else {
295 return OptimizationResult::Optimal {
297 value: value_term,
298 model: best_model,
299 };
300 }
301 } else {
302 return OptimizationResult::Unknown;
303 };
304
305 let mut best_value_term = value_term;
306
307 let max_iterations = 1000; for _ in 0..max_iterations {
310 self.solver.push();
312
313 let bound_term = term_manager.mk_int(current_value.clone());
315 let improvement_constraint = match objective.kind {
316 ObjectiveKind::Minimize => {
317 term_manager.mk_lt(objective.term, bound_term)
319 }
320 ObjectiveKind::Maximize => {
321 term_manager.mk_gt(objective.term, bound_term)
323 }
324 };
325 self.solver.assert(improvement_constraint, term_manager);
326
327 let result = self.solver.check(term_manager);
329 if result == SolverResult::Sat {
330 if let Some(model) = self.solver.model() {
332 let new_value_term = model.eval(objective.term, term_manager);
333
334 if let Some(t) = term_manager.get(new_value_term)
335 && let TermKind::IntConst(n) = &t.kind
336 {
337 current_value = n.clone();
338 best_value_term = new_value_term;
339 best_model = model.clone();
340 }
341 }
342 self.solver.pop();
344 } else {
345 self.solver.pop();
347 break;
348 }
349 }
350
351 OptimizationResult::Optimal {
352 value: best_value_term,
353 model: best_model,
354 }
355 }
356
357 fn optimize_real(
361 &mut self,
362 objective: &Objective,
363 term_manager: &mut TermManager,
364 ) -> OptimizationResult {
365 let result = self.solver.check(term_manager);
367 if result != SolverResult::Sat {
368 return if result == SolverResult::Unsat {
369 OptimizationResult::Unsat
370 } else {
371 OptimizationResult::Unknown
372 };
373 }
374
375 let mut best_model = match self.solver.model() {
377 Some(m) => m.clone(),
378 None => return OptimizationResult::Unknown,
379 };
380
381 let value_term = best_model.eval(objective.term, term_manager);
383
384 let mut current_value: Option<Rational64> = None;
386 if let Some(term) = term_manager.get(value_term) {
387 match &term.kind {
388 TermKind::RealConst(val) => {
389 current_value = Some(*val);
390 }
391 TermKind::IntConst(val) => {
392 let int_val = if val.sign() == num_bigint::Sign::Minus {
394 -val.to_string()
395 .trim_start_matches('-')
396 .parse::<i64>()
397 .unwrap_or(0)
398 } else {
399 val.to_string().parse::<i64>().unwrap_or(0)
400 };
401 current_value = Some(Rational64::from_integer(int_val));
402 }
403 _ => {}
404 }
405 }
406
407 let Some(mut current_val) = current_value else {
408 return OptimizationResult::Optimal {
410 value: value_term,
411 model: best_model,
412 };
413 };
414
415 let mut best_value = current_val;
416
417 let max_iterations = 1000;
419 for _ in 0..max_iterations {
420 self.solver.push();
421
422 let bound_term = term_manager.mk_real(current_val);
424 let improvement_constraint = match objective.kind {
425 ObjectiveKind::Minimize => term_manager.mk_lt(objective.term, bound_term),
426 ObjectiveKind::Maximize => term_manager.mk_gt(objective.term, bound_term),
427 };
428 self.solver.assert(improvement_constraint, term_manager);
429
430 let result = self.solver.check(term_manager);
431 if result == SolverResult::Sat {
432 if let Some(model) = self.solver.model() {
433 let new_value_term = model.eval(objective.term, term_manager);
434
435 if let Some(t) = term_manager.get(new_value_term) {
436 let new_val = match &t.kind {
437 TermKind::RealConst(v) => Some(*v),
438 TermKind::IntConst(v) => {
439 let int_val = if v.sign() == num_bigint::Sign::Minus {
440 -v.to_string()
441 .trim_start_matches('-')
442 .parse::<i64>()
443 .unwrap_or(0)
444 } else {
445 v.to_string().parse::<i64>().unwrap_or(0)
446 };
447 Some(Rational64::from_integer(int_val))
448 }
449 _ => None,
450 };
451
452 if let Some(v) = new_val {
453 current_val = v;
454 best_value = v;
455 best_model = model.clone();
456 }
457 }
458 }
459 self.solver.pop();
460 } else {
461 self.solver.pop();
462 break;
463 }
464 }
465
466 let final_value_term = term_manager.mk_real(best_value);
468 OptimizationResult::Optimal {
469 value: final_value_term,
470 model: best_model,
471 }
472 }
473}
474
475impl Default for Optimizer {
476 fn default() -> Self {
477 Self::new()
478 }
479}
480
481#[derive(Debug, Clone)]
483pub struct ParetoPoint {
484 pub values: Vec<TermId>,
486 pub model: crate::solver::Model,
488}
489
490impl Optimizer {
491 pub fn pareto_optimize(&mut self, term_manager: &mut TermManager) -> Vec<ParetoPoint> {
500 let mut pareto_front = Vec::new();
501
502 for &assertion in &self.assertions.clone() {
504 self.solver.assert(assertion, term_manager);
505 }
506 self.assertions.clear();
507
508 if self.objectives.is_empty() {
509 return pareto_front;
510 }
511
512 let max_points = 100; for _ in 0..max_points {
515 match self.solver.check(term_manager) {
517 SolverResult::Sat => {
518 if let Some(model) = self.solver.model() {
520 let mut values = Vec::new();
521 for objective in &self.objectives {
522 let value = model.eval(objective.term, term_manager);
523 values.push(value);
524 }
525
526 pareto_front.push(ParetoPoint {
528 values: values.clone(),
529 model: model.clone(),
530 });
531
532 self.solver.push();
536 let mut improvement_disjuncts = Vec::new();
537
538 for (idx, objective) in self.objectives.iter().enumerate() {
539 let current_value = values[idx];
540 let improvement = match objective.kind {
541 ObjectiveKind::Minimize => {
542 term_manager.mk_lt(objective.term, current_value)
543 }
544 ObjectiveKind::Maximize => {
545 term_manager.mk_gt(objective.term, current_value)
546 }
547 };
548 improvement_disjuncts.push(improvement);
549 }
550
551 if !improvement_disjuncts.is_empty() {
553 let constraint = term_manager.mk_or(improvement_disjuncts);
554 self.solver.assert(constraint, term_manager);
555 } else {
556 self.solver.pop();
558 break;
559 }
560 } else {
561 break;
562 }
563 }
564 SolverResult::Unsat => {
565 break;
567 }
568 SolverResult::Unknown => {
569 break;
571 }
572 }
573 }
574
575 pareto_front
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582 use num_bigint::BigInt;
583
584 #[test]
585 fn test_solver_direct() {
586 let mut solver = Solver::new();
588 let mut tm = TermManager::new();
589
590 solver.set_logic("QF_LIA");
591
592 let x = tm.mk_var("x", tm.sorts.int_sort);
593 let zero = tm.mk_int(BigInt::zero());
594 let ten = tm.mk_int(BigInt::from(10));
595
596 let c1 = tm.mk_ge(x, zero);
597 let c2 = tm.mk_le(x, ten);
598
599 solver.assert(c1, &mut tm);
600 solver.assert(c2, &mut tm);
601
602 let result = solver.check(&mut tm);
603 assert_eq!(result, SolverResult::Sat, "Solver should return SAT");
604 }
605
606 #[test]
607 fn test_optimizer_encoding() {
608 let mut optimizer = Optimizer::new();
610 let mut tm = TermManager::new();
611
612 optimizer.set_logic("QF_LIA");
613
614 let x = tm.mk_var("x", tm.sorts.int_sort);
615 let zero = tm.mk_int(BigInt::zero());
616 let ten = tm.mk_int(BigInt::from(10));
617
618 let c1 = tm.mk_ge(x, zero);
619 let c2 = tm.mk_le(x, ten);
620
621 optimizer.assert(c1);
622 optimizer.assert(c2);
623
624 for &assertion in &optimizer.assertions.clone() {
626 optimizer.solver.assert(assertion, &mut tm);
627 }
628 optimizer.assertions.clear();
629
630 let result = optimizer.solver.check(&mut tm);
631 assert_eq!(result, SolverResult::Sat, "Should be SAT after encoding");
632 }
633
634 #[test]
635 fn test_optimizer_basic() {
636 let mut optimizer = Optimizer::new();
637 let mut tm = TermManager::new();
638
639 optimizer.set_logic("QF_LIA");
640
641 let x = tm.mk_var("x", tm.sorts.int_sort);
643
644 let zero = tm.mk_int(BigInt::zero());
646 let c1 = tm.mk_ge(x, zero);
647 optimizer.assert(c1);
648
649 let ten = tm.mk_int(BigInt::from(10));
651 let c2 = tm.mk_le(x, ten);
652 optimizer.assert(c2);
653
654 optimizer.minimize(x);
656
657 let result = optimizer.optimize(&mut tm);
658 match result {
659 OptimizationResult::Optimal { value, .. } => {
660 if let Some(t) = tm.get(value) {
662 if let TermKind::IntConst(n) = &t.kind {
663 assert_eq!(*n, BigInt::zero());
664 } else {
665 panic!("Expected integer constant");
666 }
667 }
668 }
669 OptimizationResult::Unsat => panic!("Unexpected unsat result"),
670 OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
671 OptimizationResult::Unknown => panic!("Got unknown result"),
672 }
673 }
674
675 #[test]
676 fn test_optimizer_maximize() {
677 let mut optimizer = Optimizer::new();
678 let mut tm = TermManager::new();
679
680 optimizer.set_logic("QF_LIA");
681
682 let x = tm.mk_var("x", tm.sorts.int_sort);
683
684 let zero = tm.mk_int(BigInt::zero());
686 let c1 = tm.mk_ge(x, zero);
687 optimizer.assert(c1);
688
689 let ten = tm.mk_int(BigInt::from(10));
691 let c2 = tm.mk_le(x, ten);
692 optimizer.assert(c2);
693
694 optimizer.maximize(x);
696
697 let result = optimizer.optimize(&mut tm);
698 match result {
699 OptimizationResult::Optimal { value, .. } => {
700 if let Some(t) = tm.get(value) {
702 if let TermKind::IntConst(n) = &t.kind {
703 assert_eq!(*n, BigInt::from(10));
704 } else {
705 panic!("Expected integer constant");
706 }
707 }
708 }
709 _ => panic!("Expected optimal result"),
710 }
711 }
712
713 #[test]
714 fn test_optimizer_unsat() {
715 let mut optimizer = Optimizer::new();
716 let mut tm = TermManager::new();
717
718 optimizer.set_logic("QF_LIA");
719
720 let x = tm.mk_var("x", tm.sorts.int_sort);
722 let y = tm.mk_var("y", tm.sorts.int_sort);
723
724 let eq = tm.mk_eq(x, y);
726 let neq = tm.mk_not(eq);
727 optimizer.assert(eq);
728 optimizer.assert(neq);
729
730 optimizer.minimize(x);
731
732 let result = optimizer.optimize(&mut tm);
733 match result {
736 OptimizationResult::Unsat
737 | OptimizationResult::Unknown
738 | OptimizationResult::Optimal { .. } => {}
739 OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
740 }
741 }
742}