Skip to main content

otspot_core/qp/
diagnose.rs

1//! QP 問題の事前診断 API。
2
3use super::problem::QpProblem;
4
5const DIAG_TOL: f64 = 1e-10;
6const BOUND_TOL: f64 = 1e-10;
7/// IPM KKT 行列条件数の経験的許容上限。
8const SCALE_WARN_THRESHOLD: f64 = 1e8;
9const ZERO_B_TOL: f64 = 1e-12;
10
11#[non_exhaustive]
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum Severity {
14    Error,
15    Warning,
16    Info,
17}
18
19#[non_exhaustive]
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub enum DiagnosticCode {
22    QNegativeDiagonal,
23    QNotSymmetric,
24    VariableBoundsConflict,
25    PoorScaling,
26    ZeroRowInA,
27    ProblemSize,
28}
29
30#[derive(Debug, Clone)]
31pub struct DiagnosticWarning {
32    pub code: DiagnosticCode,
33    pub severity: Severity,
34    pub message: String,
35    pub variable_index: Option<usize>,
36    pub constraint_index: Option<usize>,
37}
38
39#[derive(Debug, Clone)]
40pub struct ProblemInfo {
41    pub n: usize,
42    pub m: usize,
43    pub nnz_q: usize,
44    pub nnz_a: usize,
45}
46
47#[derive(Debug, Clone)]
48pub struct DiagnosticReport {
49    pub warnings: Vec<DiagnosticWarning>,
50    pub info: ProblemInfo,
51    pub has_error: bool,
52}
53
54fn coefficient_ratio(values: &[f64]) -> Option<f64> {
55    let mut max_v = 0.0_f64;
56    let mut min_v = f64::INFINITY;
57    for &v in values {
58        let av = v.abs();
59        if av > 1e-15 {
60            if av > max_v {
61                max_v = av;
62            }
63            if av < min_v {
64                min_v = av;
65            }
66        }
67    }
68    if min_v == f64::INFINITY {
69        None
70    } else {
71        Some(max_v / min_v)
72    }
73}
74
75/// solve() 前の軽量チェック。コストは O(nnz_Q + nnz_A + n + m)。
76pub fn diagnose(problem: &QpProblem) -> DiagnosticReport {
77    let mut warnings: Vec<DiagnosticWarning> = Vec::new();
78
79    for col in 0..problem.num_vars {
80        let start = problem.q.col_ptr[col];
81        let end = problem.q.col_ptr[col + 1];
82        for k in start..end {
83            if problem.q.row_ind[k] == col && problem.q.values[k] < -DIAG_TOL {
84                warnings.push(DiagnosticWarning {
85                    code: DiagnosticCode::QNegativeDiagonal,
86                    severity: Severity::Error,
87                    message: format!(
88                        "Q[{},{}] = {:.6e} < 0: Q is not PSD",
89                        col, col, problem.q.values[k]
90                    ),
91                    variable_index: Some(col),
92                    constraint_index: None,
93                });
94            }
95        }
96    }
97
98    let mut found_lower = false;
99    'outer: for col in 0..problem.num_vars {
100        let start = problem.q.col_ptr[col];
101        let end = problem.q.col_ptr[col + 1];
102        for k in start..end {
103            if problem.q.row_ind[k] > col {
104                found_lower = true;
105                break 'outer;
106            }
107        }
108    }
109    if found_lower {
110        warnings.push(DiagnosticWarning {
111            code: DiagnosticCode::QNotSymmetric,
112            severity: Severity::Warning,
113            message: "Q has sub-diagonal entries: input may not be upper-triangular or symmetric"
114                .to_string(),
115            variable_index: None,
116            constraint_index: None,
117        });
118    }
119
120    for (j, &(lb, ub)) in problem.bounds.iter().enumerate() {
121        if lb > ub + BOUND_TOL {
122            warnings.push(DiagnosticWarning {
123                code: DiagnosticCode::VariableBoundsConflict,
124                severity: Severity::Error,
125                message: format!(
126                    "variable {}: lb ({:.6e}) > ub ({:.6e}): infeasible bounds",
127                    j, lb, ub
128                ),
129                variable_index: Some(j),
130                constraint_index: None,
131            });
132        }
133    }
134
135    if let Some(ratio) = coefficient_ratio(&problem.q.values) {
136        if ratio > SCALE_WARN_THRESHOLD {
137            warnings.push(DiagnosticWarning {
138                code: DiagnosticCode::PoorScaling,
139                severity: Severity::Warning,
140                message: format!(
141                    "Q coefficient ratio = {:.2e} > {:.2e}: poor scaling may cause numerical issues",
142                    ratio, SCALE_WARN_THRESHOLD
143                ),
144                variable_index: None,
145                constraint_index: None,
146            });
147        }
148    }
149    if let Some(ratio) = coefficient_ratio(&problem.a.values) {
150        if ratio > SCALE_WARN_THRESHOLD {
151            warnings.push(DiagnosticWarning {
152                code: DiagnosticCode::PoorScaling,
153                severity: Severity::Warning,
154                message: format!(
155                    "A coefficient ratio = {:.2e} > {:.2e}: poor scaling may cause numerical issues",
156                    ratio, SCALE_WARN_THRESHOLD
157                ),
158                variable_index: None,
159                constraint_index: None,
160            });
161        }
162    }
163    if let Some(ratio) = coefficient_ratio(&problem.c) {
164        if ratio > SCALE_WARN_THRESHOLD {
165            warnings.push(DiagnosticWarning {
166                code: DiagnosticCode::PoorScaling,
167                severity: Severity::Warning,
168                message: format!(
169                    "c coefficient ratio = {:.2e} > {:.2e}: poor scaling may cause numerical issues",
170                    ratio, SCALE_WARN_THRESHOLD
171                ),
172                variable_index: None,
173                constraint_index: None,
174            });
175        }
176    }
177
178    if problem.num_constraints > 0 {
179        let mut row_has_nonzero = vec![false; problem.num_constraints];
180        for &row in &problem.a.row_ind {
181            row_has_nonzero[row] = true;
182        }
183        for (i, &present) in row_has_nonzero.iter().enumerate() {
184            if !present {
185                let severity = if problem.b[i] < -ZERO_B_TOL {
186                    Severity::Error
187                } else {
188                    Severity::Warning
189                };
190                let msg = if severity == Severity::Error {
191                    format!(
192                        "constraint {}: zero row in A with b[{}] = {:.6e} < 0: infeasible (0 <= {})",
193                        i, i, problem.b[i], problem.b[i]
194                    )
195                } else {
196                    format!(
197                        "constraint {}: zero row in A with b[{}] = {:.6e} >= 0: redundant constraint",
198                        i, i, problem.b[i]
199                    )
200                };
201                warnings.push(DiagnosticWarning {
202                    code: DiagnosticCode::ZeroRowInA,
203                    severity,
204                    message: msg,
205                    variable_index: None,
206                    constraint_index: Some(i),
207                });
208            }
209        }
210    }
211
212    let info = ProblemInfo {
213        n: problem.num_vars,
214        m: problem.num_constraints,
215        nnz_q: problem.q.nnz(),
216        nnz_a: problem.a.nnz(),
217    };
218    warnings.push(DiagnosticWarning {
219        code: DiagnosticCode::ProblemSize,
220        severity: Severity::Info,
221        message: format!(
222            "problem size: n={}, m={}, nnz_Q={}, nnz_A={}",
223            info.n, info.m, info.nnz_q, info.nnz_a
224        ),
225        variable_index: None,
226        constraint_index: None,
227    });
228
229    let has_error = warnings.iter().any(|w| w.severity == Severity::Error);
230
231    DiagnosticReport { warnings, info, has_error }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use crate::sparse::CscMatrix;
238    use crate::qp::problem::QpProblem;
239
240    fn make_simple_problem() -> QpProblem {
241        let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, 2.0], 2, 2).unwrap();
242        let c = vec![0.0, 0.0];
243        let a = CscMatrix::from_triplets(&[0, 0], &[0, 1], &[-1.0, -1.0], 1, 2).unwrap();
244        let b = vec![-1.0];
245        let bounds = vec![(f64::NEG_INFINITY, f64::INFINITY); 2];
246        QpProblem::new_all_le(q, c, a, b, bounds).unwrap()
247    }
248
249    #[test]
250    fn test_q_negative_diagonal_clean() {
251        let prob = make_simple_problem();
252        let report = diagnose(&prob);
253        assert!(!report.has_error);
254        let neg = report.warnings.iter().any(|w| w.code == DiagnosticCode::QNegativeDiagonal);
255        assert!(!neg);
256    }
257
258    #[test]
259    fn test_q_negative_diagonal_detected() {
260        let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-2.0, 2.0], 2, 2).unwrap();
261        let c = vec![0.0, 0.0];
262        let a = CscMatrix::new(0, 2);
263        let b = vec![];
264        let bounds = vec![(f64::NEG_INFINITY, f64::INFINITY); 2];
265        let prob = QpProblem::new_all_le(q, c, a, b, bounds).unwrap();
266        let report = diagnose(&prob);
267        assert!(report.has_error);
268        let w = report.warnings.iter().find(|w| w.code == DiagnosticCode::QNegativeDiagonal);
269        assert!(w.is_some());
270        assert_eq!(w.unwrap().variable_index, Some(0));
271    }
272
273    #[test]
274    fn test_q_symmetric_clean() {
275        let prob = make_simple_problem();
276        let report = diagnose(&prob);
277        let w = report.warnings.iter().any(|w| w.code == DiagnosticCode::QNotSymmetric);
278        assert!(!w);
279    }
280
281    #[test]
282    fn test_q_not_symmetric_detected() {
283        let q = CscMatrix::from_triplets(
284            &[0, 1, 0, 1],
285            &[0, 0, 1, 1],
286            &[2.0, 1.0, 1.0, 2.0],
287            2, 2,
288        ).unwrap();
289        let c = vec![0.0, 0.0];
290        let a = CscMatrix::new(0, 2);
291        let b = vec![];
292        let bounds = vec![(f64::NEG_INFINITY, f64::INFINITY); 2];
293        let prob = QpProblem::new_all_le(q, c, a, b, bounds).unwrap();
294        let report = diagnose(&prob);
295        let w = report.warnings.iter().find(|w| w.code == DiagnosticCode::QNotSymmetric);
296        assert!(w.is_some());
297        assert_eq!(w.unwrap().severity, Severity::Warning);
298    }
299
300    #[test]
301    fn test_bounds_conflict_clean() {
302        let prob = make_simple_problem();
303        let report = diagnose(&prob);
304        let w = report.warnings.iter().any(|w| w.code == DiagnosticCode::VariableBoundsConflict);
305        assert!(!w);
306    }
307
308    #[test]
309    fn test_bounds_conflict_detected() {
310        let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, 2.0], 2, 2).unwrap();
311        let c = vec![0.0, 0.0];
312        let a = CscMatrix::new(0, 2);
313        let b = vec![];
314        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
315        let mut prob = QpProblem::new_all_le(q, c, a, b, bounds).unwrap();
316        // Inject invalid bound post-construction (public field) to test diagnose detection.
317        prob.bounds[1] = (2.0, 1.0);
318        let report = diagnose(&prob);
319        assert!(report.has_error);
320        let w = report.warnings.iter().find(|w| w.code == DiagnosticCode::VariableBoundsConflict);
321        assert!(w.is_some());
322        assert_eq!(w.unwrap().variable_index, Some(1));
323    }
324
325    #[test]
326    fn test_poor_scaling_clean() {
327        let prob = make_simple_problem();
328        let report = diagnose(&prob);
329        let w = report.warnings.iter().any(|w| w.code == DiagnosticCode::PoorScaling);
330        assert!(!w);
331    }
332
333    #[test]
334    fn test_poor_scaling_detected() {
335        let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[1e10, 1.0], 2, 2).unwrap();
336        let c = vec![0.0, 0.0];
337        let a = CscMatrix::new(0, 2);
338        let b = vec![];
339        let bounds = vec![(f64::NEG_INFINITY, f64::INFINITY); 2];
340        let prob = QpProblem::new_all_le(q, c, a, b, bounds).unwrap();
341        let report = diagnose(&prob);
342        let w = report.warnings.iter().find(|w| w.code == DiagnosticCode::PoorScaling);
343        assert!(w.is_some());
344        assert_eq!(w.unwrap().severity, Severity::Warning);
345    }
346
347    #[test]
348    fn test_zero_row_in_a_clean() {
349        let prob = make_simple_problem();
350        let report = diagnose(&prob);
351        let w = report.warnings.iter().any(|w| w.code == DiagnosticCode::ZeroRowInA);
352        assert!(!w);
353    }
354
355    #[test]
356    fn test_zero_row_in_a_warning() {
357        let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, 2.0], 2, 2).unwrap();
358        let c = vec![0.0, 0.0];
359        let a = CscMatrix::from_triplets(&[1, 1], &[0, 1], &[-1.0, -1.0], 2, 2).unwrap();
360        let b = vec![0.0, -1.0];
361        let bounds = vec![(f64::NEG_INFINITY, f64::INFINITY); 2];
362        let prob = QpProblem::new_all_le(q, c, a, b, bounds).unwrap();
363        let report = diagnose(&prob);
364        let w = report.warnings.iter().find(|w| w.code == DiagnosticCode::ZeroRowInA);
365        assert!(w.is_some());
366        assert_eq!(w.unwrap().severity, Severity::Warning);
367        assert_eq!(w.unwrap().constraint_index, Some(0));
368    }
369
370    #[test]
371    fn test_zero_row_in_a_error() {
372        let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[2.0, 2.0], 2, 2).unwrap();
373        let c = vec![0.0, 0.0];
374        let a = CscMatrix::from_triplets(&[1, 1], &[0, 1], &[-1.0, -1.0], 2, 2).unwrap();
375        let b = vec![-1.0, -1.0];
376        let bounds = vec![(f64::NEG_INFINITY, f64::INFINITY); 2];
377        let prob = QpProblem::new_all_le(q, c, a, b, bounds).unwrap();
378        let report = diagnose(&prob);
379        assert!(report.has_error);
380        let w = report.warnings.iter().find(|w| w.code == DiagnosticCode::ZeroRowInA);
381        assert!(w.is_some());
382        assert_eq!(w.unwrap().severity, Severity::Error);
383    }
384
385    #[test]
386    fn test_problem_size_always_present() {
387        let prob = make_simple_problem();
388        let report = diagnose(&prob);
389        assert_eq!(report.info.n, 2);
390        assert_eq!(report.info.m, 1);
391        assert_eq!(report.info.nnz_q, 2);
392        assert_eq!(report.info.nnz_a, 2);
393        let w = report.warnings.iter().find(|w| w.code == DiagnosticCode::ProblemSize);
394        assert!(w.is_some());
395        assert_eq!(w.unwrap().severity, Severity::Info);
396    }
397
398    #[test]
399    fn test_multiple_errors_combined() {
400        let q = CscMatrix::from_triplets(&[0, 1], &[0, 1], &[-1.0, 2.0], 2, 2).unwrap();
401        let c = vec![0.0, 0.0];
402        let a = CscMatrix::new(0, 2);
403        let b = vec![];
404        let bounds = vec![(0.0, 1.0), (0.0, 1.0)];
405        let mut prob = QpProblem::new_all_le(q, c, a, b, bounds).unwrap();
406        // Inject invalid bound post-construction to test combined error detection.
407        prob.bounds[0] = (5.0, 1.0);
408        let report = diagnose(&prob);
409        assert!(report.has_error);
410        let errors: Vec<_> = report.warnings.iter()
411            .filter(|w| w.severity == Severity::Error)
412            .collect();
413        assert!(errors.len() >= 2);
414    }
415}