qail_core/analyzer/rust_ast/
detector.rs

1//! Rust AST analyzer using `syn`.
2//!
3//! Provides 100% accurate detection of QAIL patterns in Rust source code
4//! by parsing the actual AST instead of using regex.
5
6use std::fs;
7use std::path::Path;
8use syn::visit::Visit;
9use syn::{Expr, ExprCall, ExprMethodCall, Lit, LitStr};
10
11// Use the parent analyzer module's types
12use crate::analyzer::{CodeReference, QueryType};
13
14// Re-export proc_macro2 span type for line number extraction
15use proc_macro2::Span;
16
17/// Patterns we're looking for in Rust code
18#[derive(Debug, Clone)]
19pub struct RustPattern {
20    pub table: String,
21    pub columns: Vec<String>,
22    /// Line number (approximation based on span)
23    pub line: usize,
24    /// Code snippet
25    pub snippet: String,
26}
27
28/// Visitor that walks Rust AST to find QAIL patterns
29struct QailVisitor {
30    patterns: Vec<RustPattern>,
31    #[allow(dead_code)]
32    source: String,
33}
34
35impl QailVisitor {
36    fn new(source: String) -> Self {
37        Self {
38            patterns: Vec::new(),
39            source,
40        }
41    }
42
43    /// Extract string value from a string literal
44    fn extract_string(lit: &LitStr) -> String {
45        lit.value()
46    }
47
48    /// Approximate line number from span
49    fn line_from_span(&self, span: Span) -> usize {
50        span.start().line
51    }
52
53    /// Extract all string literals from any expression (generic approach)
54    fn extract_strings_from_expr(expr: &Expr) -> Vec<String> {
55        let mut strings = Vec::new();
56        match expr {
57            // Direct string literal
58            Expr::Lit(lit) => {
59                if let Lit::Str(s) = &lit.lit {
60                    strings.push(Self::extract_string(s));
61                }
62            }
63            // Array of strings [\"a\", \"b\", \"c\"]
64            Expr::Array(arr) => {
65                for elem in &arr.elems {
66                    strings.extend(Self::extract_strings_from_expr(elem));
67                }
68            }
69            // Reference &\"string\"
70            Expr::Reference(r) => {
71                strings.extend(Self::extract_strings_from_expr(&r.expr));
72            }
73            _ => {}
74        }
75        strings
76    }
77
78    /// Check if this is a Qail constructor call (generic - detects ALL constructors)
79    fn check_qailcmd_call(&mut self, path: &syn::ExprPath, args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>) {
80        let segments: Vec<_> = path.path.segments.iter().map(|s| s.ident.to_string()).collect();
81        
82        // Match Qail::* where * is any method
83        if segments.len() >= 2 && segments[0] == "Qail" {
84            let action = &segments[1];
85            
86            let mut columns = Vec::new();
87            let mut table = String::new();
88            
89            for arg in args {
90                let extracted = Self::extract_strings_from_expr(arg);
91                if table.is_empty() && !extracted.is_empty() {
92                    table = extracted[0].clone();
93                } else {
94                    columns.extend(extracted);
95                }
96            }
97            
98            if !table.is_empty() {
99                self.patterns.push(RustPattern {
100                    table: table.clone(),
101                    columns,
102                    line: self.line_from_span(path.path.segments.first().map(|s| s.ident.span()).unwrap_or_else(Span::call_site)),
103                    snippet: format!("Qail::{}(\"{}\")", action, table),
104                });
105            }
106        }
107    }
108
109    /// Check method calls for column/table references (generic - captures ALL string arguments)
110    fn check_method_call(&mut self, method: &str, args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>, span: Span) {
111        let mut all_strings = Vec::new();
112        for arg in args {
113            all_strings.extend(Self::extract_strings_from_expr(arg));
114        }
115        
116        // If we found any strings, record this method call
117        if !all_strings.is_empty() {
118            let snippet = if all_strings.len() == 1 {
119                format!(".{}(\"{}\")", method, all_strings[0])
120            } else if all_strings.len() <= 3 {
121                format!(".{}([{}])", method, all_strings.iter().map(|s| format!("\"{}\"", s)).collect::<Vec<_>>().join(", "))
122            } else {
123                format!(".{}([\"{}\" +{}])", method, all_strings[0], all_strings.len() - 1)
124            };
125            
126            self.patterns.push(RustPattern {
127                table: String::new(), // Will be merged with parent
128                columns: all_strings,
129                line: self.line_from_span(span),
130                snippet,
131            });
132        }
133    }
134}
135
136impl<'ast> Visit<'ast> for QailVisitor {
137    fn visit_expr_call(&mut self, node: &'ast ExprCall) {
138        if let Expr::Path(path) = &*node.func {
139            self.check_qailcmd_call(path, &node.args);
140        }
141        // Continue visiting children
142        syn::visit::visit_expr_call(self, node);
143    }
144
145    fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
146        let method = node.method.to_string();
147        self.check_method_call(&method, &node.args, node.method.span());
148        // Continue visiting children
149        syn::visit::visit_expr_method_call(self, node);
150    }
151}
152
153/// Rust AST Analyzer
154pub struct RustAnalyzer;
155
156impl RustAnalyzer {
157    /// Scan a Rust file for QAIL patterns using AST parsing
158    pub fn scan_file(path: &Path) -> Vec<CodeReference> {
159        let content = match fs::read_to_string(path) {
160            Ok(c) => c,
161            Err(_) => return vec![],
162        };
163
164        let syntax = match syn::parse_file(&content) {
165            Ok(s) => s,
166            Err(_) => return vec![], // Fall back to regex if parse fails
167        };
168
169        let mut visitor = QailVisitor::new(content);
170        visitor.visit_file(&syntax);
171
172        // Post-process: merge column patterns with their preceding table pattern
173        let mut merged_refs: Vec<CodeReference> = Vec::new();
174        let mut current_table = String::new();
175
176        for p in visitor.patterns {
177            if !p.table.is_empty() {
178                // This is a table reference (Qail::get("table"))
179                current_table = p.table.clone();
180                merged_refs.push(CodeReference {
181                    file: path.to_path_buf(),
182                    line: p.line,
183                    table: p.table,
184                    columns: p.columns,
185                    query_type: QueryType::Qail,
186                    snippet: p.snippet,
187                });
188            } else if !current_table.is_empty() {
189                // This is a column reference (.filter("col"), .columns([...]))
190                // Associate it with the current table
191                merged_refs.push(CodeReference {
192                    file: path.to_path_buf(),
193                    line: p.line,
194                    table: current_table.clone(), // <-- Associate with parent table!
195                    columns: p.columns,
196                    query_type: QueryType::Qail,
197                    snippet: p.snippet,
198                });
199            }
200        }
201
202        merged_refs
203    }
204
205    /// Check if this is a Rust project (has Cargo.toml)
206    pub fn is_rust_project(path: &Path) -> bool {
207        let cargo_toml = if path.is_file() {
208            path.parent().map(|p| p.join("Cargo.toml"))
209        } else {
210            Some(path.join("Cargo.toml"))
211        };
212        
213        cargo_toml.map(|p| p.exists()).unwrap_or(false)
214    }
215
216    /// Scan a directory for Rust files
217    pub fn scan_directory(dir: &Path) -> Vec<CodeReference> {
218        let mut refs = Vec::new();
219        Self::scan_dir_recursive(dir, &mut refs);
220        refs
221    }
222
223    fn scan_dir_recursive(dir: &Path, refs: &mut Vec<CodeReference>) {
224        let entries = match fs::read_dir(dir) {
225            Ok(e) => e,
226            Err(_) => return,
227        };
228
229        for entry in entries.flatten() {
230            let path = entry.path();
231
232            // Skip common non-source directories
233            if path.is_dir() {
234                let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
235                if name == "target" || name == ".git" || name == "node_modules" {
236                    continue;
237                }
238                Self::scan_dir_recursive(&path, refs);
239            } else if path.extension().map(|e| e == "rs").unwrap_or(false) {
240                refs.extend(Self::scan_file(&path));
241            }
242        }
243    }
244}
245
246// =============================================================================
247// Raw SQL Detection (for VS Code extension)
248// =============================================================================
249
250/// A raw SQL statement detected in Rust source code
251#[derive(Debug, Clone, serde::Serialize)]
252pub struct RawSqlMatch {
253    /// Line number (1-indexed)
254    pub line: usize,
255    pub column: usize,
256    /// End line number (1-indexed)
257    pub end_line: usize,
258    /// End column number (1-indexed)
259    pub end_column: usize,
260    /// Type of SQL statement
261    pub sql_type: String,
262    /// The raw SQL content
263    pub raw_sql: String,
264    /// Suggested QAIL equivalent
265    pub suggested_qail: String,
266}
267
268/// Visitor that finds raw SQL strings in Rust code
269struct SqlDetectorVisitor {
270    matches: Vec<RawSqlMatch>,
271}
272
273impl SqlDetectorVisitor {
274    fn new() -> Self {
275        Self { matches: Vec::new() }
276    }
277
278    /// Check if a string literal contains SQL
279    fn check_string_literal(&mut self, lit: &LitStr) {
280        let value = lit.value();
281        let upper = value.to_uppercase();
282        
283        let sql_type = if upper.contains("SELECT") && upper.contains("FROM") {
284            "SELECT"
285        } else if upper.contains("INSERT INTO") {
286            "INSERT"
287        } else if upper.contains("UPDATE") && upper.contains("SET") {
288            "UPDATE"
289        } else if upper.contains("DELETE FROM") {
290            "DELETE"
291        } else {
292            return; // Not SQL
293        };
294
295        let span = lit.span();
296        let start = span.start();
297        let end = span.end();
298
299        // The span includes the quotes, so we use the exact positions
300        // But we need to ensure we capture the entire literal including quotes
301        self.matches.push(RawSqlMatch {
302            line: start.line,
303            column: start.column, // 0-indexed, includes opening quote
304            end_line: end.line,
305            end_column: end.column, // 0-indexed, should be after closing quote
306            sql_type: sql_type.to_string(),
307            raw_sql: value.clone(),
308            suggested_qail: super::transformer::sql_to_qail(&value).unwrap_or_else(|_| "// Could not parse SQL".to_string()),
309        });
310    }
311}
312 
313impl<'ast> Visit<'ast> for SqlDetectorVisitor {
314    fn visit_lit(&mut self, lit: &'ast Lit) {
315        if let Lit::Str(lit_str) = lit {
316            self.check_string_literal(lit_str);
317        }
318        syn::visit::visit_lit(self, lit);
319    }
320}
321
322/// Detect raw SQL strings in a Rust source file
323pub fn detect_raw_sql(source: &str) -> Vec<RawSqlMatch> {
324    match syn::parse_file(source) {
325        Ok(syntax) => {
326            let mut visitor = SqlDetectorVisitor::new();
327            visitor.visit_file(&syntax);
328            visitor.matches
329        }
330        Err(_) => Vec::new(),
331    }
332}
333
334/// Detect raw SQL strings in a file by path
335pub fn detect_raw_sql_in_file(path: &Path) -> Vec<RawSqlMatch> {
336    match fs::read_to_string(path) {
337        Ok(source) => detect_raw_sql(&source),
338        Err(_) => Vec::new(),
339    }
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345
346    #[test]
347    fn test_parse_qailcmd_get() {
348        let code = r#"
349            fn query() {
350                let cmd = Qail::get("users")
351                    .filter("status", Operator::Eq, "active")
352                    .columns(["id", "name", "email"]);
353            }
354        "#;
355
356        let syntax = syn::parse_file(code).unwrap();
357        let mut visitor = QailVisitor::new(code.to_string());
358        visitor.visit_file(&syntax);
359
360        assert!(!visitor.patterns.is_empty());
361        // Should find "users" table
362        assert!(visitor.patterns.iter().any(|p| p.table == "users"));
363        // Should find "status" column
364        assert!(visitor.patterns.iter().any(|p| p.columns.contains(&"status".to_string())));
365    }
366
367    #[test]
368    fn test_detect_raw_sql() {
369        let code = r#"
370            fn query() {
371                let sql = "SELECT id, name FROM users WHERE status = 'active'";
372                sqlx::query(sql);
373            }
374        "#;
375
376        let matches = detect_raw_sql(code);
377        assert!(!matches.is_empty());
378        assert_eq!(matches[0].sql_type, "SELECT");
379        assert!(matches[0].suggested_qail.contains("Qail::get"));
380    }
381
382    #[test]
383    fn test_generate_cte_qail() {
384        let code = r##"
385            fn get_insights() {
386                let sql = r#"
387                    WITH stats AS (
388                        SELECT COUNT(*) FILTER (WHERE direction = 'outbound' 
389                        AND created_at > NOW() - INTERVAL '24 hours') AS sent
390                        FROM messages
391                    )
392                    SELECT sent FROM stats
393                "#;
394            }
395        "##;
396
397        let matches = detect_raw_sql(code);
398        assert!(!matches.is_empty());
399        
400        let qail = &matches[0].suggested_qail;
401        // Should detect CTE pattern - generates separate CTE variables
402        assert!(qail.contains("CTE 'stats'") || qail.contains("stats_cte"), 
403            "Should generate CTE variable: {}", qail);
404        // Should find the source table
405        assert!(qail.contains("messages"), "Should find source table 'messages': {}", qail);
406    }
407}
408