#[allow(unused_imports)]
use crate::prelude::*;
use crate::solver::{Solver, SolverResult};
use num_bigint::BigInt;
use num_rational::Rational64;
use num_traits::Zero;
use oxiz_core::ast::{TermId, TermKind, TermManager};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ObjectiveKind {
Minimize,
Maximize,
}
#[derive(Debug, Clone)]
pub struct Objective {
pub term: TermId,
pub kind: ObjectiveKind,
pub priority: usize,
}
#[derive(Debug, Clone)]
pub enum OptimizationResult {
Optimal {
value: TermId,
model: crate::solver::Model,
},
Unbounded,
Unsat,
Unknown,
}
#[derive(Debug)]
pub struct Optimizer {
solver: Solver,
objectives: Vec<Objective>,
assertions: Vec<TermId>,
}
impl Optimizer {
#[must_use]
pub fn new() -> Self {
Self {
solver: Solver::new(),
objectives: Vec::new(),
assertions: Vec::new(),
}
}
pub fn assert(&mut self, term: TermId) {
self.assertions.push(term);
}
pub fn minimize(&mut self, term: TermId) {
self.objectives.push(Objective {
term,
kind: ObjectiveKind::Minimize,
priority: self.objectives.len(),
});
}
pub fn maximize(&mut self, term: TermId) {
self.objectives.push(Objective {
term,
kind: ObjectiveKind::Maximize,
priority: self.objectives.len(),
});
}
pub fn set_logic(&mut self, logic: &str) {
self.solver.set_logic(logic);
}
pub fn push(&mut self) {
self.solver.push();
}
pub fn pop(&mut self) {
self.solver.pop();
}
pub fn optimize(&mut self, term_manager: &mut TermManager) -> OptimizationResult {
for &assertion in &self.assertions.clone() {
self.solver.assert(assertion, term_manager);
}
self.assertions.clear();
if self.objectives.is_empty() {
match self.solver.check(term_manager) {
SolverResult::Sat => {
if let Some(model) = self.solver.model() {
let zero = term_manager.mk_int(BigInt::zero());
return OptimizationResult::Optimal {
value: zero,
model: model.clone(),
};
}
OptimizationResult::Unknown
}
SolverResult::Unsat => OptimizationResult::Unsat,
SolverResult::Unknown => OptimizationResult::Unknown,
}
} else {
let mut sorted_objectives = self.objectives.clone();
sorted_objectives.sort_by_key(|obj| obj.priority);
for (idx, objective) in sorted_objectives.iter().enumerate() {
let result = self.optimize_single(objective, term_manager);
match result {
OptimizationResult::Optimal { value, model } => {
if idx < sorted_objectives.len() - 1 {
self.solver.push();
let eq = term_manager.mk_eq(objective.term, value);
self.solver.assert(eq, term_manager);
} else {
return OptimizationResult::Optimal { value, model };
}
}
other => return other,
}
}
OptimizationResult::Unknown
}
}
fn optimize_single(
&mut self,
objective: &Objective,
term_manager: &mut TermManager,
) -> OptimizationResult {
let result = self.solver.check(term_manager);
if result != SolverResult::Sat {
return match result {
SolverResult::Unsat => OptimizationResult::Unsat,
_ => OptimizationResult::Unknown,
};
}
let term_info = term_manager.get(objective.term);
let is_int = term_info.is_some_and(|t| t.sort == term_manager.sorts.int_sort);
if is_int {
self.optimize_int(objective, term_manager)
} else {
self.optimize_real(objective, term_manager)
}
}
fn optimize_int(
&mut self,
objective: &Objective,
term_manager: &mut TermManager,
) -> OptimizationResult {
let result = self.solver.check(term_manager);
if result != SolverResult::Sat {
return if result == SolverResult::Unsat {
OptimizationResult::Unsat
} else {
OptimizationResult::Unknown
};
}
let mut best_model = match self.solver.model() {
Some(m) => m.clone(),
None => return OptimizationResult::Unknown,
};
let value_term = best_model.eval(objective.term, term_manager);
let mut current_value = if let Some(t) = term_manager.get(value_term) {
if let TermKind::IntConst(n) = &t.kind {
n.clone()
} else {
return OptimizationResult::Optimal {
value: value_term,
model: best_model,
};
}
} else {
return OptimizationResult::Unknown;
};
let mut best_value_term = value_term;
let max_iterations = 1000; for _ in 0..max_iterations {
self.solver.push();
let bound_term = term_manager.mk_int(current_value.clone());
let improvement_constraint = match objective.kind {
ObjectiveKind::Minimize => {
term_manager.mk_lt(objective.term, bound_term)
}
ObjectiveKind::Maximize => {
term_manager.mk_gt(objective.term, bound_term)
}
};
self.solver.assert(improvement_constraint, term_manager);
let result = self.solver.check(term_manager);
if result == SolverResult::Sat {
if let Some(model) = self.solver.model() {
let new_value_term = model.eval(objective.term, term_manager);
if let Some(t) = term_manager.get(new_value_term)
&& let TermKind::IntConst(n) = &t.kind
{
current_value = n.clone();
best_value_term = new_value_term;
best_model = model.clone();
}
}
self.solver.pop();
} else {
self.solver.pop();
break;
}
}
OptimizationResult::Optimal {
value: best_value_term,
model: best_model,
}
}
fn optimize_real(
&mut self,
objective: &Objective,
term_manager: &mut TermManager,
) -> OptimizationResult {
let result = self.solver.check(term_manager);
if result != SolverResult::Sat {
return if result == SolverResult::Unsat {
OptimizationResult::Unsat
} else {
OptimizationResult::Unknown
};
}
let mut best_model = match self.solver.model() {
Some(m) => m.clone(),
None => return OptimizationResult::Unknown,
};
let value_term = best_model.eval(objective.term, term_manager);
let mut current_value: Option<Rational64> = None;
if let Some(term) = term_manager.get(value_term) {
match &term.kind {
TermKind::RealConst(val) => {
current_value = Some(*val);
}
TermKind::IntConst(val) => {
let int_val = if val.sign() == num_bigint::Sign::Minus {
-val.to_string()
.trim_start_matches('-')
.parse::<i64>()
.unwrap_or(0)
} else {
val.to_string().parse::<i64>().unwrap_or(0)
};
current_value = Some(Rational64::from_integer(int_val));
}
_ => {}
}
}
let Some(mut current_val) = current_value else {
return OptimizationResult::Optimal {
value: value_term,
model: best_model,
};
};
let mut best_value = current_val;
let max_iterations = 1000;
for _ in 0..max_iterations {
self.solver.push();
let bound_term = term_manager.mk_real(current_val);
let improvement_constraint = match objective.kind {
ObjectiveKind::Minimize => term_manager.mk_lt(objective.term, bound_term),
ObjectiveKind::Maximize => term_manager.mk_gt(objective.term, bound_term),
};
self.solver.assert(improvement_constraint, term_manager);
let result = self.solver.check(term_manager);
if result == SolverResult::Sat {
if let Some(model) = self.solver.model() {
let new_value_term = model.eval(objective.term, term_manager);
if let Some(t) = term_manager.get(new_value_term) {
let new_val = match &t.kind {
TermKind::RealConst(v) => Some(*v),
TermKind::IntConst(v) => {
let int_val = if v.sign() == num_bigint::Sign::Minus {
-v.to_string()
.trim_start_matches('-')
.parse::<i64>()
.unwrap_or(0)
} else {
v.to_string().parse::<i64>().unwrap_or(0)
};
Some(Rational64::from_integer(int_val))
}
_ => None,
};
if let Some(v) = new_val {
current_val = v;
best_value = v;
best_model = model.clone();
}
}
}
self.solver.pop();
} else {
self.solver.pop();
break;
}
}
let final_value_term = term_manager.mk_real(best_value);
OptimizationResult::Optimal {
value: final_value_term,
model: best_model,
}
}
}
impl Default for Optimizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ParetoPoint {
pub values: Vec<TermId>,
pub model: crate::solver::Model,
}
impl Optimizer {
pub fn pareto_optimize(&mut self, term_manager: &mut TermManager) -> Vec<ParetoPoint> {
let mut pareto_front = Vec::new();
for &assertion in &self.assertions.clone() {
self.solver.assert(assertion, term_manager);
}
self.assertions.clear();
if self.objectives.is_empty() {
return pareto_front;
}
let max_points = 100; for _ in 0..max_points {
match self.solver.check(term_manager) {
SolverResult::Sat => {
if let Some(model) = self.solver.model() {
let mut values = Vec::new();
for objective in &self.objectives {
let value = model.eval(objective.term, term_manager);
values.push(value);
}
pareto_front.push(ParetoPoint {
values: values.clone(),
model: model.clone(),
});
self.solver.push();
let mut improvement_disjuncts = Vec::new();
for (idx, objective) in self.objectives.iter().enumerate() {
let current_value = values[idx];
let improvement = match objective.kind {
ObjectiveKind::Minimize => {
term_manager.mk_lt(objective.term, current_value)
}
ObjectiveKind::Maximize => {
term_manager.mk_gt(objective.term, current_value)
}
};
improvement_disjuncts.push(improvement);
}
if !improvement_disjuncts.is_empty() {
let constraint = term_manager.mk_or(improvement_disjuncts);
self.solver.assert(constraint, term_manager);
} else {
self.solver.pop();
break;
}
} else {
break;
}
}
SolverResult::Unsat => {
break;
}
SolverResult::Unknown => {
break;
}
}
}
pareto_front
}
}
#[cfg(test)]
mod tests {
use super::*;
use num_bigint::BigInt;
#[test]
fn test_solver_direct() {
let mut solver = Solver::new();
let mut tm = TermManager::new();
solver.set_logic("QF_LIA");
let x = tm.mk_var("x", tm.sorts.int_sort);
let zero = tm.mk_int(BigInt::zero());
let ten = tm.mk_int(BigInt::from(10));
let c1 = tm.mk_ge(x, zero);
let c2 = tm.mk_le(x, ten);
solver.assert(c1, &mut tm);
solver.assert(c2, &mut tm);
let result = solver.check(&mut tm);
assert_eq!(result, SolverResult::Sat, "Solver should return SAT");
}
#[test]
fn test_optimizer_encoding() {
let mut optimizer = Optimizer::new();
let mut tm = TermManager::new();
optimizer.set_logic("QF_LIA");
let x = tm.mk_var("x", tm.sorts.int_sort);
let zero = tm.mk_int(BigInt::zero());
let ten = tm.mk_int(BigInt::from(10));
let c1 = tm.mk_ge(x, zero);
let c2 = tm.mk_le(x, ten);
optimizer.assert(c1);
optimizer.assert(c2);
for &assertion in &optimizer.assertions.clone() {
optimizer.solver.assert(assertion, &mut tm);
}
optimizer.assertions.clear();
let result = optimizer.solver.check(&mut tm);
assert_eq!(result, SolverResult::Sat, "Should be SAT after encoding");
}
#[test]
fn test_optimizer_basic() {
let mut optimizer = Optimizer::new();
let mut tm = TermManager::new();
optimizer.set_logic("QF_LIA");
let x = tm.mk_var("x", tm.sorts.int_sort);
let zero = tm.mk_int(BigInt::zero());
let c1 = tm.mk_ge(x, zero);
optimizer.assert(c1);
let ten = tm.mk_int(BigInt::from(10));
let c2 = tm.mk_le(x, ten);
optimizer.assert(c2);
optimizer.minimize(x);
let result = optimizer.optimize(&mut tm);
match result {
OptimizationResult::Optimal { value, .. } => {
if let Some(t) = tm.get(value) {
if let TermKind::IntConst(n) = &t.kind {
assert_eq!(*n, BigInt::zero());
} else {
panic!("Expected integer constant");
}
}
}
OptimizationResult::Unsat => panic!("Unexpected unsat result"),
OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
OptimizationResult::Unknown => panic!("Got unknown result"),
}
}
#[test]
fn test_optimizer_maximize() {
let mut optimizer = Optimizer::new();
let mut tm = TermManager::new();
optimizer.set_logic("QF_LIA");
let x = tm.mk_var("x", tm.sorts.int_sort);
let zero = tm.mk_int(BigInt::zero());
let c1 = tm.mk_ge(x, zero);
optimizer.assert(c1);
let ten = tm.mk_int(BigInt::from(10));
let c2 = tm.mk_le(x, ten);
optimizer.assert(c2);
optimizer.maximize(x);
let result = optimizer.optimize(&mut tm);
match result {
OptimizationResult::Optimal { value, .. } => {
if let Some(t) = tm.get(value) {
if let TermKind::IntConst(n) = &t.kind {
assert_eq!(*n, BigInt::from(10));
} else {
panic!("Expected integer constant");
}
}
}
_ => panic!("Expected optimal result"),
}
}
#[test]
fn test_optimizer_unsat() {
let mut optimizer = Optimizer::new();
let mut tm = TermManager::new();
optimizer.set_logic("QF_LIA");
let x = tm.mk_var("x", tm.sorts.int_sort);
let y = tm.mk_var("y", tm.sorts.int_sort);
let eq = tm.mk_eq(x, y);
let neq = tm.mk_not(eq);
optimizer.assert(eq);
optimizer.assert(neq);
optimizer.minimize(x);
let result = optimizer.optimize(&mut tm);
match result {
OptimizationResult::Unsat
| OptimizationResult::Unknown
| OptimizationResult::Optimal { .. } => {}
OptimizationResult::Unbounded => panic!("Unexpected unbounded result"),
}
}
}