Skip to main content

lean_ctx/core/
structural_tokenizer.rs

1//! Structural tokenizer treating idiomatic multi-token spans as single motifs.
2
3use std::collections::HashSet;
4use std::sync::OnceLock;
5
6use regex::Regex;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum TokenKind {
10    Keyword,
11    Identifier,
12    Operator,
13    Literal,
14    Pattern,
15    Noise,
16}
17
18#[derive(Debug, Clone, PartialEq)]
19pub struct StructuralToken {
20    pub kind: TokenKind,
21    pub text: String,
22    pub weight: f64,
23}
24
25const W_PATTERN: f64 = 3.0;
26const W_KEYWORD: f64 = 2.0;
27const W_LITERAL: f64 = 1.5;
28const W_IDENTIFIER: f64 = 1.0;
29const W_OPERATOR: f64 = 0.8;
30const W_NOISE: f64 = 0.15;
31
32fn for_in_rust_re() -> &'static Regex {
33    static CELL: OnceLock<Regex> = OnceLock::new();
34    CELL.get_or_init(|| Regex::new(r"for\s+[a-zA-Z_][a-zA-Z0-9_]*\s+in\s+").expect("for-in regex"))
35}
36
37fn keywords_for(lang: &str) -> &'static HashSet<&'static str> {
38    static RUST: OnceLock<HashSet<&str>> = OnceLock::new();
39    static GO: OnceLock<HashSet<&str>> = OnceLock::new();
40    static GENERIC: OnceLock<HashSet<&str>> = OnceLock::new();
41
42    match lang {
43        "rust" | "rs" => RUST.get_or_init(|| {
44            HashSet::from([
45                "pub", "fn", "let", "mut", "struct", "enum", "impl", "trait", "use", "mod",
46                "crate", "super", "self", "where", "type", "const", "static", "async", "await",
47                "match", "if", "else", "for", "while", "loop", "break", "continue", "return",
48                "unsafe", "move", "ref", "dyn", "extern", "in", "as",
49            ])
50        }),
51        "go" => GO.get_or_init(|| {
52            HashSet::from([
53                "func",
54                "package",
55                "import",
56                "var",
57                "const",
58                "type",
59                "struct",
60                "interface",
61                "map",
62                "chan",
63                "defer",
64                "go",
65                "select",
66                "switch",
67                "case",
68                "default",
69                "if",
70                "else",
71                "for",
72                "range",
73                "return",
74                "break",
75                "continue",
76                "fallthrough",
77                "nil",
78                "make",
79                "new",
80                "len",
81                "cap",
82            ])
83        }),
84        _ => GENERIC.get_or_init(|| {
85            HashSet::from([
86                "if", "else", "for", "while", "return", "fn", "func", "let", "var", "const", "pub",
87                "import", "class", "def",
88            ])
89        }),
90    }
91}
92
93fn try_pattern(rest: &str, lang: &str) -> Option<(usize, String)> {
94    let ascii_patterns: &[(&str, &[&str])] = &[
95        ("if err != nil", &["go"]),
96        ("pub async fn", &["rust", "rs"]),
97        ("async fn", &["rust", "rs"]),
98        ("pub fn", &["rust", "rs"]),
99        ("fn main()", &["rust", "rs", "generic", ""]),
100        ("match ", &["rust", "rs"]),
101    ];
102
103    for (pat, langs) in ascii_patterns {
104        if !langs.iter().any(|&l| l == lang || l.is_empty()) {
105            continue;
106        }
107        if rest.starts_with(pat) {
108            return Some((pat.len(), (*pat).to_string()));
109        }
110    }
111
112    if lang == "rust" || lang == "rs" {
113        if let Some(m) = for_in_rust_re().find(rest) {
114            if m.start() == 0 {
115                return Some((m.end(), m.as_str().to_string()));
116            }
117        }
118    }
119
120    None
121}
122
123fn skip_line_comment(bytes: &[u8], mut i: usize) -> usize {
124    while i < bytes.len() && bytes[i] != b'\n' {
125        i += 1;
126    }
127    i
128}
129
130fn skip_block_comment(bytes: &[u8], mut i: usize) -> Option<usize> {
131    if i + 1 >= bytes.len() || bytes[i] != b'/' || bytes[i + 1] != b'*' {
132        return None;
133    }
134    i += 2;
135    while i + 1 < bytes.len() {
136        if bytes[i] == b'*' && bytes[i + 1] == b'/' {
137            return Some(i + 2);
138        }
139        i += 1;
140    }
141    Some(bytes.len())
142}
143
144fn scan_string(bytes: &[u8], quote: u8, mut i: usize) -> usize {
145    i += 1;
146    while i < bytes.len() {
147        let b = bytes[i];
148        if b == b'\\' && i + 1 < bytes.len() {
149            i += 2;
150            continue;
151        }
152        if b == quote {
153            return i + 1;
154        }
155        i += 1;
156    }
157    bytes.len()
158}
159
160fn scan_raw_string(bytes: &[u8], i: usize) -> usize {
161    if i + 1 >= bytes.len() || bytes[i] != b'r' {
162        return i;
163    }
164    let mut j = i + 1;
165    let mut hashes = 0usize;
166    while j < bytes.len() && bytes[j] == b'#' {
167        hashes += 1;
168        j += 1;
169    }
170    if j >= bytes.len() || bytes[j] != b'"' {
171        return i;
172    }
173    j += 1;
174    while j < bytes.len() {
175        if bytes[j] == b'"' {
176            let mut k = j + 1;
177            let mut ok = true;
178            for _ in 0..hashes {
179                if k >= bytes.len() || bytes[k] != b'#' {
180                    ok = false;
181                    break;
182                }
183                k += 1;
184            }
185            if ok && hashes == 0 {
186                return k;
187            }
188            if ok {
189                return k;
190            }
191        }
192        j += 1;
193    }
194    bytes.len()
195}
196
197fn scan_number(bytes: &[u8], mut i: usize) -> usize {
198    let start = i;
199    if bytes.get(i) == Some(&b'0') && bytes.get(i + 1).is_some_and(|b| *b == b'x' || *b == b'X') {
200        i += 2;
201        while i < bytes.len() && bytes[i].is_ascii_hexdigit() {
202            i += 1;
203        }
204        return i.max(start + 1);
205    }
206    while i < bytes.len() && (bytes[i].is_ascii_digit() || bytes[i] == b'_' || bytes[i] == b'.') {
207        i += 1;
208    }
209    if bytes.get(i) == Some(&b'e') || bytes.get(i) == Some(&b'E') {
210        i += 1;
211        if bytes.get(i) == Some(&b'+') || bytes.get(i) == Some(&b'-') {
212            i += 1;
213        }
214        while i < bytes.len() && bytes[i].is_ascii_digit() {
215            i += 1;
216        }
217    }
218    i.max(start + 1)
219}
220
221fn scan_identifier(bytes: &[u8], mut i: usize) -> usize {
222    let start = i;
223    while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
224        i += 1;
225    }
226    i.max(start + 1)
227}
228
229fn push_op(out: &mut Vec<StructuralToken>, text: &str) {
230    out.push(StructuralToken {
231        kind: TokenKind::Operator,
232        text: text.to_string(),
233        weight: W_OPERATOR,
234    });
235}
236
237/// Tokenize source into weighted structural tokens (motifs, keywords, literals, …).
238pub fn structural_tokenize(code: &str, lang: &str) -> Vec<StructuralToken> {
239    let lang_lower = lang.to_lowercase();
240    let lang_k = match lang_lower.as_str() {
241        "rust" | "rs" => "rust",
242        "go" | "golang" => "go",
243        _ => "generic",
244    };
245
246    let kw = keywords_for(lang_k);
247    let bytes = code.as_bytes();
248    let mut i = 0usize;
249    let mut out = Vec::new();
250
251    while i < bytes.len() {
252        if bytes[i].is_ascii_whitespace() {
253            let start = i;
254            while i < bytes.len() && bytes[i].is_ascii_whitespace() {
255                i += 1;
256            }
257            if start != i {
258                out.push(StructuralToken {
259                    kind: TokenKind::Noise,
260                    text: code[start..i].to_string(),
261                    weight: W_NOISE,
262                });
263            }
264            continue;
265        }
266
267        let rest = &code[i..];
268        if let Some((len, text)) = try_pattern(rest, lang_k) {
269            out.push(StructuralToken {
270                kind: TokenKind::Pattern,
271                text,
272                weight: W_PATTERN,
273            });
274            i += len;
275            continue;
276        }
277
278        if bytes[i] == b'/' && bytes.get(i + 1) == Some(&b'/') {
279            let start = i;
280            i = skip_line_comment(bytes, i);
281            out.push(StructuralToken {
282                kind: TokenKind::Noise,
283                text: code[start..i].to_string(),
284                weight: W_NOISE,
285            });
286            continue;
287        }
288
289        if let Some(next) = skip_block_comment(bytes, i) {
290            let start = i;
291            i = next;
292            out.push(StructuralToken {
293                kind: TokenKind::Noise,
294                text: code[start..i].to_string(),
295                weight: W_NOISE,
296            });
297            continue;
298        }
299
300        if lang_k == "rust"
301            && bytes[i] == b'r'
302            && (bytes.get(i + 1) == Some(&b'#') || bytes.get(i + 1) == Some(&b'"'))
303        {
304            let start = i;
305            i = scan_raw_string(bytes, i);
306            out.push(StructuralToken {
307                kind: TokenKind::Literal,
308                text: code[start..i].to_string(),
309                weight: W_LITERAL,
310            });
311            continue;
312        }
313
314        if bytes[i] == b'"' || bytes[i] == b'\'' {
315            let quote = bytes[i];
316            let start = i;
317            i = scan_string(bytes, quote, i);
318            out.push(StructuralToken {
319                kind: TokenKind::Literal,
320                text: code[start..i].to_string(),
321                weight: W_LITERAL,
322            });
323            continue;
324        }
325
326        if bytes[i].is_ascii_digit() {
327            let start = i;
328            i = scan_number(bytes, i);
329            out.push(StructuralToken {
330                kind: TokenKind::Literal,
331                text: code[start..i].to_string(),
332                weight: W_LITERAL,
333            });
334            continue;
335        }
336
337        if bytes[i].is_ascii_alphabetic() || bytes[i] == b'_' {
338            let start = i;
339            i = scan_identifier(bytes, i);
340            let word = &code[start..i];
341            let kind = if kw.contains(word) {
342                TokenKind::Keyword
343            } else {
344                TokenKind::Identifier
345            };
346            let weight = if kind == TokenKind::Keyword {
347                W_KEYWORD
348            } else {
349                W_IDENTIFIER
350            };
351            out.push(StructuralToken {
352                kind,
353                text: word.to_string(),
354                weight,
355            });
356            continue;
357        }
358
359        let two = i + 1 < bytes.len();
360        if two {
361            let pair = [bytes[i], bytes[i + 1]];
362            let s = std::str::from_utf8(&pair).unwrap_or("??");
363            match pair {
364                [b'!' | b'=' | b'<' | b'>' | b'+' | b'-', b'=']
365                | [b'-' | b'=', b'>']
366                | [b':', b':']
367                | [b'&', b'&']
368                | [b'|', b'|'] => {
369                    push_op(&mut out, s);
370                    i += 2;
371                    continue;
372                }
373                _ => {}
374            }
375        }
376
377        let ch = bytes[i] as char;
378        push_op(&mut out, &ch.to_string());
379        i += 1;
380    }
381
382    out
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388
389    #[test]
390    fn rust_pub_fn_pattern() {
391        let toks = structural_tokenize("pub fn foo() {}", "rust");
392        assert_eq!(toks[0].kind, TokenKind::Pattern);
393        assert_eq!(toks[0].text, "pub fn");
394        assert_eq!(toks[0].weight, W_PATTERN);
395    }
396
397    #[test]
398    fn rust_async_fn_pattern() {
399        let toks = structural_tokenize("pub async fn bar() {}", "rust");
400        assert!(
401            toks.iter()
402                .any(|t| t.kind == TokenKind::Pattern && t.text.starts_with("pub async fn")),
403            "{toks:?}"
404        );
405    }
406
407    #[test]
408    fn rust_match_pattern_prefix() {
409        let toks = structural_tokenize("match x {", "rust");
410        assert_eq!(toks[0].kind, TokenKind::Pattern);
411        assert_eq!(toks[0].text, "match ");
412    }
413
414    #[test]
415    fn rust_for_in_loop_pattern() {
416        let src = "for item in items.iter() {";
417        let toks = structural_tokenize(src, "rust");
418        assert!(toks
419            .iter()
420            .any(|t| t.kind == TokenKind::Pattern && t.text.starts_with("for ")));
421    }
422
423    #[test]
424    fn go_err_nil_pattern() {
425        let toks = structural_tokenize("if err != nil { return err }", "go");
426        assert!(toks
427            .iter()
428            .any(|t| t.kind == TokenKind::Pattern && t.text.contains("err")));
429        let pat = toks
430            .iter()
431            .find(|t| t.kind == TokenKind::Pattern)
432            .expect("pattern");
433        assert_eq!(pat.text, "if err != nil");
434        assert_eq!(pat.weight, W_PATTERN);
435    }
436
437    #[test]
438    fn weights_pattern_above_identifier() {
439        let toks = structural_tokenize("pub fn main() {}", "rust");
440        let p = toks.iter().find(|t| t.kind == TokenKind::Pattern).unwrap();
441        let id = toks
442            .iter()
443            .find(|t| t.kind == TokenKind::Identifier && t.text == "main")
444            .unwrap();
445        assert!(p.weight > id.weight);
446        assert!(p.weight > W_KEYWORD);
447    }
448
449    #[test]
450    fn comment_is_noise() {
451        let toks = structural_tokenize("// hello\nlet x = 1;", "rust");
452        assert!(toks
453            .iter()
454            .any(|t| t.kind == TokenKind::Noise && t.text.starts_with("//")));
455        assert!(toks
456            .iter()
457            .any(|t| t.kind == TokenKind::Keyword && t.text == "let"));
458    }
459
460    #[test]
461    fn string_literal_kind() {
462        let toks = structural_tokenize(r#"let s = "ab";"#, "rust");
463        let lit = toks
464            .iter()
465            .find(|t| t.kind == TokenKind::Literal && t.text.starts_with('"'));
466        assert!(lit.is_some());
467        assert_eq!(lit.unwrap().weight, W_LITERAL);
468    }
469}