Skip to main content

aft/
query_shape.rs

1use regex::Regex;
2use std::sync::LazyLock;
3
4static CAMEL_CASE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-z][A-Z]").unwrap());
5static SNAKE_CASE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-z]_[a-z]").unwrap());
6static PASCAL_CASE_RE: LazyLock<Regex> =
7    LazyLock::new(|| Regex::new(r"^[A-Z][a-z]+[A-Z]").unwrap());
8static ACRONYM_PASCAL_RE: LazyLock<Regex> =
9    LazyLock::new(|| Regex::new(r"\b[A-Z]{2,}[A-Z][a-z]").unwrap());
10static DOT_PATH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[a-zA-Z]\.[a-zA-Z]").unwrap());
11static FILE_PATH_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"[/\\].*\.\w{1,5}$").unwrap());
12static HEX_CODE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"0x[A-Fa-f0-9]+").unwrap());
13static ERROR_PREFIX_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\bERR_\w+").unwrap());
14static NUMERIC_ERROR_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\bE\d{4,}").unwrap());
15static HTTP_STATUS_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b[1-5]\d{2}\b").unwrap());
16static IDENTIFIER_TOKEN_RE: LazyLock<Regex> = LazyLock::new(|| {
17    Regex::new(r"\b[A-Za-z_$][A-Za-z0-9_$]*(?:\.[A-Za-z_$][A-Za-z0-9_$]*)*\b").unwrap()
18});
19
20const QUESTION_WORDS: &[&str] = &[
21    "how", "what", "where", "why", "when", "which", "who", "does",
22];
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum QueryKind {
26    Identifier,
27    Mixed,
28    ErrorCode,
29    Path,
30    NaturalLanguage,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq)]
34pub struct ShapeWeights {
35    pub semantic: f32,
36    pub lexical: f32,
37    pub should_use_lexical: bool,
38}
39
40#[derive(Debug, Clone, Copy, PartialEq)]
41pub struct QueryShape {
42    pub kind: QueryKind,
43    pub weights: ShapeWeights,
44}
45
46pub fn classify(query: &str) -> QueryShape {
47    let trimmed = query.trim();
48    if trimmed.is_empty() {
49        return shape(QueryKind::NaturalLanguage);
50    }
51
52    let words: Vec<&str> = trimmed.split_whitespace().collect();
53    let word_count = words.len();
54    let first_word_lower = words[0].to_ascii_lowercase();
55
56    if FILE_PATH_RE.is_match(trimmed) {
57        return shape(QueryKind::Path);
58    }
59
60    let has_http_status = word_count <= 3 && HTTP_STATUS_RE.is_match(trimmed);
61    if HEX_CODE_RE.is_match(trimmed)
62        || ERROR_PREFIX_RE.is_match(trimmed)
63        || NUMERIC_ERROR_RE.is_match(trimmed)
64        || has_http_status
65    {
66        return shape(QueryKind::ErrorCode);
67    }
68
69    let has_code_identifier = CAMEL_CASE_RE.is_match(trimmed)
70        || SNAKE_CASE_RE.is_match(trimmed)
71        || PASCAL_CASE_RE.is_match(trimmed)
72        || ACRONYM_PASCAL_RE.is_match(trimmed)
73        || DOT_PATH_RE.is_match(trimmed);
74    let has_question_word = QUESTION_WORDS.contains(&first_word_lower.as_str());
75    let is_long_phrase = word_count > 2;
76    let has_natural_language_signals = has_question_word || is_long_phrase;
77
78    if has_code_identifier && has_natural_language_signals {
79        return shape(QueryKind::Mixed);
80    }
81
82    if has_code_identifier || (word_count <= 2 && !has_natural_language_signals) {
83        return shape(QueryKind::Identifier);
84    }
85
86    shape(QueryKind::NaturalLanguage)
87}
88
89pub fn extract_tokens(query: &str, shape: &QueryShape) -> Vec<String> {
90    match shape.kind {
91        QueryKind::NaturalLanguage => Vec::new(),
92        QueryKind::Path => extract_path_tokens(query),
93        QueryKind::ErrorCode => extract_error_code_tokens(query),
94        QueryKind::Identifier => extract_identifier_tokens(query, false),
95        QueryKind::Mixed => extract_identifier_tokens(query, true),
96    }
97}
98
99fn extract_path_tokens(query: &str) -> Vec<String> {
100    let mut tokens = Vec::new();
101    for segment in query
102        .split(['/', '\\'])
103        .filter(|segment| !segment.is_empty())
104    {
105        if segment.contains('.') {
106            if let Some(stem) = segment.rsplit_once('.').map(|(stem, _)| stem) {
107                push_unique(&mut tokens, stem);
108            }
109        }
110        push_unique(&mut tokens, segment);
111    }
112    tokens
113}
114
115fn extract_error_code_tokens(query: &str) -> Vec<String> {
116    let mut tokens = Vec::new();
117    for regex in [
118        &*HEX_CODE_RE,
119        &*ERROR_PREFIX_RE,
120        &*NUMERIC_ERROR_RE,
121        &*HTTP_STATUS_RE,
122    ] {
123        for mat in regex.find_iter(query) {
124            push_unique(&mut tokens, mat.as_str());
125        }
126    }
127    if tokens.is_empty() && !query.trim().is_empty() {
128        push_unique(&mut tokens, query.trim());
129    }
130    tokens
131}
132
133fn extract_identifier_tokens(query: &str, require_code_shape: bool) -> Vec<String> {
134    let mut tokens = Vec::new();
135    for mat in IDENTIFIER_TOKEN_RE.find_iter(query) {
136        let token = mat.as_str();
137        if require_code_shape && !is_code_identifier_token(token) {
138            continue;
139        }
140        push_unique(&mut tokens, token);
141    }
142    tokens
143}
144
145fn is_code_identifier_token(token: &str) -> bool {
146    CAMEL_CASE_RE.is_match(token)
147        || SNAKE_CASE_RE.is_match(token)
148        || PASCAL_CASE_RE.is_match(token)
149        || ACRONYM_PASCAL_RE.is_match(token)
150        || DOT_PATH_RE.is_match(token)
151        || ERROR_PREFIX_RE.is_match(token)
152}
153
154fn push_unique(tokens: &mut Vec<String>, token: &str) {
155    if !token.is_empty() && !tokens.iter().any(|existing| existing == token) {
156        tokens.push(token.to_string());
157    }
158}
159
160fn shape(kind: QueryKind) -> QueryShape {
161    QueryShape {
162        kind,
163        weights: weights_for(kind),
164    }
165}
166
167fn weights_for(kind: QueryKind) -> ShapeWeights {
168    match kind {
169        QueryKind::Identifier => ShapeWeights {
170            semantic: 0.2,
171            lexical: 0.8,
172            should_use_lexical: true,
173        },
174        QueryKind::Path | QueryKind::ErrorCode => ShapeWeights {
175            semantic: 0.1,
176            lexical: 0.9,
177            should_use_lexical: true,
178        },
179        QueryKind::NaturalLanguage => ShapeWeights {
180            semantic: 0.6,
181            lexical: 0.4,
182            should_use_lexical: false,
183        },
184        QueryKind::Mixed => ShapeWeights {
185            semantic: 0.4,
186            lexical: 0.6,
187            should_use_lexical: true,
188        },
189    }
190}