1use std::collections::HashMap;
2
3use approx::AbsDiffEq;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::{
7 error::{Error, Result},
8 weights::Weights,
9 FeatureIndex, ModelIndex, RawWeightsIndex, StateIndex,
10};
11
12fn num_bits_to_represent(val: u64) -> u64 {
13 64 - val.leading_zeros() as u64
14}
15
16#[derive(Deserialize, Serialize, PartialEq, Debug, Clone)]
17pub struct DenseWeights {
18 #[serde(
19 deserialize_with = "deserialize_sparse_f32_vec",
20 serialize_with = "serialize_sparse_f32_vec"
21 )]
22 weights: Vec<f32>,
23 feature_index_size: FeatureIndex,
25 model_index_size: ModelIndex,
26 feature_state_size: StateIndex,
27 model_index_size_shift: u8,
29 feature_state_size_shift: u8,
30}
31
32impl AbsDiffEq for DenseWeights {
33 type Epsilon = f32;
34
35 fn default_epsilon() -> Self::Epsilon {
36 core::f32::EPSILON
37 }
38
39 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
40 if self.feature_index_size != other.feature_index_size
41 || self.model_index_size != other.model_index_size
42 || self.feature_state_size != other.feature_state_size
43 || self.model_index_size_shift != other.model_index_size_shift
44 || self.feature_state_size_shift != other.feature_state_size_shift
45 {
46 return false;
47 }
48
49 if self.weights.len() != other.weights.len() {
50 return false;
51 }
52
53 for i in 0..self.weights.len() {
54 if !self.weights[i].abs_diff_eq(&other.weights[i], epsilon) {
55 return false;
56 }
57 }
58
59 true
60 }
61}
62
63#[derive(Deserialize, Serialize)]
64pub struct DenseWeightsWithNDArray {
65 weights: HashMap<FeatureIndex, Vec<Vec<f32>>>,
66 feature_index_size: FeatureIndex,
68 model_index_size: ModelIndex,
69 feature_state_size: StateIndex,
70 model_index_size_shift: u8,
72 feature_state_size_shift: u8,
73}
74
75impl DenseWeightsWithNDArray {
76 pub fn from_dense_weights(weights: DenseWeights) -> Self {
77 let mut weights_map = HashMap::new();
78 let feature_index_size = weights.feature_index_size;
79 let model_index_size = weights.model_index_size;
80 let feature_state_size = weights.feature_state_size;
81 let model_index_size_shift = weights.model_index_size_shift;
82 let feature_state_size_shift = weights.feature_state_size_shift;
83
84 for i in 0..*feature_index_size {
85 let mut found_non_zero = false;
86 let mut vec = Vec::new();
87 for j in 0..*model_index_size {
88 let state = weights.state_at(i.into(), j.into());
89 if state.iter().any(|x| *x != 0.0) {
90 found_non_zero = true;
91 }
92 vec.push(state.into());
93 }
94 if found_non_zero {
95 weights_map.insert(FeatureIndex::from(i), vec);
96 }
97 }
98
99 DenseWeightsWithNDArray {
100 weights: weights_map,
101 feature_index_size,
102 model_index_size,
103 feature_state_size,
104 model_index_size_shift,
105 feature_state_size_shift,
106 }
107 }
108
109 pub fn to_dense_weights(&self) -> DenseWeights {
110 let feature_index_size_shift =
111 num_bits_to_represent(*self.feature_index_size as u64 - 1) as usize;
112
113 let weights = vec![
114 0.0;
115 (1 << feature_index_size_shift)
116 * (1 << self.model_index_size_shift)
117 * (1 << self.feature_state_size_shift)
118 ];
119
120 let mut weights = DenseWeights {
121 weights,
122 feature_index_size: self.feature_index_size,
123 model_index_size: self.model_index_size,
124 feature_state_size: self.feature_state_size,
125 model_index_size_shift: self.model_index_size_shift,
126 feature_state_size_shift: self.feature_state_size_shift,
127 };
128
129 for (feature_index, feature) in &self.weights {
130 for (model_index, state) in feature.iter().enumerate() {
131 weights
132 .state_at_mut(*feature_index, ModelIndex::from(model_index as u8))
133 .copy_from_slice(state);
134 }
135 }
136
137 weights
138 }
139}
140
141#[derive(Debug, Deserialize, Serialize)]
142struct SparseF32Vec {
143 len: u64,
144 non_zero_value_and_index_pairs: Vec<(usize, f32)>,
145}
146
147impl SparseF32Vec {
148 fn from_dense(vec: &Vec<f32>) -> SparseF32Vec {
149 let len: u64 = vec.len().try_into().unwrap();
150 let non_zero_value_and_index_pairs: Vec<(usize, f32)> = vec
151 .iter()
152 .enumerate()
153 .filter(|(_, v)| **v != 0.0)
154 .map(|(i, v)| (i, *v))
155 .collect();
156 SparseF32Vec {
157 len,
158 non_zero_value_and_index_pairs,
159 }
160 }
161
162 fn to_dense(&self) -> Vec<f32> {
163 let mut vec = vec![0.0; self.len as usize];
164 for (index, value) in self.non_zero_value_and_index_pairs.iter() {
165 vec[*index] = *value;
166 }
167 vec
168 }
169}
170
171fn serialize_sparse_f32_vec<S>(
172 vec: &Vec<f32>,
173 serializer: S,
174) -> std::result::Result<S::Ok, S::Error>
175where
176 S: Serializer,
177{
178 let sparse_vec = SparseF32Vec::from_dense(vec);
179 sparse_vec.serialize(serializer)
180}
181
182fn deserialize_sparse_f32_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<f32>, D::Error>
183where
184 D: Deserializer<'de>,
185{
186 let sparse_vec = SparseF32Vec::deserialize(deserializer)?;
187 Ok(sparse_vec.to_dense())
188}
189
190impl DenseWeights {
191 fn convert_index(
192 &self,
193 feature_index: FeatureIndex,
194 model_index: ModelIndex,
195 ) -> RawWeightsIndex {
196 debug_assert!(feature_index < self.feature_index_size);
198 debug_assert!(model_index < self.model_index_size);
199 let raw_index = ((*feature_index as usize)
200 << (self.model_index_size_shift + self.feature_state_size_shift))
201 + ((*model_index as usize) << self.feature_state_size_shift);
202
203 RawWeightsIndex::from(raw_index)
204 }
205
206 pub fn new(
207 feature_index_size: FeatureIndex,
208 model_index_size: ModelIndex,
209 feature_state_size: StateIndex,
210 ) -> Result<DenseWeights> {
211 let feature_index_size_shift =
212 num_bits_to_represent(*feature_index_size as u64 - 1) as usize;
213 let model_index_size_shift = num_bits_to_represent(*model_index_size as u64 - 1) as usize;
214 let feature_state_size_shift =
215 num_bits_to_represent(*feature_state_size as u64 - 1) as usize;
216 assert!(feature_index_size_shift + model_index_size_shift + feature_state_size_shift <= 64);
217 let weights = vec![
218 0.0;
219 (1 << feature_index_size_shift)
220 * (1 << model_index_size_shift)
221 * (1 << feature_state_size_shift)
222 ];
223 Ok(DenseWeights {
224 weights,
225 feature_index_size,
226 model_index_size,
227 feature_state_size,
228 model_index_size_shift: u8::try_from(model_index_size_shift)
230 .map_err(|e| Error::InvalidArgument(e.to_string()))?,
231 feature_state_size_shift: u8::try_from(feature_state_size_shift)
232 .map_err(|e| Error::InvalidArgument(e.to_string()))?,
233 })
234 }
235}
236
237impl Weights for DenseWeights {
238 fn weight_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> f32 {
239 let index = self.convert_index(feature_index, model_index);
240 self.weights[*index]
241 }
242
243 fn weight_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut f32 {
244 let index = self.convert_index(feature_index, model_index);
245 &mut self.weights[*index]
246 }
247
248 fn state_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> &[f32] {
249 let index = self.convert_index(feature_index, model_index);
250 &self.weights[*index..*index + *self.feature_state_size as usize]
251 }
252
253 fn state_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut [f32] {
254 let index = self.convert_index(feature_index, model_index);
255 &mut self.weights[*index..*index + *self.feature_state_size as usize]
256 }
257}
258
259#[cfg(test)]
269mod tests {
270 use approx::{assert_abs_diff_eq, assert_abs_diff_ne};
271
272 use super::*;
273
274 #[test]
275 fn test_num_bits_to_represent() {
276 assert_eq!(num_bits_to_represent(0), 0);
277 assert_eq!(num_bits_to_represent(1), 1);
278 assert_eq!(num_bits_to_represent(2), 2);
279 assert_eq!(num_bits_to_represent(3), 2);
280 assert_eq!(num_bits_to_represent(4), 3);
281 assert_eq!(num_bits_to_represent(5), 3);
282 assert_eq!(num_bits_to_represent(6), 3);
283 assert_eq!(num_bits_to_represent(7), 3);
284 assert_eq!(num_bits_to_represent(8), 4);
285 assert_eq!(num_bits_to_represent(9), 4);
286 assert_eq!(num_bits_to_represent(10), 4);
287 assert_eq!(num_bits_to_represent(11), 4);
288 assert_eq!(num_bits_to_represent(12), 4);
289 assert_eq!(num_bits_to_represent(13), 4);
290 assert_eq!(num_bits_to_represent(14), 4);
291 assert_eq!(num_bits_to_represent(15), 4);
292 assert_eq!(num_bits_to_represent(16), 5);
293 assert_eq!(num_bits_to_represent(17), 5);
294 assert_eq!(num_bits_to_represent(18), 5);
295 assert_eq!(num_bits_to_represent(19), 5);
296 assert_eq!(num_bits_to_represent(20), 5);
297 assert_eq!(num_bits_to_represent(21), 5);
298 assert_eq!(num_bits_to_represent(22), 5);
299 assert_eq!(num_bits_to_represent(23), 5);
300 assert_eq!(num_bits_to_represent(24), 5);
301 assert_eq!(num_bits_to_represent(25), 5);
302 assert_eq!(num_bits_to_represent(26), 5);
303 assert_eq!(num_bits_to_represent(27), 5);
304 assert_eq!(num_bits_to_represent(28), 5);
305 assert_eq!(num_bits_to_represent(29), 5);
306 assert_eq!(num_bits_to_represent(30), 5);
307 assert_eq!(num_bits_to_represent(31), 5);
308 }
309
310 #[test]
311 fn weights_equality() {
312 let mut w1 = DenseWeights::new(
313 FeatureIndex::from(4),
314 ModelIndex::from(1),
315 StateIndex::from(1),
316 )
317 .unwrap();
318 let w2 = DenseWeights::new(
319 FeatureIndex::from(4),
320 ModelIndex::from(1),
321 StateIndex::from(1),
322 )
323 .unwrap();
324
325 assert_abs_diff_eq!(w1, w2);
326
327 *w1.weight_at_mut(FeatureIndex::from(0), ModelIndex::from(0)) = 1.0;
328
329 assert_abs_diff_ne!(w1, w2);
330 }
331
332 #[test]
333 fn weights_roundtrip() {
334 let mut w1 = DenseWeights::new(
335 FeatureIndex::from(4),
336 ModelIndex::from(2),
337 StateIndex::from(3),
338 )
339 .unwrap();
340 for i in 0..4 {
341 *w1.weight_at_mut(FeatureIndex::from(i), ModelIndex::from(0)) = i as f32;
342 *w1.weight_at_mut(FeatureIndex::from(i), ModelIndex::from(1)) = i as f32 * 2_f32;
343 w1.state_at_mut(FeatureIndex::from(i), ModelIndex::from(0))[1] = i as f32 * 3_f32;
344 w1.state_at_mut(FeatureIndex::from(i), ModelIndex::from(1))[2] = i as f32 * 3_f32;
345 }
346
347 let dwnd = DenseWeightsWithNDArray::from_dense_weights(w1.clone());
348
349 let w2 = dwnd.to_dense_weights();
350
351 assert_abs_diff_eq!(w1, w2);
352 }
353}