Skip to main content

oximo_core/
model.rs

1use std::cell::{Ref, RefCell};
2
3use oximo_expr::{Expr, ExprArena, ExprClass, ParamId, VarId, classify};
4use rustc_hash::FxHashMap;
5use smol_str::SmolStr;
6
7use crate::constraint::{Constraint, ConstraintExpr, ConstraintId};
8use crate::domain::Domain;
9use crate::error::{Error, Result};
10use crate::indexed::IndexedVar;
11use crate::objective::{Objective, ObjectiveSense};
12use crate::param::Parameter;
13use crate::set::{FromIndexKey, IndexKey, Set};
14use crate::var::{VarBuilder, Variable};
15
16/// The kind of mathematical program a `Model` represents.
17///
18/// This is inferred from the variables and expressions in the model, not set
19/// explicitly by the user. See [`Model::kind`] for details.
20#[derive(Copy, Clone, Debug, PartialEq, Eq)]
21pub enum ModelKind {
22    LP,
23    MILP,
24    QP,
25    MIQP,
26    NLP,
27    MINLP,
28}
29
30/// The optimization model. Owns the expression arena, variable/parameter
31/// registries, constraints, and (optional) objective.
32///
33/// `Model` uses interior mutability so the builder API can take `&self`
34/// references.
35///
36/// Variables, constraints, and the objective are added through
37/// `RefCell`s under the hood.
38pub struct Model {
39    pub name: String,
40    pub(crate) arena: RefCell<ExprArena>,
41    pub(crate) variables: RefCell<Vec<Variable>>,
42    pub(crate) var_names: RefCell<FxHashMap<SmolStr, VarId>>,
43    pub(crate) parameters: RefCell<Vec<Parameter>>,
44    pub(crate) param_names: RefCell<FxHashMap<SmolStr, ParamId>>,
45    pub(crate) constraints: RefCell<Vec<Constraint>>,
46    pub(crate) constraint_names: RefCell<FxHashMap<SmolStr, ConstraintId>>,
47    pub(crate) objective: RefCell<Option<Objective>>,
48    cached_kind: RefCell<Option<ModelKind>>,
49}
50
51impl std::fmt::Debug for Model {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.debug_struct("Model")
54            .field("name", &self.name)
55            .field("vars", &self.variables.borrow().len())
56            .field("params", &self.parameters.borrow().len())
57            .field("constraints", &self.constraints.borrow().len())
58            .field("has_objective", &self.objective.borrow().is_some())
59            .finish()
60    }
61}
62
63impl Model {
64    pub fn new(name: impl Into<String>) -> Self {
65        Self {
66            name: name.into(),
67            arena: RefCell::new(ExprArena::new()),
68            variables: RefCell::new(Vec::new()),
69            var_names: RefCell::new(FxHashMap::default()),
70            parameters: RefCell::new(Vec::new()),
71            param_names: RefCell::new(FxHashMap::default()),
72            constraints: RefCell::new(Vec::new()),
73            constraint_names: RefCell::new(FxHashMap::default()),
74            objective: RefCell::new(None),
75            cached_kind: RefCell::new(None),
76        }
77    }
78
79    // Variables
80
81    pub fn var(&self, name: impl Into<SmolStr>) -> VarBuilder<'_> {
82        VarBuilder {
83            model: self,
84            name: name.into(),
85            lb: f64::NEG_INFINITY,
86            ub: f64::INFINITY,
87            domain: Domain::Real,
88            initial: None,
89        }
90    }
91
92    /// Called by [`VarBuilder::build`]. Pushes the var into the registry and
93    /// returns its `Expr` handle.
94    pub(crate) fn register_var<'a>(&'a self, b: VarBuilder<'a>) -> Expr<'a> {
95        let mut names = self.var_names.borrow_mut();
96        assert!(!names.contains_key(&b.name), "variable name {:?} already registered", b.name);
97        let mut vars = self.variables.borrow_mut();
98        let id = VarId(u32::try_from(vars.len()).expect("variable count overflow"));
99        vars.push(Variable {
100            id,
101            name: b.name.clone(),
102            domain: b.domain,
103            lb: b.lb,
104            ub: b.ub,
105            initial: b.initial,
106        });
107        names.insert(b.name, id);
108        drop(vars);
109        drop(names);
110        *self.cached_kind.borrow_mut() = None;
111        Expr::from_var(&self.arena, id)
112    }
113
114    pub fn indexed_var<'a>(&'a self, name: impl Into<String>, set: &Set) -> IndexedVarBuilder<'a> {
115        IndexedVarBuilder {
116            model: self,
117            base_name: name.into(),
118            keys: set.iter().collect(),
119            lb: f64::NEG_INFINITY,
120            ub: f64::INFINITY,
121            lb_by: None,
122            ub_by: None,
123            domain: Domain::Real,
124        }
125    }
126
127    pub fn variable_id(&self, name: &str) -> Option<VarId> {
128        self.var_names.borrow().get(name).copied()
129    }
130
131    pub fn variables(&self) -> Ref<'_, Vec<Variable>> {
132        self.variables.borrow()
133    }
134
135    pub fn arena(&self) -> Ref<'_, ExprArena> {
136        self.arena.borrow()
137    }
138
139    pub fn num_variables(&self) -> usize {
140        self.variables.borrow().len()
141    }
142
143    /// Fix a single-variable expression to `value`.
144    /// Convenience over [`Self::fix_var`] for handles from [`Model::var`] or
145    /// [`crate::IndexedVar`] indexing.
146    ///
147    /// # Panics
148    ///
149    /// Panics if `e` is not a bare variable handle.
150    pub fn fix(&self, e: Expr<'_>, value: f64) {
151        let id = e.var_id().expect("Model::fix expects a single-variable expression");
152        self.fix_var(id, value);
153    }
154
155    /// Fix variable `id` to `value` by setting `lb = ub = value`.
156    pub fn fix_var(&self, id: VarId, value: f64) {
157        let mut vars = self.variables.borrow_mut();
158        let v = &mut vars[id.index()];
159        v.lb = value;
160        v.ub = value;
161    }
162
163    /// Restore bounds on variable `id`. Pass `f64::NEG_INFINITY` / `f64::INFINITY`
164    /// to restore an unbounded direction.
165    pub fn unfix_var(&self, id: VarId, lb: f64, ub: f64) {
166        let mut vars = self.variables.borrow_mut();
167        let v = &mut vars[id.index()];
168        v.lb = lb;
169        v.ub = ub;
170    }
171
172    // Parameters
173
174    /// Register a named scalar parameter initialized to `value`, returning an
175    /// [`Expr`] handle that references it symbolically.
176    ///
177    /// A parameter behaves like a constant coefficient (`param * var` is linear),
178    /// but stays symbolic in the expression tree so it can be re-bound with
179    /// [`Self::set_param`] / [`Self::set_param_id`] between solves without
180    /// rebuilding the model.
181    ///
182    /// # Panics
183    ///
184    /// Panics if a parameter with the same name is already registered.
185    pub fn param<'a>(&'a self, name: impl Into<SmolStr>, value: f64) -> Expr<'a> {
186        let name = name.into();
187        assert!(
188            !self.param_names.borrow().contains_key(&name),
189            "parameter name {name:?} already registered"
190        );
191        let (id, node) = {
192            let mut a = self.arena.borrow_mut();
193            let id = a.new_param(value);
194            (id, a.param(id))
195        };
196        self.parameters.borrow_mut().push(Parameter { id, name: name.clone() });
197        self.param_names.borrow_mut().insert(name, id);
198        *self.cached_kind.borrow_mut() = None;
199        Expr::new(node, &self.arena)
200    }
201
202    /// Re-bind the parameter referenced by handle `p` to `value`.
203    ///
204    /// # Panics
205    ///
206    /// Panics if `p` is not a bare parameter handle (see [`Self::param`]).
207    pub fn set_param(&self, p: Expr<'_>, value: f64) {
208        let id = p.param_id().expect("Model::set_param expects a single-parameter expression");
209        self.set_param_id(id, value);
210    }
211
212    /// Re-bind parameter `id` to `value`. Takes effect on the next solve.
213    ///
214    /// The value is stored only in the expression arena (its single source of
215    /// truth); extraction and evaluation read it from there.
216    pub fn set_param_id(&self, id: ParamId, value: f64) {
217        self.arena.borrow_mut().set_param_value(id, value);
218        *self.cached_kind.borrow_mut() = None;
219    }
220
221    /// Current value bound to parameter `id`.
222    ///
223    /// # Panics
224    ///
225    /// Panics if `id` was not produced by [`Self::param`] on this model.
226    pub fn param_value(&self, id: ParamId) -> f64 {
227        self.arena.borrow().param_value(id)
228    }
229
230    /// Current value of the parameter referenced by handle `p`, or `None` if
231    /// `p` is not a bare parameter handle.
232    pub fn param_value_of(&self, p: Expr<'_>) -> Option<f64> {
233        p.param_id().map(|id| self.param_value(id))
234    }
235
236    pub fn parameter_id(&self, name: &str) -> Option<ParamId> {
237        self.param_names.borrow().get(name).copied()
238    }
239
240    pub fn parameters(&self) -> Ref<'_, Vec<Parameter>> {
241        self.parameters.borrow()
242    }
243
244    pub fn num_parameters(&self) -> usize {
245        self.parameters.borrow().len()
246    }
247
248    // Constraints
249
250    /// Register a new constraint.
251    ///
252    /// # Panics
253    ///
254    /// Panics if a constraint with the same name is already registered, or if
255    /// the constraint count exceeds `u32::MAX`.
256    pub fn constraint(&self, name: impl Into<SmolStr>, c: ConstraintExpr<'_>) -> ConstraintId {
257        let name = name.into();
258        let mut by_name = self.constraint_names.borrow_mut();
259        assert!(!by_name.contains_key(&name), "constraint name {name:?} already registered");
260        let mut all = self.constraints.borrow_mut();
261        let id = ConstraintId(u32::try_from(all.len()).expect("constraint count overflow"));
262        all.push(Constraint {
263            name: name.clone(),
264            lhs: c.lhs.id,
265            sense: c.sense,
266            rhs: c.rhs,
267            active: true,
268        });
269        by_name.insert(name, id);
270        *self.cached_kind.borrow_mut() = None;
271        id
272    }
273
274    /// Bulk-register constraints. Each entry is `(name, ConstraintExpr)`.
275    /// Useful with `.par_iter().map(...).collect()` style construction.
276    pub fn add_constraints<'a, I>(&'a self, items: I)
277    where
278        I: IntoIterator<Item = (SmolStr, ConstraintExpr<'a>)>,
279    {
280        for (name, c) in items {
281            self.constraint(name, c);
282        }
283    }
284
285    /// Rule-style bulk constraint registration.
286    ///
287    /// The closure receives the index as a typed value `K`. Any type
288    /// implementing [`FromIndexKey`] is accepted. Built-in impls cover `i64`,
289    /// `i32`, `usize`, `String`, raw `IndexKey`, and tuples up to arity 4.
290    /// The user states the expected shape via the closure-arg annotation.
291    ///
292    /// # Example
293    /// ```ignore
294    /// // Scalar set: closure receives a usize directly.
295    /// m.add_constraints_over("upper", &i, |i: usize| x[i].le(b[i]));
296    ///
297    /// // Tuple set: destructure inline.
298    /// m.add_constraints_over("blo", &(&tasks * &events), |(t, n): (usize, usize)| {
299    ///     (b[(t, n)] - b_min[t] * w[(t, n)]).ge(0.0)
300    /// });
301    /// ```
302    pub fn add_constraints_over<'a, K, F>(&'a self, name_prefix: &str, set: &Set, mut rule: F)
303    where
304        K: FromIndexKey,
305        F: FnMut(K) -> ConstraintExpr<'a>,
306    {
307        for key in set {
308            let typed = K::from_index_key(&key);
309            let c = rule(typed);
310            let name: SmolStr = format_index_name(name_prefix, &key).into();
311            self.constraint(name, c);
312        }
313    }
314
315    pub fn constraints(&self) -> Ref<'_, Vec<Constraint>> {
316        self.constraints.borrow()
317    }
318
319    pub fn num_constraints(&self) -> usize {
320        self.constraints.borrow().len()
321    }
322
323    pub fn constraint_id(&self, name: &str) -> Option<ConstraintId> {
324        self.constraint_names.borrow().get(name).copied()
325    }
326
327    // Objective
328
329    pub fn minimize(&self, expr: Expr<'_>) {
330        self.set_objective(expr, ObjectiveSense::Minimize);
331    }
332
333    pub fn maximize(&self, expr: Expr<'_>) {
334        self.set_objective(expr, ObjectiveSense::Maximize);
335    }
336
337    fn set_objective(&self, expr: Expr<'_>, sense: ObjectiveSense) {
338        *self.objective.borrow_mut() = Some(Objective { expr: expr.id, sense });
339        *self.cached_kind.borrow_mut() = None;
340    }
341
342    pub fn objective(&self) -> Ref<'_, Option<Objective>> {
343        self.objective.borrow()
344    }
345
346    /// Try to get a cloned copy of the objective.
347    ///
348    /// # Errors
349    ///
350    /// Returns [`Error::NoObjective`] if no objective is set on this model.
351    pub fn try_objective(&self) -> Result<Objective> {
352        self.objective.borrow().clone().ok_or(Error::NoObjective)
353    }
354
355    // Classification
356
357    /// Infer the [`ModelKind`] from current variables and expressions.
358    /// Result is cached and invalidated whenever variables, constraints, or the
359    /// objective change.
360    pub fn kind(&self) -> ModelKind {
361        if let Some(k) = *self.cached_kind.borrow() {
362            return k;
363        }
364        let arena = self.arena.borrow();
365        let has_int = self.variables.borrow().iter().any(|v| v.domain.is_integer());
366
367        // Highest expression class across the objective and every constraint
368        // determines the model class
369        let mut class = ExprClass::Linear;
370        if let Some(o) = self.objective.borrow().as_ref() {
371            class = class.max(classify(&arena, o.expr));
372        }
373        for c in self.constraints.borrow().iter() {
374            if class == ExprClass::Nonlinear {
375                break;
376            }
377            class = class.max(classify(&arena, c.lhs));
378        }
379
380        let k = match (has_int, class) {
381            (false, ExprClass::Linear) => ModelKind::LP,
382            (true, ExprClass::Linear) => ModelKind::MILP,
383            (false, ExprClass::Quadratic) => ModelKind::QP,
384            (true, ExprClass::Quadratic) => ModelKind::MIQP,
385            (false, ExprClass::Nonlinear) => ModelKind::NLP,
386            (true, ExprClass::Nonlinear) => ModelKind::MINLP,
387        };
388        *self.cached_kind.borrow_mut() = Some(k);
389        k
390    }
391}
392
393// IndexedVarBuilder
394
395/// Builder for a collection of scalar variables indexed by a [`Set`].
396///
397/// For example, `flow[i]` for `i in 0..3` registers `flow[0]`, `flow[1]`, and
398/// `flow[2]` as separate scalar variables in the model. Call `.build()` to get
399/// an [`IndexedVar`] that maps each key to its [`Expr`] handle. Bounds and
400/// domain set here apply uniformly to every scalar in the collection.
401type BoundFn<'a> = Box<dyn Fn(&IndexKey) -> f64 + 'a>;
402
403#[must_use = "IndexedVarBuilder does nothing until you call .build()"]
404pub struct IndexedVarBuilder<'a> {
405    model: &'a Model,
406    base_name: String,
407    keys: Vec<IndexKey>,
408    lb: f64,
409    ub: f64,
410    lb_by: Option<BoundFn<'a>>,
411    ub_by: Option<BoundFn<'a>>,
412    domain: Domain,
413}
414
415impl<'a> std::fmt::Debug for IndexedVarBuilder<'a> {
416    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
417        f.debug_struct("IndexedVarBuilder")
418            .field("base_name", &self.base_name)
419            .field("keys", &self.keys.len())
420            .field("lb", &self.lb)
421            .field("ub", &self.ub)
422            .field("per_key_lb", &self.lb_by.is_some())
423            .field("per_key_ub", &self.ub_by.is_some())
424            .field("domain", &self.domain)
425            .finish()
426    }
427}
428
429impl<'a> IndexedVarBuilder<'a> {
430    pub fn lb(mut self, v: f64) -> Self {
431        self.lb = v;
432        self
433    }
434    pub fn ub(mut self, v: f64) -> Self {
435        self.ub = v;
436        self
437    }
438    pub fn bounds(mut self, lb: f64, ub: f64) -> Self {
439        self.lb = lb;
440        self.ub = ub;
441        self
442    }
443    /// Per-key lower bound. Overrides [`Self::lb`] when both are set.
444    ///
445    /// The closure receives a typed index value via [`FromIndexKey`].
446    /// Annotate the argument to select the projection:
447    /// ```ignore
448    /// .lb_by(|(p, q): (String, String)| floor_for(&p, &q))
449    /// .lb_by(|i: usize| lower_bounds[i])
450    /// ```
451    pub fn lb_by<K, F>(mut self, f: F) -> Self
452    where
453        K: FromIndexKey,
454        F: Fn(K) -> f64 + 'a,
455    {
456        self.lb_by = Some(Box::new(move |k: &IndexKey| f(K::from_index_key(k))));
457        self
458    }
459    /// Per-key upper bound. Overrides [`Self::ub`] when both are set.
460    ///
461    /// The closure receives a typed index value via [`FromIndexKey`]; annotate
462    /// the argument to select the projection:
463    /// ```ignore
464    /// .ub_by(|(p, q): (String, String)| capacity_for(&p, &q))
465    /// .ub_by(|i: usize| upper_bounds[i])
466    /// ```
467    pub fn ub_by<K, F>(mut self, f: F) -> Self
468    where
469        K: FromIndexKey,
470        F: Fn(K) -> f64 + 'a,
471    {
472        self.ub_by = Some(Box::new(move |k: &IndexKey| f(K::from_index_key(k))));
473        self
474    }
475    pub fn domain(mut self, d: Domain) -> Self {
476        self.domain = d;
477        self
478    }
479    pub fn integer(mut self) -> Self {
480        self.domain = Domain::Integer;
481        self
482    }
483    pub fn binary(mut self) -> Self {
484        self.domain = Domain::Binary;
485        self.lb = 0.0;
486        self.ub = 1.0;
487        self
488    }
489
490    pub fn build(self) -> IndexedVar<'a> {
491        let mut entries = FxHashMap::default();
492        for key in self.keys {
493            let scalar_name: SmolStr = format_index_name(&self.base_name, &key).into();
494            let lb = self.lb_by.as_ref().map_or(self.lb, |f| f(&key));
495            let ub = self.ub_by.as_ref().map_or(self.ub, |f| f(&key));
496            let expr = self.model.var(scalar_name).lb(lb).ub(ub).domain(self.domain).build();
497            entries.insert(key, expr);
498        }
499        IndexedVar { entries }
500    }
501}
502
503fn format_index_name(base: &str, key: &IndexKey) -> String {
504    let mut out = String::with_capacity(base.len() + 4);
505    out.push_str(base);
506    out.push('[');
507    write_key_parts(&mut out, key);
508    out.push(']');
509    out
510}
511
512fn write_key_parts(out: &mut String, key: &IndexKey) {
513    use std::fmt::Write;
514    match key {
515        IndexKey::Int(i) => write!(out, "{i}").unwrap(),
516        IndexKey::Str(s) => out.push_str(s),
517        IndexKey::Tuple(parts) => {
518            for (i, p) in parts.iter().enumerate() {
519                if i > 0 {
520                    out.push(',');
521                }
522                write_key_parts(out, p);
523            }
524        }
525    }
526}
527
528/// Public render of an `IndexKey`'s textual form, used by helpers like
529/// [`Model::add_constraints_over`] to derive constraint names.
530pub fn display_index_key(key: &IndexKey) -> String {
531    let mut out = String::new();
532    write_key_parts(&mut out, key);
533    out
534}
535
536#[cfg(test)]
537mod tests {
538    use oximo_expr::extract_linear;
539
540    use super::*;
541    use crate::constraint::Relate;
542
543    #[test]
544    fn param_times_var_keeps_model_linear() {
545        let m = Model::new("p");
546        let param = m.param("param", 4.0);
547        let x = m.var("x").lb(0.0).build();
548        m.minimize(param * x);
549        assert_eq!(m.kind(), ModelKind::LP);
550    }
551
552    #[test]
553    fn param_coeff_resolves_and_rebinds() {
554        let m = Model::new("p");
555        let param = m.param("param", 4.0);
556        let x = m.var("x").lb(0.0).build();
557        let obj = param * x;
558
559        let coeff = |m: &Model| {
560            let arena = m.arena();
561            extract_linear(&arena, obj.id).expect("linear").coeffs[0].1
562        };
563        assert!((coeff(&m) - 4.0).abs() < f64::EPSILON);
564
565        m.set_param(param, 9.0);
566        assert!((coeff(&m) - 9.0).abs() < f64::EPSILON);
567        assert_eq!(m.parameter_id("param"), Some(param.param_id().unwrap()));
568    }
569
570    #[test]
571    fn param_value_reads_live_arena_value() {
572        let m = Model::new("p");
573        let param = m.param("param", 4.0);
574        let id = param.param_id().unwrap();
575        assert!((m.param_value(id) - 4.0).abs() < f64::EPSILON);
576        assert!((m.param_value_of(param).unwrap() - 4.0).abs() < f64::EPSILON);
577
578        m.set_param(param, 7.5);
579        assert!((m.param_value(id) - 7.5).abs() < f64::EPSILON);
580
581        let x = m.var("x").build();
582        assert!(m.param_value_of(x).is_none());
583    }
584
585    #[test]
586    fn set_param_invalidates_kind_cache() {
587        let m = Model::new("p");
588        let p = m.param("p", 1.0);
589        let x = m.var("x").lb(0.0).build();
590        m.constraint("c", (p * x).le(10.0));
591        assert_eq!(m.kind(), ModelKind::LP);
592        m.set_param(p, 2.0);
593        assert_eq!(m.kind(), ModelKind::LP);
594    }
595
596    #[test]
597    #[should_panic(expected = "parameter name \"dup\" already registered")]
598    fn duplicate_param_name_panics() {
599        let m = Model::new("p");
600        let _a = m.param("dup", 1.0);
601        let _b = m.param("dup", 2.0);
602    }
603}