1use std::iter::Sum;
2use std::ops::Deref;
3
4use crate::dense_weights::DenseWeights;
5use crate::error::Result;
6use crate::global_config::GlobalConfig;
7use crate::interactions::compile_interactions;
8use crate::loss_function::{LossFunction, LossFunctionType};
9use crate::reduction::{
10 DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
11};
12use crate::reduction_factory::{PascalCaseString, ReductionConfig, ReductionFactory};
13use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
14use crate::utils::bits_to_max_feature_index;
15use crate::utils::AsInner;
16use crate::weights::{foreach_feature, foreach_feature_with_state, foreach_feature_with_state_mut};
17use crate::{impl_default_factory_functions, types::*, ModelIndex, StateIndex};
18use schemars::schema::RootSchema;
19use schemars::{schema_for, JsonSchema};
20use serde::{Deserialize, Deserializer, Serialize, Serializer};
21use serde_default::DefaultFromSerde;
22
23#[derive(Deserialize, DefaultFromSerde, Serialize, Debug, Clone, JsonSchema)]
24#[serde(deny_unknown_fields)]
25#[serde(rename_all = "camelCase")]
26pub struct CoinRegressorConfig {
27 #[serde(default = "default_alpha")]
28 alpha: f32,
29
30 #[serde(default = "default_beta")]
31 beta: f32,
32
33 #[serde(default)]
34 l1_lambda: f32,
35
36 #[serde(default)]
37 l2_lambda: f32,
38}
39
40const fn default_alpha() -> f32 {
41 4.0
42}
43
44const fn default_beta() -> f32 {
45 1.0
46}
47
48impl ReductionConfig for CoinRegressorConfig {
49 fn as_any(&self) -> &dyn std::any::Any {
50 self
51 }
52
53 fn typename(&self) -> PascalCaseString {
54 "Coin".try_into().unwrap()
55 }
56}
57
58#[derive(Clone, Serialize, Deserialize)]
59struct CoinRegressorModelState {
60 normalized_sum_norm_x: f32,
61 total_weight: f32,
62}
63
64struct LossFunctionHolder {
65 loss_function: Box<dyn LossFunction>,
66}
67
68impl Deref for LossFunctionHolder {
69 type Target = dyn LossFunction;
70
71 fn deref(&self) -> &Self::Target {
72 self.loss_function.deref()
73 }
74}
75
76impl Serialize for LossFunctionHolder {
77 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
78 where
79 S: Serializer,
80 {
81 self.loss_function.get_type().serialize(serializer)
82 }
83}
84
85impl<'de> Deserialize<'de> for LossFunctionHolder {
86 fn deserialize<D>(deserializer: D) -> std::result::Result<LossFunctionHolder, D::Error>
87 where
88 D: Deserializer<'de>,
89 {
90 LossFunctionType::deserialize(deserializer).map(|x| LossFunctionHolder {
91 loss_function: x.create(),
92 })
93 }
94}
95
96#[derive(Serialize, Deserialize)]
97struct CoinRegressor {
98 weights: DenseWeights,
99 config: CoinRegressorConfig,
100 model_states: Vec<CoinRegressorModelState>,
101 average_squared_norm_x: f32,
102 min_label: f32,
103 max_label: f32,
104 loss_function: LossFunctionHolder,
106 pairs: Vec<(Namespace, Namespace)>,
107 triples: Vec<(Namespace, Namespace, Namespace)>,
108 num_bits: u8,
109 constant_feature_enabled: bool,
110}
111
112impl CoinRegressor {
113 pub fn new(
114 config: CoinRegressorConfig,
115 global_config: &GlobalConfig,
116 num_models_above: ModelIndex,
117 ) -> Result<CoinRegressor> {
118 let (pairs, triples) =
119 compile_interactions(global_config.interactions(), global_config.hash_seed());
120 Ok(CoinRegressor {
121 weights: DenseWeights::new(
122 bits_to_max_feature_index(global_config.num_bits()),
123 num_models_above,
124 StateIndex::from(6),
125 )?,
126 config,
127 model_states: vec![
128 CoinRegressorModelState {
129 normalized_sum_norm_x: 0.0,
130 total_weight: 0.0
131 };
132 *num_models_above as usize
133 ],
134 average_squared_norm_x: 0.0,
135 min_label: 0.0,
136 max_label: 0.0,
137 loss_function: LossFunctionHolder {
138 loss_function: LossFunctionType::Squared.create(),
139 },
140 pairs,
141 triples,
142 num_bits: global_config.num_bits(),
143 constant_feature_enabled: global_config.constant_feature_enabled(),
144 })
145 }
146}
147
148#[derive(Default)]
149pub struct CoinRegressorFactory;
150
151impl ReductionFactory for CoinRegressorFactory {
152 impl_default_factory_functions!("Coin", CoinRegressorConfig);
153
154 fn create(
155 &self,
156 config: &dyn ReductionConfig,
157 global_config: &GlobalConfig,
158 num_models_above: ModelIndex,
159 ) -> Result<ReductionWrapper> {
160 let config = config
161 .as_any()
162 .downcast_ref::<CoinRegressorConfig>()
163 .unwrap();
164
165 Ok(ReductionWrapper::new(
166 self.typename(),
167 Box::new(CoinRegressor::new(
168 config.clone(),
169 global_config,
170 num_models_above,
171 )?),
172 ReductionTypeDescriptionBuilder::new(
173 LabelType::Simple,
174 FeaturesType::SparseSimple,
175 PredictionType::Scalar,
176 )
177 .build(),
178 num_models_above,
179 ))
180 }
181}
182
183#[typetag::serde]
184impl ReductionImpl for CoinRegressor {
185 fn predict(
186 &self,
187 features: &mut Features,
188 _depth_info: &mut DepthInfo,
189 _model_offset: ModelIndex,
190 ) -> Prediction {
191 let sparse_feats: &SparseFeatures = features.as_inner().unwrap();
192
193 let mut prediction = 0.0;
194 foreach_feature(
195 0.into(),
196 sparse_feats,
197 &self.weights,
198 &self.pairs,
199 &self.triples,
200 self.num_bits,
201 self.constant_feature_enabled,
202 |feat_val, weight_val| prediction += feat_val * weight_val,
203 );
204
205 if prediction.is_nan() {
206 prediction = 0.0;
207 }
208
209 let scalar_pred = ScalarPrediction {
210 prediction: prediction.clamp(self.min_label, self.max_label),
211 raw_prediction: prediction,
212 };
213 scalar_pred.into()
214 }
215
216 fn learn(
217 &mut self,
218 features: &mut Features,
219 label: &Label,
220 _depth_info: &mut DepthInfo,
221 _model_offset: ModelIndex,
222 ) {
223 let sparse_feats: &SparseFeatures = features.as_inner().unwrap();
224 let simple_label: &SimpleLabel = label.as_inner().unwrap();
225
226 self.min_label = simple_label.value().min(self.min_label);
227 self.max_label = simple_label.value().max(self.max_label);
228 let _prediction = self.coin_betting_predict(sparse_feats, simple_label.weight());
229 self.coin_betting_update_after_predict(
230 sparse_feats,
231 _prediction,
232 simple_label.value(),
233 simple_label.weight(),
234 );
235 }
236
237 fn children(&self) -> Vec<&ReductionWrapper> {
238 vec![]
239 }
240
241 fn sensitivity(
243 &self,
244 features: &Features,
245 _label: f32,
246 _prediction: f32,
247 _weight: f32,
248 _depth_info: DepthInfo,
249 ) -> f32 {
250 let mut score = 0.0;
251 let inner = |feat_value: f32, state: &[f32]| {
252 assert!(state.len() == 6);
253 let sqrtf_ng2 = state[W_G2].sqrt();
254 let uncertain =
255 (self.config.beta + sqrtf_ng2) / self.config.alpha + self.config.l2_lambda;
256 score += (1.0 / uncertain) * feat_value.signum();
257 };
258
259 let feat = features.as_inner().unwrap();
260 foreach_feature_with_state(
261 ModelIndex::from(0),
262 feat,
263 &self.weights,
264 &self.pairs,
265 &self.triples,
266 self.num_bits,
267 self.constant_feature_enabled,
268 inner,
269 );
270 score
271 }
272}
273
274const W_XT: usize = 0; const W_ZT: usize = 1; const W_G2: usize = 2; const W_MX: usize = 3; const W_WE: usize = 4; const W_MG: usize = 5; struct PredOutcome(f32, f32);
284
285impl Sum for PredOutcome {
286 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
287 let mut sum = PredOutcome(0.0, 0.0);
288 for PredOutcome(x, y) in iter {
289 sum.0 += x;
290 sum.1 += y;
291 }
292 sum
293 }
294}
295
296impl CoinRegressor {
297 fn coin_betting_predict(&mut self, features: &SparseFeatures, weight: f32) -> f32 {
298 let mut prediction = 0.0;
299 let mut normalized_squared_norm_x = 0.0;
300
301 let inner_predict = |feat_value: f32, state: &[f32]| {
302 assert!(state.len() == 6);
303
304 let w_mx = state[W_MX].max(feat_value.abs());
305
306 let w_xt = if state[W_MG] * w_mx > 0.0 {
308 ((self.config.alpha + state[W_WE])
309 / (state[W_MG] * w_mx * (state[W_MG] * w_mx + state[W_G2])))
310 * state[W_ZT]
311 } else {
312 0.0
313 };
314
315 prediction += w_xt * feat_value;
316 if w_mx > 0.0 {
317 let x_normalized = feat_value / w_mx;
318 normalized_squared_norm_x += x_normalized * x_normalized;
319 } else {
320 }
321 };
322
323 foreach_feature_with_state(
324 0.into(),
325 features,
326 &self.weights,
327 &self.pairs,
328 &self.triples,
329 self.num_bits,
330 self.constant_feature_enabled,
331 inner_predict,
332 );
333
334 self.model_states[0].normalized_sum_norm_x += normalized_squared_norm_x * weight;
336 self.model_states[0].total_weight += weight;
337 self.average_squared_norm_x =
338 (self.model_states[0].normalized_sum_norm_x + 1e-6) / self.model_states[0].total_weight;
339
340 let partial_prediction = prediction / self.average_squared_norm_x;
341
342 partial_prediction.clamp(self.min_label, self.max_label)
346 }
347
348 fn coin_betting_update_after_predict(
349 &mut self,
350 features: &SparseFeatures,
351 prediction: f32,
352 label: f32,
353 weight: f32,
354 ) {
355 let update =
356 self.loss_function
357 .first_derivative(self.min_label, self.max_label, prediction, label)
358 * weight;
359
360 let inner_update = |feat_value: f32, state: &mut [f32]| {
363 assert!(state.len() == 6);
364 let fabs_x = feat_value.abs();
367 let gradient = update * feat_value;
368 if fabs_x > state[W_MX] {
369 state[W_MX] = fabs_x;
370 }
371 let fabs_gradient = update.abs();
372 if fabs_gradient > state[W_MG] {
375 state[W_MG] = if fabs_gradient > self.config.beta {
376 fabs_gradient
377 } else {
378 self.config.beta
379 };
380 }
381 if state[W_MG] * state[W_MX] > 0.0 {
382 state[W_XT] = ((self.config.alpha + state[W_WE])
383 / (state[W_MG] * state[W_MX] * (state[W_MG] * state[W_MX] + state[W_G2])))
384 * state[W_ZT];
385 } else {
386 state[W_XT] = 0.0;
387 }
388
389 state[W_ZT] += -gradient;
390 state[W_G2] += gradient.abs();
391 state[W_WE] += -gradient * state[W_XT];
392
393 state[W_XT] /= self.average_squared_norm_x;
394
395 };
403 foreach_feature_with_state_mut(
404 ModelIndex::from(0),
405 features,
406 &mut self.weights,
407 &self.pairs,
408 &self.triples,
409 self.num_bits,
410 self.constant_feature_enabled,
411 inner_update,
412 );
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use approx::assert_relative_eq;
419
420 use crate::{interactions::NamespaceDef, sparse_namespaced_features::Namespace};
421
422 use super::*;
423
424 #[test]
425 fn test_coin_betting_predict() {
426 let coin_config = CoinRegressorConfig::default();
427 let global_config = GlobalConfig::new(4, 0, false, &Vec::new());
428 let coin = CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
429 let mut features = SparseFeatures::new();
430 let ns = features.get_or_create_namespace(Namespace::Default);
431 ns.add_feature(0.into(), 1.0);
432
433 let mut features = Features::SparseSimple(features);
434
435 let mut depth_info = DepthInfo::new();
436 let prediction = coin.predict(&mut features, &mut depth_info, 0.into());
437 assert!(matches!(prediction, Prediction::Scalar { .. }));
439 }
440
441 #[test]
442 fn test_learning() {
443 let coin_config = CoinRegressorConfig::default();
444 let global_config = GlobalConfig::new(2, 0, false, &Vec::new());
445 let mut coin =
446 CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
447
448 let mut features = SparseFeatures::new();
449
450 {
451 let ns = features.get_or_create_namespace(Namespace::Default);
452 ns.add_feature(0.into(), 1.0);
453 ns.add_feature(1.into(), 1.0);
454 ns.add_feature(2.into(), 1.0);
455 ns.add_feature(3.into(), 1.0);
456 }
457
458 let mut depth_info = DepthInfo::new();
459 let mut features = Features::SparseSimple(features);
460 coin.learn(
461 &mut features,
462 &Label::Simple(SimpleLabel::new(0.5, 1.0)),
463 &mut depth_info,
464 0.into(),
465 );
466 coin.learn(
467 &mut features,
468 &Label::Simple(SimpleLabel::new(0.5, 1.0)),
469 &mut depth_info,
470 0.into(),
471 );
472 coin.learn(
473 &mut features,
474 &Label::Simple(SimpleLabel::new(0.5, 1.0)),
475 &mut depth_info,
476 0.into(),
477 );
478 coin.learn(
479 &mut features,
480 &Label::Simple(SimpleLabel::new(0.5, 1.0)),
481 &mut depth_info,
482 0.into(),
483 );
484
485 let pred = coin.predict(&mut features, &mut depth_info, 0.into());
486
487 assert!(matches!(pred, Prediction::Scalar { .. }));
488 let pred1: &ScalarPrediction = pred.as_inner().unwrap();
489 assert_relative_eq!(pred1.prediction, 0.5);
490 }
491
492 fn test_learning_e2e(
493 x: fn(i32) -> f32,
494 yhat: fn(f32) -> f32,
495 n: i32,
496 mut regressor: CoinRegressor,
497 test_set: Vec<f32>,
498 ) {
499 for i in 0..n {
500 let mut features = SparseFeatures::new();
501 let _x = x(i);
502 {
503 let ns = features.get_or_create_namespace(Namespace::Default);
504 ns.add_feature(2.into(), _x);
506 }
507
508 let mut depth_info = DepthInfo::new();
509 let mut features = Features::SparseSimple(features);
510 regressor.learn(
511 &mut features,
512 &Label::Simple(SimpleLabel::new(yhat(_x), 1.0)),
513 &mut depth_info,
514 0.into(),
515 );
516 }
517
518 for x in test_set {
519 let mut features = SparseFeatures::new();
520 {
521 let ns = features.get_or_create_namespace(Namespace::Default);
522 ns.add_feature(2.into(), x);
523 }
524
525 let mut depth_info = DepthInfo::new();
526 let mut features = Features::SparseSimple(features);
527 let pred = regressor.predict(&mut features, &mut depth_info, 0.into());
528 assert!(matches!(pred, Prediction::Scalar { .. }));
529
530 let pred_value: &ScalarPrediction = pred.as_inner().unwrap();
531 assert_relative_eq!(pred_value.prediction, yhat(x), epsilon = 0.1);
532 }
533 }
534
535 #[test]
536 fn test_learning_const() {
537 fn x(i: i32) -> f32 {
538 (i % 100) as f32 / 10.0
539 }
540 fn yhat(_x: f32) -> f32 {
541 1.0
542 }
543
544 let coin_config = CoinRegressorConfig::default();
545 let global_config = GlobalConfig::new(4, 0, true, &Vec::new());
546 let coin: CoinRegressor =
547 CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
548
549 test_learning_e2e(x, yhat, 10000, coin, vec![0.0, 1.0, 2.0, 3.0]);
550 }
551
552 #[test]
553 fn test_learning_linear() {
554 fn x(i: i32) -> f32 {
555 (i % 100) as f32 / 10.0
556 }
557 fn yhat(x: f32) -> f32 {
558 2.0 * x + 3.0
559 }
560
561 let coin_config = CoinRegressorConfig::default();
562 let global_config = GlobalConfig::new(4, 0, true, &Vec::new());
563 let coin: CoinRegressor =
564 CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
565
566 test_learning_e2e(x, yhat, 100000, coin, vec![0.0, 1.0, 2.0, 3.0]);
567 }
568
569 #[test]
570 fn test_learning_quadratic() {
571 fn x(i: i32) -> f32 {
572 (i % 100) as f32 / 10.0
573 }
574 fn yhat(x: f32) -> f32 {
575 x * x - 2.0 * x + 3.0
576 }
577
578 let coin_config = CoinRegressorConfig::default();
579 let global_config = GlobalConfig::new(
580 4,
581 0,
582 true,
583 &vec![vec![NamespaceDef::Default, NamespaceDef::Default]],
584 );
585 let coin: CoinRegressor =
586 CoinRegressor::new(coin_config, &global_config, ModelIndex::from(1)).unwrap();
587 test_learning_e2e(x, yhat, 100000, coin, vec![0.0, 1.0, 2.0, 3.0]);
588 }
589}