Skip to main content

otspot_model/
quad_expr.rs

1//! Quadratic expression type for QP objectives.
2//!
3//! `QuadExpr` extends [`Expression`] with quadratic terms, enabling ergonomic
4//! construction of QP objectives via operator overloading:
5//!
6//! ```rust,no_run
7//! use otspot_model::Model;
8//!
9//! let mut model = Model::new("qp");
10//! let x = model.add_var("x", 1.0, f64::INFINITY);
11//! let y = model.add_var("y", 0.0, f64::INFINITY);
12//! model.minimize(x * x + 2.0 * x * y);  // min x² + 2xy
13//! ```
14
15use std::collections::HashMap;
16use std::ops::{Add, Mul, Neg, Sub};
17
18use super::expression::Expression;
19use super::variable::Variable;
20use otspot_core::sparse::CscMatrix;
21
22/// A quadratic (or linear) expression for use as a QP objective.
23///
24/// Stores quadratic terms as `(va, vb) → coefficient` where the pair is in
25/// canonical order: `(va.model_id, va.index) ≤ (vb.model_id, vb.index)`.
26///
27/// # Q-matrix convention
28/// When converted to `CscMatrix` via [`quad_to_csc`], the "1/2 xᵀQx" convention is used:
29/// - Diagonal `c · xi²` → `Q[i][i] = 2c`  (so `1/2 · Q[i][i] · xi² = c · xi²`)
30/// - Cross `c · xi · xj` (i≠j) → `Q[i][j] = Q[j][i] = c`  (symmetric fill, both sides)
31#[derive(Debug, Clone, Default)]
32pub struct QuadExpr {
33    /// Quadratic terms in canonical-pair key order.
34    pub(crate) quad: HashMap<(Variable, Variable), f64>,
35    /// Linear and constant parts.
36    pub(crate) linear: Expression,
37}
38
39impl QuadExpr {
40    /// Returns `true` if this expression contains no quadratic terms.
41    pub fn is_linear(&self) -> bool {
42        self.quad.is_empty()
43    }
44
45    fn merge_add(&mut self, rhs: QuadExpr) {
46        for (pair, c) in rhs.quad {
47            insert_quad_term(&mut self.quad, pair, c);
48        }
49        let old = std::mem::take(&mut self.linear);
50        self.linear = old + rhs.linear;
51    }
52}
53
54/// Accumulate a quadratic term into the map, skipping zero deltas and removing
55/// entries that cancel to exactly zero.  This is the **single insertion
56/// chokepoint** for all quad-term construction — routing every write through
57/// here prevents zero-coefficient entries from leaking into `QuadExpr::quad`
58/// and causing spurious `is_linear() == false`.
59fn insert_quad_term(
60    quad: &mut HashMap<(Variable, Variable), f64>,
61    key: (Variable, Variable),
62    delta: f64,
63) {
64    if delta == 0.0 {
65        return; // zero contribution — skip to avoid polluting the map
66    }
67    let entry = quad.entry(key).or_insert(0.0);
68    *entry += delta;
69    if *entry == 0.0 {
70        quad.remove(&key);
71    }
72}
73
74/// Returns the canonical pair `(a, b)` with `(a.model_id, a.index) ≤ (b.model_id, b.index)`.
75fn canon(a: Variable, b: Variable) -> (Variable, Variable) {
76    if (a.model_id, a.index) <= (b.model_id, b.index) {
77        (a, b)
78    } else {
79        (b, a)
80    }
81}
82
83/// Convert quadratic-term map to a symmetric `CscMatrix` using the 1/2 xᵀQx convention.
84///
85/// - Diagonal entry `(i, i) → c`: emits `Q[i][i] = 2c`.
86/// - Off-diagonal `(i, j) → c` (i ≠ j): emits both `Q[i][j] = c` and `Q[j][i] = c`.
87///
88/// Returns an error if any variable index is out of range for a matrix of size `n×n`.
89pub(crate) fn quad_to_csc(
90    terms: &HashMap<(Variable, Variable), f64>,
91    n: usize,
92) -> Result<CscMatrix, String> {
93    if terms.is_empty() {
94        return Ok(CscMatrix::new(n, n));
95    }
96
97    let mut rows: Vec<usize> = Vec::new();
98    let mut cols: Vec<usize> = Vec::new();
99    let mut vals: Vec<f64> = Vec::new();
100
101    for (&(va, vb), &c) in terms {
102        let (i, j) = (va.index, vb.index);
103        if !c.is_finite() {
104            return Err(format!(
105                "non-finite quad coefficient at ({i}, {j}): {c}"
106            ));
107        }
108        if i >= n || j >= n {
109            return Err(format!(
110                "quad term ({i}, {j}) out of range for {n} variables"
111            ));
112        }
113        if i == j {
114            // Diagonal: 1/2 · Q[i][i] · xi² = c · xi²  ⟹  Q[i][i] = 2c
115            rows.push(i);
116            cols.push(j);
117            vals.push(2.0 * c);
118        } else {
119            // Off-diagonal symmetric: 1/2 · (Q[i][j] + Q[j][i]) · xi·xj = c · xi·xj  ⟹  Q[i][j] = Q[j][i] = c
120            rows.push(i);
121            cols.push(j);
122            vals.push(c);
123            rows.push(j);
124            cols.push(i);
125            vals.push(c);
126        }
127    }
128
129    CscMatrix::from_triplets(&rows, &cols, &vals, n, n)
130        .map_err(|e| e.to_string())
131}
132
133// ---------------------------------------------------------------------------
134// Variable extension: pow2
135// ---------------------------------------------------------------------------
136
137impl Variable {
138    /// Returns `x²` as a [`QuadExpr`].
139    pub fn pow2(self) -> QuadExpr {
140        self * self
141    }
142}
143
144// ---------------------------------------------------------------------------
145// From conversions
146// ---------------------------------------------------------------------------
147
148impl From<Variable> for QuadExpr {
149    fn from(v: Variable) -> Self {
150        QuadExpr { quad: HashMap::new(), linear: Expression::from(v) }
151    }
152}
153
154impl From<Expression> for QuadExpr {
155    fn from(e: Expression) -> Self {
156        QuadExpr { quad: HashMap::new(), linear: e }
157    }
158}
159
160impl From<f64> for QuadExpr {
161    fn from(c: f64) -> Self {
162        QuadExpr { quad: HashMap::new(), linear: Expression::from(c) }
163    }
164}
165
166impl From<i32> for QuadExpr {
167    fn from(c: i32) -> Self {
168        QuadExpr { quad: HashMap::new(), linear: Expression::from(c) }
169    }
170}
171
172// ---------------------------------------------------------------------------
173// Variable * Variable → QuadExpr
174// ---------------------------------------------------------------------------
175
176impl Mul<Variable> for Variable {
177    type Output = QuadExpr;
178    fn mul(self, rhs: Variable) -> QuadExpr {
179        let mut quad = HashMap::new();
180        insert_quad_term(&mut quad, canon(self, rhs), 1.0);
181        QuadExpr { quad, linear: Expression::default() }
182    }
183}
184
185// ---------------------------------------------------------------------------
186// Expression * Variable  /  Variable * Expression → QuadExpr
187// ---------------------------------------------------------------------------
188
189impl Mul<Variable> for Expression {
190    type Output = QuadExpr;
191    fn mul(self, var: Variable) -> QuadExpr {
192        let mut quad = HashMap::new();
193        let mut linear = Expression::default();
194        for (&v, &c) in &self.coefficients {
195            // Route through the single chokepoint so zero coefficients (e.g.
196            // from `(x - x) * y`) never enter the map.
197            insert_quad_term(&mut quad, canon(v, var), c);
198        }
199        if self.constant != 0.0 {
200            *linear.coefficients.entry(var).or_insert(0.0) += self.constant;
201        }
202        QuadExpr { quad, linear }
203    }
204}
205
206impl Mul<Expression> for Variable {
207    type Output = QuadExpr;
208    fn mul(self, rhs: Expression) -> QuadExpr {
209        rhs * self
210    }
211}
212
213// ---------------------------------------------------------------------------
214// f64 * QuadExpr  /  QuadExpr * f64
215// ---------------------------------------------------------------------------
216
217impl Mul<f64> for QuadExpr {
218    type Output = QuadExpr;
219    fn mul(mut self, rhs: f64) -> QuadExpr {
220        for v in self.quad.values_mut() {
221            *v *= rhs;
222        }
223        // Prune entries zeroed by multiplication (e.g. 0.0 * x*x → is_linear = true).
224        self.quad.retain(|_, c| *c != 0.0);
225        self.linear = rhs * self.linear;
226        self
227    }
228}
229
230impl Mul<QuadExpr> for f64 {
231    type Output = QuadExpr;
232    fn mul(self, rhs: QuadExpr) -> QuadExpr {
233        rhs * self
234    }
235}
236
237// ---------------------------------------------------------------------------
238// Neg
239// ---------------------------------------------------------------------------
240
241impl Neg for QuadExpr {
242    type Output = QuadExpr;
243    fn neg(mut self) -> QuadExpr {
244        for v in self.quad.values_mut() {
245            *v = -*v;
246        }
247        self.linear = -self.linear;
248        self
249    }
250}
251
252// ---------------------------------------------------------------------------
253// QuadExpr ± QuadExpr
254// ---------------------------------------------------------------------------
255
256impl Add for QuadExpr {
257    type Output = QuadExpr;
258    fn add(mut self, rhs: QuadExpr) -> QuadExpr {
259        self.merge_add(rhs);
260        self
261    }
262}
263
264impl Sub for QuadExpr {
265    type Output = QuadExpr;
266    fn sub(self, rhs: QuadExpr) -> QuadExpr {
267        self + (-rhs)
268    }
269}
270
271// ---------------------------------------------------------------------------
272// QuadExpr ± Expression
273// ---------------------------------------------------------------------------
274
275impl Add<Expression> for QuadExpr {
276    type Output = QuadExpr;
277    fn add(self, rhs: Expression) -> QuadExpr {
278        self + QuadExpr::from(rhs)
279    }
280}
281
282impl Add<QuadExpr> for Expression {
283    type Output = QuadExpr;
284    fn add(self, rhs: QuadExpr) -> QuadExpr {
285        rhs + self
286    }
287}
288
289impl Sub<Expression> for QuadExpr {
290    type Output = QuadExpr;
291    fn sub(self, rhs: Expression) -> QuadExpr {
292        self + (-rhs)
293    }
294}
295
296impl Sub<QuadExpr> for Expression {
297    type Output = QuadExpr;
298    fn sub(self, rhs: QuadExpr) -> QuadExpr {
299        QuadExpr::from(self) + (-rhs)
300    }
301}
302
303// ---------------------------------------------------------------------------
304// QuadExpr ± Variable
305// ---------------------------------------------------------------------------
306
307impl Add<Variable> for QuadExpr {
308    type Output = QuadExpr;
309    fn add(self, rhs: Variable) -> QuadExpr {
310        self + Expression::from(rhs)
311    }
312}
313
314impl Add<QuadExpr> for Variable {
315    type Output = QuadExpr;
316    fn add(self, rhs: QuadExpr) -> QuadExpr {
317        rhs + self
318    }
319}
320
321impl Sub<Variable> for QuadExpr {
322    type Output = QuadExpr;
323    fn sub(self, rhs: Variable) -> QuadExpr {
324        self + (-Expression::from(rhs))
325    }
326}
327
328impl Sub<QuadExpr> for Variable {
329    type Output = QuadExpr;
330    fn sub(self, rhs: QuadExpr) -> QuadExpr {
331        QuadExpr::from(Expression::from(self)) + (-rhs)
332    }
333}
334
335// ---------------------------------------------------------------------------
336// QuadExpr ± f64
337// ---------------------------------------------------------------------------
338
339impl Add<f64> for QuadExpr {
340    type Output = QuadExpr;
341    fn add(self, rhs: f64) -> QuadExpr {
342        self + Expression::from(rhs)
343    }
344}
345
346impl Add<QuadExpr> for f64 {
347    type Output = QuadExpr;
348    fn add(self, rhs: QuadExpr) -> QuadExpr {
349        rhs + self
350    }
351}
352
353impl Sub<f64> for QuadExpr {
354    type Output = QuadExpr;
355    fn sub(self, rhs: f64) -> QuadExpr {
356        self + (-rhs)
357    }
358}
359
360impl Sub<QuadExpr> for f64 {
361    type Output = QuadExpr;
362    fn sub(self, rhs: QuadExpr) -> QuadExpr {
363        self + (-rhs)
364    }
365}
366
367// ---------------------------------------------------------------------------
368// Tests
369// ---------------------------------------------------------------------------
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use crate::Model;
375
376    const TOL: f64 = 1e-5;
377
378    fn assert_close(a: f64, b: f64, label: &str) {
379        assert!((a - b).abs() < TOL, "{label}: expected {b}, got {a}");
380    }
381
382    /// Extract Q[row][col] from a CscMatrix (returns 0.0 if absent).
383    fn q_entry(q: &CscMatrix, row: usize, col: usize) -> f64 {
384        let col_start = q.col_ptr()[col];
385        let col_end = q.col_ptr()[col + 1];
386        for k in col_start..col_end {
387            if q.row_ind()[k] == row {
388                return q.values()[k];
389            }
390        }
391        0.0
392    }
393
394    // --- quad_to_csc unit tests ---
395
396    #[test]
397    fn test_quad_to_csc_diagonal() {
398        let mut model = Model::new("m");
399        let x = model.add_var("x", 0.0, f64::INFINITY);
400        // c · x² with c = 3.0  →  Q[0][0] = 6.0
401        let mut terms = HashMap::new();
402        terms.insert((x, x), 3.0);
403        let q = quad_to_csc(&terms, 1).unwrap();
404        assert_eq!(q_entry(&q, 0, 0), 6.0, "diagonal: Q[0][0] should be 2*c");
405    }
406
407    #[test]
408    fn test_quad_to_csc_cross_symmetric() {
409        let mut model = Model::new("m");
410        let x = model.add_var("x", 0.0, f64::INFINITY);
411        let y = model.add_var("y", 0.0, f64::INFINITY);
412        // c · x·y with c = 5.0  →  Q[0][1] = Q[1][0] = 5.0
413        let mut terms = HashMap::new();
414        terms.insert(canon(x, y), 5.0);
415        let q = quad_to_csc(&terms, 2).unwrap();
416        assert_eq!(q_entry(&q, 0, 1), 5.0, "cross: Q[0][1] must equal c");
417        assert_eq!(q_entry(&q, 1, 0), 5.0, "cross: Q[1][0] must equal c (symmetry)");
418    }
419
420    /// Sentinel: verify that `quad_to_csc` fills both Q[i][j] and Q[j][i].
421    ///
422    /// A broken upper-triangle-only implementation would produce nnz=1 for a
423    /// cross term (only Q[0][1], missing Q[1][0]).  The correct implementation
424    /// produces nnz=2.  We also confirm Q[1][0] == 0 in the broken Q, proving
425    /// the missing entry would cause a wrong (asymmetric) matrix.
426    #[test]
427    fn test_symmetry_sentinel_quad_to_csc_fills_both_sides() {
428        let mut model = Model::new("m");
429        let x = model.add_var("x", 0.0, f64::INFINITY);
430        let y = model.add_var("y", 0.0, f64::INFINITY);
431
432        // DSL: x * y  →  quad[(x,y)] = 1.0  →  must emit Q[0][1] = Q[1][0] = 1.0
433        let mut terms = HashMap::new();
434        terms.insert(canon(x, y), 5.0);
435        let correct = quad_to_csc(&terms, 2).unwrap();
436
437        // Both sides must be present and equal (symmetric fill):
438        assert_eq!(q_entry(&correct, 0, 1), 5.0, "sentinel: Q[0][1] must be 5.0");
439        assert_eq!(
440            q_entry(&correct, 1, 0),
441            5.0,
442            "sentinel: Q[1][0] must be 5.0 — missing this entry is the classic bug"
443        );
444        // nnz = 2: one entry per triangle
445        assert_eq!(correct.nnz(), 2, "sentinel: cross term must emit exactly 2 triplets");
446
447        // No-op proof: a broken upper-triangle-only matrix has Q[1][0] == 0.
448        let broken = CscMatrix::from_triplets(&[0], &[1], &[5.0], 2, 2).unwrap();
449        assert_eq!(broken.nnz(), 1, "broken: only 1 triplet (missing lower side)");
450        assert_eq!(
451            q_entry(&broken, 1, 0),
452            0.0,
453            "broken: Q[1][0] is 0 — this is the missing-symmetry bug"
454        );
455        assert_ne!(
456            q_entry(&broken, 0, 1),
457            q_entry(&broken, 1, 0),
458            "broken: Q is not symmetric (upper ≠ lower), confirming the bug exists"
459        );
460    }
461
462    // --- Operator DSL tests ---
463
464    #[test]
465    fn test_var_times_var_is_quadratic() {
466        let mut model = Model::new("m");
467        let x = model.add_var("x", 0.0, f64::INFINITY);
468        let y = model.add_var("y", 0.0, f64::INFINITY);
469        let q = x * x;
470        assert!(!q.is_linear());
471        let q2 = x * y;
472        assert!(!q2.is_linear());
473    }
474
475    #[test]
476    fn test_pow2_equals_var_times_var() {
477        let mut model = Model::new("m");
478        let x = model.add_var("x", 0.0, f64::INFINITY);
479        let q1 = x * x;
480        let q2 = x.pow2();
481        // Both should produce the same quad entry
482        assert_eq!(q1.quad.len(), 1);
483        assert_eq!(q2.quad.len(), 1);
484        let c1: f64 = q1.quad.values().copied().sum();
485        let c2: f64 = q2.quad.values().copied().sum();
486        assert!((c1 - c2).abs() < 1e-12);
487    }
488
489    #[test]
490    fn test_scalar_mul_quad_expr() {
491        let mut model = Model::new("m");
492        let x = model.add_var("x", 0.0, f64::INFINITY);
493        // 3.0 * x * x: coefficient should be 3.0 in quad map
494        let q = 3.0 * (x * x);
495        assert_eq!(q.quad.len(), 1);
496        let c: f64 = q.quad.values().copied().sum();
497        assert!((c - 3.0).abs() < 1e-12, "scalar mul: coefficient should be 3.0, got {c}");
498    }
499
500    #[test]
501    fn test_expression_times_var() {
502        let mut model = Model::new("m");
503        let x = model.add_var("x", 0.0, f64::INFINITY);
504        let y = model.add_var("y", 0.0, f64::INFINITY);
505        // (2.0 * x) * y  →  QuadExpr with quad[(x,y)] = 2.0
506        let expr = 2.0 * x;
507        let q = expr * y;
508        assert!(!q.is_linear());
509        // Extract coefficient for the x-y pair
510        let c: f64 = q.quad.values().copied().sum();
511        assert!((c - 2.0).abs() < 1e-12, "expr*var: coefficient should be 2.0, got {c}");
512    }
513
514    #[test]
515    fn test_add_quadexprs() {
516        let mut model = Model::new("m");
517        let x = model.add_var("x", 0.0, f64::INFINITY);
518        let y = model.add_var("y", 0.0, f64::INFINITY);
519        // x*x + y*y should have two diagonal entries
520        let q = x * x + y * y;
521        assert_eq!(q.quad.len(), 2);
522    }
523
524    #[test]
525    fn test_neg_quad_expr() {
526        let mut model = Model::new("m");
527        let x = model.add_var("x", 0.0, f64::INFINITY);
528        let q = -(x * x);
529        let c: f64 = q.quad.values().copied().sum();
530        assert!((c + 1.0).abs() < 1e-12, "neg: coefficient should be -1.0, got {c}");
531    }
532
533    #[test]
534    fn test_mixed_quad_linear() {
535        let mut model = Model::new("m");
536        let x = model.add_var("x", 0.0, f64::INFINITY);
537        let y = model.add_var("y", 0.0, f64::INFINITY);
538        // 2*x*x + 3*x*y + y   (quadratic + linear mixed)
539        let q = 2.0 * x * x + 3.0 * x * y + y;
540        assert!(!q.is_linear());
541        // Should have two quad entries: (x,x) and (x,y)
542        assert_eq!(q.quad.len(), 2);
543        // Linear part should have y with coefficient 1.0
544        let lin_y = q.linear.coefficient(y);
545        assert!((lin_y - 1.0).abs() < 1e-12, "linear y coeff should be 1.0, got {lin_y}");
546    }
547
548    // --- Model solve tests (through Model API) ---
549
550    #[test]
551    fn test_minimize_x_squared_with_lb() {
552        // min x²  s.t. x ≥ 1  →  x* = 1, obj* = 1
553        let mut model = Model::new("min_x2");
554        let x = model.add_var("x", 1.0, f64::INFINITY);
555        model.minimize(x * x);
556        let result = model.solve().unwrap();
557        assert_close(result[x], 1.0, "min x²: x*");
558        assert_close(result.objective_value, 1.0, "min x²: obj*");
559    }
560
561    #[test]
562    fn test_minimize_x_squared_plus_y_squared() {
563        // min x² + y²  s.t. x + y = 2, x,y ≥ 0  →  x* = y* = 1, obj* = 2
564        let mut model = Model::new("min_x2_y2");
565        let x = model.add_var("x", 0.0, f64::INFINITY);
566        let y = model.add_var("y", 0.0, f64::INFINITY);
567        model.add_constraint((x + y).eq_constraint(2.0));
568        model.minimize(x * x + y * y);
569        let result = model.solve().unwrap();
570        assert_close(result[x], 1.0, "min x²+y²: x*");
571        assert_close(result[y], 1.0, "min x²+y²: y*");
572        assert_close(result.objective_value, 2.0, "min x²+y²: obj*");
573    }
574
575    #[test]
576    fn test_minimize_pow2_api() {
577        // Same as above via x.pow2() + y.pow2()
578        let mut model = Model::new("pow2");
579        let x = model.add_var("x", 0.0, f64::INFINITY);
580        let y = model.add_var("y", 0.0, f64::INFINITY);
581        model.add_constraint((x + y).eq_constraint(2.0));
582        model.minimize(x.pow2() + y.pow2());
583        let result = model.solve().unwrap();
584        assert_close(result.objective_value, 2.0, "pow2 API: obj*");
585    }
586
587    #[test]
588    fn test_maximize_concave_qp() {
589        // max -x² + 4x  s.t. x ≥ 0  (concave → unique interior max at x=2, obj=4)
590        // Sign-flip check: Q must be negated for maximize.
591        let mut model = Model::new("max_concave");
592        let x = model.add_var("x", 0.0, f64::INFINITY);
593        model.maximize(-(x * x) + 4.0 * x);
594        let result = model.solve().unwrap();
595        assert_close(result[x], 2.0, "max -x²+4x: x*");
596        assert_close(result.objective_value, 4.0, "max -x²+4x: obj*");
597    }
598
599    #[test]
600    fn test_minimize_cross_term_q_symmetry() {
601        // min x² + x·y + y²  s.t. x + y = 2, x,y ≥ 0
602        // → x* = y* = 1, obj* = 1 + 1 + 1 = 3
603        //
604        // Symmetry proof: if Q[0][1] were set but Q[1][0] omitted (upper-triangle only),
605        // the effective objective would be x² + y² + ½·x·y, giving obj* = 2.5 ≠ 3.
606        // This test therefore FAILS under a broken (upper-triangle-only) implementation.
607        let mut model = Model::new("cross_sym");
608        let x = model.add_var("x", 0.0, f64::INFINITY);
609        let y = model.add_var("y", 0.0, f64::INFINITY);
610        model.add_constraint((x + y).eq_constraint(2.0));
611        model.minimize(x * x + x * y + y * y);
612        let result = model.solve().unwrap();
613        let tol = 1e-3;
614        assert!(
615            (result[x] - 1.0).abs() < tol,
616            "cross_sym: x* ≈ 1, got {}",
617            result[x]
618        );
619        assert!(
620            (result[y] - 1.0).abs() < tol,
621            "cross_sym: y* ≈ 1, got {}",
622            result[y]
623        );
624        assert!(
625            (result.objective_value - 3.0).abs() < tol,
626            "cross_sym: obj* ≈ 3 (symmetric Q fill required), got {}",
627            result.objective_value
628        );
629    }
630
631    #[test]
632    fn test_mixed_quad_linear_solve() {
633        // min x² - 4x  s.t. x ≥ 0  →  x* = 2, obj* = 4 - 8 = -4
634        // Written as: minimize(x*x + (-4.0) * x)
635        let mut model = Model::new("quad_linear");
636        let x = model.add_var("x", 0.0, f64::INFINITY);
637        model.minimize(x * x + (-4.0) * x);
638        let result = model.solve().unwrap();
639        assert_close(result[x], 2.0, "quad+linear: x*");
640        assert_close(result.objective_value, -4.0, "quad+linear: obj*");
641    }
642
643    #[test]
644    fn test_scalar_multiple_quad_solve() {
645        // min 2·x² - 8·x  s.t. x ≥ 0  →  x* = 2, obj* = 8 - 16 = -8
646        let mut model = Model::new("2x2_8x");
647        let x = model.add_var("x", 0.0, f64::INFINITY);
648        model.minimize(2.0 * x * x + (-8.0) * x);
649        let result = model.solve().unwrap();
650        assert_close(result[x], 2.0, "2x²-8x: x*");
651        assert_close(result.objective_value, -8.0, "2x²-8x: obj*");
652    }
653
654    #[test]
655    fn test_dsl_qp_solves_correctly() {
656        // Verify DSL minimize(x*x + y*y) gives correct answer
657        // min x² + y²  s.t. x+y=3, x,y≥0  →  x=y=1.5, obj=4.5
658        let mut m = Model::new("dsl");
659        let x = m.add_var("x", 0.0, f64::INFINITY);
660        let y = m.add_var("y", 0.0, f64::INFINITY);
661        m.add_constraint((x + y).eq_constraint(3.0));
662        m.minimize(x * x + y * y);
663        let r = m.solve().unwrap();
664
665        let tol = 1e-3;
666        assert!((r[x] - 1.5).abs() < tol, "DSL x={} expected 1.5", r[x]);
667        assert!((r[y] - 1.5).abs() < tol, "DSL y={} expected 1.5", r[y]);
668        assert!((r.objective_value - 4.5).abs() < tol, "DSL obj={} expected 4.5", r.objective_value);
669    }
670
671    #[test]
672    fn test_linear_objective_still_works_after_quad_change() {
673        // minimize(x) should work normally via Into<QuadExpr> (pure-linear path)
674        let mut model = Model::new("lin");
675        let x = model.add_var("x", 2.0, 10.0);
676        model.minimize(x);
677        let result = model.solve().unwrap();
678        assert_close(result[x], 2.0, "linear min x: x*");
679    }
680
681    #[test]
682    fn test_from_expression_into_quad_expr() {
683        // model.minimize(2.0 * x + y) — Expression → QuadExpr via From
684        let mut model = Model::new("lin_expr");
685        let x = model.add_var("x", 0.0, f64::INFINITY);
686        let y = model.add_var("y", 0.0, 10.0);
687        model.add_constraint((x + y).geq(3.0));
688        model.minimize(2.0 * x + y);  // Expression into QuadExpr (no quad terms)
689        let result = model.solve().unwrap();
690        assert_close(result[x], 0.0, "linear via QuadExpr: x*");
691        assert_close(result[y], 3.0, "linear via QuadExpr: y*");
692    }
693
694    // --- P3.1: ゼロ quad entry の pruning ---
695
696    #[test]
697    fn test_cancelled_quad_term_is_linear() {
698        // x*y - x*y should cancel to zero quad terms → is_linear() == true
699        let mut model = Model::new("m");
700        let x = model.add_var("x", 0.0, f64::INFINITY);
701        let y = model.add_var("y", 0.0, f64::INFINITY);
702        let q = x * y - x * y;
703        assert!(q.is_linear(), "x*y - x*y should cancel to is_linear() == true");
704    }
705
706    #[test]
707    fn test_zero_scalar_mul_is_linear() {
708        // 0.0 * (x * x) should prune the quad entry → is_linear() == true
709        let mut model = Model::new("m");
710        let x = model.add_var("x", 0.0, f64::INFINITY);
711        let q = 0.0 * (x * x);
712        assert!(q.is_linear(), "0.0 * x*x should prune to is_linear() == true");
713    }
714
715    #[test]
716    fn test_cancelled_quad_routes_to_lp() {
717        // x*y - x*y (pure linear 0) minimized: routes to LP, not QP
718        // With only constant, any feasible x,y is optimal with obj=0.
719        let mut model = Model::new("cancel_route");
720        let x = model.add_var("x", 2.0, 2.0);
721        let y = model.add_var("y", 3.0, 3.0);
722        model.minimize(x * y - x * y + 1.0);  // = 1.0 (constant only, LP path)
723        let result = model.solve().unwrap();
724        assert!((result.objective_value - 1.0).abs() < TOL,
725            "cancelled quad routes to LP: obj should be 1.0, got {}", result.objective_value);
726    }
727
728    // --- P3.3: coverage gap — NaN 係数 / indefinite QP ---
729
730    #[test]
731    fn test_nan_quad_coefficient_gives_error() {
732        // NaN coefficient in quadratic term should produce an error at solve time.
733        let mut model = Model::new("nan_q");
734        let x = model.add_var("x", 0.0, f64::INFINITY);
735        let q_expr = f64::NAN * (x * x);
736        model.minimize(q_expr);
737        let result = model.solve();
738        assert!(
739            result.is_err(),
740            "NaN quad coefficient should produce an error, got Ok"
741        );
742    }
743
744    #[test]
745    fn test_indefinite_qp_no_silent_optimal() {
746        use crate::SolutionProof;
747        // min x·y  s.t. x+y≥1, x,y≥0 — indefinite (non-convex) QP.
748        // Must NOT silently claim GlobalOptimal.
749        let mut model = Model::new("indef");
750        let x = model.add_var("x", 0.0, f64::INFINITY);
751        let y = model.add_var("y", 0.0, f64::INFINITY);
752        model.add_constraint((x + y).geq(1.0));
753        model.minimize(x * y);
754        let result = model.solve();
755        match result {
756            Ok(r) => {
757                assert_ne!(
758                    r.proof,
759                    SolutionProof::GlobalOptimal,
760                    "indefinite QP must not claim global optimality"
761                );
762            }
763            Err(_) => {
764                // Error (e.g. NonConvex) is also acceptable for indefinite QP
765            }
766        }
767    }
768
769    // ---------------------------------------------------------------------------
770    // P2-c: zero-coefficient prune via single chokepoint (insert_quad_term)
771    // ---------------------------------------------------------------------------
772
773    // `(x - x) * y` — Expression has coef[x]=0, must not enter quad map.
774    #[test]
775    fn test_zero_coef_expr_times_var_is_linear() {
776        let mut model = Model::new("m");
777        let x = model.add_var("x", 0.0, f64::INFINITY);
778        let y = model.add_var("y", 0.0, f64::INFINITY);
779        let q = (x - x) * y;
780        assert!(q.is_linear(), "(x-x)*y must be is_linear(); quad.len()={}", q.quad.len());
781    }
782
783    // `(x + x - 2*x) * y` — three-way cancellation.
784    #[test]
785    fn test_multi_cancel_expr_times_var_is_linear() {
786        let mut model = Model::new("m");
787        let x = model.add_var("x", 0.0, f64::INFINITY);
788        let y = model.add_var("y", 0.0, f64::INFINITY);
789        let q = (x + x + ((-2.0) * x)) * y;
790        assert!(q.is_linear(), "(x+x-2x)*y must be is_linear(); quad.len()={}", q.quad.len());
791    }
792
793    // x*x - x*x: merge cancellation still works (existing test extended).
794    #[test]
795    fn test_quad_sub_self_is_linear() {
796        let mut model = Model::new("m");
797        let x = model.add_var("x", 0.0, f64::INFINITY);
798        let q = x * x - x * x;
799        assert!(q.is_linear(), "x*x - x*x must cancel to is_linear()");
800        assert_eq!(q.quad.len(), 0, "quad map must be empty after cancellation");
801    }
802
803    // ---------------------------------------------------------------------------
804    // P2-d: cross-model variable validation in apply_objective
805    // ---------------------------------------------------------------------------
806
807    // Diagonal term from another model must be rejected.
808    #[test]
809    fn test_p2d_cross_model_diagonal_rejected() {
810        use crate::ModelError;
811        let mut m1 = Model::new("m1");
812        let x1 = m1.add_var("x", 0.0, f64::INFINITY);
813
814        let mut m2 = Model::new("m2");
815        // x1 has m1's model_id; minimizing it in m2 must error.
816        m2.minimize(x1 * x1);
817        let result = m2.solve();
818        assert!(
819            matches!(result, Err(ModelError::InvalidInput(_))),
820            "P2-d: cross-model diagonal must give InvalidInput, got {result:?}"
821        );
822    }
823
824    // Cross term mixing variables from two models must be rejected.
825    #[test]
826    fn test_p2d_cross_model_mixed_term_rejected() {
827        use crate::ModelError;
828        let mut m1 = Model::new("m1");
829        let x1 = m1.add_var("x", 0.0, f64::INFINITY);
830
831        let mut m2 = Model::new("m2");
832        let y2 = m2.add_var("y", 0.0, f64::INFINITY);
833
834        // x1 belongs to m1, y2 belongs to m2; the cross term is invalid for both.
835        m1.minimize(x1 * y2);
836        let result = m1.solve();
837        assert!(
838            matches!(result, Err(ModelError::InvalidInput(_))),
839            "P2-d: cross-model cross-term must give InvalidInput, got {result:?}"
840        );
841    }
842
843    // Sanity: same-model variable works correctly (no false positive).
844    #[test]
845    fn test_p2d_same_model_accepted() {
846        let mut model = Model::new("sanity");
847        let x = model.add_var("x", 1.0, f64::INFINITY);
848        model.minimize(x * x);
849        let result = model.solve();
850        assert!(result.is_ok(), "P2-d: same-model quad must be accepted, got {result:?}");
851    }
852
853    // maximize path: cross-model rejection via maximize (not just minimize).
854    #[test]
855    fn test_p2d_cross_model_maximize_rejected() {
856        use crate::ModelError;
857        let mut m1 = Model::new("m1");
858        let x1 = m1.add_var("x", 0.0, 5.0);
859
860        let mut m2 = Model::new("m2");
861        m2.maximize(x1 * x1);
862        let result = m2.solve();
863        assert!(
864            matches!(result, Err(ModelError::InvalidInput(_))),
865            "P2-d: cross-model maximize must give InvalidInput, got {result:?}"
866        );
867    }
868}