reductionml_core/reductions/
debug.rs

1use crate::error::Result;
2use crate::global_config::GlobalConfig;
3
4use crate::reduction::{
5    DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
6};
7use crate::reduction_factory::{
8    create_reduction, JsonReductionConfig, PascalCaseString, ReductionConfig, ReductionFactory,
9};
10
11use crate::{impl_default_factory_functions, types::*, ModelIndex};
12use schemars::schema::RootSchema;
13use schemars::{schema_for, JsonSchema};
14use serde::{Deserialize, Serialize};
15use serde_default::DefaultFromSerde;
16use serde_json::json;
17
18#[derive(Deserialize, Serialize, JsonSchema, DefaultFromSerde)]
19#[serde(deny_unknown_fields)]
20#[serde(rename_all = "camelCase")]
21pub struct DebugConfig {
22    #[serde(default = "default_cb_type")]
23    id: String,
24    #[serde(default = "default_false")]
25    prediction: bool,
26    #[serde(default = "default_false")]
27    label: bool,
28    #[serde(default = "default_false")]
29    features: bool,
30    #[serde(default = "default_indent")]
31    indent: usize,
32    #[serde(default = "default_next")]
33    #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
34    next: JsonReductionConfig,
35}
36
37fn default_cb_type() -> String {
38    "".to_owned()
39}
40
41fn default_false() -> bool {
42    false
43}
44
45fn default_indent() -> usize {
46    0
47}
48
49fn default_next() -> JsonReductionConfig {
50    JsonReductionConfig::new("Unknown".try_into().unwrap(), json!({}))
51}
52
53impl ReductionConfig for DebugConfig {
54    fn typename(&self) -> PascalCaseString {
55        "Debug".try_into().unwrap()
56    }
57
58    fn as_any(&self) -> &dyn std::any::Any {
59        self
60    }
61}
62
63#[derive(Serialize, Deserialize)]
64struct DebugReduction {
65    id: String,
66    indent: usize,
67    prediction: bool,
68    label: bool,
69    features: bool,
70    next: ReductionWrapper,
71}
72
73#[derive(Default)]
74pub struct DebugReductionFactory;
75
76impl ReductionFactory for DebugReductionFactory {
77    impl_default_factory_functions!("Debug", DebugConfig);
78
79    fn create(
80        &self,
81        config: &dyn ReductionConfig,
82        global_config: &GlobalConfig,
83        num_models_above: ModelIndex,
84    ) -> Result<ReductionWrapper> {
85        let config = config.as_any().downcast_ref::<DebugConfig>().unwrap();
86        let next_config = crate::reduction_factory::parse_config(&config.next)?;
87        let next: ReductionWrapper =
88            create_reduction(next_config.as_ref(), global_config, num_models_above)?;
89
90        let types: crate::reduction::ReductionTypeDescription =
91            ReductionTypeDescriptionBuilder::new(
92                next.types().input_label_type(),
93                next.types().input_features_type(),
94                next.types().output_prediction_type(),
95            )
96            .with_output_features_type(next.types().input_features_type())
97            .with_input_prediction_type(next.types().output_prediction_type())
98            .with_output_label_type(next.types().input_label_type())
99            .build();
100
101        if let Some(reason) = types.check_and_get_reason(next.types()) {
102            return Err(crate::error::Error::InvalidArgument(format!(
103                "Invalid reduction configuration: {}",
104                reason
105            )));
106        }
107
108        Ok(ReductionWrapper::new(
109            self.typename(),
110            Box::new(DebugReduction {
111                id: config.id.clone(),
112                indent: config.indent,
113                prediction: config.prediction,
114                features: config.features,
115                label: config.label,
116                next,
117            }),
118            types,
119            num_models_above,
120        ))
121    }
122}
123
124impl DebugReduction {
125    fn print_debug<S: AsRef<str>>(
126        &self,
127        func: S,
128        offset: ModelIndex,
129        depth_info: &DepthInfo,
130        msg: S,
131    ) {
132        let space = " ";
133        let indent = self.indent;
134        let id = &self.id;
135        let func = func.as_ref();
136        let msg = msg.as_ref();
137        let off = u8::from(offset);
138        let abs_off = u8::from(depth_info.absolute_offset());
139        eprintln!("{space:indent$}[{id}({func}), off: {off}, abs_off: {abs_off}] {msg}");
140    }
141}
142
143#[typetag::serde]
144impl ReductionImpl for DebugReduction {
145    fn predict(
146        &self,
147        features: &mut Features,
148        depth_info: &mut DepthInfo,
149        model_offset: ModelIndex,
150    ) -> Prediction {
151        if self.features {
152            self.print_debug(
153                "predict",
154                model_offset,
155                depth_info,
156                &format!("features: {:?}", features),
157            );
158        }
159        let prediction = self.next.predict(features, depth_info, 0.into());
160
161        if self.prediction {
162            self.print_debug(
163                "predict",
164                model_offset,
165                depth_info,
166                &format!("prediction: {:?}", prediction),
167            );
168        }
169        prediction
170    }
171
172    fn predict_then_learn(
173        &mut self,
174        features: &mut Features,
175        label: &Label,
176        depth_info: &mut DepthInfo,
177        model_offset: ModelIndex,
178    ) -> Prediction {
179        if self.features {
180            self.print_debug(
181                "predict_then_learn",
182                model_offset,
183                depth_info,
184                &format!("features: {:?}", features),
185            );
186        }
187
188        if self.label {
189            self.print_debug(
190                "predict_then_learn",
191                model_offset,
192                depth_info,
193                &format!("label: {:?}", label),
194            );
195        }
196
197        let prediction = self
198            .next
199            .predict_then_learn(features, label, depth_info, model_offset);
200        if self.prediction {
201            self.print_debug(
202                "predict_then_learn",
203                model_offset,
204                depth_info,
205                &format!("prediction: {:?}", prediction),
206            );
207        }
208        prediction
209    }
210
211    fn learn(
212        &mut self,
213        features: &mut Features,
214        label: &Label,
215        depth_info: &mut DepthInfo,
216        model_offset: ModelIndex,
217    ) {
218        if self.features {
219            self.print_debug(
220                "learn",
221                model_offset,
222                depth_info,
223                &format!("features: {:?}", features),
224            );
225        }
226
227        if self.label {
228            self.print_debug(
229                "learn",
230                model_offset,
231                depth_info,
232                &format!("label: {:?}", label),
233            );
234        }
235        self.next.learn(features, label, depth_info, 0.into());
236    }
237
238    fn children(&self) -> Vec<&ReductionWrapper> {
239        vec![&self.next]
240    }
241}