reductionml_core/reductions/
binary.rs1use crate::error::Result;
2use crate::global_config::GlobalConfig;
3use crate::reduction::{
4 DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
5};
6use crate::reduction_factory::{
7 create_reduction, JsonReductionConfig, PascalCaseString, ReductionConfig, ReductionFactory,
8};
9use crate::utils::AsInner;
10
11use crate::reductions::CoinRegressorConfig;
12use crate::{impl_default_factory_functions, types::*, ModelIndex};
13
14use schemars::schema::RootSchema;
15use schemars::{schema_for, JsonSchema};
16use serde::{Deserialize, Serialize};
17use serde_default::DefaultFromSerde;
18use serde_json::json;
19
20#[derive(Deserialize, Serialize, JsonSchema, DefaultFromSerde)]
21#[serde(deny_unknown_fields)]
22#[serde(rename_all = "camelCase")]
23struct BinaryReductionConfig {
24 #[serde(default = "default_regressor")]
25 #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
26 regressor: JsonReductionConfig,
27}
28
29fn default_regressor() -> JsonReductionConfig {
30 JsonReductionConfig::new(
31 "Coin".try_into().unwrap(),
32 json!(CoinRegressorConfig::default()),
33 )
34}
35
36#[derive(Serialize, Deserialize)]
37struct BinaryReduction {
38 regressor: ReductionWrapper,
39}
40
41#[derive(Default)]
42pub struct BinaryReductionFactory;
43
44impl ReductionConfig for BinaryReductionConfig {
45 fn typename(&self) -> PascalCaseString {
46 "Binary".try_into().unwrap()
47 }
48
49 fn as_any(&self) -> &dyn std::any::Any {
50 self
51 }
52}
53
54impl ReductionFactory for BinaryReductionFactory {
55 impl_default_factory_functions!("Binary", BinaryReductionConfig);
56 fn create(
57 &self,
58 config: &dyn ReductionConfig,
59 global_config: &GlobalConfig,
60 num_models_above: ModelIndex,
61 ) -> Result<ReductionWrapper> {
62 let config = config
63 .as_any()
64 .downcast_ref::<BinaryReductionConfig>()
65 .unwrap();
66 let regressor_config = crate::reduction_factory::parse_config(&config.regressor)?;
67 let regressor: ReductionWrapper =
68 create_reduction(regressor_config.as_ref(), global_config, num_models_above)?;
69
70 let types = ReductionTypeDescriptionBuilder::new(
71 LabelType::Binary,
72 regressor.types().input_features_type(),
73 PredictionType::Binary,
74 )
75 .with_input_prediction_type(PredictionType::Scalar)
76 .with_output_features_type(regressor.types().input_features_type())
77 .with_output_label_type(LabelType::Simple)
78 .build();
79
80 if let Some(reason) = types.check_and_get_reason(regressor.types()) {
81 return Err(crate::error::Error::InvalidArgument(format!(
82 "Invalid reduction configuration: {}",
83 reason
84 )));
85 }
86
87 Ok(ReductionWrapper::new(
88 self.typename(),
89 Box::new(BinaryReduction { regressor }),
90 types,
91 num_models_above,
92 ))
93 }
94}
95
96impl From<BinaryLabel> for SimpleLabel {
97 fn from(label: BinaryLabel) -> Self {
98 if label.0 { 1.0 } else { -1.0 }.into()
99 }
100}
101
102#[typetag::serde]
103impl ReductionImpl for BinaryReduction {
104 fn predict(
105 &self,
106 features: &mut Features,
107 depth_info: &mut DepthInfo,
108 _model_offset: ModelIndex,
109 ) -> Prediction {
110 let pred = self.regressor.predict(features, depth_info, 0.into());
111 let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
112
113 Prediction::Binary((scalar_pred.prediction > 0.0).into())
114 }
115
116 fn predict_then_learn(
117 &mut self,
118 features: &mut Features,
119 label: &Label,
120 depth_info: &mut DepthInfo,
121 _model_offset: ModelIndex,
122 ) -> Prediction {
123 let binary_label: &BinaryLabel = label.as_inner().unwrap();
124
125 let pred = self.regressor.predict_then_learn(
126 features,
127 &SimpleLabel::from(*binary_label).into(),
128 depth_info,
129 0.into(),
130 );
131
132 let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
133
134 Prediction::Binary((scalar_pred.prediction > 0.0).into())
135 }
136
137 fn learn(
138 &mut self,
139 features: &mut Features,
140 label: &Label,
141 depth_info: &mut DepthInfo,
142 _model_offset: ModelIndex,
143 ) {
144 let binary_label: &BinaryLabel = label.as_inner().unwrap();
145
146 self.regressor.learn(
147 features,
148 &SimpleLabel::from(*binary_label).into(),
149 depth_info,
150 0.into(),
151 )
152 }
153
154 fn children(&self) -> Vec<&ReductionWrapper> {
155 vec![&self.regressor]
156 }
157}