reflex/semantic/
executor.rs

1//! Command parser and query executor for semantic queries
2
3use anyhow::{Context, Result};
4use std::collections::HashSet;
5
6use crate::cache::CacheManager;
7use crate::models::{FileGroupedResult, Language, SymbolKind};
8use crate::query::{QueryEngine, QueryFilter};
9
10use super::schema::QueryCommand;
11
12/// Parse a command string into query parameters
13///
14/// The command string should be in the format:
15/// `query "pattern" [flags...]`
16///
17/// Example: `query "TODO" --symbols --lang rust`
18pub fn parse_command(command: &str) -> Result<ParsedCommand> {
19    // Parse the command using shell-words to handle quoted strings
20    let parts = shell_words::split(command)
21        .context("Failed to parse command string")?;
22
23    if parts.is_empty() {
24        anyhow::bail!("Empty command string");
25    }
26
27    // First word should be "query"
28    if parts[0] != "query" {
29        anyhow::bail!("Command must start with 'query', got '{}'", parts[0]);
30    }
31
32    if parts.len() < 2 {
33        anyhow::bail!("Missing search pattern in query command");
34    }
35
36    // Second word is the pattern
37    let pattern = parts[1].clone();
38
39    // Parse remaining arguments as flags
40    let mut parsed = ParsedCommand {
41        pattern,
42        symbols: false,
43        lang: None,
44        kind: None,
45        use_ast: false,
46        use_regex: false,
47        limit: None,
48        offset: None,
49        expand: false,
50        file: None,
51        exact: false,
52        contains: false,
53        glob: Vec::new(),
54        exclude: Vec::new(),
55        paths: false,
56        all: false,
57        force: false,
58        dependencies: false,
59        count: false,
60    };
61
62    let mut i = 2;
63    while i < parts.len() {
64        match parts[i].as_str() {
65            "--symbols" | "-s" => {
66                parsed.symbols = true;
67                i += 1;
68            }
69            "--lang" | "-l" => {
70                if i + 1 >= parts.len() {
71                    anyhow::bail!("--lang requires a value");
72                }
73                parsed.lang = Some(parts[i + 1].clone());
74                i += 2;
75            }
76            "--kind" | "-k" => {
77                if i + 1 >= parts.len() {
78                    anyhow::bail!("--kind requires a value");
79                }
80                parsed.kind = Some(parts[i + 1].clone());
81                i += 2;
82            }
83            "--ast" => {
84                parsed.use_ast = true;
85                i += 1;
86            }
87            "--regex" | "-r" => {
88                parsed.use_regex = true;
89                i += 1;
90            }
91            "--limit" | "-n" => {
92                if i + 1 >= parts.len() {
93                    anyhow::bail!("--limit requires a value");
94                }
95                let limit_val: usize = parts[i + 1].parse()
96                    .context("--limit must be a number")?;
97                parsed.limit = Some(limit_val);
98                i += 2;
99            }
100            "--offset" | "-o" => {
101                if i + 1 >= parts.len() {
102                    anyhow::bail!("--offset requires a value");
103                }
104                let offset_val: usize = parts[i + 1].parse()
105                    .context("--offset must be a number")?;
106                parsed.offset = Some(offset_val);
107                i += 2;
108            }
109            "--expand" => {
110                parsed.expand = true;
111                i += 1;
112            }
113            "--file" | "-f" => {
114                if i + 1 >= parts.len() {
115                    anyhow::bail!("--file requires a value");
116                }
117                parsed.file = Some(parts[i + 1].clone());
118                i += 2;
119            }
120            "--exact" => {
121                parsed.exact = true;
122                i += 1;
123            }
124            "--contains" => {
125                parsed.contains = true;
126                i += 1;
127            }
128            "--glob" | "-g" => {
129                if i + 1 >= parts.len() {
130                    anyhow::bail!("--glob requires a value");
131                }
132                parsed.glob.push(parts[i + 1].clone());
133                i += 2;
134            }
135            "--exclude" | "-x" => {
136                if i + 1 >= parts.len() {
137                    anyhow::bail!("--exclude requires a value");
138                }
139                parsed.exclude.push(parts[i + 1].clone());
140                i += 2;
141            }
142            "--paths" | "-p" => {
143                parsed.paths = true;
144                i += 1;
145            }
146            "--all" | "-a" => {
147                parsed.all = true;
148                i += 1;
149            }
150            "--force" => {
151                parsed.force = true;
152                i += 1;
153            }
154            "--dependencies" => {
155                parsed.dependencies = true;
156                i += 1;
157            }
158            "--count" | "-c" => {
159                parsed.count = true;
160                i += 1;
161            }
162            unknown => {
163                log::debug!("Ignoring unknown flag: {}", unknown);
164                i += 1;
165            }
166        }
167    }
168
169    Ok(parsed)
170}
171
172/// Parsed command structure
173#[derive(Debug, Clone)]
174pub struct ParsedCommand {
175    pub pattern: String,
176    pub symbols: bool,
177    pub lang: Option<String>,
178    pub kind: Option<String>,
179    pub use_ast: bool,
180    pub use_regex: bool,
181    pub limit: Option<usize>,
182    pub offset: Option<usize>,
183    pub expand: bool,
184    pub file: Option<String>,
185    pub exact: bool,
186    pub contains: bool,
187    pub glob: Vec<String>,
188    pub exclude: Vec<String>,
189    pub paths: bool,
190    pub all: bool,
191    pub force: bool,
192    pub dependencies: bool,
193    pub count: bool,
194}
195
196impl ParsedCommand {
197    /// Convert to QueryFilter
198    pub fn to_query_filter(&self) -> Result<QueryFilter> {
199        // Parse language
200        let language = if let Some(lang_str) = &self.lang {
201            match lang_str.to_lowercase().as_str() {
202                "rust" | "rs" => Some(Language::Rust),
203                "python" | "py" => Some(Language::Python),
204                "javascript" | "js" => Some(Language::JavaScript),
205                "typescript" | "ts" => Some(Language::TypeScript),
206                "vue" => Some(Language::Vue),
207                "svelte" => Some(Language::Svelte),
208                "go" => Some(Language::Go),
209                "java" => Some(Language::Java),
210                "php" => Some(Language::PHP),
211                "c" => Some(Language::C),
212                "cpp" | "c++" => Some(Language::Cpp),
213                "csharp" | "cs" | "c#" => Some(Language::CSharp),
214                "ruby" | "rb" => Some(Language::Ruby),
215                "kotlin" | "kt" => Some(Language::Kotlin),
216                "swift" => Some(Language::Swift),
217                "zig" => Some(Language::Zig),
218                _ => anyhow::bail!("Unknown language: {}", lang_str),
219            }
220        } else {
221            None
222        };
223
224        // Parse symbol kind
225        let kind = if let Some(kind_str) = &self.kind {
226            // Capitalize first letter for parsing
227            let capitalized = {
228                let mut chars = kind_str.chars();
229                match chars.next() {
230                    None => String::new(),
231                    Some(first) => first.to_uppercase()
232                        .chain(chars.flat_map(|c| c.to_lowercase()))
233                        .collect()
234                }
235            };
236
237            let parsed_kind: SymbolKind = capitalized.parse()
238                .ok()
239                .or_else(|| {
240                    log::debug!("Treating '{}' as unknown symbol kind", kind_str);
241                    Some(SymbolKind::Unknown(kind_str.to_string()))
242                })
243                .context("Failed to parse symbol kind")?;
244
245            Some(parsed_kind)
246        } else {
247            None
248        };
249
250        // Symbol mode is enabled if --symbols flag OR --kind is specified
251        let symbols_mode = self.symbols || self.kind.is_some();
252
253        // Handle --all flag (unlimited results)
254        let limit = if self.all {
255            None
256        } else {
257            self.limit
258        };
259
260        Ok(QueryFilter {
261            language,
262            kind,
263            use_ast: self.use_ast,
264            use_regex: self.use_regex,
265            limit,
266            symbols_mode,
267            expand: self.expand,
268            file_pattern: self.file.clone(),
269            exact: self.exact,
270            use_contains: self.contains,
271            timeout_secs: 30, // Default timeout
272            glob_patterns: self.glob.clone(),
273            exclude_patterns: self.exclude.clone(),
274            paths_only: self.paths,
275            offset: self.offset,
276            force: self.force,
277            suppress_output: true, // Suppress output for programmatic use
278            include_dependencies: self.dependencies,
279            ..Default::default()
280        })
281    }
282}
283
284/// Execute multiple queries with ordering and merging
285///
286/// Queries are executed in order based on their `order` field.
287/// Results are merged based on the `merge` flag - only queries with `merge: true`
288/// contribute to the final result set.
289///
290/// Results are deduplicated by (file_path, start_line, end_line) to avoid duplicates
291/// across multiple queries.
292///
293/// Returns a tuple of (merged results, total count across all queries, count_only mode).
294/// If count_only is true, all queries had --count flag and only the count should be displayed.
295pub async fn execute_queries(
296    queries: Vec<QueryCommand>,
297    cache: &CacheManager,
298) -> Result<(Vec<FileGroupedResult>, usize, bool)> {
299    if queries.is_empty() {
300        return Ok((Vec::new(), 0, false));
301    }
302
303    // Sort queries by order field
304    let mut sorted_queries = queries.clone();
305    sorted_queries.sort_by_key(|q| q.order);
306
307    log::info!("Executing {} queries in order", sorted_queries.len());
308
309    let mut merged_results: Vec<FileGroupedResult> = Vec::new();
310    let mut seen_matches: HashSet<(String, usize, usize)> = HashSet::new();
311    let mut total_count: usize = 0;
312    let mut all_count_only = true;
313
314    // Create a single QueryEngine and reuse it for all queries
315    // This avoids redundant cache validation and SQLite connection overhead
316    let engine = QueryEngine::new(cache.clone());
317
318    for query_cmd in sorted_queries {
319        log::debug!("Executing query {}: {}", query_cmd.order, query_cmd.command);
320
321        // Parse command
322        let parsed = parse_command(&query_cmd.command)
323            .with_context(|| format!("Failed to parse query command: {}", query_cmd.command))?;
324
325        // Track if this query has --count flag
326        if !parsed.count {
327            all_count_only = false;
328        }
329
330        // Convert to QueryFilter
331        let filter = parsed.to_query_filter()?;
332
333        // Execute query (reusing the same engine)
334        let response = engine.search_with_metadata(&parsed.pattern, filter)
335            .with_context(|| format!("Failed to execute query: {}", query_cmd.command))?;
336
337        // Always accumulate total count from all queries
338        total_count += response.pagination.total;
339
340        log::debug!(
341            "Query {} returned {} file groups, {} total matches (merge={})",
342            query_cmd.order,
343            response.results.len(),
344            response.pagination.total,
345            query_cmd.merge
346        );
347
348        // If merge is true, add results to merged set (with deduplication)
349        if query_cmd.merge {
350            for file_group in response.results {
351                // Find or create file group in merged results
352                let file_path = file_group.path.clone();
353
354                let existing_group = merged_results.iter_mut()
355                    .find(|g| g.path == file_path);
356
357                if let Some(group) = existing_group {
358                    // Add matches to existing group (deduplicate)
359                    for match_result in file_group.matches {
360                        let key = (
361                            file_path.clone(),
362                            match_result.span.start_line,
363                            match_result.span.end_line,
364                        );
365
366                        if !seen_matches.contains(&key) {
367                            seen_matches.insert(key);
368                            group.matches.push(match_result);
369                        }
370                    }
371                } else {
372                    // Create new group
373                    for match_result in &file_group.matches {
374                        let key = (
375                            file_path.clone(),
376                            match_result.span.start_line,
377                            match_result.span.end_line,
378                        );
379                        seen_matches.insert(key);
380                    }
381
382                    merged_results.push(file_group);
383                }
384            }
385        }
386    }
387
388    log::info!(
389        "Merged results: {} file groups, {} unique matches, {} total count (count_only={})",
390        merged_results.len(),
391        seen_matches.len(),
392        total_count,
393        all_count_only
394    );
395
396    Ok((merged_results, total_count, all_count_only))
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_parse_simple_query() {
405        let cmd = r#"query "TODO""#;
406        let parsed = parse_command(cmd).unwrap();
407
408        assert_eq!(parsed.pattern, "TODO");
409        assert!(!parsed.symbols);
410        assert!(parsed.lang.is_none());
411    }
412
413    #[test]
414    fn test_parse_query_with_flags() {
415        let cmd = r#"query "extract_symbols" --symbols --lang rust"#;
416        let parsed = parse_command(cmd).unwrap();
417
418        assert_eq!(parsed.pattern, "extract_symbols");
419        assert!(parsed.symbols);
420        assert_eq!(parsed.lang, Some("rust".to_string()));
421    }
422
423    #[test]
424    fn test_parse_query_with_kind() {
425        let cmd = r#"query "main" --kind function --lang rust"#;
426        let parsed = parse_command(cmd).unwrap();
427
428        assert_eq!(parsed.pattern, "main");
429        assert_eq!(parsed.kind, Some("function".to_string()));
430        assert_eq!(parsed.lang, Some("rust".to_string()));
431    }
432
433    #[test]
434    fn test_parse_query_with_glob() {
435        let cmd = r#"query "TODO" --glob "src/**/*.rs" --glob "tests/**/*.rs""#;
436        let parsed = parse_command(cmd).unwrap();
437
438        assert_eq!(parsed.pattern, "TODO");
439        assert_eq!(parsed.glob.len(), 2);
440        assert_eq!(parsed.glob[0], "src/**/*.rs");
441        assert_eq!(parsed.glob[1], "tests/**/*.rs");
442    }
443
444    #[test]
445    fn test_parse_query_with_exclude() {
446        let cmd = r#"query "config" --exclude "target/**" --exclude "*.gen.rs""#;
447        let parsed = parse_command(cmd).unwrap();
448
449        assert_eq!(parsed.pattern, "config");
450        assert_eq!(parsed.exclude.len(), 2);
451    }
452
453    #[test]
454    fn test_parse_invalid_command() {
455        let cmd = r#"search "pattern""#;
456        let result = parse_command(cmd);
457        assert!(result.is_err());
458        assert!(result.unwrap_err().to_string().contains("must start with 'query'"));
459    }
460
461    #[test]
462    fn test_parse_empty_command() {
463        let cmd = "";
464        let result = parse_command(cmd);
465        assert!(result.is_err());
466    }
467
468    #[test]
469    fn test_to_query_filter() {
470        let cmd = r#"query "TODO" --symbols --lang rust --limit 10"#;
471        let parsed = parse_command(cmd).unwrap();
472        let filter = parsed.to_query_filter().unwrap();
473
474        assert_eq!(filter.language, Some(Language::Rust));
475        assert!(filter.symbols_mode);
476        assert_eq!(filter.limit, Some(10));
477    }
478
479    #[test]
480    fn test_to_query_filter_with_kind() {
481        let cmd = r#"query "parse" --kind function"#;
482        let parsed = parse_command(cmd).unwrap();
483        let filter = parsed.to_query_filter().unwrap();
484
485        assert!(filter.symbols_mode); // kind implies symbols mode
486        assert!(matches!(filter.kind, Some(SymbolKind::Function)));
487    }
488}