Skip to main content

fdars_core/explain/
importance.rs

1//! Permutation importance, pointwise importance, and conditional permutation importance.
2
3use super::helpers::{
4    clone_scores_matrix, compute_conditioning_bins, compute_score_variance,
5    logistic_accuracy_from_scores, permute_component, project_scores, shuffle_global,
6};
7use crate::error::FdarError;
8use crate::matrix::FdMatrix;
9use crate::scalar_on_function::{sigmoid, FregreLmResult, FunctionalLogisticResult};
10use rand::prelude::*;
11
12// ===========================================================================
13// FPC Permutation Importance
14// ===========================================================================
15
16/// Result of FPC permutation importance.
17#[derive(Debug, Clone, PartialEq)]
18pub struct FpcPermutationImportance {
19    /// R^2 (or accuracy) drop per component (length ncomp).
20    pub importance: Vec<f64>,
21    /// Baseline metric (R^2 or accuracy).
22    pub baseline_metric: f64,
23    /// Mean metric after permuting each component.
24    pub permuted_metric: Vec<f64>,
25}
26
27/// Permutation importance for a linear functional regression (metric = R^2).
28///
29/// # Errors
30///
31/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
32/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
33/// count.
34/// Returns [`FdarError::InvalidParameter`] if `n_perm` is zero.
35/// Returns [`FdarError::ComputationFailed`] if the total sum of squares is zero.
36///
37/// # Examples
38///
39/// ```
40/// use fdars_core::matrix::FdMatrix;
41/// use fdars_core::scalar_on_function::fregre_lm;
42/// use fdars_core::explain::fpc_permutation_importance;
43///
44/// let (n, m) = (20, 30);
45/// let data = FdMatrix::from_column_major(
46///     (0..n * m).map(|k| {
47///         let i = (k % n) as f64;
48///         let j = (k / n) as f64;
49///         ((i + 1.0) * j * 0.2).sin()
50///     }).collect(),
51///     n, m,
52/// ).unwrap();
53/// let y: Vec<f64> = (0..n).map(|i| (i as f64 * 0.5).sin()).collect();
54/// let fit = fregre_lm(&data, &y, None, 3).unwrap();
55/// let imp = fpc_permutation_importance(&fit, &data, &y, 10, 42).unwrap();
56/// assert_eq!(imp.importance.len(), 3);
57/// ```
58#[must_use = "expensive computation whose result should not be discarded"]
59pub fn fpc_permutation_importance(
60    fit: &FregreLmResult,
61    data: &FdMatrix,
62    y: &[f64],
63    n_perm: usize,
64    seed: u64,
65) -> Result<FpcPermutationImportance, FdarError> {
66    let (n, m) = data.shape();
67    if n == 0 {
68        return Err(FdarError::InvalidDimension {
69            parameter: "data",
70            expected: ">0 rows".into(),
71            actual: "0".into(),
72        });
73    }
74    if n != y.len() {
75        return Err(FdarError::InvalidDimension {
76            parameter: "y",
77            expected: format!("{n} (matching data rows)"),
78            actual: format!("{}", y.len()),
79        });
80    }
81    if m != fit.fpca.mean.len() {
82        return Err(FdarError::InvalidDimension {
83            parameter: "data",
84            expected: format!("{} columns", fit.fpca.mean.len()),
85            actual: format!("{m}"),
86        });
87    }
88    if n_perm == 0 {
89        return Err(FdarError::InvalidParameter {
90            parameter: "n_perm",
91            message: "must be > 0".into(),
92        });
93    }
94    let ncomp = fit.ncomp;
95    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
96
97    // Baseline R^2 -- compute from same FPC-only prediction used in permuted path
98    // to ensure consistent comparison (gamma terms are constant across permutations)
99    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
100    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
101    if ss_tot == 0.0 {
102        return Err(FdarError::ComputationFailed {
103            operation: "fpc_permutation_importance",
104            detail: "total sum of squares is zero; all response values may be identical — check your data".into(),
105        });
106    }
107    let identity_idx: Vec<usize> = (0..n).collect();
108    let ss_res_base = permuted_ss_res_linear(
109        &scores,
110        &fit.coefficients,
111        y,
112        n,
113        ncomp,
114        ncomp,
115        &identity_idx,
116    );
117    let baseline = 1.0 - ss_res_base / ss_tot;
118
119    let mut rng = StdRng::seed_from_u64(seed);
120    let mut importance = vec![0.0; ncomp];
121    let mut permuted_metric = vec![0.0; ncomp];
122
123    for k in 0..ncomp {
124        let mut sum_r2 = 0.0;
125        for _ in 0..n_perm {
126            let mut idx: Vec<usize> = (0..n).collect();
127            idx.shuffle(&mut rng);
128            let ss_res_perm =
129                permuted_ss_res_linear(&scores, &fit.coefficients, y, n, ncomp, k, &idx);
130            sum_r2 += 1.0 - ss_res_perm / ss_tot;
131        }
132        let mean_perm = sum_r2 / n_perm as f64;
133        permuted_metric[k] = mean_perm;
134        importance[k] = baseline - mean_perm;
135    }
136
137    Ok(FpcPermutationImportance {
138        importance,
139        baseline_metric: baseline,
140        permuted_metric,
141    })
142}
143
144/// Permutation importance for functional logistic regression (metric = accuracy).
145///
146/// # Errors
147///
148/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
149/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
150/// count.
151/// Returns [`FdarError::InvalidParameter`] if `n_perm` is zero.
152#[must_use = "expensive computation whose result should not be discarded"]
153pub fn fpc_permutation_importance_logistic(
154    fit: &FunctionalLogisticResult,
155    data: &FdMatrix,
156    y: &[f64],
157    n_perm: usize,
158    seed: u64,
159) -> Result<FpcPermutationImportance, FdarError> {
160    let (n, m) = data.shape();
161    if n == 0 {
162        return Err(FdarError::InvalidDimension {
163            parameter: "data",
164            expected: ">0 rows".into(),
165            actual: "0".into(),
166        });
167    }
168    if n != y.len() {
169        return Err(FdarError::InvalidDimension {
170            parameter: "y",
171            expected: format!("{n} (matching data rows)"),
172            actual: format!("{}", y.len()),
173        });
174    }
175    if m != fit.fpca.mean.len() {
176        return Err(FdarError::InvalidDimension {
177            parameter: "data",
178            expected: format!("{} columns", fit.fpca.mean.len()),
179            actual: format!("{m}"),
180        });
181    }
182    if n_perm == 0 {
183        return Err(FdarError::InvalidParameter {
184            parameter: "n_perm",
185            message: "must be > 0".into(),
186        });
187    }
188    let ncomp = fit.ncomp;
189    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
190
191    let baseline: f64 = (0..n)
192        .filter(|&i| {
193            let pred = if fit.probabilities[i] >= 0.5 {
194                1.0
195            } else {
196                0.0
197            };
198            (pred - y[i]).abs() < 1e-10
199        })
200        .count() as f64
201        / n as f64;
202
203    let mut rng = StdRng::seed_from_u64(seed);
204    let mut importance = vec![0.0; ncomp];
205    let mut permuted_metric = vec![0.0; ncomp];
206
207    for k in 0..ncomp {
208        let mut sum_acc = 0.0;
209        for _ in 0..n_perm {
210            let mut perm_scores = clone_scores_matrix(&scores, n, ncomp);
211            shuffle_global(&mut perm_scores, &scores, k, n, &mut rng);
212            sum_acc += logistic_accuracy_from_scores(
213                &perm_scores,
214                fit.intercept,
215                &fit.coefficients,
216                y,
217                n,
218                ncomp,
219            );
220        }
221        let mean_acc = sum_acc / n_perm as f64;
222        permuted_metric[k] = mean_acc;
223        importance[k] = baseline - mean_acc;
224    }
225
226    Ok(FpcPermutationImportance {
227        importance,
228        baseline_metric: baseline,
229        permuted_metric,
230    })
231}
232
233/// Compute SS_res with component k shuffled by given index permutation.
234fn permuted_ss_res_linear(
235    scores: &FdMatrix,
236    coefficients: &[f64],
237    y: &[f64],
238    n: usize,
239    ncomp: usize,
240    k: usize,
241    perm_idx: &[usize],
242) -> f64 {
243    (0..n)
244        .map(|i| {
245            let mut yhat = coefficients[0];
246            for c in 0..ncomp {
247                let s = if c == k {
248                    scores[(perm_idx[i], c)]
249                } else {
250                    scores[(i, c)]
251                };
252                yhat += coefficients[1 + c] * s;
253            }
254            (y[i] - yhat).powi(2)
255        })
256        .sum()
257}
258
259// ===========================================================================
260// Pointwise Variable Importance
261// ===========================================================================
262
263/// Result of pointwise variable importance analysis.
264#[derive(Debug, Clone, PartialEq)]
265#[non_exhaustive]
266pub struct PointwiseImportanceResult {
267    /// Importance at each grid point (length m).
268    pub importance: Vec<f64>,
269    /// Normalized importance summing to 1 (length m).
270    pub importance_normalized: Vec<f64>,
271    /// Per-component importance (ncomp x m).
272    pub component_importance: FdMatrix,
273    /// Variance of each FPC score (length ncomp).
274    pub score_variance: Vec<f64>,
275}
276
277/// Pointwise variable importance for a linear functional regression model.
278///
279/// Measures how much X(t_j) contributes to prediction variance via the FPC decomposition.
280///
281/// # Errors
282///
283/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
284/// Returns [`FdarError::InvalidDimension`] if the rotation matrix has zero rows
285/// or the scores matrix has fewer than 2 rows.
286#[must_use = "expensive computation whose result should not be discarded"]
287pub fn pointwise_importance(fit: &FregreLmResult) -> Result<PointwiseImportanceResult, FdarError> {
288    let ncomp = fit.ncomp;
289    let m = fit.fpca.rotation.nrows();
290    let n = fit.fpca.scores.nrows();
291    if ncomp == 0 {
292        return Err(FdarError::InvalidParameter {
293            parameter: "ncomp",
294            message: "must be > 0".into(),
295        });
296    }
297    if m == 0 {
298        return Err(FdarError::InvalidDimension {
299            parameter: "rotation",
300            expected: ">0 rows".into(),
301            actual: "0".into(),
302        });
303    }
304    if n < 2 {
305        return Err(FdarError::InvalidDimension {
306            parameter: "scores",
307            expected: ">=2 rows".into(),
308            actual: format!("{n}"),
309        });
310    }
311
312    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
313    let (component_importance, importance, importance_normalized) =
314        compute_pointwise_importance_core(
315            &fit.coefficients,
316            &fit.fpca.rotation,
317            &score_variance,
318            ncomp,
319            m,
320        );
321
322    Ok(PointwiseImportanceResult {
323        importance,
324        importance_normalized,
325        component_importance,
326        score_variance,
327    })
328}
329
330/// Pointwise variable importance for a functional logistic regression model.
331///
332/// # Errors
333///
334/// Returns [`FdarError::InvalidParameter`] if `fit.ncomp` is zero.
335/// Returns [`FdarError::InvalidDimension`] if the rotation matrix has zero rows
336/// or the scores matrix has fewer than 2 rows.
337#[must_use = "expensive computation whose result should not be discarded"]
338pub fn pointwise_importance_logistic(
339    fit: &FunctionalLogisticResult,
340) -> Result<PointwiseImportanceResult, FdarError> {
341    let ncomp = fit.ncomp;
342    let m = fit.fpca.rotation.nrows();
343    let n = fit.fpca.scores.nrows();
344    if ncomp == 0 {
345        return Err(FdarError::InvalidParameter {
346            parameter: "ncomp",
347            message: "must be > 0".into(),
348        });
349    }
350    if m == 0 {
351        return Err(FdarError::InvalidDimension {
352            parameter: "rotation",
353            expected: ">0 rows".into(),
354            actual: "0".into(),
355        });
356    }
357    if n < 2 {
358        return Err(FdarError::InvalidDimension {
359            parameter: "scores",
360            expected: ">=2 rows".into(),
361            actual: format!("{n}"),
362        });
363    }
364
365    let score_variance = compute_score_variance(&fit.fpca.scores, n, ncomp);
366    let (component_importance, importance, importance_normalized) =
367        compute_pointwise_importance_core(
368            &fit.coefficients,
369            &fit.fpca.rotation,
370            &score_variance,
371            ncomp,
372            m,
373        );
374
375    Ok(PointwiseImportanceResult {
376        importance,
377        importance_normalized,
378        component_importance,
379        score_variance,
380    })
381}
382
383/// Compute component importance matrix and aggregated importance.
384fn compute_pointwise_importance_core(
385    coefficients: &[f64],
386    rotation: &FdMatrix,
387    score_variance: &[f64],
388    ncomp: usize,
389    m: usize,
390) -> (FdMatrix, Vec<f64>, Vec<f64>) {
391    let mut component_importance = FdMatrix::zeros(ncomp, m);
392    for k in 0..ncomp {
393        let ck = coefficients[1 + k];
394        for j in 0..m {
395            component_importance[(k, j)] = (ck * rotation[(j, k)]).powi(2) * score_variance[k];
396        }
397    }
398
399    let mut importance = vec![0.0; m];
400    for j in 0..m {
401        for k in 0..ncomp {
402            importance[j] += component_importance[(k, j)];
403        }
404    }
405
406    let total: f64 = importance.iter().sum();
407    let importance_normalized = if total > 0.0 {
408        importance.iter().map(|&v| v / total).collect()
409    } else {
410        vec![0.0; m]
411    };
412
413    (component_importance, importance, importance_normalized)
414}
415
416// ===========================================================================
417// Conditional Permutation Importance
418// ===========================================================================
419
420/// Result of conditional permutation importance.
421#[derive(Debug, Clone, PartialEq)]
422#[non_exhaustive]
423pub struct ConditionalPermutationImportanceResult {
424    /// Conditional importance per FPC component, length ncomp.
425    pub importance: Vec<f64>,
426    /// Baseline metric (R^2 or accuracy).
427    pub baseline_metric: f64,
428    /// Mean metric after conditional permutation, length ncomp.
429    pub permuted_metric: Vec<f64>,
430    /// Unconditional (standard) permutation importance for comparison, length ncomp.
431    pub unconditional_importance: Vec<f64>,
432}
433
434/// Conditional permutation importance for a linear functional regression model.
435///
436/// # Errors
437///
438/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
439/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
440/// count.
441/// Returns [`FdarError::InvalidParameter`] if `n_perm` or `n_bins` is zero.
442/// Returns [`FdarError::ComputationFailed`] if the total sum of squares is zero.
443#[must_use = "expensive computation whose result should not be discarded"]
444pub fn conditional_permutation_importance(
445    fit: &FregreLmResult,
446    data: &FdMatrix,
447    y: &[f64],
448    scalar_covariates: Option<&FdMatrix>,
449    n_bins: usize,
450    n_perm: usize,
451    seed: u64,
452) -> Result<ConditionalPermutationImportanceResult, FdarError> {
453    let (n, m) = data.shape();
454    if n == 0 {
455        return Err(FdarError::InvalidDimension {
456            parameter: "data",
457            expected: ">0 rows".into(),
458            actual: "0".into(),
459        });
460    }
461    if n != y.len() {
462        return Err(FdarError::InvalidDimension {
463            parameter: "y",
464            expected: format!("{n} (matching data rows)"),
465            actual: format!("{}", y.len()),
466        });
467    }
468    if m != fit.fpca.mean.len() {
469        return Err(FdarError::InvalidDimension {
470            parameter: "data",
471            expected: format!("{} columns", fit.fpca.mean.len()),
472            actual: format!("{m}"),
473        });
474    }
475    if n_perm == 0 {
476        return Err(FdarError::InvalidParameter {
477            parameter: "n_perm",
478            message: "must be > 0".into(),
479        });
480    }
481    if n_bins == 0 {
482        return Err(FdarError::InvalidParameter {
483            parameter: "n_bins",
484            message: "must be > 0".into(),
485        });
486    }
487    let _ = scalar_covariates;
488    let ncomp = fit.ncomp;
489    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
490
491    let y_mean: f64 = y.iter().sum::<f64>() / n as f64;
492    let ss_tot: f64 = y.iter().map(|&yi| (yi - y_mean).powi(2)).sum();
493    if ss_tot == 0.0 {
494        return Err(FdarError::ComputationFailed {
495            operation: "conditional_permutation_importance",
496            detail: "total sum of squares is zero; all response values may be identical — check your data".into(),
497        });
498    }
499    let ss_res_base: f64 = fit.residuals.iter().map(|r| r * r).sum();
500    let baseline = 1.0 - ss_res_base / ss_tot;
501
502    let predict_r2 = |score_mat: &FdMatrix| -> f64 {
503        let ss_res: f64 = (0..n)
504            .map(|i| {
505                let mut yhat = fit.coefficients[0];
506                for c in 0..ncomp {
507                    yhat += fit.coefficients[1 + c] * score_mat[(i, c)];
508                }
509                (y[i] - yhat).powi(2)
510            })
511            .sum();
512        1.0 - ss_res / ss_tot
513    };
514
515    let mut rng = StdRng::seed_from_u64(seed);
516    let mut importance = vec![0.0; ncomp];
517    let mut permuted_metric = vec![0.0; ncomp];
518    let mut unconditional_importance = vec![0.0; ncomp];
519
520    for k in 0..ncomp {
521        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
522        let (mean_cond, mean_uncond) =
523            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_r2);
524        permuted_metric[k] = mean_cond;
525        importance[k] = baseline - mean_cond;
526        unconditional_importance[k] = baseline - mean_uncond;
527    }
528
529    Ok(ConditionalPermutationImportanceResult {
530        importance,
531        baseline_metric: baseline,
532        permuted_metric,
533        unconditional_importance,
534    })
535}
536
537/// Conditional permutation importance for a functional logistic regression model.
538///
539/// # Errors
540///
541/// Returns [`FdarError::InvalidDimension`] if `data` has zero rows, its column
542/// count does not match `fit.fpca.mean`, or `y.len()` does not match the row
543/// count.
544/// Returns [`FdarError::InvalidParameter`] if `n_perm` or `n_bins` is zero.
545#[must_use = "expensive computation whose result should not be discarded"]
546pub fn conditional_permutation_importance_logistic(
547    fit: &FunctionalLogisticResult,
548    data: &FdMatrix,
549    y: &[f64],
550    scalar_covariates: Option<&FdMatrix>,
551    n_bins: usize,
552    n_perm: usize,
553    seed: u64,
554) -> Result<ConditionalPermutationImportanceResult, FdarError> {
555    let (n, m) = data.shape();
556    if n == 0 {
557        return Err(FdarError::InvalidDimension {
558            parameter: "data",
559            expected: ">0 rows".into(),
560            actual: "0".into(),
561        });
562    }
563    if n != y.len() {
564        return Err(FdarError::InvalidDimension {
565            parameter: "y",
566            expected: format!("{n} (matching data rows)"),
567            actual: format!("{}", y.len()),
568        });
569    }
570    if m != fit.fpca.mean.len() {
571        return Err(FdarError::InvalidDimension {
572            parameter: "data",
573            expected: format!("{} columns", fit.fpca.mean.len()),
574            actual: format!("{m}"),
575        });
576    }
577    if n_perm == 0 {
578        return Err(FdarError::InvalidParameter {
579            parameter: "n_perm",
580            message: "must be > 0".into(),
581        });
582    }
583    if n_bins == 0 {
584        return Err(FdarError::InvalidParameter {
585            parameter: "n_bins",
586            message: "must be > 0".into(),
587        });
588    }
589    let _ = scalar_covariates;
590    let ncomp = fit.ncomp;
591    let scores = project_scores(data, &fit.fpca.mean, &fit.fpca.rotation, ncomp);
592
593    let baseline: f64 = (0..n)
594        .filter(|&i| {
595            let pred = if fit.probabilities[i] >= 0.5 {
596                1.0
597            } else {
598                0.0
599            };
600            (pred - y[i]).abs() < 1e-10
601        })
602        .count() as f64
603        / n as f64;
604
605    let predict_acc = |score_mat: &FdMatrix| -> f64 {
606        let correct: usize = (0..n)
607            .filter(|&i| {
608                let mut eta = fit.intercept;
609                for c in 0..ncomp {
610                    eta += fit.coefficients[1 + c] * score_mat[(i, c)];
611                }
612                let pred = if sigmoid(eta) >= 0.5 { 1.0 } else { 0.0 };
613                (pred - y[i]).abs() < 1e-10
614            })
615            .count();
616        correct as f64 / n as f64
617    };
618
619    let mut rng = StdRng::seed_from_u64(seed);
620    let mut importance = vec![0.0; ncomp];
621    let mut permuted_metric = vec![0.0; ncomp];
622    let mut unconditional_importance = vec![0.0; ncomp];
623
624    for k in 0..ncomp {
625        let bins = compute_conditioning_bins(&scores, ncomp, k, n, n_bins);
626        let (mean_cond, mean_uncond) =
627            permute_component(&scores, &bins, k, n, ncomp, n_perm, &mut rng, &predict_acc);
628        permuted_metric[k] = mean_cond;
629        importance[k] = baseline - mean_cond;
630        unconditional_importance[k] = baseline - mean_uncond;
631    }
632
633    Ok(ConditionalPermutationImportanceResult {
634        importance,
635        baseline_metric: baseline,
636        permuted_metric,
637        unconditional_importance,
638    })
639}