1use core::f32;
2
3use derive_more::TryInto;
4
5use smallvec::SmallVec;
6
7use crate::error::{Error, Result};
8
9use crate::object_pool::Pool;
10use crate::parsers::ParsedFeature;
11use crate::sparse_namespaced_features::{Namespace, SparseFeatures};
12use crate::types::{Features, Label, LabelType};
13use crate::utils::AsInner;
14use crate::{CBAdfFeatures, CBLabel, FeatureMask, FeaturesType, SimpleLabel};
15
16use super::{ParsedNamespaceInfo, TextModeParser, TextModeParserFactory};
17
18#[derive(Clone, Copy)]
19struct CBTextLabel {
20 shared: bool,
21 acp: Option<(u32, f32, f32)>,
23}
24
25#[derive(TryInto, Clone, Copy)]
26enum TextLabel {
27 Simple(f32, Option<f32>),
28 CB(CBTextLabel),
30}
31
32impl AsInner<CBTextLabel> for TextLabel {
33 fn as_inner(&self) -> Option<&CBTextLabel> {
34 match self {
35 TextLabel::CB(f) => Some(f),
36 _ => None,
37 }
38 }
39 fn as_inner_mut(&mut self) -> Option<&mut CBTextLabel> {
40 match self {
41 TextLabel::CB(f) => Some(f),
42 _ => None,
43 }
44 }
45}
46
47fn finalize_parsed_result_singleline<'a>(
50 parsed: TextParseResult,
51 _num_bits: u8,
52 dest: SparseFeatures,
53) -> (Features<'a>, Option<Label>) {
54 let hashed_sparse_features = Features::SparseSimple(dest);
55 match parsed.label {
56 Some(TextLabel::Simple(x, weight)) => (
58 hashed_sparse_features,
59 Some(Label::Simple(SimpleLabel::new(x, weight.unwrap_or(1.0)))),
60 ),
61 Some(_) => todo!(),
63 None => (hashed_sparse_features, None),
64 }
65}
66
67fn finalize_parsed_result_multiline<'a, 'b, T, U>(
68 mut feats_iter: T,
69 parsed: U,
70 expected_label: LabelType,
71 expected_features: FeaturesType,
72 _num_bits: u8,
73) -> Result<(Features<'b>, Option<Label>)>
74where
75 T: IntoIterator<Item = SparseFeatures> + Iterator<Item = SparseFeatures> + Clone,
76 U: Iterator<Item = TextParseResult<'a>>,
77{
78 match (expected_label, expected_features) {
79 (LabelType::CB, FeaturesType::SparseCBAdf) => {
80 let mut txt_labels_iter = parsed.map(|x| x.label.unwrap()).peekable();
82 let first_label: &CBTextLabel = txt_labels_iter
84 .peek()
85 .ok_or(Error::InvalidArgument("".to_owned()))?
86 .as_inner()
87 .expect("Label should be CB");
88 let first_is_shared = first_label.shared;
89
90 let shared_ex = if first_is_shared {
92 txt_labels_iter.next();
94 Some(feats_iter.next().unwrap())
95 } else {
96 None
97 };
98
99 let mut label: Option<CBLabel> = None;
101 for (counter, action_label) in txt_labels_iter.enumerate() {
102 let lbl: &CBTextLabel = action_label.as_inner().expect("Label should be CB");
103 if let Some((_a, c, p)) = lbl.acp {
104 if label.is_some() {
105 return Err(Error::InvalidArgument(
106 "More than one action label found".to_owned(),
107 ));
108 }
109 label = Some(CBLabel {
110 action: counter,
111 cost: c,
112 probability: p,
113 });
114 }
115 }
116
117 Ok((
118 Features::SparseCBAdf(CBAdfFeatures {
119 shared: shared_ex,
120 actions: feats_iter.collect(),
121 }),
122 label.map(Label::CB),
123 ))
124 }
125 _ => Err(Error::InvalidArgument("".to_owned())),
126 }
127}
128
129struct TextParseResult<'a> {
130 _tag: Option<&'a str>,
131 label: Option<TextLabel>,
133}
134
135fn parse_label(tokens: &[&str], label_type: LabelType) -> Result<Option<TextLabel>> {
136 match label_type {
137 LabelType::Simple => match tokens.len() {
138 0 => Ok(None),
139 1 => Ok(Some(TextLabel::Simple(
140 fast_float::parse(tokens[0]).unwrap(),
141 None,
142 ))),
143 2 => Ok(Some(TextLabel::Simple(
144 fast_float::parse(tokens[0]).unwrap(),
145 Some(fast_float::parse(tokens[1]).unwrap()),
146 ))),
147 3 => todo!(),
149 _ => todo!(),
150 },
151 LabelType::Binary => todo!(),
152 LabelType::CB => match tokens.iter().next() {
153 None => Ok(Some(TextLabel::CB(CBTextLabel {
154 shared: false,
155 acp: None,
156 }))),
157 Some(value) if value.trim() == "shared" => Ok(Some(TextLabel::CB(CBTextLabel {
158 shared: true,
159 acp: None,
160 }))),
161 Some(value) => {
162 let mut tokens = value.split(':');
163 let action = tokens.next().unwrap().parse().unwrap();
164 let cost = fast_float::parse(tokens.next().unwrap()).unwrap();
165 let probability = fast_float::parse(tokens.next().unwrap()).unwrap();
166
167 Ok(Some(TextLabel::CB(CBTextLabel {
170 shared: false,
171 acp: Some((action, cost, probability)),
172 })))
173 }
174 },
175 }
176}
177
178fn parse_feature<'a>(feature: &'a str, offset_counter: &mut u32) -> (ParsedFeature<'a>, f32) {
181 let first_char_is_colon = feature.starts_with(':');
183 if first_char_is_colon {
184 let value = fast_float::parse(&feature[1..]);
186 if let Ok(value) = value {
187 let offset_to_use = *offset_counter;
188 *offset_counter += 1;
189 (
190 ParsedFeature::Anonymous {
191 offset: offset_to_use,
192 },
193 value,
194 )
195 } else {
196 return (
197 ParsedFeature::SimpleWithStringValue {
198 name: "",
199 value: feature[1..].trim(),
200 },
201 1.0,
202 );
203 }
204 } else {
205 let mut tokens = feature.split(':');
207 let name = tokens.next().unwrap();
208 match tokens.next() {
209 Some(value) => {
210 if let Ok(value) = fast_float::parse(value) {
211 (ParsedFeature::Simple { name }, value)
212 } else {
213 (
214 ParsedFeature::SimpleWithStringValue {
215 name,
216 value: value.trim(),
217 },
218 1.0,
219 )
220 }
221 }
222 None => (ParsedFeature::Simple { name }, 1.0),
223 }
224 }
225}
226
227fn parse_namespace_inline(
228 namespace_segment: &str,
229 dest_namespace: &mut SparseFeatures,
230 hash_seed: u32,
231 num_bits: u8,
232) -> Result<()> {
233 let first_char_is_space = namespace_segment.starts_with(' ');
235 let mut tokens = namespace_segment.split_ascii_whitespace();
236
237 let (namespace_name, namespace_value) = if first_char_is_space {
238 (" ", 1.0)
240 } else {
241 let namespace_info_token = tokens.next().unwrap();
242 let mut namespace_info_tokens = namespace_info_token.split(':');
243 let name = namespace_info_tokens.next().unwrap();
244 let value = match namespace_info_tokens.next() {
245 Some(value) => fast_float::parse(value).unwrap(),
246 None => 1.0,
247 };
248
249 (name, value)
250 };
251
252 let namespace_def = Namespace::from_name(namespace_name, hash_seed);
253 let namespace_hash = namespace_def.hash(hash_seed);
254
255 let dest = dest_namespace.get_or_create_namespace(namespace_def);
256 let mut offset_counter = 0;
257 for token in tokens {
258 let (parsed_feat, feat_value) = parse_feature(token, &mut offset_counter);
259 let feature_hash = parsed_feat.hash(namespace_hash);
261 let masked_hash = feature_hash.mask(FeatureMask::from_num_bits(num_bits));
262 dest.add_feature(masked_hash, feat_value * namespace_value);
263 }
264
265 Ok(())
266}
267
268fn parse_namespace_info_token(namespace_segment: &str) -> Result<(&str, f32)> {
269 let mut tokens: std::str::Split<char> = namespace_segment.split(':');
270 let name = tokens
271 .next()
272 .ok_or(Error::ParserError("Expected namespace name".to_owned()))?;
273 let value = match tokens.next() {
274 Some(value) => fast_float::parse(value).map_err(|err| {
275 Error::ParserError(format!("Failed to parse namespace value: {}", err))
276 })?,
277 None => 1.0,
278 };
279
280 Ok((name, value))
281}
282
283fn parse_namespace_info(input: &str) -> Result<(&str, (ParsedNamespaceInfo, f32))> {
285 let first_char_is_space = input.starts_with(' ');
286 if first_char_is_space {
288 Ok((&input[1..], (ParsedNamespaceInfo::Default, 1.0)))
289 } else {
290 let input_until_first_space = input.find(' ').unwrap();
291 let namespace_info_token = &input[..input_until_first_space];
292 let (ns_name, ns_value) = parse_namespace_info_token(namespace_info_token)?;
293 Ok((
294 &input[input_until_first_space..],
295 (ParsedNamespaceInfo::Named(ns_name), ns_value),
296 ))
297 }
298}
299
300fn extract_namespace_features(
301 namespace_segment: &str,
302) -> Result<(ParsedNamespaceInfo, Vec<ParsedFeature>)> {
303 let (remaining, (namespace_name, _namespace_value)) = parse_namespace_info(namespace_segment)?;
304
305 let tokens = remaining.split_ascii_whitespace();
306 let mut offset_counter = 0;
307 let extracted_featrues = tokens
308 .map(|x| {
309 let (feat, _value) = parse_feature(x, &mut offset_counter);
310 feat
311 })
312 .collect();
313 Ok((namespace_name, extracted_featrues))
314}
315
316fn parse_initial_segment(
318 text: &str,
319 label_type: LabelType,
320) -> Result<(Option<&str>, Option<TextLabel>)> {
321 let last_char_is_space = text.ends_with(' ');
323
324 let mut tokens: Vec<&str> = text.split_whitespace().collect();
326
327 let tag = match tokens.last() {
328 Some(&x) if (x.starts_with('\'') || !last_char_is_space) => {
329 tokens.pop();
330 if let Some(x) = x.strip_prefix('\'') {
331 Some(x)
332 } else {
333 Some(x)
334 }
335 }
336 Some(_) => None,
337 None => None,
338 };
339
340 let label = parse_label(&tokens, label_type)?;
341 Ok((tag, label))
342}
343
344fn parse_text_line_internal<'a>(
345 text: &'a str,
346 label_type: LabelType,
347 dest: &mut SparseFeatures,
348 hash_seed: u32,
349 num_bits: u8,
350) -> Result<TextParseResult<'a>> {
351 let mut segments = text.split('|');
353 let initial_segment = segments.next().unwrap();
354 let (tag, label) = parse_initial_segment(initial_segment, label_type)?;
355
356 for segment in segments {
357 parse_namespace_inline(segment, dest, hash_seed, num_bits)?;
358 }
359 Ok(TextParseResult { _tag: tag, label })
360}
361
362#[derive(Default)]
363pub struct VwTextParserFactory;
364impl TextModeParserFactory for VwTextParserFactory {
365 type Parser = VwTextParser;
366 fn create(
367 &self,
368 features_type: FeaturesType,
369 label_type: LabelType,
370 hash_seed: u32,
371 num_bits: u8,
372 pool: std::sync::Arc<Pool<SparseFeatures>>,
373 ) -> Self::Parser {
374 VwTextParser {
375 feature_type: features_type,
376 label_type,
377 hash_seed,
378 num_bits,
379 pool,
380 }
381 }
382}
383
384pub struct VwTextParser {
385 feature_type: FeaturesType,
386 label_type: LabelType,
387 hash_seed: u32,
388 num_bits: u8,
389 pool: std::sync::Arc<Pool<SparseFeatures>>,
390}
391
392fn read_multi_lines(
393 input: &mut dyn std::io::BufRead,
394 mut output_buffer: String,
395) -> Result<Option<String>> {
396 assert!(output_buffer.is_empty());
397 loop {
398 let len_before = output_buffer.len();
399 if !output_buffer.is_empty() {
400 output_buffer.push('\n');
401 }
402 let bytes_read = input.read_line(&mut output_buffer)?;
403 if bytes_read == 0 && output_buffer.is_empty() {
404 return Ok(None);
405 }
406 output_buffer.truncate(output_buffer.trim_end().len());
407
408 if output_buffer.is_empty() && len_before == 0 {
411 continue;
412 }
413
414 if len_before > 0 && output_buffer.len() == len_before {
415 return Ok(Some(output_buffer));
417 }
418 }
419}
420
421fn read_single_line(
422 input: &mut dyn std::io::BufRead,
423 mut output_buffer: String,
424) -> Result<Option<String>> {
425 loop {
426 let bytes_read = input.read_line(&mut output_buffer)?;
427 if bytes_read == 0 {
428 return Ok(None);
429 }
430 output_buffer.truncate(output_buffer.trim_end().len());
431
432 if output_buffer.is_empty() {
435 continue;
436 }
437
438 return Ok(Some(output_buffer));
439 }
440}
441
442impl TextModeParser for VwTextParser {
443 fn get_next_chunk(
444 &self,
445 input: &mut dyn std::io::BufRead,
446 mut output_buffer: String,
447 ) -> Result<Option<String>> {
448 output_buffer.clear();
449 if self.is_multiline() {
450 read_multi_lines(input, output_buffer)
451 } else {
452 read_single_line(input, output_buffer)
453 }
454 }
455
456 fn parse_chunk<'a, 'b>(&self, chunk: &'a str) -> Result<(Features<'b>, Option<Label>)> {
457 if self.is_multiline() {
458 let mut results = SmallVec::<[TextParseResult<'a>; 4]>::new();
459 let mut all_feautures = SmallVec::<[SparseFeatures; 4]>::new();
460 for line in chunk.lines() {
461 let mut dest = self.pool.get_object();
462 let result = parse_text_line_internal(
463 line,
464 self.label_type,
465 &mut dest,
466 self.hash_seed,
467 self.num_bits,
468 )?;
469 results.push(result);
470 all_feautures.push(dest);
471 }
472 finalize_parsed_result_multiline(
473 all_feautures.into_iter(),
474 results.into_iter(),
475 self.label_type,
476 self.feature_type,
477 self.num_bits,
478 )
479 } else {
480 let mut dest = self.pool.get_object();
481 let result = parse_text_line_internal(
482 chunk,
483 self.label_type,
484 &mut dest,
485 self.hash_seed,
486 self.num_bits,
487 )?;
488 Ok(finalize_parsed_result_singleline(
489 result,
490 self.num_bits,
491 dest,
492 ))
493 }
494 }
495
496 fn extract_feature_names<'a>(
497 &self,
498 chunk: &'a str,
499 ) -> Result<std::collections::HashMap<ParsedNamespaceInfo<'a>, Vec<ParsedFeature<'a>>>> {
500 if self.is_multiline() {
501 chunk
502 .lines()
503 .flat_map(|line| {
504 let mut segments = line.split('|');
505 let _label_section = segments.next().unwrap();
506 segments.map(extract_namespace_features)
507 })
508 .collect()
509 } else {
510 let mut segments = chunk.split('|');
511 let _label_section = segments.next().unwrap();
512 segments.map(extract_namespace_features).collect()
513 }
514 }
515}
516
517impl VwTextParser {
518 fn is_multiline(&self) -> bool {
519 self.feature_type == FeaturesType::SparseCBAdf && self.label_type == LabelType::CB
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use crate::{
526 error::Error,
527 parsers::vw_text_parser::{read_multi_lines, read_single_line},
528 };
529 use std::io::Cursor;
530
531 #[test]
532 fn chunk_multiline() -> Result<(), Error> {
533 let input = r#"line 1
534line 2"#;
535
536 let mut input = Cursor::new(input);
537 let res = read_multi_lines(&mut input, String::new())?;
538 assert_eq!(res, Some("line 1\nline 2".to_string()));
539 let res = read_multi_lines(&mut input, String::new())?;
540 assert_eq!(res, None);
541
542 let input = r#"
543
544
545line 1
546line 2"#;
547
548 let mut input = Cursor::new(input);
549 let res = read_multi_lines(&mut input, String::new())?;
550 assert_eq!(res, Some("line 1\nline 2".to_string()));
551 let res = read_multi_lines(&mut input, String::new())?;
552 assert_eq!(res, None);
553
554 let input = r#"
555
556
557line 1
558line 2
559
560 "#;
561 let mut input = Cursor::new(input);
562 let res = read_multi_lines(&mut input, String::new())?;
563 assert_eq!(res, Some("line 1\nline 2".to_string()));
564 let res = read_multi_lines(&mut input, String::new())?;
565 assert_eq!(res, None);
566
567 let input = r#"
568
569
570line 1
571line 2
572
573
574line 3
575line 4
576
577 "#;
578 let mut input = Cursor::new(input);
579 let res = read_multi_lines(&mut input, String::new())?;
580 assert_eq!(res, Some("line 1\nline 2".to_string()));
581 let res = read_multi_lines(&mut input, String::new())?;
582 assert_eq!(res, Some("line 3\nline 4".to_string()));
583 let res = read_multi_lines(&mut input, String::new())?;
584 assert_eq!(res, None);
585 let res = read_multi_lines(&mut input, String::new())?;
586 assert_eq!(res, None);
587 Ok(())
588 }
589
590 #[test]
591 fn chunk_singleline() -> Result<(), Error> {
592 let input = r#"line 1
593line 2"#;
594
595 let mut input = Cursor::new(input);
596 let res = read_single_line(&mut input, String::new())?;
597 assert_eq!(res, Some("line 1".to_string()));
598 let res = read_single_line(&mut input, String::new())?;
599 assert_eq!(res, Some("line 2".to_string()));
600 let res = read_single_line(&mut input, String::new())?;
601 assert_eq!(res, None);
602
603 let input = r#"
604
605
606line 1
607line 2"#;
608
609 let mut input = Cursor::new(input);
610 let res = read_single_line(&mut input, String::new())?;
611 assert_eq!(res, Some("line 1".to_string()));
612 let res = read_single_line(&mut input, String::new())?;
613 assert_eq!(res, Some("line 2".to_string()));
614 let res = read_single_line(&mut input, String::new())?;
615 assert_eq!(res, None);
616
617 let input = r#"
618
619
620line 1
621
622line 2
623
624 "#;
625 let mut input = Cursor::new(input);
626 let res = read_single_line(&mut input, String::new())?;
627 assert_eq!(res, Some("line 1".to_string()));
628 let res = read_single_line(&mut input, String::new())?;
629 assert_eq!(res, Some("line 2".to_string()));
630 let res = read_single_line(&mut input, String::new())?;
631 assert_eq!(res, None);
632
633 let input = r#"
634
635
636line 1
637line 2
638
639
640line 3
641line 4
642
643 "#;
644 let mut input = Cursor::new(input);
645 let res = read_single_line(&mut input, String::new())?;
646 assert_eq!(res, Some("line 1".to_string()));
647 let res = read_single_line(&mut input, String::new())?;
648 assert_eq!(res, Some("line 2".to_string()));
649 let res = read_single_line(&mut input, String::new())?;
650 assert_eq!(res, Some("line 3".to_string()));
651 let res = read_single_line(&mut input, String::new())?;
652 assert_eq!(res, Some("line 4".to_string()));
653 let res = read_single_line(&mut input, String::new())?;
654 assert_eq!(res, None);
655 Ok(())
656 }
657}