reductionml_core/reductions/
cb_adf.rs1use crate::error::Result;
2use crate::global_config::GlobalConfig;
3use crate::reduction::{
4 DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
5};
6use crate::reduction_factory::{
7 create_reduction, JsonReductionConfig, PascalCaseString, ReductionConfig, ReductionFactory,
8};
9use crate::utils::AsInner;
10
11use crate::reductions::CoinRegressorConfig;
12use crate::{impl_default_factory_functions, types::*, ModelIndex};
13use schemars::schema::RootSchema;
14use schemars::{schema_for, JsonSchema};
15use serde::{Deserialize, Serialize};
16use serde_default::DefaultFromSerde;
17use serde_json::json;
18
19#[derive(Serialize, Deserialize, Clone, Copy, JsonSchema, PartialEq)]
20pub enum CBType {
21 #[serde(rename = "ips")]
22 Ips,
23 #[serde(rename = "mtr")]
24 Mtr,
25}
26
27#[derive(Deserialize, Serialize, JsonSchema, DefaultFromSerde)]
28#[serde(deny_unknown_fields)]
29#[serde(rename_all = "camelCase")]
30pub struct CBAdfConfig {
31 #[serde(default = "default_cb_type")]
32 cb_type: CBType,
33 #[serde(default = "default_regressor")]
34 #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
35 regressor: JsonReductionConfig,
36}
37
38impl CBAdfConfig {
39 pub fn cb_type(&self) -> CBType {
40 self.cb_type
41 }
42}
43
44fn default_cb_type() -> CBType {
45 CBType::Mtr
46}
47
48impl ReductionConfig for CBAdfConfig {
49 fn typename(&self) -> PascalCaseString {
50 "CbAdf".try_into().unwrap()
51 }
52
53 fn as_any(&self) -> &dyn std::any::Any {
54 self
55 }
56}
57
58fn default_regressor() -> JsonReductionConfig {
59 JsonReductionConfig::new(
60 "Coin".try_into().unwrap(),
61 json!(CoinRegressorConfig::default()),
62 )
63}
64
65#[derive(Serialize, Deserialize, Default)]
66struct MtrState {
67 action_sum: usize,
68 event_sum: usize,
69}
70
71#[derive(Serialize, Deserialize)]
72struct CBAdfReduction {
73 cb_type: CBType,
74 regressor: ReductionWrapper,
75 mtr_state: MtrState,
77}
78
79#[derive(Default)]
80pub struct CBAdfReductionFactory;
81
82impl ReductionFactory for CBAdfReductionFactory {
83 impl_default_factory_functions!("CbAdf", CBAdfConfig);
84
85 fn create(
86 &self,
87 config: &dyn ReductionConfig,
88 global_config: &GlobalConfig,
89 num_models_above: ModelIndex,
90 ) -> Result<ReductionWrapper> {
91 let config = config.as_any().downcast_ref::<CBAdfConfig>().unwrap();
92 let regressor_config = crate::reduction_factory::parse_config(&config.regressor)?;
93 let regressor: ReductionWrapper =
94 create_reduction(regressor_config.as_ref(), global_config, num_models_above)?;
95
96 let types = ReductionTypeDescriptionBuilder::new(
97 LabelType::CB,
98 FeaturesType::SparseCBAdf,
99 PredictionType::ActionScores,
100 )
101 .with_input_prediction_type(PredictionType::Scalar)
102 .with_output_features_type(FeaturesType::SparseSimple)
103 .with_output_label_type(LabelType::Simple)
104 .build();
105
106 if let Some(reason) = types.check_and_get_reason(regressor.types()) {
107 return Err(crate::error::Error::InvalidArgument(format!(
108 "Invalid reduction configuration: {}",
109 reason
110 )));
111 }
112
113 Ok(ReductionWrapper::new(
114 self.typename(),
115 Box::new(CBAdfReduction {
116 cb_type: config.cb_type,
117 regressor,
118 mtr_state: Default::default(),
119 }),
120 types,
121 num_models_above,
122 ))
123 }
124}
125
126fn generate_ips_simple_label(label: &CBLabel, current_action_index: usize) -> SimpleLabel {
128 if current_action_index == label.action {
129 debug_assert!(label.probability > 0.0);
130 (label.cost / label.probability).into()
131 } else {
132 0.0.into()
133 }
134}
135
136#[typetag::serde]
137impl ReductionImpl for CBAdfReduction {
138 fn predict(
139 &self,
140 features: &mut Features,
141 depth_info: &mut DepthInfo,
142 _model_offset: ModelIndex,
143 ) -> Prediction {
144 let cb_adf_features: &mut CBAdfFeatures = features.as_inner_mut().unwrap();
145
146 let mut action_scores = ActionScoresPrediction::default();
147 for (counter, action) in cb_adf_features.actions.iter_mut().enumerate() {
148 if let Some(shared_feats) = &cb_adf_features.shared {
149 action.append(shared_feats);
150 }
151
152 let pred = self
153 .regressor
154 .predict(&mut action.into(), depth_info, 0.into());
155 let scalar_pred: &ScalarPrediction = pred.as_inner().unwrap();
156 action_scores.0.push((counter, scalar_pred.raw_prediction));
157 if let Some(shared_feats) = &cb_adf_features.shared {
158 action.remove(shared_feats);
159 }
160 }
161
162 action_scores.into()
163 }
164
165 fn learn(
166 &mut self,
167 features: &mut Features,
168 label: &Label,
169 depth_info: &mut DepthInfo,
170 _model_offset: ModelIndex,
171 ) {
172 let cb_adf_features: &mut CBAdfFeatures = features.as_inner_mut().unwrap();
173 let cb_label: &CBLabel = label.as_inner().unwrap();
174
175 match self.cb_type {
176 CBType::Ips => {
177 for (counter, action) in cb_adf_features.actions.iter_mut().enumerate() {
178 if let Some(shared_feats) = &cb_adf_features.shared {
179 action.append(shared_feats);
180 }
181
182 self.regressor.learn(
183 &mut action.into(),
184 &(generate_ips_simple_label(cb_label, counter).into()),
185 depth_info,
186 0.into(),
187 );
188 if let Some(shared_feats) = &cb_adf_features.shared {
189 action.remove(shared_feats);
190 }
191 }
192 }
193 CBType::Mtr => {
194 self.mtr_state.action_sum += cb_adf_features.actions.len();
195 self.mtr_state.event_sum += 1;
196
197 let cost = cb_label.cost;
198 let prob = cb_label.probability;
200 let weight = 1.0 / prob
201 * (self.mtr_state.event_sum as f32 / self.mtr_state.action_sum as f32);
202
203 let simple_label = SimpleLabel::new(cost, weight);
204 match cb_adf_features.shared.as_mut() {
205 Some(shared_feats) => {
206 let action = cb_adf_features.actions.get(cb_label.action).unwrap();
207 shared_feats.append(action);
208 self.regressor.learn(
209 &mut Features::SparseSimpleRef(shared_feats),
210 &simple_label.into(),
211 depth_info,
212 0.into(),
213 );
214 shared_feats.remove(action);
215 }
216 None => {
217 todo!()
218 }
219 }
220 }
221 }
222 }
223
224 fn children(&self) -> Vec<&ReductionWrapper> {
225 vec![&self.regressor]
226 }
227}