Skip to main content

entrenar/train/callback/
explainability.rs

1//! Explainability callback for computing feature attributions during training
2
3use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5/// Method for computing feature attributions
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ExplainMethod {
8    /// Permutation importance - fast, model-agnostic
9    PermutationImportance,
10    /// Integrated gradients - for differentiable models
11    IntegratedGradients,
12    /// Saliency maps - gradient-based attribution
13    Saliency,
14}
15
16/// Feature importance result for a single epoch
17#[derive(Debug, Clone)]
18pub struct FeatureImportanceResult {
19    /// Epoch when computed
20    pub epoch: usize,
21    /// Feature index to importance score
22    pub importances: Vec<(usize, f32)>,
23    /// Method used
24    pub method: ExplainMethod,
25}
26
27/// Callback for computing feature attributions during training
28///
29/// Integrates with aprender's interpret module to provide explainability
30/// insights during model evaluation.
31///
32/// # Example
33///
34/// ```ignore
35/// use entrenar::train::{ExplainabilityCallback, ExplainMethod};
36///
37/// let callback = ExplainabilityCallback::new(ExplainMethod::PermutationImportance)
38///     .with_top_k(5)
39///     .with_eval_samples(100);
40/// ```
41#[derive(Debug)]
42pub struct ExplainabilityCallback {
43    method: ExplainMethod,
44    top_k: usize,
45    eval_samples: usize,
46    results: Vec<FeatureImportanceResult>,
47    feature_names: Option<Vec<String>>,
48}
49
50impl ExplainabilityCallback {
51    /// Create new explainability callback
52    ///
53    /// # Arguments
54    ///
55    /// * `method` - Attribution method to use
56    pub fn new(method: ExplainMethod) -> Self {
57        Self { method, top_k: 10, eval_samples: 50, results: Vec::new(), feature_names: None }
58    }
59
60    /// Set number of top features to track
61    pub fn with_top_k(mut self, k: usize) -> Self {
62        self.top_k = k;
63        self
64    }
65
66    /// Set number of samples to use for evaluation
67    pub fn with_eval_samples(mut self, n: usize) -> Self {
68        self.eval_samples = n;
69        self
70    }
71
72    /// Set feature names for interpretability
73    pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
74        self.feature_names = Some(names);
75        self
76    }
77
78    /// Get attribution method
79    pub fn method(&self) -> ExplainMethod {
80        self.method
81    }
82
83    /// Get top-k setting
84    pub fn top_k(&self) -> usize {
85        self.top_k
86    }
87
88    /// Get eval samples setting
89    pub fn eval_samples(&self) -> usize {
90        self.eval_samples
91    }
92
93    /// Get all computed results
94    pub fn results(&self) -> &[FeatureImportanceResult] {
95        &self.results
96    }
97
98    /// Get feature names if set
99    pub fn feature_names(&self) -> Option<&[String]> {
100        self.feature_names.as_deref()
101    }
102
103    /// Record feature importances for an epoch
104    ///
105    /// Call this during on_epoch_end with computed importances
106    pub fn record_importances(&mut self, epoch: usize, importances: Vec<(usize, f32)>) {
107        let mut sorted = importances;
108        sorted
109            .sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap_or(std::cmp::Ordering::Equal));
110        sorted.truncate(self.top_k);
111
112        self.results.push(FeatureImportanceResult {
113            epoch,
114            importances: sorted,
115            method: self.method,
116        });
117    }
118
119    /// Compute permutation importance using aprender
120    ///
121    /// # Arguments
122    ///
123    /// * `predict_fn` - Model prediction function
124    /// * `x` - Feature vectors
125    /// * `y` - Target values
126    pub fn compute_permutation_importance<P>(
127        &self,
128        predict_fn: P,
129        x: &[aprender::primitives::Vector<f32>],
130        y: &[f32],
131    ) -> Vec<(usize, f32)>
132    where
133        P: Fn(&aprender::primitives::Vector<f32>) -> f32,
134    {
135        let importance = aprender::interpret::PermutationImportance::compute(
136            predict_fn,
137            x,
138            y,
139            |pred, true_val| (pred - true_val).powi(2), // MSE
140        );
141
142        importance.scores().as_slice().iter().enumerate().map(|(i, &v)| (i, v)).collect()
143    }
144
145    /// Compute integrated gradients using aprender
146    ///
147    /// # Arguments
148    ///
149    /// * `model_fn` - Model prediction function
150    /// * `sample` - Input sample to explain
151    /// * `baseline` - Baseline input (typically zeros)
152    pub fn compute_integrated_gradients<F>(
153        &self,
154        model_fn: F,
155        sample: &aprender::primitives::Vector<f32>,
156        baseline: &aprender::primitives::Vector<f32>,
157    ) -> Vec<(usize, f32)>
158    where
159        F: Fn(&aprender::primitives::Vector<f32>) -> f32,
160    {
161        let ig = aprender::interpret::IntegratedGradients::default();
162        let attributions = ig.attribute(model_fn, sample, baseline);
163
164        attributions.as_slice().iter().enumerate().map(|(i, &v)| (i, v)).collect()
165    }
166
167    /// Compute saliency map using aprender
168    ///
169    /// # Arguments
170    ///
171    /// * `model_fn` - Model prediction function
172    /// * `sample` - Input sample to explain
173    pub fn compute_saliency<F>(
174        &self,
175        model_fn: F,
176        sample: &aprender::primitives::Vector<f32>,
177    ) -> Vec<(usize, f32)>
178    where
179        F: Fn(&aprender::primitives::Vector<f32>) -> f32,
180    {
181        let sm = aprender::interpret::SaliencyMap::default();
182        let saliency = sm.compute(model_fn, sample);
183
184        saliency.as_slice().iter().enumerate().map(|(i, &v)| (i, v)).collect()
185    }
186
187    /// Get top features that have been consistently important across epochs
188    pub fn consistent_top_features(&self) -> Vec<(usize, f32)> {
189        if self.results.is_empty() {
190            return Vec::new();
191        }
192
193        // Count frequency of each feature in top-k across epochs
194        let mut freq: std::collections::HashMap<usize, (usize, f32)> =
195            std::collections::HashMap::new();
196
197        for result in &self.results {
198            for (idx, score) in &result.importances {
199                let entry = freq.entry(*idx).or_insert((0, 0.0));
200                entry.0 += 1;
201                entry.1 += score.abs();
202            }
203        }
204
205        // Average score and sort by frequency then score
206        let mut features: Vec<_> = freq
207            .into_iter()
208            .map(|(idx, (count, total))| (idx, total / count as f32, count))
209            .collect();
210
211        features.sort_by(|a, b| {
212            b.2.cmp(&a.2).then_with(|| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal))
213        });
214
215        features.into_iter().take(self.top_k).map(|(idx, avg_score, _)| (idx, avg_score)).collect()
216    }
217}
218
219impl TrainerCallback for ExplainabilityCallback {
220    fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
221        // Note: Actual computation requires model and data access
222        // This callback stores configuration and results
223        // Users should call compute_* methods and record_importances externally
224        let _ = ctx; // Acknowledge context
225        CallbackAction::Continue
226    }
227
228    fn name(&self) -> &'static str {
229        "ExplainabilityCallback"
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn test_explainability_callback_creation() {
239        let cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
240        assert_eq!(cb.method(), ExplainMethod::PermutationImportance);
241        assert_eq!(cb.top_k(), 10); // Default
242        assert_eq!(cb.eval_samples(), 50); // Default
243        assert!(cb.results().is_empty());
244    }
245
246    #[test]
247    fn test_explainability_callback_builder() {
248        let cb = ExplainabilityCallback::new(ExplainMethod::IntegratedGradients)
249            .with_top_k(5)
250            .with_eval_samples(100)
251            .with_feature_names(vec!["f1".to_string(), "f2".to_string()]);
252
253        assert_eq!(cb.method(), ExplainMethod::IntegratedGradients);
254        assert_eq!(cb.top_k(), 5);
255        assert_eq!(cb.eval_samples(), 100);
256        assert_eq!(cb.feature_names(), Some(&["f1".to_string(), "f2".to_string()][..]));
257    }
258
259    #[test]
260    fn test_explainability_callback_record_importances() {
261        let mut cb = ExplainabilityCallback::new(ExplainMethod::Saliency).with_top_k(3);
262
263        // Record importances for epoch 0
264        let importances = vec![(0, 0.5), (1, 0.3), (2, 0.8), (3, 0.1), (4, 0.6)];
265        cb.record_importances(0, importances);
266
267        assert_eq!(cb.results().len(), 1);
268        let result = &cb.results()[0];
269        assert_eq!(result.epoch, 0);
270        assert_eq!(result.method, ExplainMethod::Saliency);
271        assert_eq!(result.importances.len(), 3); // Top 3
272
273        // Should be sorted by absolute value descending
274        assert_eq!(result.importances[0].0, 2); // 0.8
275        assert_eq!(result.importances[1].0, 4); // 0.6
276        assert_eq!(result.importances[2].0, 0); // 0.5
277    }
278
279    #[test]
280    fn test_explainability_callback_consistent_features() {
281        let mut cb =
282            ExplainabilityCallback::new(ExplainMethod::PermutationImportance).with_top_k(2);
283
284        // Epoch 0: features 0 and 1 are important
285        cb.record_importances(0, vec![(0, 0.8), (1, 0.6), (2, 0.1)]);
286        // Epoch 1: features 0 and 2 are important
287        cb.record_importances(1, vec![(0, 0.7), (2, 0.5), (1, 0.2)]);
288        // Epoch 2: feature 0 is important again
289        cb.record_importances(2, vec![(0, 0.9), (1, 0.4), (2, 0.3)]);
290
291        let consistent = cb.consistent_top_features();
292        // Feature 0 appears in all epochs, should be first
293        assert!(!consistent.is_empty());
294        assert_eq!(consistent[0].0, 0);
295    }
296
297    #[test]
298    fn test_explainability_callback_trainer_callback_impl() {
299        let mut cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
300        let ctx = CallbackContext::default();
301
302        // Should always continue (doesn't auto-compute)
303        assert_eq!(cb.on_epoch_end(&ctx), CallbackAction::Continue);
304        assert_eq!(cb.name(), "ExplainabilityCallback");
305    }
306
307    #[test]
308    fn test_explain_method_enum() {
309        // Test all variants are distinct
310        assert_ne!(ExplainMethod::PermutationImportance, ExplainMethod::IntegratedGradients);
311        assert_ne!(ExplainMethod::IntegratedGradients, ExplainMethod::Saliency);
312        assert_ne!(ExplainMethod::Saliency, ExplainMethod::PermutationImportance);
313
314        // Test Clone and Copy
315        let method = ExplainMethod::Saliency;
316        let cloned = method;
317        assert_eq!(method, cloned);
318    }
319
320    #[test]
321    fn test_feature_importance_result_fields() {
322        let result = FeatureImportanceResult {
323            epoch: 5,
324            importances: vec![(0, 0.9), (1, 0.7)],
325            method: ExplainMethod::IntegratedGradients,
326        };
327
328        assert_eq!(result.epoch, 5);
329        assert_eq!(result.importances.len(), 2);
330        assert_eq!(result.method, ExplainMethod::IntegratedGradients);
331    }
332
333    #[test]
334    fn test_explainability_empty_results() {
335        let cb = ExplainabilityCallback::new(ExplainMethod::Saliency);
336        assert!(cb.consistent_top_features().is_empty());
337    }
338
339    #[test]
340    fn test_explainability_feature_names_none() {
341        let cb = ExplainabilityCallback::new(ExplainMethod::Saliency);
342        assert!(cb.feature_names().is_none());
343    }
344
345    #[test]
346    fn test_explainability_record_importances_negative() {
347        let mut cb = ExplainabilityCallback::new(ExplainMethod::Saliency).with_top_k(2);
348        let importances = vec![(0, -0.9), (1, 0.5), (2, -0.3)];
349        cb.record_importances(0, importances);
350        let result = &cb.results()[0];
351        assert_eq!(result.importances[0].0, 0);
352        assert_eq!(result.importances[1].0, 1);
353    }
354
355    #[test]
356    fn test_explainability_callback_basic() {
357        let mut cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
358        assert_eq!(cb.name(), "ExplainabilityCallback");
359
360        let mut ctx = CallbackContext::default();
361        ctx.step = 5;
362        ctx.loss = 0.5;
363
364        cb.on_step_end(&ctx);
365        // Should have recorded something
366    }
367
368    #[test]
369    fn test_explainability_compute_permutation_importance() {
370        let cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
371
372        // Create sample data using aprender's Vector type
373        let x = vec![
374            aprender::primitives::Vector::from_slice(&[1.0, 2.0, 3.0]),
375            aprender::primitives::Vector::from_slice(&[4.0, 5.0, 6.0]),
376            aprender::primitives::Vector::from_slice(&[7.0, 8.0, 9.0]),
377        ];
378        let y = vec![1.0, 2.0, 3.0];
379
380        // Simple linear prediction function
381        let predict_fn = |v: &aprender::primitives::Vector<f32>| -> f32 {
382            v.as_slice()[0] * 0.1 + v.as_slice()[1] * 0.2
383        };
384
385        let importance = cb.compute_permutation_importance(predict_fn, &x, &y);
386        assert_eq!(importance.len(), 3);
387    }
388
389    #[test]
390    fn test_explainability_compute_integrated_gradients() {
391        let cb = ExplainabilityCallback::new(ExplainMethod::IntegratedGradients);
392
393        let sample = aprender::primitives::Vector::from_slice(&[1.0, 2.0, 3.0]);
394        let baseline = aprender::primitives::Vector::from_slice(&[0.0, 0.0, 0.0]);
395
396        let model_fn =
397            |v: &aprender::primitives::Vector<f32>| -> f32 { v.as_slice().iter().sum::<f32>() };
398
399        let attributions = cb.compute_integrated_gradients(model_fn, &sample, &baseline);
400        assert_eq!(attributions.len(), 3);
401    }
402
403    #[test]
404    fn test_explainability_compute_saliency() {
405        let cb = ExplainabilityCallback::new(ExplainMethod::Saliency);
406
407        let sample = aprender::primitives::Vector::from_slice(&[1.0, 2.0, 3.0]);
408
409        let model_fn =
410            |v: &aprender::primitives::Vector<f32>| -> f32 { v.as_slice().iter().sum::<f32>() };
411
412        let saliency = cb.compute_saliency(model_fn, &sample);
413        assert_eq!(saliency.len(), 3);
414    }
415}