reductionml_core/
dense_weights.rs

1use std::collections::HashMap;
2
3use approx::AbsDiffEq;
4use serde::{Deserialize, Deserializer, Serialize, Serializer};
5
6use crate::{
7    error::{Error, Result},
8    weights::Weights,
9    FeatureIndex, ModelIndex, RawWeightsIndex, StateIndex,
10};
11
12fn num_bits_to_represent(val: u64) -> u64 {
13    64 - val.leading_zeros() as u64
14}
15
16#[derive(Deserialize, Serialize, PartialEq, Debug, Clone)]
17pub struct DenseWeights {
18    #[serde(
19        deserialize_with = "deserialize_sparse_f32_vec",
20        serialize_with = "serialize_sparse_f32_vec"
21    )]
22    weights: Vec<f32>,
23    // Max size of index
24    feature_index_size: FeatureIndex,
25    model_index_size: ModelIndex,
26    feature_state_size: StateIndex,
27    // Number of bits required to represent index
28    model_index_size_shift: u8,
29    feature_state_size_shift: u8,
30}
31
32impl AbsDiffEq for DenseWeights {
33    type Epsilon = f32;
34
35    fn default_epsilon() -> Self::Epsilon {
36        core::f32::EPSILON
37    }
38
39    fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
40        if self.feature_index_size != other.feature_index_size
41            || self.model_index_size != other.model_index_size
42            || self.feature_state_size != other.feature_state_size
43            || self.model_index_size_shift != other.model_index_size_shift
44            || self.feature_state_size_shift != other.feature_state_size_shift
45        {
46            return false;
47        }
48
49        if self.weights.len() != other.weights.len() {
50            return false;
51        }
52
53        for i in 0..self.weights.len() {
54            if !self.weights[i].abs_diff_eq(&other.weights[i], epsilon) {
55                return false;
56            }
57        }
58
59        true
60    }
61}
62
63#[derive(Deserialize, Serialize)]
64pub struct DenseWeightsWithNDArray {
65    weights: HashMap<FeatureIndex, Vec<Vec<f32>>>,
66    // Max size of index
67    feature_index_size: FeatureIndex,
68    model_index_size: ModelIndex,
69    feature_state_size: StateIndex,
70    // Number of bits required to represent index
71    model_index_size_shift: u8,
72    feature_state_size_shift: u8,
73}
74
75impl DenseWeightsWithNDArray {
76    pub fn from_dense_weights(weights: DenseWeights) -> Self {
77        let mut weights_map = HashMap::new();
78        let feature_index_size = weights.feature_index_size;
79        let model_index_size = weights.model_index_size;
80        let feature_state_size = weights.feature_state_size;
81        let model_index_size_shift = weights.model_index_size_shift;
82        let feature_state_size_shift = weights.feature_state_size_shift;
83
84        for i in 0..*feature_index_size {
85            let mut found_non_zero = false;
86            let mut vec = Vec::new();
87            for j in 0..*model_index_size {
88                let state = weights.state_at(i.into(), j.into());
89                if state.iter().any(|x| *x != 0.0) {
90                    found_non_zero = true;
91                }
92                vec.push(state.into());
93            }
94            if found_non_zero {
95                weights_map.insert(FeatureIndex::from(i), vec);
96            }
97        }
98
99        DenseWeightsWithNDArray {
100            weights: weights_map,
101            feature_index_size,
102            model_index_size,
103            feature_state_size,
104            model_index_size_shift,
105            feature_state_size_shift,
106        }
107    }
108
109    pub fn to_dense_weights(&self) -> DenseWeights {
110        let feature_index_size_shift =
111            num_bits_to_represent(*self.feature_index_size as u64 - 1) as usize;
112
113        let weights = vec![
114            0.0;
115            (1 << feature_index_size_shift)
116                * (1 << self.model_index_size_shift)
117                * (1 << self.feature_state_size_shift)
118        ];
119
120        let mut weights = DenseWeights {
121            weights,
122            feature_index_size: self.feature_index_size,
123            model_index_size: self.model_index_size,
124            feature_state_size: self.feature_state_size,
125            model_index_size_shift: self.model_index_size_shift,
126            feature_state_size_shift: self.feature_state_size_shift,
127        };
128
129        for (feature_index, feature) in &self.weights {
130            for (model_index, state) in feature.iter().enumerate() {
131                weights
132                    .state_at_mut(*feature_index, ModelIndex::from(model_index as u8))
133                    .copy_from_slice(state);
134            }
135        }
136
137        weights
138    }
139}
140
141#[derive(Debug, Deserialize, Serialize)]
142struct SparseF32Vec {
143    len: u64,
144    non_zero_value_and_index_pairs: Vec<(usize, f32)>,
145}
146
147impl SparseF32Vec {
148    fn from_dense(vec: &Vec<f32>) -> SparseF32Vec {
149        let len: u64 = vec.len().try_into().unwrap();
150        let non_zero_value_and_index_pairs: Vec<(usize, f32)> = vec
151            .iter()
152            .enumerate()
153            .filter(|(_, v)| **v != 0.0)
154            .map(|(i, v)| (i, *v))
155            .collect();
156        SparseF32Vec {
157            len,
158            non_zero_value_and_index_pairs,
159        }
160    }
161
162    fn to_dense(&self) -> Vec<f32> {
163        let mut vec = vec![0.0; self.len as usize];
164        for (index, value) in self.non_zero_value_and_index_pairs.iter() {
165            vec[*index] = *value;
166        }
167        vec
168    }
169}
170
171fn serialize_sparse_f32_vec<S>(
172    vec: &Vec<f32>,
173    serializer: S,
174) -> std::result::Result<S::Ok, S::Error>
175where
176    S: Serializer,
177{
178    let sparse_vec = SparseF32Vec::from_dense(vec);
179    sparse_vec.serialize(serializer)
180}
181
182fn deserialize_sparse_f32_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<f32>, D::Error>
183where
184    D: Deserializer<'de>,
185{
186    let sparse_vec = SparseF32Vec::deserialize(deserializer)?;
187    Ok(sparse_vec.to_dense())
188}
189
190impl DenseWeights {
191    fn convert_index(
192        &self,
193        feature_index: FeatureIndex,
194        model_index: ModelIndex,
195    ) -> RawWeightsIndex {
196        // These are very good checks but are expensive and in the hot path, so we disable them in release
197        debug_assert!(feature_index < self.feature_index_size);
198        debug_assert!(model_index < self.model_index_size);
199        let raw_index = ((*feature_index as usize)
200            << (self.model_index_size_shift + self.feature_state_size_shift))
201            + ((*model_index as usize) << self.feature_state_size_shift);
202
203        RawWeightsIndex::from(raw_index)
204    }
205
206    pub fn new(
207        feature_index_size: FeatureIndex,
208        model_index_size: ModelIndex,
209        feature_state_size: StateIndex,
210    ) -> Result<DenseWeights> {
211        let feature_index_size_shift =
212            num_bits_to_represent(*feature_index_size as u64 - 1) as usize;
213        let model_index_size_shift = num_bits_to_represent(*model_index_size as u64 - 1) as usize;
214        let feature_state_size_shift =
215            num_bits_to_represent(*feature_state_size as u64 - 1) as usize;
216        assert!(feature_index_size_shift + model_index_size_shift + feature_state_size_shift <= 64);
217        let weights = vec![
218            0.0;
219            (1 << feature_index_size_shift)
220                * (1 << model_index_size_shift)
221                * (1 << feature_state_size_shift)
222        ];
223        Ok(DenseWeights {
224            weights,
225            feature_index_size,
226            model_index_size,
227            feature_state_size,
228            // TODO better error message
229            model_index_size_shift: u8::try_from(model_index_size_shift)
230                .map_err(|e| Error::InvalidArgument(e.to_string()))?,
231            feature_state_size_shift: u8::try_from(feature_state_size_shift)
232                .map_err(|e| Error::InvalidArgument(e.to_string()))?,
233        })
234    }
235}
236
237impl Weights for DenseWeights {
238    fn weight_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> f32 {
239        let index = self.convert_index(feature_index, model_index);
240        self.weights[*index]
241    }
242
243    fn weight_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut f32 {
244        let index = self.convert_index(feature_index, model_index);
245        &mut self.weights[*index]
246    }
247
248    fn state_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> &[f32] {
249        let index = self.convert_index(feature_index, model_index);
250        &self.weights[*index..*index + *self.feature_state_size as usize]
251    }
252
253    fn state_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut [f32] {
254        let index = self.convert_index(feature_index, model_index);
255        &mut self.weights[*index..*index + *self.feature_state_size as usize]
256    }
257}
258
259// void foreach_feature(std::uint64_t model_offset, const SparseFeatures& features, const cb::DenseWeights& weights, std::invocable<float, float> auto func)
260// {
261//   for (const auto[index, value] : features.flat_values_and_indices())
262//   {
263//     const auto model_weight = weights.weight_at(index, model_offset);
264//     func(value, model_weight);
265//   }
266// }
267
268#[cfg(test)]
269mod tests {
270    use approx::{assert_abs_diff_eq, assert_abs_diff_ne};
271
272    use super::*;
273
274    #[test]
275    fn test_num_bits_to_represent() {
276        assert_eq!(num_bits_to_represent(0), 0);
277        assert_eq!(num_bits_to_represent(1), 1);
278        assert_eq!(num_bits_to_represent(2), 2);
279        assert_eq!(num_bits_to_represent(3), 2);
280        assert_eq!(num_bits_to_represent(4), 3);
281        assert_eq!(num_bits_to_represent(5), 3);
282        assert_eq!(num_bits_to_represent(6), 3);
283        assert_eq!(num_bits_to_represent(7), 3);
284        assert_eq!(num_bits_to_represent(8), 4);
285        assert_eq!(num_bits_to_represent(9), 4);
286        assert_eq!(num_bits_to_represent(10), 4);
287        assert_eq!(num_bits_to_represent(11), 4);
288        assert_eq!(num_bits_to_represent(12), 4);
289        assert_eq!(num_bits_to_represent(13), 4);
290        assert_eq!(num_bits_to_represent(14), 4);
291        assert_eq!(num_bits_to_represent(15), 4);
292        assert_eq!(num_bits_to_represent(16), 5);
293        assert_eq!(num_bits_to_represent(17), 5);
294        assert_eq!(num_bits_to_represent(18), 5);
295        assert_eq!(num_bits_to_represent(19), 5);
296        assert_eq!(num_bits_to_represent(20), 5);
297        assert_eq!(num_bits_to_represent(21), 5);
298        assert_eq!(num_bits_to_represent(22), 5);
299        assert_eq!(num_bits_to_represent(23), 5);
300        assert_eq!(num_bits_to_represent(24), 5);
301        assert_eq!(num_bits_to_represent(25), 5);
302        assert_eq!(num_bits_to_represent(26), 5);
303        assert_eq!(num_bits_to_represent(27), 5);
304        assert_eq!(num_bits_to_represent(28), 5);
305        assert_eq!(num_bits_to_represent(29), 5);
306        assert_eq!(num_bits_to_represent(30), 5);
307        assert_eq!(num_bits_to_represent(31), 5);
308    }
309
310    #[test]
311    fn weights_equality() {
312        let mut w1 = DenseWeights::new(
313            FeatureIndex::from(4),
314            ModelIndex::from(1),
315            StateIndex::from(1),
316        )
317        .unwrap();
318        let w2 = DenseWeights::new(
319            FeatureIndex::from(4),
320            ModelIndex::from(1),
321            StateIndex::from(1),
322        )
323        .unwrap();
324
325        assert_abs_diff_eq!(w1, w2);
326
327        *w1.weight_at_mut(FeatureIndex::from(0), ModelIndex::from(0)) = 1.0;
328
329        assert_abs_diff_ne!(w1, w2);
330    }
331
332    #[test]
333    fn weights_roundtrip() {
334        let mut w1 = DenseWeights::new(
335            FeatureIndex::from(4),
336            ModelIndex::from(2),
337            StateIndex::from(3),
338        )
339        .unwrap();
340        for i in 0..4 {
341            *w1.weight_at_mut(FeatureIndex::from(i), ModelIndex::from(0)) = i as f32;
342            *w1.weight_at_mut(FeatureIndex::from(i), ModelIndex::from(1)) = i as f32 * 2_f32;
343            w1.state_at_mut(FeatureIndex::from(i), ModelIndex::from(0))[1] = i as f32 * 3_f32;
344            w1.state_at_mut(FeatureIndex::from(i), ModelIndex::from(1))[2] = i as f32 * 3_f32;
345        }
346
347        let dwnd = DenseWeightsWithNDArray::from_dense_weights(w1.clone());
348
349        let w2 = dwnd.to_dense_weights();
350
351        assert_abs_diff_eq!(w1, w2);
352    }
353}