reductionml_core/parsers/
dsjson_parser.rs

1use core::f32;
2
3use serde_json_borrow::Value;
4
5use crate::error::Result;
6
7use crate::object_pool::Pool;
8use crate::parsers::ParsedFeature;
9use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
10use crate::types::{Features, Label, LabelType};
11use crate::{CBAdfFeatures, CBLabel, FeatureMask, FeaturesType};
12
13use super::{ParsedNamespaceInfo, TextModeParser, TextModeParserFactory};
14
15#[derive(Default)]
16pub struct DsJsonParserFactory;
17impl TextModeParserFactory for DsJsonParserFactory {
18    type Parser = DsJsonParser;
19
20    fn create(
21        &self,
22        features_type: FeaturesType,
23        label_type: LabelType,
24        hash_seed: u32,
25        num_bits: u8,
26        pool: std::sync::Arc<Pool<SparseFeatures>>,
27    ) -> DsJsonParser {
28        // Only supports CB
29        if features_type != FeaturesType::SparseCBAdf {
30            panic!("DsJsonParser only supports SparseCBAdf")
31        }
32
33        if label_type != LabelType::CB {
34            panic!("DsJsonParser only supports CB labels")
35        }
36
37        DsJsonParser {
38            _feature_type: features_type,
39            _label_type: label_type,
40            hash_seed,
41            num_bits,
42            pool,
43        }
44    }
45}
46
47pub struct DsJsonParser {
48    _feature_type: FeaturesType,
49    _label_type: LabelType,
50    hash_seed: u32,
51    num_bits: u8,
52    pool: std::sync::Arc<Pool<SparseFeatures>>,
53}
54
55impl DsJsonParser {
56    pub fn handle_features(
57        &self,
58        features: &mut SparseFeatures,
59        object_key: &str,
60        json_value: &Value,
61        namespace_stack: &mut Vec<Namespace>,
62    ) {
63        // All underscore prefixed keys are ignored.
64        if object_key.starts_with('_') {
65            return;
66        }
67
68        // skip everything with _
69        match json_value {
70            Value::Null => panic!("Null is not supported"),
71            Value::Bool(true) => {
72                let current_ns = *namespace_stack
73                    .last()
74                    .expect("namespace stack should not be empty here");
75                let current_ns_hash = current_ns.hash(self.hash_seed);
76                let current_feats = features.get_or_create_namespace(current_ns);
77                current_feats.add_feature(
78                    ParsedFeature::Simple { name: object_key }
79                        .hash(current_ns_hash)
80                        .mask(FeatureMask::from_num_bits(self.num_bits)),
81                    1.0,
82                );
83            }
84            Value::Bool(false) => (),
85            Value::Number(value) => {
86                let current_ns = *namespace_stack
87                    .last()
88                    .expect("namespace stack should not be empty here");
89                let current_ns_hash = current_ns.hash(self.hash_seed);
90                let current_feats = features.get_or_create_namespace(current_ns);
91                current_feats.add_feature(
92                    ParsedFeature::Simple { name: object_key }
93                        .hash(current_ns_hash)
94                        .mask(FeatureMask::from_num_bits(self.num_bits)),
95                    value.as_f64().unwrap() as f32,
96                );
97            }
98            Value::Str(value) => {
99                let current_ns = namespace_stack
100                    .last()
101                    .expect("namespace stack should not be empty here");
102                let current_ns_hash = current_ns.hash(self.hash_seed);
103                let current_feats = features.get_or_create_namespace(*current_ns);
104                current_feats.add_feature(
105                    ParsedFeature::SimpleWithStringValue {
106                        name: object_key,
107                        value,
108                    }
109                    .hash(current_ns_hash)
110                    .mask(FeatureMask::from_num_bits(self.num_bits)),
111                    1.0,
112                );
113            }
114            Value::Array(value) => {
115                namespace_stack.push(Namespace::from_name(object_key, self.hash_seed));
116                let current_ns = *namespace_stack
117                    .last()
118                    .expect("namespace stack should not be empty here");
119                let current_ns_hash = current_ns.hash(self.hash_seed);
120                for (anon_idx, v) in value.iter().enumerate() {
121                    match v {
122                        Value::Number(value) => {
123                            // Not super efficient but it works
124                            // Doing this in the outside doesn't work as it would mean two mutable borrows.
125                            let current_feats = features.get_or_create_namespace(current_ns);
126                            current_feats.add_feature(
127                                ParsedFeature::Anonymous {
128                                    offset: anon_idx as u32,
129                                }
130                                .hash(current_ns_hash)
131                                .mask(FeatureMask::from_num_bits(self.num_bits)),
132                                value.as_f64().unwrap() as f32,
133                            );
134                        }
135                        Value::Object(_) => {
136                            self.handle_features(features, object_key, v, namespace_stack);
137                        }
138                        // Just ignore null and do nothing
139                        Value::Null => (),
140                        _ => panic!(
141                            "Array of non-number or object is not supported key:{} value:{:?}",
142                            object_key, v
143                        ),
144                    }
145                }
146                namespace_stack.pop().unwrap();
147            }
148            Value::Object(value) => {
149                namespace_stack.push(Namespace::from_name(object_key, self.hash_seed));
150                for (key, v) in value {
151                    self.handle_features(features, key, v, namespace_stack);
152                }
153                namespace_stack.pop().unwrap();
154            }
155        }
156    }
157}
158
159impl TextModeParser for DsJsonParser {
160    fn get_next_chunk(
161        &self,
162        input: &mut dyn std::io::BufRead,
163        mut output_buffer: String,
164    ) -> Result<Option<String>> {
165        output_buffer.clear();
166        input.read_line(&mut output_buffer)?;
167        if output_buffer.is_empty() {
168            return Ok(None);
169        }
170        Ok(Some(output_buffer))
171    }
172
173    fn parse_chunk<'a, 'b>(&self, chunk: &'a str) -> Result<(Features<'b>, Option<Label>)> {
174        let json: Value = serde_json::from_str(chunk).expect("JSON was not well-formatted");
175
176        let mut namespace_stack = Vec::new();
177
178        let mut shared_ex = self.pool.get_object();
179        self.handle_features(&mut shared_ex, " ", json.get("c"), &mut namespace_stack);
180        assert!(namespace_stack.is_empty());
181
182        let mut actions = Vec::new();
183        for item in json.get("c").get("_multi").iter_array().unwrap() {
184            let mut action = self.pool.get_object();
185            self.handle_features(&mut action, " ", item, &mut namespace_stack);
186            actions.push(action);
187            assert!(namespace_stack.is_empty());
188        }
189
190        let label = match (
191            json.get("_label_cost"),
192            json.get("_label_probability"),
193            json.get("_labelIndex"),
194        ) {
195            (Value::Number(cost), Value::Number(prob), Value::Number(action)) => Some(CBLabel {
196                action: action.as_u64().unwrap() as usize,
197                cost: cost.as_f64().unwrap() as f32,
198                probability: prob.as_f64().unwrap() as f32,
199            }),
200            (Value::Null, Value::Null, Value::Null) => None,
201            _ => panic!("Invalid label, all 3 or none must be present"),
202        };
203
204        Ok((
205            Features::SparseCBAdf(CBAdfFeatures {
206                shared: Some(shared_ex),
207                actions,
208            }),
209            label.map(Label::CB),
210        ))
211    }
212
213    fn extract_feature_names<'a>(
214        &self,
215        _chunk: &'a str,
216    ) -> Result<std::collections::HashMap<ParsedNamespaceInfo<'a>, Vec<ParsedFeature<'a>>>> {
217        todo!()
218    }
219}
220
221#[cfg(test)]
222mod test {
223    use std::sync::Arc;
224
225    use approx::assert_relative_eq;
226    use serde_json::json;
227
228    use crate::{
229        object_pool::Pool,
230        parsers::{DsJsonParserFactory, TextModeParser, TextModeParserFactory},
231        sparse_namespaced_features::Namespace,
232        utils::AsInner,
233        CBAdfFeatures, CBLabel, FeaturesType, LabelType,
234    };
235    #[test]
236    fn extract_dsjson_test_chain_hash() {
237        let json_obj = json!({
238          "_label_cost": -0.0,
239          "_label_probability": 0.05000000074505806,
240          "_label_Action": 4,
241          "_labelIndex": 3,
242          "o": [
243            {
244              "v": 0.0,
245              "EventId": "13118d9b4c114f8485d9dec417e3aefe",
246              "ActionTaken": false
247            }
248          ],
249          "Timestamp": "2021-02-04T16:31:29.2460000Z",
250          "Version": "1",
251          "EventId": "13118d9b4c114f8485d9dec417e3aefe",
252          "a": [4, 2, 1, 3],
253          "c": {
254            "bool_true": true,
255            "bool_false": false,
256            "numbers": [4, 5.6],
257            "FromUrl": [
258              { "timeofday": "Afternoon", "weather": "Sunny", "name": "Cathy" }
259            ],
260            "_multi": [
261              {
262                "_tag": "Cappucino",
263                "i": { "constant": 1, "id": "Cappucino" },
264                "j": [
265                  {
266                    "type": "hot",
267                    "origin": "kenya",
268                    "organic": "yes",
269                    "roast": "dark"
270                  }
271                ]
272              }
273            ]
274          },
275          "p": [0.05, 0.05, 0.05, 0.85],
276          "VWState": {
277            "m": "ff0744c1aa494e1ab39ba0c78d048146/550c12cbd3aa47f09fbed3387fb9c6ec"
278          },
279          "_original_label_cost": -0.0
280        });
281
282        let pool = Arc::new(Pool::new());
283        let parser = DsJsonParserFactory::default().create(
284            FeaturesType::SparseCBAdf,
285            LabelType::CB,
286            0,
287            18,
288            pool,
289        );
290
291        let (features, label) = parser.parse_chunk(&json_obj.to_string()).unwrap();
292        let cb_label: &CBLabel = label.as_ref().unwrap().as_inner().unwrap();
293        assert_eq!(cb_label.action, 3);
294        assert_relative_eq!(cb_label.cost, 0.0);
295        assert_relative_eq!(cb_label.probability, 0.05);
296
297        let cb_feats: &CBAdfFeatures = features.as_inner().unwrap();
298        assert_eq!(cb_feats.actions.len(), 1);
299        assert!(cb_feats.shared.is_some());
300        let shared = cb_feats.shared.as_ref().unwrap();
301        assert_eq!(shared.namespaces().count(), 3);
302        let shared_default_ns = shared.get_namespace(Namespace::Default).unwrap();
303        assert_eq!(shared_default_ns.iter().count(), 1);
304
305        let shared_from_url_ns = shared
306            .get_namespace(Namespace::from_name("FromUrl", 0))
307            .unwrap();
308        assert_eq!(shared_from_url_ns.iter().count(), 3);
309
310        let shared_numbers_ns = shared
311            .get_namespace(Namespace::from_name("numbers", 0))
312            .unwrap();
313        assert_eq!(shared_numbers_ns.iter().count(), 2);
314        assert_relative_eq!(
315            shared_numbers_ns.iter().map(|(_, val)| val).sum::<f32>(),
316            9.6
317        );
318
319        let action = cb_feats.actions.get(0).unwrap();
320        assert_eq!(action.namespaces().count(), 2);
321        assert!(action.get_namespace(Namespace::Default).is_none());
322        let action_i_ns = action.get_namespace(Namespace::from_name("i", 0)).unwrap();
323        assert_eq!(action_i_ns.iter().count(), 2);
324        let action_j_ns = action.get_namespace(Namespace::from_name("j", 0)).unwrap();
325        assert_eq!(action_j_ns.iter().count(), 4);
326    }
327}