oxicuda_quant/pruning/
mask.rs1#[derive(Debug, Clone)]
11pub struct SparseMask {
12 pub mask: Vec<bool>,
14}
15
16impl SparseMask {
17 #[must_use]
19 pub fn all_active(n: usize) -> Self {
20 Self {
21 mask: vec![true; n],
22 }
23 }
24
25 #[must_use]
27 pub fn all_pruned(n: usize) -> Self {
28 Self {
29 mask: vec![false; n],
30 }
31 }
32
33 #[must_use]
35 pub fn len(&self) -> usize {
36 self.mask.len()
37 }
38
39 #[must_use]
41 pub fn is_empty(&self) -> bool {
42 self.mask.is_empty()
43 }
44
45 #[must_use]
47 pub fn count_active(&self) -> usize {
48 self.mask.iter().filter(|&&m| m).count()
49 }
50
51 #[must_use]
53 pub fn count_pruned(&self) -> usize {
54 self.mask.iter().filter(|&&m| !m).count()
55 }
56
57 #[must_use]
59 pub fn sparsity(&self) -> f32 {
60 if self.mask.is_empty() {
61 return 0.0;
62 }
63 self.count_pruned() as f32 / self.mask.len() as f32
64 }
65
66 #[must_use]
74 pub fn apply(&self, weights: &[f32]) -> Vec<f32> {
75 assert_eq!(
76 weights.len(),
77 self.mask.len(),
78 "mask/weight length mismatch"
79 );
80 weights
81 .iter()
82 .zip(self.mask.iter())
83 .map(|(&w, &m)| if m { w } else { 0.0 })
84 .collect()
85 }
86
87 pub fn apply_in_place(&self, weights: &mut [f32]) {
93 assert_eq!(
94 weights.len(),
95 self.mask.len(),
96 "mask/weight length mismatch"
97 );
98 for (w, &m) in weights.iter_mut().zip(self.mask.iter()) {
99 if !m {
100 *w = 0.0;
101 }
102 }
103 }
104
105 #[must_use]
111 pub fn and(&self, other: &Self) -> Self {
112 assert_eq!(self.mask.len(), other.mask.len(), "mask length mismatch");
113 let mask = self
114 .mask
115 .iter()
116 .zip(other.mask.iter())
117 .map(|(&a, &b)| a && b)
118 .collect();
119 Self { mask }
120 }
121
122 #[must_use]
128 pub fn or(&self, other: &Self) -> Self {
129 assert_eq!(self.mask.len(), other.mask.len(), "mask length mismatch");
130 let mask = self
131 .mask
132 .iter()
133 .zip(other.mask.iter())
134 .map(|(&a, &b)| a || b)
135 .collect();
136 Self { mask }
137 }
138}
139
140#[cfg(test)]
143mod tests {
144 use super::*;
145 use approx::assert_abs_diff_eq;
146
147 #[test]
148 fn all_active_sparsity_zero() {
149 let m = SparseMask::all_active(10);
150 assert_abs_diff_eq!(m.sparsity(), 0.0, epsilon = 1e-6);
151 assert_eq!(m.count_active(), 10);
152 assert_eq!(m.count_pruned(), 0);
153 }
154
155 #[test]
156 fn all_pruned_sparsity_one() {
157 let m = SparseMask::all_pruned(8);
158 assert_abs_diff_eq!(m.sparsity(), 1.0, epsilon = 1e-6);
159 assert_eq!(m.count_active(), 0);
160 }
161
162 #[test]
163 fn apply_zeroes_pruned() {
164 let m = SparseMask {
165 mask: vec![true, false, true, false],
166 };
167 let w = vec![1.0_f32, 2.0, 3.0, 4.0];
168 let out = m.apply(&w);
169 assert_abs_diff_eq!(out[0], 1.0, epsilon = 1e-7);
170 assert_abs_diff_eq!(out[1], 0.0, epsilon = 1e-7);
171 assert_abs_diff_eq!(out[2], 3.0, epsilon = 1e-7);
172 assert_abs_diff_eq!(out[3], 0.0, epsilon = 1e-7);
173 }
174
175 #[test]
176 fn apply_in_place_modifies_weights() {
177 let m = SparseMask {
178 mask: vec![true, false, true],
179 };
180 let mut w = vec![5.0_f32, 9.0, 3.0];
181 m.apply_in_place(&mut w);
182 assert_abs_diff_eq!(w[1], 0.0, epsilon = 1e-7);
183 assert_abs_diff_eq!(w[0], 5.0, epsilon = 1e-7);
184 }
185
186 #[test]
187 fn sparsity_partial() {
188 let m = SparseMask {
189 mask: vec![true, false, true, false, false],
190 };
191 assert_abs_diff_eq!(m.sparsity(), 3.0 / 5.0, epsilon = 1e-6);
192 }
193
194 #[test]
195 fn and_mask() {
196 let a = SparseMask {
197 mask: vec![true, true, false],
198 };
199 let b = SparseMask {
200 mask: vec![true, false, false],
201 };
202 let c = a.and(&b);
203 assert_eq!(c.mask, vec![true, false, false]);
204 }
205
206 #[test]
207 fn or_mask() {
208 let a = SparseMask {
209 mask: vec![true, false, false],
210 };
211 let b = SparseMask {
212 mask: vec![false, true, false],
213 };
214 let c = a.or(&b);
215 assert_eq!(c.mask, vec![true, true, false]);
216 }
217}