Skip to main content

aft/
query_shape.rs

1use regex::Regex;
2use std::sync::LazyLock;
3
4static CAMEL_CASE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-z][A-Z]").unwrap());
5static SNAKE_CASE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-z]_[a-z]").unwrap());
6static PASCAL_CASE_RE: LazyLock<Regex> =
7    LazyLock::new(|| Regex::new(r"^[A-Z][a-z]+[A-Z]").unwrap());
8static ACRONYM_PASCAL_RE: LazyLock<Regex> =
9    LazyLock::new(|| Regex::new(r"\b[A-Z]{2,}[A-Z][a-z]").unwrap());
10static DOT_PATH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-zA-Z]\.[a-zA-Z]").unwrap());
11static FILE_PATH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[/\\].*\.\w{1,5}$").unwrap());
12static HEX_CODE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"0x[A-Fa-f0-9]+").unwrap());
13static ERROR_PREFIX_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\bERR_\w+").unwrap());
14static NUMERIC_ERROR_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\bE\d{4,}").unwrap());
15static TYPESCRIPT_ERROR_RE: LazyLock<Regex> =
16    LazyLock::new(|| Regex::new(r"\bTS\d{4,}\b").unwrap());
17static HTTP_STATUS_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b[1-5]\d{2}\b").unwrap());
18static IDENTIFIER_TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
19    Regex::new(r"\b[A-Za-z_$][A-Za-z0-9_$]*(?:\.[A-Za-z_$][A-Za-z0-9_$]*)*\b").unwrap()
20});
21
22static WINDOWS_ABS_PATH_RE: LazyLock<Regex> =
23    LazyLock::new(|| Regex::new(r"^[A-Za-z]:[\\/][A-Za-z0-9_.\-+?\\/' ]+$").unwrap());
24static WINDOWS_REL_PATH_RE: LazyLock<Regex> =
25    LazyLock::new(|| Regex::new(r"^[A-Za-z0-9_.\-+?' ]+(\\[A-Za-z0-9_.\-+?' ]+)+$").unwrap());
26static POSIX_ABS_PATH_RE: LazyLock<Regex> =
27    LazyLock::new(|| Regex::new(r"^/[A-Za-z0-9_.\-+?/' ]+$").unwrap());
28static POSIX_REL_PATH_RE: LazyLock<Regex> =
29    LazyLock::new(|| Regex::new(r"^[A-Za-z0-9_.\-+?' ]+(/[A-Za-z0-9_.\-+?' ]+)+$").unwrap());
30static UNC_PATH_RE: LazyLock<Regex> =
31    LazyLock::new(|| Regex::new(r"^\\\\[A-Za-z0-9_.\-+?\\']+$").unwrap());
32static FILENAME_EXEMPTION_RE: LazyLock<Regex> =
33    LazyLock::new(|| Regex::new(r"^[A-Za-z_][A-Za-z0-9_.\-+'? ]*\.[A-Za-z0-9]{1,8}$").unwrap());
34static BRACE_QUANTIFIER_RE: LazyLock<Regex> =
35    LazyLock::new(|| Regex::new(r"\{\d+(?:,\d*)?\}").unwrap());
36static NAMED_CAPTURE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\(\?P<[^>]+>").unwrap());
37static CHAR_RANGE_RE: LazyLock<Regex> =
38    LazyLock::new(|| Regex::new(r"[A-Za-z0-9]-[A-Za-z0-9]").unwrap());
39
40const QUESTION_WORDS: &[&str] = &[
41    "how", "what", "where", "why", "when", "which", "who", "does",
42];
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum QueryKind {
46    Identifier,
47    Mixed,
48    ErrorCode,
49    Path,
50    Regex,
51    NaturalLanguage,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub struct ShapeWeights {
56    pub semantic: f32,
57    pub lexical: f32,
58    pub should_use_lexical: bool,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq)]
62pub struct QueryShape {
63    pub kind: QueryKind,
64    pub weights: ShapeWeights,
65}
66
67pub fn classify(query: &str) -> QueryShape {
68    let trimmed = query.trim();
69    if trimmed.is_empty() {
70        return shape(QueryKind::NaturalLanguage);
71    }
72
73    if pre_tier_exempt(trimmed).is_some() {
74        return shape(QueryKind::Path);
75    }
76
77    if looks_like_regex(trimmed) {
78        return shape(QueryKind::Regex);
79    }
80
81    let words: Vec<&str> = trimmed.split_whitespace().collect();
82    let word_count = words.len();
83    let first_word_lower = words[0].to_ascii_lowercase();
84
85    if FILE_PATH_RE.is_match(trimmed) {
86        return shape(QueryKind::Path);
87    }
88
89    let has_question_word = QUESTION_WORDS.contains(&first_word_lower.as_str());
90    let is_long_phrase = word_count > 2;
91    let is_two_word_concept = is_two_word_lowercase_concept(&words);
92    let has_natural_language_signals = has_question_word || is_long_phrase || is_two_word_concept;
93    let has_error_code = contains_error_code(trimmed, word_count);
94
95    if has_error_code && has_natural_language_signals {
96        return shape(QueryKind::Mixed);
97    }
98
99    if has_error_code {
100        return shape(QueryKind::ErrorCode);
101    }
102
103    let has_code_identifier = CAMEL_CASE_RE.is_match(trimmed)
104        || SNAKE_CASE_RE.is_match(trimmed)
105        || PASCAL_CASE_RE.is_match(trimmed)
106        || ACRONYM_PASCAL_RE.is_match(trimmed)
107        || DOT_PATH_RE.is_match(trimmed);
108
109    if has_code_identifier && has_natural_language_signals {
110        return shape(QueryKind::Mixed);
111    }
112
113    if has_code_identifier || (word_count <= 2 && !has_natural_language_signals) {
114        return shape(QueryKind::Identifier);
115    }
116
117    shape(QueryKind::NaturalLanguage)
118}
119
120pub fn extract_tokens(query: &str, shape: &QueryShape) -> Vec<String> {
121    match shape.kind {
122        QueryKind::NaturalLanguage | QueryKind::Regex => Vec::new(),
123        QueryKind::Path => extract_path_tokens(query),
124        QueryKind::ErrorCode => extract_error_code_tokens(query),
125        QueryKind::Identifier => extract_identifier_tokens(query, false),
126        QueryKind::Mixed => extract_identifier_tokens(query, true),
127    }
128}
129
130pub fn pre_tier_exempt(query: &str) -> Option<&'static str> {
131    if let Some(kind) = check_url_exemption(query) {
132        return Some(kind);
133    }
134    check_path_exemption(query)
135}
136
137pub fn looks_like_regex(query: &str) -> bool {
138    crate::pattern_compile::detect_unsupported_features(query).is_some()
139        || tier_a_regex_signal(query)
140        || tier_b_character_class(query)
141        || tier_c_adjacent_meta(query)
142}
143
144fn check_url_exemption(query: &str) -> Option<&'static str> {
145    let parsed = url::Url::parse(query).ok()?;
146    if !matches!(parsed.scheme(), "http" | "https" | "file" | "ftp" | "ssh") {
147        return None;
148    }
149    if has_regex_meta_sequences(query) || has_obvious_regex_chars(query) {
150        return None;
151    }
152    Some("url")
153}
154
155fn check_path_exemption(query: &str) -> Option<&'static str> {
156    let kind = if WINDOWS_ABS_PATH_RE.is_match(query) {
157        "windows_abs"
158    } else if WINDOWS_REL_PATH_RE.is_match(query) {
159        "windows_rel"
160    } else if POSIX_ABS_PATH_RE.is_match(query) {
161        "posix_abs"
162    } else if POSIX_REL_PATH_RE.is_match(query) {
163        "posix_rel"
164    } else if UNC_PATH_RE.is_match(query) {
165        "unc"
166    } else if FILENAME_EXEMPTION_RE.is_match(query) {
167        "filename"
168    } else {
169        return None;
170    };
171    if has_path_regex_meta_sequences(query) || has_obvious_regex_chars(query) {
172        return None;
173    }
174    Some(kind)
175}
176
177fn contains_error_code(query: &str, word_count: usize) -> bool {
178    HEX_CODE_RE.is_match(query)
179        || ERROR_PREFIX_RE.is_match(query)
180        || NUMERIC_ERROR_RE.is_match(query)
181        || TYPESCRIPT_ERROR_RE.is_match(query)
182        || has_http_status(query, word_count)
183}
184
185fn has_http_status(query: &str, word_count: usize) -> bool {
186    HTTP_STATUS_RE.is_match(query)
187        && (word_count <= 3 || query.to_ascii_lowercase().contains("http"))
188}
189
190fn is_two_word_lowercase_concept(words: &[&str]) -> bool {
191    words.len() == 2
192        && words
193            .iter()
194            .all(|word| is_dictionary_style_lowercase_word(word))
195}
196
197fn is_dictionary_style_lowercase_word(word: &str) -> bool {
198    word.len() >= 3 && word.bytes().all(|byte| byte.is_ascii_lowercase())
199}
200
201fn has_regex_meta_sequences(query: &str) -> bool {
202    query.contains(".+")
203        || query.contains(".*")
204        || query.contains(".?")
205        || query.contains(r"\n")
206        || query.contains(r"\t")
207        || query.contains(r"\r")
208        || query.contains(r"\b")
209        || query.contains(r"\B")
210        || query.contains(r"\w")
211        || query.contains(r"\W")
212        || query.contains(r"\d")
213        || query.contains(r"\D")
214        || query.contains(r"\s")
215        || query.contains(r"\S")
216        || query.contains(r"\p{")
217        || query.contains(r"\x")
218        || query.contains(r"\u{")
219        || has_escaped_regex_metachar(query)
220}
221
222fn has_path_regex_meta_sequences(query: &str) -> bool {
223    query.contains(".+")
224        || query.contains(".*")
225        || query.contains(".?")
226        || query.contains(r"\p{")
227        || query.contains(r"\x")
228        || query.contains(r"\u{")
229        || has_path_context_regex_escape(query)
230        || has_escaped_regex_metachar(query)
231}
232
233fn has_path_context_regex_escape(query: &str) -> bool {
234    let chars = query.char_indices().collect::<Vec<_>>();
235    for index in 0..chars.len().saturating_sub(1) {
236        if chars[index].1 != '\\' {
237            continue;
238        }
239        let escaped = chars[index + 1].1;
240        if matches!(escaped, 'b' | 'B' | 'w' | 'W' | 'd' | 'D' | 's' | 'S')
241            && path_escape_looks_like_regex(&chars, index + 1)
242        {
243            return true;
244        }
245    }
246    false
247}
248
249fn path_escape_looks_like_regex(chars: &[(usize, char)], escaped_index: usize) -> bool {
250    let Some((_, next)) = chars.get(escaped_index + 1) else {
251        return true;
252    };
253
254    matches!(
255        *next,
256        '*' | '+' | '?' | '{' | '(' | '[' | '|' | '^' | '$' | '\\' | '/'
257    )
258}
259
260fn has_escaped_regex_metachar(query: &str) -> bool {
261    let mut escaped = false;
262    for ch in query.chars() {
263        if escaped {
264            if is_escaped_metachar(ch) {
265                return true;
266            }
267            escaped = false;
268            continue;
269        }
270        escaped = ch == '\\';
271    }
272    false
273}
274
275fn has_obvious_regex_chars(query: &str) -> bool {
276    query.contains('*')
277        || query.contains('[')
278        || query.contains(']')
279        || query.contains('(')
280        || query.contains(')')
281        || query.contains('|')
282        || query.contains('{')
283        || query.contains('}')
284}
285
286fn tier_a_regex_signal(query: &str) -> bool {
287    query.contains("(?:")
288        || NAMED_CAPTURE_RE.is_match(query)
289        || ["(?i)", "(?m)", "(?s)", "(?x)"]
290            .iter()
291            .any(|signal| query.contains(signal))
292        || [
293            r"\b", r"\B", r"\w", r"\W", r"\d", r"\D", r"\s", r"\S", r"\p{", r"\x", r"\u{", r"\n",
294            r"\t", r"\r",
295        ]
296        .iter()
297        .any(|signal| query.contains(signal))
298        || has_brace_quantifier(query)
299        || has_anchored_identifier(query)
300        || has_contextual_escaped_metachar(query)
301}
302
303fn has_brace_quantifier(query: &str) -> bool {
304    for matched in BRACE_QUANTIFIER_RE.find_iter(query) {
305        if matched.start() > 0
306            && query[..matched.start()]
307                .chars()
308                .last()
309                .is_some_and(|ch| !ch.is_whitespace())
310        {
311            return true;
312        }
313    }
314    false
315}
316
317fn has_anchored_identifier(query: &str) -> bool {
318    let trimmed = query.trim();
319    if let Some(rest) = trimmed.strip_prefix('^') {
320        if leading_identifier_len(rest) >= 3 {
321            return true;
322        }
323    }
324    if let Some(rest) = trimmed.strip_suffix('$') {
325        if trailing_identifier_len(rest) >= 3 {
326            return true;
327        }
328    }
329    false
330}
331
332fn leading_identifier_len(text: &str) -> usize {
333    text.chars()
334        .take_while(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
335        .count()
336}
337
338fn trailing_identifier_len(text: &str) -> usize {
339    text.chars()
340        .rev()
341        .take_while(|ch| ch.is_ascii_alphanumeric() || *ch == '_')
342        .count()
343}
344
345fn has_contextual_escaped_metachar(query: &str) -> bool {
346    let chars: Vec<char> = query.chars().collect();
347    let mut index = 0usize;
348    while index + 1 < chars.len() {
349        if chars[index] == '\\' && is_escaped_metachar(chars[index + 1]) {
350            let literal_after = chars[index + 2..]
351                .iter()
352                .filter(|ch| ch.is_ascii_alphanumeric() || **ch == '_')
353                .count();
354            if literal_after >= 2 {
355                return true;
356            }
357            index += 2;
358        } else {
359            index += 1;
360        }
361    }
362    false
363}
364
365fn is_escaped_metachar(ch: char) -> bool {
366    matches!(
367        ch,
368        '.' | '*' | '+' | '?' | '(' | ')' | '[' | ']' | '{' | '}' | '|' | '^' | '$'
369    )
370}
371
372fn tier_b_character_class(query: &str) -> bool {
373    for content in bracket_contents(query) {
374        if content.starts_with('^')
375            || CHAR_RANGE_RE.is_match(&content)
376            || [r"\w", r"\d", r"\s", r"\W", r"\D", r"\S"]
377                .iter()
378                .any(|signal| content.contains(signal))
379            || multi_char_non_identifier_class(&content)
380        {
381            return true;
382        }
383    }
384    false
385}
386
387fn bracket_contents(query: &str) -> Vec<String> {
388    let mut contents = Vec::new();
389    let mut escaped = false;
390    let mut start = None;
391    for (index, ch) in query.char_indices() {
392        if escaped {
393            escaped = false;
394            continue;
395        }
396        if ch == '\\' {
397            escaped = true;
398            continue;
399        }
400        match ch {
401            '[' if start.is_none() => start = Some(index + ch.len_utf8()),
402            ']' => {
403                if let Some(open) = start.take() {
404                    contents.push(query[open..index].to_string());
405                }
406            }
407            _ => {}
408        }
409    }
410    contents
411}
412
413fn multi_char_non_identifier_class(content: &str) -> bool {
414    let char_count = content.chars().count();
415    char_count >= 2
416        && !content.chars().any(|ch| {
417            ch.is_ascii_alphanumeric() || ch == '_' || ch == '"' || ch == '\'' || ch == ';'
418        })
419}
420
421fn tier_c_adjacent_meta(query: &str) -> bool {
422    has_dot_quantifier(query)
423        || has_literal_atom_quantifier(query)
424        || has_regex_pipe(query)
425        || escaped_paren_count(query) >= 2
426}
427
428fn has_dot_quantifier(query: &str) -> bool {
429    [".*", ".+", ".?"]
430        .iter()
431        .any(|signal| query.contains(signal) && query.trim().len() > signal.len())
432}
433
434fn has_literal_atom_quantifier(query: &str) -> bool {
435    let chars = query.char_indices().collect::<Vec<_>>();
436    for (index, (byte_index, ch)) in chars.iter().copied().enumerate() {
437        if !is_bare_quantifier(ch) || is_escaped_at(query, byte_index) {
438            continue;
439        }
440        if chars
441            .get(index + 1)
442            .is_some_and(|(_, next)| is_bare_quantifier(*next))
443        {
444            continue;
445        }
446        if ch == '?'
447            && (sentence_final_question_mark_in_phrase(query, byte_index)
448                || question_mark_is_code_shape(&chars, index))
449        {
450            continue;
451        }
452        if previous_is_literal_atom(&chars, index) {
453            return true;
454        }
455    }
456    false
457}
458
459fn sentence_final_question_mark_in_phrase(query: &str, byte_index: usize) -> bool {
460    query[byte_index + '?'.len_utf8()..].trim().is_empty()
461        && query[..byte_index].split_whitespace().count() > 1
462}
463
464fn question_mark_is_code_shape(chars: &[(usize, char)], question_index: usize) -> bool {
465    question_mark_is_optional_chain(chars, question_index)
466        || question_mark_after_empty_call(chars, question_index)
467        || question_mark_after_index_expression(chars, question_index)
468        || question_mark_is_typescript_optional(chars, question_index)
469}
470
471fn question_mark_is_optional_chain(chars: &[(usize, char)], question_index: usize) -> bool {
472    chars
473        .get(question_index + 1)
474        .is_some_and(|(_, next)| *next == '.')
475        && question_index
476            .checked_sub(1)
477            .and_then(|previous_index| chars.get(previous_index))
478            .is_some_and(|(_, previous)| is_code_expression_tail(*previous))
479}
480
481fn question_mark_after_empty_call(chars: &[(usize, char)], question_index: usize) -> bool {
482    let Some(call_open_index) = question_index.checked_sub(2) else {
483        return false;
484    };
485    chars
486        .get(question_index - 1)
487        .is_some_and(|(_, previous)| *previous == ')')
488        && chars
489            .get(call_open_index)
490            .is_some_and(|(_, open)| *open == '(')
491        && call_open_index
492            .checked_sub(1)
493            .and_then(|callee_index| chars.get(callee_index))
494            .is_some_and(|(_, callee_tail)| is_code_expression_tail(*callee_tail))
495}
496
497fn question_mark_after_index_expression(chars: &[(usize, char)], question_index: usize) -> bool {
498    if chars
499        .get(question_index.checked_sub(1).unwrap_or(usize::MAX))
500        .is_none_or(|(_, previous)| *previous != ']')
501    {
502        return false;
503    }
504
505    let mut depth = 0usize;
506    for index in (0..question_index).rev() {
507        match chars[index].1 {
508            ']' => depth += 1,
509            '[' => {
510                depth = depth.saturating_sub(1);
511                if depth == 0 {
512                    return index
513                        .checked_sub(1)
514                        .and_then(|target_index| chars.get(target_index))
515                        .is_some_and(|(_, target_tail)| is_code_expression_tail(*target_tail));
516                }
517            }
518            _ => {}
519        }
520    }
521    false
522}
523
524fn question_mark_is_typescript_optional(chars: &[(usize, char)], question_index: usize) -> bool {
525    let previous_is_identifier = question_index
526        .checked_sub(1)
527        .and_then(|previous_index| chars.get(previous_index))
528        .is_some_and(|(_, previous)| is_identifier_tail(*previous));
529    if !previous_is_identifier {
530        return false;
531    }
532    if chars
533        .get(question_index + 1)
534        .is_none_or(|(_, next)| *next != ':')
535    {
536        return false;
537    }
538
539    chars
540        .get(question_index + 2)
541        .is_none_or(|(_, after_colon)| {
542            after_colon.is_whitespace()
543                || after_colon.is_ascii_alphabetic()
544                || matches!(*after_colon, '_' | '{' | '[' | '(' | '"' | '\'')
545        })
546}
547
548fn is_code_expression_tail(ch: char) -> bool {
549    is_identifier_tail(ch) || matches!(ch, ')' | ']')
550}
551
552fn is_identifier_tail(ch: char) -> bool {
553    ch.is_ascii_alphanumeric() || matches!(ch, '_' | '$')
554}
555
556fn previous_is_literal_atom(chars: &[(usize, char)], quantifier_index: usize) -> bool {
557    let Some((_, previous)) = quantifier_index
558        .checked_sub(1)
559        .and_then(|previous_index| chars.get(previous_index))
560    else {
561        return false;
562    };
563
564    previous.is_ascii_alphanumeric() || *previous == '_' || *previous == ')' || *previous == ']'
565}
566
567fn is_bare_quantifier(ch: char) -> bool {
568    matches!(ch, '*' | '+' | '?')
569}
570
571fn is_escaped_at(query: &str, byte_index: usize) -> bool {
572    let backslash_count = query[..byte_index]
573        .chars()
574        .rev()
575        .take_while(|ch| *ch == '\\')
576        .count();
577    backslash_count % 2 == 1
578}
579
580fn has_regex_pipe(query: &str) -> bool {
581    for (index, ch) in query.char_indices() {
582        if ch != '|' {
583            continue;
584        }
585        let left = trailing_identifier_len(&query[..index]);
586        let right = leading_identifier_len(&query[index + ch.len_utf8()..]);
587        if left >= 3 && right >= 3 {
588            return true;
589        }
590    }
591    false
592}
593
594fn escaped_paren_count(query: &str) -> usize {
595    let mut count = 0usize;
596    let mut escaped = false;
597    for ch in query.chars() {
598        if escaped {
599            if ch == '(' || ch == ')' {
600                count += 1;
601            }
602            escaped = false;
603            continue;
604        }
605        escaped = ch == '\\';
606    }
607    count
608}
609
610fn extract_path_tokens(query: &str) -> Vec<String> {
611    let mut tokens = Vec::new();
612    for segment in query
613        .split(['/', '\\'])
614        .filter(|segment| !segment.is_empty())
615    {
616        if segment.contains('.') {
617            if let Some(stem) = segment.rsplit_once('.').map(|(stem, _)| stem) {
618                push_unique(&mut tokens, stem);
619            }
620        }
621        push_unique(&mut tokens, segment);
622    }
623    tokens
624}
625
626fn extract_error_code_tokens(query: &str) -> Vec<String> {
627    let mut tokens = Vec::new();
628    for regex in [
629        &*HEX_CODE_RE,
630        &*ERROR_PREFIX_RE,
631        &*NUMERIC_ERROR_RE,
632        &*TYPESCRIPT_ERROR_RE,
633        &*HTTP_STATUS_RE,
634    ] {
635        for mat in regex.find_iter(query) {
636            push_unique(&mut tokens, mat.as_str());
637        }
638    }
639    if tokens.is_empty() && !query.trim().is_empty() {
640        push_unique(&mut tokens, query.trim());
641    }
642    tokens
643}
644
645fn extract_identifier_tokens(query: &str, require_code_shape: bool) -> Vec<String> {
646    let mut tokens = Vec::new();
647    for mat in IDENTIFIER_TOKEN_RE.find_iter(query) {
648        let token = mat.as_str();
649        if require_code_shape && !is_code_identifier_token(token) {
650            continue;
651        }
652        push_unique(&mut tokens, token);
653    }
654    tokens
655}
656
657fn is_code_identifier_token(token: &str) -> bool {
658    CAMEL_CASE_RE.is_match(token)
659        || SNAKE_CASE_RE.is_match(token)
660        || PASCAL_CASE_RE.is_match(token)
661        || ACRONYM_PASCAL_RE.is_match(token)
662        || DOT_PATH_RE.is_match(token)
663        || ERROR_PREFIX_RE.is_match(token)
664        || NUMERIC_ERROR_RE.is_match(token)
665        || TYPESCRIPT_ERROR_RE.is_match(token)
666}
667
668fn push_unique(tokens: &mut Vec<String>, token: &str) {
669    if !token.is_empty() && !tokens.iter().any(|existing| existing == token) {
670        tokens.push(token.to_string());
671    }
672}
673
674fn shape(kind: QueryKind) -> QueryShape {
675    QueryShape {
676        kind,
677        weights: weights_for(kind),
678    }
679}
680
681fn weights_for(kind: QueryKind) -> ShapeWeights {
682    match kind {
683        QueryKind::Identifier => ShapeWeights {
684            semantic: 0.2,
685            lexical: 0.8,
686            should_use_lexical: true,
687        },
688        QueryKind::Path | QueryKind::ErrorCode => ShapeWeights {
689            semantic: 0.1,
690            lexical: 0.9,
691            should_use_lexical: true,
692        },
693        QueryKind::Regex => ShapeWeights {
694            semantic: 0.0,
695            lexical: 1.0,
696            should_use_lexical: false,
697        },
698        QueryKind::NaturalLanguage => ShapeWeights {
699            semantic: 0.6,
700            lexical: 0.4,
701            should_use_lexical: false,
702        },
703        QueryKind::Mixed => ShapeWeights {
704            semantic: 0.4,
705            lexical: 0.6,
706            should_use_lexical: true,
707        },
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714
715    fn kind(query: &str) -> QueryKind {
716        classify(query).kind
717    }
718
719    #[test]
720    fn url_exemptions_allow_common_literal_url_punctuation() {
721        for query in [
722            "https://api.io/path",
723            "https://api.io/foo?q=test",
724            "https://api.io/foo+bar",
725            "https://api.io/foo@bar",
726            "https://api.io/foo#anchor",
727        ] {
728            assert_eq!(pre_tier_exempt(query), Some("url"), "{query}");
729            assert_ne!(kind(query), QueryKind::Regex, "{query}");
730        }
731    }
732
733    #[test]
734    fn url_exemptions_reject_regex_sequences() {
735        for query in [
736            "https://.*",
737            "https://api.io/.+",
738            "file://[^ ]+",
739            "file:///tmp/.+",
740            r"https://api.io/users/\w+",
741        ] {
742            assert_eq!(kind(query), QueryKind::Regex, "{query}");
743        }
744    }
745
746    #[test]
747    fn path_and_filename_exemptions_allow_literal_punctuation() {
748        for (query, expected) in [
749            (r"C:\new\test", "windows_abs"),
750            (r"src\bin\main.rs", "windows_rel"),
751            (r"src\tab\main.ts", "windows_rel"),
752            (r"packages\opencode-plugin\src", "windows_rel"),
753            ("/usr/local/bin", "posix_abs"),
754            ("/Users/John Doe/Documents", "posix_abs"),
755            ("/home/user/.gitignore", "posix_abs"),
756            ("v1/release/notes.md", "posix_rel"),
757            ("/home/user/jeff's-folder", "posix_abs"),
758            ("C++/parser/main.cpp", "posix_rel"),
759            ("foo+bar/baz.ts", "posix_rel"),
760            ("is_valid?.ts", "filename"),
761            ("Cargo.lock", "filename"),
762            ("tsconfig.json", "filename"),
763        ] {
764            assert_eq!(pre_tier_exempt(query), Some(expected), "{query}");
765            assert_eq!(kind(query), QueryKind::Path, "{query}");
766        }
767        assert_eq!(pre_tier_exempt("foo?"), None);
768    }
769
770    #[test]
771    fn path_exemptions_reject_regex_sequences() {
772        for query in [
773            "src/.*",
774            "src/.+",
775            r"C:\bin\foo*.exe",
776            r"C:\Users\\w+",
777            r"src\w+\main.ts",
778        ] {
779            assert_eq!(kind(query), QueryKind::Regex, "{query}");
780        }
781    }
782
783    #[test]
784    fn tier_a_and_c_regex_signals_route_to_regex() {
785        for query in [
786            "^export",
787            "foo$",
788            "^main$",
789            r"foo\.bar",
790            r"\(method\)",
791            r"\bTODO\b",
792            ".*foo",
793            "foo|bar",
794            "(?:foo)",
795            "(?P<n>foo)",
796            "(?i)Todo",
797            r"\p{Lu}",
798            r"\xFF",
799            r"\u{1F600}",
800            "a{3}",
801            // Bare escape sequences route to regex via Tier A. Caveat: `foo\n`
802            // and similar single-backslash-escape after literal text are
803            // genuinely ambiguous with Windows path segments (e.g., file `n`
804            // in directory `foo`) and stay on the path/exemption path.
805            r"\n",
806            r"\t",
807            r"\r",
808            r"\tindent",
809        ] {
810            assert_eq!(kind(query), QueryKind::Regex, "{query}");
811        }
812    }
813
814    #[test]
815    fn character_classes_route_only_when_they_look_like_classes() {
816        for query in ["[a-z]+", "[^abc]", r"[\w]+"] {
817            assert_eq!(kind(query), QueryKind::Regex, "{query}");
818        }
819        for query in [
820            "arr[0]",
821            "obj[key]",
822            "config[\"key\"]",
823            "#[derive]",
824            "Vec<[u8; 32]>",
825        ] {
826            assert_ne!(kind(query), QueryKind::Regex, "{query}");
827        }
828    }
829
830    #[test]
831    fn unsupported_regex_syntax_still_routes_to_regex_for_compile_error() {
832        for query in [
833            "(?=foo)",
834            "(?!foo)",
835            "(?<=foo)",
836            "(?<!foo)",
837            "(?P=name)",
838            r"\1",
839            "foo*+",
840            "(?>foo)",
841        ] {
842            assert_eq!(kind(query), QueryKind::Regex, "{query}");
843        }
844    }
845
846    #[test]
847    fn two_word_lowercase_concepts_route_to_natural_language() {
848        for query in ["retry logic", "auth flow", "cache invalidation"] {
849            assert_eq!(kind(query), QueryKind::NaturalLanguage, "{query}");
850        }
851    }
852
853    #[test]
854    fn identifierish_short_queries_stay_identifier() {
855        for query in ["useState hook", "parseConfig", "parse_config option"] {
856            assert_eq!(kind(query), QueryKind::Identifier, "{query}");
857        }
858    }
859
860    #[test]
861    fn question_mark_code_shapes_do_not_route_to_regex() {
862        for query in ["foo()?", "optional?.length", "user?.name", "arr[0]?"] {
863            assert_ne!(kind(query), QueryKind::Regex, "{query}");
864        }
865    }
866
867    #[test]
868    fn question_mark_regex_quantifiers_still_route_to_regex() {
869        for query in ["colou?r", "https?"] {
870            assert_eq!(kind(query), QueryKind::Regex, "{query}");
871        }
872    }
873
874    #[test]
875    fn weak_regex_like_punctuation_does_not_route_to_regex() {
876        for query in [
877            "^id",
878            "id$",
879            "^",
880            "$",
881            "$HOME",
882            r"\.",
883            "array.length",
884            "foo()",
885            "map.get(key)",
886            "a|b",
887        ] {
888            assert_ne!(kind(query), QueryKind::Regex, "{query}");
889        }
890    }
891}