1use once_cell::sync::Lazy;
2use regex::Regex;
3use unicode_segmentation::UnicodeSegmentation;
4
5enum Token {
6 Word(String),
7 Punctuation(String),
8 Whitespace(String),
9 Space(String),
10}
11
12impl Token {
13 fn inner(self) -> String {
14 match self {
15 Token::Word(s) | Token::Punctuation(s) | Token::Whitespace(s) | Token::Space(s) => s,
16 }
17 }
18}
19
20pub trait Splitter {
21 fn split(&self, text: &str) -> Vec<String>;
26
27 fn split_with_limit(&self, text: &str, max_bytes_per_word: usize) -> Vec<Vec<u8>>;
31}
32
33pub struct HATSplitter;
34
35impl Default for HATSplitter {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl HATSplitter {
42 pub fn new() -> Self {
43 Self
44 }
45
46 fn unicode_word_split(input: &str) -> Vec<&str> {
47 input.split_word_bounds().collect::<Vec<&str>>()
48 }
49
50 fn split_at_matches<'a>(s: &'a str, re: &Regex) -> Vec<&'a str> {
51 let mut result = Vec::new();
52 let mut word_start = 0;
53
54 for regex_match in re.find_iter(s) {
55 let match_start = regex_match.start();
56
57 let word_end = match_start + s[match_start..].chars().next().unwrap().len_utf8();
59
60 result.push(&s[word_start..word_end]);
61 word_start = word_end;
62 }
63
64 if word_start < s.len() {
65 result.push(&s[word_start..s.len()]);
66 }
67
68 result
69 }
70
71 fn split_camel_case(s: &str) -> Vec<&str> {
72 static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\p{Ll})(\p{Lu})").unwrap());
73 Self::split_at_matches(s, &RE)
74 }
75
76 fn split_punctuation(s: &str) -> Vec<&str> {
77 static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\p{P}").unwrap());
78 Self::split_at_matches(s, &RE)
79 }
80
81 fn combine_spaces(strings: Vec<&str>) -> Vec<String> {
82 strings.into_iter().fold(Vec::new(), |mut acc, s| {
83 if s == " " {
84 if let Some(last) = acc.last_mut() {
86 if last.chars().all(|c| c == ' ') {
87 last.push(' ');
88 return acc;
89 }
90 }
91 }
92 acc.push(s.to_string());
94 acc
95 })
96 }
97
98 fn split_long_words(strings: Vec<String>, max_bytes: usize) -> Vec<Vec<u8>> {
101 if max_bytes == 0 {
102 panic!("max_bytes must be greater than 0");
103 }
104 strings.into_iter().fold(Vec::new(), |mut result, string| {
105 let bytes = string.as_bytes();
106 if bytes.len() <= max_bytes {
107 result.push(bytes.to_vec());
108 return result;
109 }
110
111 let mut start_byte = 0;
112 while start_byte < bytes.len() {
113 let end_byte = std::cmp::min(start_byte + max_bytes, bytes.len());
114
115 let end = (start_byte + 1..=end_byte)
117 .rev()
118 .find(|&i| string.is_char_boundary(i))
119 .unwrap_or(end_byte); result.push(bytes[start_byte..end].to_vec());
122 start_byte = end;
123 }
124 result
125 })
126 }
127
128 fn lex(s: &str) -> Vec<Token> {
130 static WHITESPACE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\s+$").unwrap());
131 static PUNCTUATION_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\p{P}$").unwrap());
132
133 let words = Self::combine_spaces(
134 Self::unicode_word_split(s)
135 .iter()
136 .flat_map(|s| Self::split_punctuation(s))
137 .flat_map(|s| Self::split_camel_case(s))
138 .collect::<Vec<&str>>(),
139 );
140
141 words
142 .into_iter()
143 .map(|s| {
144 if s == " " {
145 Token::Space(s)
146 } else if WHITESPACE_RE.is_match(s.as_str()) {
147 Token::Whitespace(s)
148 } else if PUNCTUATION_RE.is_match(s.as_str()) {
149 Token::Punctuation(s)
150 } else {
151 Token::Word(s)
152 }
153 })
154 .collect()
155 }
156
157 fn parse(tokens: Vec<Token>) -> Vec<String> {
159 let groups = tokens
160 .into_iter()
161 .fold(Vec::<Vec<Token>>::new(), |mut groups, token| {
162 let should_append_to_last_group = |last_group: &Vec<Token>, token: &Token| {
163 matches!(
164 (last_group.last(), token),
165 (Some(Token::Space(_)), Token::Word(_))
166 | (
167 Some(Token::Space(_) | Token::Word(_) | Token::Punctuation(_)),
168 Token::Punctuation(_),
169 )
170 )
171 };
172
173 if let Some(last_group) = groups.last_mut() {
174 if should_append_to_last_group(last_group, &token) {
175 last_group.push(token);
176 return groups;
177 }
178 }
179
180 groups.push(vec![token]);
181 groups
182 });
183
184 groups
186 .into_iter()
187 .map(|group| group.into_iter().map(Token::inner).collect())
188 .collect()
189 }
190}
191
192impl Splitter for HATSplitter {
193 fn split(&self, input: &str) -> Vec<String> {
194 Self::parse(Self::lex(input))
195 }
196
197 fn split_with_limit(&self, input: &str, max_bytes: usize) -> Vec<Vec<u8>> {
198 Self::split_long_words(Self::parse(Self::lex(input)), max_bytes)
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 static STRANGE_STUFF: &str = "đâ¨đđđđđđžđđ˝đ Ř´ŮŘĄ ä˝ ĺĽ˝ĺ ăăăŤăĄăŻ ěë
íě¸ě đ¤˘đ¤đ¤¤ đ˝(Îťx.đĽÂ˛) đ¤đâ¨đ´, đ⊠đ đ˝âŚđâżĎâđĄ;đ§đ<đąđđ˘â>đ ď¸ŇĐŢÂą(Îđ§) äš( â˘_⢠)ă âż°ć¨ćĽđž";
207
208 #[test]
209 fn it_works() {
210 let result = HATSplitter::new().split("Hello, world!");
211
212 assert_eq!(result, vec!["Hello,", " world!"]);
213 }
214
215 #[test]
216 fn it_handles_empty_input() {
217 let result = HATSplitter::new().split("");
218
219 assert!(result.is_empty());
220 }
221
222 #[test]
223 fn it_splits_camel_case() {
224 let result = HATSplitter::new().split("howAreYou");
225
226 assert_eq!(result, vec!["how", "Are", "You"]);
227 }
228
229 #[test]
230 fn it_splits_snake_case() {
231 let result = HATSplitter::new().split("how_are_you");
232
233 assert_eq!(result, vec!["how_", "are_", "you"]);
234 }
235
236 #[test]
237 fn it_limits_word_size() {
238 let result = HATSplitter::new().split_with_limit("verylongword", 10);
239
240 assert_eq!(result, vec![b"verylongwo".to_vec(), b"rd".to_vec()]);
241 }
242
243 #[test]
244 fn it_splits_large_unicode_characters() {
245 let result = HATSplitter::new().split_with_limit("đ", 2);
246
247 assert_eq!(result.len(), 2);
248 }
249
250 #[test]
251 fn it_does_not_split_unicode_where_possible() {
252 let result = HATSplitter::new().split_with_limit("fĂźr", 2);
255
256 assert_eq!(
257 result,
258 vec![b"f".to_vec(), "Ăź".as_bytes().to_vec(), b"r".to_vec()]
259 );
260 }
261
262 #[test]
263 #[should_panic]
264 fn it_handles_zero_max_bytes() {
265 HATSplitter::new().split_with_limit("abc", 0);
266 }
267
268 #[test]
269 fn it_handles_strange_stuff() {
270 HATSplitter::new().split_with_limit(STRANGE_STUFF, 100);
271 }
272
273 #[test]
274 fn it_is_causal() {
275 let max_chunk_size = 1024;
276 let splitter = HATSplitter::new();
277
278 let full_split = splitter.split_with_limit(STRANGE_STUFF, max_chunk_size);
279
280 for (i, _) in STRANGE_STUFF.char_indices() {
281 let prefix = &STRANGE_STUFF[..i];
282 let partial_split = splitter.split_with_limit(prefix, max_chunk_size);
283
284 for (full_word, partial_word) in full_split.iter().zip(partial_split.iter()) {
285 assert_eq!(&full_word[..partial_word.len()], partial_word);
286 }
287 }
288 }
289}