1use anyhow::{Context, Result};
4use std::collections::HashSet;
5
6use crate::cache::CacheManager;
7use crate::models::{FileGroupedResult, Language, SymbolKind};
8use crate::query::{QueryEngine, QueryFilter};
9
10use super::schema::QueryCommand;
11
12pub fn parse_command(command: &str) -> Result<ParsedCommand> {
19 let parts = shell_words::split(command)
21 .context("Failed to parse command string")?;
22
23 if parts.is_empty() {
24 anyhow::bail!("Empty command string");
25 }
26
27 if parts[0] != "query" {
29 anyhow::bail!("Command must start with 'query', got '{}'", parts[0]);
30 }
31
32 if parts.len() < 2 {
33 anyhow::bail!("Missing search pattern in query command");
34 }
35
36 let pattern = parts[1].clone();
38
39 let mut parsed = ParsedCommand {
41 pattern,
42 symbols: false,
43 lang: None,
44 kind: None,
45 use_ast: false,
46 use_regex: false,
47 limit: None,
48 offset: None,
49 expand: false,
50 file: None,
51 exact: false,
52 contains: false,
53 glob: Vec::new(),
54 exclude: Vec::new(),
55 paths: false,
56 all: false,
57 force: false,
58 dependencies: false,
59 count: false,
60 };
61
62 let mut i = 2;
63 while i < parts.len() {
64 match parts[i].as_str() {
65 "--symbols" | "-s" => {
66 parsed.symbols = true;
67 i += 1;
68 }
69 "--lang" | "-l" => {
70 if i + 1 >= parts.len() {
71 anyhow::bail!("--lang requires a value");
72 }
73 parsed.lang = Some(parts[i + 1].clone());
74 i += 2;
75 }
76 "--kind" | "-k" => {
77 if i + 1 >= parts.len() {
78 anyhow::bail!("--kind requires a value");
79 }
80 parsed.kind = Some(parts[i + 1].clone());
81 i += 2;
82 }
83 "--ast" => {
84 parsed.use_ast = true;
85 i += 1;
86 }
87 "--regex" | "-r" => {
88 parsed.use_regex = true;
89 i += 1;
90 }
91 "--limit" | "-n" => {
92 if i + 1 >= parts.len() {
93 anyhow::bail!("--limit requires a value");
94 }
95 let limit_val: usize = parts[i + 1].parse()
96 .context("--limit must be a number")?;
97 parsed.limit = Some(limit_val);
98 i += 2;
99 }
100 "--offset" | "-o" => {
101 if i + 1 >= parts.len() {
102 anyhow::bail!("--offset requires a value");
103 }
104 let offset_val: usize = parts[i + 1].parse()
105 .context("--offset must be a number")?;
106 parsed.offset = Some(offset_val);
107 i += 2;
108 }
109 "--expand" => {
110 parsed.expand = true;
111 i += 1;
112 }
113 "--file" | "-f" => {
114 if i + 1 >= parts.len() {
115 anyhow::bail!("--file requires a value");
116 }
117 parsed.file = Some(parts[i + 1].clone());
118 i += 2;
119 }
120 "--exact" => {
121 parsed.exact = true;
122 i += 1;
123 }
124 "--contains" => {
125 parsed.contains = true;
126 i += 1;
127 }
128 "--glob" | "-g" => {
129 if i + 1 >= parts.len() {
130 anyhow::bail!("--glob requires a value");
131 }
132 parsed.glob.push(parts[i + 1].clone());
133 i += 2;
134 }
135 "--exclude" | "-x" => {
136 if i + 1 >= parts.len() {
137 anyhow::bail!("--exclude requires a value");
138 }
139 parsed.exclude.push(parts[i + 1].clone());
140 i += 2;
141 }
142 "--paths" | "-p" => {
143 parsed.paths = true;
144 i += 1;
145 }
146 "--all" | "-a" => {
147 parsed.all = true;
148 i += 1;
149 }
150 "--force" => {
151 parsed.force = true;
152 i += 1;
153 }
154 "--dependencies" => {
155 parsed.dependencies = true;
156 i += 1;
157 }
158 "--count" | "-c" => {
159 parsed.count = true;
160 i += 1;
161 }
162 unknown => {
163 log::debug!("Ignoring unknown flag: {}", unknown);
164 i += 1;
165 }
166 }
167 }
168
169 Ok(parsed)
170}
171
172#[derive(Debug, Clone)]
174pub struct ParsedCommand {
175 pub pattern: String,
176 pub symbols: bool,
177 pub lang: Option<String>,
178 pub kind: Option<String>,
179 pub use_ast: bool,
180 pub use_regex: bool,
181 pub limit: Option<usize>,
182 pub offset: Option<usize>,
183 pub expand: bool,
184 pub file: Option<String>,
185 pub exact: bool,
186 pub contains: bool,
187 pub glob: Vec<String>,
188 pub exclude: Vec<String>,
189 pub paths: bool,
190 pub all: bool,
191 pub force: bool,
192 pub dependencies: bool,
193 pub count: bool,
194}
195
196impl ParsedCommand {
197 pub fn to_query_filter(&self) -> Result<QueryFilter> {
199 let language = if let Some(lang_str) = &self.lang {
201 match lang_str.to_lowercase().as_str() {
202 "rust" | "rs" => Some(Language::Rust),
203 "python" | "py" => Some(Language::Python),
204 "javascript" | "js" => Some(Language::JavaScript),
205 "typescript" | "ts" => Some(Language::TypeScript),
206 "vue" => Some(Language::Vue),
207 "svelte" => Some(Language::Svelte),
208 "go" => Some(Language::Go),
209 "java" => Some(Language::Java),
210 "php" => Some(Language::PHP),
211 "c" => Some(Language::C),
212 "cpp" | "c++" => Some(Language::Cpp),
213 "csharp" | "cs" | "c#" => Some(Language::CSharp),
214 "ruby" | "rb" => Some(Language::Ruby),
215 "kotlin" | "kt" => Some(Language::Kotlin),
216 "swift" => Some(Language::Swift),
217 "zig" => Some(Language::Zig),
218 _ => anyhow::bail!("Unknown language: {}", lang_str),
219 }
220 } else {
221 None
222 };
223
224 let kind = if let Some(kind_str) = &self.kind {
226 let capitalized = {
228 let mut chars = kind_str.chars();
229 match chars.next() {
230 None => String::new(),
231 Some(first) => first.to_uppercase()
232 .chain(chars.flat_map(|c| c.to_lowercase()))
233 .collect()
234 }
235 };
236
237 let parsed_kind: SymbolKind = capitalized.parse()
238 .ok()
239 .or_else(|| {
240 log::debug!("Treating '{}' as unknown symbol kind", kind_str);
241 Some(SymbolKind::Unknown(kind_str.to_string()))
242 })
243 .context("Failed to parse symbol kind")?;
244
245 Some(parsed_kind)
246 } else {
247 None
248 };
249
250 let symbols_mode = self.symbols || self.kind.is_some();
252
253 let limit = if self.all {
255 None
256 } else {
257 self.limit
258 };
259
260 Ok(QueryFilter {
261 language,
262 kind,
263 use_ast: self.use_ast,
264 use_regex: self.use_regex,
265 limit,
266 symbols_mode,
267 expand: self.expand,
268 file_pattern: self.file.clone(),
269 exact: self.exact,
270 use_contains: self.contains,
271 timeout_secs: 30, glob_patterns: self.glob.clone(),
273 exclude_patterns: self.exclude.clone(),
274 paths_only: self.paths,
275 offset: self.offset,
276 force: self.force,
277 suppress_output: true, include_dependencies: self.dependencies,
279 ..Default::default()
280 })
281 }
282}
283
284pub async fn execute_queries(
296 queries: Vec<QueryCommand>,
297 cache: &CacheManager,
298) -> Result<(Vec<FileGroupedResult>, usize, bool)> {
299 if queries.is_empty() {
300 return Ok((Vec::new(), 0, false));
301 }
302
303 let mut sorted_queries = queries.clone();
305 sorted_queries.sort_by_key(|q| q.order);
306
307 log::info!("Executing {} queries in order", sorted_queries.len());
308
309 let mut merged_results: Vec<FileGroupedResult> = Vec::new();
310 let mut seen_matches: HashSet<(String, usize, usize)> = HashSet::new();
311 let mut total_count: usize = 0;
312 let mut all_count_only = true;
313
314 let engine = QueryEngine::new(cache.clone());
317
318 for query_cmd in sorted_queries {
319 log::debug!("Executing query {}: {}", query_cmd.order, query_cmd.command);
320
321 let parsed = parse_command(&query_cmd.command)
323 .with_context(|| format!("Failed to parse query command: {}", query_cmd.command))?;
324
325 if !parsed.count {
327 all_count_only = false;
328 }
329
330 let filter = parsed.to_query_filter()?;
332
333 let response = engine.search_with_metadata(&parsed.pattern, filter)
335 .with_context(|| format!("Failed to execute query: {}", query_cmd.command))?;
336
337 total_count += response.pagination.total;
339
340 log::debug!(
341 "Query {} returned {} file groups, {} total matches (merge={})",
342 query_cmd.order,
343 response.results.len(),
344 response.pagination.total,
345 query_cmd.merge
346 );
347
348 if query_cmd.merge {
350 for file_group in response.results {
351 let file_path = file_group.path.clone();
353
354 let existing_group = merged_results.iter_mut()
355 .find(|g| g.path == file_path);
356
357 if let Some(group) = existing_group {
358 for match_result in file_group.matches {
360 let key = (
361 file_path.clone(),
362 match_result.span.start_line,
363 match_result.span.end_line,
364 );
365
366 if !seen_matches.contains(&key) {
367 seen_matches.insert(key);
368 group.matches.push(match_result);
369 }
370 }
371 } else {
372 for match_result in &file_group.matches {
374 let key = (
375 file_path.clone(),
376 match_result.span.start_line,
377 match_result.span.end_line,
378 );
379 seen_matches.insert(key);
380 }
381
382 merged_results.push(file_group);
383 }
384 }
385 }
386 }
387
388 log::info!(
389 "Merged results: {} file groups, {} unique matches, {} total count (count_only={})",
390 merged_results.len(),
391 seen_matches.len(),
392 total_count,
393 all_count_only
394 );
395
396 Ok((merged_results, total_count, all_count_only))
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_parse_simple_query() {
405 let cmd = r#"query "TODO""#;
406 let parsed = parse_command(cmd).unwrap();
407
408 assert_eq!(parsed.pattern, "TODO");
409 assert!(!parsed.symbols);
410 assert!(parsed.lang.is_none());
411 }
412
413 #[test]
414 fn test_parse_query_with_flags() {
415 let cmd = r#"query "extract_symbols" --symbols --lang rust"#;
416 let parsed = parse_command(cmd).unwrap();
417
418 assert_eq!(parsed.pattern, "extract_symbols");
419 assert!(parsed.symbols);
420 assert_eq!(parsed.lang, Some("rust".to_string()));
421 }
422
423 #[test]
424 fn test_parse_query_with_kind() {
425 let cmd = r#"query "main" --kind function --lang rust"#;
426 let parsed = parse_command(cmd).unwrap();
427
428 assert_eq!(parsed.pattern, "main");
429 assert_eq!(parsed.kind, Some("function".to_string()));
430 assert_eq!(parsed.lang, Some("rust".to_string()));
431 }
432
433 #[test]
434 fn test_parse_query_with_glob() {
435 let cmd = r#"query "TODO" --glob "src/**/*.rs" --glob "tests/**/*.rs""#;
436 let parsed = parse_command(cmd).unwrap();
437
438 assert_eq!(parsed.pattern, "TODO");
439 assert_eq!(parsed.glob.len(), 2);
440 assert_eq!(parsed.glob[0], "src/**/*.rs");
441 assert_eq!(parsed.glob[1], "tests/**/*.rs");
442 }
443
444 #[test]
445 fn test_parse_query_with_exclude() {
446 let cmd = r#"query "config" --exclude "target/**" --exclude "*.gen.rs""#;
447 let parsed = parse_command(cmd).unwrap();
448
449 assert_eq!(parsed.pattern, "config");
450 assert_eq!(parsed.exclude.len(), 2);
451 }
452
453 #[test]
454 fn test_parse_invalid_command() {
455 let cmd = r#"search "pattern""#;
456 let result = parse_command(cmd);
457 assert!(result.is_err());
458 assert!(result.unwrap_err().to_string().contains("must start with 'query'"));
459 }
460
461 #[test]
462 fn test_parse_empty_command() {
463 let cmd = "";
464 let result = parse_command(cmd);
465 assert!(result.is_err());
466 }
467
468 #[test]
469 fn test_to_query_filter() {
470 let cmd = r#"query "TODO" --symbols --lang rust --limit 10"#;
471 let parsed = parse_command(cmd).unwrap();
472 let filter = parsed.to_query_filter().unwrap();
473
474 assert_eq!(filter.language, Some(Language::Rust));
475 assert!(filter.symbols_mode);
476 assert_eq!(filter.limit, Some(10));
477 }
478
479 #[test]
480 fn test_to_query_filter_with_kind() {
481 let cmd = r#"query "parse" --kind function"#;
482 let parsed = parse_command(cmd).unwrap();
483 let filter = parsed.to_query_filter().unwrap();
484
485 assert!(filter.symbols_mode); assert!(matches!(filter.kind, Some(SymbolKind::Function)));
487 }
488}