reductionml_core/reductions/
binary.rs

1use crate::error::Result;
2use crate::global_config::GlobalConfig;
3use crate::reduction::{
4    DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
5};
6use crate::reduction_factory::{
7    create_reduction, JsonReductionConfig, PascalCaseString, ReductionConfig, ReductionFactory,
8};
9use crate::utils::AsInner;
10
11use crate::reductions::CoinRegressorConfig;
12use crate::{impl_default_factory_functions, types::*, ModelIndex};
13
14use schemars::schema::RootSchema;
15use schemars::{schema_for, JsonSchema};
16use serde::{Deserialize, Serialize};
17use serde_default::DefaultFromSerde;
18use serde_json::json;
19
20#[derive(Deserialize, Serialize, JsonSchema, DefaultFromSerde)]
21#[serde(deny_unknown_fields)]
22#[serde(rename_all = "camelCase")]
23struct BinaryReductionConfig {
24    #[serde(default = "default_regressor")]
25    #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
26    regressor: JsonReductionConfig,
27}
28
29fn default_regressor() -> JsonReductionConfig {
30    JsonReductionConfig::new(
31        "Coin".try_into().unwrap(),
32        json!(CoinRegressorConfig::default()),
33    )
34}
35
36#[derive(Serialize, Deserialize)]
37struct BinaryReduction {
38    regressor: ReductionWrapper,
39}
40
41#[derive(Default)]
42pub struct BinaryReductionFactory;
43
44impl ReductionConfig for BinaryReductionConfig {
45    fn typename(&self) -> PascalCaseString {
46        "Binary".try_into().unwrap()
47    }
48
49    fn as_any(&self) -> &dyn std::any::Any {
50        self
51    }
52}
53
54impl ReductionFactory for BinaryReductionFactory {
55    impl_default_factory_functions!("Binary", BinaryReductionConfig);
56    fn create(
57        &self,
58        config: &dyn ReductionConfig,
59        global_config: &GlobalConfig,
60        num_models_above: ModelIndex,
61    ) -> Result<ReductionWrapper> {
62        let config = config
63            .as_any()
64            .downcast_ref::<BinaryReductionConfig>()
65            .unwrap();
66        let regressor_config = crate::reduction_factory::parse_config(&config.regressor)?;
67        let regressor: ReductionWrapper =
68            create_reduction(regressor_config.as_ref(), global_config, num_models_above)?;
69
70        let types = ReductionTypeDescriptionBuilder::new(
71            LabelType::Binary,
72            regressor.types().input_features_type(),
73            PredictionType::Binary,
74        )
75        .with_input_prediction_type(PredictionType::Scalar)
76        .with_output_features_type(regressor.types().input_features_type())
77        .with_output_label_type(LabelType::Simple)
78        .build();
79
80        if let Some(reason) = types.check_and_get_reason(regressor.types()) {
81            return Err(crate::error::Error::InvalidArgument(format!(
82                "Invalid reduction configuration: {}",
83                reason
84            )));
85        }
86
87        Ok(ReductionWrapper::new(
88            self.typename(),
89            Box::new(BinaryReduction { regressor }),
90            types,
91            num_models_above,
92        ))
93    }
94}
95
96impl From<BinaryLabel> for SimpleLabel {
97    fn from(label: BinaryLabel) -> Self {
98        if label.0 { 1.0 } else { -1.0 }.into()
99    }
100}
101
102#[typetag::serde]
103impl ReductionImpl for BinaryReduction {
104    fn predict(
105        &self,
106        features: &mut Features,
107        depth_info: &mut DepthInfo,
108        _model_offset: ModelIndex,
109    ) -> Prediction {
110        let pred = self.regressor.predict(features, depth_info, 0.into());
111        let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
112
113        Prediction::Binary((scalar_pred.prediction > 0.0).into())
114    }
115
116    fn predict_then_learn(
117        &mut self,
118        features: &mut Features,
119        label: &Label,
120        depth_info: &mut DepthInfo,
121        _model_offset: ModelIndex,
122    ) -> Prediction {
123        let binary_label: &BinaryLabel = label.as_inner().unwrap();
124
125        let pred = self.regressor.predict_then_learn(
126            features,
127            &SimpleLabel::from(*binary_label).into(),
128            depth_info,
129            0.into(),
130        );
131
132        let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
133
134        Prediction::Binary((scalar_pred.prediction > 0.0).into())
135    }
136
137    fn learn(
138        &mut self,
139        features: &mut Features,
140        label: &Label,
141        depth_info: &mut DepthInfo,
142        _model_offset: ModelIndex,
143    ) {
144        let binary_label: &BinaryLabel = label.as_inner().unwrap();
145
146        self.regressor.learn(
147            features,
148            &SimpleLabel::from(*binary_label).into(),
149            depth_info,
150            0.into(),
151        )
152    }
153
154    fn children(&self) -> Vec<&ReductionWrapper> {
155        vec![&self.regressor]
156    }
157}