Skip to main content

reflex/semantic/
executor.rs

1//! Command parser and query executor for semantic queries
2
3use anyhow::{Context, Result};
4use std::collections::HashSet;
5
6use crate::cache::CacheManager;
7use crate::models::{FileGroupedResult, Language, SymbolKind};
8use crate::query::{QueryEngine, QueryFilter};
9
10use super::schema::QueryCommand;
11
12/// Parse a command string into query parameters
13///
14/// The command string should be in the format:
15/// `query "pattern" [flags...]`
16///
17/// Example: `query "TODO" --symbols --lang rust`
18pub fn parse_command(command: &str) -> Result<ParsedCommand> {
19    // Parse the command using shell-words to handle quoted strings
20    let parts = shell_words::split(command).context("Failed to parse command string")?;
21
22    if parts.is_empty() {
23        anyhow::bail!("Empty command string");
24    }
25
26    // First word should be "query"
27    if parts[0] != "query" {
28        anyhow::bail!("Command must start with 'query', got '{}'", parts[0]);
29    }
30
31    if parts.len() < 2 {
32        anyhow::bail!("Missing search pattern in query command");
33    }
34
35    // Second word is the pattern
36    let pattern = parts[1].clone();
37
38    // Parse remaining arguments as flags
39    let mut parsed = ParsedCommand {
40        pattern,
41        symbols: false,
42        lang: None,
43        kind: None,
44        use_ast: false,
45        use_regex: false,
46        limit: None,
47        offset: None,
48        expand: false,
49        file: None,
50        exact: false,
51        contains: false,
52        glob: Vec::new(),
53        exclude: Vec::new(),
54        paths: false,
55        all: false,
56        force: false,
57        dependencies: false,
58        count: false,
59    };
60
61    let mut i = 2;
62    while i < parts.len() {
63        match parts[i].as_str() {
64            "--symbols" | "-s" => {
65                parsed.symbols = true;
66                i += 1;
67            }
68            "--lang" | "-l" => {
69                if i + 1 >= parts.len() {
70                    anyhow::bail!("--lang requires a value");
71                }
72                parsed.lang = Some(parts[i + 1].clone());
73                i += 2;
74            }
75            "--kind" | "-k" => {
76                if i + 1 >= parts.len() {
77                    anyhow::bail!("--kind requires a value");
78                }
79                parsed.kind = Some(parts[i + 1].clone());
80                i += 2;
81            }
82            "--ast" => {
83                parsed.use_ast = true;
84                i += 1;
85            }
86            "--regex" | "-r" => {
87                parsed.use_regex = true;
88                i += 1;
89            }
90            "--limit" | "-n" => {
91                if i + 1 >= parts.len() {
92                    anyhow::bail!("--limit requires a value");
93                }
94                let limit_val: usize = parts[i + 1].parse().context("--limit must be a number")?;
95                parsed.limit = Some(limit_val);
96                i += 2;
97            }
98            "--offset" | "-o" => {
99                if i + 1 >= parts.len() {
100                    anyhow::bail!("--offset requires a value");
101                }
102                let offset_val: usize =
103                    parts[i + 1].parse().context("--offset must be a number")?;
104                parsed.offset = Some(offset_val);
105                i += 2;
106            }
107            "--expand" => {
108                parsed.expand = true;
109                i += 1;
110            }
111            "--file" | "-f" => {
112                if i + 1 >= parts.len() {
113                    anyhow::bail!("--file requires a value");
114                }
115                parsed.file = Some(parts[i + 1].clone());
116                i += 2;
117            }
118            "--exact" => {
119                parsed.exact = true;
120                i += 1;
121            }
122            "--contains" => {
123                parsed.contains = true;
124                i += 1;
125            }
126            "--glob" | "-g" => {
127                if i + 1 >= parts.len() {
128                    anyhow::bail!("--glob requires a value");
129                }
130                parsed.glob.push(parts[i + 1].clone());
131                i += 2;
132            }
133            "--exclude" | "-x" => {
134                if i + 1 >= parts.len() {
135                    anyhow::bail!("--exclude requires a value");
136                }
137                parsed.exclude.push(parts[i + 1].clone());
138                i += 2;
139            }
140            "--paths" | "-p" => {
141                parsed.paths = true;
142                i += 1;
143            }
144            "--all" | "-a" => {
145                parsed.all = true;
146                i += 1;
147            }
148            "--force" => {
149                parsed.force = true;
150                i += 1;
151            }
152            "--dependencies" => {
153                parsed.dependencies = true;
154                i += 1;
155            }
156            "--count" | "-c" => {
157                parsed.count = true;
158                i += 1;
159            }
160            unknown => {
161                log::debug!("Ignoring unknown flag: {}", unknown);
162                i += 1;
163            }
164        }
165    }
166
167    Ok(parsed)
168}
169
170/// Parsed command structure
171#[derive(Debug, Clone)]
172pub struct ParsedCommand {
173    pub pattern: String,
174    pub symbols: bool,
175    pub lang: Option<String>,
176    pub kind: Option<String>,
177    pub use_ast: bool,
178    pub use_regex: bool,
179    pub limit: Option<usize>,
180    pub offset: Option<usize>,
181    pub expand: bool,
182    pub file: Option<String>,
183    pub exact: bool,
184    pub contains: bool,
185    pub glob: Vec<String>,
186    pub exclude: Vec<String>,
187    pub paths: bool,
188    pub all: bool,
189    pub force: bool,
190    pub dependencies: bool,
191    pub count: bool,
192}
193
194impl ParsedCommand {
195    /// Convert to QueryFilter
196    pub fn to_query_filter(&self) -> Result<QueryFilter> {
197        // Parse language
198        let language = if let Some(lang_str) = &self.lang {
199            match lang_str.to_lowercase().as_str() {
200                "rust" | "rs" => Some(Language::Rust),
201                "python" | "py" => Some(Language::Python),
202                "javascript" | "js" => Some(Language::JavaScript),
203                "typescript" | "ts" => Some(Language::TypeScript),
204                "vue" => Some(Language::Vue),
205                "svelte" => Some(Language::Svelte),
206                "go" => Some(Language::Go),
207                "java" => Some(Language::Java),
208                "php" => Some(Language::PHP),
209                "c" => Some(Language::C),
210                "cpp" | "c++" => Some(Language::Cpp),
211                "csharp" | "cs" | "c#" => Some(Language::CSharp),
212                "ruby" | "rb" => Some(Language::Ruby),
213                "kotlin" | "kt" => Some(Language::Kotlin),
214                "swift" => Some(Language::Swift),
215                "zig" => Some(Language::Zig),
216                _ => anyhow::bail!("Unknown language: {}", lang_str),
217            }
218        } else {
219            None
220        };
221
222        // Parse symbol kind
223        let kind = if let Some(kind_str) = &self.kind {
224            // Capitalize first letter for parsing
225            let capitalized = {
226                let mut chars = kind_str.chars();
227                match chars.next() {
228                    None => String::new(),
229                    Some(first) => first
230                        .to_uppercase()
231                        .chain(chars.flat_map(|c| c.to_lowercase()))
232                        .collect(),
233                }
234            };
235
236            let parsed_kind: SymbolKind = capitalized
237                .parse()
238                .ok()
239                .or_else(|| {
240                    log::debug!("Treating '{}' as unknown symbol kind", kind_str);
241                    Some(SymbolKind::Unknown(kind_str.to_string()))
242                })
243                .context("Failed to parse symbol kind")?;
244
245            Some(parsed_kind)
246        } else {
247            None
248        };
249
250        // Symbol mode is enabled if --symbols flag OR --kind is specified
251        let symbols_mode = self.symbols || self.kind.is_some();
252
253        // Handle --all flag (unlimited results)
254        let limit = if self.all { None } else { self.limit };
255
256        Ok(QueryFilter {
257            language,
258            kind,
259            use_ast: self.use_ast,
260            use_regex: self.use_regex,
261            limit,
262            symbols_mode,
263            expand: self.expand,
264            file_pattern: self.file.clone(),
265            exact: self.exact,
266            use_contains: self.contains,
267            timeout_secs: 30, // Default timeout
268            glob_patterns: self.glob.clone(),
269            exclude_patterns: self.exclude.clone(),
270            paths_only: self.paths,
271            offset: self.offset,
272            force: self.force,
273            suppress_output: true, // Suppress output for programmatic use
274            include_dependencies: self.dependencies,
275            ..Default::default()
276        })
277    }
278}
279
280/// Execute multiple queries with ordering and merging
281///
282/// Queries are executed in order based on their `order` field.
283/// Results are merged based on the `merge` flag - only queries with `merge: true`
284/// contribute to the final result set.
285///
286/// Results are deduplicated by (file_path, start_line, end_line) to avoid duplicates
287/// across multiple queries.
288///
289/// Returns a tuple of (merged results, total count across all queries, count_only mode).
290/// If count_only is true, all queries had --count flag and only the count should be displayed.
291pub async fn execute_queries(
292    queries: Vec<QueryCommand>,
293    cache: &CacheManager,
294) -> Result<(Vec<FileGroupedResult>, usize, bool)> {
295    if queries.is_empty() {
296        return Ok((Vec::new(), 0, false));
297    }
298
299    // Sort queries by order field
300    let mut sorted_queries = queries.clone();
301    sorted_queries.sort_by_key(|q| q.order);
302
303    log::info!("Executing {} queries in order", sorted_queries.len());
304
305    let mut merged_results: Vec<FileGroupedResult> = Vec::new();
306    let mut seen_matches: HashSet<(String, usize, usize)> = HashSet::new();
307    let mut total_count: usize = 0;
308    let mut all_count_only = true;
309
310    // Create a single QueryEngine and reuse it for all queries
311    // This avoids redundant cache validation and SQLite connection overhead
312    let engine = QueryEngine::new(cache.clone());
313
314    for query_cmd in sorted_queries {
315        log::debug!("Executing query {}: {}", query_cmd.order, query_cmd.command);
316
317        // Parse command
318        let parsed = parse_command(&query_cmd.command)
319            .with_context(|| format!("Failed to parse query command: {}", query_cmd.command))?;
320
321        // Track if this query has --count flag
322        if !parsed.count {
323            all_count_only = false;
324        }
325
326        // Convert to QueryFilter
327        let filter = parsed.to_query_filter()?;
328
329        // Execute query (reusing the same engine)
330        let response = engine
331            .search_with_metadata(&parsed.pattern, filter)
332            .with_context(|| format!("Failed to execute query: {}", query_cmd.command))?;
333
334        // Always accumulate total count from all queries
335        total_count += response.pagination.total;
336
337        log::debug!(
338            "Query {} returned {} file groups, {} total matches (merge={})",
339            query_cmd.order,
340            response.results.len(),
341            response.pagination.total,
342            query_cmd.merge
343        );
344
345        // If merge is true, add results to merged set (with deduplication)
346        if query_cmd.merge {
347            for file_group in response.results {
348                // Find or create file group in merged results
349                let file_path = file_group.path.clone();
350
351                let existing_group = merged_results.iter_mut().find(|g| g.path == file_path);
352
353                if let Some(group) = existing_group {
354                    // Add matches to existing group (deduplicate)
355                    for match_result in file_group.matches {
356                        let key = (
357                            file_path.clone(),
358                            match_result.span.start_line,
359                            match_result.span.end_line,
360                        );
361
362                        if !seen_matches.contains(&key) {
363                            seen_matches.insert(key);
364                            group.matches.push(match_result);
365                        }
366                    }
367                } else {
368                    // Create new group
369                    for match_result in &file_group.matches {
370                        let key = (
371                            file_path.clone(),
372                            match_result.span.start_line,
373                            match_result.span.end_line,
374                        );
375                        seen_matches.insert(key);
376                    }
377
378                    merged_results.push(file_group);
379                }
380            }
381        }
382    }
383
384    log::info!(
385        "Merged results: {} file groups, {} unique matches, {} total count (count_only={})",
386        merged_results.len(),
387        seen_matches.len(),
388        total_count,
389        all_count_only
390    );
391
392    Ok((merged_results, total_count, all_count_only))
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_parse_simple_query() {
401        let cmd = r#"query "TODO""#;
402        let parsed = parse_command(cmd).unwrap();
403
404        assert_eq!(parsed.pattern, "TODO");
405        assert!(!parsed.symbols);
406        assert!(parsed.lang.is_none());
407    }
408
409    #[test]
410    fn test_parse_query_with_flags() {
411        let cmd = r#"query "extract_symbols" --symbols --lang rust"#;
412        let parsed = parse_command(cmd).unwrap();
413
414        assert_eq!(parsed.pattern, "extract_symbols");
415        assert!(parsed.symbols);
416        assert_eq!(parsed.lang, Some("rust".to_string()));
417    }
418
419    #[test]
420    fn test_parse_query_with_kind() {
421        let cmd = r#"query "main" --kind function --lang rust"#;
422        let parsed = parse_command(cmd).unwrap();
423
424        assert_eq!(parsed.pattern, "main");
425        assert_eq!(parsed.kind, Some("function".to_string()));
426        assert_eq!(parsed.lang, Some("rust".to_string()));
427    }
428
429    #[test]
430    fn test_parse_query_with_glob() {
431        let cmd = r#"query "TODO" --glob "src/**/*.rs" --glob "tests/**/*.rs""#;
432        let parsed = parse_command(cmd).unwrap();
433
434        assert_eq!(parsed.pattern, "TODO");
435        assert_eq!(parsed.glob.len(), 2);
436        assert_eq!(parsed.glob[0], "src/**/*.rs");
437        assert_eq!(parsed.glob[1], "tests/**/*.rs");
438    }
439
440    #[test]
441    fn test_parse_query_with_exclude() {
442        let cmd = r#"query "config" --exclude "target/**" --exclude "*.gen.rs""#;
443        let parsed = parse_command(cmd).unwrap();
444
445        assert_eq!(parsed.pattern, "config");
446        assert_eq!(parsed.exclude.len(), 2);
447    }
448
449    #[test]
450    fn test_parse_invalid_command() {
451        let cmd = r#"search "pattern""#;
452        let result = parse_command(cmd);
453        assert!(result.is_err());
454        assert!(
455            result
456                .unwrap_err()
457                .to_string()
458                .contains("must start with 'query'")
459        );
460    }
461
462    #[test]
463    fn test_parse_empty_command() {
464        let cmd = "";
465        let result = parse_command(cmd);
466        assert!(result.is_err());
467    }
468
469    #[test]
470    fn test_to_query_filter() {
471        let cmd = r#"query "TODO" --symbols --lang rust --limit 10"#;
472        let parsed = parse_command(cmd).unwrap();
473        let filter = parsed.to_query_filter().unwrap();
474
475        assert_eq!(filter.language, Some(Language::Rust));
476        assert!(filter.symbols_mode);
477        assert_eq!(filter.limit, Some(10));
478    }
479
480    #[test]
481    fn test_to_query_filter_with_kind() {
482        let cmd = r#"query "parse" --kind function"#;
483        let parsed = parse_command(cmd).unwrap();
484        let filter = parsed.to_query_filter().unwrap();
485
486        assert!(filter.symbols_mode); // kind implies symbols mode
487        assert!(matches!(filter.kind, Some(SymbolKind::Function)));
488    }
489}