Skip to main content

fdars_core/explain/
mod.rs

1//! Explainability toolkit for FPC-based scalar-on-function models.
2//!
3//! - [`functional_pdp`] / [`functional_pdp_logistic`] — PDP/ICE
4//! - [`beta_decomposition`] / [`beta_decomposition_logistic`] — per-FPC β(t) decomposition
5//! - [`significant_regions`] / [`significant_regions_from_se`] — CI-based significant intervals
6//! - [`fpc_permutation_importance`] / [`fpc_permutation_importance_logistic`] — permutation importance
7//! - [`influence_diagnostics`] — Cook's distance and leverage
8//! - [`friedman_h_statistic`] / [`friedman_h_statistic_logistic`] — FPC interaction detection
9//! - [`pointwise_importance`] / [`pointwise_importance_logistic`] — pointwise variable importance
10//! - [`fpc_vif`] / [`fpc_vif_logistic`] — variance inflation factors
11//! - [`fpc_shap_values`] / [`fpc_shap_values_logistic`] — SHAP values
12//! - [`dfbetas_dffits`] — DFBETAS and DFFITS influence diagnostics
13//! - [`prediction_intervals`] — prediction intervals for new observations
14//! - [`fpc_ale`] / [`fpc_ale_logistic`] — accumulated local effects
15//! - [`loo_cv_press`] — LOO-CV / PRESS diagnostics
16//! - [`sobol_indices`] / [`sobol_indices_logistic`] — Sobol sensitivity indices
17//! - [`calibration_diagnostics`] — calibration diagnostics (logistic)
18//! - [`functional_saliency`] / [`functional_saliency_logistic`] — functional saliency maps
19//! - [`domain_selection`] / [`domain_selection_logistic`] — domain/interval importance
20//! - [`conditional_permutation_importance`] / [`conditional_permutation_importance_logistic`]
21//! - [`counterfactual_regression`] / [`counterfactual_logistic`] — counterfactual explanations
22//! - [`prototype_criticism`] — MMD-based prototype/criticism selection
23//! - [`lime_explanation`] / [`lime_explanation_logistic`] — LIME local surrogates
24//! - [`expected_calibration_error`] — ECE, MCE, ACE calibration metrics
25//! - [`conformal_prediction_residuals`] — split-conformal prediction intervals
26//! - [`regression_depth`] / [`regression_depth_logistic`] — depth-based regression diagnostics
27//! - [`explanation_stability`] / [`explanation_stability_logistic`] — bootstrap stability analysis
28//! - [`anchor_explanation`] / [`anchor_explanation_logistic`] — beam-search anchor rules
29
30// Submodules
31pub(crate) mod helpers;
32
33mod advanced;
34mod ale_lime;
35mod counterfactual;
36mod diagnostics;
37mod importance;
38mod pdp;
39mod sensitivity;
40mod shap;
41
42// ===========================================================================
43// Public re-exports (backward compatible)
44// ===========================================================================
45
46// --- pdp.rs ---
47pub use pdp::{
48    beta_decomposition, beta_decomposition_logistic, functional_pdp, functional_pdp_logistic,
49    significant_regions, significant_regions_from_se, BetaDecomposition, FunctionalPdpResult,
50    SignificanceDirection, SignificantRegion,
51};
52
53// --- importance.rs ---
54pub use importance::{
55    conditional_permutation_importance, conditional_permutation_importance_logistic,
56    fpc_permutation_importance, fpc_permutation_importance_logistic, pointwise_importance,
57    pointwise_importance_logistic, ConditionalPermutationImportanceResult,
58    FpcPermutationImportance, PointwiseImportanceResult,
59};
60
61// --- diagnostics.rs ---
62pub use diagnostics::{
63    dfbetas_dffits, fpc_vif, fpc_vif_logistic, influence_diagnostics, loo_cv_press,
64    prediction_intervals, DfbetasDffitsResult, InfluenceDiagnostics, LooCvResult,
65    PredictionIntervalResult, VifResult,
66};
67
68// --- shap.rs ---
69pub use shap::{
70    fpc_shap_values, fpc_shap_values_logistic, friedman_h_statistic, friedman_h_statistic_logistic,
71    FpcShapValues, FriedmanHResult,
72};
73
74// --- ale_lime.rs ---
75pub use ale_lime::{
76    fpc_ale, fpc_ale_logistic, lime_explanation, lime_explanation_logistic, AleResult, LimeResult,
77};
78
79// --- sensitivity.rs ---
80pub use sensitivity::{
81    domain_selection, domain_selection_logistic, functional_saliency, functional_saliency_logistic,
82    sobol_indices, sobol_indices_logistic, DomainSelectionResult, FunctionalSaliencyResult,
83    ImportantInterval, SobolIndicesResult,
84};
85
86// --- counterfactual.rs ---
87pub use counterfactual::{
88    counterfactual_logistic, counterfactual_regression, prototype_criticism, CounterfactualResult,
89    PrototypeCriticismResult,
90};
91
92// --- advanced.rs ---
93pub use advanced::{
94    anchor_explanation, anchor_explanation_logistic, calibration_diagnostics,
95    conformal_prediction_residuals, expected_calibration_error, explanation_stability,
96    explanation_stability_logistic, regression_depth, regression_depth_logistic, AnchorCondition,
97    AnchorResult, AnchorRule, CalibrationDiagnosticsResult, ConformalPredictionResult, DepthType,
98    EceResult, RegressionDepthResult, StabilityAnalysisResult,
99};
100
101// ===========================================================================
102// pub(crate) re-exports from helpers (for explain_generic.rs, conformal.rs)
103// ===========================================================================
104
105pub(crate) use helpers::{
106    accumulate_kernel_shap_sample, anchor_beam_search, build_coalition_scores,
107    build_stability_result, clone_scores_matrix, compute_ale, compute_column_means,
108    compute_conditioning_bins, compute_domain_selection, compute_h_squared, compute_kernel_mean,
109    compute_lime, compute_mean_scalar, compute_saliency_map, compute_sobol_component,
110    compute_witness, gaussian_kernel_matrix, generate_sobol_matrices, get_obs_scalar,
111    greedy_prototype_selection, ice_to_pdp, make_grid, mean_absolute_column, median_bandwidth,
112    permute_component, project_scores, reconstruct_delta_function, sample_random_coalition,
113    shapley_kernel_weight, solve_kernel_shap_obs, subsample_rows,
114};
115
116// pub(crate) re-export from diagnostics (used by explain_generic.rs)
117pub(crate) use diagnostics::compute_vif_from_scores;
118
119// ===========================================================================
120// Tests
121// ===========================================================================
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::scalar_on_function::{fregre_lm, functional_logistic};
127    use std::f64::consts::PI;
128
129    fn generate_test_data(n: usize, m: usize, seed: u64) -> (crate::matrix::FdMatrix, Vec<f64>) {
130        use crate::matrix::FdMatrix;
131        let t: Vec<f64> = (0..m).map(|j| j as f64 / (m - 1) as f64).collect();
132        let mut data = FdMatrix::zeros(n, m);
133        let mut y = vec![0.0; n];
134        for i in 0..n {
135            let phase =
136                (seed.wrapping_mul(17).wrapping_add(i as u64 * 31) % 1000) as f64 / 1000.0 * PI;
137            let amplitude =
138                ((seed.wrapping_mul(13).wrapping_add(i as u64 * 7) % 100) as f64 / 100.0) - 0.5;
139            for j in 0..m {
140                data[(i, j)] =
141                    (2.0 * PI * t[j] + phase).sin() + amplitude * (4.0 * PI * t[j]).cos();
142            }
143            y[i] = 2.0 * phase + 3.0 * amplitude;
144        }
145        (data, y)
146    }
147
148    #[test]
149    fn test_functional_pdp_shape() {
150        let (data, y) = generate_test_data(30, 50, 42);
151        let fit = fregre_lm(&data, &y, None, 3).unwrap();
152        let pdp = functional_pdp(&fit, &data, None, 0, 20).unwrap();
153        assert_eq!(pdp.grid_values.len(), 20);
154        assert_eq!(pdp.pdp_curve.len(), 20);
155        assert_eq!(pdp.ice_curves.shape(), (30, 20));
156        assert_eq!(pdp.component, 0);
157    }
158
159    #[test]
160    fn test_functional_pdp_linear_ice_parallel() {
161        let (data, y) = generate_test_data(30, 50, 42);
162        let fit = fregre_lm(&data, &y, None, 3).unwrap();
163        let pdp = functional_pdp(&fit, &data, None, 1, 10).unwrap();
164
165        // For linear model, all ICE curves should have the same slope
166        let grid_range = pdp.grid_values[9] - pdp.grid_values[0];
167        let slope_0 = (pdp.ice_curves[(0, 9)] - pdp.ice_curves[(0, 0)]) / grid_range;
168        for i in 1..30 {
169            let slope_i = (pdp.ice_curves[(i, 9)] - pdp.ice_curves[(i, 0)]) / grid_range;
170            assert!(
171                (slope_i - slope_0).abs() < 1e-10,
172                "ICE curves should be parallel for linear model: slope_0={}, slope_{}={}",
173                slope_0,
174                i,
175                slope_i
176            );
177        }
178    }
179
180    #[test]
181    fn test_functional_pdp_logistic_probabilities() {
182        let (data, y_cont) = generate_test_data(30, 50, 42);
183        let y_median = {
184            let mut sorted = y_cont.clone();
185            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
186            sorted[sorted.len() / 2]
187        };
188        let y_bin: Vec<f64> = y_cont
189            .iter()
190            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
191            .collect();
192
193        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
194        let pdp = functional_pdp_logistic(&fit, &data, None, 0, 15).unwrap();
195
196        assert_eq!(pdp.grid_values.len(), 15);
197        assert_eq!(pdp.pdp_curve.len(), 15);
198        assert_eq!(pdp.ice_curves.shape(), (30, 15));
199
200        // All ICE values and PDP values should be valid probabilities in [0, 1]
201        for g in 0..15 {
202            assert!(
203                pdp.pdp_curve[g] >= 0.0 && pdp.pdp_curve[g] <= 1.0,
204                "PDP should be in [0,1], got {}",
205                pdp.pdp_curve[g]
206            );
207            for i in 0..30 {
208                assert!(
209                    pdp.ice_curves[(i, g)] >= 0.0 && pdp.ice_curves[(i, g)] <= 1.0,
210                    "ICE should be in [0,1], got {}",
211                    pdp.ice_curves[(i, g)]
212                );
213            }
214        }
215    }
216
217    #[test]
218    fn test_functional_pdp_invalid_component() {
219        let (data, y) = generate_test_data(30, 50, 42);
220        let fit = fregre_lm(&data, &y, None, 3).unwrap();
221        assert!(functional_pdp(&fit, &data, None, 3, 10).is_none());
222        assert!(functional_pdp(&fit, &data, None, 0, 1).is_none());
223    }
224
225    #[test]
226    fn test_functional_pdp_column_mismatch() {
227        use crate::matrix::FdMatrix;
228        let (data, y) = generate_test_data(30, 50, 42);
229        let fit = fregre_lm(&data, &y, None, 3).unwrap();
230        let wrong_data = FdMatrix::zeros(30, 40);
231        assert!(functional_pdp(&fit, &wrong_data, None, 0, 10).is_none());
232    }
233
234    #[test]
235    fn test_functional_pdp_zero_rows() {
236        use crate::matrix::FdMatrix;
237        let (data, y) = generate_test_data(30, 50, 42);
238        let fit = fregre_lm(&data, &y, None, 3).unwrap();
239        let empty_data = FdMatrix::zeros(0, 50);
240        assert!(functional_pdp(&fit, &empty_data, None, 0, 10).is_none());
241    }
242
243    #[test]
244    fn test_functional_pdp_logistic_column_mismatch() {
245        use crate::matrix::FdMatrix;
246        let (data, y_cont) = generate_test_data(30, 50, 42);
247        let y_median = {
248            let mut sorted = y_cont.clone();
249            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
250            sorted[sorted.len() / 2]
251        };
252        let y_bin: Vec<f64> = y_cont
253            .iter()
254            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
255            .collect();
256        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
257        let wrong_data = FdMatrix::zeros(30, 40);
258        assert!(functional_pdp_logistic(&fit, &wrong_data, None, 0, 10).is_none());
259    }
260
261    #[test]
262    fn test_functional_pdp_logistic_zero_rows() {
263        use crate::matrix::FdMatrix;
264        let (data, y_cont) = generate_test_data(30, 50, 42);
265        let y_median = {
266            let mut sorted = y_cont.clone();
267            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
268            sorted[sorted.len() / 2]
269        };
270        let y_bin: Vec<f64> = y_cont
271            .iter()
272            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
273            .collect();
274        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
275        let empty_data = FdMatrix::zeros(0, 50);
276        assert!(functional_pdp_logistic(&fit, &empty_data, None, 0, 10).is_none());
277    }
278
279    // Beta decomposition tests
280
281    #[test]
282    fn test_beta_decomposition_sums_to_beta_t() {
283        let (data, y) = generate_test_data(30, 50, 42);
284        let fit = fregre_lm(&data, &y, None, 3).unwrap();
285        let dec = beta_decomposition(&fit).unwrap();
286        for j in 0..50 {
287            let sum: f64 = dec.components.iter().map(|c| c[j]).sum();
288            assert!(
289                (sum - fit.beta_t[j]).abs() < 1e-10,
290                "Decomposition should sum to beta_t at j={}: {} vs {}",
291                j,
292                sum,
293                fit.beta_t[j]
294            );
295        }
296    }
297
298    #[test]
299    fn test_beta_decomposition_proportions_sum_to_one() {
300        let (data, y) = generate_test_data(30, 50, 42);
301        let fit = fregre_lm(&data, &y, None, 3).unwrap();
302        let dec = beta_decomposition(&fit).unwrap();
303        let total: f64 = dec.variance_proportion.iter().sum();
304        assert!(
305            (total - 1.0).abs() < 1e-10,
306            "Proportions should sum to 1: {}",
307            total
308        );
309    }
310
311    #[test]
312    fn test_beta_decomposition_coefficients_match() {
313        let (data, y) = generate_test_data(30, 50, 42);
314        let fit = fregre_lm(&data, &y, None, 3).unwrap();
315        let dec = beta_decomposition(&fit).unwrap();
316        for k in 0..3 {
317            assert!(
318                (dec.coefficients[k] - fit.coefficients[1 + k]).abs() < 1e-12,
319                "Coefficient mismatch at k={}",
320                k
321            );
322        }
323    }
324
325    #[test]
326    fn test_beta_decomposition_logistic_sums() {
327        let (data, y_cont) = generate_test_data(30, 50, 42);
328        let y_median = {
329            let mut sorted = y_cont.clone();
330            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
331            sorted[sorted.len() / 2]
332        };
333        let y_bin: Vec<f64> = y_cont
334            .iter()
335            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
336            .collect();
337        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
338        let dec = beta_decomposition_logistic(&fit).unwrap();
339        for j in 0..50 {
340            let sum: f64 = dec.components.iter().map(|c| c[j]).sum();
341            assert!(
342                (sum - fit.beta_t[j]).abs() < 1e-10,
343                "Logistic decomposition should sum to beta_t at j={}",
344                j
345            );
346        }
347    }
348
349    // Significant regions tests
350
351    #[test]
352    fn test_significant_regions_all_positive() {
353        let lower = vec![0.1, 0.2, 0.3, 0.4, 0.5];
354        let upper = vec![1.0, 1.0, 1.0, 1.0, 1.0];
355        let regions = significant_regions(&lower, &upper).unwrap();
356        assert_eq!(regions.len(), 1);
357        assert_eq!(regions[0].start_idx, 0);
358        assert_eq!(regions[0].end_idx, 4);
359        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
360    }
361
362    #[test]
363    fn test_significant_regions_none() {
364        let lower = vec![-0.5, -0.3, -0.1, -0.5];
365        let upper = vec![0.5, 0.3, 0.1, 0.5];
366        let regions = significant_regions(&lower, &upper).unwrap();
367        assert!(regions.is_empty());
368    }
369
370    #[test]
371    fn test_significant_regions_mixed() {
372        let lower = vec![0.1, 0.2, -0.5, -1.0, -0.8];
373        let upper = vec![0.9, 0.8, 0.5, -0.1, -0.2];
374        let regions = significant_regions(&lower, &upper).unwrap();
375        assert_eq!(regions.len(), 2);
376        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
377        assert_eq!(regions[0].start_idx, 0);
378        assert_eq!(regions[0].end_idx, 1);
379        assert_eq!(regions[1].direction, SignificanceDirection::Negative);
380        assert_eq!(regions[1].start_idx, 3);
381        assert_eq!(regions[1].end_idx, 4);
382    }
383
384    #[test]
385    fn test_significant_regions_from_se() {
386        let beta_t = vec![2.0, 2.0, 0.0, -2.0, -2.0];
387        let beta_se = vec![0.5, 0.5, 0.5, 0.5, 0.5];
388        let z = 1.96;
389        let regions = significant_regions_from_se(&beta_t, &beta_se, z).unwrap();
390        assert_eq!(regions.len(), 2);
391        assert_eq!(regions[0].direction, SignificanceDirection::Positive);
392        assert_eq!(regions[1].direction, SignificanceDirection::Negative);
393    }
394
395    #[test]
396    fn test_significant_regions_single_point() {
397        let lower = vec![-1.0, 0.5, -1.0];
398        let upper = vec![1.0, 1.0, 1.0];
399        let regions = significant_regions(&lower, &upper).unwrap();
400        assert_eq!(regions.len(), 1);
401        assert_eq!(regions[0].start_idx, 1);
402        assert_eq!(regions[0].end_idx, 1);
403    }
404
405    // FPC permutation importance tests
406
407    #[test]
408    fn test_fpc_importance_shape() {
409        let (data, y) = generate_test_data(30, 50, 42);
410        let fit = fregre_lm(&data, &y, None, 3).unwrap();
411        let imp = fpc_permutation_importance(&fit, &data, &y, 10, 42).unwrap();
412        assert_eq!(imp.importance.len(), 3);
413        assert_eq!(imp.permuted_metric.len(), 3);
414    }
415
416    #[test]
417    fn test_fpc_importance_nonnegative() {
418        let (data, y) = generate_test_data(40, 50, 42);
419        let fit = fregre_lm(&data, &y, None, 3).unwrap();
420        let imp = fpc_permutation_importance(&fit, &data, &y, 50, 42).unwrap();
421        for k in 0..3 {
422            assert!(
423                imp.importance[k] >= -0.05,
424                "Importance should be approximately nonneg: k={}, val={}",
425                k,
426                imp.importance[k]
427            );
428        }
429    }
430
431    #[test]
432    fn test_fpc_importance_dominant_largest() {
433        let (data, y) = generate_test_data(50, 50, 42);
434        let fit = fregre_lm(&data, &y, None, 3).unwrap();
435        let imp = fpc_permutation_importance(&fit, &data, &y, 100, 42).unwrap();
436        let max_imp = imp
437            .importance
438            .iter()
439            .cloned()
440            .fold(f64::NEG_INFINITY, f64::max);
441        assert!(
442            max_imp > 0.0,
443            "At least one component should be important: {:?}",
444            imp.importance
445        );
446    }
447
448    #[test]
449    fn test_fpc_importance_reproducible() {
450        let (data, y) = generate_test_data(30, 50, 42);
451        let fit = fregre_lm(&data, &y, None, 3).unwrap();
452        let imp1 = fpc_permutation_importance(&fit, &data, &y, 20, 999).unwrap();
453        let imp2 = fpc_permutation_importance(&fit, &data, &y, 20, 999).unwrap();
454        for k in 0..3 {
455            assert!(
456                (imp1.importance[k] - imp2.importance[k]).abs() < 1e-12,
457                "Same seed should produce same result at k={}",
458                k
459            );
460        }
461    }
462
463    #[test]
464    fn test_fpc_importance_logistic_shape() {
465        let (data, y_cont) = generate_test_data(30, 50, 42);
466        let y_median = {
467            let mut sorted = y_cont.clone();
468            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
469            sorted[sorted.len() / 2]
470        };
471        let y_bin: Vec<f64> = y_cont
472            .iter()
473            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
474            .collect();
475        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
476        let imp = fpc_permutation_importance_logistic(&fit, &data, &y_bin, 10, 42).unwrap();
477        assert_eq!(imp.importance.len(), 3);
478        assert!(imp.baseline_metric >= 0.0 && imp.baseline_metric <= 1.0);
479    }
480
481    // Influence diagnostics tests
482
483    #[test]
484    fn test_influence_leverage_sum_equals_p() {
485        let (data, y) = generate_test_data(30, 50, 42);
486        let fit = fregre_lm(&data, &y, None, 3).unwrap();
487        let diag = influence_diagnostics(&fit, &data, None).unwrap();
488        let h_sum: f64 = diag.leverage.iter().sum();
489        assert!(
490            (h_sum - diag.p as f64).abs() < 1e-6,
491            "Leverage sum should equal p={}: got {}",
492            diag.p,
493            h_sum
494        );
495    }
496
497    #[test]
498    fn test_influence_leverage_range() {
499        let (data, y) = generate_test_data(30, 50, 42);
500        let fit = fregre_lm(&data, &y, None, 3).unwrap();
501        let diag = influence_diagnostics(&fit, &data, None).unwrap();
502        for (i, &h) in diag.leverage.iter().enumerate() {
503            assert!(
504                (-1e-10..=1.0 + 1e-10).contains(&h),
505                "Leverage out of range at i={}: {}",
506                i,
507                h
508            );
509        }
510    }
511
512    #[test]
513    fn test_influence_cooks_nonnegative() {
514        let (data, y) = generate_test_data(30, 50, 42);
515        let fit = fregre_lm(&data, &y, None, 3).unwrap();
516        let diag = influence_diagnostics(&fit, &data, None).unwrap();
517        for (i, &d) in diag.cooks_distance.iter().enumerate() {
518            assert!(d >= 0.0, "Cook's D should be nonneg at i={}: {}", i, d);
519        }
520    }
521
522    #[test]
523    fn test_influence_high_leverage_outlier() {
524        let (mut data, mut y) = generate_test_data(30, 50, 42);
525        for j in 0..50 {
526            data[(0, j)] *= 10.0;
527        }
528        y[0] = 100.0;
529        let fit = fregre_lm(&data, &y, None, 3).unwrap();
530        let diag = influence_diagnostics(&fit, &data, None).unwrap();
531        let max_cd = diag
532            .cooks_distance
533            .iter()
534            .cloned()
535            .fold(f64::NEG_INFINITY, f64::max);
536        assert!(
537            (diag.cooks_distance[0] - max_cd).abs() < 1e-10,
538            "Outlier should have max Cook's D"
539        );
540    }
541
542    #[test]
543    fn test_influence_shape() {
544        let (data, y) = generate_test_data(30, 50, 42);
545        let fit = fregre_lm(&data, &y, None, 3).unwrap();
546        let diag = influence_diagnostics(&fit, &data, None).unwrap();
547        assert_eq!(diag.leverage.len(), 30);
548        assert_eq!(diag.cooks_distance.len(), 30);
549        assert_eq!(diag.p, 4);
550    }
551
552    #[test]
553    fn test_influence_column_mismatch_returns_none() {
554        use crate::matrix::FdMatrix;
555        let (data, y) = generate_test_data(30, 50, 42);
556        let fit = fregre_lm(&data, &y, None, 3).unwrap();
557        let wrong_data = FdMatrix::zeros(30, 40);
558        assert!(influence_diagnostics(&fit, &wrong_data, None).is_none());
559    }
560
561    // Friedman H-statistic tests
562
563    #[test]
564    fn test_h_statistic_linear_zero() {
565        let (data, y) = generate_test_data(30, 50, 42);
566        let fit = fregre_lm(&data, &y, None, 3).unwrap();
567        let h = friedman_h_statistic(&fit, &data, 0, 1, 10).unwrap();
568        assert!(
569            h.h_squared.abs() < 1e-6,
570            "H^2 should be ~0 for linear model: {}",
571            h.h_squared
572        );
573    }
574
575    #[test]
576    fn test_h_statistic_logistic_positive() {
577        let (data, y_cont) = generate_test_data(40, 50, 42);
578        let y_median = {
579            let mut sorted = y_cont.clone();
580            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
581            sorted[sorted.len() / 2]
582        };
583        let y_bin: Vec<f64> = y_cont
584            .iter()
585            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
586            .collect();
587        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
588        let h = friedman_h_statistic_logistic(&fit, &data, None, 0, 1, 10).unwrap();
589        assert!(
590            h.h_squared >= 0.0,
591            "H^2 should be nonneg for logistic: {}",
592            h.h_squared
593        );
594    }
595
596    #[test]
597    fn test_h_statistic_symmetry() {
598        let (data, y) = generate_test_data(30, 50, 42);
599        let fit = fregre_lm(&data, &y, None, 3).unwrap();
600        let h01 = friedman_h_statistic(&fit, &data, 0, 1, 10).unwrap();
601        let h10 = friedman_h_statistic(&fit, &data, 1, 0, 10).unwrap();
602        assert!(
603            (h01.h_squared - h10.h_squared).abs() < 1e-10,
604            "H(0,1) should equal H(1,0): {} vs {}",
605            h01.h_squared,
606            h10.h_squared
607        );
608    }
609
610    #[test]
611    fn test_h_statistic_grid_shape() {
612        let (data, y) = generate_test_data(30, 50, 42);
613        let fit = fregre_lm(&data, &y, None, 3).unwrap();
614        let h = friedman_h_statistic(&fit, &data, 0, 2, 8).unwrap();
615        assert_eq!(h.grid_j.len(), 8);
616        assert_eq!(h.grid_k.len(), 8);
617        assert_eq!(h.pdp_2d.shape(), (8, 8));
618    }
619
620    #[test]
621    fn test_h_statistic_bounded() {
622        let (data, y_cont) = generate_test_data(40, 50, 42);
623        let y_median = {
624            let mut sorted = y_cont.clone();
625            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
626            sorted[sorted.len() / 2]
627        };
628        let y_bin: Vec<f64> = y_cont
629            .iter()
630            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
631            .collect();
632        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
633        let h = friedman_h_statistic_logistic(&fit, &data, None, 0, 1, 10).unwrap();
634        assert!(
635            h.h_squared >= 0.0 && h.h_squared <= 1.0 + 1e-6,
636            "H^2 should be in [0,1]: {}",
637            h.h_squared
638        );
639    }
640
641    #[test]
642    fn test_h_statistic_same_component_none() {
643        let (data, y) = generate_test_data(30, 50, 42);
644        let fit = fregre_lm(&data, &y, None, 3).unwrap();
645        assert!(friedman_h_statistic(&fit, &data, 1, 1, 10).is_none());
646    }
647
648    // Pointwise importance tests
649
650    #[test]
651    fn test_pointwise_importance_shape() {
652        let (data, y) = generate_test_data(30, 50, 42);
653        let fit = fregre_lm(&data, &y, None, 3).unwrap();
654        let pi = pointwise_importance(&fit).unwrap();
655        assert_eq!(pi.importance.len(), 50);
656        assert_eq!(pi.importance_normalized.len(), 50);
657        assert_eq!(pi.component_importance.shape(), (3, 50));
658        assert_eq!(pi.score_variance.len(), 3);
659    }
660
661    #[test]
662    fn test_pointwise_importance_normalized_sums_to_one() {
663        let (data, y) = generate_test_data(30, 50, 42);
664        let fit = fregre_lm(&data, &y, None, 3).unwrap();
665        let pi = pointwise_importance(&fit).unwrap();
666        let total: f64 = pi.importance_normalized.iter().sum();
667        assert!(
668            (total - 1.0).abs() < 1e-10,
669            "Normalized importance should sum to 1: {}",
670            total
671        );
672    }
673
674    #[test]
675    fn test_pointwise_importance_all_nonneg() {
676        let (data, y) = generate_test_data(30, 50, 42);
677        let fit = fregre_lm(&data, &y, None, 3).unwrap();
678        let pi = pointwise_importance(&fit).unwrap();
679        for (j, &v) in pi.importance.iter().enumerate() {
680            assert!(v >= -1e-15, "Importance should be nonneg at j={}: {}", j, v);
681        }
682    }
683
684    #[test]
685    fn test_pointwise_importance_component_sum_equals_total() {
686        let (data, y) = generate_test_data(30, 50, 42);
687        let fit = fregre_lm(&data, &y, None, 3).unwrap();
688        let pi = pointwise_importance(&fit).unwrap();
689        for j in 0..50 {
690            let sum: f64 = (0..3).map(|k| pi.component_importance[(k, j)]).sum();
691            assert!(
692                (sum - pi.importance[j]).abs() < 1e-10,
693                "Component sum should equal total at j={}: {} vs {}",
694                j,
695                sum,
696                pi.importance[j]
697            );
698        }
699    }
700
701    #[test]
702    fn test_pointwise_importance_logistic_shape() {
703        let (data, y_cont) = generate_test_data(30, 50, 42);
704        let y_median = {
705            let mut sorted = y_cont.clone();
706            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
707            sorted[sorted.len() / 2]
708        };
709        let y_bin: Vec<f64> = y_cont
710            .iter()
711            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
712            .collect();
713        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
714        let pi = pointwise_importance_logistic(&fit).unwrap();
715        assert_eq!(pi.importance.len(), 50);
716        assert_eq!(pi.score_variance.len(), 3);
717    }
718
719    // VIF tests
720
721    #[test]
722    fn test_vif_orthogonal_fpcs_near_one() {
723        let (data, y) = generate_test_data(50, 50, 42);
724        let fit = fregre_lm(&data, &y, None, 3).unwrap();
725        let vif = fpc_vif(&fit, &data, None).unwrap();
726        for (k, &v) in vif.vif.iter().enumerate() {
727            assert!(
728                (v - 1.0).abs() < 0.5,
729                "Orthogonal FPC VIF should be ~1 at k={}: {}",
730                k,
731                v
732            );
733        }
734    }
735
736    #[test]
737    fn test_vif_all_positive() {
738        let (data, y) = generate_test_data(50, 50, 42);
739        let fit = fregre_lm(&data, &y, None, 3).unwrap();
740        let vif = fpc_vif(&fit, &data, None).unwrap();
741        for (k, &v) in vif.vif.iter().enumerate() {
742            assert!(v >= 1.0 - 1e-6, "VIF should be >= 1 at k={}: {}", k, v);
743        }
744    }
745
746    #[test]
747    fn test_vif_shape() {
748        let (data, y) = generate_test_data(50, 50, 42);
749        let fit = fregre_lm(&data, &y, None, 3).unwrap();
750        let vif = fpc_vif(&fit, &data, None).unwrap();
751        assert_eq!(vif.vif.len(), 3);
752        assert_eq!(vif.labels.len(), 3);
753    }
754
755    #[test]
756    fn test_vif_labels_correct() {
757        let (data, y) = generate_test_data(50, 50, 42);
758        let fit = fregre_lm(&data, &y, None, 3).unwrap();
759        let vif = fpc_vif(&fit, &data, None).unwrap();
760        assert_eq!(vif.labels[0], "FPC_0");
761        assert_eq!(vif.labels[1], "FPC_1");
762        assert_eq!(vif.labels[2], "FPC_2");
763    }
764
765    #[test]
766    fn test_vif_logistic_agrees_with_linear() {
767        let (data, y_cont) = generate_test_data(50, 50, 42);
768        let y_median = {
769            let mut sorted = y_cont.clone();
770            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
771            sorted[sorted.len() / 2]
772        };
773        let y_bin: Vec<f64> = y_cont
774            .iter()
775            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
776            .collect();
777        let fit_lm = fregre_lm(&data, &y_cont, None, 3).unwrap();
778        let fit_log = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
779        let vif_lm = fpc_vif(&fit_lm, &data, None).unwrap();
780        let vif_log = fpc_vif_logistic(&fit_log, &data, None).unwrap();
781        for k in 0..3 {
782            assert!(
783                (vif_lm.vif[k] - vif_log.vif[k]).abs() < 1e-6,
784                "VIF should agree: lm={}, log={}",
785                vif_lm.vif[k],
786                vif_log.vif[k]
787            );
788        }
789    }
790
791    #[test]
792    fn test_vif_single_predictor() {
793        let (data, y) = generate_test_data(50, 50, 42);
794        let fit = fregre_lm(&data, &y, None, 1).unwrap();
795        let vif = fpc_vif(&fit, &data, None).unwrap();
796        assert_eq!(vif.vif.len(), 1);
797        assert!(
798            (vif.vif[0] - 1.0).abs() < 1e-6,
799            "Single predictor VIF should be 1: {}",
800            vif.vif[0]
801        );
802    }
803
804    // SHAP tests
805
806    #[test]
807    fn test_shap_linear_sum_to_fitted() {
808        let (data, y) = generate_test_data(30, 50, 42);
809        let fit = fregre_lm(&data, &y, None, 3).unwrap();
810        let shap = fpc_shap_values(&fit, &data, None).unwrap();
811        for i in 0..30 {
812            let sum: f64 = (0..3).map(|k| shap.values[(i, k)]).sum::<f64>() + shap.base_value;
813            assert!(
814                (sum - fit.fitted_values[i]).abs() < 1e-8,
815                "SHAP sum should equal fitted at i={}: {} vs {}",
816                i,
817                sum,
818                fit.fitted_values[i]
819            );
820        }
821    }
822
823    #[test]
824    fn test_shap_linear_shape() {
825        let (data, y) = generate_test_data(30, 50, 42);
826        let fit = fregre_lm(&data, &y, None, 3).unwrap();
827        let shap = fpc_shap_values(&fit, &data, None).unwrap();
828        assert_eq!(shap.values.shape(), (30, 3));
829        assert_eq!(shap.mean_scores.len(), 3);
830    }
831
832    #[test]
833    fn test_shap_linear_sign_matches_coefficient() {
834        let (data, y) = generate_test_data(50, 50, 42);
835        let fit = fregre_lm(&data, &y, None, 3).unwrap();
836        let shap = fpc_shap_values(&fit, &data, None).unwrap();
837        for k in 0..3 {
838            let coef_k = fit.coefficients[1 + k];
839            if coef_k.abs() < 1e-10 {
840                continue;
841            }
842            for i in 0..50 {
843                let score_centered = fit.fpca.scores[(i, k)] - shap.mean_scores[k];
844                let expected_sign = (coef_k * score_centered).signum();
845                if shap.values[(i, k)].abs() > 1e-10 {
846                    assert_eq!(
847                        shap.values[(i, k)].signum(),
848                        expected_sign,
849                        "SHAP sign mismatch at i={}, k={}",
850                        i,
851                        k
852                    );
853                }
854            }
855        }
856    }
857
858    #[test]
859    fn test_shap_logistic_sum_approximates_prediction() {
860        let (data, y_cont) = generate_test_data(30, 50, 42);
861        let y_median = {
862            let mut sorted = y_cont.clone();
863            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
864            sorted[sorted.len() / 2]
865        };
866        let y_bin: Vec<f64> = y_cont
867            .iter()
868            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
869            .collect();
870        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
871        let shap = fpc_shap_values_logistic(&fit, &data, None, 500, 42).unwrap();
872        let mut shap_sums = Vec::new();
873        for i in 0..30 {
874            let sum: f64 = (0..3).map(|k| shap.values[(i, k)]).sum::<f64>() + shap.base_value;
875            shap_sums.push(sum);
876        }
877        let mean_shap: f64 = shap_sums.iter().sum::<f64>() / 30.0;
878        let mean_prob: f64 = fit.probabilities.iter().sum::<f64>() / 30.0;
879        let mut cov = 0.0;
880        let mut var_s = 0.0;
881        let mut var_p = 0.0;
882        for i in 0..30 {
883            let ds = shap_sums[i] - mean_shap;
884            let dp = fit.probabilities[i] - mean_prob;
885            cov += ds * dp;
886            var_s += ds * ds;
887            var_p += dp * dp;
888        }
889        let corr = if var_s > 0.0 && var_p > 0.0 {
890            cov / (var_s.sqrt() * var_p.sqrt())
891        } else {
892            0.0
893        };
894        assert!(
895            corr > 0.5,
896            "Logistic SHAP sums should correlate with probabilities: r={}",
897            corr
898        );
899    }
900
901    #[test]
902    fn test_shap_logistic_reproducible() {
903        let (data, y_cont) = generate_test_data(30, 50, 42);
904        let y_median = {
905            let mut sorted = y_cont.clone();
906            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
907            sorted[sorted.len() / 2]
908        };
909        let y_bin: Vec<f64> = y_cont
910            .iter()
911            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
912            .collect();
913        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
914        let s1 = fpc_shap_values_logistic(&fit, &data, None, 100, 999).unwrap();
915        let s2 = fpc_shap_values_logistic(&fit, &data, None, 100, 999).unwrap();
916        for i in 0..30 {
917            for k in 0..3 {
918                assert!(
919                    (s1.values[(i, k)] - s2.values[(i, k)]).abs() < 1e-12,
920                    "Same seed should give same SHAP at i={}, k={}",
921                    i,
922                    k
923                );
924            }
925        }
926    }
927
928    #[test]
929    fn test_shap_invalid_returns_none() {
930        use crate::matrix::FdMatrix;
931        let (data, y) = generate_test_data(30, 50, 42);
932        let fit = fregre_lm(&data, &y, None, 3).unwrap();
933        let empty = FdMatrix::zeros(0, 50);
934        assert!(fpc_shap_values(&fit, &empty, None).is_none());
935    }
936
937    // DFBETAS / DFFITS tests
938
939    #[test]
940    fn test_dfbetas_shape() {
941        let (data, y) = generate_test_data(30, 50, 42);
942        let fit = fregre_lm(&data, &y, None, 3).unwrap();
943        let db = dfbetas_dffits(&fit, &data, None).unwrap();
944        assert_eq!(db.dfbetas.shape(), (30, 4));
945        assert_eq!(db.dffits.len(), 30);
946        assert_eq!(db.studentized_residuals.len(), 30);
947        assert_eq!(db.p, 4);
948    }
949
950    #[test]
951    fn test_dffits_sign_matches_residual() {
952        let (data, y) = generate_test_data(30, 50, 42);
953        let fit = fregre_lm(&data, &y, None, 3).unwrap();
954        let db = dfbetas_dffits(&fit, &data, None).unwrap();
955        for i in 0..30 {
956            if fit.residuals[i].abs() > 1e-10 && db.dffits[i].abs() > 1e-10 {
957                assert_eq!(
958                    db.dffits[i].signum(),
959                    fit.residuals[i].signum(),
960                    "DFFITS sign should match residual at i={}",
961                    i
962                );
963            }
964        }
965    }
966
967    #[test]
968    fn test_dfbetas_outlier_flagged() {
969        let (mut data, mut y) = generate_test_data(30, 50, 42);
970        for j in 0..50 {
971            data[(0, j)] *= 10.0;
972        }
973        y[0] = 100.0;
974        let fit = fregre_lm(&data, &y, None, 3).unwrap();
975        let db = dfbetas_dffits(&fit, &data, None).unwrap();
976        let max_dffits = db
977            .dffits
978            .iter()
979            .map(|v| v.abs())
980            .fold(f64::NEG_INFINITY, f64::max);
981        assert!(
982            db.dffits[0].abs() >= max_dffits - 1e-10,
983            "Outlier should have max |DFFITS|"
984        );
985    }
986
987    #[test]
988    fn test_dfbetas_cutoff_value() {
989        let (data, y) = generate_test_data(30, 50, 42);
990        let fit = fregre_lm(&data, &y, None, 3).unwrap();
991        let db = dfbetas_dffits(&fit, &data, None).unwrap();
992        assert!(
993            (db.dfbetas_cutoff - 2.0 / (30.0_f64).sqrt()).abs() < 1e-10,
994            "DFBETAS cutoff should be 2/sqrt(n)"
995        );
996        assert!(
997            (db.dffits_cutoff - 2.0 * (4.0 / 30.0_f64).sqrt()).abs() < 1e-10,
998            "DFFITS cutoff should be 2*sqrt(p/n)"
999        );
1000    }
1001
1002    #[test]
1003    fn test_dfbetas_underdetermined_returns_none() {
1004        let (data, y) = generate_test_data(3, 50, 42);
1005        let fit = fregre_lm(&data, &y, None, 2).unwrap();
1006        assert!(dfbetas_dffits(&fit, &data, None).is_none());
1007    }
1008
1009    #[test]
1010    fn test_dffits_consistency_with_cooks() {
1011        let (data, y) = generate_test_data(40, 50, 42);
1012        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1013        let db = dfbetas_dffits(&fit, &data, None).unwrap();
1014        let infl = influence_diagnostics(&fit, &data, None).unwrap();
1015        let mut dffits_order: Vec<usize> = (0..40).collect();
1016        dffits_order.sort_by(|&a, &b| db.dffits[b].abs().partial_cmp(&db.dffits[a].abs()).unwrap());
1017        let mut cooks_order: Vec<usize> = (0..40).collect();
1018        cooks_order.sort_by(|&a, &b| {
1019            infl.cooks_distance[b]
1020                .partial_cmp(&infl.cooks_distance[a])
1021                .unwrap()
1022        });
1023        assert_eq!(
1024            dffits_order[0], cooks_order[0],
1025            "Top influential obs should agree between DFFITS and Cook's D"
1026        );
1027    }
1028
1029    // Prediction interval tests
1030
1031    #[test]
1032    fn test_prediction_interval_training_data_matches_fitted() {
1033        let (data, y) = generate_test_data(30, 50, 42);
1034        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1035        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
1036        for i in 0..30 {
1037            assert!(
1038                (pi.predictions[i] - fit.fitted_values[i]).abs() < 1e-6,
1039                "Prediction should match fitted at i={}: {} vs {}",
1040                i,
1041                pi.predictions[i],
1042                fit.fitted_values[i]
1043            );
1044        }
1045    }
1046
1047    #[test]
1048    fn test_prediction_interval_covers_training_y() {
1049        let (data, y) = generate_test_data(30, 50, 42);
1050        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1051        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
1052        let mut covered = 0;
1053        for i in 0..30 {
1054            if y[i] >= pi.lower[i] && y[i] <= pi.upper[i] {
1055                covered += 1;
1056            }
1057        }
1058        assert!(
1059            covered >= 20,
1060            "At least ~67% of training y should be covered: {}/30",
1061            covered
1062        );
1063    }
1064
1065    #[test]
1066    fn test_prediction_interval_symmetry() {
1067        let (data, y) = generate_test_data(30, 50, 42);
1068        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1069        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
1070        for i in 0..30 {
1071            let above = pi.upper[i] - pi.predictions[i];
1072            let below = pi.predictions[i] - pi.lower[i];
1073            assert!(
1074                (above - below).abs() < 1e-10,
1075                "Interval should be symmetric at i={}: above={}, below={}",
1076                i,
1077                above,
1078                below
1079            );
1080        }
1081    }
1082
1083    #[test]
1084    fn test_prediction_interval_wider_at_99_than_95() {
1085        let (data, y) = generate_test_data(30, 50, 42);
1086        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1087        let pi95 = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
1088        let pi99 = prediction_intervals(&fit, &data, None, &data, None, 0.99).unwrap();
1089        for i in 0..30 {
1090            let width95 = pi95.upper[i] - pi95.lower[i];
1091            let width99 = pi99.upper[i] - pi99.lower[i];
1092            assert!(
1093                width99 >= width95 - 1e-10,
1094                "99% interval should be wider at i={}: {} vs {}",
1095                i,
1096                width99,
1097                width95
1098            );
1099        }
1100    }
1101
1102    #[test]
1103    fn test_prediction_interval_shape() {
1104        let (data, y) = generate_test_data(30, 50, 42);
1105        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1106        let pi = prediction_intervals(&fit, &data, None, &data, None, 0.95).unwrap();
1107        assert_eq!(pi.predictions.len(), 30);
1108        assert_eq!(pi.lower.len(), 30);
1109        assert_eq!(pi.upper.len(), 30);
1110        assert_eq!(pi.prediction_se.len(), 30);
1111        assert!((pi.confidence_level - 0.95).abs() < 1e-15);
1112    }
1113
1114    #[test]
1115    fn test_prediction_interval_invalid_confidence_returns_none() {
1116        let (data, y) = generate_test_data(30, 50, 42);
1117        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1118        assert!(prediction_intervals(&fit, &data, None, &data, None, 0.0).is_none());
1119        assert!(prediction_intervals(&fit, &data, None, &data, None, 1.0).is_none());
1120        assert!(prediction_intervals(&fit, &data, None, &data, None, -0.5).is_none());
1121    }
1122
1123    // ALE tests
1124
1125    #[test]
1126    fn test_ale_linear_is_linear() {
1127        let (data, y) = generate_test_data(50, 50, 42);
1128        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1129        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
1130        if ale.bin_midpoints.len() >= 3 {
1131            let slopes: Vec<f64> = ale
1132                .ale_values
1133                .windows(2)
1134                .zip(ale.bin_midpoints.windows(2))
1135                .map(|(v, m)| {
1136                    let dx = m[1] - m[0];
1137                    if dx.abs() > 1e-15 {
1138                        (v[1] - v[0]) / dx
1139                    } else {
1140                        0.0
1141                    }
1142                })
1143                .collect();
1144            let mean_slope = slopes.iter().sum::<f64>() / slopes.len() as f64;
1145            for (b, &s) in slopes.iter().enumerate() {
1146                assert!(
1147                    (s - mean_slope).abs() < mean_slope.abs() * 0.5 + 0.5,
1148                    "ALE slope should be constant for linear model at bin {}: {} vs mean {}",
1149                    b,
1150                    s,
1151                    mean_slope
1152                );
1153            }
1154        }
1155    }
1156
1157    #[test]
1158    fn test_ale_centered_mean_zero() {
1159        let (data, y) = generate_test_data(50, 50, 42);
1160        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1161        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
1162        let total_n: usize = ale.bin_counts.iter().sum();
1163        let weighted_mean: f64 = ale
1164            .ale_values
1165            .iter()
1166            .zip(&ale.bin_counts)
1167            .map(|(&a, &c)| a * c as f64)
1168            .sum::<f64>()
1169            / total_n as f64;
1170        assert!(
1171            weighted_mean.abs() < 1e-10,
1172            "ALE should be centered at zero: {}",
1173            weighted_mean
1174        );
1175    }
1176
1177    #[test]
1178    fn test_ale_bin_counts_sum_to_n() {
1179        let (data, y) = generate_test_data(50, 50, 42);
1180        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1181        let ale = fpc_ale(&fit, &data, None, 0, 10).unwrap();
1182        let total: usize = ale.bin_counts.iter().sum();
1183        assert_eq!(total, 50, "Bin counts should sum to n");
1184    }
1185
1186    #[test]
1187    fn test_ale_shape() {
1188        let (data, y) = generate_test_data(50, 50, 42);
1189        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1190        let ale = fpc_ale(&fit, &data, None, 0, 8).unwrap();
1191        let nb = ale.ale_values.len();
1192        assert_eq!(ale.bin_midpoints.len(), nb);
1193        assert_eq!(ale.bin_edges.len(), nb + 1);
1194        assert_eq!(ale.bin_counts.len(), nb);
1195        assert_eq!(ale.component, 0);
1196    }
1197
1198    #[test]
1199    fn test_ale_logistic_bounded() {
1200        let (data, y_cont) = generate_test_data(50, 50, 42);
1201        let y_median = {
1202            let mut sorted = y_cont.clone();
1203            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1204            sorted[sorted.len() / 2]
1205        };
1206        let y_bin: Vec<f64> = y_cont
1207            .iter()
1208            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
1209            .collect();
1210        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1211        let ale = fpc_ale_logistic(&fit, &data, None, 0, 10).unwrap();
1212        for &v in &ale.ale_values {
1213            assert!(v.abs() < 2.0, "Logistic ALE should be bounded: {}", v);
1214        }
1215    }
1216
1217    #[test]
1218    fn test_ale_invalid_returns_none() {
1219        let (data, y) = generate_test_data(30, 50, 42);
1220        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1221        assert!(fpc_ale(&fit, &data, None, 5, 10).is_none());
1222        assert!(fpc_ale(&fit, &data, None, 0, 0).is_none());
1223    }
1224
1225    // LOO-CV / PRESS tests
1226
1227    #[test]
1228    fn test_loo_cv_shape() {
1229        let (data, y) = generate_test_data(30, 50, 42);
1230        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1231        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
1232        assert_eq!(loo.loo_residuals.len(), 30);
1233        assert_eq!(loo.leverage.len(), 30);
1234    }
1235
1236    #[test]
1237    fn test_loo_r_squared_leq_r_squared() {
1238        let (data, y) = generate_test_data(30, 50, 42);
1239        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1240        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
1241        assert!(
1242            loo.loo_r_squared <= fit.r_squared + 1e-10,
1243            "LOO R^2 ({}) should be <= training R^2 ({})",
1244            loo.loo_r_squared,
1245            fit.r_squared
1246        );
1247    }
1248
1249    #[test]
1250    fn test_loo_press_equals_sum_squares() {
1251        let (data, y) = generate_test_data(30, 50, 42);
1252        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1253        let loo = loo_cv_press(&fit, &data, &y, None).unwrap();
1254        let manual_press: f64 = loo.loo_residuals.iter().map(|r| r * r).sum();
1255        assert!(
1256            (loo.press - manual_press).abs() < 1e-10,
1257            "PRESS mismatch: {} vs {}",
1258            loo.press,
1259            manual_press
1260        );
1261    }
1262
1263    // Sobol tests
1264
1265    #[test]
1266    fn test_sobol_linear_nonnegative() {
1267        let (data, y) = generate_test_data(30, 50, 42);
1268        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1269        let sobol = sobol_indices(&fit, &data, &y, None).unwrap();
1270        for (k, &s) in sobol.first_order.iter().enumerate() {
1271            assert!(s >= -1e-10, "S_{} should be >= 0: {}", k, s);
1272        }
1273    }
1274
1275    #[test]
1276    fn test_sobol_linear_sum_approx_r2() {
1277        let (data, y) = generate_test_data(30, 50, 42);
1278        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1279        let sobol = sobol_indices(&fit, &data, &y, None).unwrap();
1280        let sum_s: f64 = sobol.first_order.iter().sum();
1281        assert!(
1282            (sum_s - fit.r_squared).abs() < 0.2,
1283            "Sum S_k ({}) should be close to R^2 ({})",
1284            sum_s,
1285            fit.r_squared
1286        );
1287    }
1288
1289    #[test]
1290    fn test_sobol_logistic_bounded() {
1291        let (data, y_cont) = generate_test_data(30, 50, 42);
1292        let y_bin = {
1293            let mut s = y_cont.clone();
1294            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
1295            let med = s[s.len() / 2];
1296            y_cont
1297                .iter()
1298                .map(|&v| if v >= med { 1.0 } else { 0.0 })
1299                .collect::<Vec<_>>()
1300        };
1301        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1302        let sobol = sobol_indices_logistic(&fit, &data, None, 500, 42).unwrap();
1303        for &s in &sobol.first_order {
1304            assert!(s > -0.5 && s < 1.5, "Logistic S_k should be bounded: {}", s);
1305        }
1306    }
1307
1308    // Calibration tests
1309
1310    #[test]
1311    fn test_calibration_brier_range() {
1312        let (data, y_cont) = generate_test_data(30, 50, 42);
1313        let y_bin = {
1314            let mut s = y_cont.clone();
1315            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
1316            let med = s[s.len() / 2];
1317            y_cont
1318                .iter()
1319                .map(|&v| if v >= med { 1.0 } else { 0.0 })
1320                .collect::<Vec<_>>()
1321        };
1322        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1323        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
1324        assert!(
1325            cal.brier_score >= 0.0 && cal.brier_score <= 1.0,
1326            "Brier score should be in [0,1]: {}",
1327            cal.brier_score
1328        );
1329    }
1330
1331    #[test]
1332    fn test_calibration_bin_counts_sum_to_n() {
1333        let (data, y_cont) = generate_test_data(30, 50, 42);
1334        let y_bin = {
1335            let mut s = y_cont.clone();
1336            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
1337            let med = s[s.len() / 2];
1338            y_cont
1339                .iter()
1340                .map(|&v| if v >= med { 1.0 } else { 0.0 })
1341                .collect::<Vec<_>>()
1342        };
1343        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1344        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
1345        let total: usize = cal.bin_counts.iter().sum();
1346        assert_eq!(total, 30, "Bin counts should sum to n");
1347    }
1348
1349    #[test]
1350    fn test_calibration_n_groups_match() {
1351        let (data, y_cont) = generate_test_data(30, 50, 42);
1352        let y_bin = {
1353            let mut s = y_cont.clone();
1354            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
1355            let med = s[s.len() / 2];
1356            y_cont
1357                .iter()
1358                .map(|&v| if v >= med { 1.0 } else { 0.0 })
1359                .collect::<Vec<_>>()
1360        };
1361        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1362        let cal = calibration_diagnostics(&fit, &y_bin, 5).unwrap();
1363        assert_eq!(cal.n_groups, cal.reliability_bins.len());
1364        assert_eq!(cal.n_groups, cal.bin_counts.len());
1365    }
1366
1367    // Saliency tests
1368
1369    #[test]
1370    fn test_saliency_linear_shape() {
1371        let (data, y) = generate_test_data(30, 50, 42);
1372        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1373        let sal = functional_saliency(&fit, &data, None).unwrap();
1374        assert_eq!(sal.saliency_map.shape(), (30, 50));
1375        assert_eq!(sal.mean_absolute_saliency.len(), 50);
1376    }
1377
1378    #[test]
1379    fn test_saliency_logistic_bounded_by_quarter_beta() {
1380        let (data, y_cont) = generate_test_data(30, 50, 42);
1381        let y_bin = {
1382            let mut s = y_cont.clone();
1383            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
1384            let med = s[s.len() / 2];
1385            y_cont
1386                .iter()
1387                .map(|&v| if v >= med { 1.0 } else { 0.0 })
1388                .collect::<Vec<_>>()
1389        };
1390        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1391        let sal = functional_saliency_logistic(&fit).unwrap();
1392        for i in 0..30 {
1393            for j in 0..50 {
1394                assert!(
1395                    sal.saliency_map[(i, j)].abs() <= 0.25 * fit.beta_t[j].abs() + 1e-10,
1396                    "|s| should be <= 0.25 * |beta(t)| at ({},{})",
1397                    i,
1398                    j
1399                );
1400            }
1401        }
1402    }
1403
1404    #[test]
1405    fn test_saliency_mean_abs_nonneg() {
1406        let (data, y) = generate_test_data(30, 50, 42);
1407        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1408        let sal = functional_saliency(&fit, &data, None).unwrap();
1409        for &v in &sal.mean_absolute_saliency {
1410            assert!(v >= 0.0, "Mean absolute saliency should be >= 0: {}", v);
1411        }
1412    }
1413
1414    // Domain selection tests
1415
1416    #[test]
1417    fn test_domain_selection_valid_indices() {
1418        let (data, y) = generate_test_data(30, 50, 42);
1419        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1420        let ds = domain_selection(&fit, 5, 0.01).unwrap();
1421        for iv in &ds.intervals {
1422            assert!(iv.start_idx <= iv.end_idx);
1423            assert!(iv.end_idx < 50);
1424        }
1425    }
1426
1427    #[test]
1428    fn test_domain_selection_full_window_one_interval() {
1429        let (data, y) = generate_test_data(30, 50, 42);
1430        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1431        let ds = domain_selection(&fit, 50, 0.01).unwrap();
1432        assert!(
1433            ds.intervals.len() <= 1,
1434            "Full window should give <= 1 interval: {}",
1435            ds.intervals.len()
1436        );
1437    }
1438
1439    #[test]
1440    fn test_domain_selection_high_threshold_fewer() {
1441        let (data, y) = generate_test_data(30, 50, 42);
1442        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1443        let ds_low = domain_selection(&fit, 5, 0.01).unwrap();
1444        let ds_high = domain_selection(&fit, 5, 0.5).unwrap();
1445        assert!(
1446            ds_high.intervals.len() <= ds_low.intervals.len(),
1447            "Higher threshold should give <= intervals: {} vs {}",
1448            ds_high.intervals.len(),
1449            ds_low.intervals.len()
1450        );
1451    }
1452
1453    // Conditional permutation importance tests
1454
1455    #[test]
1456    fn test_cond_perm_shape() {
1457        let (data, y) = generate_test_data(30, 50, 42);
1458        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1459        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 5, 42).unwrap();
1460        assert_eq!(cp.importance.len(), 3);
1461        assert_eq!(cp.permuted_metric.len(), 3);
1462        assert_eq!(cp.unconditional_importance.len(), 3);
1463    }
1464
1465    #[test]
1466    fn test_cond_perm_vs_unconditional_close() {
1467        let (data, y) = generate_test_data(40, 50, 42);
1468        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1469        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 20, 42).unwrap();
1470        for k in 0..3 {
1471            let diff = (cp.importance[k] - cp.unconditional_importance[k]).abs();
1472            assert!(
1473                diff < 0.5,
1474                "Conditional vs unconditional should be similar for FPC {}: {} vs {}",
1475                k,
1476                cp.importance[k],
1477                cp.unconditional_importance[k]
1478            );
1479        }
1480    }
1481
1482    #[test]
1483    fn test_cond_perm_importance_nonneg() {
1484        let (data, y) = generate_test_data(40, 50, 42);
1485        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1486        let cp = conditional_permutation_importance(&fit, &data, &y, None, 3, 20, 42).unwrap();
1487        for k in 0..3 {
1488            assert!(
1489                cp.importance[k] >= -0.15,
1490                "Importance should be approx >= 0 for FPC {}: {}",
1491                k,
1492                cp.importance[k]
1493            );
1494        }
1495    }
1496
1497    // Counterfactual tests
1498
1499    #[test]
1500    fn test_counterfactual_regression_exact() {
1501        let (data, y) = generate_test_data(30, 50, 42);
1502        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1503        let target = fit.fitted_values[0] + 1.0;
1504        let cf = counterfactual_regression(&fit, &data, None, 0, target).unwrap();
1505        assert!(cf.found);
1506        assert!(
1507            (cf.counterfactual_prediction - target).abs() < 1e-10,
1508            "Counterfactual prediction should match target: {} vs {}",
1509            cf.counterfactual_prediction,
1510            target
1511        );
1512    }
1513
1514    #[test]
1515    fn test_counterfactual_regression_minimal() {
1516        let (data, y) = generate_test_data(30, 50, 42);
1517        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1518        let gap = 1.0;
1519        let target = fit.fitted_values[0] + gap;
1520        let cf = counterfactual_regression(&fit, &data, None, 0, target).unwrap();
1521        let gamma: Vec<f64> = (0..3).map(|k| fit.coefficients[1 + k]).collect();
1522        let gamma_norm: f64 = gamma.iter().map(|g| g * g).sum::<f64>().sqrt();
1523        let expected_dist = gap.abs() / gamma_norm;
1524        assert!(
1525            (cf.distance - expected_dist).abs() < 1e-6,
1526            "Distance should be |gap|/||gamma||: {} vs {}",
1527            cf.distance,
1528            expected_dist
1529        );
1530    }
1531
1532    #[test]
1533    fn test_counterfactual_logistic_flips_class() {
1534        let (data, y_cont) = generate_test_data(30, 50, 42);
1535        let y_bin = {
1536            let mut s = y_cont.clone();
1537            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
1538            let med = s[s.len() / 2];
1539            y_cont
1540                .iter()
1541                .map(|&v| if v >= med { 1.0 } else { 0.0 })
1542                .collect::<Vec<_>>()
1543        };
1544        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1545        let cf = counterfactual_logistic(&fit, &data, None, 0, 1000, 0.5).unwrap();
1546        if cf.found {
1547            let orig_class = if cf.original_prediction >= 0.5 { 1 } else { 0 };
1548            let new_class = if cf.counterfactual_prediction >= 0.5 {
1549                1
1550            } else {
1551                0
1552            };
1553            assert_ne!(
1554                orig_class, new_class,
1555                "Class should flip: orig={}, new={}",
1556                orig_class, new_class
1557            );
1558        }
1559    }
1560
1561    #[test]
1562    fn test_counterfactual_invalid_obs_none() {
1563        let (data, y) = generate_test_data(30, 50, 42);
1564        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1565        assert!(counterfactual_regression(&fit, &data, None, 100, 0.0).is_none());
1566    }
1567
1568    // Prototype/criticism tests
1569
1570    #[test]
1571    fn test_prototype_criticism_shape() {
1572        let (data, y) = generate_test_data(30, 50, 42);
1573        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1574        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
1575        assert_eq!(pc.prototype_indices.len(), 5);
1576        assert_eq!(pc.prototype_witness.len(), 5);
1577        assert_eq!(pc.criticism_indices.len(), 3);
1578        assert_eq!(pc.criticism_witness.len(), 3);
1579    }
1580
1581    #[test]
1582    fn test_prototype_criticism_no_overlap() {
1583        let (data, y) = generate_test_data(30, 50, 42);
1584        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1585        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
1586        for &ci in &pc.criticism_indices {
1587            assert!(
1588                !pc.prototype_indices.contains(&ci),
1589                "Criticism {} should not be a prototype",
1590                ci
1591            );
1592        }
1593    }
1594
1595    #[test]
1596    fn test_prototype_criticism_bandwidth_positive() {
1597        let (data, y) = generate_test_data(30, 50, 42);
1598        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1599        let pc = prototype_criticism(&fit.fpca, 3, 5, 3).unwrap();
1600        assert!(
1601            pc.bandwidth > 0.0,
1602            "Bandwidth should be > 0: {}",
1603            pc.bandwidth
1604        );
1605    }
1606
1607    // LIME tests
1608
1609    #[test]
1610    fn test_lime_linear_matches_global() {
1611        let (data, y) = generate_test_data(40, 50, 42);
1612        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1613        let lime = lime_explanation(&fit, &data, None, 0, 5000, 1.0, 42).unwrap();
1614        for k in 0..3 {
1615            let global = fit.coefficients[1 + k];
1616            let local = lime.attributions[k];
1617            let rel_err = if global.abs() > 1e-6 {
1618                (local - global).abs() / global.abs()
1619            } else {
1620                local.abs()
1621            };
1622            assert!(
1623                rel_err < 0.5,
1624                "LIME should approximate global coef for FPC {}: local={}, global={}",
1625                k,
1626                local,
1627                global
1628            );
1629        }
1630    }
1631
1632    #[test]
1633    fn test_lime_logistic_shape() {
1634        let (data, y_cont) = generate_test_data(30, 50, 42);
1635        let y_bin = {
1636            let mut s = y_cont.clone();
1637            s.sort_by(|a, b| a.partial_cmp(b).unwrap());
1638            let med = s[s.len() / 2];
1639            y_cont
1640                .iter()
1641                .map(|&v| if v >= med { 1.0 } else { 0.0 })
1642                .collect::<Vec<_>>()
1643        };
1644        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1645        let lime = lime_explanation_logistic(&fit, &data, None, 0, 500, 1.0, 42).unwrap();
1646        assert_eq!(lime.attributions.len(), 3);
1647        assert!(
1648            lime.local_r_squared >= 0.0 && lime.local_r_squared <= 1.0,
1649            "R^2 should be in [0,1]: {}",
1650            lime.local_r_squared
1651        );
1652    }
1653
1654    #[test]
1655    fn test_lime_invalid_none() {
1656        let (data, y) = generate_test_data(30, 50, 42);
1657        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1658        assert!(lime_explanation(&fit, &data, None, 100, 100, 1.0, 42).is_none());
1659        assert!(lime_explanation(&fit, &data, None, 0, 0, 1.0, 42).is_none());
1660        assert!(lime_explanation(&fit, &data, None, 0, 100, 0.0, 42).is_none());
1661    }
1662
1663    // ECE tests
1664
1665    fn make_logistic_fit() -> (
1666        crate::matrix::FdMatrix,
1667        Vec<f64>,
1668        crate::scalar_on_function::FunctionalLogisticResult,
1669    ) {
1670        let (data, y_cont) = generate_test_data(40, 50, 42);
1671        let y_median = {
1672            let mut sorted = y_cont.clone();
1673            sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
1674            sorted[sorted.len() / 2]
1675        };
1676        let y_bin: Vec<f64> = y_cont
1677            .iter()
1678            .map(|&v| if v >= y_median { 1.0 } else { 0.0 })
1679            .collect();
1680        let fit = functional_logistic(&data, &y_bin, None, 3, 25, 1e-6).unwrap();
1681        (data, y_bin, fit)
1682    }
1683
1684    #[test]
1685    fn test_ece_range() {
1686        let (_data, y_bin, fit) = make_logistic_fit();
1687        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
1688        assert!(
1689            ece.ece >= 0.0 && ece.ece <= 1.0,
1690            "ECE out of range: {}",
1691            ece.ece
1692        );
1693        assert!(
1694            ece.mce >= 0.0 && ece.mce <= 1.0,
1695            "MCE out of range: {}",
1696            ece.mce
1697        );
1698    }
1699
1700    #[test]
1701    fn test_ece_leq_mce() {
1702        let (_data, y_bin, fit) = make_logistic_fit();
1703        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
1704        assert!(
1705            ece.ece <= ece.mce + 1e-10,
1706            "ECE should <= MCE: {} vs {}",
1707            ece.ece,
1708            ece.mce
1709        );
1710    }
1711
1712    #[test]
1713    fn test_ece_bin_contributions_sum() {
1714        let (_data, y_bin, fit) = make_logistic_fit();
1715        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
1716        let sum: f64 = ece.bin_ece_contributions.iter().sum();
1717        assert!(
1718            (sum - ece.ece).abs() < 1e-10,
1719            "Contributions should sum to ECE: {} vs {}",
1720            sum,
1721            ece.ece
1722        );
1723    }
1724
1725    #[test]
1726    fn test_ece_n_bins_match() {
1727        let (_data, y_bin, fit) = make_logistic_fit();
1728        let ece = expected_calibration_error(&fit, &y_bin, 10).unwrap();
1729        assert_eq!(ece.n_bins, 10);
1730        assert_eq!(ece.bin_ece_contributions.len(), 10);
1731    }
1732
1733    // Conformal prediction tests
1734
1735    #[test]
1736    fn test_conformal_coverage_near_target() {
1737        let (data, y) = generate_test_data(60, 50, 42);
1738        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1739        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
1740            .unwrap();
1741        assert!(
1742            cp.coverage >= 0.8,
1743            "Coverage {} should be >= 0.8 for alpha=0.1",
1744            cp.coverage
1745        );
1746    }
1747
1748    #[test]
1749    fn test_conformal_interval_width_positive() {
1750        let (data, y) = generate_test_data(60, 50, 42);
1751        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1752        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
1753            .unwrap();
1754        for i in 0..cp.predictions.len() {
1755            assert!(
1756                cp.upper[i] > cp.lower[i],
1757                "Upper should > lower at {}: {} vs {}",
1758                i,
1759                cp.upper[i],
1760                cp.lower[i]
1761            );
1762        }
1763    }
1764
1765    #[test]
1766    fn test_conformal_quantile_positive() {
1767        let (data, y) = generate_test_data(60, 50, 42);
1768        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1769        let cp = conformal_prediction_residuals(&fit, &data, &y, &data, None, None, 0.3, 0.1, 42)
1770            .unwrap();
1771        assert!(
1772            cp.residual_quantile >= 0.0,
1773            "Quantile should be >= 0: {}",
1774            cp.residual_quantile
1775        );
1776    }
1777
1778    #[test]
1779    fn test_conformal_lengths_match() {
1780        use crate::matrix::FdMatrix;
1781        let (data, y) = generate_test_data(60, 50, 42);
1782        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1783        let test_data = FdMatrix::zeros(10, 50);
1784        let cp =
1785            conformal_prediction_residuals(&fit, &data, &y, &test_data, None, None, 0.3, 0.1, 42)
1786                .unwrap();
1787        assert_eq!(cp.predictions.len(), 10);
1788        assert_eq!(cp.lower.len(), 10);
1789        assert_eq!(cp.upper.len(), 10);
1790    }
1791
1792    // Regression depth tests
1793
1794    #[test]
1795    fn test_regression_depth_range() {
1796        let (data, y) = generate_test_data(30, 50, 42);
1797        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1798        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::FraimanMuniz, 42).unwrap();
1799        for (i, &d) in rd.score_depths.iter().enumerate() {
1800            assert!(
1801                (-1e-10..=1.0 + 1e-10).contains(&d),
1802                "Depth out of range at {}: {}",
1803                i,
1804                d
1805            );
1806        }
1807    }
1808
1809    #[test]
1810    fn test_regression_depth_beta_nonneg() {
1811        let (data, y) = generate_test_data(30, 50, 42);
1812        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1813        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::FraimanMuniz, 42).unwrap();
1814        assert!(
1815            rd.beta_depth >= -1e-10,
1816            "Beta depth should be >= 0: {}",
1817            rd.beta_depth
1818        );
1819    }
1820
1821    #[test]
1822    fn test_regression_depth_score_lengths() {
1823        let (data, y) = generate_test_data(30, 50, 42);
1824        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1825        let rd = regression_depth(&fit, &data, &y, None, 20, DepthType::ModifiedBand, 42).unwrap();
1826        assert_eq!(rd.score_depths.len(), 30);
1827    }
1828
1829    #[test]
1830    fn test_regression_depth_types_all_work() {
1831        let (data, y) = generate_test_data(30, 50, 42);
1832        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1833        for dt in [
1834            DepthType::FraimanMuniz,
1835            DepthType::ModifiedBand,
1836            DepthType::FunctionalSpatial,
1837        ] {
1838            let rd = regression_depth(&fit, &data, &y, None, 10, dt, 42);
1839            assert!(rd.is_some(), "Depth type {:?} should work", dt);
1840        }
1841    }
1842
1843    // Stability tests
1844
1845    #[test]
1846    fn test_stability_beta_std_nonneg() {
1847        let (data, y) = generate_test_data(30, 50, 42);
1848        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
1849        for (j, &s) in sa.beta_t_std.iter().enumerate() {
1850            assert!(s >= 0.0, "Std should be >= 0 at {}: {}", j, s);
1851        }
1852    }
1853
1854    #[test]
1855    fn test_stability_coefficient_std_length() {
1856        let (data, y) = generate_test_data(30, 50, 42);
1857        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
1858        assert_eq!(sa.coefficient_std.len(), 3);
1859    }
1860
1861    #[test]
1862    fn test_stability_importance_bounded() {
1863        let (data, y) = generate_test_data(30, 50, 42);
1864        let sa = explanation_stability(&data, &y, None, 3, 20, 42).unwrap();
1865        assert!(
1866            sa.importance_stability >= -1.0 - 1e-10 && sa.importance_stability <= 1.0 + 1e-10,
1867            "Importance stability out of range: {}",
1868            sa.importance_stability
1869        );
1870    }
1871
1872    #[test]
1873    fn test_stability_more_boots_more_stable() {
1874        let (data, y) = generate_test_data(40, 50, 42);
1875        let sa1 = explanation_stability(&data, &y, None, 3, 5, 42).unwrap();
1876        let sa2 = explanation_stability(&data, &y, None, 3, 50, 42).unwrap();
1877        assert!(sa2.n_boot_success >= sa1.n_boot_success);
1878    }
1879
1880    // Anchor tests
1881
1882    #[test]
1883    fn test_anchor_precision_meets_threshold() {
1884        let (data, y) = generate_test_data(40, 50, 42);
1885        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1886        let ar = anchor_explanation(&fit, &data, None, 0, 0.8, 5).unwrap();
1887        assert!(
1888            ar.rule.precision >= 0.8 - 1e-10,
1889            "Precision {} should meet 0.8",
1890            ar.rule.precision
1891        );
1892    }
1893
1894    #[test]
1895    fn test_anchor_coverage_range() {
1896        let (data, y) = generate_test_data(40, 50, 42);
1897        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1898        let ar = anchor_explanation(&fit, &data, None, 0, 0.8, 5).unwrap();
1899        assert!(
1900            ar.rule.coverage > 0.0 && ar.rule.coverage <= 1.0,
1901            "Coverage out of range: {}",
1902            ar.rule.coverage
1903        );
1904    }
1905
1906    #[test]
1907    fn test_anchor_observation_matches() {
1908        let (data, y) = generate_test_data(40, 50, 42);
1909        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1910        let ar = anchor_explanation(&fit, &data, None, 5, 0.8, 5).unwrap();
1911        assert_eq!(ar.observation, 5);
1912    }
1913
1914    #[test]
1915    fn test_anchor_invalid_obs_none() {
1916        let (data, y) = generate_test_data(40, 50, 42);
1917        let fit = fregre_lm(&data, &y, None, 3).unwrap();
1918        assert!(anchor_explanation(&fit, &data, None, 100, 0.8, 5).is_none());
1919    }
1920}