impl Default for SaliencyMap {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CounterfactualExplainer {
max_iter: usize,
step_size: f32,
epsilon: f32,
}
impl CounterfactualExplainer {
#[must_use]
pub fn new(max_iter: usize, step_size: f32) -> Self {
Self {
max_iter,
step_size,
epsilon: 1e-4,
}
}
#[must_use]
pub fn max_iter(&self) -> usize {
self.max_iter
}
#[must_use]
pub fn step_size(&self) -> f32 {
self.step_size
}
pub fn find<F>(
&self,
original: &Vector<f32>,
target_class: usize,
model_fn: F,
) -> Option<CounterfactualResult>
where
F: Fn(&Vector<f32>) -> usize,
{
let n = original.len();
let mut current = original.clone();
for _ in 0..self.max_iter {
if model_fn(¤t) == target_class {
let distance = Self::euclidean_distance(¤t, original);
return Some(CounterfactualResult {
counterfactual: current,
original: original.clone(),
target_class,
distance,
});
}
let mut gradient = vec![0.0f32; n];
for i in 0..n {
let mut x_plus = current.clone();
x_plus[i] += self.epsilon;
let mut x_minus = current.clone();
x_minus[i] -= self.epsilon;
let score_plus = if model_fn(&x_plus) == target_class {
0.0
} else {
1.0
};
let score_minus = if model_fn(&x_minus) == target_class {
0.0
} else {
1.0
};
let class_grad = (score_plus - score_minus) / (2.0 * self.epsilon);
let dist_grad = 2.0 * (current[i] - original[i]);
gradient[i] = class_grad + 0.001 * dist_grad;
}
for i in 0..n {
current[i] -= self.step_size * gradient[i];
}
}
if model_fn(¤t) == target_class {
let distance = Self::euclidean_distance(¤t, original);
Some(CounterfactualResult {
counterfactual: current,
original: original.clone(),
target_class,
distance,
})
} else {
None
}
}
fn euclidean_distance(a: &Vector<f32>, b: &Vector<f32>) -> f32 {
crate::nn::functional::euclidean_distance(a.as_slice(), b.as_slice())
}
}
#[derive(Debug, Clone)]
pub struct CounterfactualResult {
pub counterfactual: Vector<f32>,
pub original: Vector<f32>,
pub target_class: usize,
pub distance: f32,
}
impl CounterfactualResult {
#[must_use]
pub fn feature_changes(&self) -> Vec<f32> {
self.counterfactual
.as_slice()
.iter()
.zip(self.original.as_slice())
.map(|(&cf, &orig)| cf - orig)
.collect()
}
#[must_use]
pub fn top_changed_features(&self, k: usize) -> Vec<(usize, f32)> {
let changes = self.feature_changes();
let mut indexed: Vec<(usize, f32)> = changes.into_iter().enumerate().collect();
indexed.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(k);
indexed
}
}
#[derive(Debug, Clone)]
pub struct LIMEExplanation {
pub coefficients: Vector<f32>,
pub intercept: f32,
pub original_prediction: f32,
}
impl LIMEExplanation {
#[must_use]
pub fn top_features(&self, k: usize) -> Vec<(usize, f32)> {
let mut indexed: Vec<(usize, f32)> = self
.coefficients
.as_slice()
.iter()
.copied()
.enumerate()
.collect();
indexed.sort_by(|a, b| {
b.1.abs()
.partial_cmp(&a.1.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
indexed.truncate(k);
indexed
}
#[must_use]
pub fn local_prediction(&self, sample: &Vector<f32>) -> f32 {
self.intercept
+ self
.coefficients
.as_slice()
.iter()
.zip(sample.as_slice())
.map(|(&c, &x)| c * x)
.sum::<f32>()
}
}
#[cfg(test)]
mod tests;