1use std::collections::HashSet;
4use std::sync::OnceLock;
5
6use regex::Regex;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum TokenKind {
10 Keyword,
11 Identifier,
12 Operator,
13 Literal,
14 Pattern,
15 Noise,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct StructuralToken {
20 pub kind: TokenKind,
21 pub text: String,
22 pub weight: f64,
23}
24
25const W_PATTERN: f64 = 3.0;
26const W_KEYWORD: f64 = 2.0;
27const W_LITERAL: f64 = 1.5;
28const W_IDENTIFIER: f64 = 1.0;
29const W_OPERATOR: f64 = 0.8;
30const W_NOISE: f64 = 0.15;
31
32fn for_in_rust_re() -> &'static Regex {
33 static CELL: OnceLock<Regex> = OnceLock::new();
34 CELL.get_or_init(|| Regex::new(r"for\s+[a-zA-Z_][a-zA-Z0-9_]*\s+in\s+").expect("for-in regex"))
35}
36
37fn keywords_for(lang: &str) -> &'static HashSet<&'static str> {
38 static RUST: OnceLock<HashSet<&str>> = OnceLock::new();
39 static GO: OnceLock<HashSet<&str>> = OnceLock::new();
40 static GENERIC: OnceLock<HashSet<&str>> = OnceLock::new();
41
42 match lang {
43 "rust" | "rs" => RUST.get_or_init(|| {
44 HashSet::from([
45 "pub", "fn", "let", "mut", "struct", "enum", "impl", "trait", "use", "mod",
46 "crate", "super", "self", "where", "type", "const", "static", "async", "await",
47 "match", "if", "else", "for", "while", "loop", "break", "continue", "return",
48 "unsafe", "move", "ref", "dyn", "extern", "in", "as",
49 ])
50 }),
51 "go" => GO.get_or_init(|| {
52 HashSet::from([
53 "func",
54 "package",
55 "import",
56 "var",
57 "const",
58 "type",
59 "struct",
60 "interface",
61 "map",
62 "chan",
63 "defer",
64 "go",
65 "select",
66 "switch",
67 "case",
68 "default",
69 "if",
70 "else",
71 "for",
72 "range",
73 "return",
74 "break",
75 "continue",
76 "fallthrough",
77 "nil",
78 "make",
79 "new",
80 "len",
81 "cap",
82 ])
83 }),
84 _ => GENERIC.get_or_init(|| {
85 HashSet::from([
86 "if", "else", "for", "while", "return", "fn", "func", "let", "var", "const", "pub",
87 "import", "class", "def",
88 ])
89 }),
90 }
91}
92
93fn try_pattern(rest: &str, lang: &str) -> Option<(usize, String)> {
94 let ascii_patterns: &[(&str, &[&str])] = &[
95 ("if err != nil", &["go"]),
96 ("pub async fn", &["rust", "rs"]),
97 ("async fn", &["rust", "rs"]),
98 ("pub fn", &["rust", "rs"]),
99 ("fn main()", &["rust", "rs", "generic", ""]),
100 ("match ", &["rust", "rs"]),
101 ];
102
103 for (pat, langs) in ascii_patterns {
104 if !langs.iter().any(|&l| l == lang || l.is_empty()) {
105 continue;
106 }
107 if rest.starts_with(pat) {
108 return Some((pat.len(), (*pat).to_string()));
109 }
110 }
111
112 if lang == "rust" || lang == "rs" {
113 if let Some(m) = for_in_rust_re().find(rest) {
114 if m.start() == 0 {
115 return Some((m.end(), m.as_str().to_string()));
116 }
117 }
118 }
119
120 None
121}
122
123fn skip_line_comment(bytes: &[u8], mut i: usize) -> usize {
124 while i < bytes.len() && bytes[i] != b'\n' {
125 i += 1;
126 }
127 i
128}
129
130fn skip_block_comment(bytes: &[u8], mut i: usize) -> Option<usize> {
131 if i + 1 >= bytes.len() || bytes[i] != b'/' || bytes[i + 1] != b'*' {
132 return None;
133 }
134 i += 2;
135 while i + 1 < bytes.len() {
136 if bytes[i] == b'*' && bytes[i + 1] == b'/' {
137 return Some(i + 2);
138 }
139 i += 1;
140 }
141 Some(bytes.len())
142}
143
144fn scan_string(bytes: &[u8], quote: u8, mut i: usize) -> usize {
145 i += 1;
146 while i < bytes.len() {
147 let b = bytes[i];
148 if b == b'\\' && i + 1 < bytes.len() {
149 i += 2;
150 continue;
151 }
152 if b == quote {
153 return i + 1;
154 }
155 i += 1;
156 }
157 bytes.len()
158}
159
160fn scan_raw_string(bytes: &[u8], i: usize) -> usize {
161 if i + 1 >= bytes.len() || bytes[i] != b'r' {
162 return i;
163 }
164 let mut j = i + 1;
165 let mut hashes = 0usize;
166 while j < bytes.len() && bytes[j] == b'#' {
167 hashes += 1;
168 j += 1;
169 }
170 if j >= bytes.len() || bytes[j] != b'"' {
171 return i;
172 }
173 j += 1;
174 while j < bytes.len() {
175 if bytes[j] == b'"' {
176 let mut k = j + 1;
177 let mut ok = true;
178 for _ in 0..hashes {
179 if k >= bytes.len() || bytes[k] != b'#' {
180 ok = false;
181 break;
182 }
183 k += 1;
184 }
185 if ok && hashes == 0 {
186 return k;
187 }
188 if ok {
189 return k;
190 }
191 }
192 j += 1;
193 }
194 bytes.len()
195}
196
197fn scan_number(bytes: &[u8], mut i: usize) -> usize {
198 let start = i;
199 if bytes.get(i) == Some(&b'0') && bytes.get(i + 1).is_some_and(|b| *b == b'x' || *b == b'X') {
200 i += 2;
201 while i < bytes.len() && bytes[i].is_ascii_hexdigit() {
202 i += 1;
203 }
204 return i.max(start + 1);
205 }
206 while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'_' || bytes[i] == b'.') {
207 i += 1;
208 }
209 if bytes.get(i) == Some(&b'e') || bytes.get(i) == Some(&b'E') {
210 i += 1;
211 if bytes.get(i) == Some(&b'+') || bytes.get(i) == Some(&b'-') {
212 i += 1;
213 }
214 while i < bytes.len() && bytes[i].is_ascii_digit() {
215 i += 1;
216 }
217 }
218 i.max(start + 1)
219}
220
221fn scan_identifier(bytes: &[u8], mut i: usize) -> usize {
222 let start = i;
223 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
224 i += 1;
225 }
226 i.max(start + 1)
227}
228
229fn push_op(out: &mut Vec<StructuralToken>, text: &str) {
230 out.push(StructuralToken {
231 kind: TokenKind::Operator,
232 text: text.to_string(),
233 weight: W_OPERATOR,
234 });
235}
236
237pub fn structural_tokenize(code: &str, lang: &str) -> Vec<StructuralToken> {
239 let lang_lower = lang.to_lowercase();
240 let lang_k = match lang_lower.as_str() {
241 "rust" | "rs" => "rust",
242 "go" | "golang" => "go",
243 _ => "generic",
244 };
245
246 let kw = keywords_for(lang_k);
247 let bytes = code.as_bytes();
248 let mut i = 0usize;
249 let mut out = Vec::new();
250
251 while i < bytes.len() {
252 if bytes[i].is_ascii_whitespace() {
253 let start = i;
254 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
255 i += 1;
256 }
257 if start != i {
258 out.push(StructuralToken {
259 kind: TokenKind::Noise,
260 text: code[start..i].to_string(),
261 weight: W_NOISE,
262 });
263 }
264 continue;
265 }
266
267 let rest = &code[i..];
268 if let Some((len, text)) = try_pattern(rest, lang_k) {
269 out.push(StructuralToken {
270 kind: TokenKind::Pattern,
271 text,
272 weight: W_PATTERN,
273 });
274 i += len;
275 continue;
276 }
277
278 if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'/') {
279 let start = i;
280 i = skip_line_comment(bytes, i);
281 out.push(StructuralToken {
282 kind: TokenKind::Noise,
283 text: code[start..i].to_string(),
284 weight: W_NOISE,
285 });
286 continue;
287 }
288
289 if let Some(next) = skip_block_comment(bytes, i) {
290 let start = i;
291 i = next;
292 out.push(StructuralToken {
293 kind: TokenKind::Noise,
294 text: code[start..i].to_string(),
295 weight: W_NOISE,
296 });
297 continue;
298 }
299
300 if lang_k == "rust"
301 && bytes[i] == b'r'
302 && (bytes.get(i + 1) == Some(&b'#') || bytes.get(i + 1) == Some(&b'"'))
303 {
304 let start = i;
305 i = scan_raw_string(bytes, i);
306 out.push(StructuralToken {
307 kind: TokenKind::Literal,
308 text: code[start..i].to_string(),
309 weight: W_LITERAL,
310 });
311 continue;
312 }
313
314 if bytes[i] == b'"' || bytes[i] == b'\'' {
315 let quote = bytes[i];
316 let start = i;
317 i = scan_string(bytes, quote, i);
318 out.push(StructuralToken {
319 kind: TokenKind::Literal,
320 text: code[start..i].to_string(),
321 weight: W_LITERAL,
322 });
323 continue;
324 }
325
326 if bytes[i].is_ascii_digit() {
327 let start = i;
328 i = scan_number(bytes, i);
329 out.push(StructuralToken {
330 kind: TokenKind::Literal,
331 text: code[start..i].to_string(),
332 weight: W_LITERAL,
333 });
334 continue;
335 }
336
337 if bytes[i].is_ascii_alphabetic() || bytes[i] == b'_' {
338 let start = i;
339 i = scan_identifier(bytes, i);
340 let word = &code[start..i];
341 let kind = if kw.contains(word) {
342 TokenKind::Keyword
343 } else {
344 TokenKind::Identifier
345 };
346 let weight = if kind == TokenKind::Keyword {
347 W_KEYWORD
348 } else {
349 W_IDENTIFIER
350 };
351 out.push(StructuralToken {
352 kind,
353 text: word.to_string(),
354 weight,
355 });
356 continue;
357 }
358
359 let two = i + 1 < bytes.len();
360 if two {
361 let pair = [bytes[i], bytes[i + 1]];
362 let s = std::str::from_utf8(&pair).unwrap_or("??");
363 match pair {
364 [b'!' | b'=' | b'<' | b'>' | b'+' | b'-', b'=']
365 | [b'-' | b'=', b'>']
366 | [b':', b':']
367 | [b'&', b'&']
368 | [b'|', b'|'] => {
369 push_op(&mut out, s);
370 i += 2;
371 continue;
372 }
373 _ => {}
374 }
375 }
376
377 let ch = bytes[i] as char;
378 push_op(&mut out, &ch.to_string());
379 i += 1;
380 }
381
382 out
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 #[test]
390 fn rust_pub_fn_pattern() {
391 let toks = structural_tokenize("pub fn foo() {}", "rust");
392 assert_eq!(toks[0].kind, TokenKind::Pattern);
393 assert_eq!(toks[0].text, "pub fn");
394 assert_eq!(toks[0].weight, W_PATTERN);
395 }
396
397 #[test]
398 fn rust_async_fn_pattern() {
399 let toks = structural_tokenize("pub async fn bar() {}", "rust");
400 assert!(
401 toks.iter()
402 .any(|t| t.kind == TokenKind::Pattern && t.text.starts_with("pub async fn")),
403 "{toks:?}"
404 );
405 }
406
407 #[test]
408 fn rust_match_pattern_prefix() {
409 let toks = structural_tokenize("match x {", "rust");
410 assert_eq!(toks[0].kind, TokenKind::Pattern);
411 assert_eq!(toks[0].text, "match ");
412 }
413
414 #[test]
415 fn rust_for_in_loop_pattern() {
416 let src = "for item in items.iter() {";
417 let toks = structural_tokenize(src, "rust");
418 assert!(toks
419 .iter()
420 .any(|t| t.kind == TokenKind::Pattern && t.text.starts_with("for ")));
421 }
422
423 #[test]
424 fn go_err_nil_pattern() {
425 let toks = structural_tokenize("if err != nil { return err }", "go");
426 assert!(toks
427 .iter()
428 .any(|t| t.kind == TokenKind::Pattern && t.text.contains("err")));
429 let pat = toks
430 .iter()
431 .find(|t| t.kind == TokenKind::Pattern)
432 .expect("pattern");
433 assert_eq!(pat.text, "if err != nil");
434 assert_eq!(pat.weight, W_PATTERN);
435 }
436
437 #[test]
438 fn weights_pattern_above_identifier() {
439 let toks = structural_tokenize("pub fn main() {}", "rust");
440 let p = toks.iter().find(|t| t.kind == TokenKind::Pattern).unwrap();
441 let id = toks
442 .iter()
443 .find(|t| t.kind == TokenKind::Identifier && t.text == "main")
444 .unwrap();
445 assert!(p.weight > id.weight);
446 assert!(p.weight > W_KEYWORD);
447 }
448
449 #[test]
450 fn comment_is_noise() {
451 let toks = structural_tokenize("// hello\nlet x = 1;", "rust");
452 assert!(toks
453 .iter()
454 .any(|t| t.kind == TokenKind::Noise && t.text.starts_with("//")));
455 assert!(toks
456 .iter()
457 .any(|t| t.kind == TokenKind::Keyword && t.text == "let"));
458 }
459
460 #[test]
461 fn string_literal_kind() {
462 let toks = structural_tokenize(r#"let s = "ab";"#, "rust");
463 let lit = toks
464 .iter()
465 .find(|t| t.kind == TokenKind::Literal && t.text.starts_with('"'));
466 assert!(lit.is_some());
467 assert_eq!(lit.unwrap().weight, W_LITERAL);
468 }
469}