Skip to main content

oximo_expr/
linear.rs

1use rustc_hash::{FxBuildHasher, FxHashMap};
2use smallvec::smallvec;
3
4use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
5
6/// Coefficients of a linear expression: `sum(coeff * var) + constant`.
7#[derive(Clone, Debug, Default)]
8pub struct LinearTerms {
9    pub coeffs: Vec<(VarId, f64)>,
10    pub constant: f64,
11}
12
13/// Try to interpret `id` as a linear expression. Returns `None` for any
14/// nonlinear node (Mul of two non-constants, Pow, transcendentals, ...).
15///
16/// When `resolve_params` is set, a [`ExprNode::Param`] folds to its current
17/// arena value and counts as a constant.
18fn as_linear(arena: &ExprArena, id: ExprId, resolve_params: bool) -> Option<LinearTerms> {
19    match arena.get(id) {
20        ExprNode::Const(c) => Some(LinearTerms { coeffs: Vec::new(), constant: *c }),
21        ExprNode::Param(p) if resolve_params => {
22            Some(LinearTerms { coeffs: Vec::new(), constant: arena.param_value(*p) })
23        }
24        ExprNode::Var(v) => Some(LinearTerms { coeffs: vec![(*v, 1.0)], constant: 0.0 }),
25        ExprNode::Linear { coeffs, constant } => {
26            Some(LinearTerms { coeffs: coeffs.clone(), constant: *constant })
27        }
28        ExprNode::Neg(inner) => {
29            let inner = *inner;
30            as_linear(arena, inner, resolve_params).map(|mut t| {
31                t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
32                t.constant = -t.constant;
33                t
34            })
35        }
36        ExprNode::Add(children) => {
37            let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
38            let mut acc = LinearTerms::default();
39            let mut map: FxHashMap<VarId, f64> =
40                FxHashMap::with_capacity_and_hasher(children.len() * 4, FxBuildHasher);
41            for child in children {
42                let t = as_linear(arena, child, resolve_params)?;
43                for (v, c) in t.coeffs {
44                    *map.entry(v).or_insert(0.0) += c;
45                }
46                acc.constant += t.constant;
47            }
48            acc.coeffs = map.into_iter().collect();
49            Some(acc)
50        }
51        ExprNode::Mul(children) => {
52            // Linear if and only if exactly one non-const child is linear and the rest are constants.
53            let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
54            let mut scalar = 1.0;
55            let mut linear: Option<LinearTerms> = None;
56            for child in children {
57                match arena.get(child) {
58                    ExprNode::Const(c) => scalar *= c,
59                    ExprNode::Param(p) if resolve_params => scalar *= arena.param_value(*p),
60                    _ if linear.is_none() => {
61                        linear = Some(as_linear(arena, child, resolve_params)?);
62                    }
63                    _ => return None,
64                }
65            }
66            Some(match linear {
67                None => LinearTerms { coeffs: Vec::new(), constant: scalar },
68                Some(mut t) => {
69                    t.coeffs.iter_mut().for_each(|(_, c)| *c *= scalar);
70                    t.constant *= scalar;
71                    t
72                }
73            })
74        }
75        _ => None,
76    }
77}
78
79/// Materialize a linear-terms struct into a fresh `Linear` node in the arena.
80fn push_linear(arena: &mut ExprArena, mut t: LinearTerms) -> ExprId {
81    t.coeffs.retain(|(_, c)| *c != 0.0);
82    arena.push(ExprNode::Linear { coeffs: t.coeffs, constant: t.constant })
83}
84
85/// Build `lhs + rhs`, preserving the linear fast-path when both sides are
86/// linear. Falls back to an n-ary `Add` node otherwise.
87pub(crate) fn add_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
88    if let (Some(lt), Some(rt)) = (as_linear(arena, lhs, false), as_linear(arena, rhs, false)) {
89        let mut map: FxHashMap<VarId, f64> =
90            FxHashMap::with_capacity_and_hasher(lt.coeffs.len() + rt.coeffs.len(), FxBuildHasher);
91        for (v, c) in lt.coeffs.into_iter().chain(rt.coeffs) {
92            *map.entry(v).or_insert(0.0) += c;
93        }
94        return push_linear(
95            arena,
96            LinearTerms { coeffs: map.into_iter().collect(), constant: lt.constant + rt.constant },
97        );
98    }
99    arena.push(ExprNode::Add(smallvec![lhs, rhs]))
100}
101
102/// Build a flat n-ary sum of `ids` as a single `Add` node.
103/// `as_linear`/`split_linear` collapse the resulting `Add`
104/// in one pass at extraction, so the linear fast-path is preserved.
105///
106/// # Panics
107/// Panics if `ids` is empty (callers supply at least one term).
108pub(crate) fn add_n(arena: &mut ExprArena, ids: &[ExprId]) -> ExprId {
109    match ids {
110        [] => panic!("add_n on an empty term list"),
111        [one] => *one,
112        _ => arena.push(ExprNode::Add(ids.iter().copied().collect())),
113    }
114}
115
116/// Build `lhs - rhs`. Same linear fast-path as `add_into`.
117pub(crate) fn sub_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
118    let neg = neg_into(arena, rhs);
119    add_into(arena, lhs, neg)
120}
121
122/// Build `lhs * rhs`. If either side is constant and the other is linear, we
123/// stay on the linear fast-path. Otherwise produce a generic n-ary `Mul`.
124pub(crate) fn mul_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
125    if let ExprNode::Const(c) = *arena.get(lhs) {
126        if let Some(mut t) = as_linear(arena, rhs, false) {
127            t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
128            t.constant *= c;
129            return push_linear(arena, t);
130        }
131    }
132    if let ExprNode::Const(c) = *arena.get(rhs) {
133        if let Some(mut t) = as_linear(arena, lhs, false) {
134            t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
135            t.constant *= c;
136            return push_linear(arena, t);
137        }
138    }
139    arena.push(ExprNode::Mul(smallvec![lhs, rhs]))
140}
141
142/// Build `num / den`. If `den` is a nonzero constant `c`, fold to `num * (1/c)`
143/// so a constant-denominator division stays on the linear fast-path. Otherwise
144/// produce a `Div` node (always nonlinear, even when the numerator is linear).
145pub(crate) fn div_into(arena: &mut ExprArena, num: ExprId, den: ExprId) -> ExprId {
146    if let ExprNode::Const(c) = *arena.get(den) {
147        if c != 0.0 {
148            if let Some(mut t) = as_linear(arena, num, false) {
149                let inv = 1.0 / c;
150                t.coeffs.iter_mut().for_each(|(_, co)| *co *= inv);
151                t.constant *= inv;
152                return push_linear(arena, t);
153            }
154            let inv = arena.push(ExprNode::Const(1.0 / c));
155            return mul_into(arena, num, inv);
156        }
157    }
158    arena.push(ExprNode::Div(num, den))
159}
160
161/// Build `-rhs`, preserving linearity.
162pub(crate) fn neg_into(arena: &mut ExprArena, rhs: ExprId) -> ExprId {
163    if let Some(mut t) = as_linear(arena, rhs, false) {
164        t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
165        t.constant = -t.constant;
166        return push_linear(arena, t);
167    }
168    arena.push(ExprNode::Neg(rhs))
169}
170
171/// Snapshot the linear terms of `id`, if any. Used by solver backends to
172/// extract LP coefficients without walking the tree themselves.
173///
174/// Parameters are folded to their current arena values, so the returned
175/// coefficients reflect the latest [`ExprArena::set_param_value`] binding.
176///
177/// [`ExprArena::set_param_value`]: crate::ExprArena::set_param_value
178pub fn extract_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
179    as_linear(arena, id, true)
180}
181
182/// A nonlinear residual summand: the existing arena node `id`, taken with a
183/// leading negation when `neg` is set. Carrying the sign as a flag.
184/// Lets [`split_linear`] run without a mutable arena.
185#[derive(Copy, Clone, Debug, PartialEq, Eq)]
186pub struct SignedExpr {
187    pub id: ExprId,
188    pub neg: bool,
189}
190
191/// Split an expression into its linear part and a nonlinear residual. The
192/// returned `(LinearTerms, Vec<SignedExpr>)` satisfies
193///
194/// ```text
195/// value(id) == sum_i coef_i * var_i + constant + sum_j (-1)^neg_j value(id_j)
196/// ```
197///
198/// where the residual is empty when the whole expression is linear and
199/// otherwise lists the remaining nonlinear summands (each a pre-existing arena
200/// node, optionally negated). `LinearTerms` may have empty `coeffs` and
201/// `constant == 0.0` when the whole expression is purely nonlinear.
202pub fn split_linear(arena: &ExprArena, id: ExprId) -> (LinearTerms, Vec<SignedExpr>) {
203    if let Some(lt) = as_linear(arena, id, true) {
204        return (lt, Vec::new());
205    }
206    let mut lin = LinearTerms::default();
207    let mut residual: Vec<SignedExpr> = Vec::new();
208    let mut sign_stack: smallvec::SmallVec<[(ExprId, f64); 8]> = smallvec![(id, 1.0)];
209    while let Some((cur, sign)) = sign_stack.pop() {
210        match arena.get(cur) {
211            ExprNode::Add(children) => {
212                for c in children.iter().copied() {
213                    sign_stack.push((c, sign));
214                }
215            }
216            ExprNode::Neg(inner) => sign_stack.push((*inner, -sign)),
217            _ => {
218                if let Some(mut t) = as_linear(arena, cur, true) {
219                    if (sign - 1.0).abs() > 0.0 {
220                        t.coeffs.iter_mut().for_each(|(_, c)| *c *= sign);
221                        t.constant *= sign;
222                    }
223                    for (v, c) in t.coeffs {
224                        if let Some((_, acc)) = lin.coeffs.iter_mut().find(|(vv, _)| *vv == v) {
225                            *acc += c;
226                        } else {
227                            lin.coeffs.push((v, c));
228                        }
229                    }
230                    lin.constant += t.constant;
231                } else {
232                    residual.push(SignedExpr { id: cur, neg: sign < 0.0 });
233                }
234            }
235        }
236    }
237    lin.coeffs.retain(|(_, c)| *c != 0.0);
238    (lin, residual)
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use crate::arena::{ExprArena, ExprNode, VarId};
245
246    #[test]
247    fn param_times_var_stays_symbolic_until_extracted() {
248        // Build `price * x` through the operator helper. The parameter must NOT
249        // be folded into a Linear node at build time (so it stays re-bindable)
250        let mut arena = ExprArena::new();
251        let pid = arena.new_param(3.0);
252        let price = arena.param(pid);
253        let xnode = arena.push(ExprNode::Var(VarId(0)));
254        let prod = mul_into(&mut arena, price, xnode);
255        assert!(matches!(arena.get(prod), ExprNode::Mul(_)));
256
257        let terms = extract_linear(&arena, prod).expect("linear");
258        assert_eq!(terms.coeffs, vec![(VarId(0), 3.0)]);
259        assert!(terms.constant.abs() < f64::EPSILON);
260    }
261
262    #[test]
263    fn rebinding_param_updates_extracted_coeff() {
264        let mut arena = ExprArena::new();
265        let pid = arena.new_param(3.0);
266        let price = arena.param(pid);
267        let xnode = arena.push(ExprNode::Var(VarId(0)));
268        let prod = mul_into(&mut arena, price, xnode);
269
270        arena.set_param_value(pid, 10.0);
271        let terms = extract_linear(&arena, prod).expect("linear");
272        assert_eq!(terms.coeffs, vec![(VarId(0), 10.0)]);
273    }
274
275    #[test]
276    fn param_plus_var_resolves_constant() {
277        let mut arena = ExprArena::new();
278        let pid = arena.new_param(5.0);
279        let price = arena.param(pid);
280        let xnode = arena.push(ExprNode::Var(VarId(0)));
281        let sum = add_into(&mut arena, price, xnode);
282        let terms = extract_linear(&arena, sum).expect("linear");
283        assert_eq!(terms.coeffs, vec![(VarId(0), 1.0)]);
284        assert!((terms.constant - 5.0).abs() < f64::EPSILON);
285    }
286}