hat_splitter/
split.rs

1use std::sync::LazyLock;
2
3use icu_segmenter::WordSegmenter;
4use once_cell::sync::Lazy;
5use regex::Regex;
6
7enum Token {
8    Word(String),
9    Punctuation(String),
10    Whitespace(String),
11    Space(String),
12}
13
14impl Token {
15    fn inner(self) -> String {
16        match self {
17            Token::Word(s) | Token::Punctuation(s) | Token::Whitespace(s) | Token::Space(s) => s,
18        }
19    }
20}
21
22pub trait Splitter {
23    // At some point it would be great to do this without allocations...
24    //fn split<'a>(&self, input: &'a str) -> Vec<&'a str>;
25
26    /// Splits a string into words.
27    fn split(&self, input: &str) -> Vec<String>;
28
29    /// Splits a string into words and limits the size of each word to `max_bytes`. As this
30    /// function enforces a byte limit, it may split unicode characters. That is, this function
31    /// does not guarantee that the resulting byte arrays are valid UTF-8.
32    fn split_with_limit(&self, input: &str, max_bytes_per_word: usize) -> Vec<Vec<u8>>;
33}
34
35pub struct HATSplitter;
36
37impl Default for HATSplitter {
38    fn default() -> Self {
39        Self::new()
40    }
41}
42
43impl HATSplitter {
44    pub fn new() -> Self {
45        Self
46    }
47
48    fn unicode_word_split(input: &str) -> Vec<&str> {
49        // Note: we could also try `new_auto` which uses a LSTM (we should figure out which is better)
50        static WORD_SEGMENTER: LazyLock<WordSegmenter> =
51            LazyLock::new(WordSegmenter::new_dictionary);
52        let breakpoints: Vec<usize> = WORD_SEGMENTER.segment_str(input).collect();
53        breakpoints.windows(2).map(|w| &input[w[0]..w[1]]).collect()
54    }
55
56    fn split_at_matches<'a>(s: &'a str, re: &Regex) -> Vec<&'a str> {
57        let mut result = Vec::new();
58        let mut match_start = 0;
59
60        for regex_match in re.find_iter(s) {
61            let match_end = regex_match.start() + 1;
62            result.push(&s[match_start..match_end]);
63            match_start = match_end;
64        }
65
66        if match_start < s.len() {
67            result.push(&s[match_start..s.len()]);
68        }
69
70        result
71    }
72
73    fn split_camel_case(s: &str) -> Vec<&str> {
74        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\p{Ll})(\p{Lu})").unwrap());
75        Self::split_at_matches(s, &RE)
76    }
77
78    fn split_snake_case(s: &str) -> Vec<&str> {
79        static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"_").unwrap());
80        Self::split_at_matches(s, &RE)
81    }
82
83    fn combine_spaces(strings: Vec<&str>) -> Vec<String> {
84        strings.into_iter().fold(Vec::new(), |mut acc, s| {
85            if s == " " {
86                // If we have a space and the last element is also spaces, append to it
87                if let Some(last) = acc.last_mut() {
88                    if last.chars().all(|c| c == ' ') {
89                        last.push(' ');
90                        return acc;
91                    }
92                }
93            }
94            // Otherwise add as a new element
95            acc.push(s.to_string());
96            acc
97        })
98    }
99
100    // This function does its best to avoid splitting unicode characters, but in some cases it has
101    // no choice (e.g., if max_bytes < 4 and an emoji comes in).
102    fn split_long_words(strings: Vec<String>, max_bytes: usize) -> Vec<Vec<u8>> {
103        if max_bytes == 0 {
104            panic!("max_bytes must be greater than 0");
105        }
106        strings.into_iter().fold(Vec::new(), |mut result, string| {
107            let bytes = string.as_bytes();
108            if bytes.len() <= max_bytes {
109                result.push(bytes.to_vec());
110                return result;
111            }
112
113            let mut start_byte = 0;
114            while start_byte < bytes.len() {
115                let end_byte = std::cmp::min(start_byte + max_bytes, bytes.len());
116
117                // Backtrack to find a valid UTF-8 boundary
118                let end = (start_byte + 1..=end_byte)
119                    .rev()
120                    .find(|&i| string.is_char_boundary(i))
121                    .unwrap_or(end_byte); // Fall back to end_byte if no boundary found
122
123                result.push(bytes[start_byte..end].to_vec());
124                start_byte = end;
125            }
126            result
127        })
128    }
129
130    /// The Lexer takes a string and splits it into logical tokens.
131    fn lex(s: &str) -> Vec<Token> {
132        static WHITESPACE_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\s+$").unwrap());
133        static PUNCTUATION_RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"^\p{P}$").unwrap());
134
135        let words = Self::unicode_word_split(s);
136
137        let words = words
138            .iter()
139            .flat_map(|s| Self::split_camel_case(s))
140            .flat_map(|s| Self::split_snake_case(s))
141            .collect::<Vec<&str>>();
142
143        let words = Self::combine_spaces(words.clone());
144
145        words
146            .into_iter()
147            .map(|s| {
148                if s == " " {
149                    Token::Space(s)
150                } else if WHITESPACE_RE.is_match(s.as_str()) {
151                    Token::Whitespace(s)
152                } else if PUNCTUATION_RE.is_match(s.as_str()) {
153                    Token::Punctuation(s)
154                } else {
155                    Token::Word(s)
156                }
157            })
158            .collect()
159    }
160
161    /// The Parser takes tokens and groups them into a string split.
162    fn parse(tokens: Vec<Token>) -> Vec<String> {
163        let groups = tokens
164            .into_iter()
165            .fold(Vec::<Vec<Token>>::new(), |mut groups, token| {
166                let should_append_to_last_group = |last_group: &Vec<Token>, token: &Token| {
167                    matches!(
168                        (last_group.last(), token),
169                        (Some(Token::Space(_)), Token::Word(_))
170                            | (
171                                Some(Token::Space(_) | Token::Word(_) | Token::Punctuation(_)),
172                                Token::Punctuation(_),
173                            )
174                    )
175                };
176
177                if let Some(last_group) = groups.last_mut() {
178                    if should_append_to_last_group(last_group, &token) {
179                        last_group.push(token);
180                        return groups;
181                    }
182                }
183
184                groups.push(vec![token]);
185                groups
186            });
187
188        // Concatenate groups
189        groups
190            .into_iter()
191            .map(|group| group.into_iter().map(Token::inner).collect())
192            .collect()
193    }
194}
195
196impl Splitter for HATSplitter {
197    fn split(&self, input: &str) -> Vec<String> {
198        Self::parse(Self::lex(input))
199    }
200
201    fn split_with_limit(&self, input: &str, max_bytes: usize) -> Vec<Vec<u8>> {
202        Self::split_long_words(Self::parse(Self::lex(input)), max_bytes)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn it_works() {
212        let result = HATSplitter::new().split("Hello, world!");
213
214        assert_eq!(result, vec!["Hello,", " world!"]);
215    }
216
217    #[test]
218    fn it_handles_empty_input() {
219        let result = HATSplitter::new().split("");
220
221        assert!(result.is_empty());
222    }
223
224    #[test]
225    fn it_splits_camel_case() {
226        let result = HATSplitter::new().split("howAreYou");
227
228        assert_eq!(result, vec!["how", "Are", "You"]);
229    }
230
231    #[test]
232    fn it_splits_snake_case() {
233        let result = HATSplitter::new().split("how_are_you");
234
235        assert_eq!(result, vec!["how_", "are_", "you"]);
236    }
237
238    #[test]
239    fn it_limits_word_size() {
240        let result = HATSplitter::new().split_with_limit("verylongword", 10);
241
242        assert_eq!(result, vec![b"verylongwo".to_vec(), b"rd".to_vec()]);
243    }
244
245    #[test]
246    fn it_splits_large_unicode_characters() {
247        let result = HATSplitter::new().split_with_limit("🌝", 2);
248
249        assert_eq!(result.len(), 2);
250    }
251
252    #[test]
253    fn it_does_not_split_unicode_where_possible() {
254        // This is one word with a 2-byte 'ü' starting at byte offset 1. We hope that the splitter
255        // preserves this character by splitting into three parts instead of two.
256        let result = HATSplitter::new().split_with_limit("für", 2);
257
258        assert_eq!(
259            result,
260            vec![b"f".to_vec(), "ü".as_bytes().to_vec(), b"r".to_vec()]
261        );
262    }
263
264    #[test]
265    #[should_panic]
266    fn it_handles_zero_max_bytes() {
267        HATSplitter::new().split_with_limit("abc", 0);
268    }
269}