Skip to main content

oxicuda_quant/pruning/
mask.rs

1//! # Sparse Mask
2//!
3//! A `SparseMask` is a boolean weight mask where `false` = pruned, `true` = active.
4//! It is produced by a pruner and applied to weight tensors to zero out pruned entries.
5
6/// Boolean pruning mask over a weight tensor.
7///
8/// `true`  → weight is active (kept).
9/// `false` → weight is pruned (zeroed).
10#[derive(Debug, Clone)]
11pub struct SparseMask {
12    /// Element-wise mask (same length as the weight tensor).
13    pub mask: Vec<bool>,
14}
15
16impl SparseMask {
17    /// Create a mask with all weights active (unpruned).
18    #[must_use]
19    pub fn all_active(n: usize) -> Self {
20        Self {
21            mask: vec![true; n],
22        }
23    }
24
25    /// Create a mask with all weights pruned.
26    #[must_use]
27    pub fn all_pruned(n: usize) -> Self {
28        Self {
29            mask: vec![false; n],
30        }
31    }
32
33    /// Number of elements in the mask.
34    #[must_use]
35    pub fn len(&self) -> usize {
36        self.mask.len()
37    }
38
39    /// Returns `true` if the mask covers an empty weight tensor.
40    #[must_use]
41    pub fn is_empty(&self) -> bool {
42        self.mask.is_empty()
43    }
44
45    /// Number of active (non-pruned) weights.
46    #[must_use]
47    pub fn count_active(&self) -> usize {
48        self.mask.iter().filter(|&&m| m).count()
49    }
50
51    /// Number of pruned weights.
52    #[must_use]
53    pub fn count_pruned(&self) -> usize {
54        self.mask.iter().filter(|&&m| !m).count()
55    }
56
57    /// Fraction of weights that are pruned ∈ [0, 1].
58    #[must_use]
59    pub fn sparsity(&self) -> f32 {
60        if self.mask.is_empty() {
61            return 0.0;
62        }
63        self.count_pruned() as f32 / self.mask.len() as f32
64    }
65
66    /// Apply the mask to a weight slice: pruned weights become 0.
67    ///
68    /// Returns a new `Vec<f32>` of the same length.
69    ///
70    /// # Panics
71    ///
72    /// Panics if `weights.len() != self.len()`.
73    #[must_use]
74    pub fn apply(&self, weights: &[f32]) -> Vec<f32> {
75        assert_eq!(
76            weights.len(),
77            self.mask.len(),
78            "mask/weight length mismatch"
79        );
80        weights
81            .iter()
82            .zip(self.mask.iter())
83            .map(|(&w, &m)| if m { w } else { 0.0 })
84            .collect()
85    }
86
87    /// Apply the mask in-place: pruned weights become 0.
88    ///
89    /// # Panics
90    ///
91    /// Panics if `weights.len() != self.len()`.
92    pub fn apply_in_place(&self, weights: &mut [f32]) {
93        assert_eq!(
94            weights.len(),
95            self.mask.len(),
96            "mask/weight length mismatch"
97        );
98        for (w, &m) in weights.iter_mut().zip(self.mask.iter()) {
99            if !m {
100                *w = 0.0;
101            }
102        }
103    }
104
105    /// Combine two masks with logical AND (both must be active to stay active).
106    ///
107    /// # Panics
108    ///
109    /// Panics if the masks have different lengths.
110    #[must_use]
111    pub fn and(&self, other: &Self) -> Self {
112        assert_eq!(self.mask.len(), other.mask.len(), "mask length mismatch");
113        let mask = self
114            .mask
115            .iter()
116            .zip(other.mask.iter())
117            .map(|(&a, &b)| a && b)
118            .collect();
119        Self { mask }
120    }
121
122    /// Combine two masks with logical OR (at least one active → active).
123    ///
124    /// # Panics
125    ///
126    /// Panics if the masks have different lengths.
127    #[must_use]
128    pub fn or(&self, other: &Self) -> Self {
129        assert_eq!(self.mask.len(), other.mask.len(), "mask length mismatch");
130        let mask = self
131            .mask
132            .iter()
133            .zip(other.mask.iter())
134            .map(|(&a, &b)| a || b)
135            .collect();
136        Self { mask }
137    }
138}
139
140// ─── Tests ───────────────────────────────────────────────────────────────────
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use approx::assert_abs_diff_eq;
146
147    #[test]
148    fn all_active_sparsity_zero() {
149        let m = SparseMask::all_active(10);
150        assert_abs_diff_eq!(m.sparsity(), 0.0, epsilon = 1e-6);
151        assert_eq!(m.count_active(), 10);
152        assert_eq!(m.count_pruned(), 0);
153    }
154
155    #[test]
156    fn all_pruned_sparsity_one() {
157        let m = SparseMask::all_pruned(8);
158        assert_abs_diff_eq!(m.sparsity(), 1.0, epsilon = 1e-6);
159        assert_eq!(m.count_active(), 0);
160    }
161
162    #[test]
163    fn apply_zeroes_pruned() {
164        let m = SparseMask {
165            mask: vec![true, false, true, false],
166        };
167        let w = vec![1.0_f32, 2.0, 3.0, 4.0];
168        let out = m.apply(&w);
169        assert_abs_diff_eq!(out[0], 1.0, epsilon = 1e-7);
170        assert_abs_diff_eq!(out[1], 0.0, epsilon = 1e-7);
171        assert_abs_diff_eq!(out[2], 3.0, epsilon = 1e-7);
172        assert_abs_diff_eq!(out[3], 0.0, epsilon = 1e-7);
173    }
174
175    #[test]
176    fn apply_in_place_modifies_weights() {
177        let m = SparseMask {
178            mask: vec![true, false, true],
179        };
180        let mut w = vec![5.0_f32, 9.0, 3.0];
181        m.apply_in_place(&mut w);
182        assert_abs_diff_eq!(w[1], 0.0, epsilon = 1e-7);
183        assert_abs_diff_eq!(w[0], 5.0, epsilon = 1e-7);
184    }
185
186    #[test]
187    fn sparsity_partial() {
188        let m = SparseMask {
189            mask: vec![true, false, true, false, false],
190        };
191        assert_abs_diff_eq!(m.sparsity(), 3.0 / 5.0, epsilon = 1e-6);
192    }
193
194    #[test]
195    fn and_mask() {
196        let a = SparseMask {
197            mask: vec![true, true, false],
198        };
199        let b = SparseMask {
200            mask: vec![true, false, false],
201        };
202        let c = a.and(&b);
203        assert_eq!(c.mask, vec![true, false, false]);
204    }
205
206    #[test]
207    fn or_mask() {
208        let a = SparseMask {
209            mask: vec![true, false, false],
210        };
211        let b = SparseMask {
212            mask: vec![false, true, false],
213        };
214        let c = a.or(&b);
215        assert_eq!(c.mask, vec![true, true, false]);
216    }
217}