1use rustc_hash::{FxBuildHasher, FxHashMap};
2
3use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
4
5#[derive(Clone, Debug, Default)]
17pub struct QuadraticTerms {
18 pub hessian: Vec<(VarId, VarId, f64)>,
23 pub linear: Vec<(VarId, f64)>,
25 pub constant: f64,
27}
28
29#[derive(Default)]
34struct Poly {
35 quad: FxHashMap<(VarId, VarId), f64>,
36 linear: FxHashMap<VarId, f64>,
37 constant: f64,
38}
39
40impl Poly {
41 fn constant(c: f64) -> Self {
42 Self { constant: c, ..Self::default() }
43 }
44
45 fn var(v: VarId) -> Self {
46 let mut linear = FxHashMap::with_capacity_and_hasher(1, FxBuildHasher);
47 linear.insert(v, 1.0);
48 Self { linear, ..Self::default() }
49 }
50
51 fn is_constant(&self) -> bool {
52 self.quad.is_empty() && self.linear.is_empty()
53 }
54
55 fn is_linear(&self) -> bool {
56 self.quad.is_empty()
57 }
58
59 fn scale(mut self, s: f64) -> Self {
60 self.constant *= s;
61 for c in self.linear.values_mut() {
62 *c *= s;
63 }
64 for c in self.quad.values_mut() {
65 *c *= s;
66 }
67 self
68 }
69
70 fn neg(self) -> Self {
71 self.scale(-1.0)
72 }
73
74 fn add_assign(&mut self, other: Poly) {
75 self.constant += other.constant;
76 for (v, c) in other.linear {
77 *self.linear.entry(v).or_insert(0.0) += c;
78 }
79 for (k, c) in other.quad {
80 *self.quad.entry(k).or_insert(0.0) += c;
81 }
82 }
83}
84
85fn pair(a: VarId, b: VarId) -> (VarId, VarId) {
87 if a.0 <= b.0 { (a, b) } else { (b, a) }
88}
89
90fn mul_linear(a: &Poly, b: &Poly) -> Poly {
93 let mut out = Poly::constant(a.constant * b.constant);
94 for (v, c) in &b.linear {
96 *out.linear.entry(*v).or_insert(0.0) += a.constant * c;
97 }
98 for (v, c) in &a.linear {
99 *out.linear.entry(*v).or_insert(0.0) += b.constant * c;
100 }
101 for (vi, ci) in &a.linear {
103 for (vj, cj) in &b.linear {
104 *out.quad.entry(pair(*vi, *vj)).or_insert(0.0) += ci * cj;
105 }
106 }
107 out
108}
109
110fn as_poly(arena: &ExprArena, id: ExprId) -> Option<Poly> {
114 match arena.get(id) {
115 ExprNode::Const(c) => Some(Poly::constant(*c)),
116 ExprNode::Var(v) => Some(Poly::var(*v)),
117 ExprNode::Linear { coeffs, constant } => {
118 let mut linear: FxHashMap<VarId, f64> =
119 FxHashMap::with_capacity_and_hasher(coeffs.len(), FxBuildHasher);
120 for (v, c) in coeffs {
121 *linear.entry(*v).or_insert(0.0) += *c;
122 }
123 Some(Poly { quad: FxHashMap::default(), linear, constant: *constant })
124 }
125 ExprNode::Neg(inner) => as_poly(arena, *inner).map(Poly::neg),
126 ExprNode::Add(children) => {
127 let mut acc = Poly::default();
128 for child in children {
129 acc.add_assign(as_poly(arena, *child)?);
130 }
131 Some(acc)
132 }
133 ExprNode::Mul(children) => {
134 let mut acc = Poly::constant(1.0);
135 for child in children {
136 let p = as_poly(arena, *child)?;
137 acc = if acc.is_constant() {
138 p.scale(acc.constant)
139 } else if p.is_constant() {
140 acc.scale(p.constant)
141 } else if acc.is_linear() && p.is_linear() {
142 mul_linear(&acc, &p)
143 } else {
144 return None;
145 };
146 }
147 Some(acc)
148 }
149 ExprNode::Pow(base, exp) => {
150 let ExprNode::Const(e) = arena.get(*exp) else { return None };
151 if (*e - e.round()).abs() >= f64::EPSILON || *e < 0.0 {
152 return None;
153 }
154 match e.round() {
155 n if n < 0.5 => Some(Poly::constant(1.0)),
156 n if n < 1.5 => as_poly(arena, *base),
157 n if n < 2.5 => {
158 let p = as_poly(arena, *base)?;
159 if !p.is_linear() {
160 return None;
161 }
162 Some(mul_linear(&p, &p))
163 }
164 _ => None,
165 }
166 }
167 ExprNode::Param(p) => Some(Poly::constant(arena.param_value(*p))),
168 ExprNode::Div(_, _)
169 | ExprNode::Sin(_)
170 | ExprNode::Cos(_)
171 | ExprNode::Exp(_)
172 | ExprNode::Log(_)
173 | ExprNode::Abs(_) => None,
174 }
175}
176
177pub fn extract_quadratic(arena: &ExprArena, id: ExprId) -> Option<QuadraticTerms> {
188 let poly = as_poly(arena, id)?;
189
190 let mut hessian: Vec<(VarId, VarId, f64)> = Vec::with_capacity(poly.quad.len());
191 for ((lo, hi), c) in poly.quad {
192 if c == 0.0 {
193 continue;
194 }
195 if lo == hi {
196 hessian.push((lo, lo, 2.0 * c));
198 } else {
199 hessian.push((hi, lo, c));
201 }
202 }
203
204 let mut linear: Vec<(VarId, f64)> =
205 poly.linear.into_iter().filter(|(_, c)| *c != 0.0).collect();
206 linear.sort_unstable_by_key(|(v, _)| v.0);
207 hessian.sort_unstable_by_key(|(r, c, _)| (c.0, r.0));
208
209 Some(QuadraticTerms { hessian, linear, constant: poly.constant })
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::arena::{ExprArena, ExprNode, VarId};
216 use smallvec::smallvec;
217
218 fn var(arena: &mut ExprArena, i: u32) -> ExprId {
219 arena.push(ExprNode::Var(VarId(i)))
220 }
221
222 fn v(i: u32) -> VarId {
223 VarId(i)
224 }
225
226 #[test]
227 fn square_doubles_diagonal() {
228 let mut a = ExprArena::new();
230 let x = var(&mut a, 0);
231 let two = a.push(ExprNode::Const(2.0));
232 let sq = a.push(ExprNode::Pow(x, two));
233 let q = extract_quadratic(&a, sq).unwrap();
234 assert_eq!(q.hessian, vec![(v(0), v(0), 2.0)]);
235 assert!(q.linear.is_empty());
236 assert!(q.constant.abs() < f64::EPSILON);
237 }
238
239 #[test]
240 fn bilinear_off_diagonal() {
241 let mut a = ExprArena::new();
242 let x = var(&mut a, 0);
243 let y = var(&mut a, 1);
244 let xy = a.push(ExprNode::Mul(smallvec![x, y]));
245 let q = extract_quadratic(&a, xy).unwrap();
246 assert_eq!(q.hessian, vec![(v(1), v(0), 1.0)]);
247 assert!(q.linear.is_empty());
248 }
249
250 #[test]
251 fn cvxopt_objective_recovers_hessian() {
252 let mut a = ExprArena::new();
254 let x0 = var(&mut a, 0);
255 let x1 = var(&mut a, 1);
256 let two = a.push(ExprNode::Const(2.0));
257 let x0sq = a.push(ExprNode::Pow(x0, two));
258 let term0 = a.push(ExprNode::Mul(smallvec![two, x0sq]));
259 let x0x1 = a.push(ExprNode::Mul(smallvec![x0, x1]));
260 let two_b = a.push(ExprNode::Const(2.0));
261 let x1sq = a.push(ExprNode::Pow(x1, two_b));
262 let sum = a.push(ExprNode::Add(smallvec![term0, x0x1, x1sq, x0, x1]));
263 let q = extract_quadratic(&a, sum).unwrap();
264 assert_eq!(q.hessian, vec![(v(0), v(0), 4.0), (v(1), v(0), 1.0), (v(1), v(1), 2.0)]);
265 assert_eq!(q.linear, vec![(v(0), 1.0), (v(1), 1.0)]);
266 assert!(q.constant.abs() < f64::EPSILON);
267 }
268
269 #[test]
270 fn square_of_sum_cross_term() {
271 let mut a = ExprArena::new();
273 let x0 = var(&mut a, 0);
274 let x1 = var(&mut a, 1);
275 let sum = a.push(ExprNode::Add(smallvec![x0, x1]));
276 let two = a.push(ExprNode::Const(2.0));
277 let sq = a.push(ExprNode::Pow(sum, two));
278 let q = extract_quadratic(&a, sq).unwrap();
279 assert_eq!(q.hessian, vec![(v(0), v(0), 2.0), (v(1), v(0), 2.0), (v(1), v(1), 2.0)]);
280 }
281
282 #[test]
283 fn linear_only_has_empty_hessian() {
284 let mut a = ExprArena::new();
286 let x = var(&mut a, 0);
287 let three = a.push(ExprNode::Const(3.0));
288 let mul = a.push(ExprNode::Mul(smallvec![three, x]));
289 let five = a.push(ExprNode::Const(5.0));
290 let expr = a.push(ExprNode::Add(smallvec![mul, five]));
291 let q = extract_quadratic(&a, expr).unwrap();
292 assert!(q.hessian.is_empty());
293 assert_eq!(q.linear, vec![(v(0), 3.0)]);
294 assert!((q.constant - 5.0).abs() < f64::EPSILON);
295 }
296
297 #[test]
298 fn constant_only() {
299 let mut a = ExprArena::new();
300 let c = a.push(ExprNode::Const(7.0));
301 let q = extract_quadratic(&a, c).unwrap();
302 assert!(q.hessian.is_empty());
303 assert!(q.linear.is_empty());
304 assert!((q.constant - 7.0).abs() < f64::EPSILON);
305 }
306
307 #[test]
308 fn negation_flips_signs() {
309 let mut a = ExprArena::new();
310 let x = var(&mut a, 0);
311 let two = a.push(ExprNode::Const(2.0));
312 let sq = a.push(ExprNode::Pow(x, two));
313 let inner = a.push(ExprNode::Add(smallvec![sq, x]));
314 let neg = a.push(ExprNode::Neg(inner));
315 let q = extract_quadratic(&a, neg).unwrap();
316 assert_eq!(q.hessian, vec![(v(0), v(0), -2.0)]);
317 assert_eq!(q.linear, vec![(v(0), -1.0)]);
318 }
319
320 #[test]
321 fn cubic_is_none() {
322 let mut a = ExprArena::new();
323 let x = var(&mut a, 0);
324 let three = a.push(ExprNode::Const(3.0));
325 let cube = a.push(ExprNode::Pow(x, three));
326 assert!(extract_quadratic(&a, cube).is_none());
327 }
328
329 #[test]
330 fn triple_product_is_none() {
331 let mut a = ExprArena::new();
332 let x = var(&mut a, 0);
333 let y = var(&mut a, 1);
334 let z = var(&mut a, 2);
335 let prod = a.push(ExprNode::Mul(smallvec![x, y, z]));
336 assert!(extract_quadratic(&a, prod).is_none());
337 }
338
339 #[test]
340 fn transcendental_is_none() {
341 let mut a = ExprArena::new();
342 let x = var(&mut a, 0);
343 let s = a.push(ExprNode::Sin(x));
344 assert!(extract_quadratic(&a, s).is_none());
345 }
346
347 #[test]
348 fn const_times_square_scales() {
349 let mut a = ExprArena::new();
350 let x = var(&mut a, 0);
351 let y = var(&mut a, 1);
352 let xy = a.push(ExprNode::Mul(smallvec![x, y]));
353 let three = a.push(ExprNode::Const(3.0));
354 let scaled = a.push(ExprNode::Mul(smallvec![three, xy]));
355 let q = extract_quadratic(&a, scaled).unwrap();
356 assert_eq!(q.hessian, vec![(v(1), v(0), 3.0)]);
357 }
358}