pub mod statement_dependencies;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::sql::parser::ast::{CTEType, SelectItem, SelectStatement, CTE};
#[derive(Serialize, Deserialize, Debug)]
pub struct QueryAnalysis {
pub valid: bool,
pub query_type: String,
pub has_star: bool,
pub star_locations: Vec<StarLocation>,
pub tables: Vec<String>,
pub columns: Vec<String>,
pub ctes: Vec<CteAnalysis>,
pub from_clause: Option<FromClauseInfo>,
pub where_clause: Option<WhereClauseInfo>,
pub errors: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct StarLocation {
pub line: usize,
pub column: usize,
pub context: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct CteAnalysis {
pub name: String,
pub cte_type: String,
pub start_line: usize,
pub end_line: usize,
pub start_offset: usize,
pub end_offset: usize,
pub has_star: bool,
pub columns: Vec<String>,
pub web_config: Option<WebCteConfig>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WebCteConfig {
pub url: String,
pub method: String,
pub headers: Vec<(String, String)>,
pub format: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct FromClauseInfo {
pub source_type: String,
pub name: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct WhereClauseInfo {
pub present: bool,
pub columns_referenced: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ColumnExpansion {
pub original_query: String,
pub expanded_query: String,
pub columns: Vec<ColumnInfo>,
pub expansion_count: usize,
pub cte_columns: HashMap<String, Vec<String>>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ColumnInfo {
pub name: String,
pub data_type: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct QueryContext {
pub context_type: String,
pub cte_name: Option<String>,
pub cte_index: Option<usize>,
pub query_bounds: QueryBounds,
pub parent_query_bounds: Option<QueryBounds>,
pub can_execute_independently: bool,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct QueryBounds {
pub start_line: usize,
pub end_line: usize,
pub start_offset: usize,
pub end_offset: usize,
}
pub fn analyze_query(ast: &SelectStatement, _sql: &str) -> QueryAnalysis {
let mut analysis = QueryAnalysis {
valid: true,
query_type: "SELECT".to_string(),
has_star: false,
star_locations: vec![],
tables: vec![],
columns: vec![],
ctes: vec![],
from_clause: None,
where_clause: None,
errors: vec![],
};
for cte in &ast.ctes {
analysis.ctes.push(analyze_cte(cte));
}
for item in &ast.select_items {
if matches!(item, SelectItem::Star { .. }) {
analysis.has_star = true;
analysis.star_locations.push(StarLocation {
line: 1, column: 8,
context: "main_query".to_string(),
});
}
}
if let Some(ref table) = ast.from_table {
let table_name: String = table.clone();
analysis.tables.push(table_name.clone());
analysis.from_clause = Some(FromClauseInfo {
source_type: "table".to_string(),
name: Some(table_name),
});
} else if ast.from_subquery.is_some() {
analysis.from_clause = Some(FromClauseInfo {
source_type: "subquery".to_string(),
name: None,
});
}
if let Some(ref where_clause) = ast.where_clause {
let mut columns = vec![];
for condition in &where_clause.conditions {
if let Some(col) = extract_column_from_expr(&condition.expr) {
if !columns.contains(&col) {
columns.push(col);
}
}
}
analysis.where_clause = Some(WhereClauseInfo {
present: true,
columns_referenced: columns,
});
}
for item in &ast.select_items {
if let SelectItem::Column {
column: col_ref, ..
} = item
{
if !analysis.columns.contains(&col_ref.name) {
analysis.columns.push(col_ref.name.clone());
}
}
}
analysis
}
fn analyze_cte(cte: &CTE) -> CteAnalysis {
let cte_type_str = match &cte.cte_type {
CTEType::Standard(_) => "Standard",
CTEType::Web(_) => "WEB",
CTEType::File(_) => "FILE",
};
let mut has_star = false;
let mut web_config = None;
match &cte.cte_type {
CTEType::Standard(stmt) => {
for item in &stmt.select_items {
if matches!(item, SelectItem::Star { .. }) {
has_star = true;
break;
}
}
}
CTEType::Web(web_spec) => {
let method_str = match &web_spec.method {
Some(m) => format!("{:?}", m),
None => "GET".to_string(),
};
web_config = Some(WebCteConfig {
url: web_spec.url.clone(),
method: method_str,
headers: web_spec.headers.clone(),
format: web_spec.format.as_ref().map(|f| format!("{:?}", f)),
});
}
CTEType::File(_) => {
}
}
CteAnalysis {
name: cte.name.clone(),
cte_type: cte_type_str.to_string(),
start_line: 1, end_line: 1, start_offset: 0,
end_offset: 0,
has_star,
columns: vec![], web_config,
}
}
fn extract_column_from_expr(expr: &crate::sql::parser::ast::SqlExpression) -> Option<String> {
use crate::sql::parser::ast::SqlExpression;
match expr {
SqlExpression::Column(col_ref) => Some(col_ref.name.clone()),
SqlExpression::BinaryOp { left, right, .. } => {
extract_column_from_expr(left).or_else(|| extract_column_from_expr(right))
}
SqlExpression::FunctionCall { args, .. } => {
args.first().and_then(|arg| extract_column_from_expr(arg))
}
_ => None,
}
}
pub fn extract_cte(ast: &SelectStatement, cte_name: &str) -> Option<String> {
let mut target_index = None;
for (idx, cte) in ast.ctes.iter().enumerate() {
if cte.name == cte_name {
target_index = Some(idx);
break;
}
}
let target_index = target_index?;
let mut parts = vec![];
parts.push("WITH".to_string());
for (idx, cte) in ast.ctes.iter().enumerate() {
if idx > target_index {
break; }
let prefix = if idx == 0 { "" } else { "," };
match &cte.cte_type {
CTEType::Standard(stmt) => {
parts.push(format!("{} {} AS (", prefix, cte.name));
parts.push(indent_query(&format_select_statement(stmt), 2));
parts.push(")".to_string());
}
CTEType::Web(web_spec) => {
parts.push(format!("{} WEB {} AS (", prefix, cte.name));
parts.push(format!(" URL '{}'", web_spec.url));
if let Some(ref m) = web_spec.method {
parts.push(format!(" METHOD {:?}", m));
}
if let Some(ref f) = web_spec.format {
parts.push(format!(" FORMAT {:?}", f));
}
if let Some(cache) = web_spec.cache_seconds {
parts.push(format!(" CACHE {}", cache));
}
if !web_spec.headers.is_empty() {
parts.push(" HEADERS (".to_string());
for (i, (k, v)) in web_spec.headers.iter().enumerate() {
let comma = if i < web_spec.headers.len() - 1 {
","
} else {
""
};
parts.push(format!(" '{}': '{}'{}", k, v, comma));
}
parts.push(" )".to_string());
}
for (field_name, file_path) in &web_spec.form_files {
parts.push(format!(" FORM_FILE '{}' '{}'", field_name, file_path));
}
for (field_name, value) in &web_spec.form_fields {
let trimmed_value = value.trim();
if (trimmed_value.starts_with('{') && trimmed_value.ends_with('}'))
|| (trimmed_value.starts_with('[') && trimmed_value.ends_with(']'))
{
parts.push(format!(
" FORM_FIELD '{}' $JSON${}$JSON$",
field_name, trimmed_value
));
} else {
parts.push(format!(" FORM_FIELD '{}' '{}'", field_name, value));
}
}
if let Some(ref b) = web_spec.body {
let trimmed_body = b.trim();
if (trimmed_body.starts_with('{') && trimmed_body.ends_with('}'))
|| (trimmed_body.starts_with('[') && trimmed_body.ends_with(']'))
{
parts.push(format!(" BODY $JSON${}$JSON$", trimmed_body));
} else {
parts.push(format!(" BODY '{}'", b));
}
}
if let Some(ref jp) = web_spec.json_path {
parts.push(format!(" JSON_PATH '{}'", jp));
}
parts.push(")".to_string());
}
CTEType::File(file_spec) => {
parts.push(format!("{} {} AS (", prefix, cte.name));
parts.push(format!(" FILE PATH '{}'", file_spec.path));
if file_spec.recursive {
parts.push(" RECURSIVE".to_string());
}
if let Some(ref g) = file_spec.glob {
parts.push(format!(" GLOB '{}'", g));
}
if let Some(d) = file_spec.max_depth {
parts.push(format!(" MAX_DEPTH {}", d));
}
if let Some(m) = file_spec.max_files {
parts.push(format!(" MAX_FILES {}", m));
}
if file_spec.follow_links {
parts.push(" FOLLOW_LINKS".to_string());
}
if file_spec.include_hidden {
parts.push(" INCLUDE_HIDDEN".to_string());
}
parts.push(")".to_string());
}
}
}
parts.push(format!("SELECT * FROM {}", cte_name));
Some(parts.join("\n"))
}
fn indent_query(query: &str, spaces: usize) -> String {
let indent = " ".repeat(spaces);
query
.lines()
.map(|line| format!("{}{}", indent, line))
.collect::<Vec<_>>()
.join("\n")
}
fn format_cte_as_query(cte: &CTE) -> String {
match &cte.cte_type {
CTEType::Standard(stmt) => {
format_select_statement(stmt)
}
CTEType::Web(web_spec) => {
let mut parts = vec![
format!("WITH WEB {} AS (", cte.name),
format!(" URL '{}'", web_spec.url),
];
if let Some(ref m) = web_spec.method {
parts.push(format!(" METHOD {:?}", m));
}
if !web_spec.headers.is_empty() {
parts.push(" HEADERS (".to_string());
for (k, v) in &web_spec.headers {
parts.push(format!(" '{}' = '{}'", k, v));
}
parts.push(" )".to_string());
}
if let Some(ref b) = web_spec.body {
parts.push(format!(" BODY '{}'", b));
}
if let Some(ref f) = web_spec.format {
parts.push(format!(" FORMAT {:?}", f));
}
parts.push(")".to_string());
parts.push(format!("SELECT * FROM {}", cte.name));
parts.join("\n")
}
CTEType::File(file_spec) => {
let mut parts = vec![
format!("WITH {} AS (", cte.name),
format!(" FILE PATH '{}'", file_spec.path),
];
if file_spec.recursive {
parts.push(" RECURSIVE".to_string());
}
if let Some(ref g) = file_spec.glob {
parts.push(format!(" GLOB '{}'", g));
}
if let Some(d) = file_spec.max_depth {
parts.push(format!(" MAX_DEPTH {}", d));
}
if let Some(m) = file_spec.max_files {
parts.push(format!(" MAX_FILES {}", m));
}
parts.push(")".to_string());
parts.push(format!("SELECT * FROM {}", cte.name));
parts.join("\n")
}
}
}
fn format_select_statement(stmt: &SelectStatement) -> String {
let mut parts = vec!["SELECT".to_string()];
if stmt.select_items.is_empty() {
parts.push(" *".to_string());
} else {
for (i, item) in stmt.select_items.iter().enumerate() {
let prefix = if i == 0 { " " } else { " , " };
match item {
SelectItem::Star { .. } => parts.push(format!("{}*", prefix)),
SelectItem::StarExclude {
excluded_columns, ..
} => {
parts.push(format!(
"{}* EXCLUDE ({})",
prefix,
excluded_columns.join(", ")
));
}
SelectItem::Column { column: col, .. } => {
parts.push(format!("{}{}", prefix, col.name));
}
SelectItem::Expression { expr, alias, .. } => {
let expr_str = format_expr(expr);
parts.push(format!("{}{} AS {}", prefix, expr_str, alias));
}
}
}
}
if let Some(ref table) = stmt.from_table {
parts.push(format!("FROM {}", table));
}
if let Some(ref where_clause) = stmt.where_clause {
parts.push("WHERE".to_string());
for (i, condition) in where_clause.conditions.iter().enumerate() {
let connector = if i > 0 {
condition
.connector
.as_ref()
.map(|op| match op {
crate::sql::parser::ast::LogicalOp::And => "AND",
crate::sql::parser::ast::LogicalOp::Or => "OR",
})
.unwrap_or("AND")
} else {
""
};
let expr_str = format_expr(&condition.expr);
if i == 0 {
parts.push(format!(" {}", expr_str));
} else {
parts.push(format!(" {} {}", connector, expr_str));
}
}
}
if let Some(limit) = stmt.limit {
parts.push(format!("LIMIT {}", limit));
}
parts.join("\n")
}
fn format_expr(expr: &crate::sql::parser::ast::SqlExpression) -> String {
crate::sql::parser::ast_formatter::format_expression(expr)
}
pub fn find_query_context(ast: &SelectStatement, line: usize, _column: usize) -> QueryContext {
for (idx, cte) in ast.ctes.iter().enumerate() {
let cte_start = 1 + (idx * 5);
let cte_end = cte_start + 4;
if line >= cte_start && line <= cte_end {
return QueryContext {
context_type: "CTE".to_string(),
cte_name: Some(cte.name.clone()),
cte_index: Some(idx),
query_bounds: QueryBounds {
start_line: cte_start,
end_line: cte_end,
start_offset: 0,
end_offset: 0,
},
parent_query_bounds: Some(QueryBounds {
start_line: 1,
end_line: 100, start_offset: 0,
end_offset: 0,
}),
can_execute_independently: matches!(cte.cte_type, CTEType::Standard(_)),
};
}
}
QueryContext {
context_type: "main_query".to_string(),
cte_name: None,
cte_index: None,
query_bounds: QueryBounds {
start_line: 1,
end_line: 100, start_offset: 0,
end_offset: 0,
},
parent_query_bounds: None,
can_execute_independently: true,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::recursive_parser::Parser;
#[test]
fn test_analyze_simple_query() {
let sql = "SELECT * FROM trades WHERE price > 100";
let mut parser = Parser::new(sql);
let ast = parser.parse().unwrap();
let analysis = analyze_query(&ast, sql);
assert!(analysis.valid);
assert_eq!(analysis.query_type, "SELECT");
assert!(analysis.has_star);
assert_eq!(analysis.star_locations.len(), 1);
assert_eq!(analysis.tables, vec!["trades"]);
}
#[test]
fn test_analyze_cte_query() {
let sql = "WITH trades AS (SELECT * FROM raw_trades) SELECT symbol FROM trades";
let mut parser = Parser::new(sql);
let ast = parser.parse().unwrap();
let analysis = analyze_query(&ast, sql);
assert!(analysis.valid);
assert_eq!(analysis.ctes.len(), 1);
assert_eq!(analysis.ctes[0].name, "trades");
assert_eq!(analysis.ctes[0].cte_type, "Standard");
assert!(analysis.ctes[0].has_star);
}
#[test]
fn test_extract_cte() {
let sql =
"WITH trades AS (SELECT * FROM raw_trades WHERE price > 100) SELECT * FROM trades";
let mut parser = Parser::new(sql);
let ast = parser.parse().unwrap();
let extracted = extract_cte(&ast, "trades").unwrap();
assert!(extracted.contains("SELECT"));
assert!(extracted.contains("raw_trades"));
assert!(extracted.contains("price > 100"));
}
}