reductionml_core/parsers/
vw_text_parser.rs

1use core::f32;
2
3use derive_more::TryInto;
4
5use smallvec::SmallVec;
6
7use crate::error::{Error, Result};
8
9use crate::object_pool::Pool;
10use crate::parsers::ParsedFeature;
11use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
12use crate::types::{Features, Label, LabelType};
13use crate::utils::AsInner;
14use crate::{CBAdfFeatures, CBLabel, FeatureMask, FeaturesType, SimpleLabel};
15
16use super::{ParsedNamespaceInfo, TextModeParser, TextModeParserFactory};
17
18#[derive(Clone, Copy)]
19struct CBTextLabel {
20    shared: bool,
21    // Action, cost, prob
22    acp: Option<(u32, f32, f32)>,
23}
24
25#[derive(TryInto, Clone, Copy)]
26enum TextLabel {
27    Simple(f32, Option<f32>),
28    // Binary(bool),
29    CB(CBTextLabel),
30}
31
32impl AsInner<CBTextLabel> for TextLabel {
33    fn as_inner(&self) -> Option<&CBTextLabel> {
34        match self {
35            TextLabel::CB(f) => Some(f),
36            _ => None,
37        }
38    }
39    fn as_inner_mut(&mut self) -> Option<&mut CBTextLabel> {
40        match self {
41            TextLabel::CB(f) => Some(f),
42            _ => None,
43        }
44    }
45}
46
47// TODO work out where to put tag.
48// Idea - tag is not a concept here but for the cases where it was necessary (ccb) it will be folded into the feature type
49fn finalize_parsed_result_singleline<'a>(
50    parsed: TextParseResult,
51    _num_bits: u8,
52    dest: SparseFeatures,
53) -> (Features<'a>, Option<Label>) {
54    let hashed_sparse_features = Features::SparseSimple(dest);
55    match parsed.label {
56        // TODO fix
57        Some(TextLabel::Simple(x, weight)) => (
58            hashed_sparse_features,
59            Some(Label::Simple(SimpleLabel::new(x, weight.unwrap_or(1.0)))),
60        ),
61        // TODO binary
62        Some(_) => todo!(),
63        None => (hashed_sparse_features, None),
64    }
65}
66
67fn finalize_parsed_result_multiline<'a, 'b, T, U>(
68    mut feats_iter: T,
69    parsed: U,
70    expected_label: LabelType,
71    expected_features: FeaturesType,
72    _num_bits: u8,
73) -> Result<(Features<'b>, Option<Label>)>
74where
75    T: IntoIterator<Item = SparseFeatures> + Iterator<Item = SparseFeatures> + Clone,
76    U: Iterator<Item = TextParseResult<'a>>,
77{
78    match (expected_label, expected_features) {
79        (LabelType::CB, FeaturesType::SparseCBAdf) => {
80            // First thing to do is to determine if there is a shared example.
81            let mut txt_labels_iter = parsed.map(|x| x.label.unwrap()).peekable();
82            // let mut feats_iter = feats.into_iter();
83            let first_label: &CBTextLabel = txt_labels_iter
84                .peek()
85                .ok_or(Error::InvalidArgument("".to_owned()))?
86                .as_inner()
87                .expect("Label should be CB");
88            let first_is_shared = first_label.shared;
89
90            // TODO assert not more than 1 is shared.
91            let shared_ex = if first_is_shared {
92                // Consume shared token
93                txt_labels_iter.next();
94                Some(feats_iter.next().unwrap())
95            } else {
96                None
97            };
98
99            // Find the labelled action.
100            let mut label: Option<CBLabel> = None;
101            for (counter, action_label) in txt_labels_iter.enumerate() {
102                let lbl: &CBTextLabel = action_label.as_inner().expect("Label should be CB");
103                if let Some((_a, c, p)) = lbl.acp {
104                    if label.is_some() {
105                        return Err(Error::InvalidArgument(
106                            "More than one action label found".to_owned(),
107                        ));
108                    }
109                    label = Some(CBLabel {
110                        action: counter,
111                        cost: c,
112                        probability: p,
113                    });
114                }
115            }
116
117            Ok((
118                Features::SparseCBAdf(CBAdfFeatures {
119                    shared: shared_ex,
120                    actions: feats_iter.collect(),
121                }),
122                label.map(Label::CB),
123            ))
124        }
125        _ => Err(Error::InvalidArgument("".to_owned())),
126    }
127}
128
129struct TextParseResult<'a> {
130    _tag: Option<&'a str>,
131    // namespaces: Vec<ParsedNamespace<'a>>,
132    label: Option<TextLabel>,
133}
134
135fn parse_label(tokens: &[&str], label_type: LabelType) -> Result<Option<TextLabel>> {
136    match label_type {
137        LabelType::Simple => match tokens.len() {
138            0 => Ok(None),
139            1 => Ok(Some(TextLabel::Simple(
140                fast_float::parse(tokens[0]).unwrap(),
141                None,
142            ))),
143            2 => Ok(Some(TextLabel::Simple(
144                fast_float::parse(tokens[0]).unwrap(),
145                Some(fast_float::parse(tokens[1]).unwrap()),
146            ))),
147            // Initial not currently supported...
148            3 => todo!(),
149            _ => todo!(),
150        },
151        LabelType::Binary => todo!(),
152        LabelType::CB => match tokens.iter().next() {
153            None => Ok(Some(TextLabel::CB(CBTextLabel {
154                shared: false,
155                acp: None,
156            }))),
157            Some(value) if value.trim() == "shared" => Ok(Some(TextLabel::CB(CBTextLabel {
158                shared: true,
159                acp: None,
160            }))),
161            Some(value) => {
162                let mut tokens = value.split(':');
163                let action = tokens.next().unwrap().parse().unwrap();
164                let cost = fast_float::parse(tokens.next().unwrap()).unwrap();
165                let probability = fast_float::parse(tokens.next().unwrap()).unwrap();
166
167                // TODO: check that there are no more tokens
168
169                Ok(Some(TextLabel::CB(CBTextLabel {
170                    shared: false,
171                    acp: Some((action, cost, probability)),
172                })))
173            }
174        },
175    }
176}
177
178// TODO - consider conditionally allowing a feature whose name is a number ONLY to be interpreted as an anonymous features
179// This would be to mimic VW's hash "mode" of all vs txt
180fn parse_feature<'a>(feature: &'a str, offset_counter: &mut u32) -> (ParsedFeature<'a>, f32) {
181    // Check if char 0 is a :
182    let first_char_is_colon = feature.starts_with(':');
183    if first_char_is_colon {
184        // Anonymous feature
185        let value = fast_float::parse(&feature[1..]);
186        if let Ok(value) = value {
187            let offset_to_use = *offset_counter;
188            *offset_counter += 1;
189            (
190                ParsedFeature::Anonymous {
191                    offset: offset_to_use,
192                },
193                value,
194            )
195        } else {
196            return (
197                ParsedFeature::SimpleWithStringValue {
198                    name: "",
199                    value: feature[1..].trim(),
200                },
201                1.0,
202            );
203        }
204    } else {
205        // Named feature
206        let mut tokens = feature.split(':');
207        let name = tokens.next().unwrap();
208        match tokens.next() {
209            Some(value) => {
210                if let Ok(value) = fast_float::parse(value) {
211                    (ParsedFeature::Simple { name }, value)
212                } else {
213                    (
214                        ParsedFeature::SimpleWithStringValue {
215                            name,
216                            value: value.trim(),
217                        },
218                        1.0,
219                    )
220                }
221            }
222            None => (ParsedFeature::Simple { name }, 1.0),
223        }
224    }
225}
226
227fn parse_namespace_inline(
228    namespace_segment: &str,
229    dest_namespace: &mut SparseFeatures,
230    hash_seed: u32,
231    num_bits: u8,
232) -> Result<()> {
233    // Check if first char is a space or not
234    let first_char_is_space = namespace_segment.starts_with(' ');
235    let mut tokens = namespace_segment.split_ascii_whitespace();
236
237    let (namespace_name, namespace_value) = if first_char_is_space {
238        // Anonymous namespace
239        (" ", 1.0)
240    } else {
241        let namespace_info_token = tokens.next().unwrap();
242        let mut namespace_info_tokens = namespace_info_token.split(':');
243        let name = namespace_info_tokens.next().unwrap();
244        let value = match namespace_info_tokens.next() {
245            Some(value) => fast_float::parse(value).unwrap(),
246            None => 1.0,
247        };
248
249        (name, value)
250    };
251
252    let namespace_def = Namespace::from_name(namespace_name, hash_seed);
253    let namespace_hash = namespace_def.hash(hash_seed);
254
255    let dest = dest_namespace.get_or_create_namespace(namespace_def);
256    let mut offset_counter = 0;
257    for token in tokens {
258        let (parsed_feat, feat_value) = parse_feature(token, &mut offset_counter);
259        // let this_ns = dest.get_or_create_namespace_with_capacity(namespace_hash, features.len());
260        let feature_hash = parsed_feat.hash(namespace_hash);
261        let masked_hash = feature_hash.mask(FeatureMask::from_num_bits(num_bits));
262        dest.add_feature(masked_hash, feat_value * namespace_value);
263    }
264
265    Ok(())
266}
267
268fn parse_namespace_info_token(namespace_segment: &str) -> Result<(&str, f32)> {
269    let mut tokens: std::str::Split<char> = namespace_segment.split(':');
270    let name = tokens
271        .next()
272        .ok_or(Error::ParserError("Expected namespace name".to_owned()))?;
273    let value = match tokens.next() {
274        Some(value) => fast_float::parse(value).map_err(|err| {
275            Error::ParserError(format!("Failed to parse namespace value: {}", err))
276        })?,
277        None => 1.0,
278    };
279
280    Ok((name, value))
281}
282
283// "Consumes" some amount of input and returns the namespace info and the remaining input
284fn parse_namespace_info(input: &str) -> Result<(&str, (ParsedNamespaceInfo, f32))> {
285    let first_char_is_space = input.starts_with(' ');
286    // Extract up until the first space
287    if first_char_is_space {
288        Ok((&input[1..], (ParsedNamespaceInfo::Default, 1.0)))
289    } else {
290        let input_until_first_space = input.find(' ').unwrap();
291        let namespace_info_token = &input[..input_until_first_space];
292        let (ns_name, ns_value) = parse_namespace_info_token(namespace_info_token)?;
293        Ok((
294            &input[input_until_first_space..],
295            (ParsedNamespaceInfo::Named(ns_name), ns_value),
296        ))
297    }
298}
299
300fn extract_namespace_features(
301    namespace_segment: &str,
302) -> Result<(ParsedNamespaceInfo, Vec<ParsedFeature>)> {
303    let (remaining, (namespace_name, _namespace_value)) = parse_namespace_info(namespace_segment)?;
304
305    let tokens = remaining.split_ascii_whitespace();
306    let mut offset_counter = 0;
307    let extracted_featrues = tokens
308        .map(|x| {
309            let (feat, _value) = parse_feature(x, &mut offset_counter);
310            feat
311        })
312        .collect();
313    Ok((namespace_name, extracted_featrues))
314}
315
316// TODO revisit this function. Scanning to the last character is not ideal since it is linear time.
317fn parse_initial_segment(
318    text: &str,
319    label_type: LabelType,
320) -> Result<(Option<&str>, Option<TextLabel>)> {
321    // Is the last char of text a space?
322    let last_char_is_space = text.ends_with(' ');
323
324    // TODO: avoid this allocation!
325    let mut tokens: Vec<&str> = text.split_whitespace().collect();
326
327    let tag = match tokens.last() {
328        Some(&x) if (x.starts_with('\'') || !last_char_is_space) => {
329            tokens.pop();
330            if let Some(x) = x.strip_prefix('\'') {
331                Some(x)
332            } else {
333                Some(x)
334            }
335        }
336        Some(_) => None,
337        None => None,
338    };
339
340    let label = parse_label(&tokens, label_type)?;
341    Ok((tag, label))
342}
343
344fn parse_text_line_internal<'a>(
345    text: &'a str,
346    label_type: LabelType,
347    dest: &mut SparseFeatures,
348    hash_seed: u32,
349    num_bits: u8,
350) -> Result<TextParseResult<'a>> {
351    // Get string view up until first bar
352    let mut segments = text.split('|');
353    let initial_segment = segments.next().unwrap();
354    let (tag, label) = parse_initial_segment(initial_segment, label_type)?;
355
356    for segment in segments {
357        parse_namespace_inline(segment, dest, hash_seed, num_bits)?;
358    }
359    Ok(TextParseResult { _tag: tag, label })
360}
361
362#[derive(Default)]
363pub struct VwTextParserFactory;
364impl TextModeParserFactory for VwTextParserFactory {
365    type Parser = VwTextParser;
366    fn create(
367        &self,
368        features_type: FeaturesType,
369        label_type: LabelType,
370        hash_seed: u32,
371        num_bits: u8,
372        pool: std::sync::Arc<Pool<SparseFeatures>>,
373    ) -> Self::Parser {
374        VwTextParser {
375            feature_type: features_type,
376            label_type,
377            hash_seed,
378            num_bits,
379            pool,
380        }
381    }
382}
383
384pub struct VwTextParser {
385    feature_type: FeaturesType,
386    label_type: LabelType,
387    hash_seed: u32,
388    num_bits: u8,
389    pool: std::sync::Arc<Pool<SparseFeatures>>,
390}
391
392fn read_multi_lines(
393    input: &mut dyn std::io::BufRead,
394    mut output_buffer: String,
395) -> Result<Option<String>> {
396    assert!(output_buffer.is_empty());
397    loop {
398        let len_before = output_buffer.len();
399        if !output_buffer.is_empty() {
400            output_buffer.push('\n');
401        }
402        let bytes_read = input.read_line(&mut output_buffer)?;
403        if bytes_read == 0 && output_buffer.is_empty() {
404            return Ok(None);
405        }
406        output_buffer.truncate(output_buffer.trim_end().len());
407
408        // If we encounter an empty line, we are done. But if we are at
409        // the start (no data yet) we should just skip the empty line.
410        if output_buffer.is_empty() && len_before == 0 {
411            continue;
412        }
413
414        if len_before > 0 && output_buffer.len() == len_before {
415            // We read a line, but it was empty. This means we are done.
416            return Ok(Some(output_buffer));
417        }
418    }
419}
420
421fn read_single_line(
422    input: &mut dyn std::io::BufRead,
423    mut output_buffer: String,
424) -> Result<Option<String>> {
425    loop {
426        let bytes_read = input.read_line(&mut output_buffer)?;
427        if bytes_read == 0 {
428            return Ok(None);
429        }
430        output_buffer.truncate(output_buffer.trim_end().len());
431
432        // If we encounter an empty line, we are done. But if we are at
433        // the start (no data yet) we should just skip the empty line.
434        if output_buffer.is_empty() {
435            continue;
436        }
437
438        return Ok(Some(output_buffer));
439    }
440}
441
442impl TextModeParser for VwTextParser {
443    fn get_next_chunk(
444        &self,
445        input: &mut dyn std::io::BufRead,
446        mut output_buffer: String,
447    ) -> Result<Option<String>> {
448        output_buffer.clear();
449        if self.is_multiline() {
450            read_multi_lines(input, output_buffer)
451        } else {
452            read_single_line(input, output_buffer)
453        }
454    }
455
456    fn parse_chunk<'a, 'b>(&self, chunk: &'a str) -> Result<(Features<'b>, Option<Label>)> {
457        if self.is_multiline() {
458            let mut results = SmallVec::<[TextParseResult<'a>; 4]>::new();
459            let mut all_feautures = SmallVec::<[SparseFeatures; 4]>::new();
460            for line in chunk.lines() {
461                let mut dest = self.pool.get_object();
462                let result = parse_text_line_internal(
463                    line,
464                    self.label_type,
465                    &mut dest,
466                    self.hash_seed,
467                    self.num_bits,
468                )?;
469                results.push(result);
470                all_feautures.push(dest);
471            }
472            finalize_parsed_result_multiline(
473                all_feautures.into_iter(),
474                results.into_iter(),
475                self.label_type,
476                self.feature_type,
477                self.num_bits,
478            )
479        } else {
480            let mut dest = self.pool.get_object();
481            let result = parse_text_line_internal(
482                chunk,
483                self.label_type,
484                &mut dest,
485                self.hash_seed,
486                self.num_bits,
487            )?;
488            Ok(finalize_parsed_result_singleline(
489                result,
490                self.num_bits,
491                dest,
492            ))
493        }
494    }
495
496    fn extract_feature_names<'a>(
497        &self,
498        chunk: &'a str,
499    ) -> Result<std::collections::HashMap<ParsedNamespaceInfo<'a>, Vec<ParsedFeature<'a>>>> {
500        if self.is_multiline() {
501            chunk
502                .lines()
503                .flat_map(|line| {
504                    let mut segments = line.split('|');
505                    let _label_section = segments.next().unwrap();
506                    segments.map(extract_namespace_features)
507                })
508                .collect()
509        } else {
510            let mut segments = chunk.split('|');
511            let _label_section = segments.next().unwrap();
512            segments.map(extract_namespace_features).collect()
513        }
514    }
515}
516
517impl VwTextParser {
518    fn is_multiline(&self) -> bool {
519        self.feature_type == FeaturesType::SparseCBAdf && self.label_type == LabelType::CB
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use crate::{
526        error::Error,
527        parsers::vw_text_parser::{read_multi_lines, read_single_line},
528    };
529    use std::io::Cursor;
530
531    #[test]
532    fn chunk_multiline() -> Result<(), Error> {
533        let input = r#"line 1
534line 2"#;
535
536        let mut input = Cursor::new(input);
537        let res = read_multi_lines(&mut input, String::new())?;
538        assert_eq!(res, Some("line 1\nline 2".to_string()));
539        let res = read_multi_lines(&mut input, String::new())?;
540        assert_eq!(res, None);
541
542        let input = r#"
543
544
545line 1
546line 2"#;
547
548        let mut input = Cursor::new(input);
549        let res = read_multi_lines(&mut input, String::new())?;
550        assert_eq!(res, Some("line 1\nline 2".to_string()));
551        let res = read_multi_lines(&mut input, String::new())?;
552        assert_eq!(res, None);
553
554        let input = r#"
555
556
557line 1
558line 2
559
560        "#;
561        let mut input = Cursor::new(input);
562        let res = read_multi_lines(&mut input, String::new())?;
563        assert_eq!(res, Some("line 1\nline 2".to_string()));
564        let res = read_multi_lines(&mut input, String::new())?;
565        assert_eq!(res, None);
566
567        let input = r#"
568
569
570line 1
571line 2
572
573
574line 3
575line 4
576
577        "#;
578        let mut input = Cursor::new(input);
579        let res = read_multi_lines(&mut input, String::new())?;
580        assert_eq!(res, Some("line 1\nline 2".to_string()));
581        let res = read_multi_lines(&mut input, String::new())?;
582        assert_eq!(res, Some("line 3\nline 4".to_string()));
583        let res = read_multi_lines(&mut input, String::new())?;
584        assert_eq!(res, None);
585        let res = read_multi_lines(&mut input, String::new())?;
586        assert_eq!(res, None);
587        Ok(())
588    }
589
590    #[test]
591    fn chunk_singleline() -> Result<(), Error> {
592        let input = r#"line 1
593line 2"#;
594
595        let mut input = Cursor::new(input);
596        let res = read_single_line(&mut input, String::new())?;
597        assert_eq!(res, Some("line 1".to_string()));
598        let res = read_single_line(&mut input, String::new())?;
599        assert_eq!(res, Some("line 2".to_string()));
600        let res = read_single_line(&mut input, String::new())?;
601        assert_eq!(res, None);
602
603        let input = r#"
604
605
606line 1
607line 2"#;
608
609        let mut input = Cursor::new(input);
610        let res = read_single_line(&mut input, String::new())?;
611        assert_eq!(res, Some("line 1".to_string()));
612        let res = read_single_line(&mut input, String::new())?;
613        assert_eq!(res, Some("line 2".to_string()));
614        let res = read_single_line(&mut input, String::new())?;
615        assert_eq!(res, None);
616
617        let input = r#"
618
619
620line 1
621
622line 2
623
624        "#;
625        let mut input = Cursor::new(input);
626        let res = read_single_line(&mut input, String::new())?;
627        assert_eq!(res, Some("line 1".to_string()));
628        let res = read_single_line(&mut input, String::new())?;
629        assert_eq!(res, Some("line 2".to_string()));
630        let res = read_single_line(&mut input, String::new())?;
631        assert_eq!(res, None);
632
633        let input = r#"
634
635
636line 1
637line 2
638
639
640line 3
641line 4
642
643        "#;
644        let mut input = Cursor::new(input);
645        let res = read_single_line(&mut input, String::new())?;
646        assert_eq!(res, Some("line 1".to_string()));
647        let res = read_single_line(&mut input, String::new())?;
648        assert_eq!(res, Some("line 2".to_string()));
649        let res = read_single_line(&mut input, String::new())?;
650        assert_eq!(res, Some("line 3".to_string()));
651        let res = read_single_line(&mut input, String::new())?;
652        assert_eq!(res, Some("line 4".to_string()));
653        let res = read_single_line(&mut input, String::new())?;
654        assert_eq!(res, None);
655        Ok(())
656    }
657}