reductionml_core/reductions/
cb_adf.rs

1use crate::error::Result;
2use crate::global_config::GlobalConfig;
3use crate::reduction::{
4    DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
5};
6use crate::reduction_factory::{
7    create_reduction, JsonReductionConfig, PascalCaseString, ReductionConfig, ReductionFactory,
8};
9use crate::utils::AsInner;
10
11use crate::reductions::CoinRegressorConfig;
12use crate::{impl_default_factory_functions, types::*, ModelIndex};
13use schemars::schema::RootSchema;
14use schemars::{schema_for, JsonSchema};
15use serde::{Deserialize, Serialize};
16use serde_default::DefaultFromSerde;
17use serde_json::json;
18
19#[derive(Serialize, Deserialize, Clone, Copy, JsonSchema, PartialEq)]
20pub enum CBType {
21    #[serde(rename = "ips")]
22    Ips,
23    #[serde(rename = "mtr")]
24    Mtr,
25}
26
27#[derive(Deserialize, Serialize, JsonSchema, DefaultFromSerde)]
28#[serde(deny_unknown_fields)]
29#[serde(rename_all = "camelCase")]
30pub struct CBAdfConfig {
31    #[serde(default = "default_cb_type")]
32    cb_type: CBType,
33    #[serde(default = "default_regressor")]
34    #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
35    regressor: JsonReductionConfig,
36}
37
38impl CBAdfConfig {
39    pub fn cb_type(&self) -> CBType {
40        self.cb_type
41    }
42}
43
44fn default_cb_type() -> CBType {
45    CBType::Mtr
46}
47
48impl ReductionConfig for CBAdfConfig {
49    fn typename(&self) -> PascalCaseString {
50        "CbAdf".try_into().unwrap()
51    }
52
53    fn as_any(&self) -> &dyn std::any::Any {
54        self
55    }
56}
57
58fn default_regressor() -> JsonReductionConfig {
59    JsonReductionConfig::new(
60        "Coin".try_into().unwrap(),
61        json!(CoinRegressorConfig::default()),
62    )
63}
64
65#[derive(Serialize, Deserialize, Default)]
66struct MtrState {
67    action_sum: usize,
68    event_sum: usize,
69}
70
71#[derive(Serialize, Deserialize)]
72struct CBAdfReduction {
73    cb_type: CBType,
74    regressor: ReductionWrapper,
75    // TODO: have MTR state per interleaved model.
76    mtr_state: MtrState,
77}
78
79#[derive(Default)]
80pub struct CBAdfReductionFactory;
81
82impl ReductionFactory for CBAdfReductionFactory {
83    impl_default_factory_functions!("CbAdf", CBAdfConfig);
84
85    fn create(
86        &self,
87        config: &dyn ReductionConfig,
88        global_config: &GlobalConfig,
89        num_models_above: ModelIndex,
90    ) -> Result<ReductionWrapper> {
91        let config = config.as_any().downcast_ref::<CBAdfConfig>().unwrap();
92        let regressor_config = crate::reduction_factory::parse_config(&config.regressor)?;
93        let regressor: ReductionWrapper =
94            create_reduction(regressor_config.as_ref(), global_config, num_models_above)?;
95
96        let types = ReductionTypeDescriptionBuilder::new(
97            LabelType::CB,
98            FeaturesType::SparseCBAdf,
99            PredictionType::ActionScores,
100        )
101        .with_input_prediction_type(PredictionType::Scalar)
102        .with_output_features_type(FeaturesType::SparseSimple)
103        .with_output_label_type(LabelType::Simple)
104        .build();
105
106        if let Some(reason) = types.check_and_get_reason(regressor.types()) {
107            return Err(crate::error::Error::InvalidArgument(format!(
108                "Invalid reduction configuration: {}",
109                reason
110            )));
111        }
112
113        Ok(ReductionWrapper::new(
114            self.typename(),
115            Box::new(CBAdfReduction {
116                cb_type: config.cb_type,
117                regressor,
118                mtr_state: Default::default(),
119            }),
120            types,
121            num_models_above,
122        ))
123    }
124}
125
126// TODO: clip_p
127fn generate_ips_simple_label(label: &CBLabel, current_action_index: usize) -> SimpleLabel {
128    if current_action_index == label.action {
129        debug_assert!(label.probability > 0.0);
130        (label.cost / label.probability).into()
131    } else {
132        0.0.into()
133    }
134}
135
136#[typetag::serde]
137impl ReductionImpl for CBAdfReduction {
138    fn predict(
139        &self,
140        features: &mut Features,
141        depth_info: &mut DepthInfo,
142        _model_offset: ModelIndex,
143    ) -> Prediction {
144        let cb_adf_features: &mut CBAdfFeatures = features.as_inner_mut().unwrap();
145
146        let mut action_scores = ActionScoresPrediction::default();
147        for (counter, action) in cb_adf_features.actions.iter_mut().enumerate() {
148            if let Some(shared_feats) = &cb_adf_features.shared {
149                action.append(shared_feats);
150            }
151
152            let pred = self
153                .regressor
154                .predict(&mut action.into(), depth_info, 0.into());
155            let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
156            action_scores.0.push((counter, scalar_pred.raw_prediction));
157            if let Some(shared_feats) = &cb_adf_features.shared {
158                action.remove(shared_feats);
159            }
160        }
161
162        action_scores.into()
163    }
164
165    fn learn(
166        &mut self,
167        features: &mut Features,
168        label: &Label,
169        depth_info: &mut DepthInfo,
170        _model_offset: ModelIndex,
171    ) {
172        let cb_adf_features: &mut CBAdfFeatures = features.as_inner_mut().unwrap();
173        let cb_label: &CBLabel = label.as_inner().unwrap();
174
175        match self.cb_type {
176            CBType::Ips => {
177                for (counter, action) in cb_adf_features.actions.iter_mut().enumerate() {
178                    if let Some(shared_feats) = &cb_adf_features.shared {
179                        action.append(shared_feats);
180                    }
181
182                    self.regressor.learn(
183                        &mut action.into(),
184                        &(generate_ips_simple_label(cb_label, counter).into()),
185                        depth_info,
186                        0.into(),
187                    );
188                    if let Some(shared_feats) = &cb_adf_features.shared {
189                        action.remove(shared_feats);
190                    }
191                }
192            }
193            CBType::Mtr => {
194                self.mtr_state.action_sum += cb_adf_features.actions.len();
195                self.mtr_state.event_sum += 1;
196
197                let cost = cb_label.cost;
198                // TODO clip_p
199                let prob = cb_label.probability;
200                let weight = 1.0 / prob
201                    * (self.mtr_state.event_sum as f32 / self.mtr_state.action_sum as f32);
202
203                let simple_label = SimpleLabel::new(cost, weight);
204                match cb_adf_features.shared.as_mut() {
205                    Some(shared_feats) => {
206                        let action = cb_adf_features.actions.get(cb_label.action).unwrap();
207                        shared_feats.append(action);
208                        self.regressor.learn(
209                            &mut Features::SparseSimpleRef(shared_feats),
210                            &simple_label.into(),
211                            depth_info,
212                            0.into(),
213                        );
214                        shared_feats.remove(action);
215                    }
216                    None => {
217                        todo!()
218                    }
219                }
220            }
221        }
222    }
223
224    fn children(&self) -> Vec<&ReductionWrapper> {
225        vec![&self.regressor]
226    }
227}