hat_splitter/
split.rs

1use once_cell::sync::Lazy;
2use regex::Regex;
3use unicode_segmentation::UnicodeSegmentation;
4
5enum Token {
6    Word(String),
7    Punctuation(String),
8    Whitespace(String),
9    Space(String),
10}
11
12impl Token {
13    fn inner(self) -> String {
14        match self {
15            Token::Word(s) | Token::Punctuation(s) | Token::Whitespace(s) | Token::Space(s) => s,
16        }
17    }
18}
19
20pub trait Splitter {
21    // At some point it would be great to do this without allocations...
22    //fn split<'a>(&self, input: &'a str) -> Vec<&'a str>;
23
24    /// Splits a string into words.
25    fn split(&self, input: &str) -> Vec<String>;
26
27    /// Splits a string into words and limits the size of each word to `max_bytes`. As this
28    /// function enforces a byte limit, it may split unicode characters. That is, this function
29    /// does not guarantee that the resulting byte arrays are valid UTF-8.
30    fn split_with_limit(&self, input: &str, max_bytes_per_word: usize) -> Vec<Vec<u8>>;
31}
32
33pub struct HATSplitter;
34
35impl Default for HATSplitter {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl HATSplitter {
42    pub fn new() -> Self {
43        Self
44    }
45
46    fn unicode_word_split(input: &str) -> Vec<&str> {
47        input.split_word_bounds().collect::<Vec<&str>>()
48    }
49
50    fn split_at_matches<'a>(s: &'a str, re: &Regex) -> Vec<&'a str> {
51        let mut result = Vec::new();
52        let mut word_start = 0;
53
54        for regex_match in re.find_iter(s) {
55            let match_start = regex_match.start();
56
57            // We can unwrap here as we assume the regex match points to a valid UTF-8 character
58            let word_end = match_start + s[match_start..].chars().next().unwrap().len_utf8();
59
60            result.push(&s[word_start..word_end]);
61            word_start = word_end;
62        }
63
64        if word_start < s.len() {
65            result.push(&s[word_start..s.len()]);
66        }
67
68        result
69    }
70
71    fn split_camel_case(s: &str) -> Vec<&str> {
72        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\p{Ll})(\p{Lu})").unwrap());
73        Self::split_at_matches(s, &RE)
74    }
75
76    fn split_punctuation(s: &str) -> Vec<&str> {
77        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\p{P}").unwrap());
78        Self::split_at_matches(s, &RE)
79    }
80
81    fn combine_spaces(strings: Vec<&str>) -> Vec<String> {
82        strings.into_iter().fold(Vec::new(), |mut acc, s| {
83            if s == " " {
84                // If we have a space and the last element is also spaces, append to it
85                if let Some(last) = acc.last_mut() {
86                    if last.chars().all(|c| c == ' ') {
87                        last.push(' ');
88                        return acc;
89                    }
90                }
91            }
92            // Otherwise add as a new element
93            acc.push(s.to_string());
94            acc
95        })
96    }
97
98    // This function does its best to avoid splitting unicode characters, but in some cases it has
99    // no choice (e.g., if max_bytes < 4 and an emoji comes in).
100    fn split_long_words(strings: Vec<String>, max_bytes: usize) -> Vec<Vec<u8>> {
101        if max_bytes == 0 {
102            panic!("max_bytes must be greater than 0");
103        }
104        strings.into_iter().fold(Vec::new(), |mut result, string| {
105            let bytes = string.as_bytes();
106            if bytes.len() <= max_bytes {
107                result.push(bytes.to_vec());
108                return result;
109            }
110
111            let mut start_byte = 0;
112            while start_byte < bytes.len() {
113                let end_byte = std::cmp::min(start_byte + max_bytes, bytes.len());
114
115                // Backtrack to find a valid UTF-8 boundary
116                let end = (start_byte + 1..=end_byte)
117                    .rev()
118                    .find(|&i| string.is_char_boundary(i))
119                    .unwrap_or(end_byte); // Fall back to end_byte if no boundary found
120
121                result.push(bytes[start_byte..end].to_vec());
122                start_byte = end;
123            }
124            result
125        })
126    }
127
128    /// The Lexer takes a string and splits it into logical tokens.
129    fn lex(s: &str) -> Vec<Token> {
130        static WHITESPACE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\s+$").unwrap());
131        static PUNCTUATION_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\p{P}$").unwrap());
132
133        let words = Self::unicode_word_split(s);
134
135        let words = words
136            .iter()
137            .flat_map(|s| Self::split_punctuation(s))
138            .flat_map(|s| Self::split_camel_case(s))
139            .collect::<Vec<&str>>();
140
141        let words = Self::combine_spaces(words);
142
143        words
144            .into_iter()
145            .map(|s| {
146                if s == " " {
147                    Token::Space(s)
148                } else if WHITESPACE_RE.is_match(s.as_str()) {
149                    Token::Whitespace(s)
150                } else if PUNCTUATION_RE.is_match(s.as_str()) {
151                    Token::Punctuation(s)
152                } else {
153                    Token::Word(s)
154                }
155            })
156            .collect()
157    }
158
159    /// The Parser takes tokens and groups them into a string split.
160    fn parse(tokens: Vec<Token>) -> Vec<String> {
161        let groups = tokens
162            .into_iter()
163            .fold(Vec::<Vec<Token>>::new(), |mut groups, token| {
164                let should_append_to_last_group = |last_group: &Vec<Token>, token: &Token| {
165                    matches!(
166                        (last_group.last(), token),
167                        (Some(Token::Space(_)), Token::Word(_))
168                            | (
169                                Some(Token::Space(_) | Token::Word(_) | Token::Punctuation(_)),
170                                Token::Punctuation(_),
171                            )
172                    )
173                };
174
175                if let Some(last_group) = groups.last_mut() {
176                    if should_append_to_last_group(last_group, &token) {
177                        last_group.push(token);
178                        return groups;
179                    }
180                }
181
182                groups.push(vec![token]);
183                groups
184            });
185
186        // Concatenate groups
187        groups
188            .into_iter()
189            .map(|group| group.into_iter().map(Token::inner).collect())
190            .collect()
191    }
192}
193
194impl Splitter for HATSplitter {
195    fn split(&self, input: &str) -> Vec<String> {
196        Self::parse(Self::lex(input))
197    }
198
199    fn split_with_limit(&self, input: &str, max_bytes: usize) -> Vec<Vec<u8>> {
200        Self::split_long_words(Self::parse(Self::lex(input)), max_bytes)
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    static STRANGE_STUFF: &str = "𓀀✨𝒜𝓁𝑔𝑜𝓇𝒾𝓉𝒽𝓂 شْء 你好吗 こんにちは 안녕하세요 𞤢𞤭𞤤 𝔽(λx.𝑥²) 🤖🍕⟨𝛴, 𝜋⟩ 🜚 𝔽↦𝑒ⁿω₀📡;𝑧𝑎<𝔱𝓇𝑢∃>🛠️ҀЋހ±(Δ𝓧) 乁( •_• )ㄏ   ⿰木日👾";
209
210    #[test]
211    fn it_works() {
212        let result = HATSplitter::new().split("Hello, world!");
213
214        assert_eq!(result, vec!["Hello,", " world!"]);
215    }
216
217    #[test]
218    fn it_handles_empty_input() {
219        let result = HATSplitter::new().split("");
220
221        assert!(result.is_empty());
222    }
223
224    #[test]
225    fn it_splits_camel_case() {
226        let result = HATSplitter::new().split("howAreYou");
227
228        assert_eq!(result, vec!["how", "Are", "You"]);
229    }
230
231    #[test]
232    fn it_splits_snake_case() {
233        let result = HATSplitter::new().split("how_are_you");
234
235        assert_eq!(result, vec!["how_", "are_", "you"]);
236    }
237
238    #[test]
239    fn it_limits_word_size() {
240        let result = HATSplitter::new().split_with_limit("verylongword", 10);
241
242        assert_eq!(result, vec![b"verylongwo".to_vec(), b"rd".to_vec()]);
243    }
244
245    #[test]
246    fn it_splits_large_unicode_characters() {
247        let result = HATSplitter::new().split_with_limit("🌝", 2);
248
249        assert_eq!(result.len(), 2);
250    }
251
252    #[test]
253    fn it_does_not_split_unicode_where_possible() {
254        // This is one word with a 2-byte 'Ăź' starting at byte offset 1. We hope that the splitter
255        // preserves this character by splitting into three parts instead of two.
256        let result = HATSplitter::new().split_with_limit("fĂźr", 2);
257
258        assert_eq!(
259            result,
260            vec![b"f".to_vec(), "Ăź".as_bytes().to_vec(), b"r".to_vec()]
261        );
262    }
263
264    #[test]
265    #[should_panic]
266    fn it_handles_zero_max_bytes() {
267        HATSplitter::new().split_with_limit("abc", 0);
268    }
269
270    #[test]
271    fn it_handles_strange_stuff() {
272        HATSplitter::new().split_with_limit(STRANGE_STUFF, 100);
273    }
274
275    #[test]
276    fn it_is_causal() {
277        let max_chunk_size = 1024;
278        let splitter = HATSplitter::new();
279
280        let full_split = splitter.split_with_limit(STRANGE_STUFF, max_chunk_size);
281
282        for (i, _) in STRANGE_STUFF.char_indices() {
283            let prefix = &STRANGE_STUFF[..i];
284            let partial_split = splitter.split_with_limit(prefix, max_chunk_size);
285
286            for (full_word, partial_word) in full_split.iter().zip(partial_split.iter()) {
287                assert_eq!(&full_word[..partial_word.len()], partial_word);
288            }
289        }
290    }
291}