qail_core/analyzer/rust_ast/
detector.rs1use std::fs;
7use std::path::Path;
8use syn::visit::Visit;
9use syn::{Expr, ExprCall, ExprMethodCall, Lit, LitStr};
10
11use crate::analyzer::{CodeReference, QueryType};
13
14use proc_macro2::Span;
16
17#[derive(Debug, Clone)]
19pub struct RustPattern {
20 pub table: String,
21 pub columns: Vec<String>,
22 pub line: usize,
24 pub snippet: String,
26}
27
28struct QailVisitor {
30 patterns: Vec<RustPattern>,
31 #[allow(dead_code)]
32 source: String,
33}
34
35impl QailVisitor {
36 fn new(source: String) -> Self {
37 Self {
38 patterns: Vec::new(),
39 source,
40 }
41 }
42
43 fn extract_string(lit: &LitStr) -> String {
45 lit.value()
46 }
47
48 fn line_from_span(&self, span: Span) -> usize {
50 span.start().line
51 }
52
53 fn extract_strings_from_expr(expr: &Expr) -> Vec<String> {
55 let mut strings = Vec::new();
56 match expr {
57 Expr::Lit(lit) => {
59 if let Lit::Str(s) = &lit.lit {
60 strings.push(Self::extract_string(s));
61 }
62 }
63 Expr::Array(arr) => {
65 for elem in &arr.elems {
66 strings.extend(Self::extract_strings_from_expr(elem));
67 }
68 }
69 Expr::Reference(r) => {
71 strings.extend(Self::extract_strings_from_expr(&r.expr));
72 }
73 _ => {}
74 }
75 strings
76 }
77
78 fn check_qailcmd_call(&mut self, path: &syn::ExprPath, args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>) {
80 let segments: Vec<_> = path.path.segments.iter().map(|s| s.ident.to_string()).collect();
81
82 if segments.len() >= 2 && segments[0] == "Qail" {
84 let action = &segments[1];
85
86 let mut columns = Vec::new();
87 let mut table = String::new();
88
89 for arg in args {
90 let extracted = Self::extract_strings_from_expr(arg);
91 if table.is_empty() && !extracted.is_empty() {
92 table = extracted[0].clone();
93 } else {
94 columns.extend(extracted);
95 }
96 }
97
98 if !table.is_empty() {
99 self.patterns.push(RustPattern {
100 table: table.clone(),
101 columns,
102 line: self.line_from_span(path.path.segments.first().map(|s| s.ident.span()).unwrap_or_else(Span::call_site)),
103 snippet: format!("Qail::{}(\"{}\")", action, table),
104 });
105 }
106 }
107 }
108
109 fn check_method_call(&mut self, method: &str, args: &syn::punctuated::Punctuated<Expr, syn::token::Comma>, span: Span) {
111 let mut all_strings = Vec::new();
112 for arg in args {
113 all_strings.extend(Self::extract_strings_from_expr(arg));
114 }
115
116 if !all_strings.is_empty() {
118 let snippet = if all_strings.len() == 1 {
119 format!(".{}(\"{}\")", method, all_strings[0])
120 } else if all_strings.len() <= 3 {
121 format!(".{}([{}])", method, all_strings.iter().map(|s| format!("\"{}\"", s)).collect::<Vec<_>>().join(", "))
122 } else {
123 format!(".{}([\"{}\" +{}])", method, all_strings[0], all_strings.len() - 1)
124 };
125
126 self.patterns.push(RustPattern {
127 table: String::new(), columns: all_strings,
129 line: self.line_from_span(span),
130 snippet,
131 });
132 }
133 }
134}
135
136impl<'ast> Visit<'ast> for QailVisitor {
137 fn visit_expr_call(&mut self, node: &'ast ExprCall) {
138 if let Expr::Path(path) = &*node.func {
139 self.check_qailcmd_call(path, &node.args);
140 }
141 syn::visit::visit_expr_call(self, node);
143 }
144
145 fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
146 let method = node.method.to_string();
147 self.check_method_call(&method, &node.args, node.method.span());
148 syn::visit::visit_expr_method_call(self, node);
150 }
151}
152
153pub struct RustAnalyzer;
155
156impl RustAnalyzer {
157 pub fn scan_file(path: &Path) -> Vec<CodeReference> {
159 let content = match fs::read_to_string(path) {
160 Ok(c) => c,
161 Err(_) => return vec![],
162 };
163
164 let syntax = match syn::parse_file(&content) {
165 Ok(s) => s,
166 Err(_) => return vec![], };
168
169 let mut visitor = QailVisitor::new(content);
170 visitor.visit_file(&syntax);
171
172 let mut merged_refs: Vec<CodeReference> = Vec::new();
174 let mut current_table = String::new();
175
176 for p in visitor.patterns {
177 if !p.table.is_empty() {
178 current_table = p.table.clone();
180 merged_refs.push(CodeReference {
181 file: path.to_path_buf(),
182 line: p.line,
183 table: p.table,
184 columns: p.columns,
185 query_type: QueryType::Qail,
186 snippet: p.snippet,
187 });
188 } else if !current_table.is_empty() {
189 merged_refs.push(CodeReference {
192 file: path.to_path_buf(),
193 line: p.line,
194 table: current_table.clone(), columns: p.columns,
196 query_type: QueryType::Qail,
197 snippet: p.snippet,
198 });
199 }
200 }
201
202 merged_refs
203 }
204
205 pub fn is_rust_project(path: &Path) -> bool {
207 let cargo_toml = if path.is_file() {
208 path.parent().map(|p| p.join("Cargo.toml"))
209 } else {
210 Some(path.join("Cargo.toml"))
211 };
212
213 cargo_toml.map(|p| p.exists()).unwrap_or(false)
214 }
215
216 pub fn scan_directory(dir: &Path) -> Vec<CodeReference> {
218 let mut refs = Vec::new();
219 Self::scan_dir_recursive(dir, &mut refs);
220 refs
221 }
222
223 fn scan_dir_recursive(dir: &Path, refs: &mut Vec<CodeReference>) {
224 let entries = match fs::read_dir(dir) {
225 Ok(e) => e,
226 Err(_) => return,
227 };
228
229 for entry in entries.flatten() {
230 let path = entry.path();
231
232 if path.is_dir() {
234 let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
235 if name == "target" || name == ".git" || name == "node_modules" {
236 continue;
237 }
238 Self::scan_dir_recursive(&path, refs);
239 } else if path.extension().map(|e| e == "rs").unwrap_or(false) {
240 refs.extend(Self::scan_file(&path));
241 }
242 }
243 }
244}
245
246#[derive(Debug, Clone, serde::Serialize)]
252pub struct RawSqlMatch {
253 pub line: usize,
255 pub column: usize,
256 pub end_line: usize,
258 pub end_column: usize,
260 pub sql_type: String,
262 pub raw_sql: String,
264 pub suggested_qail: String,
266}
267
268struct SqlDetectorVisitor {
270 matches: Vec<RawSqlMatch>,
271}
272
273impl SqlDetectorVisitor {
274 fn new() -> Self {
275 Self { matches: Vec::new() }
276 }
277
278 fn check_string_literal(&mut self, lit: &LitStr) {
280 let value = lit.value();
281 let upper = value.to_uppercase();
282
283 let sql_type = if upper.contains("SELECT") && upper.contains("FROM") {
284 "SELECT"
285 } else if upper.contains("INSERT INTO") {
286 "INSERT"
287 } else if upper.contains("UPDATE") && upper.contains("SET") {
288 "UPDATE"
289 } else if upper.contains("DELETE FROM") {
290 "DELETE"
291 } else {
292 return; };
294
295 let span = lit.span();
296 let start = span.start();
297 let end = span.end();
298
299 self.matches.push(RawSqlMatch {
302 line: start.line,
303 column: start.column, end_line: end.line,
305 end_column: end.column, sql_type: sql_type.to_string(),
307 raw_sql: value.clone(),
308 suggested_qail: super::transformer::sql_to_qail(&value).unwrap_or_else(|_| "// Could not parse SQL".to_string()),
309 });
310 }
311}
312
313impl<'ast> Visit<'ast> for SqlDetectorVisitor {
314 fn visit_lit(&mut self, lit: &'ast Lit) {
315 if let Lit::Str(lit_str) = lit {
316 self.check_string_literal(lit_str);
317 }
318 syn::visit::visit_lit(self, lit);
319 }
320}
321
322pub fn detect_raw_sql(source: &str) -> Vec<RawSqlMatch> {
324 match syn::parse_file(source) {
325 Ok(syntax) => {
326 let mut visitor = SqlDetectorVisitor::new();
327 visitor.visit_file(&syntax);
328 visitor.matches
329 }
330 Err(_) => Vec::new(),
331 }
332}
333
334pub fn detect_raw_sql_in_file(path: &Path) -> Vec<RawSqlMatch> {
336 match fs::read_to_string(path) {
337 Ok(source) => detect_raw_sql(&source),
338 Err(_) => Vec::new(),
339 }
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345
346 #[test]
347 fn test_parse_qailcmd_get() {
348 let code = r#"
349 fn query() {
350 let cmd = Qail::get("users")
351 .filter("status", Operator::Eq, "active")
352 .columns(["id", "name", "email"]);
353 }
354 "#;
355
356 let syntax = syn::parse_file(code).unwrap();
357 let mut visitor = QailVisitor::new(code.to_string());
358 visitor.visit_file(&syntax);
359
360 assert!(!visitor.patterns.is_empty());
361 assert!(visitor.patterns.iter().any(|p| p.table == "users"));
363 assert!(visitor.patterns.iter().any(|p| p.columns.contains(&"status".to_string())));
365 }
366
367 #[test]
368 fn test_detect_raw_sql() {
369 let code = r#"
370 fn query() {
371 let sql = "SELECT id, name FROM users WHERE status = 'active'";
372 sqlx::query(sql);
373 }
374 "#;
375
376 let matches = detect_raw_sql(code);
377 assert!(!matches.is_empty());
378 assert_eq!(matches[0].sql_type, "SELECT");
379 assert!(matches[0].suggested_qail.contains("Qail::get"));
380 }
381
382 #[test]
383 fn test_generate_cte_qail() {
384 let code = r##"
385 fn get_insights() {
386 let sql = r#"
387 WITH stats AS (
388 SELECT COUNT(*) FILTER (WHERE direction = 'outbound'
389 AND created_at > NOW() - INTERVAL '24 hours') AS sent
390 FROM messages
391 )
392 SELECT sent FROM stats
393 "#;
394 }
395 "##;
396
397 let matches = detect_raw_sql(code);
398 assert!(!matches.is_empty());
399
400 let qail = &matches[0].suggested_qail;
401 assert!(qail.contains("CTE 'stats'") || qail.contains("stats_cte"),
403 "Should generate CTE variable: {}", qail);
404 assert!(qail.contains("messages"), "Should find source table 'messages': {}", qail);
406 }
407}
408