hat_splitter/
split.rs

1use once_cell::sync::Lazy;
2use regex::Regex;
3use unicode_segmentation::UnicodeSegmentation;
4
5enum Token {
6    Word(String),
7    Punctuation(String),
8    Whitespace(String),
9    Space(String),
10}
11
12impl Token {
13    fn inner(self) -> String {
14        match self {
15            Token::Word(s) | Token::Punctuation(s) | Token::Whitespace(s) | Token::Space(s) => s,
16        }
17    }
18}
19
20pub trait Splitter {
21    // At some point it would be great to do this without allocations...
22    //fn split<'a>(&self, input: &'a str) -> Vec<&'a str>;
23
24    /// Splits a string into words.
25    fn split(&self, text: &str) -> Vec<String>;
26
27    /// Splits a string into words and limits the size of each word to `max_bytes_per_word`. As
28    /// this function enforces a byte limit, it may split unicode characters. That is, this
29    /// function does not guarantee that the resulting byte arrays are valid UTF-8.
30    fn split_with_limit(&self, text: &str, max_bytes_per_word: usize) -> Vec<Vec<u8>>;
31}
32
33pub struct HATSplitter;
34
35impl Default for HATSplitter {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl HATSplitter {
42    pub fn new() -> Self {
43        Self
44    }
45
46    fn unicode_word_split(input: &str) -> Vec<&str> {
47        input.split_word_bounds().collect::<Vec<&str>>()
48    }
49
50    fn split_at_matches<'a>(s: &'a str, re: &Regex) -> Vec<&'a str> {
51        let mut result = Vec::new();
52        let mut word_start = 0;
53
54        for regex_match in re.find_iter(s) {
55            let match_start = regex_match.start();
56
57            // We can unwrap here as we assume the regex match points to a valid UTF-8 character
58            let word_end = match_start + s[match_start..].chars().next().unwrap().len_utf8();
59
60            result.push(&s[word_start..word_end]);
61            word_start = word_end;
62        }
63
64        if word_start < s.len() {
65            result.push(&s[word_start..s.len()]);
66        }
67
68        result
69    }
70
71    fn split_camel_case(s: &str) -> Vec<&str> {
72        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\p{Ll})(\p{Lu})").unwrap());
73        Self::split_at_matches(s, &RE)
74    }
75
76    fn split_punctuation(s: &str) -> Vec<&str> {
77        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\p{P}").unwrap());
78        Self::split_at_matches(s, &RE)
79    }
80
81    fn combine_spaces(strings: Vec<&str>) -> Vec<String> {
82        strings.into_iter().fold(Vec::new(), |mut acc, s| {
83            if s == " " {
84                // If we have a space and the last element is also spaces, append to it
85                if let Some(last) = acc.last_mut() {
86                    if last.chars().all(|c| c == ' ') {
87                        last.push(' ');
88                        return acc;
89                    }
90                }
91            }
92            // Otherwise add as a new element
93            acc.push(s.to_string());
94            acc
95        })
96    }
97
98    // This function does its best to avoid splitting unicode characters, but in some cases it has
99    // no choice (e.g., if max_bytes < 4 and an emoji comes in).
100    fn split_long_words(strings: Vec<String>, max_bytes: usize) -> Vec<Vec<u8>> {
101        if max_bytes == 0 {
102            panic!("max_bytes must be greater than 0");
103        }
104        strings.into_iter().fold(Vec::new(), |mut result, string| {
105            let bytes = string.as_bytes();
106            if bytes.len() <= max_bytes {
107                result.push(bytes.to_vec());
108                return result;
109            }
110
111            let mut start_byte = 0;
112            while start_byte < bytes.len() {
113                let end_byte = std::cmp::min(start_byte + max_bytes, bytes.len());
114
115                // Backtrack to find a valid UTF-8 boundary
116                let end = (start_byte + 1..=end_byte)
117                    .rev()
118                    .find(|&i| string.is_char_boundary(i))
119                    .unwrap_or(end_byte); // Fall back to end_byte if no boundary found
120
121                result.push(bytes[start_byte..end].to_vec());
122                start_byte = end;
123            }
124            result
125        })
126    }
127
128    /// The Lexer takes a string and splits it into logical tokens.
129    fn lex(s: &str) -> Vec<Token> {
130        static WHITESPACE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\s+$").unwrap());
131        static PUNCTUATION_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\p{P}$").unwrap());
132
133        let words = Self::combine_spaces(
134            Self::unicode_word_split(s)
135                .iter()
136                .flat_map(|s| Self::split_punctuation(s))
137                .flat_map(|s| Self::split_camel_case(s))
138                .collect::<Vec<&str>>(),
139        );
140
141        words
142            .into_iter()
143            .map(|s| {
144                if s == " " {
145                    Token::Space(s)
146                } else if WHITESPACE_RE.is_match(s.as_str()) {
147                    Token::Whitespace(s)
148                } else if PUNCTUATION_RE.is_match(s.as_str()) {
149                    Token::Punctuation(s)
150                } else {
151                    Token::Word(s)
152                }
153            })
154            .collect()
155    }
156
157    /// The Parser takes tokens and groups them into a string split.
158    fn parse(tokens: Vec<Token>) -> Vec<String> {
159        let groups = tokens
160            .into_iter()
161            .fold(Vec::<Vec<Token>>::new(), |mut groups, token| {
162                let should_append_to_last_group = |last_group: &Vec<Token>, token: &Token| {
163                    matches!(
164                        (last_group.last(), token),
165                        (Some(Token::Space(_)), Token::Word(_))
166                            | (
167                                Some(Token::Space(_) | Token::Word(_) | Token::Punctuation(_)),
168                                Token::Punctuation(_),
169                            )
170                    )
171                };
172
173                if let Some(last_group) = groups.last_mut() {
174                    if should_append_to_last_group(last_group, &token) {
175                        last_group.push(token);
176                        return groups;
177                    }
178                }
179
180                groups.push(vec![token]);
181                groups
182            });
183
184        // Concatenate groups
185        groups
186            .into_iter()
187            .map(|group| group.into_iter().map(Token::inner).collect())
188            .collect()
189    }
190}
191
192impl Splitter for HATSplitter {
193    fn split(&self, input: &str) -> Vec<String> {
194        Self::parse(Self::lex(input))
195    }
196
197    fn split_with_limit(&self, input: &str, max_bytes: usize) -> Vec<Vec<u8>> {
198        Self::split_long_words(Self::parse(Self::lex(input)), max_bytes)
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    static STRANGE_STUFF: &str = "𓀀✨𝒜𝓁𝑔𝑜𝓇𝒾𝓉𝒽𝓂 شْء 你好吗 こんにちは 안녕하세요 𞤢𞤭𞤤 𝔽(λx.𝑥²) 🤖🍕⟨𝛴, 𝜋⟩ 🜚 𝔽↦𝑒ⁿω₀📡;𝑧𝑎<𝔱𝓇𝑢∃>🛠️ҀЋހ±(Δ𝓧) 乁( •_• )ㄏ   ⿰木日👾";
207
208    #[test]
209    fn it_works() {
210        let result = HATSplitter::new().split("Hello, world!");
211
212        assert_eq!(result, vec!["Hello,", " world!"]);
213    }
214
215    #[test]
216    fn it_handles_empty_input() {
217        let result = HATSplitter::new().split("");
218
219        assert!(result.is_empty());
220    }
221
222    #[test]
223    fn it_splits_camel_case() {
224        let result = HATSplitter::new().split("howAreYou");
225
226        assert_eq!(result, vec!["how", "Are", "You"]);
227    }
228
229    #[test]
230    fn it_splits_snake_case() {
231        let result = HATSplitter::new().split("how_are_you");
232
233        assert_eq!(result, vec!["how_", "are_", "you"]);
234    }
235
236    #[test]
237    fn it_limits_word_size() {
238        let result = HATSplitter::new().split_with_limit("verylongword", 10);
239
240        assert_eq!(result, vec![b"verylongwo".to_vec(), b"rd".to_vec()]);
241    }
242
243    #[test]
244    fn it_splits_large_unicode_characters() {
245        let result = HATSplitter::new().split_with_limit("🌝", 2);
246
247        assert_eq!(result.len(), 2);
248    }
249
250    #[test]
251    fn it_does_not_split_unicode_where_possible() {
252        // This is one word with a 2-byte 'Ăź' starting at byte offset 1. We hope that the splitter
253        // preserves this character by splitting into three parts instead of two.
254        let result = HATSplitter::new().split_with_limit("fĂźr", 2);
255
256        assert_eq!(
257            result,
258            vec![b"f".to_vec(), "Ăź".as_bytes().to_vec(), b"r".to_vec()]
259        );
260    }
261
262    #[test]
263    #[should_panic]
264    fn it_handles_zero_max_bytes() {
265        HATSplitter::new().split_with_limit("abc", 0);
266    }
267
268    #[test]
269    fn it_handles_strange_stuff() {
270        HATSplitter::new().split_with_limit(STRANGE_STUFF, 100);
271    }
272
273    #[test]
274    fn it_is_causal() {
275        let max_chunk_size = 1024;
276        let splitter = HATSplitter::new();
277
278        let full_split = splitter.split_with_limit(STRANGE_STUFF, max_chunk_size);
279
280        for (i, _) in STRANGE_STUFF.char_indices() {
281            let prefix = &STRANGE_STUFF[..i];
282            let partial_split = splitter.split_with_limit(prefix, max_chunk_size);
283
284            for (full_word, partial_word) in full_split.iter().zip(partial_split.iter()) {
285                assert_eq!(&full_word[..partial_word.len()], partial_word);
286            }
287        }
288    }
289}