Skip to main content

aft/
query_shape.rs

1use regex::Regex;
2use std::sync::LazyLock;
3
4static CAMEL_CASE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-z][A-Z]").unwrap());
5static SNAKE_CASE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-z]_[a-z]").unwrap());
6static PASCAL_CASE_RE: LazyLock<Regex> =
7    LazyLock::new(|| Regex::new(r"^[A-Z][a-z]+[A-Z]").unwrap());
8static ACRONYM_PASCAL_RE: LazyLock<Regex> =
9    LazyLock::new(|| Regex::new(r"\b[A-Z]{2,}[A-Z][a-z]").unwrap());
10static DOT_PATH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-zA-Z]\.[a-zA-Z]").unwrap());
11static FILE_PATH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[/\\].*\.\w{1,5}$").unwrap());
12static HEX_CODE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"0x[A-Fa-f0-9]+").unwrap());
13static ERROR_PREFIX_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\bERR_\w+").unwrap());
14static NUMERIC_ERROR_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\bE\d{4,}").unwrap());
15static HTTP_STATUS_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b[1-5]\d{2}\b").unwrap());
16static IDENTIFIER_TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
17    Regex::new(r"\b[A-Za-z_$][A-Za-z0-9_$]*(?:\.[A-Za-z_$][A-Za-z0-9_$]*)*\b").unwrap()
18});
19
20static WINDOWS_ABS_PATH_RE: LazyLock<Regex> =
21    LazyLock::new(|| Regex::new(r"^[A-Za-z]:[\\/][A-Za-z0-9_.\-+?\\/' ]+$").unwrap());
22static WINDOWS_REL_PATH_RE: LazyLock<Regex> =
23    LazyLock::new(|| Regex::new(r"^[A-Za-z0-9_.\-+?' ]+(\\[A-Za-z0-9_.\-+?' ]+)+$").unwrap());
24static POSIX_ABS_PATH_RE: LazyLock<Regex> =
25    LazyLock::new(|| Regex::new(r"^/[A-Za-z0-9_.\-+?/' ]+$").unwrap());
26static POSIX_REL_PATH_RE: LazyLock<Regex> =
27    LazyLock::new(|| Regex::new(r"^[A-Za-z0-9_.\-+?' ]+(/[A-Za-z0-9_.\-+?' ]+)+$").unwrap());
28static UNC_PATH_RE: LazyLock<Regex> =
29    LazyLock::new(|| Regex::new(r"^\\\\[A-Za-z0-9_.\-+?\\']+$").unwrap());
30static FILENAME_EXEMPTION_RE: LazyLock<Regex> =
31    LazyLock::new(|| Regex::new(r"^[A-Za-z_][A-Za-z0-9_.\-+'? ]*\.[A-Za-z0-9]{1,8}$").unwrap());
32static BRACE_QUANTIFIER_RE: LazyLock<Regex> =
33    LazyLock::new(|| Regex::new(r"\{\d+(?:,\d*)?\}").unwrap());
34static NAMED_CAPTURE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\(\?P<[^>]+>").unwrap());
35static CHAR_RANGE_RE: LazyLock<Regex> =
36    LazyLock::new(|| Regex::new(r"[A-Za-z0-9]-[A-Za-z0-9]").unwrap());
37
38const QUESTION_WORDS: &[&str] = &[
39    "how", "what", "where", "why", "when", "which", "who", "does",
40];
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum QueryKind {
44    Identifier,
45    Mixed,
46    ErrorCode,
47    Path,
48    Regex,
49    NaturalLanguage,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq)]
53pub struct ShapeWeights {
54    pub semantic: f32,
55    pub lexical: f32,
56    pub should_use_lexical: bool,
57}
58
59#[derive(Debug, Clone, Copy, PartialEq)]
60pub struct QueryShape {
61    pub kind: QueryKind,
62    pub weights: ShapeWeights,
63}
64
65pub fn classify(query: &str) -> QueryShape {
66    let trimmed = query.trim();
67    if trimmed.is_empty() {
68        return shape(QueryKind::NaturalLanguage);
69    }
70
71    if pre_tier_exempt(trimmed).is_some() {
72        return shape(QueryKind::Path);
73    }
74
75    if looks_like_regex(trimmed) {
76        return shape(QueryKind::Regex);
77    }
78
79    let words: Vec<&str> = trimmed.split_whitespace().collect();
80    let word_count = words.len();
81    let first_word_lower = words[0].to_ascii_lowercase();
82
83    if FILE_PATH_RE.is_match(trimmed) {
84        return shape(QueryKind::Path);
85    }
86
87    let has_http_status = word_count <= 3 && HTTP_STATUS_RE.is_match(trimmed);
88    if HEX_CODE_RE.is_match(trimmed)
89        || ERROR_PREFIX_RE.is_match(trimmed)
90        || NUMERIC_ERROR_RE.is_match(trimmed)
91        || has_http_status
92    {
93        return shape(QueryKind::ErrorCode);
94    }
95
96    let has_code_identifier = CAMEL_CASE_RE.is_match(trimmed)
97        || SNAKE_CASE_RE.is_match(trimmed)
98        || PASCAL_CASE_RE.is_match(trimmed)
99        || ACRONYM_PASCAL_RE.is_match(trimmed)
100        || DOT_PATH_RE.is_match(trimmed);
101    let has_question_word = QUESTION_WORDS.contains(&first_word_lower.as_str());
102    let is_long_phrase = word_count > 2;
103    let has_natural_language_signals = has_question_word || is_long_phrase;
104
105    if has_code_identifier && has_natural_language_signals {
106        return shape(QueryKind::Mixed);
107    }
108
109    if has_code_identifier || (word_count <= 2 && !has_natural_language_signals) {
110        return shape(QueryKind::Identifier);
111    }
112
113    shape(QueryKind::NaturalLanguage)
114}
115
116pub fn extract_tokens(query: &str, shape: &QueryShape) -> Vec<String> {
117    match shape.kind {
118        QueryKind::NaturalLanguage | QueryKind::Regex => Vec::new(),
119        QueryKind::Path => extract_path_tokens(query),
120        QueryKind::ErrorCode => extract_error_code_tokens(query),
121        QueryKind::Identifier => extract_identifier_tokens(query, false),
122        QueryKind::Mixed => extract_identifier_tokens(query, true),
123    }
124}
125
126pub fn pre_tier_exempt(query: &str) -> Option<&'static str> {
127    if let Some(kind) = check_url_exemption(query) {
128        return Some(kind);
129    }
130    check_path_exemption(query)
131}
132
133pub fn looks_like_regex(query: &str) -> bool {
134    crate::pattern_compile::detect_unsupported_features(query).is_some()
135        || tier_a_regex_signal(query)
136        || tier_b_character_class(query)
137        || tier_c_adjacent_meta(query)
138}
139
140fn check_url_exemption(query: &str) -> Option<&'static str> {
141    let parsed = url::Url::parse(query).ok()?;
142    if !matches!(parsed.scheme(), "http" | "https" | "file" | "ftp" | "ssh") {
143        return None;
144    }
145    if has_regex_meta_sequences(query) || has_obvious_regex_chars(query) {
146        return None;
147    }
148    Some("url")
149}
150
151fn check_path_exemption(query: &str) -> Option<&'static str> {
152    let kind = if WINDOWS_ABS_PATH_RE.is_match(query) {
153        "windows_abs"
154    } else if WINDOWS_REL_PATH_RE.is_match(query) {
155        "windows_rel"
156    } else if POSIX_ABS_PATH_RE.is_match(query) {
157        "posix_abs"
158    } else if POSIX_REL_PATH_RE.is_match(query) {
159        "posix_rel"
160    } else if UNC_PATH_RE.is_match(query) {
161        "unc"
162    } else if FILENAME_EXEMPTION_RE.is_match(query) {
163        "filename"
164    } else {
165        return None;
166    };
167    if has_path_regex_meta_sequences(query) || has_obvious_regex_chars(query) {
168        return None;
169    }
170    Some(kind)
171}
172
173fn has_regex_meta_sequences(query: &str) -> bool {
174    query.contains(".+")
175        || query.contains(".*")
176        || query.contains(".?")
177        || query.contains(r"\n")
178        || query.contains(r"\t")
179        || query.contains(r"\r")
180        || query.contains(r"\b")
181        || query.contains(r"\B")
182        || query.contains(r"\w")
183        || query.contains(r"\W")
184        || query.contains(r"\d")
185        || query.contains(r"\D")
186        || query.contains(r"\s")
187        || query.contains(r"\S")
188        || query.contains(r"\p{")
189        || query.contains(r"\x")
190        || query.contains(r"\u{")
191        || has_escaped_regex_metachar(query)
192}
193
194fn has_path_regex_meta_sequences(query: &str) -> bool {
195    // Note: path-context uses doubled backslashes (Windows-style paths). Pure
196    // `\n`/`\t`/`\r` are NOT included here because they appear in Windows
197    // path segments (e.g. `C:\new\test`, `C:\temp\rs`); only the doubled
198    // forms count as deliberate escape sequences in path-shaped input.
199    query.contains(".+")
200        || query.contains(".*")
201        || query.contains(".?")
202        || query.contains(r"\\b")
203        || query.contains(r"\\B")
204        || query.contains(r"\\w")
205        || query.contains(r"\\W")
206        || query.contains(r"\\d")
207        || query.contains(r"\\D")
208        || query.contains(r"\\s")
209        || query.contains(r"\\S")
210        || query.contains(r"\p{")
211        || query.contains(r"\x")
212        || query.contains(r"\u{")
213        || has_escaped_regex_metachar(query)
214}
215
216fn has_escaped_regex_metachar(query: &str) -> bool {
217    let mut escaped = false;
218    for ch in query.chars() {
219        if escaped {
220            if is_escaped_metachar(ch) {
221                return true;
222            }
223            escaped = false;
224            continue;
225        }
226        escaped = ch == '\\';
227    }
228    false
229}
230
231fn has_obvious_regex_chars(query: &str) -> bool {
232    query.contains('*')
233        || query.contains('[')
234        || query.contains(']')
235        || query.contains('(')
236        || query.contains(')')
237        || query.contains('|')
238        || query.contains('{')
239        || query.contains('}')
240}
241
242fn tier_a_regex_signal(query: &str) -> bool {
243    query.contains("(?:")
244        || NAMED_CAPTURE_RE.is_match(query)
245        || ["(?i)", "(?m)", "(?s)", "(?x)"]
246            .iter()
247            .any(|signal| query.contains(signal))
248        || [
249            r"\b", r"\B", r"\w", r"\W", r"\d", r"\D", r"\s", r"\S", r"\p{", r"\x", r"\u{", r"\n",
250            r"\t", r"\r",
251        ]
252        .iter()
253        .any(|signal| query.contains(signal))
254        || has_brace_quantifier(query)
255        || has_anchored_identifier(query)
256        || has_contextual_escaped_metachar(query)
257}
258
259fn has_brace_quantifier(query: &str) -> bool {
260    for matched in BRACE_QUANTIFIER_RE.find_iter(query) {
261        if matched.start() > 0
262            && query[..matched.start()]
263                .chars()
264                .last()
265                .is_some_and(|ch| !ch.is_whitespace())
266        {
267            return true;
268        }
269    }
270    false
271}
272
273fn has_anchored_identifier(query: &str) -> bool {
274    let trimmed = query.trim();
275    if let Some(rest) = trimmed.strip_prefix('^') {
276        if leading_identifier_len(rest) >= 3 {
277            return true;
278        }
279    }
280    if let Some(rest) = trimmed.strip_suffix('$') {
281        if trailing_identifier_len(rest) >= 3 {
282            return true;
283        }
284    }
285    false
286}
287
288fn leading_identifier_len(text: &str) -> usize {
289    text.chars()
290        .take_while(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
291        .count()
292}
293
294fn trailing_identifier_len(text: &str) -> usize {
295    text.chars()
296        .rev()
297        .take_while(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
298        .count()
299}
300
301fn has_contextual_escaped_metachar(query: &str) -> bool {
302    let chars: Vec<char> = query.chars().collect();
303    let mut index = 0usize;
304    while index + 1 < chars.len() {
305        if chars[index] == '\\' && is_escaped_metachar(chars[index + 1]) {
306            let literal_after = chars[index + 2..]
307                .iter()
308                .filter(|ch| ch.is_ascii_alphanumeric() || **ch == '_')
309                .count();
310            if literal_after >= 2 {
311                return true;
312            }
313            index += 2;
314        } else {
315            index += 1;
316        }
317    }
318    false
319}
320
321fn is_escaped_metachar(ch: char) -> bool {
322    matches!(
323        ch,
324        '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$'
325    )
326}
327
328fn tier_b_character_class(query: &str) -> bool {
329    for content in bracket_contents(query) {
330        if content.starts_with('^')
331            || CHAR_RANGE_RE.is_match(&content)
332            || [r"\w", r"\d", r"\s", r"\W", r"\D", r"\S"]
333                .iter()
334                .any(|signal| content.contains(signal))
335            || multi_char_non_identifier_class(&content)
336        {
337            return true;
338        }
339    }
340    false
341}
342
343fn bracket_contents(query: &str) -> Vec<String> {
344    let mut contents = Vec::new();
345    let mut escaped = false;
346    let mut start = None;
347    for (index, ch) in query.char_indices() {
348        if escaped {
349            escaped = false;
350            continue;
351        }
352        if ch == '\\' {
353            escaped = true;
354            continue;
355        }
356        match ch {
357            '[' if start.is_none() => start = Some(index + ch.len_utf8()),
358            ']' => {
359                if let Some(open) = start.take() {
360                    contents.push(query[open..index].to_string());
361                }
362            }
363            _ => {}
364        }
365    }
366    contents
367}
368
369fn multi_char_non_identifier_class(content: &str) -> bool {
370    let char_count = content.chars().count();
371    char_count >= 2
372        && !content.chars().any(|ch| {
373            ch.is_ascii_alphanumeric() || ch == '_' || ch == '"' || ch == '\'' || ch == ';'
374        })
375}
376
377fn tier_c_adjacent_meta(query: &str) -> bool {
378    has_dot_quantifier(query) || has_regex_pipe(query) || escaped_paren_count(query) >= 2
379}
380
381fn has_dot_quantifier(query: &str) -> bool {
382    [".*", ".+", ".?"]
383        .iter()
384        .any(|signal| query.contains(signal) && query.trim().len() > signal.len())
385}
386
387fn has_regex_pipe(query: &str) -> bool {
388    for (index, ch) in query.char_indices() {
389        if ch != '|' {
390            continue;
391        }
392        let left = trailing_identifier_len(&query[..index]);
393        let right = leading_identifier_len(&query[index + ch.len_utf8()..]);
394        if left >= 3 && right >= 3 {
395            return true;
396        }
397    }
398    false
399}
400
401fn escaped_paren_count(query: &str) -> usize {
402    let mut count = 0usize;
403    let mut escaped = false;
404    for ch in query.chars() {
405        if escaped {
406            if ch == '(' || ch == ')' {
407                count += 1;
408            }
409            escaped = false;
410            continue;
411        }
412        escaped = ch == '\\';
413    }
414    count
415}
416
417fn extract_path_tokens(query: &str) -> Vec<String> {
418    let mut tokens = Vec::new();
419    for segment in query
420        .split(['/', '\\'])
421        .filter(|segment| !segment.is_empty())
422    {
423        if segment.contains('.') {
424            if let Some(stem) = segment.rsplit_once('.').map(|(stem, _)| stem) {
425                push_unique(&mut tokens, stem);
426            }
427        }
428        push_unique(&mut tokens, segment);
429    }
430    tokens
431}
432
433fn extract_error_code_tokens(query: &str) -> Vec<String> {
434    let mut tokens = Vec::new();
435    for regex in [
436        &*HEX_CODE_RE,
437        &*ERROR_PREFIX_RE,
438        &*NUMERIC_ERROR_RE,
439        &*HTTP_STATUS_RE,
440    ] {
441        for mat in regex.find_iter(query) {
442            push_unique(&mut tokens, mat.as_str());
443        }
444    }
445    if tokens.is_empty() && !query.trim().is_empty() {
446        push_unique(&mut tokens, query.trim());
447    }
448    tokens
449}
450
451fn extract_identifier_tokens(query: &str, require_code_shape: bool) -> Vec<String> {
452    let mut tokens = Vec::new();
453    for mat in IDENTIFIER_TOKEN_RE.find_iter(query) {
454        let token = mat.as_str();
455        if require_code_shape && !is_code_identifier_token(token) {
456            continue;
457        }
458        push_unique(&mut tokens, token);
459    }
460    tokens
461}
462
463fn is_code_identifier_token(token: &str) -> bool {
464    CAMEL_CASE_RE.is_match(token)
465        || SNAKE_CASE_RE.is_match(token)
466        || PASCAL_CASE_RE.is_match(token)
467        || ACRONYM_PASCAL_RE.is_match(token)
468        || DOT_PATH_RE.is_match(token)
469        || ERROR_PREFIX_RE.is_match(token)
470}
471
472fn push_unique(tokens: &mut Vec<String>, token: &str) {
473    if !token.is_empty() && !tokens.iter().any(|existing| existing == token) {
474        tokens.push(token.to_string());
475    }
476}
477
478fn shape(kind: QueryKind) -> QueryShape {
479    QueryShape {
480        kind,
481        weights: weights_for(kind),
482    }
483}
484
485fn weights_for(kind: QueryKind) -> ShapeWeights {
486    match kind {
487        QueryKind::Identifier => ShapeWeights {
488            semantic: 0.2,
489            lexical: 0.8,
490            should_use_lexical: true,
491        },
492        QueryKind::Path | QueryKind::ErrorCode => ShapeWeights {
493            semantic: 0.1,
494            lexical: 0.9,
495            should_use_lexical: true,
496        },
497        QueryKind::Regex => ShapeWeights {
498            semantic: 0.0,
499            lexical: 1.0,
500            should_use_lexical: false,
501        },
502        QueryKind::NaturalLanguage => ShapeWeights {
503            semantic: 0.6,
504            lexical: 0.4,
505            should_use_lexical: false,
506        },
507        QueryKind::Mixed => ShapeWeights {
508            semantic: 0.4,
509            lexical: 0.6,
510            should_use_lexical: true,
511        },
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    fn kind(query: &str) -> QueryKind {
520        classify(query).kind
521    }
522
523    #[test]
524    fn url_exemptions_allow_common_literal_url_punctuation() {
525        for query in [
526            "https://api.io/path",
527            "https://api.io/foo?q=test",
528            "https://api.io/foo+bar",
529            "https://api.io/foo@bar",
530            "https://api.io/foo#anchor",
531        ] {
532            assert_eq!(pre_tier_exempt(query), Some("url"), "{query}");
533            assert_ne!(kind(query), QueryKind::Regex, "{query}");
534        }
535    }
536
537    #[test]
538    fn url_exemptions_reject_regex_sequences() {
539        for query in [
540            "https://.*",
541            "https://api.io/.+",
542            "file://[^ ]+",
543            "file:///tmp/.+",
544            r"https://api.io/users/\w+",
545        ] {
546            assert_eq!(kind(query), QueryKind::Regex, "{query}");
547        }
548    }
549
550    #[test]
551    fn path_and_filename_exemptions_allow_literal_punctuation() {
552        for (query, expected) in [
553            (r"C:\new\test", "windows_abs"),
554            (r"src\bin\main.rs", "windows_rel"),
555            (r"packages\opencode-plugin\src", "windows_rel"),
556            ("/usr/local/bin", "posix_abs"),
557            ("/Users/John Doe/Documents", "posix_abs"),
558            ("/home/user/.gitignore", "posix_abs"),
559            ("v1/release/notes.md", "posix_rel"),
560            ("/home/user/jeff's-folder", "posix_abs"),
561            ("C++/parser/main.cpp", "posix_rel"),
562            ("foo+bar/baz.ts", "posix_rel"),
563            ("is_valid?.ts", "filename"),
564            ("Cargo.lock", "filename"),
565            ("tsconfig.json", "filename"),
566        ] {
567            assert_eq!(pre_tier_exempt(query), Some(expected), "{query}");
568            assert_eq!(kind(query), QueryKind::Path, "{query}");
569        }
570        assert_eq!(pre_tier_exempt("foo?"), None);
571    }
572
573    #[test]
574    fn path_exemptions_reject_regex_sequences() {
575        for query in ["src/.*", "src/.+", r"C:\bin\foo*.exe", r"C:\Users\\w+"] {
576            assert_eq!(kind(query), QueryKind::Regex, "{query}");
577        }
578    }
579
580    #[test]
581    fn tier_a_and_c_regex_signals_route_to_regex() {
582        for query in [
583            "^export",
584            "foo$",
585            "^main$",
586            r"foo\.bar",
587            r"\(method\)",
588            r"\bTODO\b",
589            ".*foo",
590            "foo|bar",
591            "(?:foo)",
592            "(?P<n>foo)",
593            "(?i)Todo",
594            r"\p{Lu}",
595            r"\xFF",
596            r"\u{1F600}",
597            "a{3}",
598            // Bare escape sequences route to regex via Tier A. Caveat: `foo\n`
599            // and similar single-backslash-escape after literal text are
600            // genuinely ambiguous with Windows path segments (e.g., file `n`
601            // in directory `foo`) and stay on the path/exemption path.
602            r"\n",
603            r"\t",
604            r"\r",
605            r"\tindent",
606        ] {
607            assert_eq!(kind(query), QueryKind::Regex, "{query}");
608        }
609    }
610
611    #[test]
612    fn character_classes_route_only_when_they_look_like_classes() {
613        for query in ["[a-z]+", "[^abc]", r"[\w]+"] {
614            assert_eq!(kind(query), QueryKind::Regex, "{query}");
615        }
616        for query in [
617            "arr[0]",
618            "obj[key]",
619            "config[\"key\"]",
620            "#[derive]",
621            "Vec<[u8; 32]>",
622        ] {
623            assert_ne!(kind(query), QueryKind::Regex, "{query}");
624        }
625    }
626
627    #[test]
628    fn unsupported_regex_syntax_still_routes_to_regex_for_compile_error() {
629        for query in [
630            "(?=foo)",
631            "(?!foo)",
632            "(?<=foo)",
633            "(?<!foo)",
634            "(?P=name)",
635            r"\1",
636            "foo*+",
637            "(?>foo)",
638        ] {
639            assert_eq!(kind(query), QueryKind::Regex, "{query}");
640        }
641    }
642
643    #[test]
644    fn weak_regex_like_punctuation_does_not_route_to_regex() {
645        for query in [
646            "^id",
647            "id$",
648            "^",
649            "$",
650            "$HOME",
651            r"\.",
652            "array.length",
653            "foo()",
654            "map.get(key)",
655            "a|b",
656        ] {
657            assert_ne!(kind(query), QueryKind::Regex, "{query}");
658        }
659    }
660}