reductionml_core/
weights.rs1use crate::{
2 hash::FNV_PRIME,
3 sparse_namespaced_features::{constant_feature_index, Namespace, SparseFeatures},
4 FeatureHash, FeatureIndex, FeatureMask, ModelIndex,
5};
6
7pub trait Weights {
8 fn weight_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> f32;
9 fn weight_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut f32;
10
11 fn state_at(&self, feature_index: FeatureIndex, model_index: ModelIndex) -> &[f32];
13 fn state_at_mut(&mut self, feature_index: FeatureIndex, model_index: ModelIndex) -> &mut [f32];
15}
16
17macro_rules! generate_foreach_feature_func {
18 ($func_name: ident, $weight_type: ty, $inner_func_type: ty, $weight_at_func: ident) => {
19 pub fn $func_name<F, W>(
20 model_offset: ModelIndex,
21 features: &SparseFeatures,
22 weights: $weight_type,
23 quadratic_interactions: &[(Namespace, Namespace)],
24 cubic_interactions: &[(Namespace, Namespace, Namespace)],
25 num_bits: u8,
26 constant_feature_enabled: bool,
27 mut func: F,
28 ) where
29 F: FnMut(f32, $inner_func_type),
30 W: Weights,
31 {
32 for (index, value) in features.all_features() {
33 let model_weight = weights.$weight_at_func(index, model_offset);
34 func(value, model_weight);
35 }
36
37 let masker = FeatureMask::from_num_bits(num_bits);
38 for (ns1, ns2) in quadratic_interactions {
40 let same_ns = ns1 == ns2;
41 if let Some(ns1) = features.get_namespace(*ns1) {
42 for (i, feat1) in ns1.iter().enumerate() {
43 let multiplied =
44 (FNV_PRIME as u64).wrapping_mul(u32::from(feat1.0) as u64) as u32;
45 if let Some(ns2) = features.get_namespace(*ns2) {
46 for feat2 in ns2.iter().skip(if same_ns { i } else { 0 }) {
47 let idx =
48 FeatureHash::from(multiplied ^ u32::from(feat2.0)).mask(masker);
49 let model_weight = weights.$weight_at_func(idx, model_offset);
50 func(feat1.1 * feat2.1, model_weight);
51 }
52 }
53 }
54 }
55 }
56
57 for (ns1, ns2, ns3) in cubic_interactions {
59 let same_ns = ns1 == ns2;
60 let same_ns2 = ns2 == ns3;
61 if let Some(ns1) = features.get_namespace(*ns1) {
62 for (i, feat1) in ns1.iter().enumerate() {
63 let halfhash1 =
64 (FNV_PRIME as u64).wrapping_mul(u32::from(feat1.0) as u64) as u32;
65 if let Some(ns2) = features.get_namespace(*ns2) {
66 for feat2 in ns2.iter().skip(if same_ns { i } else { 0 }) {
67 let halfhash2 = (FNV_PRIME as u64)
68 .wrapping_mul((halfhash1 ^ u32::from(feat2.0)) as u64)
69 as u32;
70 if let Some(ns3) = features.get_namespace(*ns3) {
71 for feat3 in ns3.iter().skip(if same_ns2 { i } else { 0 }) {
72 let idx = FeatureHash::from(halfhash2 ^ u32::from(feat3.0))
73 .mask(masker);
74 let model_weight =
75 weights.$weight_at_func(idx, model_offset);
76 func(feat1.1 * feat2.1 * feat3.1, model_weight);
77 }
78 }
79 }
80 }
81 }
82 }
83 }
84
85 if constant_feature_enabled {
86 let constant_feature_index = constant_feature_index(num_bits);
87 let model_weight = weights.$weight_at_func(constant_feature_index, model_offset);
88 func(1.0, model_weight);
89 }
90 }
91 };
92}
93
94generate_foreach_feature_func!(foreach_feature, &W, f32, weight_at);
95generate_foreach_feature_func!(foreach_feature_with_state, &W, &[f32], state_at);
96generate_foreach_feature_func!(
97 foreach_feature_with_state_mut,
98 &mut W,
99 &mut [f32],
100 state_at_mut
101);