rustkernel_ml/
explainability.rs

1//! Explainability kernels for model interpretation.
2//!
3//! This module provides GPU-accelerated explainability algorithms:
4//! - SHAPValues - Kernel SHAP approximation for feature importance
5//! - FeatureImportance - Permutation-based feature importance
6
7use crate::types::DataMatrix;
8use rand::prelude::*;
9use rand::{Rng, SeedableRng, rng};
10use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
11use serde::{Deserialize, Serialize};
12
13// ============================================================================
14// SHAP Values Kernel
15// ============================================================================
16
17/// Configuration for SHAP computation.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct SHAPConfig {
20    /// Number of samples for approximation.
21    pub n_samples: usize,
22    /// Whether to use kernel SHAP (vs sampling SHAP).
23    pub use_kernel_shap: bool,
24    /// Regularization for weighted least squares.
25    pub regularization: f64,
26    /// Random seed for reproducibility.
27    pub seed: Option<u64>,
28}
29
30impl Default for SHAPConfig {
31    fn default() -> Self {
32        Self {
33            n_samples: 100,
34            use_kernel_shap: true,
35            regularization: 0.01,
36            seed: None,
37        }
38    }
39}
40
41/// SHAP explanation for a single prediction.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct SHAPExplanation {
44    /// Base value (expected prediction over training data).
45    pub base_value: f64,
46    /// SHAP values for each feature.
47    pub shap_values: Vec<f64>,
48    /// Feature names if provided.
49    pub feature_names: Option<Vec<String>>,
50    /// The prediction being explained.
51    pub prediction: f64,
52    /// Sum of SHAP values (should equal prediction - base_value).
53    pub shap_sum: f64,
54}
55
56/// Batch SHAP results.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SHAPBatchResult {
59    /// Base value.
60    pub base_value: f64,
61    /// SHAP values matrix (samples x features).
62    pub shap_values: Vec<Vec<f64>>,
63    /// Feature names.
64    pub feature_names: Option<Vec<String>>,
65    /// Mean absolute SHAP values per feature.
66    pub feature_importance: Vec<f64>,
67}
68
69/// SHAP Values kernel.
70///
71/// Computes SHAP (SHapley Additive exPlanations) values for model predictions.
72/// Uses Kernel SHAP approximation which is model-agnostic and works with any
73/// prediction function.
74///
75/// SHAP values satisfy:
76/// - Local accuracy: f(x) = base_value + sum(shap_values)
77/// - Missingness: Missing features have 0 contribution
78/// - Consistency: If a feature's contribution increases, its SHAP value increases
79#[derive(Debug, Clone)]
80pub struct SHAPValues {
81    metadata: KernelMetadata,
82}
83
84impl Default for SHAPValues {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl SHAPValues {
91    /// Create a new SHAP Values kernel.
92    #[must_use]
93    pub fn new() -> Self {
94        Self {
95            metadata: KernelMetadata::batch("ml/shap-values", Domain::StatisticalML)
96                .with_description("Kernel SHAP for model-agnostic feature explanations")
97                .with_throughput(1_000)
98                .with_latency_us(500.0),
99        }
100    }
101
102    /// Compute SHAP values for a single instance.
103    ///
104    /// # Arguments
105    /// * `instance` - The instance to explain
106    /// * `background` - Background dataset for baseline
107    /// * `predict_fn` - Model prediction function
108    /// * `config` - SHAP configuration
109    pub fn explain<F>(
110        instance: &[f64],
111        background: &DataMatrix,
112        predict_fn: F,
113        config: &SHAPConfig,
114    ) -> SHAPExplanation
115    where
116        F: Fn(&[f64]) -> f64,
117    {
118        let n_features = instance.len();
119
120        if n_features == 0 || background.n_samples == 0 {
121            return SHAPExplanation {
122                base_value: 0.0,
123                shap_values: Vec::new(),
124                feature_names: None,
125                prediction: 0.0,
126                shap_sum: 0.0,
127            };
128        }
129
130        // Compute base value as expected prediction over background
131        let base_value: f64 = (0..background.n_samples)
132            .map(|i| predict_fn(background.row(i)))
133            .sum::<f64>()
134            / background.n_samples as f64;
135
136        let prediction = predict_fn(instance);
137
138        // Use Kernel SHAP
139        let shap_values = if config.use_kernel_shap {
140            Self::kernel_shap(instance, background, &predict_fn, config)
141        } else {
142            Self::sampling_shap(instance, background, &predict_fn, config)
143        };
144
145        let shap_sum: f64 = shap_values.iter().sum();
146
147        SHAPExplanation {
148            base_value,
149            shap_values,
150            feature_names: None,
151            prediction,
152            shap_sum,
153        }
154    }
155
156    /// Kernel SHAP implementation using weighted linear regression.
157    fn kernel_shap<F>(
158        instance: &[f64],
159        background: &DataMatrix,
160        predict_fn: &F,
161        config: &SHAPConfig,
162    ) -> Vec<f64>
163    where
164        F: Fn(&[f64]) -> f64,
165    {
166        let n_features = instance.len();
167        let n_samples = config.n_samples;
168
169        let mut rng = match config.seed {
170            Some(seed) => StdRng::seed_from_u64(seed),
171            None => StdRng::from_rng(&mut rng()),
172        };
173
174        // Generate coalition samples
175        let mut coalitions: Vec<Vec<bool>> = Vec::with_capacity(n_samples);
176        let mut predictions: Vec<f64> = Vec::with_capacity(n_samples);
177        let mut weights: Vec<f64> = Vec::with_capacity(n_samples);
178
179        // Always include full and empty coalitions
180        coalitions.push(vec![true; n_features]);
181        coalitions.push(vec![false; n_features]);
182
183        for coalition in &coalitions[..2] {
184            let masked = Self::create_masked_instance(instance, background, coalition, &mut rng);
185            predictions.push(predict_fn(&masked));
186        }
187
188        weights.push(1e6); // High weight for full coalition
189        weights.push(1e6); // High weight for empty coalition
190
191        // Sample random coalitions
192        for _ in 2..n_samples {
193            let coalition: Vec<bool> = (0..n_features).map(|_| rng.random_bool(0.5)).collect();
194
195            let z: usize = coalition.iter().filter(|&&b| b).count();
196            let weight = Self::kernel_shap_weight(n_features, z);
197
198            let masked = Self::create_masked_instance(instance, background, &coalition, &mut rng);
199            let pred = predict_fn(&masked);
200
201            coalitions.push(coalition);
202            predictions.push(pred);
203            weights.push(weight);
204        }
205
206        // Solve weighted least squares: (X^T W X + λI)^-1 X^T W y
207        Self::solve_weighted_regression(&coalitions, &predictions, &weights, config.regularization)
208    }
209
210    /// Sampling SHAP implementation (simpler, faster, less accurate).
211    fn sampling_shap<F>(
212        instance: &[f64],
213        background: &DataMatrix,
214        predict_fn: &F,
215        config: &SHAPConfig,
216    ) -> Vec<f64>
217    where
218        F: Fn(&[f64]) -> f64,
219    {
220        let n_features = instance.len();
221        let mut shap_values = vec![0.0; n_features];
222        let samples_per_feature = config.n_samples / n_features;
223
224        let mut rng = match config.seed {
225            Some(seed) => StdRng::seed_from_u64(seed),
226            None => StdRng::from_rng(&mut rng()),
227        };
228
229        for feature_idx in 0..n_features {
230            let mut contributions = Vec::with_capacity(samples_per_feature);
231
232            for _ in 0..samples_per_feature {
233                // Random permutation
234                let mut perm: Vec<usize> = (0..n_features).collect();
235                perm.shuffle(&mut rng);
236
237                let feature_pos = perm.iter().position(|&i| i == feature_idx).unwrap();
238
239                // Features before this one in permutation
240                let before: Vec<bool> = (0..n_features)
241                    .map(|i| {
242                        let pos = perm.iter().position(|&p| p == i).unwrap();
243                        pos < feature_pos
244                    })
245                    .collect();
246
247                // Include current feature
248                let mut with_feature = before.clone();
249                with_feature[feature_idx] = true;
250
251                // Sample background
252                let bg_idx = rng.random_range(0..background.n_samples);
253                let bg = background.row(bg_idx);
254
255                // Create masked instances
256                let x_with: Vec<f64> = (0..n_features)
257                    .map(|i| if with_feature[i] { instance[i] } else { bg[i] })
258                    .collect();
259
260                let x_without: Vec<f64> = (0..n_features)
261                    .map(|i| if before[i] { instance[i] } else { bg[i] })
262                    .collect();
263
264                let contribution = predict_fn(&x_with) - predict_fn(&x_without);
265                contributions.push(contribution);
266            }
267
268            shap_values[feature_idx] =
269                contributions.iter().sum::<f64>() / contributions.len() as f64;
270        }
271
272        shap_values
273    }
274
275    /// Kernel SHAP weight function.
276    fn kernel_shap_weight(n_features: usize, coalition_size: usize) -> f64 {
277        if coalition_size == 0 || coalition_size == n_features {
278            return 1e6; // Very high weight for full/empty coalitions
279        }
280
281        let m = n_features as f64;
282        let z = coalition_size as f64;
283
284        // SHAP kernel weight: (M-1) / (C(M,z) * z * (M-z))
285        let binomial = Self::binomial(n_features, coalition_size);
286        if binomial == 0.0 {
287            return 0.0;
288        }
289
290        (m - 1.0) / (binomial * z * (m - z))
291    }
292
293    /// Binomial coefficient.
294    fn binomial(n: usize, k: usize) -> f64 {
295        if k > n {
296            return 0.0;
297        }
298        let k = k.min(n - k);
299        let mut result = 1.0;
300        for i in 0..k {
301            result *= (n - i) as f64 / (i + 1) as f64;
302        }
303        result
304    }
305
306    /// Create masked instance using background data.
307    fn create_masked_instance(
308        instance: &[f64],
309        background: &DataMatrix,
310        coalition: &[bool],
311        rng: &mut StdRng,
312    ) -> Vec<f64> {
313        let bg_idx = rng.random_range(0..background.n_samples);
314        let bg = background.row(bg_idx);
315
316        coalition
317            .iter()
318            .enumerate()
319            .map(|(i, &included)| if included { instance[i] } else { bg[i] })
320            .collect()
321    }
322
323    /// Solve weighted least squares regression.
324    fn solve_weighted_regression(
325        coalitions: &[Vec<bool>],
326        predictions: &[f64],
327        weights: &[f64],
328        regularization: f64,
329    ) -> Vec<f64> {
330        if coalitions.is_empty() {
331            return Vec::new();
332        }
333
334        let n_features = coalitions[0].len();
335        let n_samples = coalitions.len();
336
337        // Build design matrix X (coalitions as 0/1)
338        let mut x: Vec<Vec<f64>> = Vec::with_capacity(n_samples);
339        for coalition in coalitions {
340            let row: Vec<f64> = coalition
341                .iter()
342                .map(|&b| if b { 1.0 } else { 0.0 })
343                .collect();
344            x.push(row);
345        }
346
347        // Compute X^T W X
348        let mut xtw_x = vec![vec![0.0; n_features]; n_features];
349        for i in 0..n_features {
350            for j in 0..n_features {
351                for k in 0..n_samples {
352                    xtw_x[i][j] += x[k][i] * weights[k] * x[k][j];
353                }
354            }
355        }
356
357        // Add regularization
358        for i in 0..n_features {
359            xtw_x[i][i] += regularization;
360        }
361
362        // Compute X^T W y
363        let mut xtw_y = vec![0.0; n_features];
364        for i in 0..n_features {
365            for k in 0..n_samples {
366                xtw_y[i] += x[k][i] * weights[k] * predictions[k];
367            }
368        }
369
370        // Solve using simple Cholesky-like approach
371        Self::solve_linear_system(&xtw_x, &xtw_y)
372    }
373
374    /// Simple linear system solver.
375    fn solve_linear_system(a: &[Vec<f64>], b: &[f64]) -> Vec<f64> {
376        let n = b.len();
377        if n == 0 {
378            return Vec::new();
379        }
380
381        // Gaussian elimination with partial pivoting
382        let mut aug: Vec<Vec<f64>> = a
383            .iter()
384            .enumerate()
385            .map(|(i, row)| {
386                let mut new_row = row.clone();
387                new_row.push(b[i]);
388                new_row
389            })
390            .collect();
391
392        // Forward elimination
393        for i in 0..n {
394            // Find pivot
395            let mut max_idx = i;
396            let mut max_val = aug[i][i].abs();
397            for k in (i + 1)..n {
398                if aug[k][i].abs() > max_val {
399                    max_val = aug[k][i].abs();
400                    max_idx = k;
401                }
402            }
403
404            aug.swap(i, max_idx);
405
406            if aug[i][i].abs() < 1e-10 {
407                continue;
408            }
409
410            for k in (i + 1)..n {
411                let factor = aug[k][i] / aug[i][i];
412                for j in i..=n {
413                    aug[k][j] -= factor * aug[i][j];
414                }
415            }
416        }
417
418        // Back substitution
419        let mut x = vec![0.0; n];
420        for i in (0..n).rev() {
421            if aug[i][i].abs() < 1e-10 {
422                x[i] = 0.0;
423                continue;
424            }
425            x[i] = aug[i][n];
426            for j in (i + 1)..n {
427                x[i] -= aug[i][j] * x[j];
428            }
429            x[i] /= aug[i][i];
430        }
431
432        x
433    }
434
435    /// Explain multiple instances.
436    pub fn explain_batch<F>(
437        instances: &DataMatrix,
438        background: &DataMatrix,
439        predict_fn: F,
440        config: &SHAPConfig,
441        feature_names: Option<Vec<String>>,
442    ) -> SHAPBatchResult
443    where
444        F: Fn(&[f64]) -> f64,
445    {
446        if instances.n_samples == 0 {
447            return SHAPBatchResult {
448                base_value: 0.0,
449                shap_values: Vec::new(),
450                feature_names: None,
451                feature_importance: Vec::new(),
452            };
453        }
454
455        // Compute base value
456        let base_value: f64 = (0..background.n_samples)
457            .map(|i| predict_fn(background.row(i)))
458            .sum::<f64>()
459            / background.n_samples.max(1) as f64;
460
461        // Compute SHAP values for each instance
462        let mut shap_values: Vec<Vec<f64>> = Vec::with_capacity(instances.n_samples);
463
464        for i in 0..instances.n_samples {
465            let instance = instances.row(i);
466            let explanation = Self::explain(instance, background, &predict_fn, config);
467            shap_values.push(explanation.shap_values);
468        }
469
470        // Compute feature importance as mean absolute SHAP values
471        let n_features = instances.n_features;
472        let mut feature_importance = vec![0.0; n_features];
473
474        for values in &shap_values {
475            for (i, &v) in values.iter().enumerate() {
476                feature_importance[i] += v.abs();
477            }
478        }
479
480        for imp in &mut feature_importance {
481            *imp /= shap_values.len() as f64;
482        }
483
484        SHAPBatchResult {
485            base_value,
486            shap_values,
487            feature_names,
488            feature_importance,
489        }
490    }
491}
492
493impl GpuKernel for SHAPValues {
494    fn metadata(&self) -> &KernelMetadata {
495        &self.metadata
496    }
497}
498
499// ============================================================================
500// Feature Importance Kernel
501// ============================================================================
502
503/// Configuration for permutation feature importance.
504#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct FeatureImportanceConfig {
506    /// Number of permutations per feature.
507    pub n_permutations: usize,
508    /// Random seed.
509    pub seed: Option<u64>,
510    /// Metric to use (higher is better).
511    pub metric: ImportanceMetric,
512}
513
514impl Default for FeatureImportanceConfig {
515    fn default() -> Self {
516        Self {
517            n_permutations: 10,
518            seed: None,
519            metric: ImportanceMetric::Accuracy,
520        }
521    }
522}
523
524/// Metric for measuring importance.
525#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
526pub enum ImportanceMetric {
527    /// Classification accuracy.
528    Accuracy,
529    /// Mean squared error (for regression).
530    MSE,
531    /// Mean absolute error.
532    MAE,
533    /// R-squared score.
534    R2,
535}
536
537/// Feature importance result.
538#[derive(Debug, Clone, Serialize, Deserialize)]
539pub struct FeatureImportanceResult {
540    /// Importance scores per feature.
541    pub importances: Vec<f64>,
542    /// Standard deviations of importance scores.
543    pub std_devs: Vec<f64>,
544    /// Feature names if provided.
545    pub feature_names: Option<Vec<String>>,
546    /// Baseline score (without permutation).
547    pub baseline_score: f64,
548    /// Ranked feature indices (most important first).
549    pub ranking: Vec<usize>,
550}
551
552/// Permutation Feature Importance kernel.
553///
554/// Computes feature importance by measuring how much model performance
555/// degrades when each feature is randomly shuffled. Features that cause
556/// larger degradation are more important.
557#[derive(Debug, Clone)]
558pub struct FeatureImportance {
559    metadata: KernelMetadata,
560}
561
562impl Default for FeatureImportance {
563    fn default() -> Self {
564        Self::new()
565    }
566}
567
568impl FeatureImportance {
569    /// Create a new Feature Importance kernel.
570    #[must_use]
571    pub fn new() -> Self {
572        Self {
573            metadata: KernelMetadata::batch("ml/feature-importance", Domain::StatisticalML)
574                .with_description("Permutation-based feature importance")
575                .with_throughput(5_000)
576                .with_latency_us(200.0),
577        }
578    }
579
580    /// Compute permutation feature importance.
581    ///
582    /// # Arguments
583    /// * `data` - Input features
584    /// * `targets` - True labels/values
585    /// * `predict_fn` - Model prediction function
586    /// * `config` - Configuration
587    /// * `feature_names` - Optional feature names
588    pub fn compute<F>(
589        data: &DataMatrix,
590        targets: &[f64],
591        predict_fn: F,
592        config: &FeatureImportanceConfig,
593        feature_names: Option<Vec<String>>,
594    ) -> FeatureImportanceResult
595    where
596        F: Fn(&[f64]) -> f64,
597    {
598        if data.n_samples == 0 || data.n_features == 0 {
599            return FeatureImportanceResult {
600                importances: Vec::new(),
601                std_devs: Vec::new(),
602                feature_names: None,
603                baseline_score: 0.0,
604                ranking: Vec::new(),
605            };
606        }
607
608        let mut rng = match config.seed {
609            Some(seed) => StdRng::seed_from_u64(seed),
610            None => StdRng::from_rng(&mut rng()),
611        };
612
613        // Compute baseline score
614        let predictions: Vec<f64> = (0..data.n_samples)
615            .map(|i| predict_fn(data.row(i)))
616            .collect();
617        let baseline_score = Self::compute_score(&predictions, targets, config.metric);
618
619        // Compute importance for each feature
620        let mut importances = Vec::with_capacity(data.n_features);
621        let mut std_devs = Vec::with_capacity(data.n_features);
622
623        for feature_idx in 0..data.n_features {
624            let mut scores = Vec::with_capacity(config.n_permutations);
625
626            for _ in 0..config.n_permutations {
627                // Create permuted data
628                let mut perm_data = data.data.clone();
629                let mut perm_indices: Vec<usize> = (0..data.n_samples).collect();
630                perm_indices.shuffle(&mut rng);
631
632                // Shuffle feature values
633                for (i, &perm_idx) in perm_indices.iter().enumerate() {
634                    perm_data[i * data.n_features + feature_idx] =
635                        data.data[perm_idx * data.n_features + feature_idx];
636                }
637
638                let perm_matrix = DataMatrix::new(perm_data, data.n_samples, data.n_features);
639
640                // Compute predictions with permuted feature
641                let perm_predictions: Vec<f64> = (0..perm_matrix.n_samples)
642                    .map(|i| predict_fn(perm_matrix.row(i)))
643                    .collect();
644
645                let score = Self::compute_score(&perm_predictions, targets, config.metric);
646                scores.push(score);
647            }
648
649            // Importance = baseline - mean(permuted scores)
650            let mean_score: f64 = scores.iter().sum::<f64>() / scores.len() as f64;
651            let importance = baseline_score - mean_score;
652
653            let variance: f64 =
654                scores.iter().map(|s| (s - mean_score).powi(2)).sum::<f64>() / scores.len() as f64;
655            let std_dev = variance.sqrt();
656
657            importances.push(importance);
658            std_devs.push(std_dev);
659        }
660
661        // Compute ranking
662        let mut ranking: Vec<usize> = (0..data.n_features).collect();
663        ranking.sort_by(|&a, &b| {
664            importances[b]
665                .partial_cmp(&importances[a])
666                .unwrap_or(std::cmp::Ordering::Equal)
667        });
668
669        FeatureImportanceResult {
670            importances,
671            std_devs,
672            feature_names,
673            baseline_score,
674            ranking,
675        }
676    }
677
678    /// Compute score based on metric.
679    fn compute_score(predictions: &[f64], targets: &[f64], metric: ImportanceMetric) -> f64 {
680        if predictions.is_empty() || targets.is_empty() {
681            return 0.0;
682        }
683
684        match metric {
685            ImportanceMetric::Accuracy => {
686                let correct: usize = predictions
687                    .iter()
688                    .zip(targets.iter())
689                    .filter(|&(p, t)| (p.round() - t.round()).abs() < 0.5)
690                    .count();
691                correct as f64 / predictions.len() as f64
692            }
693            ImportanceMetric::MSE => {
694                let mse: f64 = predictions
695                    .iter()
696                    .zip(targets.iter())
697                    .map(|(p, t)| (p - t).powi(2))
698                    .sum::<f64>()
699                    / predictions.len() as f64;
700                -mse // Negative because higher is better
701            }
702            ImportanceMetric::MAE => {
703                let mae: f64 = predictions
704                    .iter()
705                    .zip(targets.iter())
706                    .map(|(p, t)| (p - t).abs())
707                    .sum::<f64>()
708                    / predictions.len() as f64;
709                -mae // Negative because higher is better
710            }
711            ImportanceMetric::R2 => {
712                let mean_target: f64 = targets.iter().sum::<f64>() / targets.len() as f64;
713                let ss_res: f64 = predictions
714                    .iter()
715                    .zip(targets.iter())
716                    .map(|(p, t)| (t - p).powi(2))
717                    .sum();
718                let ss_tot: f64 = targets.iter().map(|t| (t - mean_target).powi(2)).sum();
719                if ss_tot.abs() < 1e-10 {
720                    0.0
721                } else {
722                    1.0 - ss_res / ss_tot
723                }
724            }
725        }
726    }
727}
728
729impl GpuKernel for FeatureImportance {
730    fn metadata(&self) -> &KernelMetadata {
731        &self.metadata
732    }
733}
734
735#[cfg(test)]
736mod tests {
737    use super::*;
738
739    #[test]
740    fn test_shap_values_metadata() {
741        let kernel = SHAPValues::new();
742        assert_eq!(kernel.metadata().id, "ml/shap-values");
743    }
744
745    #[test]
746    fn test_shap_basic() {
747        // Simple linear model: f(x) = x[0] + 2*x[1]
748        let predict_fn = |x: &[f64]| x[0] + 2.0 * x[1];
749
750        let background = DataMatrix::new(vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0], 4, 2);
751
752        let config = SHAPConfig {
753            n_samples: 50,
754            use_kernel_shap: true,
755            regularization: 0.1,
756            seed: Some(42),
757        };
758
759        let instance = vec![1.0, 1.0];
760        let explanation = SHAPValues::explain(&instance, &background, predict_fn, &config);
761
762        // For linear model, SHAP values should approximate coefficients
763        assert!(explanation.shap_values.len() == 2);
764        assert!(explanation.prediction > 0.0);
765    }
766
767    #[test]
768    fn test_shap_batch() {
769        let predict_fn = |x: &[f64]| x[0] * 2.0;
770
771        let background = DataMatrix::new(vec![0.0, 0.5, 1.0, 1.5], 4, 1);
772        let instances = DataMatrix::new(vec![0.5, 1.0, 2.0], 3, 1);
773
774        let config = SHAPConfig {
775            n_samples: 20,
776            seed: Some(42),
777            ..Default::default()
778        };
779
780        let result = SHAPValues::explain_batch(&instances, &background, predict_fn, &config, None);
781
782        assert_eq!(result.shap_values.len(), 3);
783        assert_eq!(result.feature_importance.len(), 1);
784    }
785
786    #[test]
787    fn test_shap_empty() {
788        let predict_fn = |x: &[f64]| x.iter().sum();
789        let background = DataMatrix::new(vec![], 0, 0);
790        let config = SHAPConfig::default();
791
792        let explanation = SHAPValues::explain(&[], &background, predict_fn, &config);
793        assert!(explanation.shap_values.is_empty());
794    }
795
796    #[test]
797    fn test_kernel_shap_weight() {
798        // Edge cases
799        assert!(SHAPValues::kernel_shap_weight(5, 0) > 1000.0);
800        assert!(SHAPValues::kernel_shap_weight(5, 5) > 1000.0);
801
802        // Middle values should have finite weights
803        let w = SHAPValues::kernel_shap_weight(5, 2);
804        assert!(w > 0.0 && w < 1000.0);
805    }
806
807    #[test]
808    fn test_feature_importance_metadata() {
809        let kernel = FeatureImportance::new();
810        assert_eq!(kernel.metadata().id, "ml/feature-importance");
811    }
812
813    #[test]
814    fn test_feature_importance_basic() {
815        // Model that only uses first feature
816        let predict_fn = |x: &[f64]| x[0];
817
818        let data = DataMatrix::new(
819            vec![1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 0.0, 4.0, 0.0, 0.0],
820            4,
821            3,
822        );
823        let targets = vec![1.0, 2.0, 3.0, 4.0];
824
825        let config = FeatureImportanceConfig {
826            n_permutations: 5,
827            seed: Some(42),
828            metric: ImportanceMetric::MSE,
829        };
830
831        let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
832
833        // First feature should be most important
834        assert_eq!(result.importances.len(), 3);
835        assert!(result.importances[0].abs() > result.importances[1].abs());
836        assert!(result.importances[0].abs() > result.importances[2].abs());
837        assert_eq!(result.ranking[0], 0);
838    }
839
840    #[test]
841    fn test_feature_importance_empty() {
842        let predict_fn = |_: &[f64]| 0.0;
843        let data = DataMatrix::new(vec![], 0, 0);
844        let targets: Vec<f64> = vec![];
845        let config = FeatureImportanceConfig::default();
846
847        let result = FeatureImportance::compute(&data, &targets, predict_fn, &config, None);
848        assert!(result.importances.is_empty());
849    }
850
851    #[test]
852    fn test_metrics() {
853        let preds = vec![1.0, 2.0, 3.0];
854        let targets = vec![1.0, 2.0, 3.0];
855
856        // Perfect predictions
857        let acc = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::Accuracy);
858        assert!((acc - 1.0).abs() < 0.01);
859
860        let mse = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::MSE);
861        assert!((mse - 0.0).abs() < 0.01);
862
863        let r2 = FeatureImportance::compute_score(&preds, &targets, ImportanceMetric::R2);
864        assert!((r2 - 1.0).abs() < 0.01);
865    }
866
867    #[test]
868    fn test_binomial() {
869        assert!((SHAPValues::binomial(5, 2) - 10.0).abs() < 0.01);
870        assert!((SHAPValues::binomial(10, 3) - 120.0).abs() < 0.01);
871        assert!((SHAPValues::binomial(5, 0) - 1.0).abs() < 0.01);
872        assert!((SHAPValues::binomial(5, 5) - 1.0).abs() < 0.01);
873    }
874}