use std::collections::HashMap;
use std::ops::{Add, Mul, Neg, Sub};
use super::expression::Expression;
use super::variable::Variable;
use otspot_core::sparse::CscMatrix;
#[derive(Debug, Clone, Default)]
pub struct QuadExpr {
pub(crate) quad: HashMap<(Variable, Variable), f64>,
pub(crate) linear: Expression,
}
impl QuadExpr {
pub fn is_linear(&self) -> bool {
self.quad.is_empty()
}
fn merge_add(&mut self, rhs: QuadExpr) {
for (pair, c) in rhs.quad {
insert_quad_term(&mut self.quad, pair, c);
}
let old = std::mem::take(&mut self.linear);
self.linear = old + rhs.linear;
}
}
fn insert_quad_term(
quad: &mut HashMap<(Variable, Variable), f64>,
key: (Variable, Variable),
delta: f64,
) {
if delta == 0.0 {
return; }
let entry = quad.entry(key).or_insert(0.0);
*entry += delta;
if *entry == 0.0 {
quad.remove(&key);
}
}
fn canon(a: Variable, b: Variable) -> (Variable, Variable) {
if (a.model_id, a.index) <= (b.model_id, b.index) {
(a, b)
} else {
(b, a)
}
}
pub(crate) fn quad_to_csc(
terms: &HashMap<(Variable, Variable), f64>,
n: usize,
) -> Result<CscMatrix, String> {
if terms.is_empty() {
return Ok(CscMatrix::new(n, n));
}
let mut rows: Vec<usize> = Vec::new();
let mut cols: Vec<usize> = Vec::new();
let mut vals: Vec<f64> = Vec::new();
for (&(va, vb), &c) in terms {
let (i, j) = (va.index, vb.index);
if !c.is_finite() {
return Err(format!(
"non-finite quad coefficient at ({i}, {j}): {c}"
));
}
if i >= n || j >= n {
return Err(format!(
"quad term ({i}, {j}) out of range for {n} variables"
));
}
if i == j {
rows.push(i);
cols.push(j);
vals.push(2.0 * c);
} else {
rows.push(i);
cols.push(j);
vals.push(c);
rows.push(j);
cols.push(i);
vals.push(c);
}
}
CscMatrix::from_triplets(&rows, &cols, &vals, n, n)
.map_err(|e| e.to_string())
}
impl Variable {
pub fn pow2(self) -> QuadExpr {
self * self
}
}
impl From<Variable> for QuadExpr {
fn from(v: Variable) -> Self {
QuadExpr { quad: HashMap::new(), linear: Expression::from(v) }
}
}
impl From<Expression> for QuadExpr {
fn from(e: Expression) -> Self {
QuadExpr { quad: HashMap::new(), linear: e }
}
}
impl From<f64> for QuadExpr {
fn from(c: f64) -> Self {
QuadExpr { quad: HashMap::new(), linear: Expression::from(c) }
}
}
impl From<i32> for QuadExpr {
fn from(c: i32) -> Self {
QuadExpr { quad: HashMap::new(), linear: Expression::from(c) }
}
}
impl Mul<Variable> for Variable {
type Output = QuadExpr;
fn mul(self, rhs: Variable) -> QuadExpr {
let mut quad = HashMap::new();
insert_quad_term(&mut quad, canon(self, rhs), 1.0);
QuadExpr { quad, linear: Expression::default() }
}
}
impl Mul<Variable> for Expression {
type Output = QuadExpr;
fn mul(self, var: Variable) -> QuadExpr {
let mut quad = HashMap::new();
let mut linear = Expression::default();
for (&v, &c) in &self.coefficients {
insert_quad_term(&mut quad, canon(v, var), c);
}
if self.constant != 0.0 {
*linear.coefficients.entry(var).or_insert(0.0) += self.constant;
}
QuadExpr { quad, linear }
}
}
impl Mul<Expression> for Variable {
type Output = QuadExpr;
fn mul(self, rhs: Expression) -> QuadExpr {
rhs * self
}
}
impl Mul<f64> for QuadExpr {
type Output = QuadExpr;
fn mul(mut self, rhs: f64) -> QuadExpr {
for v in self.quad.values_mut() {
*v *= rhs;
}
self.quad.retain(|_, c| *c != 0.0);
self.linear = rhs * self.linear;
self
}
}
impl Mul<QuadExpr> for f64 {
type Output = QuadExpr;
fn mul(self, rhs: QuadExpr) -> QuadExpr {
rhs * self
}
}
impl Neg for QuadExpr {
type Output = QuadExpr;
fn neg(mut self) -> QuadExpr {
for v in self.quad.values_mut() {
*v = -*v;
}
self.linear = -self.linear;
self
}
}
impl Add for QuadExpr {
type Output = QuadExpr;
fn add(mut self, rhs: QuadExpr) -> QuadExpr {
self.merge_add(rhs);
self
}
}
impl Sub for QuadExpr {
type Output = QuadExpr;
fn sub(self, rhs: QuadExpr) -> QuadExpr {
self + (-rhs)
}
}
impl Add<Expression> for QuadExpr {
type Output = QuadExpr;
fn add(self, rhs: Expression) -> QuadExpr {
self + QuadExpr::from(rhs)
}
}
impl Add<QuadExpr> for Expression {
type Output = QuadExpr;
fn add(self, rhs: QuadExpr) -> QuadExpr {
rhs + self
}
}
impl Sub<Expression> for QuadExpr {
type Output = QuadExpr;
fn sub(self, rhs: Expression) -> QuadExpr {
self + (-rhs)
}
}
impl Sub<QuadExpr> for Expression {
type Output = QuadExpr;
fn sub(self, rhs: QuadExpr) -> QuadExpr {
QuadExpr::from(self) + (-rhs)
}
}
impl Add<Variable> for QuadExpr {
type Output = QuadExpr;
fn add(self, rhs: Variable) -> QuadExpr {
self + Expression::from(rhs)
}
}
impl Add<QuadExpr> for Variable {
type Output = QuadExpr;
fn add(self, rhs: QuadExpr) -> QuadExpr {
rhs + self
}
}
impl Sub<Variable> for QuadExpr {
type Output = QuadExpr;
fn sub(self, rhs: Variable) -> QuadExpr {
self + (-Expression::from(rhs))
}
}
impl Sub<QuadExpr> for Variable {
type Output = QuadExpr;
fn sub(self, rhs: QuadExpr) -> QuadExpr {
QuadExpr::from(Expression::from(self)) + (-rhs)
}
}
impl Add<f64> for QuadExpr {
type Output = QuadExpr;
fn add(self, rhs: f64) -> QuadExpr {
self + Expression::from(rhs)
}
}
impl Add<QuadExpr> for f64 {
type Output = QuadExpr;
fn add(self, rhs: QuadExpr) -> QuadExpr {
rhs + self
}
}
impl Sub<f64> for QuadExpr {
type Output = QuadExpr;
fn sub(self, rhs: f64) -> QuadExpr {
self + (-rhs)
}
}
impl Sub<QuadExpr> for f64 {
type Output = QuadExpr;
fn sub(self, rhs: QuadExpr) -> QuadExpr {
self + (-rhs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Model;
const TOL: f64 = 1e-5;
fn assert_close(a: f64, b: f64, label: &str) {
assert!((a - b).abs() < TOL, "{label}: expected {b}, got {a}");
}
fn q_entry(q: &CscMatrix, row: usize, col: usize) -> f64 {
let col_start = q.col_ptr()[col];
let col_end = q.col_ptr()[col + 1];
for k in col_start..col_end {
if q.row_ind()[k] == row {
return q.values()[k];
}
}
0.0
}
#[test]
fn test_quad_to_csc_diagonal() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let mut terms = HashMap::new();
terms.insert((x, x), 3.0);
let q = quad_to_csc(&terms, 1).unwrap();
assert_eq!(q_entry(&q, 0, 0), 6.0, "diagonal: Q[0][0] should be 2*c");
}
#[test]
fn test_quad_to_csc_cross_symmetric() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let mut terms = HashMap::new();
terms.insert(canon(x, y), 5.0);
let q = quad_to_csc(&terms, 2).unwrap();
assert_eq!(q_entry(&q, 0, 1), 5.0, "cross: Q[0][1] must equal c");
assert_eq!(q_entry(&q, 1, 0), 5.0, "cross: Q[1][0] must equal c (symmetry)");
}
#[test]
fn test_symmetry_sentinel_quad_to_csc_fills_both_sides() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let mut terms = HashMap::new();
terms.insert(canon(x, y), 5.0);
let correct = quad_to_csc(&terms, 2).unwrap();
assert_eq!(q_entry(&correct, 0, 1), 5.0, "sentinel: Q[0][1] must be 5.0");
assert_eq!(
q_entry(&correct, 1, 0),
5.0,
"sentinel: Q[1][0] must be 5.0 — missing this entry is the classic bug"
);
assert_eq!(correct.nnz(), 2, "sentinel: cross term must emit exactly 2 triplets");
let broken = CscMatrix::from_triplets(&[0], &[1], &[5.0], 2, 2).unwrap();
assert_eq!(broken.nnz(), 1, "broken: only 1 triplet (missing lower side)");
assert_eq!(
q_entry(&broken, 1, 0),
0.0,
"broken: Q[1][0] is 0 — this is the missing-symmetry bug"
);
assert_ne!(
q_entry(&broken, 0, 1),
q_entry(&broken, 1, 0),
"broken: Q is not symmetric (upper ≠ lower), confirming the bug exists"
);
}
#[test]
fn test_var_times_var_is_quadratic() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let q = x * x;
assert!(!q.is_linear());
let q2 = x * y;
assert!(!q2.is_linear());
}
#[test]
fn test_pow2_equals_var_times_var() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let q1 = x * x;
let q2 = x.pow2();
assert_eq!(q1.quad.len(), 1);
assert_eq!(q2.quad.len(), 1);
let c1: f64 = q1.quad.values().copied().sum();
let c2: f64 = q2.quad.values().copied().sum();
assert!((c1 - c2).abs() < 1e-12);
}
#[test]
fn test_scalar_mul_quad_expr() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let q = 3.0 * (x * x);
assert_eq!(q.quad.len(), 1);
let c: f64 = q.quad.values().copied().sum();
assert!((c - 3.0).abs() < 1e-12, "scalar mul: coefficient should be 3.0, got {c}");
}
#[test]
fn test_expression_times_var() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let expr = 2.0 * x;
let q = expr * y;
assert!(!q.is_linear());
let c: f64 = q.quad.values().copied().sum();
assert!((c - 2.0).abs() < 1e-12, "expr*var: coefficient should be 2.0, got {c}");
}
#[test]
fn test_add_quadexprs() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let q = x * x + y * y;
assert_eq!(q.quad.len(), 2);
}
#[test]
fn test_neg_quad_expr() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let q = -(x * x);
let c: f64 = q.quad.values().copied().sum();
assert!((c + 1.0).abs() < 1e-12, "neg: coefficient should be -1.0, got {c}");
}
#[test]
fn test_mixed_quad_linear() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let q = 2.0 * x * x + 3.0 * x * y + y;
assert!(!q.is_linear());
assert_eq!(q.quad.len(), 2);
let lin_y = q.linear.coefficient(y);
assert!((lin_y - 1.0).abs() < 1e-12, "linear y coeff should be 1.0, got {lin_y}");
}
#[test]
fn test_minimize_x_squared_with_lb() {
let mut model = Model::new("min_x2");
let x = model.add_var("x", 1.0, f64::INFINITY);
model.minimize(x * x);
let result = model.solve().unwrap();
assert_close(result[x], 1.0, "min x²: x*");
assert_close(result.objective_value, 1.0, "min x²: obj*");
}
#[test]
fn test_minimize_x_squared_plus_y_squared() {
let mut model = Model::new("min_x2_y2");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
model.add_constraint((x + y).eq_constraint(2.0));
model.minimize(x * x + y * y);
let result = model.solve().unwrap();
assert_close(result[x], 1.0, "min x²+y²: x*");
assert_close(result[y], 1.0, "min x²+y²: y*");
assert_close(result.objective_value, 2.0, "min x²+y²: obj*");
}
#[test]
fn test_minimize_pow2_api() {
let mut model = Model::new("pow2");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
model.add_constraint((x + y).eq_constraint(2.0));
model.minimize(x.pow2() + y.pow2());
let result = model.solve().unwrap();
assert_close(result.objective_value, 2.0, "pow2 API: obj*");
}
#[test]
fn test_maximize_concave_qp() {
let mut model = Model::new("max_concave");
let x = model.add_var("x", 0.0, f64::INFINITY);
model.maximize(-(x * x) + 4.0 * x);
let result = model.solve().unwrap();
assert_close(result[x], 2.0, "max -x²+4x: x*");
assert_close(result.objective_value, 4.0, "max -x²+4x: obj*");
}
#[test]
fn test_minimize_cross_term_q_symmetry() {
let mut model = Model::new("cross_sym");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
model.add_constraint((x + y).eq_constraint(2.0));
model.minimize(x * x + x * y + y * y);
let result = model.solve().unwrap();
let tol = 1e-3;
assert!(
(result[x] - 1.0).abs() < tol,
"cross_sym: x* ≈ 1, got {}",
result[x]
);
assert!(
(result[y] - 1.0).abs() < tol,
"cross_sym: y* ≈ 1, got {}",
result[y]
);
assert!(
(result.objective_value - 3.0).abs() < tol,
"cross_sym: obj* ≈ 3 (symmetric Q fill required), got {}",
result.objective_value
);
}
#[test]
fn test_mixed_quad_linear_solve() {
let mut model = Model::new("quad_linear");
let x = model.add_var("x", 0.0, f64::INFINITY);
model.minimize(x * x + (-4.0) * x);
let result = model.solve().unwrap();
assert_close(result[x], 2.0, "quad+linear: x*");
assert_close(result.objective_value, -4.0, "quad+linear: obj*");
}
#[test]
fn test_scalar_multiple_quad_solve() {
let mut model = Model::new("2x2_8x");
let x = model.add_var("x", 0.0, f64::INFINITY);
model.minimize(2.0 * x * x + (-8.0) * x);
let result = model.solve().unwrap();
assert_close(result[x], 2.0, "2x²-8x: x*");
assert_close(result.objective_value, -8.0, "2x²-8x: obj*");
}
#[test]
fn test_dsl_qp_solves_correctly() {
let mut m = Model::new("dsl");
let x = m.add_var("x", 0.0, f64::INFINITY);
let y = m.add_var("y", 0.0, f64::INFINITY);
m.add_constraint((x + y).eq_constraint(3.0));
m.minimize(x * x + y * y);
let r = m.solve().unwrap();
let tol = 1e-3;
assert!((r[x] - 1.5).abs() < tol, "DSL x={} expected 1.5", r[x]);
assert!((r[y] - 1.5).abs() < tol, "DSL y={} expected 1.5", r[y]);
assert!((r.objective_value - 4.5).abs() < tol, "DSL obj={} expected 4.5", r.objective_value);
}
#[test]
fn test_linear_objective_still_works_after_quad_change() {
let mut model = Model::new("lin");
let x = model.add_var("x", 2.0, 10.0);
model.minimize(x);
let result = model.solve().unwrap();
assert_close(result[x], 2.0, "linear min x: x*");
}
#[test]
fn test_from_expression_into_quad_expr() {
let mut model = Model::new("lin_expr");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, 10.0);
model.add_constraint((x + y).geq(3.0));
model.minimize(2.0 * x + y); let result = model.solve().unwrap();
assert_close(result[x], 0.0, "linear via QuadExpr: x*");
assert_close(result[y], 3.0, "linear via QuadExpr: y*");
}
#[test]
fn test_cancelled_quad_term_is_linear() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let q = x * y - x * y;
assert!(q.is_linear(), "x*y - x*y should cancel to is_linear() == true");
}
#[test]
fn test_zero_scalar_mul_is_linear() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let q = 0.0 * (x * x);
assert!(q.is_linear(), "0.0 * x*x should prune to is_linear() == true");
}
#[test]
fn test_cancelled_quad_routes_to_lp() {
let mut model = Model::new("cancel_route");
let x = model.add_var("x", 2.0, 2.0);
let y = model.add_var("y", 3.0, 3.0);
model.minimize(x * y - x * y + 1.0); let result = model.solve().unwrap();
assert!((result.objective_value - 1.0).abs() < TOL,
"cancelled quad routes to LP: obj should be 1.0, got {}", result.objective_value);
}
#[test]
fn test_nan_quad_coefficient_gives_error() {
let mut model = Model::new("nan_q");
let x = model.add_var("x", 0.0, f64::INFINITY);
let q_expr = f64::NAN * (x * x);
model.minimize(q_expr);
let result = model.solve();
assert!(
result.is_err(),
"NaN quad coefficient should produce an error, got Ok"
);
}
#[test]
fn test_indefinite_qp_no_silent_optimal() {
use crate::SolutionProof;
let mut model = Model::new("indef");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
model.add_constraint((x + y).geq(1.0));
model.minimize(x * y);
let result = model.solve();
match result {
Ok(r) => {
assert_ne!(
r.proof,
SolutionProof::GlobalOptimal,
"indefinite QP must not claim global optimality"
);
}
Err(_) => {
}
}
}
#[test]
fn test_zero_coef_expr_times_var_is_linear() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let q = (x - x) * y;
assert!(q.is_linear(), "(x-x)*y must be is_linear(); quad.len()={}", q.quad.len());
}
#[test]
fn test_multi_cancel_expr_times_var_is_linear() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let y = model.add_var("y", 0.0, f64::INFINITY);
let q = (x + x + ((-2.0) * x)) * y;
assert!(q.is_linear(), "(x+x-2x)*y must be is_linear(); quad.len()={}", q.quad.len());
}
#[test]
fn test_quad_sub_self_is_linear() {
let mut model = Model::new("m");
let x = model.add_var("x", 0.0, f64::INFINITY);
let q = x * x - x * x;
assert!(q.is_linear(), "x*x - x*x must cancel to is_linear()");
assert_eq!(q.quad.len(), 0, "quad map must be empty after cancellation");
}
#[test]
fn test_p2d_cross_model_diagonal_rejected() {
use crate::ModelError;
let mut m1 = Model::new("m1");
let x1 = m1.add_var("x", 0.0, f64::INFINITY);
let mut m2 = Model::new("m2");
m2.minimize(x1 * x1);
let result = m2.solve();
assert!(
matches!(result, Err(ModelError::InvalidInput(_))),
"P2-d: cross-model diagonal must give InvalidInput, got {result:?}"
);
}
#[test]
fn test_p2d_cross_model_mixed_term_rejected() {
use crate::ModelError;
let mut m1 = Model::new("m1");
let x1 = m1.add_var("x", 0.0, f64::INFINITY);
let mut m2 = Model::new("m2");
let y2 = m2.add_var("y", 0.0, f64::INFINITY);
m1.minimize(x1 * y2);
let result = m1.solve();
assert!(
matches!(result, Err(ModelError::InvalidInput(_))),
"P2-d: cross-model cross-term must give InvalidInput, got {result:?}"
);
}
#[test]
fn test_p2d_same_model_accepted() {
let mut model = Model::new("sanity");
let x = model.add_var("x", 1.0, f64::INFINITY);
model.minimize(x * x);
let result = model.solve();
assert!(result.is_ok(), "P2-d: same-model quad must be accepted, got {result:?}");
}
#[test]
fn test_p2d_cross_model_maximize_rejected() {
use crate::ModelError;
let mut m1 = Model::new("m1");
let x1 = m1.add_var("x", 0.0, 5.0);
let mut m2 = Model::new("m2");
m2.maximize(x1 * x1);
let result = m2.solve();
assert!(
matches!(result, Err(ModelError::InvalidInput(_))),
"P2-d: cross-model maximize must give InvalidInput, got {result:?}"
);
}
}