use std::cell::{Ref, RefCell};
use oximo_expr::{Expr, ExprArena, ExprClass, ParamId, VarId, classify};
use rustc_hash::FxHashMap;
use smol_str::SmolStr;
use crate::constraint::{Constraint, ConstraintExpr, ConstraintId};
use crate::domain::Domain;
use crate::error::{Error, Result};
use crate::indexed::IndexedVar;
use crate::objective::{Objective, ObjectiveSense};
use crate::param::Parameter;
use crate::set::{FromIndexKey, IndexKey, Set};
use crate::var::{VarBuilder, Variable};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ModelKind {
LP,
MILP,
QP,
MIQP,
NLP,
MINLP,
}
pub struct Model {
pub name: String,
pub(crate) arena: RefCell<ExprArena>,
pub(crate) variables: RefCell<Vec<Variable>>,
pub(crate) var_names: RefCell<FxHashMap<SmolStr, VarId>>,
pub(crate) parameters: RefCell<Vec<Parameter>>,
pub(crate) param_names: RefCell<FxHashMap<SmolStr, ParamId>>,
pub(crate) constraints: RefCell<Vec<Constraint>>,
pub(crate) constraint_names: RefCell<FxHashMap<SmolStr, ConstraintId>>,
pub(crate) objective: RefCell<Option<Objective>>,
cached_kind: RefCell<Option<ModelKind>>,
}
impl std::fmt::Debug for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Model")
.field("name", &self.name)
.field("vars", &self.variables.borrow().len())
.field("params", &self.parameters.borrow().len())
.field("constraints", &self.constraints.borrow().len())
.field("has_objective", &self.objective.borrow().is_some())
.finish()
}
}
impl Model {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
arena: RefCell::new(ExprArena::new()),
variables: RefCell::new(Vec::new()),
var_names: RefCell::new(FxHashMap::default()),
parameters: RefCell::new(Vec::new()),
param_names: RefCell::new(FxHashMap::default()),
constraints: RefCell::new(Vec::new()),
constraint_names: RefCell::new(FxHashMap::default()),
objective: RefCell::new(None),
cached_kind: RefCell::new(None),
}
}
pub fn var(&self, name: impl Into<SmolStr>) -> VarBuilder<'_> {
VarBuilder {
model: self,
name: name.into(),
lb: f64::NEG_INFINITY,
ub: f64::INFINITY,
domain: Domain::Real,
initial: None,
}
}
pub(crate) fn register_var<'a>(&'a self, b: VarBuilder<'a>) -> Expr<'a> {
let mut names = self.var_names.borrow_mut();
assert!(!names.contains_key(&b.name), "variable name {:?} already registered", b.name);
let mut vars = self.variables.borrow_mut();
let id = VarId(u32::try_from(vars.len()).expect("variable count overflow"));
vars.push(Variable {
id,
name: b.name.clone(),
domain: b.domain,
lb: b.lb,
ub: b.ub,
initial: b.initial,
});
names.insert(b.name, id);
drop(vars);
drop(names);
*self.cached_kind.borrow_mut() = None;
Expr::from_var(&self.arena, id)
}
pub fn indexed_var<'a>(&'a self, name: impl Into<String>, set: &Set) -> IndexedVarBuilder<'a> {
IndexedVarBuilder {
model: self,
base_name: name.into(),
keys: set.iter().collect(),
lb: f64::NEG_INFINITY,
ub: f64::INFINITY,
lb_by: None,
ub_by: None,
domain: Domain::Real,
}
}
pub fn variable_id(&self, name: &str) -> Option<VarId> {
self.var_names.borrow().get(name).copied()
}
pub fn variables(&self) -> Ref<'_, Vec<Variable>> {
self.variables.borrow()
}
pub fn arena(&self) -> Ref<'_, ExprArena> {
self.arena.borrow()
}
pub fn num_variables(&self) -> usize {
self.variables.borrow().len()
}
pub fn fix(&self, e: Expr<'_>, value: f64) {
let id = e.var_id().expect("Model::fix expects a single-variable expression");
self.fix_var(id, value);
}
pub fn fix_var(&self, id: VarId, value: f64) {
let mut vars = self.variables.borrow_mut();
let v = &mut vars[id.index()];
v.lb = value;
v.ub = value;
}
pub fn unfix_var(&self, id: VarId, lb: f64, ub: f64) {
let mut vars = self.variables.borrow_mut();
let v = &mut vars[id.index()];
v.lb = lb;
v.ub = ub;
}
pub fn param<'a>(&'a self, name: impl Into<SmolStr>, value: f64) -> Expr<'a> {
let name = name.into();
assert!(
!self.param_names.borrow().contains_key(&name),
"parameter name {name:?} already registered"
);
let (id, node) = {
let mut a = self.arena.borrow_mut();
let id = a.new_param(value);
(id, a.param(id))
};
self.parameters.borrow_mut().push(Parameter { id, name: name.clone() });
self.param_names.borrow_mut().insert(name, id);
*self.cached_kind.borrow_mut() = None;
Expr::new(node, &self.arena)
}
pub fn set_param(&self, p: Expr<'_>, value: f64) {
let id = p.param_id().expect("Model::set_param expects a single-parameter expression");
self.set_param_id(id, value);
}
pub fn set_param_id(&self, id: ParamId, value: f64) {
self.arena.borrow_mut().set_param_value(id, value);
*self.cached_kind.borrow_mut() = None;
}
pub fn param_value(&self, id: ParamId) -> f64 {
self.arena.borrow().param_value(id)
}
pub fn param_value_of(&self, p: Expr<'_>) -> Option<f64> {
p.param_id().map(|id| self.param_value(id))
}
pub fn parameter_id(&self, name: &str) -> Option<ParamId> {
self.param_names.borrow().get(name).copied()
}
pub fn parameters(&self) -> Ref<'_, Vec<Parameter>> {
self.parameters.borrow()
}
pub fn num_parameters(&self) -> usize {
self.parameters.borrow().len()
}
pub fn constraint(&self, name: impl Into<SmolStr>, c: ConstraintExpr<'_>) -> ConstraintId {
let name = name.into();
let mut by_name = self.constraint_names.borrow_mut();
assert!(!by_name.contains_key(&name), "constraint name {name:?} already registered");
let mut all = self.constraints.borrow_mut();
let id = ConstraintId(u32::try_from(all.len()).expect("constraint count overflow"));
all.push(Constraint {
name: name.clone(),
lhs: c.lhs.id,
sense: c.sense,
rhs: c.rhs,
active: true,
});
by_name.insert(name, id);
*self.cached_kind.borrow_mut() = None;
id
}
pub fn add_constraints<'a, I>(&'a self, items: I)
where
I: IntoIterator<Item = (SmolStr, ConstraintExpr<'a>)>,
{
for (name, c) in items {
self.constraint(name, c);
}
}
pub fn add_constraints_over<'a, K, F>(&'a self, name_prefix: &str, set: &Set, mut rule: F)
where
K: FromIndexKey,
F: FnMut(K) -> ConstraintExpr<'a>,
{
for key in set {
let typed = K::from_index_key(&key);
let c = rule(typed);
let name: SmolStr = format_index_name(name_prefix, &key).into();
self.constraint(name, c);
}
}
pub fn constraints(&self) -> Ref<'_, Vec<Constraint>> {
self.constraints.borrow()
}
pub fn num_constraints(&self) -> usize {
self.constraints.borrow().len()
}
pub fn constraint_id(&self, name: &str) -> Option<ConstraintId> {
self.constraint_names.borrow().get(name).copied()
}
pub fn minimize(&self, expr: Expr<'_>) {
self.set_objective(expr, ObjectiveSense::Minimize);
}
pub fn maximize(&self, expr: Expr<'_>) {
self.set_objective(expr, ObjectiveSense::Maximize);
}
fn set_objective(&self, expr: Expr<'_>, sense: ObjectiveSense) {
*self.objective.borrow_mut() = Some(Objective { expr: expr.id, sense });
*self.cached_kind.borrow_mut() = None;
}
pub fn objective(&self) -> Ref<'_, Option<Objective>> {
self.objective.borrow()
}
pub fn try_objective(&self) -> Result<Objective> {
self.objective.borrow().clone().ok_or(Error::NoObjective)
}
pub fn kind(&self) -> ModelKind {
if let Some(k) = *self.cached_kind.borrow() {
return k;
}
let arena = self.arena.borrow();
let has_int = self.variables.borrow().iter().any(|v| v.domain.is_integer());
let mut class = ExprClass::Linear;
if let Some(o) = self.objective.borrow().as_ref() {
class = class.max(classify(&arena, o.expr));
}
for c in self.constraints.borrow().iter() {
if class == ExprClass::Nonlinear {
break;
}
class = class.max(classify(&arena, c.lhs));
}
let k = match (has_int, class) {
(false, ExprClass::Linear) => ModelKind::LP,
(true, ExprClass::Linear) => ModelKind::MILP,
(false, ExprClass::Quadratic) => ModelKind::QP,
(true, ExprClass::Quadratic) => ModelKind::MIQP,
(false, ExprClass::Nonlinear) => ModelKind::NLP,
(true, ExprClass::Nonlinear) => ModelKind::MINLP,
};
*self.cached_kind.borrow_mut() = Some(k);
k
}
}
type BoundFn<'a> = Box<dyn Fn(&IndexKey) -> f64 + 'a>;
#[must_use = "IndexedVarBuilder does nothing until you call .build()"]
pub struct IndexedVarBuilder<'a> {
model: &'a Model,
base_name: String,
keys: Vec<IndexKey>,
lb: f64,
ub: f64,
lb_by: Option<BoundFn<'a>>,
ub_by: Option<BoundFn<'a>>,
domain: Domain,
}
impl<'a> std::fmt::Debug for IndexedVarBuilder<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IndexedVarBuilder")
.field("base_name", &self.base_name)
.field("keys", &self.keys.len())
.field("lb", &self.lb)
.field("ub", &self.ub)
.field("per_key_lb", &self.lb_by.is_some())
.field("per_key_ub", &self.ub_by.is_some())
.field("domain", &self.domain)
.finish()
}
}
impl<'a> IndexedVarBuilder<'a> {
pub fn lb(mut self, v: f64) -> Self {
self.lb = v;
self
}
pub fn ub(mut self, v: f64) -> Self {
self.ub = v;
self
}
pub fn bounds(mut self, lb: f64, ub: f64) -> Self {
self.lb = lb;
self.ub = ub;
self
}
pub fn lb_by<K, F>(mut self, f: F) -> Self
where
K: FromIndexKey,
F: Fn(K) -> f64 + 'a,
{
self.lb_by = Some(Box::new(move |k: &IndexKey| f(K::from_index_key(k))));
self
}
pub fn ub_by<K, F>(mut self, f: F) -> Self
where
K: FromIndexKey,
F: Fn(K) -> f64 + 'a,
{
self.ub_by = Some(Box::new(move |k: &IndexKey| f(K::from_index_key(k))));
self
}
pub fn domain(mut self, d: Domain) -> Self {
self.domain = d;
self
}
pub fn integer(mut self) -> Self {
self.domain = Domain::Integer;
self
}
pub fn binary(mut self) -> Self {
self.domain = Domain::Binary;
self.lb = 0.0;
self.ub = 1.0;
self
}
pub fn build(self) -> IndexedVar<'a> {
let mut entries = FxHashMap::default();
for key in self.keys {
let scalar_name: SmolStr = format_index_name(&self.base_name, &key).into();
let lb = self.lb_by.as_ref().map_or(self.lb, |f| f(&key));
let ub = self.ub_by.as_ref().map_or(self.ub, |f| f(&key));
let expr = self.model.var(scalar_name).lb(lb).ub(ub).domain(self.domain).build();
entries.insert(key, expr);
}
IndexedVar { entries }
}
}
fn format_index_name(base: &str, key: &IndexKey) -> String {
let mut out = String::with_capacity(base.len() + 4);
out.push_str(base);
out.push('[');
write_key_parts(&mut out, key);
out.push(']');
out
}
fn write_key_parts(out: &mut String, key: &IndexKey) {
use std::fmt::Write;
match key {
IndexKey::Int(i) => write!(out, "{i}").unwrap(),
IndexKey::Str(s) => out.push_str(s),
IndexKey::Tuple(parts) => {
for (i, p) in parts.iter().enumerate() {
if i > 0 {
out.push(',');
}
write_key_parts(out, p);
}
}
}
}
pub fn display_index_key(key: &IndexKey) -> String {
let mut out = String::new();
write_key_parts(&mut out, key);
out
}
#[cfg(test)]
mod tests {
use oximo_expr::extract_linear;
use super::*;
use crate::constraint::Relate;
#[test]
fn param_times_var_keeps_model_linear() {
let m = Model::new("p");
let param = m.param("param", 4.0);
let x = m.var("x").lb(0.0).build();
m.minimize(param * x);
assert_eq!(m.kind(), ModelKind::LP);
}
#[test]
fn param_coeff_resolves_and_rebinds() {
let m = Model::new("p");
let param = m.param("param", 4.0);
let x = m.var("x").lb(0.0).build();
let obj = param * x;
let coeff = |m: &Model| {
let arena = m.arena();
extract_linear(&arena, obj.id).expect("linear").coeffs[0].1
};
assert!((coeff(&m) - 4.0).abs() < f64::EPSILON);
m.set_param(param, 9.0);
assert!((coeff(&m) - 9.0).abs() < f64::EPSILON);
assert_eq!(m.parameter_id("param"), Some(param.param_id().unwrap()));
}
#[test]
fn param_value_reads_live_arena_value() {
let m = Model::new("p");
let param = m.param("param", 4.0);
let id = param.param_id().unwrap();
assert!((m.param_value(id) - 4.0).abs() < f64::EPSILON);
assert!((m.param_value_of(param).unwrap() - 4.0).abs() < f64::EPSILON);
m.set_param(param, 7.5);
assert!((m.param_value(id) - 7.5).abs() < f64::EPSILON);
let x = m.var("x").build();
assert!(m.param_value_of(x).is_none());
}
#[test]
fn set_param_invalidates_kind_cache() {
let m = Model::new("p");
let p = m.param("p", 1.0);
let x = m.var("x").lb(0.0).build();
m.constraint("c", (p * x).le(10.0));
assert_eq!(m.kind(), ModelKind::LP);
m.set_param(p, 2.0);
assert_eq!(m.kind(), ModelKind::LP);
}
#[test]
#[should_panic(expected = "parameter name \"dup\" already registered")]
fn duplicate_param_name_panics() {
let m = Model::new("p");
let _a = m.param("dup", 1.0);
let _b = m.param("dup", 2.0);
}
}