hat_splitter/
split.rs

1use std::sync::LazyLock;
2
3use icu_segmenter::WordSegmenter;
4use once_cell::sync::Lazy;
5use regex::Regex;
6
7enum Token {
8    Word(String),
9    Punctuation(String),
10    Whitespace(String),
11    Space(String),
12}
13
14impl Token {
15    fn inner(self) -> String {
16        match self {
17            Token::Word(s) | Token::Punctuation(s) | Token::Whitespace(s) | Token::Space(s) => s,
18        }
19    }
20}
21
22pub trait Splitter {
23    // At some point it would be great to do this without allocations...
24    //fn split<'a>(&self, input: &'a str) -> Vec<&'a str>;
25
26    /// Splits a string into words.
27    fn split(&self, input: &str) -> Vec<String>;
28
29    /// Splits a string into words and limits the size of each word to `max_bytes`. As this
30    /// function enforces a byte limit, it may split unicode characters. That is, this function
31    /// does not guarantee that the resulting byte arrays are valid UTF-8.
32    fn split_with_limit(&self, input: &str, max_bytes_per_word: usize) -> Vec<Vec<u8>>;
33}
34
35pub struct HATSplitter;
36
37impl Default for HATSplitter {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl HATSplitter {
44    pub fn new() -> Self {
45        Self
46    }
47
48    fn unicode_word_split(input: &str) -> Vec<&str> {
49        // Note: we could also try `new_auto` which uses a LSTM (we should figure out which is better)
50        static WORD_SEGMENTER: LazyLock<WordSegmenter> =
51            LazyLock::new(WordSegmenter::new_dictionary);
52        let breakpoints: Vec<usize> = WORD_SEGMENTER.segment_str(input).collect();
53        breakpoints.windows(2).map(|w| &input[w[0]..w[1]]).collect()
54    }
55
56    fn split_at_matches<'a>(s: &'a str, re: &Regex) -> Vec<&'a str> {
57        let mut result = Vec::new();
58        let mut word_start = 0;
59
60        for regex_match in re.find_iter(s) {
61            let match_start = regex_match.start();
62
63            // We can unwrap here as we assume the regex match points to a valid UTF-8 character
64            let word_end = match_start + s[match_start..].chars().next().unwrap().len_utf8();
65
66            result.push(&s[word_start..word_end]);
67            word_start = word_end;
68        }
69
70        if word_start < s.len() {
71            result.push(&s[word_start..s.len()]);
72        }
73
74        result
75    }
76
77    fn split_camel_case(s: &str) -> Vec<&str> {
78        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\p{Ll})(\p{Lu})").unwrap());
79        Self::split_at_matches(s, &RE)
80    }
81
82    fn split_punctuation(s: &str) -> Vec<&str> {
83        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"\p{P}").unwrap());
84        Self::split_at_matches(s, &RE)
85    }
86
87    fn combine_spaces(strings: Vec<&str>) -> Vec<String> {
88        strings.into_iter().fold(Vec::new(), |mut acc, s| {
89            if s == " " {
90                // If we have a space and the last element is also spaces, append to it
91                if let Some(last) = acc.last_mut() {
92                    if last.chars().all(|c| c == ' ') {
93                        last.push(' ');
94                        return acc;
95                    }
96                }
97            }
98            // Otherwise add as a new element
99            acc.push(s.to_string());
100            acc
101        })
102    }
103
104    // This function does its best to avoid splitting unicode characters, but in some cases it has
105    // no choice (e.g., if max_bytes < 4 and an emoji comes in).
106    fn split_long_words(strings: Vec<String>, max_bytes: usize) -> Vec<Vec<u8>> {
107        if max_bytes == 0 {
108            panic!("max_bytes must be greater than 0");
109        }
110        strings.into_iter().fold(Vec::new(), |mut result, string| {
111            let bytes = string.as_bytes();
112            if bytes.len() <= max_bytes {
113                result.push(bytes.to_vec());
114                return result;
115            }
116
117            let mut start_byte = 0;
118            while start_byte < bytes.len() {
119                let end_byte = std::cmp::min(start_byte + max_bytes, bytes.len());
120
121                // Backtrack to find a valid UTF-8 boundary
122                let end = (start_byte + 1..=end_byte)
123                    .rev()
124                    .find(|&i| string.is_char_boundary(i))
125                    .unwrap_or(end_byte); // Fall back to end_byte if no boundary found
126
127                result.push(bytes[start_byte..end].to_vec());
128                start_byte = end;
129            }
130            result
131        })
132    }
133
134    /// The Lexer takes a string and splits it into logical tokens.
135    fn lex(s: &str) -> Vec<Token> {
136        static WHITESPACE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\s+$").unwrap());
137        static PUNCTUATION_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\p{P}$").unwrap());
138
139        let words = Self::unicode_word_split(s);
140
141        let words = words
142            .iter()
143            .flat_map(|s| Self::split_punctuation(s))
144            .flat_map(|s| Self::split_camel_case(s))
145            .collect::<Vec<&str>>();
146
147        let words = Self::combine_spaces(words);
148
149        words
150            .into_iter()
151            .map(|s| {
152                if s == " " {
153                    Token::Space(s)
154                } else if WHITESPACE_RE.is_match(s.as_str()) {
155                    Token::Whitespace(s)
156                } else if PUNCTUATION_RE.is_match(s.as_str()) {
157                    Token::Punctuation(s)
158                } else {
159                    Token::Word(s)
160                }
161            })
162            .collect()
163    }
164
165    /// The Parser takes tokens and groups them into a string split.
166    fn parse(tokens: Vec<Token>) -> Vec<String> {
167        let groups = tokens
168            .into_iter()
169            .fold(Vec::<Vec<Token>>::new(), |mut groups, token| {
170                let should_append_to_last_group = |last_group: &Vec<Token>, token: &Token| {
171                    matches!(
172                        (last_group.last(), token),
173                        (Some(Token::Space(_)), Token::Word(_))
174                            | (
175                                Some(Token::Space(_) | Token::Word(_) | Token::Punctuation(_)),
176                                Token::Punctuation(_),
177                            )
178                    )
179                };
180
181                if let Some(last_group) = groups.last_mut() {
182                    if should_append_to_last_group(last_group, &token) {
183                        last_group.push(token);
184                        return groups;
185                    }
186                }
187
188                groups.push(vec![token]);
189                groups
190            });
191
192        // Concatenate groups
193        groups
194            .into_iter()
195            .map(|group| group.into_iter().map(Token::inner).collect())
196            .collect()
197    }
198}
199
200impl Splitter for HATSplitter {
201    fn split(&self, input: &str) -> Vec<String> {
202        Self::parse(Self::lex(input))
203    }
204
205    fn split_with_limit(&self, input: &str, max_bytes: usize) -> Vec<Vec<u8>> {
206        Self::split_long_words(Self::parse(Self::lex(input)), max_bytes)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn it_works() {
216        let result = HATSplitter::new().split("Hello, world!");
217
218        assert_eq!(result, vec!["Hello,", " world!"]);
219    }
220
221    #[test]
222    fn it_handles_empty_input() {
223        let result = HATSplitter::new().split("");
224
225        assert!(result.is_empty());
226    }
227
228    #[test]
229    fn it_splits_camel_case() {
230        let result = HATSplitter::new().split("howAreYou");
231
232        assert_eq!(result, vec!["how", "Are", "You"]);
233    }
234
235    #[test]
236    fn it_splits_snake_case() {
237        let result = HATSplitter::new().split("how_are_you");
238
239        assert_eq!(result, vec!["how_", "are_", "you"]);
240    }
241
242    #[test]
243    fn it_limits_word_size() {
244        let result = HATSplitter::new().split_with_limit("verylongword", 10);
245
246        assert_eq!(result, vec![b"verylongwo".to_vec(), b"rd".to_vec()]);
247    }
248
249    #[test]
250    fn it_splits_large_unicode_characters() {
251        let result = HATSplitter::new().split_with_limit("🌝", 2);
252
253        assert_eq!(result.len(), 2);
254    }
255
256    #[test]
257    fn it_does_not_split_unicode_where_possible() {
258        // This is one word with a 2-byte 'Ăź' starting at byte offset 1. We hope that the splitter
259        // preserves this character by splitting into three parts instead of two.
260        let result = HATSplitter::new().split_with_limit("fĂźr", 2);
261
262        assert_eq!(
263            result,
264            vec![b"f".to_vec(), "Ăź".as_bytes().to_vec(), b"r".to_vec()]
265        );
266    }
267
268    #[test]
269    #[should_panic]
270    fn it_handles_zero_max_bytes() {
271        HATSplitter::new().split_with_limit("abc", 0);
272    }
273
274    #[test]
275    fn it_handles_strange_stuff() {
276        let text = "𓀀✨𝒜𝓁𝑔𝑜𝓇𝒾𝓉𝒽𝓂 شْء 你好吗 こんにちは 안녕하세요 𞤢𞤭𞤤 𝔽(λx.𝑥²) 🤖🍕⟨𝛴, 𝜋⟩ 🜚 𝔽↦𝑒ⁿω₀📡;𝑧𝑎<𝔱𝓇𝑢∃>🛠️ҀЋހ±(Δ𝓧) 乁( •_• )ㄏ   ⿰木日👾";
277
278        HATSplitter::new().split_with_limit(text, 100);
279    }
280}