Skip to main content

oximo_expr/
linear.rs

1use rustc_hash::{FxBuildHasher, FxHashMap};
2use smallvec::smallvec;
3
4use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
5
6/// Coefficients of a linear expression: `sum(coeff * var) + constant`.
7#[derive(Clone, Debug, Default)]
8pub struct LinearTerms {
9    pub coeffs: Vec<(VarId, f64)>,
10    pub constant: f64,
11}
12
13/// Try to interpret `id` as a linear expression. Returns `None` for any
14/// nonlinear node (Mul of two non-constants, Pow, transcendentals, ...).
15///
16/// When `resolve_params` is set, a [`ExprNode::Param`] folds to its current
17/// arena value and counts as a constant.
18fn as_linear(arena: &ExprArena, id: ExprId, resolve_params: bool) -> Option<LinearTerms> {
19    match arena.get(id) {
20        ExprNode::Const(c) => Some(LinearTerms { coeffs: Vec::new(), constant: *c }),
21        ExprNode::Param(p) if resolve_params => {
22            Some(LinearTerms { coeffs: Vec::new(), constant: arena.param_value(*p) })
23        }
24        ExprNode::Var(v) => Some(LinearTerms { coeffs: vec![(*v, 1.0)], constant: 0.0 }),
25        ExprNode::Linear { coeffs, constant } => {
26            Some(LinearTerms { coeffs: coeffs.clone(), constant: *constant })
27        }
28        ExprNode::Neg(inner) => {
29            let inner = *inner;
30            as_linear(arena, inner, resolve_params).map(|mut t| {
31                t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
32                t.constant = -t.constant;
33                t
34            })
35        }
36        ExprNode::Add(children) => {
37            let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
38            let mut acc = LinearTerms::default();
39            let mut map: FxHashMap<VarId, f64> =
40                FxHashMap::with_capacity_and_hasher(children.len() * 4, FxBuildHasher);
41            for child in children {
42                let t = as_linear(arena, child, resolve_params)?;
43                for (v, c) in t.coeffs {
44                    *map.entry(v).or_insert(0.0) += c;
45                }
46                acc.constant += t.constant;
47            }
48            acc.coeffs = map.into_iter().collect();
49            Some(acc)
50        }
51        ExprNode::Mul(children) => {
52            // Linear if and only if exactly one non-const child is linear and the rest are constants.
53            let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
54            let mut scalar = 1.0;
55            let mut linear: Option<LinearTerms> = None;
56            for child in children {
57                match arena.get(child) {
58                    ExprNode::Const(c) => scalar *= c,
59                    ExprNode::Param(p) if resolve_params => scalar *= arena.param_value(*p),
60                    _ if linear.is_none() => {
61                        linear = Some(as_linear(arena, child, resolve_params)?);
62                    }
63                    _ => return None,
64                }
65            }
66            Some(match linear {
67                None => LinearTerms { coeffs: Vec::new(), constant: scalar },
68                Some(mut t) => {
69                    t.coeffs.iter_mut().for_each(|(_, c)| *c *= scalar);
70                    t.constant *= scalar;
71                    t
72                }
73            })
74        }
75        _ => None,
76    }
77}
78
79/// Materialize a linear-terms struct into a fresh `Linear` node in the arena.
80fn push_linear(arena: &mut ExprArena, mut t: LinearTerms) -> ExprId {
81    t.coeffs.retain(|(_, c)| *c != 0.0);
82    arena.push(ExprNode::Linear { coeffs: t.coeffs, constant: t.constant })
83}
84
85/// Build `lhs + rhs`, preserving the linear fast-path when both sides are
86/// linear. Falls back to an n-ary `Add` node otherwise.
87pub(crate) fn add_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
88    if let (Some(lt), Some(rt)) = (as_linear(arena, lhs, false), as_linear(arena, rhs, false)) {
89        let mut map: FxHashMap<VarId, f64> =
90            FxHashMap::with_capacity_and_hasher(lt.coeffs.len() + rt.coeffs.len(), FxBuildHasher);
91        for (v, c) in lt.coeffs.into_iter().chain(rt.coeffs) {
92            *map.entry(v).or_insert(0.0) += c;
93        }
94        return push_linear(
95            arena,
96            LinearTerms { coeffs: map.into_iter().collect(), constant: lt.constant + rt.constant },
97        );
98    }
99    arena.push(ExprNode::Add(smallvec![lhs, rhs]))
100}
101
102/// Build `lhs - rhs`. Same linear fast-path as `add_into`.
103pub(crate) fn sub_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
104    let neg = neg_into(arena, rhs);
105    add_into(arena, lhs, neg)
106}
107
108/// Build `lhs * rhs`. If either side is constant and the other is linear, we
109/// stay on the linear fast-path. Otherwise produce a generic n-ary `Mul`.
110pub(crate) fn mul_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
111    if let ExprNode::Const(c) = *arena.get(lhs) {
112        if let Some(mut t) = as_linear(arena, rhs, false) {
113            t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
114            t.constant *= c;
115            return push_linear(arena, t);
116        }
117    }
118    if let ExprNode::Const(c) = *arena.get(rhs) {
119        if let Some(mut t) = as_linear(arena, lhs, false) {
120            t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
121            t.constant *= c;
122            return push_linear(arena, t);
123        }
124    }
125    arena.push(ExprNode::Mul(smallvec![lhs, rhs]))
126}
127
128/// Build `num / den`. If `den` is a nonzero constant `c`, fold to `num * (1/c)`
129/// so a constant-denominator division stays on the linear fast-path. Otherwise
130/// produce a `Div` node (always nonlinear, even when the numerator is linear).
131pub(crate) fn div_into(arena: &mut ExprArena, num: ExprId, den: ExprId) -> ExprId {
132    if let ExprNode::Const(c) = *arena.get(den) {
133        if c != 0.0 {
134            if let Some(mut t) = as_linear(arena, num, false) {
135                let inv = 1.0 / c;
136                t.coeffs.iter_mut().for_each(|(_, co)| *co *= inv);
137                t.constant *= inv;
138                return push_linear(arena, t);
139            }
140            let inv = arena.push(ExprNode::Const(1.0 / c));
141            return mul_into(arena, num, inv);
142        }
143    }
144    arena.push(ExprNode::Div(num, den))
145}
146
147/// Build `-rhs`, preserving linearity.
148pub(crate) fn neg_into(arena: &mut ExprArena, rhs: ExprId) -> ExprId {
149    if let Some(mut t) = as_linear(arena, rhs, false) {
150        t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
151        t.constant = -t.constant;
152        return push_linear(arena, t);
153    }
154    arena.push(ExprNode::Neg(rhs))
155}
156
157/// Snapshot the linear terms of `id`, if any. Used by solver backends to
158/// extract LP coefficients without walking the tree themselves.
159///
160/// Parameters are folded to their current arena values, so the returned
161/// coefficients reflect the latest [`ExprArena::set_param_value`] binding.
162///
163/// [`ExprArena::set_param_value`]: crate::ExprArena::set_param_value
164pub fn extract_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
165    as_linear(arena, id, true)
166}
167
168/// A nonlinear residual summand: the existing arena node `id`, taken with a
169/// leading negation when `neg` is set. Carrying the sign as a flag.
170/// Lets [`split_linear`] run without a mutable arena.
171#[derive(Copy, Clone, Debug, PartialEq, Eq)]
172pub struct SignedExpr {
173    pub id: ExprId,
174    pub neg: bool,
175}
176
177/// Split an expression into its linear part and a nonlinear residual. The
178/// returned `(LinearTerms, Vec<SignedExpr>)` satisfies
179///
180/// ```text
181/// value(id) == sum_i coef_i * var_i + constant + sum_j (-1)^neg_j value(id_j)
182/// ```
183///
184/// where the residual is empty when the whole expression is linear and
185/// otherwise lists the remaining nonlinear summands (each a pre-existing arena
186/// node, optionally negated). `LinearTerms` may have empty `coeffs` and
187/// `constant == 0.0` when the whole expression is purely nonlinear.
188pub fn split_linear(arena: &ExprArena, id: ExprId) -> (LinearTerms, Vec<SignedExpr>) {
189    if let Some(lt) = as_linear(arena, id, true) {
190        return (lt, Vec::new());
191    }
192    let mut lin = LinearTerms::default();
193    let mut residual: Vec<SignedExpr> = Vec::new();
194    let mut sign_stack: smallvec::SmallVec<[(ExprId, f64); 8]> = smallvec![(id, 1.0)];
195    while let Some((cur, sign)) = sign_stack.pop() {
196        match arena.get(cur) {
197            ExprNode::Add(children) => {
198                for c in children.iter().copied() {
199                    sign_stack.push((c, sign));
200                }
201            }
202            ExprNode::Neg(inner) => sign_stack.push((*inner, -sign)),
203            _ => {
204                if let Some(mut t) = as_linear(arena, cur, true) {
205                    if (sign - 1.0).abs() > 0.0 {
206                        t.coeffs.iter_mut().for_each(|(_, c)| *c *= sign);
207                        t.constant *= sign;
208                    }
209                    for (v, c) in t.coeffs {
210                        if let Some((_, acc)) = lin.coeffs.iter_mut().find(|(vv, _)| *vv == v) {
211                            *acc += c;
212                        } else {
213                            lin.coeffs.push((v, c));
214                        }
215                    }
216                    lin.constant += t.constant;
217                } else {
218                    residual.push(SignedExpr { id: cur, neg: sign < 0.0 });
219                }
220            }
221        }
222    }
223    lin.coeffs.retain(|(_, c)| *c != 0.0);
224    (lin, residual)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::arena::{ExprArena, ExprNode, VarId};
231
232    #[test]
233    fn param_times_var_stays_symbolic_until_extracted() {
234        // Build `price * x` through the operator helper. The parameter must NOT
235        // be folded into a Linear node at build time (so it stays re-bindable)
236        let mut arena = ExprArena::new();
237        let pid = arena.new_param(3.0);
238        let price = arena.param(pid);
239        let xnode = arena.push(ExprNode::Var(VarId(0)));
240        let prod = mul_into(&mut arena, price, xnode);
241        assert!(matches!(arena.get(prod), ExprNode::Mul(_)));
242
243        let terms = extract_linear(&arena, prod).expect("linear");
244        assert_eq!(terms.coeffs, vec![(VarId(0), 3.0)]);
245        assert!(terms.constant.abs() < f64::EPSILON);
246    }
247
248    #[test]
249    fn rebinding_param_updates_extracted_coeff() {
250        let mut arena = ExprArena::new();
251        let pid = arena.new_param(3.0);
252        let price = arena.param(pid);
253        let xnode = arena.push(ExprNode::Var(VarId(0)));
254        let prod = mul_into(&mut arena, price, xnode);
255
256        arena.set_param_value(pid, 10.0);
257        let terms = extract_linear(&arena, prod).expect("linear");
258        assert_eq!(terms.coeffs, vec![(VarId(0), 10.0)]);
259    }
260
261    #[test]
262    fn param_plus_var_resolves_constant() {
263        let mut arena = ExprArena::new();
264        let pid = arena.new_param(5.0);
265        let price = arena.param(pid);
266        let xnode = arena.push(ExprNode::Var(VarId(0)));
267        let sum = add_into(&mut arena, price, xnode);
268        let terms = extract_linear(&arena, sum).expect("linear");
269        assert_eq!(terms.coeffs, vec![(VarId(0), 1.0)]);
270        assert!((terms.constant - 5.0).abs() < f64::EPSILON);
271    }
272}