use rustc_hash::{FxBuildHasher, FxHashMap};
use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
#[derive(Clone, Debug, Default)]
pub struct QuadraticTerms {
pub hessian: Vec<(VarId, VarId, f64)>,
pub linear: Vec<(VarId, f64)>,
pub constant: f64,
}
#[derive(Default)]
struct Poly {
quad: FxHashMap<(VarId, VarId), f64>,
linear: FxHashMap<VarId, f64>,
constant: f64,
}
impl Poly {
fn constant(c: f64) -> Self {
Self { constant: c, ..Self::default() }
}
fn var(v: VarId) -> Self {
let mut linear = FxHashMap::with_capacity_and_hasher(1, FxBuildHasher);
linear.insert(v, 1.0);
Self { linear, ..Self::default() }
}
fn is_constant(&self) -> bool {
self.quad.is_empty() && self.linear.is_empty()
}
fn is_linear(&self) -> bool {
self.quad.is_empty()
}
fn scale(mut self, s: f64) -> Self {
self.constant *= s;
for c in self.linear.values_mut() {
*c *= s;
}
for c in self.quad.values_mut() {
*c *= s;
}
self
}
fn neg(self) -> Self {
self.scale(-1.0)
}
fn add_assign(&mut self, other: Poly) {
self.constant += other.constant;
for (v, c) in other.linear {
*self.linear.entry(v).or_insert(0.0) += c;
}
for (k, c) in other.quad {
*self.quad.entry(k).or_insert(0.0) += c;
}
}
}
fn pair(a: VarId, b: VarId) -> (VarId, VarId) {
if a.0 <= b.0 { (a, b) } else { (b, a) }
}
fn mul_linear(a: &Poly, b: &Poly) -> Poly {
let mut out = Poly::constant(a.constant * b.constant);
for (v, c) in &b.linear {
*out.linear.entry(*v).or_insert(0.0) += a.constant * c;
}
for (v, c) in &a.linear {
*out.linear.entry(*v).or_insert(0.0) += b.constant * c;
}
for (vi, ci) in &a.linear {
for (vj, cj) in &b.linear {
*out.quad.entry(pair(*vi, *vj)).or_insert(0.0) += ci * cj;
}
}
out
}
fn as_poly(arena: &ExprArena, id: ExprId) -> Option<Poly> {
match arena.get(id) {
ExprNode::Const(c) => Some(Poly::constant(*c)),
ExprNode::Var(v) => Some(Poly::var(*v)),
ExprNode::Linear { coeffs, constant } => {
let mut linear: FxHashMap<VarId, f64> =
FxHashMap::with_capacity_and_hasher(coeffs.len(), FxBuildHasher);
for (v, c) in coeffs {
*linear.entry(*v).or_insert(0.0) += *c;
}
Some(Poly { quad: FxHashMap::default(), linear, constant: *constant })
}
ExprNode::Neg(inner) => as_poly(arena, *inner).map(Poly::neg),
ExprNode::Add(children) => {
let mut acc = Poly::default();
for child in children {
acc.add_assign(as_poly(arena, *child)?);
}
Some(acc)
}
ExprNode::Mul(children) => {
let mut acc = Poly::constant(1.0);
for child in children {
let p = as_poly(arena, *child)?;
acc = if acc.is_constant() {
p.scale(acc.constant)
} else if p.is_constant() {
acc.scale(p.constant)
} else if acc.is_linear() && p.is_linear() {
mul_linear(&acc, &p)
} else {
return None;
};
}
Some(acc)
}
ExprNode::Pow(base, exp) => {
let ExprNode::Const(e) = arena.get(*exp) else { return None };
if (*e - e.round()).abs() >= f64::EPSILON || *e < 0.0 {
return None;
}
match e.round() {
n if n < 0.5 => Some(Poly::constant(1.0)),
n if n < 1.5 => as_poly(arena, *base),
n if n < 2.5 => {
let p = as_poly(arena, *base)?;
if !p.is_linear() {
return None;
}
Some(mul_linear(&p, &p))
}
_ => None,
}
}
ExprNode::Param(p) => Some(Poly::constant(arena.param_value(*p))),
ExprNode::Div(_, _)
| ExprNode::Sin(_)
| ExprNode::Cos(_)
| ExprNode::Exp(_)
| ExprNode::Log(_)
| ExprNode::Abs(_) => None,
}
}
pub fn extract_quadratic(arena: &ExprArena, id: ExprId) -> Option<QuadraticTerms> {
let poly = as_poly(arena, id)?;
let mut hessian: Vec<(VarId, VarId, f64)> = Vec::with_capacity(poly.quad.len());
for ((lo, hi), c) in poly.quad {
if c == 0.0 {
continue;
}
if lo == hi {
hessian.push((lo, lo, 2.0 * c));
} else {
hessian.push((hi, lo, c));
}
}
let mut linear: Vec<(VarId, f64)> =
poly.linear.into_iter().filter(|(_, c)| *c != 0.0).collect();
linear.sort_unstable_by_key(|(v, _)| v.0);
hessian.sort_unstable_by_key(|(r, c, _)| (c.0, r.0));
Some(QuadraticTerms { hessian, linear, constant: poly.constant })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::{ExprArena, ExprNode, VarId};
use smallvec::smallvec;
fn var(arena: &mut ExprArena, i: u32) -> ExprId {
arena.push(ExprNode::Var(VarId(i)))
}
fn v(i: u32) -> VarId {
VarId(i)
}
#[test]
fn square_doubles_diagonal() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let two = a.push(ExprNode::Const(2.0));
let sq = a.push(ExprNode::Pow(x, two));
let q = extract_quadratic(&a, sq).unwrap();
assert_eq!(q.hessian, vec![(v(0), v(0), 2.0)]);
assert!(q.linear.is_empty());
assert!(q.constant.abs() < f64::EPSILON);
}
#[test]
fn bilinear_off_diagonal() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let y = var(&mut a, 1);
let xy = a.push(ExprNode::Mul(smallvec![x, y]));
let q = extract_quadratic(&a, xy).unwrap();
assert_eq!(q.hessian, vec![(v(1), v(0), 1.0)]);
assert!(q.linear.is_empty());
}
#[test]
fn cvxopt_objective_recovers_hessian() {
let mut a = ExprArena::new();
let x0 = var(&mut a, 0);
let x1 = var(&mut a, 1);
let two = a.push(ExprNode::Const(2.0));
let x0sq = a.push(ExprNode::Pow(x0, two));
let term0 = a.push(ExprNode::Mul(smallvec![two, x0sq]));
let x0x1 = a.push(ExprNode::Mul(smallvec![x0, x1]));
let two_b = a.push(ExprNode::Const(2.0));
let x1sq = a.push(ExprNode::Pow(x1, two_b));
let sum = a.push(ExprNode::Add(smallvec![term0, x0x1, x1sq, x0, x1]));
let q = extract_quadratic(&a, sum).unwrap();
assert_eq!(q.hessian, vec![(v(0), v(0), 4.0), (v(1), v(0), 1.0), (v(1), v(1), 2.0)]);
assert_eq!(q.linear, vec![(v(0), 1.0), (v(1), 1.0)]);
assert!(q.constant.abs() < f64::EPSILON);
}
#[test]
fn square_of_sum_cross_term() {
let mut a = ExprArena::new();
let x0 = var(&mut a, 0);
let x1 = var(&mut a, 1);
let sum = a.push(ExprNode::Add(smallvec![x0, x1]));
let two = a.push(ExprNode::Const(2.0));
let sq = a.push(ExprNode::Pow(sum, two));
let q = extract_quadratic(&a, sq).unwrap();
assert_eq!(q.hessian, vec![(v(0), v(0), 2.0), (v(1), v(0), 2.0), (v(1), v(1), 2.0)]);
}
#[test]
fn linear_only_has_empty_hessian() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let three = a.push(ExprNode::Const(3.0));
let mul = a.push(ExprNode::Mul(smallvec![three, x]));
let five = a.push(ExprNode::Const(5.0));
let expr = a.push(ExprNode::Add(smallvec![mul, five]));
let q = extract_quadratic(&a, expr).unwrap();
assert!(q.hessian.is_empty());
assert_eq!(q.linear, vec![(v(0), 3.0)]);
assert!((q.constant - 5.0).abs() < f64::EPSILON);
}
#[test]
fn constant_only() {
let mut a = ExprArena::new();
let c = a.push(ExprNode::Const(7.0));
let q = extract_quadratic(&a, c).unwrap();
assert!(q.hessian.is_empty());
assert!(q.linear.is_empty());
assert!((q.constant - 7.0).abs() < f64::EPSILON);
}
#[test]
fn negation_flips_signs() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let two = a.push(ExprNode::Const(2.0));
let sq = a.push(ExprNode::Pow(x, two));
let inner = a.push(ExprNode::Add(smallvec![sq, x]));
let neg = a.push(ExprNode::Neg(inner));
let q = extract_quadratic(&a, neg).unwrap();
assert_eq!(q.hessian, vec![(v(0), v(0), -2.0)]);
assert_eq!(q.linear, vec![(v(0), -1.0)]);
}
#[test]
fn cubic_is_none() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let three = a.push(ExprNode::Const(3.0));
let cube = a.push(ExprNode::Pow(x, three));
assert!(extract_quadratic(&a, cube).is_none());
}
#[test]
fn triple_product_is_none() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let y = var(&mut a, 1);
let z = var(&mut a, 2);
let prod = a.push(ExprNode::Mul(smallvec![x, y, z]));
assert!(extract_quadratic(&a, prod).is_none());
}
#[test]
fn transcendental_is_none() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let s = a.push(ExprNode::Sin(x));
assert!(extract_quadratic(&a, s).is_none());
}
#[test]
fn const_times_square_scales() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let y = var(&mut a, 1);
let xy = a.push(ExprNode::Mul(smallvec![x, y]));
let three = a.push(ExprNode::Const(3.0));
let scaled = a.push(ExprNode::Mul(smallvec![three, xy]));
let q = extract_quadratic(&a, scaled).unwrap();
assert_eq!(q.hessian, vec![(v(1), v(0), 3.0)]);
}
}