reductionml_core/reductions/
debug.rs1use crate::error::Result;
2use crate::global_config::GlobalConfig;
3
4use crate::reduction::{
5 DepthInfo, ReductionImpl, ReductionTypeDescriptionBuilder, ReductionWrapper,
6};
7use crate::reduction_factory::{
8 create_reduction, JsonReductionConfig, PascalCaseString, ReductionConfig, ReductionFactory,
9};
10
11use crate::{impl_default_factory_functions, types::*, ModelIndex};
12use schemars::schema::RootSchema;
13use schemars::{schema_for, JsonSchema};
14use serde::{Deserialize, Serialize};
15use serde_default::DefaultFromSerde;
16use serde_json::json;
17
18#[derive(Deserialize, Serialize, JsonSchema, DefaultFromSerde)]
19#[serde(deny_unknown_fields)]
20#[serde(rename_all = "camelCase")]
21pub struct DebugConfig {
22 #[serde(default = "default_cb_type")]
23 id: String,
24 #[serde(default = "default_false")]
25 prediction: bool,
26 #[serde(default = "default_false")]
27 label: bool,
28 #[serde(default = "default_false")]
29 features: bool,
30 #[serde(default = "default_indent")]
31 indent: usize,
32 #[serde(default = "default_next")]
33 #[schemars(schema_with = "crate::config_schema::gen_json_reduction_config_schema")]
34 next: JsonReductionConfig,
35}
36
37fn default_cb_type() -> String {
38 "".to_owned()
39}
40
41fn default_false() -> bool {
42 false
43}
44
45fn default_indent() -> usize {
46 0
47}
48
49fn default_next() -> JsonReductionConfig {
50 JsonReductionConfig::new("Unknown".try_into().unwrap(), json!({}))
51}
52
53impl ReductionConfig for DebugConfig {
54 fn typename(&self) -> PascalCaseString {
55 "Debug".try_into().unwrap()
56 }
57
58 fn as_any(&self) -> &dyn std::any::Any {
59 self
60 }
61}
62
63#[derive(Serialize, Deserialize)]
64struct DebugReduction {
65 id: String,
66 indent: usize,
67 prediction: bool,
68 label: bool,
69 features: bool,
70 next: ReductionWrapper,
71}
72
73#[derive(Default)]
74pub struct DebugReductionFactory;
75
76impl ReductionFactory for DebugReductionFactory {
77 impl_default_factory_functions!("Debug", DebugConfig);
78
79 fn create(
80 &self,
81 config: &dyn ReductionConfig,
82 global_config: &GlobalConfig,
83 num_models_above: ModelIndex,
84 ) -> Result<ReductionWrapper> {
85 let config = config.as_any().downcast_ref::<DebugConfig>().unwrap();
86 let next_config = crate::reduction_factory::parse_config(&config.next)?;
87 let next: ReductionWrapper =
88 create_reduction(next_config.as_ref(), global_config, num_models_above)?;
89
90 let types: crate::reduction::ReductionTypeDescription =
91 ReductionTypeDescriptionBuilder::new(
92 next.types().input_label_type(),
93 next.types().input_features_type(),
94 next.types().output_prediction_type(),
95 )
96 .with_output_features_type(next.types().input_features_type())
97 .with_input_prediction_type(next.types().output_prediction_type())
98 .with_output_label_type(next.types().input_label_type())
99 .build();
100
101 if let Some(reason) = types.check_and_get_reason(next.types()) {
102 return Err(crate::error::Error::InvalidArgument(format!(
103 "Invalid reduction configuration: {}",
104 reason
105 )));
106 }
107
108 Ok(ReductionWrapper::new(
109 self.typename(),
110 Box::new(DebugReduction {
111 id: config.id.clone(),
112 indent: config.indent,
113 prediction: config.prediction,
114 features: config.features,
115 label: config.label,
116 next,
117 }),
118 types,
119 num_models_above,
120 ))
121 }
122}
123
124impl DebugReduction {
125 fn print_debug<S: AsRef<str>>(
126 &self,
127 func: S,
128 offset: ModelIndex,
129 depth_info: &DepthInfo,
130 msg: S,
131 ) {
132 let space = " ";
133 let indent = self.indent;
134 let id = &self.id;
135 let func = func.as_ref();
136 let msg = msg.as_ref();
137 let off = u8::from(offset);
138 let abs_off = u8::from(depth_info.absolute_offset());
139 eprintln!("{space:indent$}[{id}({func}), off: {off}, abs_off: {abs_off}] {msg}");
140 }
141}
142
143#[typetag::serde]
144impl ReductionImpl for DebugReduction {
145 fn predict(
146 &self,
147 features: &mut Features,
148 depth_info: &mut DepthInfo,
149 model_offset: ModelIndex,
150 ) -> Prediction {
151 if self.features {
152 self.print_debug(
153 "predict",
154 model_offset,
155 depth_info,
156 &format!("features: {:?}", features),
157 );
158 }
159 let prediction = self.next.predict(features, depth_info, 0.into());
160
161 if self.prediction {
162 self.print_debug(
163 "predict",
164 model_offset,
165 depth_info,
166 &format!("prediction: {:?}", prediction),
167 );
168 }
169 prediction
170 }
171
172 fn predict_then_learn(
173 &mut self,
174 features: &mut Features,
175 label: &Label,
176 depth_info: &mut DepthInfo,
177 model_offset: ModelIndex,
178 ) -> Prediction {
179 if self.features {
180 self.print_debug(
181 "predict_then_learn",
182 model_offset,
183 depth_info,
184 &format!("features: {:?}", features),
185 );
186 }
187
188 if self.label {
189 self.print_debug(
190 "predict_then_learn",
191 model_offset,
192 depth_info,
193 &format!("label: {:?}", label),
194 );
195 }
196
197 let prediction = self
198 .next
199 .predict_then_learn(features, label, depth_info, model_offset);
200 if self.prediction {
201 self.print_debug(
202 "predict_then_learn",
203 model_offset,
204 depth_info,
205 &format!("prediction: {:?}", prediction),
206 );
207 }
208 prediction
209 }
210
211 fn learn(
212 &mut self,
213 features: &mut Features,
214 label: &Label,
215 depth_info: &mut DepthInfo,
216 model_offset: ModelIndex,
217 ) {
218 if self.features {
219 self.print_debug(
220 "learn",
221 model_offset,
222 depth_info,
223 &format!("features: {:?}", features),
224 );
225 }
226
227 if self.label {
228 self.print_debug(
229 "learn",
230 model_offset,
231 depth_info,
232 &format!("label: {:?}", label),
233 );
234 }
235 self.next.learn(features, label, depth_info, 0.into());
236 }
237
238 fn children(&self) -> Vec<&ReductionWrapper> {
239 vec![&self.next]
240 }
241}