tokenizers/pre_tokenizers/
digits.rs1use serde::{Deserialize, Serialize};
2
3use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
4use crate::utils::macro_rules_attribute;
5
6#[derive(Clone, Debug, PartialEq, Eq)]
7#[non_exhaustive]
10#[macro_rules_attribute(impl_serde_type!)]
11pub struct Digits {
12 pub individual_digits: bool,
13}
14
15impl Digits {
16 pub fn new(individual_digits: bool) -> Self {
17 Self { individual_digits }
18 }
19}
20
21impl Default for Digits {
22 fn default() -> Self {
23 Self::new(false)
24 }
25}
26
27impl PreTokenizer for Digits {
28 fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
29 if self.individual_digits {
30 pretokenized.split(|_, normalized| {
31 normalized.split(char::is_numeric, SplitDelimiterBehavior::Isolated)
32 })
33 } else {
34 pretokenized.split(|_, normalized| {
35 normalized.split(char::is_numeric, SplitDelimiterBehavior::Contiguous)
36 })
37 }
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44 use crate::{OffsetReferential, OffsetType};
45
46 #[test]
47 fn numbers() {
48 let pretok = Digits::new(false);
49 let mut pretokenized = PreTokenizedString::from("Hey 123 friend!");
50 pretok.pre_tokenize(&mut pretokenized).unwrap();
51 assert_eq!(
52 pretokenized
53 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
54 .into_iter()
55 .map(|(s, o, _)| (s, o))
56 .collect::<Vec<_>>(),
57 vec![("Hey ", (0, 4)), ("123", (4, 7)), (" friend!", (7, 15))]
58 );
59 assert_eq!(
60 pretokenized
61 .get_splits(OffsetReferential::Original, OffsetType::Byte)
62 .into_iter()
63 .map(|(s, o, _)| (s, o))
64 .collect::<Vec<_>>(),
65 vec![("Hey ", (0, 4)), ("123", (4, 7)), (" friend!", (7, 15))]
66 );
67 }
68 #[test]
69 fn individual_digits() {
70 let pretok = Digits::new(true);
71 let mut pretokenized = PreTokenizedString::from("Hey 123 friend!");
72 pretok.pre_tokenize(&mut pretokenized).unwrap();
73 assert_eq!(
74 pretokenized
75 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
76 .into_iter()
77 .map(|(s, o, _)| (s, o))
78 .collect::<Vec<_>>(),
79 vec![
80 ("Hey ", (0, 4)),
81 ("1", (4, 5)),
82 ("2", (5, 6)),
83 ("3", (6, 7)),
84 (" friend!", (7, 15))
85 ]
86 );
87 assert_eq!(
88 pretokenized
89 .get_splits(OffsetReferential::Original, OffsetType::Byte)
90 .into_iter()
91 .map(|(s, o, _)| (s, o))
92 .collect::<Vec<_>>(),
93 vec![
94 ("Hey ", (0, 4)),
95 ("1", (4, 5)),
96 ("2", (5, 6)),
97 ("3", (6, 7)),
98 (" friend!", (7, 15))
99 ]
100 );
101 }
102}