1use std::sync::LazyLock;
2
3use icu_segmenter::WordSegmenter;
4use once_cell::sync::Lazy;
5use regex::Regex;
6
7enum Token {
8 Word(String),
9 Punctuation(String),
10 Whitespace(String),
11 Space(String),
12}
13
14impl Token {
15 fn inner(self) -> String {
16 match self {
17 Token::Word(s) | Token::Punctuation(s) | Token::Whitespace(s) | Token::Space(s) => s,
18 }
19 }
20}
21
22pub trait Splitter {
23 fn split(&self, input: &str) -> Vec<String>;
28
29 fn split_with_limit(&self, input: &str, max_bytes_per_word: usize) -> Vec<Vec<u8>>;
33}
34
35pub struct HATSplitter;
36
37impl Default for HATSplitter {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl HATSplitter {
44 pub fn new() -> Self {
45 Self
46 }
47
48 fn unicode_word_split(input: &str) -> Vec<&str> {
49 static WORD_SEGMENTER: LazyLock<WordSegmenter> =
51 LazyLock::new(WordSegmenter::new_dictionary);
52 let breakpoints: Vec<usize> = WORD_SEGMENTER.segment_str(input).collect();
53 breakpoints.windows(2).map(|w| &input[w[0]..w[1]]).collect()
54 }
55
56 fn split_at_matches<'a>(s: &'a str, re: &Regex) -> Vec<&'a str> {
57 let mut result = Vec::new();
58 let mut match_start = 0;
59
60 for regex_match in re.find_iter(s) {
61 let match_end = regex_match.start() + 1;
62 result.push(&s[match_start..match_end]);
63 match_start = match_end;
64 }
65
66 if match_start < s.len() {
67 result.push(&s[match_start..s.len()]);
68 }
69
70 result
71 }
72
73 fn split_camel_case(s: &str) -> Vec<&str> {
74 static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\p{Ll})(\p{Lu})").unwrap());
75 Self::split_at_matches(s, &RE)
76 }
77
78 fn split_punctuation(s: &str) -> Vec<&str> {
79 static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\p{P}").unwrap());
80 Self::split_at_matches(s, &RE)
81 }
82
83 fn combine_spaces(strings: Vec<&str>) -> Vec<String> {
84 strings.into_iter().fold(Vec::new(), |mut acc, s| {
85 if s == " " {
86 if let Some(last) = acc.last_mut() {
88 if last.chars().all(|c| c == ' ') {
89 last.push(' ');
90 return acc;
91 }
92 }
93 }
94 acc.push(s.to_string());
96 acc
97 })
98 }
99
100 fn split_long_words(strings: Vec<String>, max_bytes: usize) -> Vec<Vec<u8>> {
103 if max_bytes == 0 {
104 panic!("max_bytes must be greater than 0");
105 }
106 strings.into_iter().fold(Vec::new(), |mut result, string| {
107 let bytes = string.as_bytes();
108 if bytes.len() <= max_bytes {
109 result.push(bytes.to_vec());
110 return result;
111 }
112
113 let mut start_byte = 0;
114 while start_byte < bytes.len() {
115 let end_byte = std::cmp::min(start_byte + max_bytes, bytes.len());
116
117 let end = (start_byte + 1..=end_byte)
119 .rev()
120 .find(|&i| string.is_char_boundary(i))
121 .unwrap_or(end_byte); result.push(bytes[start_byte..end].to_vec());
124 start_byte = end;
125 }
126 result
127 })
128 }
129
130 fn lex(s: &str) -> Vec<Token> {
132 static WHITESPACE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\s+$").unwrap());
133 static PUNCTUATION_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\p{P}$").unwrap());
134
135 let words = Self::unicode_word_split(s);
136
137 let words = words
138 .iter()
139 .flat_map(|s| Self::split_punctuation(s))
140 .flat_map(|s| Self::split_camel_case(s))
141 .collect::<Vec<&str>>();
142
143 let words = Self::combine_spaces(words.clone());
144
145 words
146 .into_iter()
147 .map(|s| {
148 if s == " " {
149 Token::Space(s)
150 } else if WHITESPACE_RE.is_match(s.as_str()) {
151 Token::Whitespace(s)
152 } else if PUNCTUATION_RE.is_match(s.as_str()) {
153 Token::Punctuation(s)
154 } else {
155 Token::Word(s)
156 }
157 })
158 .collect()
159 }
160
161 fn parse(tokens: Vec<Token>) -> Vec<String> {
163 let groups = tokens
164 .into_iter()
165 .fold(Vec::<Vec<Token>>::new(), |mut groups, token| {
166 let should_append_to_last_group = |last_group: &Vec<Token>, token: &Token| {
167 matches!(
168 (last_group.last(), token),
169 (Some(Token::Space(_)), Token::Word(_))
170 | (
171 Some(Token::Space(_) | Token::Word(_) | Token::Punctuation(_)),
172 Token::Punctuation(_),
173 )
174 )
175 };
176
177 if let Some(last_group) = groups.last_mut() {
178 if should_append_to_last_group(last_group, &token) {
179 last_group.push(token);
180 return groups;
181 }
182 }
183
184 groups.push(vec![token]);
185 groups
186 });
187
188 groups
190 .into_iter()
191 .map(|group| group.into_iter().map(Token::inner).collect())
192 .collect()
193 }
194}
195
196impl Splitter for HATSplitter {
197 fn split(&self, input: &str) -> Vec<String> {
198 Self::parse(Self::lex(input))
199 }
200
201 fn split_with_limit(&self, input: &str, max_bytes: usize) -> Vec<Vec<u8>> {
202 Self::split_long_words(Self::parse(Self::lex(input)), max_bytes)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn it_works() {
212 let result = HATSplitter::new().split("Hello, world!");
213
214 assert_eq!(result, vec!["Hello,", " world!"]);
215 }
216
217 #[test]
218 fn it_handles_empty_input() {
219 let result = HATSplitter::new().split("");
220
221 assert!(result.is_empty());
222 }
223
224 #[test]
225 fn it_splits_camel_case() {
226 let result = HATSplitter::new().split("howAreYou");
227
228 assert_eq!(result, vec!["how", "Are", "You"]);
229 }
230
231 #[test]
232 fn it_splits_snake_case() {
233 let result = HATSplitter::new().split("how_are_you");
234
235 assert_eq!(result, vec!["how_", "are_", "you"]);
236 }
237
238 #[test]
239 fn it_limits_word_size() {
240 let result = HATSplitter::new().split_with_limit("verylongword", 10);
241
242 assert_eq!(result, vec![b"verylongwo".to_vec(), b"rd".to_vec()]);
243 }
244
245 #[test]
246 fn it_splits_large_unicode_characters() {
247 let result = HATSplitter::new().split_with_limit("🌝", 2);
248
249 assert_eq!(result.len(), 2);
250 }
251
252 #[test]
253 fn it_does_not_split_unicode_where_possible() {
254 let result = HATSplitter::new().split_with_limit("für", 2);
257
258 assert_eq!(
259 result,
260 vec![b"f".to_vec(), "ü".as_bytes().to_vec(), b"r".to_vec()]
261 );
262 }
263
264 #[test]
265 #[should_panic]
266 fn it_handles_zero_max_bytes() {
267 HATSplitter::new().split_with_limit("abc", 0);
268 }
269}