reductionml_core/
reduction.rs

1use approx::assert_abs_diff_eq;
2use serde::{Deserialize, Serialize};
3
4use crate::{reduction_factory::PascalCaseString, types::*, ModelIndex};
5
6#[derive(Clone, Copy, PartialEq, Eq)]
7pub struct DepthInfo {
8    offset: ModelIndex,
9}
10
11impl DepthInfo {
12    pub fn new() -> DepthInfo {
13        DepthInfo { offset: 0.into() }
14    }
15    pub(crate) fn increment(&mut self, num_models_below: ModelIndex, i: ModelIndex) {
16        self.offset = (*self.offset + (*num_models_below * *i)).into();
17    }
18    pub(crate) fn decrement(&mut self, num_models_below: ModelIndex, i: ModelIndex) {
19        self.offset = (*self.offset - (*num_models_below * *i)).into();
20    }
21
22    pub(crate) fn absolute_offset(&self) -> ModelIndex {
23        self.offset
24    }
25}
26
27impl ReductionWrapper {
28    pub fn predict(
29        &self,
30        features: &mut Features,
31        depth_info: &mut DepthInfo,
32        model_offset: ModelIndex,
33    ) -> Prediction {
34        // TODO assert prediction matches expected type.
35        if cfg!(debug_assertions) {
36            let mut features_copy = features.clone();
37            depth_info.increment(self.num_models_below, model_offset);
38            let res = self.reduction.predict(features, depth_info, model_offset);
39            depth_info.decrement(self.num_models_below, model_offset);
40
41            // This is an important check to ensure that a reduction put the features back in the
42            // same state as it found them.
43            assert_abs_diff_eq!(features, &mut features_copy);
44            res
45        } else {
46            depth_info.increment(self.num_models_below, model_offset);
47            let res = self.reduction.predict(features, depth_info, model_offset);
48            depth_info.decrement(self.num_models_below, model_offset);
49            res
50        }
51    }
52    pub fn predict_then_learn(
53        &mut self,
54        features: &mut Features,
55        label: &Label,
56        depth_info: &mut DepthInfo,
57        model_offset: ModelIndex,
58    ) -> Prediction {
59        if cfg!(debug_assertions) {
60            let mut features_copy = features.clone();
61
62            depth_info.increment(self.num_models_below, model_offset);
63            let res = self
64                .reduction
65                .predict_then_learn(features, label, depth_info, model_offset);
66            depth_info.decrement(self.num_models_below, model_offset);
67            // This is an important check to ensure that a reduction put the features back in the
68            // same state as it found them.
69            assert_abs_diff_eq!(features, &mut features_copy);
70            res
71        } else {
72            depth_info.increment(self.num_models_below, model_offset);
73            let res = self
74                .reduction
75                .predict_then_learn(features, label, depth_info, model_offset);
76            depth_info.decrement(self.num_models_below, model_offset);
77            res
78        }
79    }
80    pub fn learn(
81        &mut self,
82        features: &mut Features,
83        label: &Label,
84        depth_info: &mut DepthInfo,
85        model_offset: ModelIndex,
86    ) {
87        // TODO assert label matches expected type.
88
89        if cfg!(debug_assertions) {
90            let mut features_copy = features.clone();
91
92            depth_info.increment(self.num_models_below, model_offset);
93            self.reduction
94                .learn(features, label, depth_info, model_offset);
95            depth_info.decrement(self.num_models_below, model_offset);
96            // This is an important check to ensure that a reduction put the features back in the
97            // same state as it found them.
98            assert_abs_diff_eq!(features, &mut features_copy);
99        } else {
100            depth_info.increment(self.num_models_below, model_offset);
101            self.reduction
102                .learn(features, label, depth_info, model_offset);
103            depth_info.decrement(self.num_models_below, model_offset);
104        }
105    }
106
107    pub fn children(&self) -> Vec<&ReductionWrapper> {
108        self.reduction.children()
109    }
110
111    // TODO work out how to handle model offset for sensitivity...
112    pub fn sensitivity(
113        &self,
114        features: &Features,
115        label: f32,
116        prediction: f32,
117        weight: f32,
118        depth_info: DepthInfo,
119    ) -> f32 {
120        self.reduction
121            .sensitivity(features, label, prediction, weight, depth_info)
122    }
123}
124
125#[derive(Serialize, Deserialize)]
126pub struct ReductionTypeDescription {
127    input_label_type: LabelType,
128    output_label_type: Option<LabelType>,
129    input_features_type: FeaturesType,
130    output_features_type: Option<FeaturesType>,
131    input_prediction_type: Option<PredictionType>,
132    output_prediction_type: PredictionType,
133}
134
135impl ReductionTypeDescription {
136    pub fn input_label_type(&self) -> LabelType {
137        self.input_label_type
138    }
139    pub fn output_label_type(&self) -> Option<LabelType> {
140        self.output_label_type
141    }
142    pub fn input_features_type(&self) -> FeaturesType {
143        self.input_features_type
144    }
145    pub fn output_features_type(&self) -> Option<FeaturesType> {
146        self.output_features_type
147    }
148    pub fn input_prediction_type(&self) -> Option<PredictionType> {
149        self.input_prediction_type
150    }
151    pub fn output_prediction_type(&self) -> PredictionType {
152        self.output_prediction_type
153    }
154}
155pub struct ReductionTypeDescriptionBuilder {
156    types: ReductionTypeDescription,
157}
158
159impl ReductionTypeDescriptionBuilder {
160    pub fn new(
161        input_label_type: LabelType,
162        input_features_type: FeaturesType,
163        output_prediction_type: PredictionType,
164    ) -> ReductionTypeDescriptionBuilder {
165        ReductionTypeDescriptionBuilder {
166            types: ReductionTypeDescription::new(
167                input_label_type,
168                None,
169                input_features_type,
170                None,
171                None,
172                output_prediction_type,
173            ),
174        }
175    }
176
177    pub fn with_output_label_type(
178        mut self,
179        output_label_type: LabelType,
180    ) -> ReductionTypeDescriptionBuilder {
181        self.types.output_label_type = Some(output_label_type);
182        self
183    }
184
185    pub fn with_output_features_type(
186        mut self,
187        output_features_type: FeaturesType,
188    ) -> ReductionTypeDescriptionBuilder {
189        self.types.output_features_type = Some(output_features_type);
190        self
191    }
192
193    pub fn with_input_prediction_type(
194        mut self,
195        input_prediction_type: PredictionType,
196    ) -> ReductionTypeDescriptionBuilder {
197        self.types.input_prediction_type = Some(input_prediction_type);
198        self
199    }
200
201    pub fn build(self) -> ReductionTypeDescription {
202        self.types
203    }
204}
205
206impl ReductionTypeDescription {
207    fn new(
208        input_label_type: LabelType,
209        output_label_type: Option<LabelType>,
210        input_features_type: FeaturesType,
211        output_features_type: Option<FeaturesType>,
212        input_prediction_type: Option<PredictionType>,
213        output_prediction_type: PredictionType,
214    ) -> ReductionTypeDescription {
215        ReductionTypeDescription {
216            input_label_type,
217            output_label_type,
218            input_features_type,
219            output_features_type,
220            input_prediction_type,
221            output_prediction_type,
222        }
223    }
224
225    pub fn check_and_get_reason(&self, base: &ReductionTypeDescription) -> Option<String> {
226        let mut res = None;
227        if self.output_label_type != Some(base.input_label_type) {
228            res = Some(format!(
229                "input_label_type: {:?} != {:?}",
230                self.input_label_type, base.input_label_type
231            ));
232        }
233
234        if self.output_features_type != Some(base.input_features_type) {
235            res = Some(format!(
236                "output_features_type: {:?} != {:?}",
237                self.output_features_type, base.output_features_type
238            ));
239        }
240
241        if self.input_prediction_type != Some(base.output_prediction_type) {
242            res = Some(format!(
243                "output_prediction_type: {:?} != {:?}",
244                self.output_prediction_type, base.output_prediction_type
245            ));
246        }
247        res
248    }
249}
250
251#[derive(Serialize, Deserialize)]
252pub struct ReductionWrapper {
253    typename: PascalCaseString,
254    reduction: Box<dyn ReductionImpl>,
255    type_description: ReductionTypeDescription,
256    num_models_below: ModelIndex,
257}
258
259impl ReductionWrapper {
260    pub fn new(
261        typename: PascalCaseString,
262        reduction: Box<dyn ReductionImpl>,
263        type_description: ReductionTypeDescription,
264        num_models_below: ModelIndex,
265    ) -> ReductionWrapper {
266        ReductionWrapper {
267            typename,
268            reduction,
269            type_description,
270            num_models_below,
271        }
272    }
273
274    pub fn types(&self) -> &ReductionTypeDescription {
275        &self.type_description
276    }
277
278    pub fn typename(&self) -> &str {
279        self.typename.as_ref()
280    }
281}
282
283#[typetag::serde(tag = "type")]
284pub trait ReductionImpl: Send {
285    fn predict(
286        &self,
287        features: &mut Features,
288        depth_info: &mut DepthInfo,
289        model_offset: ModelIndex,
290    ) -> Prediction;
291    fn predict_then_learn(
292        &mut self,
293        features: &mut Features,
294        label: &Label,
295        depth_info: &mut DepthInfo,
296        model_offset: ModelIndex,
297    ) -> Prediction {
298        let depth_info_copy: DepthInfo = *depth_info;
299        let prediction = self.predict(features, depth_info, model_offset);
300        let depth_info_copy2: DepthInfo = depth_info_copy;
301        self.learn(features, label, depth_info, model_offset);
302        assert!(depth_info == &depth_info_copy2);
303        assert!(depth_info == &depth_info_copy);
304        prediction
305    }
306
307    fn learn(
308        &mut self,
309        features: &mut Features,
310        label: &Label,
311        depth_info: &mut DepthInfo,
312        model_offset: ModelIndex,
313    );
314    fn sensitivity(
315        &self,
316        features: &Features,
317        label: f32,
318        prediction: f32,
319        weight: f32,
320        depth_info: DepthInfo,
321    ) -> f32 {
322        self.children()
323            .first()
324            .unwrap()
325            .sensitivity(features, label, prediction, weight, depth_info)
326    }
327    fn children(&self) -> Vec<&ReductionWrapper>;
328}