Skip to main content

oximo_expr/
quadratic.rs

1use rustc_hash::{FxBuildHasher, FxHashMap};
2
3use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
4
5/// Quadratic decomposition of an expression: its Hessian, gradient-linear
6/// part, and constant.
7///
8/// For a degree-`<= 2` polynomial `f(x)`, this holds the exact Taylor data
9///
10/// ```text
11/// f(x) = constant + sum_i linear_i * x_i + 0.5 * x' Q x
12/// ```
13///
14/// where `Q` is the (constant) Hessian. Returned by
15/// [`extract_quadratic`].
16#[derive(Clone, Debug, Default)]
17pub struct QuadraticTerms {
18    /// Lower-triangle Hessian entries `(row, col, h)` with `row >= col`, where
19    /// `h = partial^2 f / partial x_row partial x_col`. Diagonal entries are the full second
20    /// derivative, so `a * x^2` yields `(x, x, 2a)`. This matches the
21    /// `0.5 * x' Q x` convention used by QP solvers.
22    pub hessian: Vec<(VarId, VarId, f64)>,
23    /// Linear coefficients `(var, coeff)`, the gradient of `f` at the origin.
24    pub linear: Vec<(VarId, f64)>,
25    /// The constant term `f(0)`.
26    pub constant: f64,
27}
28
29/// Internal accumulator while walking the expression. `quad` keys are ordered
30/// `(min, max)` variable pairs and hold the polynomial coefficient of
31/// `x_i * x_j` (i.e. the coefficient of `x_i^2` on the diagonal), not yet the
32/// doubled Hessian value.
33#[derive(Default)]
34struct Poly {
35    quad: FxHashMap<(VarId, VarId), f64>,
36    linear: FxHashMap<VarId, f64>,
37    constant: f64,
38}
39
40impl Poly {
41    fn constant(c: f64) -> Self {
42        Self { constant: c, ..Self::default() }
43    }
44
45    fn var(v: VarId) -> Self {
46        let mut linear = FxHashMap::with_capacity_and_hasher(1, FxBuildHasher);
47        linear.insert(v, 1.0);
48        Self { linear, ..Self::default() }
49    }
50
51    fn is_constant(&self) -> bool {
52        self.quad.is_empty() && self.linear.is_empty()
53    }
54
55    fn is_linear(&self) -> bool {
56        self.quad.is_empty()
57    }
58
59    fn scale(mut self, s: f64) -> Self {
60        self.constant *= s;
61        for c in self.linear.values_mut() {
62            *c *= s;
63        }
64        for c in self.quad.values_mut() {
65            *c *= s;
66        }
67        self
68    }
69
70    fn neg(self) -> Self {
71        self.scale(-1.0)
72    }
73
74    fn add_assign(&mut self, other: Poly) {
75        self.constant += other.constant;
76        for (v, c) in other.linear {
77            *self.linear.entry(v).or_insert(0.0) += c;
78        }
79        for (k, c) in other.quad {
80            *self.quad.entry(k).or_insert(0.0) += c;
81        }
82    }
83}
84
85/// Ordered `(min, max)` variable pair, used as the canonical quad key.
86fn pair(a: VarId, b: VarId) -> (VarId, VarId) {
87    if a.0 <= b.0 { (a, b) } else { (b, a) }
88}
89
90/// Multiply two linear polynomials, producing the degree-2 product. Both
91/// operands must be linear (`quad` empty), the caller guarantees this.
92fn mul_linear(a: &Poly, b: &Poly) -> Poly {
93    let mut out = Poly::constant(a.constant * b.constant);
94    // a.constant * b.linear + b.constant * a.linear
95    for (v, c) in &b.linear {
96        *out.linear.entry(*v).or_insert(0.0) += a.constant * c;
97    }
98    for (v, c) in &a.linear {
99        *out.linear.entry(*v).or_insert(0.0) += b.constant * c;
100    }
101    // a.linear[i] * b.linear[j] -> quad term x_i x_j
102    for (vi, ci) in &a.linear {
103        for (vj, cj) in &b.linear {
104            *out.quad.entry(pair(*vi, *vj)).or_insert(0.0) += ci * cj;
105        }
106    }
107    out
108}
109
110/// Recursively interpret `id` as a polynomial of degree `<= 2`. Returns `None`
111/// for anything of higher degree, transcendentals, or division. Parameters fold
112/// to their live arena value (a degree-0 constant).
113fn as_poly(arena: &ExprArena, id: ExprId) -> Option<Poly> {
114    match arena.get(id) {
115        ExprNode::Const(c) => Some(Poly::constant(*c)),
116        ExprNode::Var(v) => Some(Poly::var(*v)),
117        ExprNode::Linear { coeffs, constant } => {
118            let mut linear: FxHashMap<VarId, f64> =
119                FxHashMap::with_capacity_and_hasher(coeffs.len(), FxBuildHasher);
120            for (v, c) in coeffs {
121                *linear.entry(*v).or_insert(0.0) += *c;
122            }
123            Some(Poly { quad: FxHashMap::default(), linear, constant: *constant })
124        }
125        ExprNode::Neg(inner) => as_poly(arena, *inner).map(Poly::neg),
126        ExprNode::Add(children) => {
127            let mut acc = Poly::default();
128            for child in children {
129                acc.add_assign(as_poly(arena, *child)?);
130            }
131            Some(acc)
132        }
133        ExprNode::Mul(children) => {
134            let mut acc = Poly::constant(1.0);
135            for child in children {
136                let p = as_poly(arena, *child)?;
137                acc = if acc.is_constant() {
138                    p.scale(acc.constant)
139                } else if p.is_constant() {
140                    acc.scale(p.constant)
141                } else if acc.is_linear() && p.is_linear() {
142                    mul_linear(&acc, &p)
143                } else {
144                    return None;
145                };
146            }
147            Some(acc)
148        }
149        ExprNode::Pow(base, exp) => {
150            let ExprNode::Const(e) = arena.get(*exp) else { return None };
151            if (*e - e.round()).abs() >= f64::EPSILON || *e < 0.0 {
152                return None;
153            }
154            match e.round() {
155                n if n < 0.5 => Some(Poly::constant(1.0)),
156                n if n < 1.5 => as_poly(arena, *base),
157                n if n < 2.5 => {
158                    let p = as_poly(arena, *base)?;
159                    if !p.is_linear() {
160                        return None;
161                    }
162                    Some(mul_linear(&p, &p))
163                }
164                _ => None,
165            }
166        }
167        ExprNode::Param(p) => Some(Poly::constant(arena.param_value(*p))),
168        ExprNode::Div(_, _)
169        | ExprNode::Sin(_)
170        | ExprNode::Cos(_)
171        | ExprNode::Exp(_)
172        | ExprNode::Log(_)
173        | ExprNode::Abs(_) => None,
174    }
175}
176
177/// Snapshot the quadratic structure of `id`, if it is a polynomial of degree
178/// `<= 2`. Returns the Hessian (lower triangle), the linear
179/// coefficients, and the constant (see [`QuadraticTerms`]).
180///
181/// `None` is returned for any expression `classify` would call
182/// `Nonlinear` (degree `> 2`, transcendentals, non-integer/negative powers,
183/// division). Parameters are folded to their current arena values, so a
184/// polynomial whose coefficients are parameters is still extracted.
185///
186/// A purely linear (or constant) expression yields an empty `hessian`.
187pub fn extract_quadratic(arena: &ExprArena, id: ExprId) -> Option<QuadraticTerms> {
188    let poly = as_poly(arena, id)?;
189
190    let mut hessian: Vec<(VarId, VarId, f64)> = Vec::with_capacity(poly.quad.len());
191    for ((lo, hi), c) in poly.quad {
192        if c == 0.0 {
193            continue;
194        }
195        if lo == hi {
196            // Diagonal: partial^2 (c x^2)/partial x^2 = 2c.
197            hessian.push((lo, lo, 2.0 * c));
198        } else {
199            // Off-diagonal: store in the lower triangle (row = larger index).
200            hessian.push((hi, lo, c));
201        }
202    }
203
204    let mut linear: Vec<(VarId, f64)> =
205        poly.linear.into_iter().filter(|(_, c)| *c != 0.0).collect();
206    linear.sort_unstable_by_key(|(v, _)| v.0);
207    hessian.sort_unstable_by_key(|(r, c, _)| (c.0, r.0));
208
209    Some(QuadraticTerms { hessian, linear, constant: poly.constant })
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::arena::{ExprArena, ExprNode, VarId};
216    use smallvec::smallvec;
217
218    fn var(arena: &mut ExprArena, i: u32) -> ExprId {
219        arena.push(ExprNode::Var(VarId(i)))
220    }
221
222    fn v(i: u32) -> VarId {
223        VarId(i)
224    }
225
226    #[test]
227    fn square_doubles_diagonal() {
228        // x0^2 -> Hessian (0,0,2), no linear, constant 0.
229        let mut a = ExprArena::new();
230        let x = var(&mut a, 0);
231        let two = a.push(ExprNode::Const(2.0));
232        let sq = a.push(ExprNode::Pow(x, two));
233        let q = extract_quadratic(&a, sq).unwrap();
234        assert_eq!(q.hessian, vec![(v(0), v(0), 2.0)]);
235        assert!(q.linear.is_empty());
236        assert!(q.constant.abs() < f64::EPSILON);
237    }
238
239    #[test]
240    fn bilinear_off_diagonal() {
241        let mut a = ExprArena::new();
242        let x = var(&mut a, 0);
243        let y = var(&mut a, 1);
244        let xy = a.push(ExprNode::Mul(smallvec![x, y]));
245        let q = extract_quadratic(&a, xy).unwrap();
246        assert_eq!(q.hessian, vec![(v(1), v(0), 1.0)]);
247        assert!(q.linear.is_empty());
248    }
249
250    #[test]
251    fn cvxopt_objective_recovers_hessian() {
252        // 2*x0^2 + x0*x1 + x1^2 + x0 + x1 -> Q = [[4,1],[1,2]], c = [1,1].
253        let mut a = ExprArena::new();
254        let x0 = var(&mut a, 0);
255        let x1 = var(&mut a, 1);
256        let two = a.push(ExprNode::Const(2.0));
257        let x0sq = a.push(ExprNode::Pow(x0, two));
258        let term0 = a.push(ExprNode::Mul(smallvec![two, x0sq]));
259        let x0x1 = a.push(ExprNode::Mul(smallvec![x0, x1]));
260        let two_b = a.push(ExprNode::Const(2.0));
261        let x1sq = a.push(ExprNode::Pow(x1, two_b));
262        let sum = a.push(ExprNode::Add(smallvec![term0, x0x1, x1sq, x0, x1]));
263        let q = extract_quadratic(&a, sum).unwrap();
264        assert_eq!(q.hessian, vec![(v(0), v(0), 4.0), (v(1), v(0), 1.0), (v(1), v(1), 2.0)]);
265        assert_eq!(q.linear, vec![(v(0), 1.0), (v(1), 1.0)]);
266        assert!(q.constant.abs() < f64::EPSILON);
267    }
268
269    #[test]
270    fn square_of_sum_cross_term() {
271        // (x0 + x1)^2 = x0^2 + 2 x0 x1 + x1^2 -> Q = [[2,2],[2,2]].
272        let mut a = ExprArena::new();
273        let x0 = var(&mut a, 0);
274        let x1 = var(&mut a, 1);
275        let sum = a.push(ExprNode::Add(smallvec![x0, x1]));
276        let two = a.push(ExprNode::Const(2.0));
277        let sq = a.push(ExprNode::Pow(sum, two));
278        let q = extract_quadratic(&a, sq).unwrap();
279        assert_eq!(q.hessian, vec![(v(0), v(0), 2.0), (v(1), v(0), 2.0), (v(1), v(1), 2.0)]);
280    }
281
282    #[test]
283    fn linear_only_has_empty_hessian() {
284        // 3*x0 + 5 -> empty hessian, linear [(0,3)], constant 5.
285        let mut a = ExprArena::new();
286        let x = var(&mut a, 0);
287        let three = a.push(ExprNode::Const(3.0));
288        let mul = a.push(ExprNode::Mul(smallvec![three, x]));
289        let five = a.push(ExprNode::Const(5.0));
290        let expr = a.push(ExprNode::Add(smallvec![mul, five]));
291        let q = extract_quadratic(&a, expr).unwrap();
292        assert!(q.hessian.is_empty());
293        assert_eq!(q.linear, vec![(v(0), 3.0)]);
294        assert!((q.constant - 5.0).abs() < f64::EPSILON);
295    }
296
297    #[test]
298    fn constant_only() {
299        let mut a = ExprArena::new();
300        let c = a.push(ExprNode::Const(7.0));
301        let q = extract_quadratic(&a, c).unwrap();
302        assert!(q.hessian.is_empty());
303        assert!(q.linear.is_empty());
304        assert!((q.constant - 7.0).abs() < f64::EPSILON);
305    }
306
307    #[test]
308    fn negation_flips_signs() {
309        let mut a = ExprArena::new();
310        let x = var(&mut a, 0);
311        let two = a.push(ExprNode::Const(2.0));
312        let sq = a.push(ExprNode::Pow(x, two));
313        let inner = a.push(ExprNode::Add(smallvec![sq, x]));
314        let neg = a.push(ExprNode::Neg(inner));
315        let q = extract_quadratic(&a, neg).unwrap();
316        assert_eq!(q.hessian, vec![(v(0), v(0), -2.0)]);
317        assert_eq!(q.linear, vec![(v(0), -1.0)]);
318    }
319
320    #[test]
321    fn cubic_is_none() {
322        let mut a = ExprArena::new();
323        let x = var(&mut a, 0);
324        let three = a.push(ExprNode::Const(3.0));
325        let cube = a.push(ExprNode::Pow(x, three));
326        assert!(extract_quadratic(&a, cube).is_none());
327    }
328
329    #[test]
330    fn triple_product_is_none() {
331        let mut a = ExprArena::new();
332        let x = var(&mut a, 0);
333        let y = var(&mut a, 1);
334        let z = var(&mut a, 2);
335        let prod = a.push(ExprNode::Mul(smallvec![x, y, z]));
336        assert!(extract_quadratic(&a, prod).is_none());
337    }
338
339    #[test]
340    fn transcendental_is_none() {
341        let mut a = ExprArena::new();
342        let x = var(&mut a, 0);
343        let s = a.push(ExprNode::Sin(x));
344        assert!(extract_quadratic(&a, s).is_none());
345    }
346
347    #[test]
348    fn const_times_square_scales() {
349        let mut a = ExprArena::new();
350        let x = var(&mut a, 0);
351        let y = var(&mut a, 1);
352        let xy = a.push(ExprNode::Mul(smallvec![x, y]));
353        let three = a.push(ExprNode::Const(3.0));
354        let scaled = a.push(ExprNode::Mul(smallvec![three, xy]));
355        let q = extract_quadratic(&a, scaled).unwrap();
356        assert_eq!(q.hessian, vec![(v(1), v(0), 3.0)]);
357    }
358}