1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum TokenType {
10 LineComment,
11 BlockCommentStart,
12 StringDelimiter,
13 DocStringDelimiter,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct TokenMatch {
22 pub token_type: TokenType,
23 pub close: Option<Vec<u8>>,
26 pub advance: usize,
28}
29
30struct TrieNode {
31 children: Vec<Option<Box<TrieNode>>>,
32 token_match: Option<TokenMatch>,
33}
34
35impl TrieNode {
36 fn new() -> Self {
37 Self {
38 children: (0..256).map(|_| None).collect(),
39 token_match: None,
40 }
41 }
42}
43
44pub struct TokenTrie {
48 root: TrieNode,
49 mask: u8,
50}
51
52impl TokenTrie {
53 pub fn new() -> Self {
54 Self {
55 root: TrieNode::new(),
56 mask: 0,
57 }
58 }
59
60 pub fn insert(&mut self, pattern: &[u8], mut token_match: TokenMatch) {
61 if pattern.is_empty() {
62 return;
63 }
64 self.mask |= pattern[0];
65 token_match.advance = pattern.len();
66
67 let mut node = &mut self.root;
68 for &byte in pattern {
69 let idx = byte as usize;
70 if node.children[idx].is_none() {
71 node.children[idx] = Some(Box::new(TrieNode::new()));
72 }
73 node = node.children[idx].as_mut().unwrap();
74 }
75 node.token_match = Some(token_match);
76 }
77
78 pub fn match_at(&self, content: &[u8], pos: usize) -> Option<TokenMatch> {
80 let mut node = &self.root;
81 let mut last_match: Option<&TokenMatch> = None;
82
83 for &byte in &content[pos..] {
84 match &node.children[byte as usize] {
85 Some(child) => {
86 node = child;
87 if node.token_match.is_some() {
88 last_match = node.token_match.as_ref();
89 }
90 }
91 None => break,
92 }
93 }
94
95 last_match.cloned()
96 }
97
98 pub fn process_mask(&self) -> u8 {
99 self.mask
100 }
101}
102
103impl Default for TokenTrie {
104 fn default() -> Self {
105 Self::new()
106 }
107}
108
109impl std::fmt::Debug for TokenTrie {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 f.debug_struct("TokenTrie")
112 .field("mask", &self.mask)
113 .finish_non_exhaustive()
114 }
115}
116
117pub fn build_from_language(lang: &crate::language::Language) -> (TokenTrie, u8) {
123 let mut trie = TokenTrie::new();
124
125 for lc in &lang.line_comments {
127 trie.insert(
128 lc.as_bytes(),
129 TokenMatch {
130 token_type: TokenType::LineComment,
131 close: None,
132 advance: 0,
133 },
134 );
135 }
136
137 for (open, close) in &lang.block_comments {
139 trie.insert(
140 open.as_bytes(),
141 TokenMatch {
142 token_type: TokenType::BlockCommentStart,
143 close: Some(close.as_bytes().to_vec()),
144 advance: 0,
145 },
146 );
147 }
148
149 trie.insert(
151 b"\"",
152 TokenMatch {
153 token_type: TokenType::StringDelimiter,
154 close: Some(b"\"".to_vec()),
155 advance: 0,
156 },
157 );
158 trie.insert(
159 b"'",
160 TokenMatch {
161 token_type: TokenType::StringDelimiter,
162 close: Some(b"'".to_vec()),
163 advance: 0,
164 },
165 );
166
167 let name = lang.name.as_str();
169 if name == "Python" || name == "Ruby" {
170 trie.insert(
171 b"\"\"\"",
172 TokenMatch {
173 token_type: TokenType::DocStringDelimiter,
174 close: Some(b"\"\"\"".to_vec()),
175 advance: 0,
176 },
177 );
178 trie.insert(
179 b"'''",
180 TokenMatch {
181 token_type: TokenType::DocStringDelimiter,
182 close: Some(b"'''".to_vec()),
183 advance: 0,
184 },
185 );
186 }
187
188 let mask = trie.process_mask();
189 (trie, mask)
190}
191
192#[inline(always)]
199pub fn should_process(byte: u8, mask: u8) -> bool {
200 byte & mask == byte
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_empty_trie_matches_nothing() {
209 let trie = TokenTrie::new();
210 assert_eq!(trie.match_at(b"hello", 0), None);
211 }
212
213 #[test]
214 fn test_single_line_comment() {
215 let mut trie = TokenTrie::new();
216 trie.insert(
217 b"//",
218 TokenMatch {
219 token_type: TokenType::LineComment,
220 close: None,
221 advance: 0,
222 },
223 );
224 let m = trie.match_at(b"// comment", 0).unwrap();
225 assert_eq!(m.token_type, TokenType::LineComment);
226 assert_eq!(m.advance, 2);
227 assert!(m.close.is_none());
228 }
229
230 #[test]
231 fn test_block_comment() {
232 let mut trie = TokenTrie::new();
233 trie.insert(
234 b"/*",
235 TokenMatch {
236 token_type: TokenType::BlockCommentStart,
237 close: Some(b"*/".to_vec()),
238 advance: 0,
239 },
240 );
241 let m = trie.match_at(b"/* block */", 0).unwrap();
242 assert_eq!(m.token_type, TokenType::BlockCommentStart);
243 assert_eq!(m.advance, 2);
244 assert_eq!(m.close.as_deref(), Some(b"*/".as_slice()));
245 }
246
247 #[test]
248 fn test_no_match_at_wrong_position() {
249 let mut trie = TokenTrie::new();
250 trie.insert(
251 b"//",
252 TokenMatch {
253 token_type: TokenType::LineComment,
254 close: None,
255 advance: 0,
256 },
257 );
258 assert_eq!(trie.match_at(b"x // y", 0), None);
259 let m = trie.match_at(b"x // y", 2).unwrap();
260 assert_eq!(m.token_type, TokenType::LineComment);
261 }
262
263 #[test]
264 fn test_string_delimiter() {
265 let mut trie = TokenTrie::new();
266 trie.insert(
267 b"\"",
268 TokenMatch {
269 token_type: TokenType::StringDelimiter,
270 close: Some(b"\"".to_vec()),
271 advance: 0,
272 },
273 );
274 let m = trie.match_at(b"\"hello\"", 0).unwrap();
275 assert_eq!(m.token_type, TokenType::StringDelimiter);
276 assert_eq!(m.close.as_deref(), Some(b"\"".as_slice()));
277 }
278
279 #[test]
280 fn test_process_mask_filters_correctly() {
281 let mut trie = TokenTrie::new();
282 trie.insert(
283 b"//",
284 TokenMatch {
285 token_type: TokenType::LineComment,
286 close: None,
287 advance: 0,
288 },
289 );
290 trie.insert(
291 b"\"",
292 TokenMatch {
293 token_type: TokenType::StringDelimiter,
294 close: Some(b"\"".to_vec()),
295 advance: 0,
296 },
297 );
298 let mask = trie.process_mask();
299 assert!(should_process(b'/', mask));
300 assert!(should_process(b'"', mask));
301 assert!(!should_process(b'a', mask));
303 }
304
305 #[test]
306 fn test_longer_match_wins() {
307 let mut trie = TokenTrie::new();
308 trie.insert(
309 b"\"",
310 TokenMatch {
311 token_type: TokenType::StringDelimiter,
312 close: Some(b"\"".to_vec()),
313 advance: 0,
314 },
315 );
316 trie.insert(
317 b"\"\"\"",
318 TokenMatch {
319 token_type: TokenType::DocStringDelimiter,
320 close: Some(b"\"\"\"".to_vec()),
321 advance: 0,
322 },
323 );
324 let m = trie.match_at(b"\"\"\"hello\"\"\"", 0).unwrap();
325 assert_eq!(m.token_type, TokenType::DocStringDelimiter);
326 assert_eq!(m.advance, 3);
327 }
328
329 #[test]
330 fn test_build_from_rust_language() {
331 use crate::language::Language;
332
333 let lang = Language {
334 name: "Rust".to_string(),
335 extensions: vec![".rs".to_string()],
336 line_comments: vec!["//".to_string()],
337 block_comments: vec![("/*".to_string(), "*/".to_string())],
338 nested_comments: true,
339 ..Default::default()
340 };
341 let (trie, mask) = build_from_language(&lang);
342
343 let m = trie.match_at(b"// comment", 0).unwrap();
344 assert_eq!(m.token_type, TokenType::LineComment);
345
346 let m = trie.match_at(b"/* block */", 0).unwrap();
347 assert_eq!(m.token_type, TokenType::BlockCommentStart);
348 assert_eq!(m.close.as_deref(), Some(b"*/".as_slice()));
349
350 let m = trie.match_at(b"\"hello\"", 0).unwrap();
351 assert_eq!(m.token_type, TokenType::StringDelimiter);
352
353 assert_ne!(mask, 0);
354 }
355
356 #[test]
357 fn test_build_from_python_language() {
358 use crate::language::Language;
359
360 let lang = Language {
361 name: "Python".to_string(),
362 extensions: vec![".py".to_string()],
363 line_comments: vec!["#".to_string()],
364 ..Default::default()
365 };
366 let (trie, _mask) = build_from_language(&lang);
367
368 let m = trie.match_at(b"# comment", 0).unwrap();
369 assert_eq!(m.token_type, TokenType::LineComment);
370
371 let m = trie.match_at(b"\"\"\"docstring\"\"\"", 0).unwrap();
372 assert_eq!(m.token_type, TokenType::DocStringDelimiter);
373 assert_eq!(m.close.as_deref(), Some(b"\"\"\"".as_slice()));
374
375 let m = trie.match_at(b"'''docstring'''", 0).unwrap();
376 assert_eq!(m.token_type, TokenType::DocStringDelimiter);
377 assert_eq!(m.close.as_deref(), Some(b"'''".as_slice()));
378 }
379}