Skip to main content

otspot_model/
quad_expr.rs

1//! Quadratic expression type for QP objectives.
2//!
3//! `QuadExpr` extends [`Expression`] with quadratic terms, enabling ergonomic
4//! construction of QP objectives via operator overloading:
5//!
6//! ```rust,no_run
7//! use otspot_model::Model;
8//!
9//! let mut model = Model::new("qp");
10//! let x = model.add_var("x", 1.0, f64::INFINITY);
11//! let y = model.add_var("y", 0.0, f64::INFINITY);
12//! model.minimize(x * x + 2.0 * x * y);  // min x² + 2xy
13//! ```
14
15use std::collections::HashMap;
16use std::ops::{Add, Mul, Neg, Sub};
17
18use super::expression::Expression;
19use super::variable::Variable;
20use otspot_core::sparse::CscMatrix;
21
22/// A quadratic (or linear) expression for use as a QP objective.
23///
24/// Stores quadratic terms as `(va, vb) → coefficient` where the pair is in
25/// canonical order: `(va.model_id, va.index) ≤ (vb.model_id, vb.index)`.
26///
27/// # Q-matrix convention
28/// When converted to `CscMatrix` via `quad_to_csc`, the "1/2 xᵀQx" convention is used:
29/// - Diagonal `c · xi²` → `Q[i][i] = 2c`  (so `1/2 · Q[i][i] · xi² = c · xi²`)
30/// - Cross `c · xi · xj` (i≠j) → `Q[i][j] = Q[j][i] = c`  (symmetric fill, both sides)
31#[derive(Debug, Clone, Default)]
32pub struct QuadExpr {
33    /// Quadratic terms in canonical-pair key order.
34    pub(crate) quad: HashMap<(Variable, Variable), f64>,
35    /// Linear and constant parts.
36    pub(crate) linear: Expression,
37}
38
39impl QuadExpr {
40    /// Returns `true` if this expression contains no quadratic terms.
41    pub fn is_linear(&self) -> bool {
42        self.quad.is_empty()
43    }
44
45    fn merge_add(&mut self, rhs: QuadExpr) {
46        for (pair, c) in rhs.quad {
47            insert_quad_term(&mut self.quad, pair, c);
48        }
49        let old = std::mem::take(&mut self.linear);
50        self.linear = old + rhs.linear;
51    }
52}
53
54/// Accumulate a quadratic term into the map, skipping zero deltas and removing
55/// entries that cancel to exactly zero.  This is the **single insertion
56/// chokepoint** for all quad-term construction — routing every write through
57/// here prevents zero-coefficient entries from leaking into `QuadExpr::quad`
58/// and causing spurious `is_linear() == false`.
59fn insert_quad_term(
60    quad: &mut HashMap<(Variable, Variable), f64>,
61    key: (Variable, Variable),
62    delta: f64,
63) {
64    if delta == 0.0 {
65        return; // zero contribution — skip to avoid polluting the map
66    }
67    let entry = quad.entry(key).or_insert(0.0);
68    *entry += delta;
69    if *entry == 0.0 {
70        quad.remove(&key);
71    }
72}
73
74/// Returns the canonical pair `(a, b)` with `(a.model_id, a.index) ≤ (b.model_id, b.index)`.
75fn canon(a: Variable, b: Variable) -> (Variable, Variable) {
76    if (a.model_id, a.index) <= (b.model_id, b.index) {
77        (a, b)
78    } else {
79        (b, a)
80    }
81}
82
83/// Convert quadratic-term map to a symmetric `CscMatrix` using the 1/2 xᵀQx convention.
84///
85/// - Diagonal entry `(i, i) → c`: emits `Q[i][i] = 2c`.
86/// - Off-diagonal `(i, j) → c` (i ≠ j): emits both `Q[i][j] = c` and `Q[j][i] = c`.
87///
88/// Returns an error if any variable index is out of range for a matrix of size `n×n`.
89pub(crate) fn quad_to_csc(
90    terms: &HashMap<(Variable, Variable), f64>,
91    n: usize,
92) -> Result<CscMatrix, String> {
93    if terms.is_empty() {
94        return Ok(CscMatrix::new(n, n));
95    }
96
97    let mut rows: Vec<usize> = Vec::new();
98    let mut cols: Vec<usize> = Vec::new();
99    let mut vals: Vec<f64> = Vec::new();
100
101    for (&(va, vb), &c) in terms {
102        let (i, j) = (va.index, vb.index);
103        if !c.is_finite() {
104            return Err(format!("non-finite quad coefficient at ({i}, {j}): {c}"));
105        }
106        if i >= n || j >= n {
107            return Err(format!(
108                "quad term ({i}, {j}) out of range for {n} variables"
109            ));
110        }
111        if i == j {
112            // Diagonal: 1/2 · Q[i][i] · xi² = c · xi²  ⟹  Q[i][i] = 2c
113            rows.push(i);
114            cols.push(j);
115            vals.push(2.0 * c);
116        } else {
117            // Off-diagonal symmetric: 1/2 · (Q[i][j] + Q[j][i]) · xi·xj = c · xi·xj  ⟹  Q[i][j] = Q[j][i] = c
118            rows.push(i);
119            cols.push(j);
120            vals.push(c);
121            rows.push(j);
122            cols.push(i);
123            vals.push(c);
124        }
125    }
126
127    CscMatrix::from_triplets(&rows, &cols, &vals, n, n).map_err(|e| e.to_string())
128}
129
130// ---------------------------------------------------------------------------
131// Variable extension: pow2
132// ---------------------------------------------------------------------------
133
134impl Variable {
135    /// Returns `x²` as a [`QuadExpr`].
136    pub fn pow2(self) -> QuadExpr {
137        self * self
138    }
139}
140
141// ---------------------------------------------------------------------------
142// From conversions
143// ---------------------------------------------------------------------------
144
145impl From<Variable> for QuadExpr {
146    fn from(v: Variable) -> Self {
147        QuadExpr {
148            quad: HashMap::new(),
149            linear: Expression::from(v),
150        }
151    }
152}
153
154impl From<Expression> for QuadExpr {
155    fn from(e: Expression) -> Self {
156        QuadExpr {
157            quad: HashMap::new(),
158            linear: e,
159        }
160    }
161}
162
163impl From<f64> for QuadExpr {
164    fn from(c: f64) -> Self {
165        QuadExpr {
166            quad: HashMap::new(),
167            linear: Expression::from(c),
168        }
169    }
170}
171
172impl From<i32> for QuadExpr {
173    fn from(c: i32) -> Self {
174        QuadExpr {
175            quad: HashMap::new(),
176            linear: Expression::from(c),
177        }
178    }
179}
180
181// ---------------------------------------------------------------------------
182// Variable * Variable → QuadExpr
183// ---------------------------------------------------------------------------
184
185impl Mul<Variable> for Variable {
186    type Output = QuadExpr;
187    fn mul(self, rhs: Variable) -> QuadExpr {
188        let mut quad = HashMap::new();
189        insert_quad_term(&mut quad, canon(self, rhs), 1.0);
190        QuadExpr {
191            quad,
192            linear: Expression::default(),
193        }
194    }
195}
196
197// ---------------------------------------------------------------------------
198// Expression * Variable  /  Variable * Expression → QuadExpr
199// ---------------------------------------------------------------------------
200
201impl Mul<Variable> for Expression {
202    type Output = QuadExpr;
203    fn mul(self, var: Variable) -> QuadExpr {
204        let mut quad = HashMap::new();
205        let mut linear = Expression::default();
206        for (&v, &c) in &self.coefficients {
207            // Route through the single chokepoint so zero coefficients (e.g.
208            // from `(x - x) * y`) never enter the map.
209            insert_quad_term(&mut quad, canon(v, var), c);
210        }
211        if self.constant != 0.0 {
212            *linear.coefficients.entry(var).or_insert(0.0) += self.constant;
213        }
214        QuadExpr { quad, linear }
215    }
216}
217
218impl Mul<Expression> for Variable {
219    type Output = QuadExpr;
220    fn mul(self, rhs: Expression) -> QuadExpr {
221        rhs * self
222    }
223}
224
225// ---------------------------------------------------------------------------
226// f64 * QuadExpr  /  QuadExpr * f64
227// ---------------------------------------------------------------------------
228
229impl Mul<f64> for QuadExpr {
230    type Output = QuadExpr;
231    fn mul(mut self, rhs: f64) -> QuadExpr {
232        for v in self.quad.values_mut() {
233            *v *= rhs;
234        }
235        // Prune entries zeroed by multiplication (e.g. 0.0 * x*x → is_linear = true).
236        self.quad.retain(|_, c| *c != 0.0);
237        self.linear = rhs * self.linear;
238        self
239    }
240}
241
242impl Mul<QuadExpr> for f64 {
243    type Output = QuadExpr;
244    fn mul(self, rhs: QuadExpr) -> QuadExpr {
245        rhs * self
246    }
247}
248
249// ---------------------------------------------------------------------------
250// Neg
251// ---------------------------------------------------------------------------
252
253impl Neg for QuadExpr {
254    type Output = QuadExpr;
255    fn neg(mut self) -> QuadExpr {
256        for v in self.quad.values_mut() {
257            *v = -*v;
258        }
259        self.linear = -self.linear;
260        self
261    }
262}
263
264// ---------------------------------------------------------------------------
265// QuadExpr ± QuadExpr
266// ---------------------------------------------------------------------------
267
268impl Add for QuadExpr {
269    type Output = QuadExpr;
270    fn add(mut self, rhs: QuadExpr) -> QuadExpr {
271        self.merge_add(rhs);
272        self
273    }
274}
275
276impl Sub for QuadExpr {
277    type Output = QuadExpr;
278    fn sub(self, rhs: QuadExpr) -> QuadExpr {
279        self + (-rhs)
280    }
281}
282
283// ---------------------------------------------------------------------------
284// QuadExpr ± Expression
285// ---------------------------------------------------------------------------
286
287impl Add<Expression> for QuadExpr {
288    type Output = QuadExpr;
289    fn add(self, rhs: Expression) -> QuadExpr {
290        self + QuadExpr::from(rhs)
291    }
292}
293
294impl Add<QuadExpr> for Expression {
295    type Output = QuadExpr;
296    fn add(self, rhs: QuadExpr) -> QuadExpr {
297        rhs + self
298    }
299}
300
301impl Sub<Expression> for QuadExpr {
302    type Output = QuadExpr;
303    fn sub(self, rhs: Expression) -> QuadExpr {
304        self + (-rhs)
305    }
306}
307
308impl Sub<QuadExpr> for Expression {
309    type Output = QuadExpr;
310    fn sub(self, rhs: QuadExpr) -> QuadExpr {
311        QuadExpr::from(self) + (-rhs)
312    }
313}
314
315// ---------------------------------------------------------------------------
316// QuadExpr ± Variable
317// ---------------------------------------------------------------------------
318
319impl Add<Variable> for QuadExpr {
320    type Output = QuadExpr;
321    fn add(self, rhs: Variable) -> QuadExpr {
322        self + Expression::from(rhs)
323    }
324}
325
326impl Add<QuadExpr> for Variable {
327    type Output = QuadExpr;
328    fn add(self, rhs: QuadExpr) -> QuadExpr {
329        rhs + self
330    }
331}
332
333impl Sub<Variable> for QuadExpr {
334    type Output = QuadExpr;
335    fn sub(self, rhs: Variable) -> QuadExpr {
336        self + (-Expression::from(rhs))
337    }
338}
339
340impl Sub<QuadExpr> for Variable {
341    type Output = QuadExpr;
342    fn sub(self, rhs: QuadExpr) -> QuadExpr {
343        QuadExpr::from(Expression::from(self)) + (-rhs)
344    }
345}
346
347// ---------------------------------------------------------------------------
348// QuadExpr ± f64
349// ---------------------------------------------------------------------------
350
351impl Add<f64> for QuadExpr {
352    type Output = QuadExpr;
353    fn add(self, rhs: f64) -> QuadExpr {
354        self + Expression::from(rhs)
355    }
356}
357
358impl Add<QuadExpr> for f64 {
359    type Output = QuadExpr;
360    fn add(self, rhs: QuadExpr) -> QuadExpr {
361        rhs + self
362    }
363}
364
365impl Sub<f64> for QuadExpr {
366    type Output = QuadExpr;
367    fn sub(self, rhs: f64) -> QuadExpr {
368        self + (-rhs)
369    }
370}
371
372impl Sub<QuadExpr> for f64 {
373    type Output = QuadExpr;
374    fn sub(self, rhs: QuadExpr) -> QuadExpr {
375        self + (-rhs)
376    }
377}
378
379// ---------------------------------------------------------------------------
380// Tests
381// ---------------------------------------------------------------------------
382
383#[cfg(test)]
384mod tests {
385    use super::*;
386    use crate::Model;
387
388    const TOL: f64 = 1e-5;
389
390    fn assert_close(a: f64, b: f64, label: &str) {
391        assert!((a - b).abs() < TOL, "{label}: expected {b}, got {a}");
392    }
393
394    /// Extract Q[row][col] from a CscMatrix (returns 0.0 if absent).
395    fn q_entry(q: &CscMatrix, row: usize, col: usize) -> f64 {
396        let col_start = q.col_ptr()[col];
397        let col_end = q.col_ptr()[col + 1];
398        for k in col_start..col_end {
399            if q.row_ind()[k] == row {
400                return q.values()[k];
401            }
402        }
403        0.0
404    }
405
406    // --- quad_to_csc unit tests ---
407
408    #[test]
409    fn test_quad_to_csc_diagonal() {
410        let mut model = Model::new("m");
411        let x = model.add_var("x", 0.0, f64::INFINITY);
412        // c · x² with c = 3.0  →  Q[0][0] = 6.0
413        let mut terms = HashMap::new();
414        terms.insert((x, x), 3.0);
415        let q = quad_to_csc(&terms, 1).unwrap();
416        assert_eq!(q_entry(&q, 0, 0), 6.0, "diagonal: Q[0][0] should be 2*c");
417    }
418
419    #[test]
420    fn test_quad_to_csc_cross_symmetric() {
421        let mut model = Model::new("m");
422        let x = model.add_var("x", 0.0, f64::INFINITY);
423        let y = model.add_var("y", 0.0, f64::INFINITY);
424        // c · x·y with c = 5.0  →  Q[0][1] = Q[1][0] = 5.0
425        let mut terms = HashMap::new();
426        terms.insert(canon(x, y), 5.0);
427        let q = quad_to_csc(&terms, 2).unwrap();
428        assert_eq!(q_entry(&q, 0, 1), 5.0, "cross: Q[0][1] must equal c");
429        assert_eq!(
430            q_entry(&q, 1, 0),
431            5.0,
432            "cross: Q[1][0] must equal c (symmetry)"
433        );
434    }
435
436    /// Sentinel: verify that `quad_to_csc` fills both Q[i][j] and Q[j][i].
437    ///
438    /// A broken upper-triangle-only implementation would produce nnz=1 for a
439    /// cross term (only Q[0][1], missing Q[1][0]).  The correct implementation
440    /// produces nnz=2.  We also confirm Q[1][0] == 0 in the broken Q, proving
441    /// the missing entry would cause a wrong (asymmetric) matrix.
442    #[test]
443    fn test_symmetry_sentinel_quad_to_csc_fills_both_sides() {
444        let mut model = Model::new("m");
445        let x = model.add_var("x", 0.0, f64::INFINITY);
446        let y = model.add_var("y", 0.0, f64::INFINITY);
447
448        // DSL: x * y  →  quad[(x,y)] = 1.0  →  must emit Q[0][1] = Q[1][0] = 1.0
449        let mut terms = HashMap::new();
450        terms.insert(canon(x, y), 5.0);
451        let correct = quad_to_csc(&terms, 2).unwrap();
452
453        // Both sides must be present and equal (symmetric fill):
454        assert_eq!(
455            q_entry(&correct, 0, 1),
456            5.0,
457            "sentinel: Q[0][1] must be 5.0"
458        );
459        assert_eq!(
460            q_entry(&correct, 1, 0),
461            5.0,
462            "sentinel: Q[1][0] must be 5.0 — missing this entry is the classic bug"
463        );
464        // nnz = 2: one entry per triangle
465        assert_eq!(
466            correct.nnz(),
467            2,
468            "sentinel: cross term must emit exactly 2 triplets"
469        );
470
471        // No-op proof: a broken upper-triangle-only matrix has Q[1][0] == 0.
472        let broken = CscMatrix::from_triplets(&[0], &[1], &[5.0], 2, 2).unwrap();
473        assert_eq!(
474            broken.nnz(),
475            1,
476            "broken: only 1 triplet (missing lower side)"
477        );
478        assert_eq!(
479            q_entry(&broken, 1, 0),
480            0.0,
481            "broken: Q[1][0] is 0 — this is the missing-symmetry bug"
482        );
483        assert_ne!(
484            q_entry(&broken, 0, 1),
485            q_entry(&broken, 1, 0),
486            "broken: Q is not symmetric (upper ≠ lower), confirming the bug exists"
487        );
488    }
489
490    // --- Operator DSL tests ---
491
492    #[test]
493    fn test_var_times_var_is_quadratic() {
494        let mut model = Model::new("m");
495        let x = model.add_var("x", 0.0, f64::INFINITY);
496        let y = model.add_var("y", 0.0, f64::INFINITY);
497        let q = x * x;
498        assert!(!q.is_linear());
499        let q2 = x * y;
500        assert!(!q2.is_linear());
501    }
502
503    #[test]
504    fn test_pow2_equals_var_times_var() {
505        let mut model = Model::new("m");
506        let x = model.add_var("x", 0.0, f64::INFINITY);
507        let q1 = x * x;
508        let q2 = x.pow2();
509        // Both should produce the same quad entry
510        assert_eq!(q1.quad.len(), 1);
511        assert_eq!(q2.quad.len(), 1);
512        let c1: f64 = q1.quad.values().copied().sum();
513        let c2: f64 = q2.quad.values().copied().sum();
514        assert!((c1 - c2).abs() < 1e-12);
515    }
516
517    #[test]
518    fn test_scalar_mul_quad_expr() {
519        let mut model = Model::new("m");
520        let x = model.add_var("x", 0.0, f64::INFINITY);
521        // 3.0 * x * x: coefficient should be 3.0 in quad map
522        let q = 3.0 * (x * x);
523        assert_eq!(q.quad.len(), 1);
524        let c: f64 = q.quad.values().copied().sum();
525        assert!(
526            (c - 3.0).abs() < 1e-12,
527            "scalar mul: coefficient should be 3.0, got {c}"
528        );
529    }
530
531    #[test]
532    fn test_expression_times_var() {
533        let mut model = Model::new("m");
534        let x = model.add_var("x", 0.0, f64::INFINITY);
535        let y = model.add_var("y", 0.0, f64::INFINITY);
536        // (2.0 * x) * y  →  QuadExpr with quad[(x,y)] = 2.0
537        let expr = 2.0 * x;
538        let q = expr * y;
539        assert!(!q.is_linear());
540        // Extract coefficient for the x-y pair
541        let c: f64 = q.quad.values().copied().sum();
542        assert!(
543            (c - 2.0).abs() < 1e-12,
544            "expr*var: coefficient should be 2.0, got {c}"
545        );
546    }
547
548    #[test]
549    fn test_add_quadexprs() {
550        let mut model = Model::new("m");
551        let x = model.add_var("x", 0.0, f64::INFINITY);
552        let y = model.add_var("y", 0.0, f64::INFINITY);
553        // x*x + y*y should have two diagonal entries
554        let q = x * x + y * y;
555        assert_eq!(q.quad.len(), 2);
556    }
557
558    #[test]
559    fn test_neg_quad_expr() {
560        let mut model = Model::new("m");
561        let x = model.add_var("x", 0.0, f64::INFINITY);
562        let q = -(x * x);
563        let c: f64 = q.quad.values().copied().sum();
564        assert!(
565            (c + 1.0).abs() < 1e-12,
566            "neg: coefficient should be -1.0, got {c}"
567        );
568    }
569
570    #[test]
571    fn test_mixed_quad_linear() {
572        let mut model = Model::new("m");
573        let x = model.add_var("x", 0.0, f64::INFINITY);
574        let y = model.add_var("y", 0.0, f64::INFINITY);
575        // 2*x*x + 3*x*y + y   (quadratic + linear mixed)
576        let q = 2.0 * x * x + 3.0 * x * y + y;
577        assert!(!q.is_linear());
578        // Should have two quad entries: (x,x) and (x,y)
579        assert_eq!(q.quad.len(), 2);
580        // Linear part should have y with coefficient 1.0
581        let lin_y = q.linear.coefficient(y);
582        assert!(
583            (lin_y - 1.0).abs() < 1e-12,
584            "linear y coeff should be 1.0, got {lin_y}"
585        );
586    }
587
588    // --- Model solve tests (through Model API) ---
589
590    #[test]
591    fn test_minimize_x_squared_with_lb() {
592        // min x²  s.t. x ≥ 1  →  x* = 1, obj* = 1
593        let mut model = Model::new("min_x2");
594        let x = model.add_var("x", 1.0, f64::INFINITY);
595        model.minimize(x * x);
596        let result = model.solve().unwrap();
597        assert_close(result[x], 1.0, "min x²: x*");
598        assert_close(result.objective_value, 1.0, "min x²: obj*");
599    }
600
601    #[test]
602    fn test_minimize_x_squared_plus_y_squared() {
603        // min x² + y²  s.t. x + y = 2, x,y ≥ 0  →  x* = y* = 1, obj* = 2
604        let mut model = Model::new("min_x2_y2");
605        let x = model.add_var("x", 0.0, f64::INFINITY);
606        let y = model.add_var("y", 0.0, f64::INFINITY);
607        model.add_constraint((x + y).eq_constraint(2.0));
608        model.minimize(x * x + y * y);
609        let result = model.solve().unwrap();
610        assert_close(result[x], 1.0, "min x²+y²: x*");
611        assert_close(result[y], 1.0, "min x²+y²: y*");
612        assert_close(result.objective_value, 2.0, "min x²+y²: obj*");
613    }
614
615    #[test]
616    fn test_minimize_pow2_api() {
617        // Same as above via x.pow2() + y.pow2()
618        let mut model = Model::new("pow2");
619        let x = model.add_var("x", 0.0, f64::INFINITY);
620        let y = model.add_var("y", 0.0, f64::INFINITY);
621        model.add_constraint((x + y).eq_constraint(2.0));
622        model.minimize(x.pow2() + y.pow2());
623        let result = model.solve().unwrap();
624        assert_close(result.objective_value, 2.0, "pow2 API: obj*");
625    }
626
627    #[test]
628    fn test_maximize_concave_qp() {
629        // max -x² + 4x  s.t. x ≥ 0  (concave → unique interior max at x=2, obj=4)
630        // Sign-flip check: Q must be negated for maximize.
631        let mut model = Model::new("max_concave");
632        let x = model.add_var("x", 0.0, f64::INFINITY);
633        model.maximize(-(x * x) + 4.0 * x);
634        let result = model.solve().unwrap();
635        assert_close(result[x], 2.0, "max -x²+4x: x*");
636        assert_close(result.objective_value, 4.0, "max -x²+4x: obj*");
637    }
638
639    #[test]
640    fn test_minimize_cross_term_q_symmetry() {
641        // min x² + x·y + y²  s.t. x + y = 2, x,y ≥ 0
642        // → x* = y* = 1, obj* = 1 + 1 + 1 = 3
643        //
644        // Symmetry proof: if Q[0][1] were set but Q[1][0] omitted (upper-triangle only),
645        // the effective objective would be x² + y² + ½·x·y, giving obj* = 2.5 ≠ 3.
646        // This test therefore FAILS under a broken (upper-triangle-only) implementation.
647        let mut model = Model::new("cross_sym");
648        let x = model.add_var("x", 0.0, f64::INFINITY);
649        let y = model.add_var("y", 0.0, f64::INFINITY);
650        model.add_constraint((x + y).eq_constraint(2.0));
651        model.minimize(x * x + x * y + y * y);
652        let result = model.solve().unwrap();
653        let tol = 1e-3;
654        assert!(
655            (result[x] - 1.0).abs() < tol,
656            "cross_sym: x* ≈ 1, got {}",
657            result[x]
658        );
659        assert!(
660            (result[y] - 1.0).abs() < tol,
661            "cross_sym: y* ≈ 1, got {}",
662            result[y]
663        );
664        assert!(
665            (result.objective_value - 3.0).abs() < tol,
666            "cross_sym: obj* ≈ 3 (symmetric Q fill required), got {}",
667            result.objective_value
668        );
669    }
670
671    #[test]
672    fn test_mixed_quad_linear_solve() {
673        // min x² - 4x  s.t. x ≥ 0  →  x* = 2, obj* = 4 - 8 = -4
674        // Written as: minimize(x*x + (-4.0) * x)
675        let mut model = Model::new("quad_linear");
676        let x = model.add_var("x", 0.0, f64::INFINITY);
677        model.minimize(x * x + (-4.0) * x);
678        let result = model.solve().unwrap();
679        assert_close(result[x], 2.0, "quad+linear: x*");
680        assert_close(result.objective_value, -4.0, "quad+linear: obj*");
681    }
682
683    #[test]
684    fn test_scalar_multiple_quad_solve() {
685        // min 2·x² - 8·x  s.t. x ≥ 0  →  x* = 2, obj* = 8 - 16 = -8
686        let mut model = Model::new("2x2_8x");
687        let x = model.add_var("x", 0.0, f64::INFINITY);
688        model.minimize(2.0 * x * x + (-8.0) * x);
689        let result = model.solve().unwrap();
690        assert_close(result[x], 2.0, "2x²-8x: x*");
691        assert_close(result.objective_value, -8.0, "2x²-8x: obj*");
692    }
693
694    #[test]
695    fn test_dsl_qp_solves_correctly() {
696        // Verify DSL minimize(x*x + y*y) gives correct answer
697        // min x² + y²  s.t. x+y=3, x,y≥0  →  x=y=1.5, obj=4.5
698        let mut m = Model::new("dsl");
699        let x = m.add_var("x", 0.0, f64::INFINITY);
700        let y = m.add_var("y", 0.0, f64::INFINITY);
701        m.add_constraint((x + y).eq_constraint(3.0));
702        m.minimize(x * x + y * y);
703        let r = m.solve().unwrap();
704
705        let tol = 1e-3;
706        assert!((r[x] - 1.5).abs() < tol, "DSL x={} expected 1.5", r[x]);
707        assert!((r[y] - 1.5).abs() < tol, "DSL y={} expected 1.5", r[y]);
708        assert!(
709            (r.objective_value - 4.5).abs() < tol,
710            "DSL obj={} expected 4.5",
711            r.objective_value
712        );
713    }
714
715    #[test]
716    fn test_linear_objective_still_works_after_quad_change() {
717        // minimize(x) should work normally via Into<QuadExpr> (pure-linear path)
718        let mut model = Model::new("lin");
719        let x = model.add_var("x", 2.0, 10.0);
720        model.minimize(x);
721        let result = model.solve().unwrap();
722        assert_close(result[x], 2.0, "linear min x: x*");
723    }
724
725    #[test]
726    fn test_from_expression_into_quad_expr() {
727        // model.minimize(2.0 * x + y) — Expression → QuadExpr via From
728        let mut model = Model::new("lin_expr");
729        let x = model.add_var("x", 0.0, f64::INFINITY);
730        let y = model.add_var("y", 0.0, 10.0);
731        model.add_constraint((x + y).geq(3.0));
732        model.minimize(2.0 * x + y); // Expression into QuadExpr (no quad terms)
733        let result = model.solve().unwrap();
734        assert_close(result[x], 0.0, "linear via QuadExpr: x*");
735        assert_close(result[y], 3.0, "linear via QuadExpr: y*");
736    }
737
738    // --- P3.1: ゼロ quad entry の pruning ---
739
740    #[test]
741    fn test_cancelled_quad_term_is_linear() {
742        // x*y - x*y should cancel to zero quad terms → is_linear() == true
743        let mut model = Model::new("m");
744        let x = model.add_var("x", 0.0, f64::INFINITY);
745        let y = model.add_var("y", 0.0, f64::INFINITY);
746        let q = x * y - x * y;
747        assert!(
748            q.is_linear(),
749            "x*y - x*y should cancel to is_linear() == true"
750        );
751    }
752
753    #[test]
754    fn test_zero_scalar_mul_is_linear() {
755        // 0.0 * (x * x) should prune the quad entry → is_linear() == true
756        let mut model = Model::new("m");
757        let x = model.add_var("x", 0.0, f64::INFINITY);
758        let q = 0.0 * (x * x);
759        assert!(
760            q.is_linear(),
761            "0.0 * x*x should prune to is_linear() == true"
762        );
763    }
764
765    #[test]
766    fn test_cancelled_quad_routes_to_lp() {
767        // x*y - x*y (pure linear 0) minimized: routes to LP, not QP
768        // With only constant, any feasible x,y is optimal with obj=0.
769        let mut model = Model::new("cancel_route");
770        let x = model.add_var("x", 2.0, 2.0);
771        let y = model.add_var("y", 3.0, 3.0);
772        model.minimize(x * y - x * y + 1.0); // = 1.0 (constant only, LP path)
773        let result = model.solve().unwrap();
774        assert!(
775            (result.objective_value - 1.0).abs() < TOL,
776            "cancelled quad routes to LP: obj should be 1.0, got {}",
777            result.objective_value
778        );
779    }
780
781    // --- P3.3: coverage gap — NaN 係数 / indefinite QP ---
782
783    #[test]
784    fn test_nan_quad_coefficient_gives_error() {
785        // NaN coefficient in quadratic term should produce an error at solve time.
786        let mut model = Model::new("nan_q");
787        let x = model.add_var("x", 0.0, f64::INFINITY);
788        let q_expr = f64::NAN * (x * x);
789        model.minimize(q_expr);
790        let result = model.solve();
791        assert!(
792            result.is_err(),
793            "NaN quad coefficient should produce an error, got Ok"
794        );
795    }
796
797    #[test]
798    fn test_indefinite_qp_no_silent_optimal() {
799        use crate::SolutionProof;
800        // min x·y  s.t. x+y≥1, x,y≥0 — indefinite (non-convex) QP.
801        // Must NOT silently claim GlobalOptimal.
802        let mut model = Model::new("indef");
803        let x = model.add_var("x", 0.0, f64::INFINITY);
804        let y = model.add_var("y", 0.0, f64::INFINITY);
805        model.add_constraint((x + y).geq(1.0));
806        model.minimize(x * y);
807        let result = model.solve();
808        match result {
809            Ok(r) => {
810                assert_ne!(
811                    r.proof,
812                    SolutionProof::GlobalOptimal,
813                    "indefinite QP must not claim global optimality"
814                );
815            }
816            Err(_) => {
817                // Error (e.g. NonConvex) is also acceptable for indefinite QP
818            }
819        }
820    }
821
822    // ---------------------------------------------------------------------------
823    // P2-c: zero-coefficient prune via single chokepoint (insert_quad_term)
824    // ---------------------------------------------------------------------------
825
826    // `(x - x) * y` — Expression has coef[x]=0, must not enter quad map.
827    #[test]
828    fn test_zero_coef_expr_times_var_is_linear() {
829        let mut model = Model::new("m");
830        let x = model.add_var("x", 0.0, f64::INFINITY);
831        let y = model.add_var("y", 0.0, f64::INFINITY);
832        let q = (x - x) * y;
833        assert!(
834            q.is_linear(),
835            "(x-x)*y must be is_linear(); quad.len()={}",
836            q.quad.len()
837        );
838    }
839
840    // `(x + x - 2*x) * y` — three-way cancellation.
841    #[test]
842    fn test_multi_cancel_expr_times_var_is_linear() {
843        let mut model = Model::new("m");
844        let x = model.add_var("x", 0.0, f64::INFINITY);
845        let y = model.add_var("y", 0.0, f64::INFINITY);
846        let q = (x + x + ((-2.0) * x)) * y;
847        assert!(
848            q.is_linear(),
849            "(x+x-2x)*y must be is_linear(); quad.len()={}",
850            q.quad.len()
851        );
852    }
853
854    // x*x - x*x: merge cancellation still works (existing test extended).
855    #[test]
856    fn test_quad_sub_self_is_linear() {
857        let mut model = Model::new("m");
858        let x = model.add_var("x", 0.0, f64::INFINITY);
859        let q = x * x - x * x;
860        assert!(q.is_linear(), "x*x - x*x must cancel to is_linear()");
861        assert_eq!(q.quad.len(), 0, "quad map must be empty after cancellation");
862    }
863
864    // ---------------------------------------------------------------------------
865    // P2-d: cross-model variable validation in apply_objective
866    // ---------------------------------------------------------------------------
867
868    // Diagonal term from another model must be rejected.
869    #[test]
870    fn test_p2d_cross_model_diagonal_rejected() {
871        use crate::ModelError;
872        let mut m1 = Model::new("m1");
873        let x1 = m1.add_var("x", 0.0, f64::INFINITY);
874
875        let mut m2 = Model::new("m2");
876        // x1 has m1's model_id; minimizing it in m2 must error.
877        m2.minimize(x1 * x1);
878        let result = m2.solve();
879        assert!(
880            matches!(result, Err(ModelError::InvalidInput(_))),
881            "P2-d: cross-model diagonal must give InvalidInput, got {result:?}"
882        );
883    }
884
885    // Cross term mixing variables from two models must be rejected.
886    #[test]
887    fn test_p2d_cross_model_mixed_term_rejected() {
888        use crate::ModelError;
889        let mut m1 = Model::new("m1");
890        let x1 = m1.add_var("x", 0.0, f64::INFINITY);
891
892        let mut m2 = Model::new("m2");
893        let y2 = m2.add_var("y", 0.0, f64::INFINITY);
894
895        // x1 belongs to m1, y2 belongs to m2; the cross term is invalid for both.
896        m1.minimize(x1 * y2);
897        let result = m1.solve();
898        assert!(
899            matches!(result, Err(ModelError::InvalidInput(_))),
900            "P2-d: cross-model cross-term must give InvalidInput, got {result:?}"
901        );
902    }
903
904    // Sanity: same-model variable works correctly (no false positive).
905    #[test]
906    fn test_p2d_same_model_accepted() {
907        let mut model = Model::new("sanity");
908        let x = model.add_var("x", 1.0, f64::INFINITY);
909        model.minimize(x * x);
910        let result = model.solve();
911        assert!(
912            result.is_ok(),
913            "P2-d: same-model quad must be accepted, got {result:?}"
914        );
915    }
916
917    // maximize path: cross-model rejection via maximize (not just minimize).
918    #[test]
919    fn test_p2d_cross_model_maximize_rejected() {
920        use crate::ModelError;
921        let mut m1 = Model::new("m1");
922        let x1 = m1.add_var("x", 0.0, 5.0);
923
924        let mut m2 = Model::new("m2");
925        m2.maximize(x1 * x1);
926        let result = m2.solve();
927        assert!(
928            matches!(result, Err(ModelError::InvalidInput(_))),
929            "P2-d: cross-model maximize must give InvalidInput, got {result:?}"
930        );
931    }
932}