use rayon::prelude::*;
pub trait MaskedModel: Sync {
fn predict(&self, input: &[f64]) -> f64;
}
impl<F> MaskedModel for F
where
F: Fn(&[f64]) -> f64 + Sync,
{
fn predict(&self, input: &[f64]) -> f64 {
(self)(input)
}
}
pub fn masked_prediction(
model: &dyn MaskedModel,
x: &[f64],
background: &[Vec<f64>],
mask: &[u8],
) -> f64 {
assert_eq!(x.len(), mask.len());
let num_bg = background.len() as f64;
let sum: f64 = background
.par_iter()
.map(|bg_row| {
let mut input = Vec::with_capacity(x.len());
for i in 0..x.len() {
if mask[i] == 1 {
input.push(x[i]);
} else {
input.push(bg_row[i]);
}
}
model.predict(&input)
})
.sum();
sum / num_bg
}
pub fn apply_mask(
x: &[f64],
background_row: &[f64],
mask: &[u8],
) -> Vec<f64> {
assert_eq!(x.len(), mask.len());
assert_eq!(x.len(), background_row.len());
let mut out = Vec::with_capacity(x.len());
for i in 0..x.len() {
if mask[i] == 1 {
out.push(x[i]);
} else {
out.push(background_row[i]);
}
}
out
}