unobtanium_segmenter/augmentation/
classify.rs1use unicode_properties::GeneralCategoryGroup;
2use unicode_properties::UnicodeGeneralCategory;
3
4use crate::augmentation::Augmenter;
5use crate::SegmentedToken;
6use crate::SegmentedTokenKind;
7
8#[derive(Debug, Clone, Default)]
24pub struct AugmentationClassify {}
25
26impl AugmentationClassify {
27 pub fn new() -> Self {
29 Default::default()
30 }
31}
32
33impl Augmenter for AugmentationClassify {
34 fn augment<'a>(&self, mut token: SegmentedToken<'a>) -> SegmentedToken<'a> {
35 let mut has_seperators = false;
36 let mut has_symbols = false;
37 for c in token.get_text_prefer_normalized().chars() {
38 match c.general_category_group() {
39 GeneralCategoryGroup::Letter | GeneralCategoryGroup::Number => {
40 token.kind = Some(SegmentedTokenKind::AlphaNumeric);
41 return token;
42 }
43 GeneralCategoryGroup::Punctuation | GeneralCategoryGroup::Separator => {
44 has_seperators = true
45 }
46 GeneralCategoryGroup::Symbol | GeneralCategoryGroup::Other => match c {
47 '\n' | '\0' => has_seperators = true,
48 _ => has_symbols = true,
49 },
50 GeneralCategoryGroup::Mark => { }
51 }
52 }
53 if has_symbols {
54 token.kind = Some(SegmentedTokenKind::Symbol);
55 return token;
56 }
57 if has_seperators {
58 token.kind = Some(SegmentedTokenKind::Separator);
59 return token;
60 }
61 token.kind = None;
62 return token;
63 }
64}
65
66#[cfg(test)]
67mod test {
68
69 use super::*;
70
71 use crate::chain::ChainAugmenter;
72 use crate::chain::ChainSegmenter;
73 use crate::chain::StartSegmentationChain;
74 use crate::segmentation::UnicodeWordSplitter;
75
76 fn a() -> Option<SegmentedTokenKind> {
77 Some(SegmentedTokenKind::AlphaNumeric)
78 }
79
80 fn s() -> Option<SegmentedTokenKind> {
81 Some(SegmentedTokenKind::Separator)
82 }
83
84 fn y() -> Option<SegmentedTokenKind> {
85 Some(SegmentedTokenKind::Symbol)
86 }
87
88 #[test]
89 fn test_unicode_word_split() {
90 let test_text = "The quick (\"brown\") fox🦊 can't jump 32.3 feet, right?\nThe quick (\"brown\") fox. The value of π in german is '3,141592…'.";
91
92 let word_splitter = UnicodeWordSplitter::new();
93 let classifier = AugmentationClassify::new();
94
95 let result: Vec<(&str, Option<SegmentedTokenKind>)> = test_text
96 .start_segmentation_chain()
97 .chain_segmenter(&word_splitter)
98 .chain_augmenter(&classifier)
99 .map(|t| (t.text, t.kind))
100 .collect();
101
102 let expected_tokens = vec![
103 ("The", a()),
104 (" ", s()),
105 ("quick", a()),
106 (" ", s()),
107 ("(", s()),
108 ("\"", s()),
109 ("brown", a()),
110 ("\"", s()),
111 (")", s()),
112 (" ", s()),
113 ("fox", a()),
114 ("🦊", y()),
115 (" ", s()),
116 ("can't", a()),
117 (" ", s()),
118 ("jump", a()),
119 (" ", s()),
120 ("32.3", a()),
121 (" ", s()),
122 ("feet", a()),
123 (",", s()),
124 (" ", s()),
125 ("right", a()),
126 ("?", s()),
127 ("\n", s()),
128 ("The", a()),
129 (" ", s()),
130 ("quick", a()),
131 (" ", s()),
132 ("(", s()),
133 ("\"", s()),
134 ("brown", a()),
135 ("\"", s()),
136 (")", s()),
137 (" ", s()),
138 ("fox", a()),
139 (".", s()),
140 (" ", s()),
141 ("The", a()),
142 (" ", s()),
143 ("value", a()),
144 (" ", s()),
145 ("of", a()),
146 (" ", s()),
147 ("Ï€", a()),
148 (" ", s()),
149 ("in", a()),
150 (" ", s()),
151 ("german", a()),
152 (" ", s()),
153 ("is", a()),
154 (" ", s()),
155 ("'", s()),
156 ("3,141592", a()),
157 ("…", s()),
158 ("'", s()),
159 (".", s()),
160 ];
161
162 assert_eq!(result, expected_tokens);
163 }
164}