reductionml_core/reductions/
coin.rs

1use std::iter::Sum;
2use std::ops::Deref;
3
4use crate::dense_weights::DenseWeights;
5use crate::error::Result;
6use crate::global_config::GlobalConfig;
7use crate::interactions::compile_interactions;
8use crate::loss_function::{LossFunction, LossFunctionType};
9use crate::reduction::{
10    DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
11};
12use crate::reduction_factory::{PascalCaseString, ReductionConfig, ReductionFactory};
13use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
14use crate::utils::bits_to_max_feature_index;
15use crate::utils::AsInner;
16use crate::weights::{foreach_feature, foreach_feature_with_state, foreach_feature_with_state_mut};
17use crate::{impl_default_factory_functions, types::*, ModelIndex, StateIndex};
18use schemars::schema::RootSchema;
19use schemars::{schema_for, JsonSchema};
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21use serde_default::DefaultFromSerde;
22
23#[derive(Deserialize, DefaultFromSerde, Serialize, Debug, Clone, JsonSchema)]
24#[serde(deny_unknown_fields)]
25#[serde(rename_all = "camelCase")]
26pub struct CoinRegressorConfig {
27    #[serde(default = "default_alpha")]
28    alpha: f32,
29
30    #[serde(default = "default_beta")]
31    beta: f32,
32
33    #[serde(default)]
34    l1_lambda: f32,
35
36    #[serde(default)]
37    l2_lambda: f32,
38}
39
40const fn default_alpha() -> f32 {
41    4.0
42}
43
44const fn default_beta() -> f32 {
45    1.0
46}
47
48impl ReductionConfig for CoinRegressorConfig {
49    fn as_any(&self) -> &dyn std::any::Any {
50        self
51    }
52
53    fn typename(&self) -> PascalCaseString {
54        "Coin".try_into().unwrap()
55    }
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59struct CoinRegressorModelState {
60    normalized_sum_norm_x: f32,
61    total_weight: f32,
62}
63
64struct LossFunctionHolder {
65    loss_function: Box<dyn LossFunction>,
66}
67
68impl Deref for LossFunctionHolder {
69    type Target = dyn LossFunction;
70
71    fn deref(&self) -> &Self::Target {
72        self.loss_function.deref()
73    }
74}
75
76impl Serialize for LossFunctionHolder {
77    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
78    where
79        S: Serializer,
80    {
81        self.loss_function.get_type().serialize(serializer)
82    }
83}
84
85impl<'de> Deserialize<'de> for LossFunctionHolder {
86    fn deserialize<D>(deserializer: D) -> std::result::Result<LossFunctionHolder, D::Error>
87    where
88        D: Deserializer<'de>,
89    {
90        LossFunctionType::deserialize(deserializer).map(|x| LossFunctionHolder {
91            loss_function: x.create(),
92        })
93    }
94}
95
96#[derive(Serialize, Deserialize)]
97struct CoinRegressor {
98    weights: DenseWeights,
99    config: CoinRegressorConfig,
100    model_states: Vec<CoinRegressorModelState>,
101    average_squared_norm_x: f32,
102    min_label: f32,
103    max_label: f32,
104    // TODO allow this to be chosen
105    loss_function: LossFunctionHolder,
106    pairs: Vec<(Namespace, Namespace)>,
107    triples: Vec<(Namespace, Namespace, Namespace)>,
108    num_bits: u8,
109    constant_feature_enabled: bool,
110}
111
112impl CoinRegressor {
113    pub fn new(
114        config: CoinRegressorConfig,
115        global_config: &GlobalConfig,
116        num_models_above: ModelIndex,
117    ) -> Result<CoinRegressor> {
118        let (pairs, triples) =
119            compile_interactions(global_config.interactions(), global_config.hash_seed());
120        Ok(CoinRegressor {
121            weights: DenseWeights::new(
122                bits_to_max_feature_index(global_config.num_bits()),
123                num_models_above,
124                StateIndex::from(6),
125            )?,
126            config,
127            model_states: vec![
128                CoinRegressorModelState {
129                    normalized_sum_norm_x: 0.0,
130                    total_weight: 0.0
131                };
132                *num_models_above as usize
133            ],
134            average_squared_norm_x: 0.0,
135            min_label: 0.0,
136            max_label: 0.0,
137            loss_function: LossFunctionHolder {
138                loss_function: LossFunctionType::Squared.create(),
139            },
140            pairs,
141            triples,
142            num_bits: global_config.num_bits(),
143            constant_feature_enabled: global_config.constant_feature_enabled(),
144        })
145    }
146}
147
148#[derive(Default)]
149pub struct CoinRegressorFactory;
150
151impl ReductionFactory for CoinRegressorFactory {
152    impl_default_factory_functions!("Coin", CoinRegressorConfig);
153
154    fn create(
155        &self,
156        config: &dyn ReductionConfig,
157        global_config: &GlobalConfig,
158        num_models_above: ModelIndex,
159    ) -> Result<ReductionWrapper> {
160        let config = config
161            .as_any()
162            .downcast_ref::<CoinRegressorConfig>()
163            .unwrap();
164
165        Ok(ReductionWrapper::new(
166            self.typename(),
167            Box::new(CoinRegressor::new(
168                config.clone(),
169                global_config,
170                num_models_above,
171            )?),
172            ReductionTypeDescriptionBuilder::new(
173                LabelType::Simple,
174                FeaturesType::SparseSimple,
175                PredictionType::Scalar,
176            )
177            .build(),
178            num_models_above,
179        ))
180    }
181}
182
183#[typetag::serde]
184impl ReductionImpl for CoinRegressor {
185    fn predict(
186        &self,
187        features: &mut Features,
188        _depth_info: &mut DepthInfo,
189        _model_offset: ModelIndex,
190    ) -> Prediction {
191        let sparse_feats: &SparseFeatures = features.as_inner().unwrap();
192
193        let mut prediction = 0.0;
194        foreach_feature(
195            0.into(),
196            sparse_feats,
197            &self.weights,
198            &self.pairs,
199            &self.triples,
200            self.num_bits,
201            self.constant_feature_enabled,
202            |feat_val, weight_val| prediction += feat_val * weight_val,
203        );
204
205        if prediction.is_nan() {
206            prediction = 0.0;
207        }
208
209        let scalar_pred = ScalarPrediction {
210            prediction: prediction.clamp(self.min_label, self.max_label),
211            raw_prediction: prediction,
212        };
213        scalar_pred.into()
214    }
215
216    fn learn(
217        &mut self,
218        features: &mut Features,
219        label: &Label,
220        _depth_info: &mut DepthInfo,
221        _model_offset: ModelIndex,
222    ) {
223        let sparse_feats: &SparseFeatures = features.as_inner().unwrap();
224        let simple_label: &SimpleLabel = label.as_inner().unwrap();
225
226        self.min_label = simple_label.value().min(self.min_label);
227        self.max_label = simple_label.value().max(self.max_label);
228        let _prediction = self.coin_betting_predict(sparse_feats, simple_label.weight());
229        self.coin_betting_update_after_predict(
230            sparse_feats,
231            _prediction,
232            simple_label.value(),
233            simple_label.weight(),
234        );
235    }
236
237    fn children(&self) -> Vec<&ReductionWrapper> {
238        vec![]
239    }
240
241    // TODO fix model index
242    fn sensitivity(
243        &self,
244        features: &Features,
245        _label: f32,
246        _prediction: f32,
247        _weight: f32,
248        _depth_info: DepthInfo,
249    ) -> f32 {
250        let mut score = 0.0;
251        let inner = |feat_value: f32, state: &[f32]| {
252            assert!(state.len() == 6);
253            let sqrtf_ng2 = state[W_G2].sqrt();
254            let uncertain =
255                (self.config.beta + sqrtf_ng2) / self.config.alpha + self.config.l2_lambda;
256            score += (1.0 / uncertain) * feat_value.signum();
257        };
258
259        let feat = features.as_inner().unwrap();
260        foreach_feature_with_state(
261            ModelIndex::from(0),
262            feat,
263            &self.weights,
264            &self.pairs,
265            &self.triples,
266            self.num_bits,
267            self.constant_feature_enabled,
268            inner,
269        );
270        score
271    }
272}
273
274const W_XT: usize = 0; //  current parameter
275const W_ZT: usize = 1; //  sum negative gradients
276const W_G2: usize = 2; //  sum of absolute value of gradients
277const W_MX: usize = 3; //  maximum absolute value
278const W_WE: usize = 4; //  Wealth
279const W_MG: usize = 5; //  Maximum Lipschitz constant
280
281// TODO constant
282
283struct PredOutcome(f32, f32);
284
285impl Sum for PredOutcome {
286    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
287        let mut sum = PredOutcome(0.0, 0.0);
288        for PredOutcome(x, y) in iter {
289            sum.0 += x;
290            sum.1 += y;
291        }
292        sum
293    }
294}
295
296impl CoinRegressor {
297    fn coin_betting_predict(&mut self, features: &SparseFeatures, weight: f32) -> f32 {
298        let mut prediction = 0.0;
299        let mut normalized_squared_norm_x = 0.0;
300
301        let inner_predict = |feat_value: f32, state: &[f32]| {
302            assert!(state.len() == 6);
303
304            let w_mx = state[W_MX].max(feat_value.abs());
305
306            // COCOB update without sigmoid
307            let w_xt = if state[W_MG] * w_mx > 0.0 {
308                ((self.config.alpha + state[W_WE])
309                    / (state[W_MG] * w_mx * (state[W_MG] * w_mx + state[W_G2])))
310                    * state[W_ZT]
311            } else {
312                0.0
313            };
314
315            prediction += w_xt * feat_value;
316            if w_mx > 0.0 {
317                let x_normalized = feat_value / w_mx;
318                normalized_squared_norm_x += x_normalized * x_normalized;
319            } else {
320            }
321        };
322
323        foreach_feature_with_state(
324            0.into(),
325            features,
326            &self.weights,
327            &self.pairs,
328            &self.triples,
329            self.num_bits,
330            self.constant_feature_enabled,
331            inner_predict,
332        );
333
334        // todo select correct one
335        self.model_states[0].normalized_sum_norm_x += normalized_squared_norm_x * weight;
336        self.model_states[0].total_weight += weight;
337        self.average_squared_norm_x =
338            (self.model_states[0].normalized_sum_norm_x + 1e-6) / self.model_states[0].total_weight;
339
340        let partial_prediction = prediction / self.average_squared_norm_x;
341
342        // dbg!(partial_prediction);
343
344        // todo check nan
345        partial_prediction.clamp(self.min_label, self.max_label)
346    }
347
348    fn coin_betting_update_after_predict(
349        &mut self,
350        features: &SparseFeatures,
351        prediction: f32,
352        label: f32,
353        weight: f32,
354    ) {
355        let update =
356            self.loss_function
357                .first_derivative(self.min_label, self.max_label, prediction, label)
358                * weight;
359
360        // dbg!(update);
361
362        let inner_update = |feat_value: f32, state: &mut [f32]| {
363            assert!(state.len() == 6);
364            // dbg!(feat_value);
365            //   float fabs_x = std::fabs(x);
366            let fabs_x = feat_value.abs();
367            let gradient = update * feat_value;
368            if fabs_x > state[W_MX] {
369                state[W_MX] = fabs_x;
370            }
371            let fabs_gradient = update.abs();
372            // if (fabs_gradient > w[W_MG]) { w[W_MG] = fabs_gradient > d.ftrl_beta ? fabs_gradient : d.ftrl_beta; }
373
374            if fabs_gradient > state[W_MG] {
375                state[W_MG] = if fabs_gradient > self.config.beta {
376                    fabs_gradient
377                } else {
378                    self.config.beta
379                };
380            }
381            if state[W_MG] * state[W_MX] > 0.0 {
382                state[W_XT] = ((self.config.alpha + state[W_WE])
383                    / (state[W_MG] * state[W_MX] * (state[W_MG] * state[W_MX] + state[W_G2])))
384                    * state[W_ZT];
385            } else {
386                state[W_XT] = 0.0;
387            }
388
389            state[W_ZT] += -gradient;
390            state[W_G2] += gradient.abs();
391            state[W_WE] += -gradient * state[W_XT];
392
393            state[W_XT] /= self.average_squared_norm_x;
394
395            // dbg!(state[W_XT]);
396            // dbg!(state[W_ZT]);
397            // dbg!(state[W_G2]);
398            // dbg!(state[W_MX]);
399            // dbg!(state[W_WE]);
400            // dbg!(state[W_MG]);
401            // dbg!("---");
402        };
403        foreach_feature_with_state_mut(
404            ModelIndex::from(0),
405            features,
406            &mut self.weights,
407            &self.pairs,
408            &self.triples,
409            self.num_bits,
410            self.constant_feature_enabled,
411            inner_update,
412        );
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use approx::assert_relative_eq;
419
420    use crate::{interactions::NamespaceDef, sparse_namespaced_features::Namespace};
421
422    use super::*;
423
424    #[test]
425    fn test_coin_betting_predict() {
426        let coin_config = CoinRegressorConfig::default();
427        let global_config = GlobalConfig::new(4, 0, false, &Vec::new());
428        let coin = CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
429        let mut features = SparseFeatures::new();
430        let ns = features.get_or_create_namespace(Namespace::Default);
431        ns.add_feature(0.into(), 1.0);
432
433        let mut features = Features::SparseSimple(features);
434
435        let mut depth_info = DepthInfo::new();
436        let prediction = coin.predict(&mut features, &mut depth_info, 0.into());
437        // Ensure the prediction is of variant Scalar
438        assert!(matches!(prediction, Prediction::Scalar { .. }));
439    }
440
441    #[test]
442    fn test_learning() {
443        let coin_config = CoinRegressorConfig::default();
444        let global_config = GlobalConfig::new(2, 0, false, &Vec::new());
445        let mut coin =
446            CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
447
448        let mut features = SparseFeatures::new();
449
450        {
451            let ns = features.get_or_create_namespace(Namespace::Default);
452            ns.add_feature(0.into(), 1.0);
453            ns.add_feature(1.into(), 1.0);
454            ns.add_feature(2.into(), 1.0);
455            ns.add_feature(3.into(), 1.0);
456        }
457
458        let mut depth_info = DepthInfo::new();
459        let mut features = Features::SparseSimple(features);
460        coin.learn(
461            &mut features,
462            &Label::Simple(SimpleLabel::new(0.5, 1.0)),
463            &mut depth_info,
464            0.into(),
465        );
466        coin.learn(
467            &mut features,
468            &Label::Simple(SimpleLabel::new(0.5, 1.0)),
469            &mut depth_info,
470            0.into(),
471        );
472        coin.learn(
473            &mut features,
474            &Label::Simple(SimpleLabel::new(0.5, 1.0)),
475            &mut depth_info,
476            0.into(),
477        );
478        coin.learn(
479            &mut features,
480            &Label::Simple(SimpleLabel::new(0.5, 1.0)),
481            &mut depth_info,
482            0.into(),
483        );
484
485        let pred = coin.predict(&mut features, &mut depth_info, 0.into());
486
487        assert!(matches!(pred, Prediction::Scalar { .. }));
488        let pred1: &ScalarPrediction = pred.as_inner().unwrap();
489        assert_relative_eq!(pred1.prediction, 0.5);
490    }
491
492    fn test_learning_e2e(
493        x: fn(i32) -> f32,
494        yhat: fn(f32) -> f32,
495        n: i32,
496        mut regressor: CoinRegressor,
497        test_set: Vec<f32>,
498    ) {
499        for i in 0..n {
500            let mut features = SparseFeatures::new();
501            let _x = x(i);
502            {
503                let ns = features.get_or_create_namespace(Namespace::Default);
504                // TODO: 0 index is breaking quadratic test since 0^0 = 0
505                ns.add_feature(2.into(), _x);
506            }
507
508            let mut depth_info = DepthInfo::new();
509            let mut features = Features::SparseSimple(features);
510            regressor.learn(
511                &mut features,
512                &Label::Simple(SimpleLabel::new(yhat(_x), 1.0)),
513                &mut depth_info,
514                0.into(),
515            );
516        }
517
518        for x in test_set {
519            let mut features = SparseFeatures::new();
520            {
521                let ns = features.get_or_create_namespace(Namespace::Default);
522                ns.add_feature(2.into(), x);
523            }
524
525            let mut depth_info = DepthInfo::new();
526            let mut features = Features::SparseSimple(features);
527            let pred = regressor.predict(&mut features, &mut depth_info, 0.into());
528            assert!(matches!(pred, Prediction::Scalar { .. }));
529
530            let pred_value: &ScalarPrediction = pred.as_inner().unwrap();
531            assert_relative_eq!(pred_value.prediction, yhat(x), epsilon = 0.1);
532        }
533    }
534
535    #[test]
536    fn test_learning_const() {
537        fn x(i: i32) -> f32 {
538            (i % 100) as f32 / 10.0
539        }
540        fn yhat(_x: f32) -> f32 {
541            1.0
542        }
543
544        let coin_config = CoinRegressorConfig::default();
545        let global_config = GlobalConfig::new(4, 0, true, &Vec::new());
546        let coin: CoinRegressor =
547            CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
548
549        test_learning_e2e(x, yhat, 10000, coin, vec![0.0, 1.0, 2.0, 3.0]);
550    }
551
552    #[test]
553    fn test_learning_linear() {
554        fn x(i: i32) -> f32 {
555            (i % 100) as f32 / 10.0
556        }
557        fn yhat(x: f32) -> f32 {
558            2.0 * x + 3.0
559        }
560
561        let coin_config = CoinRegressorConfig::default();
562        let global_config = GlobalConfig::new(4, 0, true, &Vec::new());
563        let coin: CoinRegressor =
564            CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
565
566        test_learning_e2e(x, yhat, 100000, coin, vec![0.0, 1.0, 2.0, 3.0]);
567    }
568
569    #[test]
570    fn test_learning_quadratic() {
571        fn x(i: i32) -> f32 {
572            (i % 100) as f32 / 10.0
573        }
574        fn yhat(x: f32) -> f32 {
575            x * x - 2.0 * x + 3.0
576        }
577
578        let coin_config = CoinRegressorConfig::default();
579        let global_config = GlobalConfig::new(
580            4,
581            0,
582            true,
583            &vec![vec![NamespaceDef::Default, NamespaceDef::Default]],
584        );
585        let coin: CoinRegressor =
586            CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
587        test_learning_e2e(x, yhat, 100000, coin, vec![0.0, 1.0, 2.0, 3.0]);
588    }
589}