#[derive(Debug, Clone)]
pub struct SparseMask {
pub mask: Vec<bool>,
}
impl SparseMask {
#[must_use]
pub fn all_active(n: usize) -> Self {
Self {
mask: vec![true; n],
}
}
#[must_use]
pub fn all_pruned(n: usize) -> Self {
Self {
mask: vec![false; n],
}
}
#[must_use]
pub fn len(&self) -> usize {
self.mask.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.mask.is_empty()
}
#[must_use]
pub fn count_active(&self) -> usize {
self.mask.iter().filter(|&&m| m).count()
}
#[must_use]
pub fn count_pruned(&self) -> usize {
self.mask.iter().filter(|&&m| !m).count()
}
#[must_use]
pub fn sparsity(&self) -> f32 {
if self.mask.is_empty() {
return 0.0;
}
self.count_pruned() as f32 / self.mask.len() as f32
}
#[must_use]
pub fn apply(&self, weights: &[f32]) -> Vec<f32> {
assert_eq!(
weights.len(),
self.mask.len(),
"mask/weight length mismatch"
);
weights
.iter()
.zip(self.mask.iter())
.map(|(&w, &m)| if m { w } else { 0.0 })
.collect()
}
pub fn apply_in_place(&self, weights: &mut [f32]) {
assert_eq!(
weights.len(),
self.mask.len(),
"mask/weight length mismatch"
);
for (w, &m) in weights.iter_mut().zip(self.mask.iter()) {
if !m {
*w = 0.0;
}
}
}
#[must_use]
pub fn and(&self, other: &Self) -> Self {
assert_eq!(self.mask.len(), other.mask.len(), "mask length mismatch");
let mask = self
.mask
.iter()
.zip(other.mask.iter())
.map(|(&a, &b)| a && b)
.collect();
Self { mask }
}
#[must_use]
pub fn or(&self, other: &Self) -> Self {
assert_eq!(self.mask.len(), other.mask.len(), "mask length mismatch");
let mask = self
.mask
.iter()
.zip(other.mask.iter())
.map(|(&a, &b)| a || b)
.collect();
Self { mask }
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn all_active_sparsity_zero() {
let m = SparseMask::all_active(10);
assert_abs_diff_eq!(m.sparsity(), 0.0, epsilon = 1e-6);
assert_eq!(m.count_active(), 10);
assert_eq!(m.count_pruned(), 0);
}
#[test]
fn all_pruned_sparsity_one() {
let m = SparseMask::all_pruned(8);
assert_abs_diff_eq!(m.sparsity(), 1.0, epsilon = 1e-6);
assert_eq!(m.count_active(), 0);
}
#[test]
fn apply_zeroes_pruned() {
let m = SparseMask {
mask: vec![true, false, true, false],
};
let w = vec![1.0_f32, 2.0, 3.0, 4.0];
let out = m.apply(&w);
assert_abs_diff_eq!(out[0], 1.0, epsilon = 1e-7);
assert_abs_diff_eq!(out[1], 0.0, epsilon = 1e-7);
assert_abs_diff_eq!(out[2], 3.0, epsilon = 1e-7);
assert_abs_diff_eq!(out[3], 0.0, epsilon = 1e-7);
}
#[test]
fn apply_in_place_modifies_weights() {
let m = SparseMask {
mask: vec![true, false, true],
};
let mut w = vec![5.0_f32, 9.0, 3.0];
m.apply_in_place(&mut w);
assert_abs_diff_eq!(w[1], 0.0, epsilon = 1e-7);
assert_abs_diff_eq!(w[0], 5.0, epsilon = 1e-7);
}
#[test]
fn sparsity_partial() {
let m = SparseMask {
mask: vec![true, false, true, false, false],
};
assert_abs_diff_eq!(m.sparsity(), 3.0 / 5.0, epsilon = 1e-6);
}
#[test]
fn and_mask() {
let a = SparseMask {
mask: vec![true, true, false],
};
let b = SparseMask {
mask: vec![true, false, false],
};
let c = a.and(&b);
assert_eq!(c.mask, vec![true, false, false]);
}
#[test]
fn or_mask() {
let a = SparseMask {
mask: vec![true, false, false],
};
let b = SparseMask {
mask: vec![false, true, false],
};
let c = a.or(&b);
assert_eq!(c.mask, vec![true, true, false]);
}
}