1use approx::assert_abs_diff_eq;
2use serde::{Deserialize, Serialize};
3
4use crate::{reduction_factory::PascalCaseString, types::*, ModelIndex};
5
6#[derive(Clone, Copy, PartialEq, Eq)]
7pub struct DepthInfo {
8 offset: ModelIndex,
9}
10
11impl DepthInfo {
12 pub fn new() -> DepthInfo {
13 DepthInfo { offset: 0.into() }
14 }
15 pub(crate) fn increment(&mut self, num_models_below: ModelIndex, i: ModelIndex) {
16 self.offset = (*self.offset + (*num_models_below * *i)).into();
17 }
18 pub(crate) fn decrement(&mut self, num_models_below: ModelIndex, i: ModelIndex) {
19 self.offset = (*self.offset - (*num_models_below * *i)).into();
20 }
21
22 pub(crate) fn absolute_offset(&self) -> ModelIndex {
23 self.offset
24 }
25}
26
27impl ReductionWrapper {
28 pub fn predict(
29 &self,
30 features: &mut Features,
31 depth_info: &mut DepthInfo,
32 model_offset: ModelIndex,
33 ) -> Prediction {
34 if cfg!(debug_assertions) {
36 let mut features_copy = features.clone();
37 depth_info.increment(self.num_models_below, model_offset);
38 let res = self.reduction.predict(features, depth_info, model_offset);
39 depth_info.decrement(self.num_models_below, model_offset);
40
41 assert_abs_diff_eq!(features, &mut features_copy);
44 res
45 } else {
46 depth_info.increment(self.num_models_below, model_offset);
47 let res = self.reduction.predict(features, depth_info, model_offset);
48 depth_info.decrement(self.num_models_below, model_offset);
49 res
50 }
51 }
52 pub fn predict_then_learn(
53 &mut self,
54 features: &mut Features,
55 label: &Label,
56 depth_info: &mut DepthInfo,
57 model_offset: ModelIndex,
58 ) -> Prediction {
59 if cfg!(debug_assertions) {
60 let mut features_copy = features.clone();
61
62 depth_info.increment(self.num_models_below, model_offset);
63 let res = self
64 .reduction
65 .predict_then_learn(features, label, depth_info, model_offset);
66 depth_info.decrement(self.num_models_below, model_offset);
67 assert_abs_diff_eq!(features, &mut features_copy);
70 res
71 } else {
72 depth_info.increment(self.num_models_below, model_offset);
73 let res = self
74 .reduction
75 .predict_then_learn(features, label, depth_info, model_offset);
76 depth_info.decrement(self.num_models_below, model_offset);
77 res
78 }
79 }
80 pub fn learn(
81 &mut self,
82 features: &mut Features,
83 label: &Label,
84 depth_info: &mut DepthInfo,
85 model_offset: ModelIndex,
86 ) {
87 if cfg!(debug_assertions) {
90 let mut features_copy = features.clone();
91
92 depth_info.increment(self.num_models_below, model_offset);
93 self.reduction
94 .learn(features, label, depth_info, model_offset);
95 depth_info.decrement(self.num_models_below, model_offset);
96 assert_abs_diff_eq!(features, &mut features_copy);
99 } else {
100 depth_info.increment(self.num_models_below, model_offset);
101 self.reduction
102 .learn(features, label, depth_info, model_offset);
103 depth_info.decrement(self.num_models_below, model_offset);
104 }
105 }
106
107 pub fn children(&self) -> Vec<&ReductionWrapper> {
108 self.reduction.children()
109 }
110
111 pub fn sensitivity(
113 &self,
114 features: &Features,
115 label: f32,
116 prediction: f32,
117 weight: f32,
118 depth_info: DepthInfo,
119 ) -> f32 {
120 self.reduction
121 .sensitivity(features, label, prediction, weight, depth_info)
122 }
123}
124
125#[derive(Serialize, Deserialize)]
126pub struct ReductionTypeDescription {
127 input_label_type: LabelType,
128 output_label_type: Option<LabelType>,
129 input_features_type: FeaturesType,
130 output_features_type: Option<FeaturesType>,
131 input_prediction_type: Option<PredictionType>,
132 output_prediction_type: PredictionType,
133}
134
135impl ReductionTypeDescription {
136 pub fn input_label_type(&self) -> LabelType {
137 self.input_label_type
138 }
139 pub fn output_label_type(&self) -> Option<LabelType> {
140 self.output_label_type
141 }
142 pub fn input_features_type(&self) -> FeaturesType {
143 self.input_features_type
144 }
145 pub fn output_features_type(&self) -> Option<FeaturesType> {
146 self.output_features_type
147 }
148 pub fn input_prediction_type(&self) -> Option<PredictionType> {
149 self.input_prediction_type
150 }
151 pub fn output_prediction_type(&self) -> PredictionType {
152 self.output_prediction_type
153 }
154}
155pub struct ReductionTypeDescriptionBuilder {
156 types: ReductionTypeDescription,
157}
158
159impl ReductionTypeDescriptionBuilder {
160 pub fn new(
161 input_label_type: LabelType,
162 input_features_type: FeaturesType,
163 output_prediction_type: PredictionType,
164 ) -> ReductionTypeDescriptionBuilder {
165 ReductionTypeDescriptionBuilder {
166 types: ReductionTypeDescription::new(
167 input_label_type,
168 None,
169 input_features_type,
170 None,
171 None,
172 output_prediction_type,
173 ),
174 }
175 }
176
177 pub fn with_output_label_type(
178 mut self,
179 output_label_type: LabelType,
180 ) -> ReductionTypeDescriptionBuilder {
181 self.types.output_label_type = Some(output_label_type);
182 self
183 }
184
185 pub fn with_output_features_type(
186 mut self,
187 output_features_type: FeaturesType,
188 ) -> ReductionTypeDescriptionBuilder {
189 self.types.output_features_type = Some(output_features_type);
190 self
191 }
192
193 pub fn with_input_prediction_type(
194 mut self,
195 input_prediction_type: PredictionType,
196 ) -> ReductionTypeDescriptionBuilder {
197 self.types.input_prediction_type = Some(input_prediction_type);
198 self
199 }
200
201 pub fn build(self) -> ReductionTypeDescription {
202 self.types
203 }
204}
205
206impl ReductionTypeDescription {
207 fn new(
208 input_label_type: LabelType,
209 output_label_type: Option<LabelType>,
210 input_features_type: FeaturesType,
211 output_features_type: Option<FeaturesType>,
212 input_prediction_type: Option<PredictionType>,
213 output_prediction_type: PredictionType,
214 ) -> ReductionTypeDescription {
215 ReductionTypeDescription {
216 input_label_type,
217 output_label_type,
218 input_features_type,
219 output_features_type,
220 input_prediction_type,
221 output_prediction_type,
222 }
223 }
224
225 pub fn check_and_get_reason(&self, base: &ReductionTypeDescription) -> Option<String> {
226 let mut res = None;
227 if self.output_label_type != Some(base.input_label_type) {
228 res = Some(format!(
229 "input_label_type: {:?} != {:?}",
230 self.input_label_type, base.input_label_type
231 ));
232 }
233
234 if self.output_features_type != Some(base.input_features_type) {
235 res = Some(format!(
236 "output_features_type: {:?} != {:?}",
237 self.output_features_type, base.output_features_type
238 ));
239 }
240
241 if self.input_prediction_type != Some(base.output_prediction_type) {
242 res = Some(format!(
243 "output_prediction_type: {:?} != {:?}",
244 self.output_prediction_type, base.output_prediction_type
245 ));
246 }
247 res
248 }
249}
250
251#[derive(Serialize, Deserialize)]
252pub struct ReductionWrapper {
253 typename: PascalCaseString,
254 reduction: Box<dyn ReductionImpl>,
255 type_description: ReductionTypeDescription,
256 num_models_below: ModelIndex,
257}
258
259impl ReductionWrapper {
260 pub fn new(
261 typename: PascalCaseString,
262 reduction: Box<dyn ReductionImpl>,
263 type_description: ReductionTypeDescription,
264 num_models_below: ModelIndex,
265 ) -> ReductionWrapper {
266 ReductionWrapper {
267 typename,
268 reduction,
269 type_description,
270 num_models_below,
271 }
272 }
273
274 pub fn types(&self) -> &ReductionTypeDescription {
275 &self.type_description
276 }
277
278 pub fn typename(&self) -> &str {
279 self.typename.as_ref()
280 }
281}
282
283#[typetag::serde(tag = "type")]
284pub trait ReductionImpl: Send {
285 fn predict(
286 &self,
287 features: &mut Features,
288 depth_info: &mut DepthInfo,
289 model_offset: ModelIndex,
290 ) -> Prediction;
291 fn predict_then_learn(
292 &mut self,
293 features: &mut Features,
294 label: &Label,
295 depth_info: &mut DepthInfo,
296 model_offset: ModelIndex,
297 ) -> Prediction {
298 let depth_info_copy: DepthInfo = *depth_info;
299 let prediction = self.predict(features, depth_info, model_offset);
300 let depth_info_copy2: DepthInfo = depth_info_copy;
301 self.learn(features, label, depth_info, model_offset);
302 assert!(depth_info == &depth_info_copy2);
303 assert!(depth_info == &depth_info_copy);
304 prediction
305 }
306
307 fn learn(
308 &mut self,
309 features: &mut Features,
310 label: &Label,
311 depth_info: &mut DepthInfo,
312 model_offset: ModelIndex,
313 );
314 fn sensitivity(
315 &self,
316 features: &Features,
317 label: f32,
318 prediction: f32,
319 weight: f32,
320 depth_info: DepthInfo,
321 ) -> f32 {
322 self.children()
323 .first()
324 .unwrap()
325 .sensitivity(features, label, prediction, weight, depth_info)
326 }
327 fn children(&self) -> Vec<&ReductionWrapper>;
328}