reductionml_core/
weights.rs

1use crate::{
2    hash::FNV_PRIME,
3    sparse_namespaced_features::{constant_feature_index, Namespace, SparseFeatures},
4    FeatureHash, FeatureIndex, FeatureMask, ModelIndex,
5};
6
7pub trait Weights {
8    fn weight_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> f32;
9    fn weight_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut f32;
10
11    // By convention state 0 is always the weight itself
12    fn state_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> &[f32];
13    // By convention state 0 is always the weight itself
14    fn state_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut [f32];
15}
16
17macro_rules! generate_foreach_feature_func {
18    ($func_name: ident, $weight_type: ty, $inner_func_type: ty, $weight_at_func: ident) => {
19        pub fn $func_name<F, W>(
20            model_offset: ModelIndex,
21            features: &SparseFeatures,
22            weights: $weight_type,
23            quadratic_interactions: &[(Namespace, Namespace)],
24            cubic_interactions: &[(Namespace, Namespace, Namespace)],
25            num_bits: u8,
26            constant_feature_enabled: bool,
27            mut func: F,
28        ) where
29            F: FnMut(f32, $inner_func_type),
30            W: Weights,
31        {
32            for (index, value) in features.all_features() {
33                let model_weight = weights.$weight_at_func(index, model_offset);
34                func(value, model_weight);
35            }
36
37            let masker = FeatureMask::from_num_bits(num_bits);
38            // quadratics
39            for (ns1, ns2) in quadratic_interactions {
40                let same_ns = ns1 == ns2;
41                if let Some(ns1) = features.get_namespace(*ns1) {
42                    for (i, feat1) in ns1.iter().enumerate() {
43                        let multiplied =
44                            (FNV_PRIME as u64).wrapping_mul(u32::from(feat1.0) as u64) as u32;
45                        if let Some(ns2) = features.get_namespace(*ns2) {
46                            for feat2 in ns2.iter().skip(if same_ns { i } else { 0 }) {
47                                let idx =
48                                    FeatureHash::from(multiplied ^ u32::from(feat2.0)).mask(masker);
49                                let model_weight = weights.$weight_at_func(idx, model_offset);
50                                func(feat1.1 * feat2.1, model_weight);
51                            }
52                        }
53                    }
54                }
55            }
56
57            // cubics
58            for (ns1, ns2, ns3) in cubic_interactions {
59                let same_ns = ns1 == ns2;
60                let same_ns2 = ns2 == ns3;
61                if let Some(ns1) = features.get_namespace(*ns1) {
62                    for (i, feat1) in ns1.iter().enumerate() {
63                        let halfhash1 =
64                            (FNV_PRIME as u64).wrapping_mul(u32::from(feat1.0) as u64) as u32;
65                        if let Some(ns2) = features.get_namespace(*ns2) {
66                            for feat2 in ns2.iter().skip(if same_ns { i } else { 0 }) {
67                                let halfhash2 = (FNV_PRIME as u64)
68                                    .wrapping_mul((halfhash1 ^ u32::from(feat2.0)) as u64)
69                                    as u32;
70                                if let Some(ns3) = features.get_namespace(*ns3) {
71                                    for feat3 in ns3.iter().skip(if same_ns2 { i } else { 0 }) {
72                                        let idx = FeatureHash::from(halfhash2 ^ u32::from(feat3.0))
73                                            .mask(masker);
74                                        let model_weight =
75                                            weights.$weight_at_func(idx, model_offset);
76                                        func(feat1.1 * feat2.1 * feat3.1, model_weight);
77                                    }
78                                }
79                            }
80                        }
81                    }
82                }
83            }
84
85            if constant_feature_enabled {
86                let constant_feature_index = constant_feature_index(num_bits);
87                let model_weight = weights.$weight_at_func(constant_feature_index, model_offset);
88                func(1.0, model_weight);
89            }
90        }
91    };
92}
93
94generate_foreach_feature_func!(foreach_feature, &W, f32, weight_at);
95generate_foreach_feature_func!(foreach_feature_with_state, &W, &[f32], state_at);
96generate_foreach_feature_func!(
97    foreach_feature_with_state_mut,
98    &mut W,
99    &mut [f32],
100    state_at_mut
101);