1use super::traits::{CallbackAction, CallbackContext, TrainerCallback};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum ExplainMethod {
8 PermutationImportance,
10 IntegratedGradients,
12 Saliency,
14}
15
16#[derive(Debug, Clone)]
18pub struct FeatureImportanceResult {
19 pub epoch: usize,
21 pub importances: Vec<(usize, f32)>,
23 pub method: ExplainMethod,
25}
26
27#[derive(Debug)]
42pub struct ExplainabilityCallback {
43 method: ExplainMethod,
44 top_k: usize,
45 eval_samples: usize,
46 results: Vec<FeatureImportanceResult>,
47 feature_names: Option<Vec<String>>,
48}
49
50impl ExplainabilityCallback {
51 pub fn new(method: ExplainMethod) -> Self {
57 Self { method, top_k: 10, eval_samples: 50, results: Vec::new(), feature_names: None }
58 }
59
60 pub fn with_top_k(mut self, k: usize) -> Self {
62 self.top_k = k;
63 self
64 }
65
66 pub fn with_eval_samples(mut self, n: usize) -> Self {
68 self.eval_samples = n;
69 self
70 }
71
72 pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
74 self.feature_names = Some(names);
75 self
76 }
77
78 pub fn method(&self) -> ExplainMethod {
80 self.method
81 }
82
83 pub fn top_k(&self) -> usize {
85 self.top_k
86 }
87
88 pub fn eval_samples(&self) -> usize {
90 self.eval_samples
91 }
92
93 pub fn results(&self) -> &[FeatureImportanceResult] {
95 &self.results
96 }
97
98 pub fn feature_names(&self) -> Option<&[String]> {
100 self.feature_names.as_deref()
101 }
102
103 pub fn record_importances(&mut self, epoch: usize, importances: Vec<(usize, f32)>) {
107 let mut sorted = importances;
108 sorted
109 .sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap_or(std::cmp::Ordering::Equal));
110 sorted.truncate(self.top_k);
111
112 self.results.push(FeatureImportanceResult {
113 epoch,
114 importances: sorted,
115 method: self.method,
116 });
117 }
118
119 pub fn compute_permutation_importance<P>(
127 &self,
128 predict_fn: P,
129 x: &[aprender::primitives::Vector<f32>],
130 y: &[f32],
131 ) -> Vec<(usize, f32)>
132 where
133 P: Fn(&aprender::primitives::Vector<f32>) -> f32,
134 {
135 let importance = aprender::interpret::PermutationImportance::compute(
136 predict_fn,
137 x,
138 y,
139 |pred, true_val| (pred - true_val).powi(2), );
141
142 importance.scores().as_slice().iter().enumerate().map(|(i, &v)| (i, v)).collect()
143 }
144
145 pub fn compute_integrated_gradients<F>(
153 &self,
154 model_fn: F,
155 sample: &aprender::primitives::Vector<f32>,
156 baseline: &aprender::primitives::Vector<f32>,
157 ) -> Vec<(usize, f32)>
158 where
159 F: Fn(&aprender::primitives::Vector<f32>) -> f32,
160 {
161 let ig = aprender::interpret::IntegratedGradients::default();
162 let attributions = ig.attribute(model_fn, sample, baseline);
163
164 attributions.as_slice().iter().enumerate().map(|(i, &v)| (i, v)).collect()
165 }
166
167 pub fn compute_saliency<F>(
174 &self,
175 model_fn: F,
176 sample: &aprender::primitives::Vector<f32>,
177 ) -> Vec<(usize, f32)>
178 where
179 F: Fn(&aprender::primitives::Vector<f32>) -> f32,
180 {
181 let sm = aprender::interpret::SaliencyMap::default();
182 let saliency = sm.compute(model_fn, sample);
183
184 saliency.as_slice().iter().enumerate().map(|(i, &v)| (i, v)).collect()
185 }
186
187 pub fn consistent_top_features(&self) -> Vec<(usize, f32)> {
189 if self.results.is_empty() {
190 return Vec::new();
191 }
192
193 let mut freq: std::collections::HashMap<usize, (usize, f32)> =
195 std::collections::HashMap::new();
196
197 for result in &self.results {
198 for (idx, score) in &result.importances {
199 let entry = freq.entry(*idx).or_insert((0, 0.0));
200 entry.0 += 1;
201 entry.1 += score.abs();
202 }
203 }
204
205 let mut features: Vec<_> = freq
207 .into_iter()
208 .map(|(idx, (count, total))| (idx, total / count as f32, count))
209 .collect();
210
211 features.sort_by(|a, b| {
212 b.2.cmp(&a.2).then_with(|| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal))
213 });
214
215 features.into_iter().take(self.top_k).map(|(idx, avg_score, _)| (idx, avg_score)).collect()
216 }
217}
218
219impl TrainerCallback for ExplainabilityCallback {
220 fn on_epoch_end(&mut self, ctx: &CallbackContext) -> CallbackAction {
221 let _ = ctx; CallbackAction::Continue
226 }
227
228 fn name(&self) -> &'static str {
229 "ExplainabilityCallback"
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_explainability_callback_creation() {
239 let cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
240 assert_eq!(cb.method(), ExplainMethod::PermutationImportance);
241 assert_eq!(cb.top_k(), 10); assert_eq!(cb.eval_samples(), 50); assert!(cb.results().is_empty());
244 }
245
246 #[test]
247 fn test_explainability_callback_builder() {
248 let cb = ExplainabilityCallback::new(ExplainMethod::IntegratedGradients)
249 .with_top_k(5)
250 .with_eval_samples(100)
251 .with_feature_names(vec!["f1".to_string(), "f2".to_string()]);
252
253 assert_eq!(cb.method(), ExplainMethod::IntegratedGradients);
254 assert_eq!(cb.top_k(), 5);
255 assert_eq!(cb.eval_samples(), 100);
256 assert_eq!(cb.feature_names(), Some(&["f1".to_string(), "f2".to_string()][..]));
257 }
258
259 #[test]
260 fn test_explainability_callback_record_importances() {
261 let mut cb = ExplainabilityCallback::new(ExplainMethod::Saliency).with_top_k(3);
262
263 let importances = vec![(0, 0.5), (1, 0.3), (2, 0.8), (3, 0.1), (4, 0.6)];
265 cb.record_importances(0, importances);
266
267 assert_eq!(cb.results().len(), 1);
268 let result = &cb.results()[0];
269 assert_eq!(result.epoch, 0);
270 assert_eq!(result.method, ExplainMethod::Saliency);
271 assert_eq!(result.importances.len(), 3); assert_eq!(result.importances[0].0, 2); assert_eq!(result.importances[1].0, 4); assert_eq!(result.importances[2].0, 0); }
278
279 #[test]
280 fn test_explainability_callback_consistent_features() {
281 let mut cb =
282 ExplainabilityCallback::new(ExplainMethod::PermutationImportance).with_top_k(2);
283
284 cb.record_importances(0, vec![(0, 0.8), (1, 0.6), (2, 0.1)]);
286 cb.record_importances(1, vec![(0, 0.7), (2, 0.5), (1, 0.2)]);
288 cb.record_importances(2, vec![(0, 0.9), (1, 0.4), (2, 0.3)]);
290
291 let consistent = cb.consistent_top_features();
292 assert!(!consistent.is_empty());
294 assert_eq!(consistent[0].0, 0);
295 }
296
297 #[test]
298 fn test_explainability_callback_trainer_callback_impl() {
299 let mut cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
300 let ctx = CallbackContext::default();
301
302 assert_eq!(cb.on_epoch_end(&ctx), CallbackAction::Continue);
304 assert_eq!(cb.name(), "ExplainabilityCallback");
305 }
306
307 #[test]
308 fn test_explain_method_enum() {
309 assert_ne!(ExplainMethod::PermutationImportance, ExplainMethod::IntegratedGradients);
311 assert_ne!(ExplainMethod::IntegratedGradients, ExplainMethod::Saliency);
312 assert_ne!(ExplainMethod::Saliency, ExplainMethod::PermutationImportance);
313
314 let method = ExplainMethod::Saliency;
316 let cloned = method;
317 assert_eq!(method, cloned);
318 }
319
320 #[test]
321 fn test_feature_importance_result_fields() {
322 let result = FeatureImportanceResult {
323 epoch: 5,
324 importances: vec![(0, 0.9), (1, 0.7)],
325 method: ExplainMethod::IntegratedGradients,
326 };
327
328 assert_eq!(result.epoch, 5);
329 assert_eq!(result.importances.len(), 2);
330 assert_eq!(result.method, ExplainMethod::IntegratedGradients);
331 }
332
333 #[test]
334 fn test_explainability_empty_results() {
335 let cb = ExplainabilityCallback::new(ExplainMethod::Saliency);
336 assert!(cb.consistent_top_features().is_empty());
337 }
338
339 #[test]
340 fn test_explainability_feature_names_none() {
341 let cb = ExplainabilityCallback::new(ExplainMethod::Saliency);
342 assert!(cb.feature_names().is_none());
343 }
344
345 #[test]
346 fn test_explainability_record_importances_negative() {
347 let mut cb = ExplainabilityCallback::new(ExplainMethod::Saliency).with_top_k(2);
348 let importances = vec![(0, -0.9), (1, 0.5), (2, -0.3)];
349 cb.record_importances(0, importances);
350 let result = &cb.results()[0];
351 assert_eq!(result.importances[0].0, 0);
352 assert_eq!(result.importances[1].0, 1);
353 }
354
355 #[test]
356 fn test_explainability_callback_basic() {
357 let mut cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
358 assert_eq!(cb.name(), "ExplainabilityCallback");
359
360 let mut ctx = CallbackContext::default();
361 ctx.step = 5;
362 ctx.loss = 0.5;
363
364 cb.on_step_end(&ctx);
365 }
367
368 #[test]
369 fn test_explainability_compute_permutation_importance() {
370 let cb = ExplainabilityCallback::new(ExplainMethod::PermutationImportance);
371
372 let x = vec![
374 aprender::primitives::Vector::from_slice(&[1.0, 2.0, 3.0]),
375 aprender::primitives::Vector::from_slice(&[4.0, 5.0, 6.0]),
376 aprender::primitives::Vector::from_slice(&[7.0, 8.0, 9.0]),
377 ];
378 let y = vec![1.0, 2.0, 3.0];
379
380 let predict_fn = |v: &aprender::primitives::Vector<f32>| -> f32 {
382 v.as_slice()[0] * 0.1 + v.as_slice()[1] * 0.2
383 };
384
385 let importance = cb.compute_permutation_importance(predict_fn, &x, &y);
386 assert_eq!(importance.len(), 3);
387 }
388
389 #[test]
390 fn test_explainability_compute_integrated_gradients() {
391 let cb = ExplainabilityCallback::new(ExplainMethod::IntegratedGradients);
392
393 let sample = aprender::primitives::Vector::from_slice(&[1.0, 2.0, 3.0]);
394 let baseline = aprender::primitives::Vector::from_slice(&[0.0, 0.0, 0.0]);
395
396 let model_fn =
397 |v: &aprender::primitives::Vector<f32>| -> f32 { v.as_slice().iter().sum::<f32>() };
398
399 let attributions = cb.compute_integrated_gradients(model_fn, &sample, &baseline);
400 assert_eq!(attributions.len(), 3);
401 }
402
403 #[test]
404 fn test_explainability_compute_saliency() {
405 let cb = ExplainabilityCallback::new(ExplainMethod::Saliency);
406
407 let sample = aprender::primitives::Vector::from_slice(&[1.0, 2.0, 3.0]);
408
409 let model_fn =
410 |v: &aprender::primitives::Vector<f32>| -> f32 { v.as_slice().iter().sum::<f32>() };
411
412 let saliency = cb.compute_saliency(model_fn, &sample);
413 assert_eq!(saliency.len(), 3);
414 }
415}