reductionml_core/parsers/
json_parser.rs

1use core::{f32, panic};
2
3use crate::error::Result;
4
5use crate::object_pool::Pool;
6use crate::parsers::ParsedFeature;
7use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
8use crate::types::{Features, Label, LabelType};
9use crate::{CBAdfFeatures, CBLabel, FeatureHash, FeatureMask, FeaturesType, SimpleLabel};
10
11use super::{TextModeParser, TextModeParserFactory};
12
13use serde_json_borrow::Value;
14
15pub fn to_features(
16    val: &Value,
17    mut output: SparseFeatures,
18    hash_seed: u32,
19    num_bits: u8,
20) -> SparseFeatures {
21    match val {
22        Value::Object(obj) => {
23            for (ns_name, value) in obj {
24                let ns = output.get_or_create_namespace(Namespace::from_name(ns_name, hash_seed));
25                let ns_hash = ns.namespace().hash(hash_seed);
26                let mask = FeatureMask::from_num_bits(num_bits);
27                match value {
28                    Value::Str(_) => todo!(),
29                    Value::Array(ar) => match ar.first() {
30                        Some(Value::Number(_)) => {
31                            let it = (u32::from(ns_hash)..(u32::from(ns_hash) + ar.len() as u32))
32                                .map(|x| FeatureHash::from(x).mask(mask));
33                            ns.add_features_with_iter(
34                                it,
35                                ar.into_iter().map(|x| {
36                                    x.as_f64().expect("Arrays must contain the same type") as f32
37                                }),
38                            );
39                        }
40                        Some(Value::Str(_)) => {
41                            ns.reserve(ar.len());
42                            for string in ar {
43                                let feat = ParsedFeature::Simple {
44                                    name: string
45                                        .as_str()
46                                        .expect("Arrays must contain the same type"),
47                                };
48                                ns.add_feature(feat.hash(ns_hash).mask(mask), 1.0);
49                            }
50                        }
51                        Some(_) => panic!("Not a number or string"),
52                        None => todo!(),
53                    },
54
55                    Value::Object(contents) => {
56                        for (key, value) in contents {
57                            match value {
58                                Value::Number(value) => {
59                                    let feat: ParsedFeature<'_> =
60                                        ParsedFeature::Simple { name: key };
61                                    ns.add_feature(
62                                        feat.hash(ns_hash).mask(mask),
63                                        value.as_f64().unwrap() as f32,
64                                    );
65                                }
66                                Value::Str(value) => {
67                                    let feat = ParsedFeature::SimpleWithStringValue {
68                                        name: key,
69                                        value: value,
70                                    };
71                                    ns.add_feature(feat.hash(ns_hash).mask(mask), 1.0);
72                                }
73                                Value::Bool(value) => {
74                                    if *value {
75                                        let feat = ParsedFeature::Simple { name: key };
76                                        ns.add_feature(feat.hash(ns_hash).mask(mask), 1.0);
77                                    }
78                                }
79                                _ => todo!(),
80                            }
81                        }
82                    }
83                    _ => todo!(),
84                }
85            }
86        }
87        _ => panic!("Not an object"),
88    }
89    output
90}
91
92#[derive(Default)]
93pub struct JsonParserFactory;
94impl TextModeParserFactory for JsonParserFactory {
95    type Parser = JsonParser;
96
97    fn create(
98        &self,
99        features_type: FeaturesType,
100        label_type: LabelType,
101        hash_seed: u32,
102        num_bits: u8,
103        pool: std::sync::Arc<Pool<SparseFeatures>>,
104    ) -> JsonParser {
105        JsonParser {
106            _feature_type: features_type,
107            _label_type: label_type,
108            hash_seed,
109            num_bits,
110            pool,
111        }
112    }
113}
114
115pub struct JsonParser {
116    _feature_type: FeaturesType,
117    _label_type: LabelType,
118    hash_seed: u32,
119    num_bits: u8,
120    pool: std::sync::Arc<Pool<SparseFeatures>>,
121}
122
123impl TextModeParser for JsonParser {
124    fn get_next_chunk(
125        &self,
126        input: &mut dyn std::io::BufRead,
127        mut output_buffer: String,
128    ) -> Result<Option<String>> {
129        output_buffer.clear();
130        input.read_line(&mut output_buffer)?;
131        if output_buffer.is_empty() {
132            return Ok(None);
133        }
134        Ok(Some(output_buffer))
135    }
136
137    fn parse_chunk<'a, 'b>(&self, chunk: &'a str) -> Result<(Features<'b>, Option<Label>)> {
138        let json: Value = serde_json::from_str(chunk).expect("JSON was not well-formatted");
139        Ok(match (self._feature_type, self._label_type) {
140            (FeaturesType::SparseSimple, LabelType::Simple) => {
141                let label = match json.get("label") {
142                    Value::Null => None,
143                    Value::Number(val) => Some(SimpleLabel::from(val.as_f64().unwrap() as f32)),
144                    val => {
145                        let l: SimpleLabel =
146                            serde_json::from_value(serde_json::Value::from(val.clone())).unwrap();
147                        Some(l)
148                    }
149                };
150
151                let features = match json.get("features") {
152                    Value::Null => panic!("No features found"),
153                    val => {
154                        let feats =
155                            to_features(val, self.pool.get_object(), self.hash_seed, self.num_bits);
156                        feats
157                    }
158                };
159
160                (Features::SparseSimple(features), label.map(|l| l.into()))
161            }
162            (FeaturesType::SparseCBAdf, LabelType::CB) => {
163                let label = match json.get("label") {
164                    Value::Null => None,
165                    val => {
166                        let l: CBLabel =
167                            serde_json::from_value(serde_json::Value::from(val.clone())).unwrap();
168                        Some(l)
169                    }
170                };
171
172                let shared = match json.get("shared") {
173                    Value::Null => None,
174                    val => {
175                        let feats =
176                            to_features(val, self.pool.get_object(), self.hash_seed, self.num_bits);
177                        Some(feats)
178                    }
179                };
180
181                let actions = match json.get("actions") {
182                    Value::Null => panic!("No actions found"),
183                    Value::Array(val) => val
184                        .iter()
185                        .map(|x| {
186                            to_features(x, self.pool.get_object(), self.hash_seed, self.num_bits)
187                        })
188                        .collect(),
189                    _ => panic!("Actions must be an array"),
190                };
191
192                (
193                    Features::SparseCBAdf(CBAdfFeatures { shared, actions }),
194                    label.map(|l| l.into()),
195                )
196            }
197
198            (_, _) => panic!("Feature type mismatch"),
199        })
200    }
201}
202
203#[cfg(test)]
204mod test {
205    use std::sync::Arc;
206
207    use approx::assert_relative_eq;
208    use serde_json::json;
209
210    use crate::{
211        object_pool::Pool,
212        parsers::{JsonParserFactory, TextModeParser, TextModeParserFactory},
213        sparse_namespaced_features::{Namespace, SparseFeatures},
214        utils::AsInner,
215        CBAdfFeatures, CBLabel, FeaturesType, LabelType, SimpleLabel,
216    };
217    #[test]
218    fn json_parse_cb() {
219        let json_obj = json!({
220            "label": {
221                "action": 3,
222                "cost": 0.0,
223                "probability": 0.05
224              },
225            "shared": {
226                ":default": {
227                    "bool_true": true,
228                    "bool_false": false
229                },
230                "numbers": [4, 5.6],
231                "FromUrl": {
232                    "timeofday": "Afternoon",
233                    "weather": "Sunny",
234                    "name": "Cathy"
235                }
236            },
237            "actions": [
238                {
239                "i": { "constant": 1, "id": "Cappucino" },
240                "j": {
241                    "type": "hot",
242                    "origin": "kenya",
243                    "organic": "yes",
244                    "roast": "dark"
245                }
246                }
247            ]
248        });
249
250        let pool = Arc::new(Pool::new());
251        let parser = JsonParserFactory::default().create(
252            FeaturesType::SparseCBAdf,
253            LabelType::CB,
254            0,
255            18,
256            pool,
257        );
258
259        let (features, label) = parser.parse_chunk(&json_obj.to_string()).unwrap();
260        let cb_label: &CBLabel = label.as_ref().unwrap().as_inner().unwrap();
261        assert_eq!(cb_label.action, 3);
262        assert_relative_eq!(cb_label.cost, 0.0);
263        assert_relative_eq!(cb_label.probability, 0.05);
264
265        let cb_feats: &CBAdfFeatures = features.as_inner().unwrap();
266        assert_eq!(cb_feats.actions.len(), 1);
267        assert!(cb_feats.shared.is_some());
268        let shared = cb_feats.shared.as_ref().unwrap();
269        assert_eq!(shared.namespaces().count(), 3);
270        let shared_default_ns = shared.get_namespace(Namespace::Default).unwrap();
271        assert_eq!(shared_default_ns.iter().count(), 1);
272
273        let shared_from_url_ns = shared
274            .get_namespace(Namespace::from_name("FromUrl", 0))
275            .unwrap();
276        assert_eq!(shared_from_url_ns.iter().count(), 3);
277
278        let shared_numbers_ns = shared
279            .get_namespace(Namespace::from_name("numbers", 0))
280            .unwrap();
281        assert_eq!(shared_numbers_ns.iter().count(), 2);
282        assert_relative_eq!(
283            shared_numbers_ns.iter().map(|(_, val)| val).sum::<f32>(),
284            9.6
285        );
286
287        let action = cb_feats.actions.get(0).unwrap();
288        assert_eq!(action.namespaces().count(), 2);
289        assert!(action.get_namespace(Namespace::Default).is_none());
290        let action_i_ns = action.get_namespace(Namespace::from_name("i", 0)).unwrap();
291        assert_eq!(action_i_ns.iter().count(), 2);
292        let action_j_ns = action.get_namespace(Namespace::from_name("j", 0)).unwrap();
293        assert_eq!(action_j_ns.iter().count(), 4);
294    }
295
296    #[test]
297    fn json_parse_simple() {
298        let json_obj = json!({
299            "label": {
300                "value": 0.2,
301                "weight": 0.4
302            },
303            "features" : {
304                ":default": {
305                    "bool_true": true,
306                    "bool_false": false
307                },
308                "numbers": [4, 5.6],
309                "FromUrl": {
310                    "timeofday": "Afternoon",
311                    "weather": "Sunny",
312                    "name": "Cathy"
313                }
314            }
315        });
316
317        let pool = Arc::new(Pool::new());
318        let parser = JsonParserFactory::default().create(
319            FeaturesType::SparseSimple,
320            LabelType::Simple,
321            0,
322            18,
323            pool,
324        );
325
326        let (features, label) = parser.parse_chunk(&json_obj.to_string()).unwrap();
327        let lbl: &SimpleLabel = label.as_ref().unwrap().as_inner().unwrap();
328        assert_relative_eq!(lbl.value(), 0.2);
329        assert_relative_eq!(lbl.weight(), 0.4);
330
331        let features: &SparseFeatures = features.as_inner().unwrap();
332        assert_eq!(features.namespaces().count(), 3);
333        let features_default_ns = features.get_namespace(Namespace::Default).unwrap();
334        assert_eq!(features_default_ns.iter().count(), 1);
335
336        let features_from_url_ns = features
337            .get_namespace(Namespace::from_name("FromUrl", 0))
338            .unwrap();
339        assert_eq!(features_from_url_ns.iter().count(), 3);
340
341        let features_numbers_ns = features
342            .get_namespace(Namespace::from_name("numbers", 0))
343            .unwrap();
344        assert_eq!(features_numbers_ns.iter().count(), 2);
345        assert_relative_eq!(
346            features_numbers_ns.iter().map(|(_, val)| val).sum::<f32>(),
347            9.6
348        );
349    }
350}