tokenizers/pre_tokenizers/
punctuation.rs1use serde::{Deserialize, Serialize};
2
3use crate::tokenizer::{PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
4use crate::utils::macro_rules_attribute;
5use unicode_categories::UnicodeCategories;
6
7fn is_punc(x: char) -> bool {
8 char::is_ascii_punctuation(&x) || x.is_punctuation()
9}
10
11#[derive(Copy, Clone, Debug, PartialEq, Eq)]
12#[macro_rules_attribute(impl_serde_type!)]
13pub struct Punctuation {
14 #[serde(default = "default_split")]
15 pub behavior: SplitDelimiterBehavior,
16}
17
18fn default_split() -> SplitDelimiterBehavior {
19 SplitDelimiterBehavior::Isolated
20}
21
22impl Punctuation {
23 pub fn new(behavior: SplitDelimiterBehavior) -> Self {
24 Self { behavior }
25 }
26}
27
28impl Default for Punctuation {
29 fn default() -> Self {
30 Self::new(SplitDelimiterBehavior::Isolated)
31 }
32}
33
34impl PreTokenizer for Punctuation {
35 fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
36 pretokenized.split(|_, s| s.split(is_punc, self.behavior))
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use crate::{OffsetReferential, OffsetType};
44
45 #[test]
46 fn punctuation_basic() {
47 let pretok = Punctuation::default();
48 let mut pretokenized: PreTokenizedString = "Hey friend! How are you?!?".into();
49 pretok.pre_tokenize(&mut pretokenized).unwrap();
50 assert_eq!(
51 pretokenized
52 .get_splits(OffsetReferential::Original, OffsetType::Byte)
53 .into_iter()
54 .map(|(s, o, _)| (s, o))
55 .collect::<Vec<_>>(),
56 vec![
57 ("Hey friend", (0, 10)),
58 ("!", (10, 11)),
59 (" How are you", (11, 27)),
60 ("?", (27, 28)),
61 ("!", (28, 29)),
62 ("?", (29, 30)),
63 ]
64 );
65 }
66
67 #[test]
68 fn deserialization() {
69 let punctuation: Punctuation = serde_json::from_str(r#"{"type": "Punctuation"}"#).unwrap();
70 assert_eq!(punctuation, Punctuation::default());
71 assert_eq!(
72 punctuation,
73 Punctuation::new(SplitDelimiterBehavior::Isolated)
74 );
75 }
76
77 #[test]
78 #[should_panic]
79 fn deserialization_erroneous() {
80 let _punctuation: Punctuation =
81 serde_json::from_str(r#"{"type": "WhitespaceSplit"}"#).unwrap();
82 }
83}