Skip to main content

oximo_expr/
linear.rs

1use rustc_hash::{FxBuildHasher, FxHashMap};
2use smallvec::smallvec;
3
4use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
5
6/// Coefficients of a linear expression: `sum(coeff * var) + constant`.
7#[derive(Clone, Debug, Default)]
8pub struct LinearTerms {
9    pub coeffs: Vec<(VarId, f64)>,
10    pub constant: f64,
11}
12
13/// Try to interpret `id` as a linear expression. Returns `None` for any
14/// nonlinear node (Mul of two non-constants, Pow, transcendentals, ...).
15fn as_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
16    match arena.get(id) {
17        ExprNode::Const(c) => Some(LinearTerms { coeffs: Vec::new(), constant: *c }),
18        ExprNode::Var(v) => Some(LinearTerms { coeffs: vec![(*v, 1.0)], constant: 0.0 }),
19        ExprNode::Linear { coeffs, constant } => {
20            Some(LinearTerms { coeffs: coeffs.clone(), constant: *constant })
21        }
22        ExprNode::Neg(inner) => {
23            let inner = *inner;
24            as_linear(arena, inner).map(|mut t| {
25                t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
26                t.constant = -t.constant;
27                t
28            })
29        }
30        ExprNode::Add(children) => {
31            let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
32            let mut acc = LinearTerms::default();
33            let mut map: FxHashMap<VarId, f64> =
34                FxHashMap::with_capacity_and_hasher(children.len() * 4, FxBuildHasher);
35            for child in children {
36                let t = as_linear(arena, child)?;
37                for (v, c) in t.coeffs {
38                    *map.entry(v).or_insert(0.0) += c;
39                }
40                acc.constant += t.constant;
41            }
42            acc.coeffs = map.into_iter().collect();
43            Some(acc)
44        }
45        ExprNode::Mul(children) => {
46            // Linear if and only if exactly one non-const child is linear and the rest are constants.
47            let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
48            let mut scalar = 1.0;
49            let mut linear: Option<LinearTerms> = None;
50            for child in children {
51                if let ExprNode::Const(c) = arena.get(child) {
52                    scalar *= c;
53                } else if linear.is_none() {
54                    linear = Some(as_linear(arena, child)?);
55                } else {
56                    return None;
57                }
58            }
59            Some(match linear {
60                None => LinearTerms { coeffs: Vec::new(), constant: scalar },
61                Some(mut t) => {
62                    t.coeffs.iter_mut().for_each(|(_, c)| *c *= scalar);
63                    t.constant *= scalar;
64                    t
65                }
66            })
67        }
68        _ => None,
69    }
70}
71
72/// Materialize a linear-terms struct into a fresh `Linear` node in the arena.
73fn push_linear(arena: &mut ExprArena, mut t: LinearTerms) -> ExprId {
74    t.coeffs.retain(|(_, c)| *c != 0.0);
75    arena.push(ExprNode::Linear { coeffs: t.coeffs, constant: t.constant })
76}
77
78/// Build `lhs + rhs`, preserving the linear fast-path when both sides are
79/// linear. Falls back to an n-ary `Add` node otherwise.
80pub(crate) fn add_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
81    if let (Some(lt), Some(rt)) = (as_linear(arena, lhs), as_linear(arena, rhs)) {
82        let mut map: FxHashMap<VarId, f64> =
83            FxHashMap::with_capacity_and_hasher(lt.coeffs.len() + rt.coeffs.len(), FxBuildHasher);
84        for (v, c) in lt.coeffs.into_iter().chain(rt.coeffs) {
85            *map.entry(v).or_insert(0.0) += c;
86        }
87        return push_linear(
88            arena,
89            LinearTerms { coeffs: map.into_iter().collect(), constant: lt.constant + rt.constant },
90        );
91    }
92    arena.push(ExprNode::Add(smallvec![lhs, rhs]))
93}
94
95/// Build `lhs - rhs`. Same linear fast-path as `add_into`.
96pub(crate) fn sub_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
97    let neg = neg_into(arena, rhs);
98    add_into(arena, lhs, neg)
99}
100
101/// Build `lhs * rhs`. If either side is constant and the other is linear, we
102/// stay on the linear fast-path. Otherwise produce a generic n-ary `Mul`.
103pub(crate) fn mul_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
104    if let ExprNode::Const(c) = *arena.get(lhs) {
105        if let Some(mut t) = as_linear(arena, rhs) {
106            t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
107            t.constant *= c;
108            return push_linear(arena, t);
109        }
110    }
111    if let ExprNode::Const(c) = *arena.get(rhs) {
112        if let Some(mut t) = as_linear(arena, lhs) {
113            t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
114            t.constant *= c;
115            return push_linear(arena, t);
116        }
117    }
118    arena.push(ExprNode::Mul(smallvec![lhs, rhs]))
119}
120
121/// Build `-rhs`, preserving linearity.
122pub(crate) fn neg_into(arena: &mut ExprArena, rhs: ExprId) -> ExprId {
123    if let Some(mut t) = as_linear(arena, rhs) {
124        t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
125        t.constant = -t.constant;
126        return push_linear(arena, t);
127    }
128    arena.push(ExprNode::Neg(rhs))
129}
130
131/// Snapshot the linear terms of `id`, if any. Used by solver backends to
132/// extract LP coefficients without walking the tree themselves.
133pub fn extract_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
134    as_linear(arena, id)
135}