reductionml_core/parsers/
json_parser.rs1use core::{f32, panic};
2
3use crate::error::Result;
4
5use crate::object_pool::Pool;
6use crate::parsers::ParsedFeature;
7use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
8use crate::types::{Features, Label, LabelType};
9use crate::{CBAdfFeatures, CBLabel, FeatureHash, FeatureMask, FeaturesType, SimpleLabel};
10
11use super::{TextModeParser, TextModeParserFactory};
12
13use serde_json_borrow::Value;
14
15pub fn to_features(
16 val: &Value,
17 mut output: SparseFeatures,
18 hash_seed: u32,
19 num_bits: u8,
20) -> SparseFeatures {
21 match val {
22 Value::Object(obj) => {
23 for (ns_name, value) in obj {
24 let ns = output.get_or_create_namespace(Namespace::from_name(ns_name, hash_seed));
25 let ns_hash = ns.namespace().hash(hash_seed);
26 let mask = FeatureMask::from_num_bits(num_bits);
27 match value {
28 Value::Str(_) => todo!(),
29 Value::Array(ar) => match ar.first() {
30 Some(Value::Number(_)) => {
31 let it = (u32::from(ns_hash)..(u32::from(ns_hash) + ar.len() as u32))
32 .map(|x| FeatureHash::from(x).mask(mask));
33 ns.add_features_with_iter(
34 it,
35 ar.into_iter().map(|x| {
36 x.as_f64().expect("Arrays must contain the same type") as f32
37 }),
38 );
39 }
40 Some(Value::Str(_)) => {
41 ns.reserve(ar.len());
42 for string in ar {
43 let feat = ParsedFeature::Simple {
44 name: string
45 .as_str()
46 .expect("Arrays must contain the same type"),
47 };
48 ns.add_feature(feat.hash(ns_hash).mask(mask), 1.0);
49 }
50 }
51 Some(_) => panic!("Not a number or string"),
52 None => todo!(),
53 },
54
55 Value::Object(contents) => {
56 for (key, value) in contents {
57 match value {
58 Value::Number(value) => {
59 let feat: ParsedFeature<'_> =
60 ParsedFeature::Simple { name: key };
61 ns.add_feature(
62 feat.hash(ns_hash).mask(mask),
63 value.as_f64().unwrap() as f32,
64 );
65 }
66 Value::Str(value) => {
67 let feat = ParsedFeature::SimpleWithStringValue {
68 name: key,
69 value: value,
70 };
71 ns.add_feature(feat.hash(ns_hash).mask(mask), 1.0);
72 }
73 Value::Bool(value) => {
74 if *value {
75 let feat = ParsedFeature::Simple { name: key };
76 ns.add_feature(feat.hash(ns_hash).mask(mask), 1.0);
77 }
78 }
79 _ => todo!(),
80 }
81 }
82 }
83 _ => todo!(),
84 }
85 }
86 }
87 _ => panic!("Not an object"),
88 }
89 output
90}
91
92#[derive(Default)]
93pub struct JsonParserFactory;
94impl TextModeParserFactory for JsonParserFactory {
95 type Parser = JsonParser;
96
97 fn create(
98 &self,
99 features_type: FeaturesType,
100 label_type: LabelType,
101 hash_seed: u32,
102 num_bits: u8,
103 pool: std::sync::Arc<Pool<SparseFeatures>>,
104 ) -> JsonParser {
105 JsonParser {
106 _feature_type: features_type,
107 _label_type: label_type,
108 hash_seed,
109 num_bits,
110 pool,
111 }
112 }
113}
114
115pub struct JsonParser {
116 _feature_type: FeaturesType,
117 _label_type: LabelType,
118 hash_seed: u32,
119 num_bits: u8,
120 pool: std::sync::Arc<Pool<SparseFeatures>>,
121}
122
123impl TextModeParser for JsonParser {
124 fn get_next_chunk(
125 &self,
126 input: &mut dyn std::io::BufRead,
127 mut output_buffer: String,
128 ) -> Result<Option<String>> {
129 output_buffer.clear();
130 input.read_line(&mut output_buffer)?;
131 if output_buffer.is_empty() {
132 return Ok(None);
133 }
134 Ok(Some(output_buffer))
135 }
136
137 fn parse_chunk<'a, 'b>(&self, chunk: &'a str) -> Result<(Features<'b>, Option<Label>)> {
138 let json: Value = serde_json::from_str(chunk).expect("JSON was not well-formatted");
139 Ok(match (self._feature_type, self._label_type) {
140 (FeaturesType::SparseSimple, LabelType::Simple) => {
141 let label = match json.get("label") {
142 Value::Null => None,
143 Value::Number(val) => Some(SimpleLabel::from(val.as_f64().unwrap() as f32)),
144 val => {
145 let l: SimpleLabel =
146 serde_json::from_value(serde_json::Value::from(val.clone())).unwrap();
147 Some(l)
148 }
149 };
150
151 let features = match json.get("features") {
152 Value::Null => panic!("No features found"),
153 val => {
154 let feats =
155 to_features(val, self.pool.get_object(), self.hash_seed, self.num_bits);
156 feats
157 }
158 };
159
160 (Features::SparseSimple(features), label.map(|l| l.into()))
161 }
162 (FeaturesType::SparseCBAdf, LabelType::CB) => {
163 let label = match json.get("label") {
164 Value::Null => None,
165 val => {
166 let l: CBLabel =
167 serde_json::from_value(serde_json::Value::from(val.clone())).unwrap();
168 Some(l)
169 }
170 };
171
172 let shared = match json.get("shared") {
173 Value::Null => None,
174 val => {
175 let feats =
176 to_features(val, self.pool.get_object(), self.hash_seed, self.num_bits);
177 Some(feats)
178 }
179 };
180
181 let actions = match json.get("actions") {
182 Value::Null => panic!("No actions found"),
183 Value::Array(val) => val
184 .iter()
185 .map(|x| {
186 to_features(x, self.pool.get_object(), self.hash_seed, self.num_bits)
187 })
188 .collect(),
189 _ => panic!("Actions must be an array"),
190 };
191
192 (
193 Features::SparseCBAdf(CBAdfFeatures { shared, actions }),
194 label.map(|l| l.into()),
195 )
196 }
197
198 (_, _) => panic!("Feature type mismatch"),
199 })
200 }
201}
202
203#[cfg(test)]
204mod test {
205 use std::sync::Arc;
206
207 use approx::assert_relative_eq;
208 use serde_json::json;
209
210 use crate::{
211 object_pool::Pool,
212 parsers::{JsonParserFactory, TextModeParser, TextModeParserFactory},
213 sparse_namespaced_features::{Namespace, SparseFeatures},
214 utils::AsInner,
215 CBAdfFeatures, CBLabel, FeaturesType, LabelType, SimpleLabel,
216 };
217 #[test]
218 fn json_parse_cb() {
219 let json_obj = json!({
220 "label": {
221 "action": 3,
222 "cost": 0.0,
223 "probability": 0.05
224 },
225 "shared": {
226 ":default": {
227 "bool_true": true,
228 "bool_false": false
229 },
230 "numbers": [4, 5.6],
231 "FromUrl": {
232 "timeofday": "Afternoon",
233 "weather": "Sunny",
234 "name": "Cathy"
235 }
236 },
237 "actions": [
238 {
239 "i": { "constant": 1, "id": "Cappucino" },
240 "j": {
241 "type": "hot",
242 "origin": "kenya",
243 "organic": "yes",
244 "roast": "dark"
245 }
246 }
247 ]
248 });
249
250 let pool = Arc::new(Pool::new());
251 let parser = JsonParserFactory::default().create(
252 FeaturesType::SparseCBAdf,
253 LabelType::CB,
254 0,
255 18,
256 pool,
257 );
258
259 let (features, label) = parser.parse_chunk(&json_obj.to_string()).unwrap();
260 let cb_label: &CBLabel = label.as_ref().unwrap().as_inner().unwrap();
261 assert_eq!(cb_label.action, 3);
262 assert_relative_eq!(cb_label.cost, 0.0);
263 assert_relative_eq!(cb_label.probability, 0.05);
264
265 let cb_feats: &CBAdfFeatures = features.as_inner().unwrap();
266 assert_eq!(cb_feats.actions.len(), 1);
267 assert!(cb_feats.shared.is_some());
268 let shared = cb_feats.shared.as_ref().unwrap();
269 assert_eq!(shared.namespaces().count(), 3);
270 let shared_default_ns = shared.get_namespace(Namespace::Default).unwrap();
271 assert_eq!(shared_default_ns.iter().count(), 1);
272
273 let shared_from_url_ns = shared
274 .get_namespace(Namespace::from_name("FromUrl", 0))
275 .unwrap();
276 assert_eq!(shared_from_url_ns.iter().count(), 3);
277
278 let shared_numbers_ns = shared
279 .get_namespace(Namespace::from_name("numbers", 0))
280 .unwrap();
281 assert_eq!(shared_numbers_ns.iter().count(), 2);
282 assert_relative_eq!(
283 shared_numbers_ns.iter().map(|(_, val)| val).sum::<f32>(),
284 9.6
285 );
286
287 let action = cb_feats.actions.get(0).unwrap();
288 assert_eq!(action.namespaces().count(), 2);
289 assert!(action.get_namespace(Namespace::Default).is_none());
290 let action_i_ns = action.get_namespace(Namespace::from_name("i", 0)).unwrap();
291 assert_eq!(action_i_ns.iter().count(), 2);
292 let action_j_ns = action.get_namespace(Namespace::from_name("j", 0)).unwrap();
293 assert_eq!(action_j_ns.iter().count(), 4);
294 }
295
296 #[test]
297 fn json_parse_simple() {
298 let json_obj = json!({
299 "label": {
300 "value": 0.2,
301 "weight": 0.4
302 },
303 "features" : {
304 ":default": {
305 "bool_true": true,
306 "bool_false": false
307 },
308 "numbers": [4, 5.6],
309 "FromUrl": {
310 "timeofday": "Afternoon",
311 "weather": "Sunny",
312 "name": "Cathy"
313 }
314 }
315 });
316
317 let pool = Arc::new(Pool::new());
318 let parser = JsonParserFactory::default().create(
319 FeaturesType::SparseSimple,
320 LabelType::Simple,
321 0,
322 18,
323 pool,
324 );
325
326 let (features, label) = parser.parse_chunk(&json_obj.to_string()).unwrap();
327 let lbl: &SimpleLabel = label.as_ref().unwrap().as_inner().unwrap();
328 assert_relative_eq!(lbl.value(), 0.2);
329 assert_relative_eq!(lbl.weight(), 0.4);
330
331 let features: &SparseFeatures = features.as_inner().unwrap();
332 assert_eq!(features.namespaces().count(), 3);
333 let features_default_ns = features.get_namespace(Namespace::Default).unwrap();
334 assert_eq!(features_default_ns.iter().count(), 1);
335
336 let features_from_url_ns = features
337 .get_namespace(Namespace::from_name("FromUrl", 0))
338 .unwrap();
339 assert_eq!(features_from_url_ns.iter().count(), 3);
340
341 let features_numbers_ns = features
342 .get_namespace(Namespace::from_name("numbers", 0))
343 .unwrap();
344 assert_eq!(features_numbers_ns.iter().count(), 2);
345 assert_relative_eq!(
346 features_numbers_ns.iter().map(|(_, val)| val).sum::<f32>(),
347 9.6
348 );
349 }
350}