1use regex::Regex;
2use std::collections::HashSet;
3use std::sync::LazyLock;
4
5#[derive(Debug, Clone, PartialEq, Eq, Hash)]
7pub enum QueryIntent {
8 FunctionSearch,
9 ErrorHandling,
10 Database,
11 Api,
12 Authentication,
13 Testing,
14}
15
16#[derive(Debug, Clone)]
18pub struct AnalyzedQuery {
19 pub original: String,
20 pub search_query: String,
22 pub tokens: Vec<String>,
23 pub normalized_tokens: Vec<String>,
24 pub intents: HashSet<QueryIntent>,
25 pub is_entity_query: bool,
26 pub has_class_keyword: bool,
27}
28
29static CAMEL_CASE_RE: LazyLock<Regex> =
30 LazyLock::new(|| Regex::new(r"([a-z])([A-Z])").unwrap());
31static WORD_RE: LazyLock<Regex> =
32 LazyLock::new(|| Regex::new(r"\w+").unwrap());
33static CAMEL_CASE_PATTERN: LazyLock<Regex> =
34 LazyLock::new(|| Regex::new(r"[A-Z][a-z]+[A-Z]").unwrap());
35
36const ACTION_WORDS: &[&str] = &[
37 "find", "search", "get", "show", "list", "display", "how", "what", "where", "which",
38 "all", "every", "the", "a", "an", "is", "are", "was", "were", "do", "does",
39];
40
41const STOP_WORDS: &[&str] = &[
43 "how", "does", "the", "where", "is", "what", "which", "do", "a", "an", "are", "was",
44 "were", "this", "that", "it", "in", "on", "at", "to", "for", "of", "with", "by",
45 "from", "can", "could", "would", "should", "will", "be", "been", "being", "have",
46 "has", "had", "did", "done", "about", "into", "through", "its", "my", "your",
47 "there", "here", "when", "why", "all", "each", "every", "both", "few", "more",
48 "most", "other", "some", "such", "than", "too", "very", "just", "also",
49];
50
51struct IntentPattern {
53 intent: QueryIntent,
54 patterns: Vec<&'static str>,
55}
56
57static INTENT_PATTERNS: LazyLock<Vec<IntentPattern>> = LazyLock::new(|| {
58 vec![
59 IntentPattern {
60 intent: QueryIntent::FunctionSearch,
61 patterns: vec![
62 r"\bfunction\b", r"\bdef\b", r"\bmethod\b", r"\bclass\b",
63 r"how.*work", r"implement", r"algorithm",
64 ],
65 },
66 IntentPattern {
67 intent: QueryIntent::ErrorHandling,
68 patterns: vec![
69 r"\berror\b", r"\bexception\b", r"\btry\b", r"\bcatch\b",
70 r"handle.*error", r"exception.*handling",
71 ],
72 },
73 IntentPattern {
74 intent: QueryIntent::Database,
75 patterns: vec![
76 r"\bdatabase\b", r"\bdb\b", r"\bquery\b", r"\bsql\b",
77 r"\bmodel\b", r"\btable\b", r"connection",
78 ],
79 },
80 IntentPattern {
81 intent: QueryIntent::Api,
82 patterns: vec![
83 r"\bapi\b", r"\bendpoint\b", r"\broute\b", r"\brequest\b",
84 r"\bresponse\b", r"\bhttp\b", r"rest.*api",
85 ],
86 },
87 IntentPattern {
88 intent: QueryIntent::Authentication,
89 patterns: vec![
90 r"\bauth\b", r"\blogin\b", r"\btoken\b", r"\bpassword\b",
91 r"\bsession\b", r"authenticate", r"permission",
92 ],
93 },
94 IntentPattern {
95 intent: QueryIntent::Testing,
96 patterns: vec![
97 r"\btest\b", r"\bmock\b", r"\bassert\b", r"\bfixture\b",
98 r"unit.*test", r"integration.*test",
99 ],
100 },
101 ]
102});
103
104const SYNONYMS: &[(&[&str], &[&str])] = &[
106 (&["auth"], &["authentication", "authorize", "authorization"]),
107 (&["db"], &["database"]),
108 (&["config"], &["configuration", "settings"]),
109 (&["init"], &["initialize", "initialization"]),
110 (&["err"], &["error"]),
111 (&["msg"], &["message"]),
112 (&["req"], &["request"]),
113 (&["res", "resp"], &["response"]),
114 (&["middleware"], &["middleware"]),
115 (&["handler"], &["handler", "handle"]),
116 (&["util"], &["utility", "utils", "helpers"]),
117 (&["param"], &["parameter"]),
118 (&["ctx"], &["context"]),
119 (&["conn"], &["connection"]),
120 (&["async"], &["asynchronous"]),
121 (&["sync"], &["synchronous"]),
122];
123
124pub fn analyze_query(query: &str) -> AnalyzedQuery {
126 let tokens = tokenize(query);
127 let normalized = normalize_tokens(&tokens);
128 let intents = detect_intents(query);
129 let is_entity = is_entity_query(&tokens, query);
130 let has_class = query.to_lowercase().contains("class");
131
132 let filtered = remove_stop_words(query);
134 let search_query = expand_query(&filtered, &normalize_tokens(&tokenize(&filtered)));
135
136 AnalyzedQuery {
137 original: query.to_string(),
138 search_query,
139 tokens,
140 normalized_tokens: normalized,
141 intents,
142 is_entity_query: is_entity,
143 has_class_keyword: has_class,
144 }
145}
146
147fn remove_stop_words(query: &str) -> String {
149 let words: Vec<&str> = query.split_whitespace().collect();
150
151 let filtered: Vec<&str> = words
152 .iter()
153 .filter(|w| !STOP_WORDS.contains(&w.to_lowercase().as_str()))
154 .copied()
155 .collect();
156
157 if filtered.is_empty() {
159 query.to_string()
160 } else {
161 filtered.join(" ")
162 }
163}
164
165fn expand_query(original: &str, tokens: &[String]) -> String {
167 let mut expansions: Vec<String> = Vec::new();
168
169 for token in tokens {
170 for (abbreviations, full_forms) in SYNONYMS {
171 if abbreviations.contains(&token.as_str()) {
172 for form in *full_forms {
173 if *form != token.as_str() {
174 expansions.push((*form).to_string());
175 }
176 }
177 }
178 if full_forms.contains(&token.as_str()) {
180 for abbr in *abbreviations {
181 if *abbr != token.as_str() {
182 expansions.push((*abbr).to_string());
183 }
184 }
185 }
186 }
187 }
188
189 if expansions.is_empty() {
190 original.to_string()
191 } else {
192 format!("{} {}", original, expansions.join(" "))
194 }
195}
196
197pub fn tokenize(text: &str) -> Vec<String> {
199 let expanded = CAMEL_CASE_RE.replace_all(text, "$1 $2");
201 let expanded = expanded.replace('_', " ").replace('-', " ");
203 WORD_RE
205 .find_iter(&expanded)
206 .map(|m| m.as_str().to_lowercase())
207 .collect()
208}
209
210fn normalize_tokens(tokens: &[String]) -> Vec<String> {
212 let mut seen = HashSet::new();
213 tokens
214 .iter()
215 .filter(|t| t.len() > 1) .filter_map(|t| {
217 let lower = t.to_lowercase();
218 if seen.insert(lower.clone()) {
219 Some(lower)
220 } else {
221 None
222 }
223 })
224 .collect()
225}
226
227fn detect_intents(query: &str) -> HashSet<QueryIntent> {
229 let lower = query.to_lowercase();
230 let mut intents = HashSet::new();
231
232 for ip in INTENT_PATTERNS.iter() {
233 for pattern in &ip.patterns {
234 if let Ok(re) = Regex::new(pattern) {
235 if re.is_match(&lower) {
236 intents.insert(ip.intent.clone());
237 break;
238 }
239 }
240 }
241 }
242
243 intents
244}
245
246fn is_entity_query(tokens: &[String], original: &str) -> bool {
248 if tokens.len() > 3 {
250 return false;
251 }
252
253 if tokens.iter().any(|t| ACTION_WORDS.contains(&t.as_str())) {
255 return false;
256 }
257
258 if CAMEL_CASE_PATTERN.is_match(original) {
260 return true;
261 }
262
263 tokens.len() <= 2
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_tokenize() {
273 assert_eq!(tokenize("getUserById"), vec!["get", "user", "by", "id"]);
274 assert_eq!(tokenize("get_user_by_id"), vec!["get", "user", "by", "id"]);
275 assert_eq!(tokenize("HTTPClient"), vec!["httpclient"]); }
277
278 #[test]
279 fn test_analyze_query() {
280 let q = analyze_query("find the login function");
281 assert!(q.intents.contains(&QueryIntent::Authentication));
282 assert!(q.intents.contains(&QueryIntent::FunctionSearch));
283 assert!(!q.is_entity_query);
284 }
285
286 #[test]
287 fn test_entity_query_detection() {
288 let q = analyze_query("UserService");
289 assert!(q.is_entity_query);
290
291 let q = analyze_query("find all tests");
292 assert!(!q.is_entity_query);
293 }
294
295 #[test]
296 fn test_query_expansion() {
297 let q = analyze_query("auth middleware");
298 assert!(q.search_query.contains("authentication"));
299 assert!(q.search_query.contains("auth middleware"));
300
301 let q = analyze_query("db connection");
302 assert!(q.search_query.contains("database"));
303
304 let q = analyze_query("foobar baz");
306 assert_eq!(q.search_query, "foobar baz");
307 }
308
309 #[test]
310 fn test_stop_word_removal() {
311 assert_eq!(remove_stop_words("how does routing work"), "routing work");
312 assert_eq!(remove_stop_words("where is the main app defined"), "main app defined");
313 assert_eq!(remove_stop_words("how to add a new endpoint"), "add new endpoint");
314 assert_eq!(remove_stop_words("the is a"), "the is a");
316 assert_eq!(remove_stop_words("middleware auth"), "middleware auth");
318 }
319
320 #[test]
321 fn test_intent_detection() {
322 let q = analyze_query("database connection");
323 assert!(q.intents.contains(&QueryIntent::Database));
324
325 let q = analyze_query("error handling middleware");
326 assert!(q.intents.contains(&QueryIntent::ErrorHandling));
327 }
328}