reductionml_core/parsers/
dsjson_parser.rs1use core::f32;
2
3use serde_json_borrow::Value;
4
5use crate::error::Result;
6
7use crate::object_pool::Pool;
8use crate::parsers::ParsedFeature;
9use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
10use crate::types::{Features, Label, LabelType};
11use crate::{CBAdfFeatures, CBLabel, FeatureMask, FeaturesType};
12
13use super::{ParsedNamespaceInfo, TextModeParser, TextModeParserFactory};
14
15#[derive(Default)]
16pub struct DsJsonParserFactory;
17impl TextModeParserFactory for DsJsonParserFactory {
18 type Parser = DsJsonParser;
19
20 fn create(
21 &self,
22 features_type: FeaturesType,
23 label_type: LabelType,
24 hash_seed: u32,
25 num_bits: u8,
26 pool: std::sync::Arc<Pool<SparseFeatures>>,
27 ) -> DsJsonParser {
28 if features_type != FeaturesType::SparseCBAdf {
30 panic!("DsJsonParser only supports SparseCBAdf")
31 }
32
33 if label_type != LabelType::CB {
34 panic!("DsJsonParser only supports CB labels")
35 }
36
37 DsJsonParser {
38 _feature_type: features_type,
39 _label_type: label_type,
40 hash_seed,
41 num_bits,
42 pool,
43 }
44 }
45}
46
47pub struct DsJsonParser {
48 _feature_type: FeaturesType,
49 _label_type: LabelType,
50 hash_seed: u32,
51 num_bits: u8,
52 pool: std::sync::Arc<Pool<SparseFeatures>>,
53}
54
55impl DsJsonParser {
56 pub fn handle_features(
57 &self,
58 features: &mut SparseFeatures,
59 object_key: &str,
60 json_value: &Value,
61 namespace_stack: &mut Vec<Namespace>,
62 ) {
63 if object_key.starts_with('_') {
65 return;
66 }
67
68 match json_value {
70 Value::Null => panic!("Null is not supported"),
71 Value::Bool(true) => {
72 let current_ns = *namespace_stack
73 .last()
74 .expect("namespace stack should not be empty here");
75 let current_ns_hash = current_ns.hash(self.hash_seed);
76 let current_feats = features.get_or_create_namespace(current_ns);
77 current_feats.add_feature(
78 ParsedFeature::Simple { name: object_key }
79 .hash(current_ns_hash)
80 .mask(FeatureMask::from_num_bits(self.num_bits)),
81 1.0,
82 );
83 }
84 Value::Bool(false) => (),
85 Value::Number(value) => {
86 let current_ns = *namespace_stack
87 .last()
88 .expect("namespace stack should not be empty here");
89 let current_ns_hash = current_ns.hash(self.hash_seed);
90 let current_feats = features.get_or_create_namespace(current_ns);
91 current_feats.add_feature(
92 ParsedFeature::Simple { name: object_key }
93 .hash(current_ns_hash)
94 .mask(FeatureMask::from_num_bits(self.num_bits)),
95 value.as_f64().unwrap() as f32,
96 );
97 }
98 Value::Str(value) => {
99 let current_ns = namespace_stack
100 .last()
101 .expect("namespace stack should not be empty here");
102 let current_ns_hash = current_ns.hash(self.hash_seed);
103 let current_feats = features.get_or_create_namespace(*current_ns);
104 current_feats.add_feature(
105 ParsedFeature::SimpleWithStringValue {
106 name: object_key,
107 value,
108 }
109 .hash(current_ns_hash)
110 .mask(FeatureMask::from_num_bits(self.num_bits)),
111 1.0,
112 );
113 }
114 Value::Array(value) => {
115 namespace_stack.push(Namespace::from_name(object_key, self.hash_seed));
116 let current_ns = *namespace_stack
117 .last()
118 .expect("namespace stack should not be empty here");
119 let current_ns_hash = current_ns.hash(self.hash_seed);
120 for (anon_idx, v) in value.iter().enumerate() {
121 match v {
122 Value::Number(value) => {
123 let current_feats = features.get_or_create_namespace(current_ns);
126 current_feats.add_feature(
127 ParsedFeature::Anonymous {
128 offset: anon_idx as u32,
129 }
130 .hash(current_ns_hash)
131 .mask(FeatureMask::from_num_bits(self.num_bits)),
132 value.as_f64().unwrap() as f32,
133 );
134 }
135 Value::Object(_) => {
136 self.handle_features(features, object_key, v, namespace_stack);
137 }
138 Value::Null => (),
140 _ => panic!(
141 "Array of non-number or object is not supported key:{} value:{:?}",
142 object_key, v
143 ),
144 }
145 }
146 namespace_stack.pop().unwrap();
147 }
148 Value::Object(value) => {
149 namespace_stack.push(Namespace::from_name(object_key, self.hash_seed));
150 for (key, v) in value {
151 self.handle_features(features, key, v, namespace_stack);
152 }
153 namespace_stack.pop().unwrap();
154 }
155 }
156 }
157}
158
159impl TextModeParser for DsJsonParser {
160 fn get_next_chunk(
161 &self,
162 input: &mut dyn std::io::BufRead,
163 mut output_buffer: String,
164 ) -> Result<Option<String>> {
165 output_buffer.clear();
166 input.read_line(&mut output_buffer)?;
167 if output_buffer.is_empty() {
168 return Ok(None);
169 }
170 Ok(Some(output_buffer))
171 }
172
173 fn parse_chunk<'a, 'b>(&self, chunk: &'a str) -> Result<(Features<'b>, Option<Label>)> {
174 let json: Value = serde_json::from_str(chunk).expect("JSON was not well-formatted");
175
176 let mut namespace_stack = Vec::new();
177
178 let mut shared_ex = self.pool.get_object();
179 self.handle_features(&mut shared_ex, " ", json.get("c"), &mut namespace_stack);
180 assert!(namespace_stack.is_empty());
181
182 let mut actions = Vec::new();
183 for item in json.get("c").get("_multi").iter_array().unwrap() {
184 let mut action = self.pool.get_object();
185 self.handle_features(&mut action, " ", item, &mut namespace_stack);
186 actions.push(action);
187 assert!(namespace_stack.is_empty());
188 }
189
190 let label = match (
191 json.get("_label_cost"),
192 json.get("_label_probability"),
193 json.get("_labelIndex"),
194 ) {
195 (Value::Number(cost), Value::Number(prob), Value::Number(action)) => Some(CBLabel {
196 action: action.as_u64().unwrap() as usize,
197 cost: cost.as_f64().unwrap() as f32,
198 probability: prob.as_f64().unwrap() as f32,
199 }),
200 (Value::Null, Value::Null, Value::Null) => None,
201 _ => panic!("Invalid label, all 3 or none must be present"),
202 };
203
204 Ok((
205 Features::SparseCBAdf(CBAdfFeatures {
206 shared: Some(shared_ex),
207 actions,
208 }),
209 label.map(Label::CB),
210 ))
211 }
212
213 fn extract_feature_names<'a>(
214 &self,
215 _chunk: &'a str,
216 ) -> Result<std::collections::HashMap<ParsedNamespaceInfo<'a>, Vec<ParsedFeature<'a>>>> {
217 todo!()
218 }
219}
220
221#[cfg(test)]
222mod test {
223 use std::sync::Arc;
224
225 use approx::assert_relative_eq;
226 use serde_json::json;
227
228 use crate::{
229 object_pool::Pool,
230 parsers::{DsJsonParserFactory, TextModeParser, TextModeParserFactory},
231 sparse_namespaced_features::Namespace,
232 utils::AsInner,
233 CBAdfFeatures, CBLabel, FeaturesType, LabelType,
234 };
235 #[test]
236 fn extract_dsjson_test_chain_hash() {
237 let json_obj = json!({
238 "_label_cost": -0.0,
239 "_label_probability": 0.05000000074505806,
240 "_label_Action": 4,
241 "_labelIndex": 3,
242 "o": [
243 {
244 "v": 0.0,
245 "EventId": "13118d9b4c114f8485d9dec417e3aefe",
246 "ActionTaken": false
247 }
248 ],
249 "Timestamp": "2021-02-04T16:31:29.2460000Z",
250 "Version": "1",
251 "EventId": "13118d9b4c114f8485d9dec417e3aefe",
252 "a": [4, 2, 1, 3],
253 "c": {
254 "bool_true": true,
255 "bool_false": false,
256 "numbers": [4, 5.6],
257 "FromUrl": [
258 { "timeofday": "Afternoon", "weather": "Sunny", "name": "Cathy" }
259 ],
260 "_multi": [
261 {
262 "_tag": "Cappucino",
263 "i": { "constant": 1, "id": "Cappucino" },
264 "j": [
265 {
266 "type": "hot",
267 "origin": "kenya",
268 "organic": "yes",
269 "roast": "dark"
270 }
271 ]
272 }
273 ]
274 },
275 "p": [0.05, 0.05, 0.05, 0.85],
276 "VWState": {
277 "m": "ff0744c1aa494e1ab39ba0c78d048146/550c12cbd3aa47f09fbed3387fb9c6ec"
278 },
279 "_original_label_cost": -0.0
280 });
281
282 let pool = Arc::new(Pool::new());
283 let parser = DsJsonParserFactory::default().create(
284 FeaturesType::SparseCBAdf,
285 LabelType::CB,
286 0,
287 18,
288 pool,
289 );
290
291 let (features, label) = parser.parse_chunk(&json_obj.to_string()).unwrap();
292 let cb_label: &CBLabel = label.as_ref().unwrap().as_inner().unwrap();
293 assert_eq!(cb_label.action, 3);
294 assert_relative_eq!(cb_label.cost, 0.0);
295 assert_relative_eq!(cb_label.probability, 0.05);
296
297 let cb_feats: &CBAdfFeatures = features.as_inner().unwrap();
298 assert_eq!(cb_feats.actions.len(), 1);
299 assert!(cb_feats.shared.is_some());
300 let shared = cb_feats.shared.as_ref().unwrap();
301 assert_eq!(shared.namespaces().count(), 3);
302 let shared_default_ns = shared.get_namespace(Namespace::Default).unwrap();
303 assert_eq!(shared_default_ns.iter().count(), 1);
304
305 let shared_from_url_ns = shared
306 .get_namespace(Namespace::from_name("FromUrl", 0))
307 .unwrap();
308 assert_eq!(shared_from_url_ns.iter().count(), 3);
309
310 let shared_numbers_ns = shared
311 .get_namespace(Namespace::from_name("numbers", 0))
312 .unwrap();
313 assert_eq!(shared_numbers_ns.iter().count(), 2);
314 assert_relative_eq!(
315 shared_numbers_ns.iter().map(|(_, val)| val).sum::<f32>(),
316 9.6
317 );
318
319 let action = cb_feats.actions.get(0).unwrap();
320 assert_eq!(action.namespaces().count(), 2);
321 assert!(action.get_namespace(Namespace::Default).is_none());
322 let action_i_ns = action.get_namespace(Namespace::from_name("i", 0)).unwrap();
323 assert_eq!(action_i_ns.iter().count(), 2);
324 let action_j_ns = action.get_namespace(Namespace::from_name("j", 0)).unwrap();
325 assert_eq!(action_j_ns.iter().count(), 4);
326 }
327}