1use std::cell::{Ref, RefCell};
2
3use oximo_expr::{Expr, ExprArena, VarId};
4use rustc_hash::FxHashMap;
5use smol_str::SmolStr;
6
7use crate::constraint::{Constraint, ConstraintExpr, ConstraintId};
8use crate::domain::Domain;
9use crate::error::{Error, Result};
10use crate::indexed::IndexedVar;
11use crate::objective::{Objective, ObjectiveSense};
12use crate::set::{FromIndexKey, IndexKey, Set};
13use crate::var::{VarBuilder, Variable};
14
15#[derive(Copy, Clone, Debug, PartialEq, Eq)]
20pub enum ModelKind {
21 LP,
22 MILP,
23 QP,
24 MIQP,
25 NLP,
26 MINLP,
27}
28
29pub struct Model {
38 pub name: String,
39 pub(crate) arena: RefCell<ExprArena>,
40 pub(crate) variables: RefCell<Vec<Variable>>,
41 pub(crate) var_names: RefCell<FxHashMap<SmolStr, VarId>>,
42 pub(crate) constraints: RefCell<Vec<Constraint>>,
43 pub(crate) constraint_names: RefCell<FxHashMap<SmolStr, ConstraintId>>,
44 pub(crate) objective: RefCell<Option<Objective>>,
45 cached_kind: RefCell<Option<ModelKind>>,
46}
47
48impl std::fmt::Debug for Model {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.debug_struct("Model")
51 .field("name", &self.name)
52 .field("vars", &self.variables.borrow().len())
53 .field("constraints", &self.constraints.borrow().len())
54 .field("has_objective", &self.objective.borrow().is_some())
55 .finish()
56 }
57}
58
59impl Model {
60 pub fn new(name: impl Into<String>) -> Self {
61 Self {
62 name: name.into(),
63 arena: RefCell::new(ExprArena::new()),
64 variables: RefCell::new(Vec::new()),
65 var_names: RefCell::new(FxHashMap::default()),
66 constraints: RefCell::new(Vec::new()),
67 constraint_names: RefCell::new(FxHashMap::default()),
68 objective: RefCell::new(None),
69 cached_kind: RefCell::new(None),
70 }
71 }
72
73 pub fn var(&self, name: impl Into<SmolStr>) -> VarBuilder<'_> {
76 VarBuilder {
77 model: self,
78 name: name.into(),
79 lb: f64::NEG_INFINITY,
80 ub: f64::INFINITY,
81 domain: Domain::Real,
82 initial: None,
83 }
84 }
85
86 pub(crate) fn register_var<'a>(&'a self, b: VarBuilder<'a>) -> Expr<'a> {
89 let mut names = self.var_names.borrow_mut();
90 assert!(!names.contains_key(&b.name), "variable name {:?} already registered", b.name);
91 let mut vars = self.variables.borrow_mut();
92 let id = VarId(u32::try_from(vars.len()).expect("variable count overflow"));
93 vars.push(Variable {
94 id,
95 name: b.name.clone(),
96 domain: b.domain,
97 lb: b.lb,
98 ub: b.ub,
99 initial: b.initial,
100 });
101 names.insert(b.name, id);
102 drop(vars);
103 drop(names);
104 *self.cached_kind.borrow_mut() = None;
105 Expr::from_var(&self.arena, id)
106 }
107
108 pub fn indexed_var<'a>(&'a self, name: impl Into<String>, set: &Set) -> IndexedVarBuilder<'a> {
109 IndexedVarBuilder {
110 model: self,
111 base_name: name.into(),
112 keys: set.iter().collect(),
113 lb: f64::NEG_INFINITY,
114 ub: f64::INFINITY,
115 lb_by: None,
116 ub_by: None,
117 domain: Domain::Real,
118 }
119 }
120
121 pub fn variable_id(&self, name: &str) -> Option<VarId> {
122 self.var_names.borrow().get(name).copied()
123 }
124
125 pub fn variables(&self) -> Ref<'_, Vec<Variable>> {
126 self.variables.borrow()
127 }
128
129 pub fn arena(&self) -> Ref<'_, ExprArena> {
130 self.arena.borrow()
131 }
132
133 pub fn num_variables(&self) -> usize {
134 self.variables.borrow().len()
135 }
136
137 pub fn fix_var(&self, id: VarId, value: f64) {
139 let mut vars = self.variables.borrow_mut();
140 let v = &mut vars[id.index()];
141 v.lb = value;
142 v.ub = value;
143 }
144
145 pub fn unfix_var(&self, id: VarId, lb: f64, ub: f64) {
148 let mut vars = self.variables.borrow_mut();
149 let v = &mut vars[id.index()];
150 v.lb = lb;
151 v.ub = ub;
152 }
153
154 pub fn constraint(&self, name: impl Into<SmolStr>, c: ConstraintExpr<'_>) -> ConstraintId {
163 let name = name.into();
164 let mut by_name = self.constraint_names.borrow_mut();
165 assert!(!by_name.contains_key(&name), "constraint name {name:?} already registered");
166 let mut all = self.constraints.borrow_mut();
167 let id = ConstraintId(u32::try_from(all.len()).expect("constraint count overflow"));
168 all.push(Constraint {
169 name: name.clone(),
170 lhs: c.lhs.id,
171 sense: c.sense,
172 rhs: c.rhs,
173 active: true,
174 });
175 by_name.insert(name, id);
176 *self.cached_kind.borrow_mut() = None;
177 id
178 }
179
180 pub fn add_constraints<'a, I>(&'a self, items: I)
183 where
184 I: IntoIterator<Item = (SmolStr, ConstraintExpr<'a>)>,
185 {
186 for (name, c) in items {
187 self.constraint(name, c);
188 }
189 }
190
191 pub fn add_constraints_over<'a, K, F>(&'a self, name_prefix: &str, set: &Set, mut rule: F)
209 where
210 K: FromIndexKey,
211 F: FnMut(K) -> ConstraintExpr<'a>,
212 {
213 for key in set {
214 let typed = K::from_index_key(&key);
215 let c = rule(typed);
216 let name: SmolStr = format_index_name(name_prefix, &key).into();
217 self.constraint(name, c);
218 }
219 }
220
221 pub fn constraints(&self) -> Ref<'_, Vec<Constraint>> {
222 self.constraints.borrow()
223 }
224
225 pub fn num_constraints(&self) -> usize {
226 self.constraints.borrow().len()
227 }
228
229 pub fn constraint_id(&self, name: &str) -> Option<ConstraintId> {
230 self.constraint_names.borrow().get(name).copied()
231 }
232
233 pub fn minimize(&self, expr: Expr<'_>) {
236 self.set_objective(expr, ObjectiveSense::Minimize);
237 }
238
239 pub fn maximize(&self, expr: Expr<'_>) {
240 self.set_objective(expr, ObjectiveSense::Maximize);
241 }
242
243 fn set_objective(&self, expr: Expr<'_>, sense: ObjectiveSense) {
244 *self.objective.borrow_mut() = Some(Objective { expr: expr.id, sense });
245 *self.cached_kind.borrow_mut() = None;
246 }
247
248 pub fn objective(&self) -> Ref<'_, Option<Objective>> {
249 self.objective.borrow()
250 }
251
252 pub fn try_objective(&self) -> Result<Objective> {
258 self.objective.borrow().clone().ok_or(Error::NoObjective)
259 }
260
261 pub fn kind(&self) -> ModelKind {
267 if let Some(k) = *self.cached_kind.borrow() {
268 return k;
269 }
270 let arena = self.arena.borrow();
271 let has_int = self.variables.borrow().iter().any(|v| v.domain.is_integer());
272 let nonlinear = self.constraints.borrow().iter().any(|c| has_nonlinear(&arena, c.lhs))
273 || self.objective.borrow().as_ref().is_some_and(|o| has_nonlinear(&arena, o.expr));
274 let k = match (has_int, nonlinear) {
275 (false, false) => ModelKind::LP,
276 (true, false) => ModelKind::MILP,
277 (false, true) => ModelKind::NLP,
278 (true, true) => ModelKind::MINLP,
279 };
280 *self.cached_kind.borrow_mut() = Some(k);
281 k
282 }
283}
284
285fn has_nonlinear(arena: &ExprArena, id: oximo_expr::ExprId) -> bool {
286 use oximo_expr::ExprNode as N;
287 match arena.get(id) {
288 N::Const(_) | N::Var(_) | N::Param(_) | N::Linear { .. } => false,
289 N::Neg(inner) => has_nonlinear(arena, *inner),
290 N::Add(children) => children.iter().any(|c| has_nonlinear(arena, *c)),
291 N::Mul(children) => {
292 let mut nonconst = 0;
293 for c in children {
294 if !matches!(arena.get(*c), N::Const(_)) {
295 nonconst += 1;
296 }
297 if has_nonlinear(arena, *c) {
298 return true;
299 }
300 }
301 nonconst >= 2
302 }
303 N::Pow(_, _) | N::Sin(_) | N::Cos(_) | N::Exp(_) | N::Log(_) => true,
304 }
305}
306
307type BoundFn<'a> = Box<dyn Fn(&IndexKey) -> f64 + 'a>;
316
317#[must_use = "IndexedVarBuilder does nothing until you call .build()"]
318pub struct IndexedVarBuilder<'a> {
319 model: &'a Model,
320 base_name: String,
321 keys: Vec<IndexKey>,
322 lb: f64,
323 ub: f64,
324 lb_by: Option<BoundFn<'a>>,
325 ub_by: Option<BoundFn<'a>>,
326 domain: Domain,
327}
328
329impl<'a> std::fmt::Debug for IndexedVarBuilder<'a> {
330 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
331 f.debug_struct("IndexedVarBuilder")
332 .field("base_name", &self.base_name)
333 .field("keys", &self.keys.len())
334 .field("lb", &self.lb)
335 .field("ub", &self.ub)
336 .field("per_key_lb", &self.lb_by.is_some())
337 .field("per_key_ub", &self.ub_by.is_some())
338 .field("domain", &self.domain)
339 .finish()
340 }
341}
342
343impl<'a> IndexedVarBuilder<'a> {
344 pub fn lb(mut self, v: f64) -> Self {
345 self.lb = v;
346 self
347 }
348 pub fn ub(mut self, v: f64) -> Self {
349 self.ub = v;
350 self
351 }
352 pub fn bounds(mut self, lb: f64, ub: f64) -> Self {
353 self.lb = lb;
354 self.ub = ub;
355 self
356 }
357 pub fn lb_by<K, F>(mut self, f: F) -> Self
366 where
367 K: FromIndexKey,
368 F: Fn(K) -> f64 + 'a,
369 {
370 self.lb_by = Some(Box::new(move |k: &IndexKey| f(K::from_index_key(k))));
371 self
372 }
373 pub fn ub_by<K, F>(mut self, f: F) -> Self
382 where
383 K: FromIndexKey,
384 F: Fn(K) -> f64 + 'a,
385 {
386 self.ub_by = Some(Box::new(move |k: &IndexKey| f(K::from_index_key(k))));
387 self
388 }
389 pub fn domain(mut self, d: Domain) -> Self {
390 self.domain = d;
391 self
392 }
393 pub fn integer(mut self) -> Self {
394 self.domain = Domain::Integer;
395 self
396 }
397 pub fn binary(mut self) -> Self {
398 self.domain = Domain::Binary;
399 self.lb = 0.0;
400 self.ub = 1.0;
401 self
402 }
403
404 pub fn build(self) -> IndexedVar<'a> {
405 let mut entries = FxHashMap::default();
406 for key in self.keys {
407 let scalar_name: SmolStr = format_index_name(&self.base_name, &key).into();
408 let lb = self.lb_by.as_ref().map_or(self.lb, |f| f(&key));
409 let ub = self.ub_by.as_ref().map_or(self.ub, |f| f(&key));
410 let expr = self.model.var(scalar_name).lb(lb).ub(ub).domain(self.domain).build();
411 entries.insert(key, expr);
412 }
413 IndexedVar { entries }
414 }
415}
416
417fn format_index_name(base: &str, key: &IndexKey) -> String {
418 let mut out = String::with_capacity(base.len() + 4);
419 out.push_str(base);
420 out.push('[');
421 write_key_parts(&mut out, key);
422 out.push(']');
423 out
424}
425
426fn write_key_parts(out: &mut String, key: &IndexKey) {
427 use std::fmt::Write;
428 match key {
429 IndexKey::Int(i) => write!(out, "{i}").unwrap(),
430 IndexKey::Str(s) => out.push_str(s),
431 IndexKey::Tuple(parts) => {
432 for (i, p) in parts.iter().enumerate() {
433 if i > 0 {
434 out.push(',');
435 }
436 write_key_parts(out, p);
437 }
438 }
439 }
440}
441
442pub fn display_index_key(key: &IndexKey) -> String {
445 let mut out = String::new();
446 write_key_parts(&mut out, key);
447 out
448}