Skip to main content

sql_cli/analysis/
mod.rs

1// Analysis module - Provides structured query analysis for IDE/plugin integration
2// This enables tools to understand SQL structure without manual text parsing
3
4pub mod statement_dependencies;
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::sql::parser::ast::{CTEType, SelectItem, SelectStatement, CTE};
10
11/// Comprehensive query analysis result
12#[derive(Serialize, Deserialize, Debug)]
13pub struct QueryAnalysis {
14    /// Whether the query is syntactically valid
15    pub valid: bool,
16    /// Type of query: "SELECT", "CTE", etc.
17    pub query_type: String,
18    /// Whether query contains SELECT *
19    pub has_star: bool,
20    /// Locations of SELECT * in query
21    pub star_locations: Vec<StarLocation>,
22    /// Tables referenced in query
23    pub tables: Vec<String>,
24    /// Columns explicitly referenced
25    pub columns: Vec<String>,
26    /// CTEs in query
27    pub ctes: Vec<CteAnalysis>,
28    /// FROM clause information
29    pub from_clause: Option<FromClauseInfo>,
30    /// WHERE clause information
31    pub where_clause: Option<WhereClauseInfo>,
32    /// Parse/validation errors
33    pub errors: Vec<String>,
34}
35
36/// Location of SELECT * in query
37#[derive(Serialize, Deserialize, Debug, Clone)]
38pub struct StarLocation {
39    /// Line number (1-indexed)
40    pub line: usize,
41    /// Column number (1-indexed)
42    pub column: usize,
43    /// Context: "main_query", "cte:name", "subquery"
44    pub context: String,
45}
46
47/// CTE analysis information
48#[derive(Serialize, Deserialize, Debug, Clone)]
49pub struct CteAnalysis {
50    /// CTE name
51    pub name: String,
52    /// CTE type: "Standard", "WEB", "Recursive"
53    pub cte_type: String,
54    /// Start line (will be populated when parser tracks positions)
55    pub start_line: usize,
56    /// End line (will be populated when parser tracks positions)
57    pub end_line: usize,
58    /// Start byte offset in query
59    pub start_offset: usize,
60    /// End byte offset in query
61    pub end_offset: usize,
62    /// Whether this CTE contains SELECT *
63    pub has_star: bool,
64    /// Columns produced by this CTE (if known)
65    pub columns: Vec<String>,
66    /// WEB CTE configuration (if applicable)
67    pub web_config: Option<WebCteConfig>,
68}
69
70/// WEB CTE configuration details
71#[derive(Serialize, Deserialize, Debug, Clone)]
72pub struct WebCteConfig {
73    /// URL endpoint
74    pub url: String,
75    /// HTTP method
76    pub method: String,
77    /// Headers (if any)
78    pub headers: Vec<(String, String)>,
79    /// Format (CSV, JSON, etc.)
80    pub format: Option<String>,
81}
82
83/// FROM clause analysis
84#[derive(Serialize, Deserialize, Debug, Clone)]
85pub struct FromClauseInfo {
86    /// Type: "table", "subquery", "function", "cte"
87    pub source_type: String,
88    /// Table/CTE name (if applicable)
89    pub name: Option<String>,
90}
91
92/// WHERE clause analysis
93#[derive(Serialize, Deserialize, Debug, Clone)]
94pub struct WhereClauseInfo {
95    /// Whether WHERE clause is present
96    pub present: bool,
97    /// Columns referenced in WHERE
98    pub columns_referenced: Vec<String>,
99}
100
101/// Column expansion result
102#[derive(Serialize, Deserialize, Debug)]
103pub struct ColumnExpansion {
104    /// Original query with SELECT *
105    pub original_query: String,
106    /// Expanded query with actual column names
107    pub expanded_query: String,
108    /// Column information
109    pub columns: Vec<ColumnInfo>,
110    /// Number of columns expanded
111    pub expansion_count: usize,
112    /// Columns from CTEs (cte_name -> columns)
113    pub cte_columns: HashMap<String, Vec<String>>,
114}
115
116/// Column information
117#[derive(Serialize, Deserialize, Debug, Clone)]
118pub struct ColumnInfo {
119    /// Column name
120    pub name: String,
121    /// Data type
122    pub data_type: String,
123}
124
125/// Query context at a specific position
126#[derive(Serialize, Deserialize, Debug)]
127pub struct QueryContext {
128    /// Type: "main_query", "CTE", "subquery"
129    pub context_type: String,
130    /// CTE name (if in CTE)
131    pub cte_name: Option<String>,
132    /// CTE index (0-based)
133    pub cte_index: Option<usize>,
134    /// Query bounds
135    pub query_bounds: QueryBounds,
136    /// Parent query bounds (if in CTE/subquery)
137    pub parent_query_bounds: Option<QueryBounds>,
138    /// Whether this can be executed independently
139    pub can_execute_independently: bool,
140}
141
142/// Query boundary information
143#[derive(Serialize, Deserialize, Debug, Clone)]
144pub struct QueryBounds {
145    /// Start line (1-indexed)
146    pub start_line: usize,
147    /// End line (1-indexed)
148    pub end_line: usize,
149    /// Start byte offset
150    pub start_offset: usize,
151    /// End byte offset
152    pub end_offset: usize,
153}
154
155/// Analyze a SQL query and return structured information
156pub fn analyze_query(ast: &SelectStatement, _sql: &str) -> QueryAnalysis {
157    let mut analysis = QueryAnalysis {
158        valid: true,
159        query_type: "SELECT".to_string(),
160        has_star: false,
161        star_locations: vec![],
162        tables: vec![],
163        columns: vec![],
164        ctes: vec![],
165        from_clause: None,
166        where_clause: None,
167        errors: vec![],
168    };
169
170    // Analyze CTEs
171    for cte in &ast.ctes {
172        analysis.ctes.push(analyze_cte(cte));
173    }
174
175    // Check for SELECT * in main query
176    for item in &ast.select_items {
177        if matches!(item, SelectItem::Star { .. }) {
178            analysis.has_star = true;
179            analysis.star_locations.push(StarLocation {
180                line: 1, // TODO: Track actual line when parser supports it
181                column: 8,
182                context: "main_query".to_string(),
183            });
184        }
185    }
186
187    // Extract table references
188    if let Some(ref table) = ast.from_table {
189        let table_name: String = table.clone();
190        analysis.tables.push(table_name.clone());
191        analysis.from_clause = Some(FromClauseInfo {
192            source_type: "table".to_string(),
193            name: Some(table_name),
194        });
195    } else if ast.from_subquery.is_some() {
196        analysis.from_clause = Some(FromClauseInfo {
197            source_type: "subquery".to_string(),
198            name: None,
199        });
200    }
201
202    // Analyze WHERE clause
203    if let Some(ref where_clause) = ast.where_clause {
204        let mut columns = vec![];
205        // TODO: Extract column references from WHERE conditions
206        for condition in &where_clause.conditions {
207            // Extract column names from condition.expr
208            // This is simplified - full implementation would walk the expression tree
209            if let Some(col) = extract_column_from_expr(&condition.expr) {
210                if !columns.contains(&col) {
211                    columns.push(col);
212                }
213            }
214        }
215
216        analysis.where_clause = Some(WhereClauseInfo {
217            present: true,
218            columns_referenced: columns,
219        });
220    }
221
222    // Extract explicitly named columns from SELECT
223    for item in &ast.select_items {
224        if let SelectItem::Column {
225            column: col_ref, ..
226        } = item
227        {
228            if !analysis.columns.contains(&col_ref.name) {
229                analysis.columns.push(col_ref.name.clone());
230            }
231        }
232    }
233
234    analysis
235}
236
237fn analyze_cte(cte: &CTE) -> CteAnalysis {
238    let cte_type_str = match &cte.cte_type {
239        CTEType::Standard(_) => "Standard",
240        CTEType::Web(_) => "WEB",
241        CTEType::File(_) => "FILE",
242    };
243
244    let mut has_star = false;
245    let mut web_config = None;
246
247    match &cte.cte_type {
248        CTEType::Standard(stmt) => {
249            // Check if CTE query has SELECT *
250            for item in &stmt.select_items {
251                if matches!(item, SelectItem::Star { .. }) {
252                    has_star = true;
253                    break;
254                }
255            }
256        }
257        CTEType::Web(web_spec) => {
258            let method_str = match &web_spec.method {
259                Some(m) => format!("{:?}", m),
260                None => "GET".to_string(),
261            };
262            web_config = Some(WebCteConfig {
263                url: web_spec.url.clone(),
264                method: method_str,
265                headers: web_spec.headers.clone(),
266                format: web_spec.format.as_ref().map(|f| format!("{:?}", f)),
267            });
268        }
269        CTEType::File(_) => {
270            // FILE CTE: metadata listing. No web config, no star check needed.
271        }
272    }
273
274    CteAnalysis {
275        name: cte.name.clone(),
276        cte_type: cte_type_str.to_string(),
277        start_line: 1, // TODO: Track when parser supports it
278        end_line: 1,   // TODO: Track when parser supports it
279        start_offset: 0,
280        end_offset: 0,
281        has_star,
282        columns: vec![], // TODO: Extract column names
283        web_config,
284    }
285}
286
287fn extract_column_from_expr(expr: &crate::sql::parser::ast::SqlExpression) -> Option<String> {
288    use crate::sql::parser::ast::SqlExpression;
289
290    match expr {
291        SqlExpression::Column(col_ref) => Some(col_ref.name.clone()),
292        SqlExpression::BinaryOp { left, right, .. } => {
293            // Try left first, then right
294            extract_column_from_expr(left).or_else(|| extract_column_from_expr(right))
295        }
296        SqlExpression::FunctionCall { args, .. } => {
297            // Extract from first argument
298            args.first().and_then(|arg| extract_column_from_expr(arg))
299        }
300        _ => None,
301    }
302}
303
304/// Extract a specific CTE as a testable query
305/// Returns ALL CTEs up to and including the target, then SELECT * FROM target
306/// This ensures the query is executable since CTEs depend on previous CTEs
307pub fn extract_cte(ast: &SelectStatement, cte_name: &str) -> Option<String> {
308    // Find the target CTE index
309    let mut target_index = None;
310    for (idx, cte) in ast.ctes.iter().enumerate() {
311        if cte.name == cte_name {
312            target_index = Some(idx);
313            break;
314        }
315    }
316
317    let target_index = target_index?;
318
319    // Build query with all CTEs up to and including target
320    let mut parts = vec![];
321
322    // Add WITH clause with all CTEs up to target
323    parts.push("WITH".to_string());
324
325    for (idx, cte) in ast.ctes.iter().enumerate() {
326        if idx > target_index {
327            break; // Stop after target CTE
328        }
329
330        // Add comma separator for CTEs after the first
331        let prefix = if idx == 0 { "" } else { "," };
332
333        match &cte.cte_type {
334            CTEType::Standard(stmt) => {
335                parts.push(format!("{} {} AS (", prefix, cte.name));
336                parts.push(indent_query(&format_select_statement(stmt), 2));
337                parts.push(")".to_string());
338            }
339            CTEType::Web(web_spec) => {
340                parts.push(format!("{} WEB {} AS (", prefix, cte.name));
341                parts.push(format!("  URL '{}'", web_spec.url));
342
343                if let Some(ref m) = web_spec.method {
344                    parts.push(format!("  METHOD {:?}", m));
345                }
346
347                if let Some(ref f) = web_spec.format {
348                    parts.push(format!("  FORMAT {:?}", f));
349                }
350
351                if let Some(cache) = web_spec.cache_seconds {
352                    parts.push(format!("  CACHE {}", cache));
353                }
354
355                if !web_spec.headers.is_empty() {
356                    parts.push("  HEADERS (".to_string());
357                    for (i, (k, v)) in web_spec.headers.iter().enumerate() {
358                        let comma = if i < web_spec.headers.len() - 1 {
359                            ","
360                        } else {
361                            ""
362                        };
363                        parts.push(format!("    '{}': '{}'{}", k, v, comma));
364                    }
365                    parts.push("  )".to_string());
366                }
367
368                // Add FORM_FILE entries
369                for (field_name, file_path) in &web_spec.form_files {
370                    parts.push(format!("  FORM_FILE '{}' '{}'", field_name, file_path));
371                }
372
373                // Add FORM_FIELD entries (handle JSON formatting)
374                for (field_name, value) in &web_spec.form_fields {
375                    let trimmed_value = value.trim();
376                    // Check if value looks like JSON
377                    if (trimmed_value.starts_with('{') && trimmed_value.ends_with('}'))
378                        || (trimmed_value.starts_with('[') && trimmed_value.ends_with(']'))
379                    {
380                        // Use $JSON$ delimiters for JSON values
381                        parts.push(format!(
382                            "  FORM_FIELD '{}' $JSON${}$JSON$",
383                            field_name, trimmed_value
384                        ));
385                    } else {
386                        // Regular value with single quotes
387                        parts.push(format!("  FORM_FIELD '{}' '{}'", field_name, value));
388                    }
389                }
390
391                if let Some(ref b) = web_spec.body {
392                    // Check if body is JSON
393                    let trimmed_body = b.trim();
394                    if (trimmed_body.starts_with('{') && trimmed_body.ends_with('}'))
395                        || (trimmed_body.starts_with('[') && trimmed_body.ends_with(']'))
396                    {
397                        parts.push(format!("  BODY $JSON${}$JSON$", trimmed_body));
398                    } else {
399                        parts.push(format!("  BODY '{}'", b));
400                    }
401                }
402
403                if let Some(ref jp) = web_spec.json_path {
404                    parts.push(format!("  JSON_PATH '{}'", jp));
405                }
406
407                parts.push(")".to_string());
408            }
409            CTEType::File(file_spec) => {
410                parts.push(format!("{} {} AS (", prefix, cte.name));
411                parts.push(format!("  FILE PATH '{}'", file_spec.path));
412                if file_spec.recursive {
413                    parts.push("  RECURSIVE".to_string());
414                }
415                if let Some(ref g) = file_spec.glob {
416                    parts.push(format!("  GLOB '{}'", g));
417                }
418                if let Some(d) = file_spec.max_depth {
419                    parts.push(format!("  MAX_DEPTH {}", d));
420                }
421                if let Some(m) = file_spec.max_files {
422                    parts.push(format!("  MAX_FILES {}", m));
423                }
424                if file_spec.follow_links {
425                    parts.push("  FOLLOW_LINKS".to_string());
426                }
427                if file_spec.include_hidden {
428                    parts.push("  INCLUDE_HIDDEN".to_string());
429                }
430                parts.push(")".to_string());
431            }
432        }
433    }
434
435    // Add SELECT * FROM target
436    parts.push(format!("SELECT * FROM {}", cte_name));
437
438    Some(parts.join("\n"))
439}
440
441fn indent_query(query: &str, spaces: usize) -> String {
442    let indent = " ".repeat(spaces);
443    query
444        .lines()
445        .map(|line| format!("{}{}", indent, line))
446        .collect::<Vec<_>>()
447        .join("\n")
448}
449
450fn format_cte_as_query(cte: &CTE) -> String {
451    match &cte.cte_type {
452        CTEType::Standard(stmt) => {
453            // Format the SELECT statement
454            // This is simplified - could use the AST formatter
455            format_select_statement(stmt)
456        }
457        CTEType::Web(web_spec) => {
458            // Can't execute WEB CTE independently (needs WITH WEB syntax)
459            let mut parts = vec![
460                format!("WITH WEB {} AS (", cte.name),
461                format!("  URL '{}'", web_spec.url),
462            ];
463
464            if let Some(ref m) = web_spec.method {
465                parts.push(format!("  METHOD {:?}", m));
466            }
467
468            if !web_spec.headers.is_empty() {
469                parts.push("  HEADERS (".to_string());
470                for (k, v) in &web_spec.headers {
471                    parts.push(format!("    '{}' = '{}'", k, v));
472                }
473                parts.push("  )".to_string());
474            }
475
476            if let Some(ref b) = web_spec.body {
477                parts.push(format!("  BODY '{}'", b));
478            }
479
480            if let Some(ref f) = web_spec.format {
481                parts.push(format!("  FORMAT {:?}", f));
482            }
483
484            parts.push(")".to_string());
485            parts.push(format!("SELECT * FROM {}", cte.name));
486
487            parts.join("\n")
488        }
489        CTEType::File(file_spec) => {
490            let mut parts = vec![
491                format!("WITH {} AS (", cte.name),
492                format!("  FILE PATH '{}'", file_spec.path),
493            ];
494            if file_spec.recursive {
495                parts.push("  RECURSIVE".to_string());
496            }
497            if let Some(ref g) = file_spec.glob {
498                parts.push(format!("  GLOB '{}'", g));
499            }
500            if let Some(d) = file_spec.max_depth {
501                parts.push(format!("  MAX_DEPTH {}", d));
502            }
503            if let Some(m) = file_spec.max_files {
504                parts.push(format!("  MAX_FILES {}", m));
505            }
506            parts.push(")".to_string());
507            parts.push(format!("SELECT * FROM {}", cte.name));
508            parts.join("\n")
509        }
510    }
511}
512
513fn format_select_statement(stmt: &SelectStatement) -> String {
514    let mut parts = vec!["SELECT".to_string()];
515
516    // SELECT items
517    if stmt.select_items.is_empty() {
518        parts.push("  *".to_string());
519    } else {
520        for (i, item) in stmt.select_items.iter().enumerate() {
521            let prefix = if i == 0 { "    " } else { "  , " };
522            match item {
523                SelectItem::Star { .. } => parts.push(format!("{}*", prefix)),
524                SelectItem::StarExclude {
525                    excluded_columns, ..
526                } => {
527                    parts.push(format!(
528                        "{}* EXCLUDE ({})",
529                        prefix,
530                        excluded_columns.join(", ")
531                    ));
532                }
533                SelectItem::Column { column: col, .. } => {
534                    parts.push(format!("{}{}", prefix, col.name));
535                }
536                SelectItem::Expression { expr, alias, .. } => {
537                    let expr_str = format_expr(expr);
538                    parts.push(format!("{}{} AS {}", prefix, expr_str, alias));
539                }
540            }
541        }
542    }
543
544    // FROM
545    if let Some(ref table) = stmt.from_table {
546        parts.push(format!("FROM {}", table));
547    }
548
549    // WHERE
550    if let Some(ref where_clause) = stmt.where_clause {
551        parts.push("WHERE".to_string());
552        for (i, condition) in where_clause.conditions.iter().enumerate() {
553            let connector = if i > 0 {
554                condition
555                    .connector
556                    .as_ref()
557                    .map(|op| match op {
558                        crate::sql::parser::ast::LogicalOp::And => "AND",
559                        crate::sql::parser::ast::LogicalOp::Or => "OR",
560                    })
561                    .unwrap_or("AND")
562            } else {
563                ""
564            };
565            let expr_str = format_expr(&condition.expr);
566            if i == 0 {
567                parts.push(format!("  {}", expr_str));
568            } else {
569                parts.push(format!("  {} {}", connector, expr_str));
570            }
571        }
572    }
573
574    // LIMIT
575    if let Some(limit) = stmt.limit {
576        parts.push(format!("LIMIT {}", limit));
577    }
578
579    parts.join("\n")
580}
581
582/// Format an expression using the centralized AST formatter
583/// This ensures consistency with query reformatting
584fn format_expr(expr: &crate::sql::parser::ast::SqlExpression) -> String {
585    crate::sql::parser::ast_formatter::format_expression(expr)
586}
587
588/// Find query context at a specific line:column position
589pub fn find_query_context(ast: &SelectStatement, line: usize, _column: usize) -> QueryContext {
590    // Check if position is within a CTE
591    for (idx, cte) in ast.ctes.iter().enumerate() {
592        // TODO: Use actual line numbers when parser tracks them
593        // For now, assume each CTE is ~5 lines
594        let cte_start = 1 + (idx * 5);
595        let cte_end = cte_start + 4;
596
597        if line >= cte_start && line <= cte_end {
598            return QueryContext {
599                context_type: "CTE".to_string(),
600                cte_name: Some(cte.name.clone()),
601                cte_index: Some(idx),
602                query_bounds: QueryBounds {
603                    start_line: cte_start,
604                    end_line: cte_end,
605                    start_offset: 0,
606                    end_offset: 0,
607                },
608                parent_query_bounds: Some(QueryBounds {
609                    start_line: 1,
610                    end_line: 100, // TODO: Track actual end
611                    start_offset: 0,
612                    end_offset: 0,
613                }),
614                can_execute_independently: matches!(cte.cte_type, CTEType::Standard(_)),
615            };
616        }
617    }
618
619    // Otherwise, in main query
620    QueryContext {
621        context_type: "main_query".to_string(),
622        cte_name: None,
623        cte_index: None,
624        query_bounds: QueryBounds {
625            start_line: 1,
626            end_line: 100, // TODO: Track actual end
627            start_offset: 0,
628            end_offset: 0,
629        },
630        parent_query_bounds: None,
631        can_execute_independently: true,
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use crate::sql::recursive_parser::Parser;
639
640    #[test]
641    fn test_analyze_simple_query() {
642        let sql = "SELECT * FROM trades WHERE price > 100";
643        let mut parser = Parser::new(sql);
644        let ast = parser.parse().unwrap();
645
646        let analysis = analyze_query(&ast, sql);
647
648        assert!(analysis.valid);
649        assert_eq!(analysis.query_type, "SELECT");
650        assert!(analysis.has_star);
651        assert_eq!(analysis.star_locations.len(), 1);
652        assert_eq!(analysis.tables, vec!["trades"]);
653    }
654
655    #[test]
656    fn test_analyze_cte_query() {
657        let sql = "WITH trades AS (SELECT * FROM raw_trades) SELECT symbol FROM trades";
658        let mut parser = Parser::new(sql);
659        let ast = parser.parse().unwrap();
660
661        let analysis = analyze_query(&ast, sql);
662
663        assert!(analysis.valid);
664        assert_eq!(analysis.ctes.len(), 1);
665        assert_eq!(analysis.ctes[0].name, "trades");
666        assert_eq!(analysis.ctes[0].cte_type, "Standard");
667        assert!(analysis.ctes[0].has_star);
668    }
669
670    #[test]
671    fn test_extract_cte() {
672        let sql =
673            "WITH trades AS (SELECT * FROM raw_trades WHERE price > 100) SELECT * FROM trades";
674        let mut parser = Parser::new(sql);
675        let ast = parser.parse().unwrap();
676
677        let extracted = extract_cte(&ast, "trades").unwrap();
678
679        assert!(extracted.contains("SELECT"));
680        assert!(extracted.contains("raw_trades"));
681        assert!(extracted.contains("price > 100"));
682    }
683}