use crate::reader::Reader;
use crate::{naming, parser::SourceTree, GgsqlError, Result};
use std::collections::HashSet;
use tree_sitter::Node;
#[derive(Debug, Clone)]
pub struct CteDefinition {
pub name: String,
pub body: String,
pub column_aliases: Vec<String>,
}
pub fn extract_ctes(source_tree: &SourceTree) -> Vec<CteDefinition> {
let root = source_tree.root();
source_tree
.find_nodes(&root, "(cte_definition) @cte")
.into_iter()
.filter_map(|node| parse_cte_definition(&node, source_tree.source))
.collect()
}
fn parse_cte_definition(node: &Node, source: &str) -> Option<CteDefinition> {
let mut name: Option<String> = None;
let mut column_aliases: Vec<String> = Vec::new();
let mut body_start: Option<usize> = None;
let mut body_end: Option<usize> = None;
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"identifier" => {
if name.is_none() {
name = Some(get_node_text(&child, source).to_string());
} else {
column_aliases.push(get_node_text(&child, source).to_string());
}
}
"select_statement" | "subquery_body" | "with_statement" => {
body_start = Some(child.start_byte());
body_end = Some(child.end_byte());
}
_ => {}
}
}
match (name, body_start, body_end) {
(Some(n), Some(start), Some(end)) => {
let body = source[start..end].to_string();
Some(CteDefinition {
name: n,
body,
column_aliases,
})
}
_ => None,
}
}
pub(crate) fn get_node_text<'a>(node: &Node, source: &'a str) -> &'a str {
&source[node.start_byte()..node.end_byte()]
}
pub fn transform_cte_references(sql: &str, cte_names: &HashSet<String>) -> String {
if cte_names.is_empty() {
return sql.to_string();
}
let mut result = sql.to_string();
for cte_name in cte_names {
let temp_table_name = naming::quote_ident(&naming::cte_table(cte_name));
let patterns = [
(
format!(r"(?i)(\bFROM\s+){}(\s|,|\)|$)", regex::escape(cte_name)),
format!("${{1}}{}${{2}}", temp_table_name),
),
(
format!(r"(?i)(\bJOIN\s+){}(\s|,|\)|$)", regex::escape(cte_name)),
format!("${{1}}{}${{2}}", temp_table_name),
),
(
format!(
r"(?i)\b{}(\.[a-zA-Z_][a-zA-Z0-9_]*)",
regex::escape(cte_name)
),
format!("{}${{1}}", temp_table_name),
),
];
for (pattern, replacement) in patterns {
if let Ok(re) = regex::Regex::new(&pattern) {
result = re.replace_all(&result, replacement.as_str()).to_string();
}
}
}
result
}
pub fn materialize_ctes(ctes: &[CteDefinition], reader: &dyn Reader) -> Result<HashSet<String>> {
let mut materialized = HashSet::new();
for cte in ctes {
let transformed_body = transform_cte_references(&cte.body, &materialized);
let temp_table_name = naming::cte_table(&cte.name);
let statements = reader.dialect().create_or_replace_temp_table_sql(
&temp_table_name,
&cte.column_aliases,
&transformed_body,
);
for stmt in &statements {
reader.execute_sql(stmt).map_err(|e| {
GgsqlError::ReaderError(format!("Failed to materialize CTE '{}': {}", cte.name, e))
})?;
}
materialized.insert(cte.name.clone());
}
Ok(materialized)
}
pub fn split_with_query(source_tree: &SourceTree) -> Option<(String, String)> {
let root = source_tree.root();
let with_node = source_tree.find_node(&root, "(with_statement) @with")?;
let mut cursor = with_node.walk();
let mut last_cte_end: Option<usize> = None;
let mut tail_node = None;
let mut seen_cte = false;
for child in with_node.children(&mut cursor) {
match child.kind() {
"cte_definition" => {
seen_cte = true;
last_cte_end = Some(child.end_byte());
}
"select_statement" if seen_cte => {
tail_node = Some((child, false));
break;
}
"from_statement" if seen_cte => {
tail_node = Some((child, true));
break;
}
_ => {}
}
}
let cte_prefix = source_tree.source[with_node.start_byte()..last_cte_end?].to_string();
let (node, is_from) = tail_node?;
let trailing = if is_from {
format!("SELECT * {}", source_tree.get_text(&node))
} else {
source_tree.get_text(&node)
};
Some((cte_prefix, trailing))
}
pub fn transform_global_sql(
source_tree: &SourceTree,
materialized_ctes: &HashSet<String>,
) -> (Vec<String>, Option<String>) {
let root = source_tree.root();
let side_effect_stmts = r#"
(sql_statement
[(create_statement)
(insert_statement)
(update_statement)
(delete_statement)] @stmt)
"#;
let side_effects: Vec<String> = source_tree
.find_texts(&root, side_effect_stmts)
.into_iter()
.map(|s| transform_cte_references(s.trim(), materialized_ctes))
.filter(|s| !s.is_empty())
.collect();
let select_sql = split_with_query(source_tree)
.map(|(_, select)| select)
.or_else(|| {
source_tree.find_text(&root, "(sql_statement (select_statement) @select)")
});
if let Some(select_sql) = select_sql {
return (
side_effects,
Some(transform_cte_references(&select_sql, materialized_ctes)),
);
}
if !has_executable_sql(source_tree) {
return (vec![], None);
}
let viz_from_query = source_tree
.find_text(
&root,
r#"(visualise_statement (from_clause (table_ref) @table))"#,
)
.map(|table| {
let q = format!("SELECT * FROM {}", table);
transform_cte_references(&q, materialized_ctes)
});
if !side_effects.is_empty() || viz_from_query.is_some() {
(side_effects, viz_from_query)
} else {
let query = source_tree
.extract_sql()
.map(|s| transform_cte_references(&s, materialized_ctes));
(vec![], query)
}
}
pub fn has_executable_sql(source_tree: &SourceTree) -> bool {
let root = source_tree.root();
let direct_statements = r#"
(sql_statement
[(select_statement)
(create_statement)
(insert_statement)
(update_statement)
(delete_statement)
(from_statement)] @stmt)
"#;
if source_tree.find_node(&root, direct_statements).is_some() {
return true;
}
if split_with_query(source_tree).is_some() {
return true;
}
let visualise_from = r#"
(visualise_statement
(from_clause) @from)
"#;
if source_tree.find_node(&root, visualise_from).is_some() {
return true;
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_ctes_single() {
let sql = "WITH sales AS (SELECT * FROM raw_sales) SELECT * FROM sales";
let source_tree = SourceTree::new(sql).unwrap();
let ctes = extract_ctes(&source_tree);
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].name, "sales");
assert!(ctes[0].body.contains("SELECT * FROM raw_sales"));
}
#[test]
fn test_extract_ctes_multiple() {
let sql = "WITH
sales AS (SELECT * FROM raw_sales),
targets AS (SELECT * FROM goals)
SELECT * FROM sales";
let source_tree = SourceTree::new(sql).unwrap();
let ctes = extract_ctes(&source_tree);
assert_eq!(ctes.len(), 2);
assert_eq!(ctes[0].name, "sales");
assert_eq!(ctes[1].name, "targets");
}
#[test]
fn test_extract_ctes_with_column_aliases() {
let sql = "WITH t(value, label) AS (SELECT * FROM (VALUES (70, 'Target'))) SELECT * FROM t";
let source_tree = SourceTree::new(sql).unwrap();
let ctes = extract_ctes(&source_tree);
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].name, "t");
assert_eq!(ctes[0].column_aliases, vec!["value", "label"]);
}
#[test]
fn test_extract_ctes_without_column_aliases() {
let sql = "WITH sales AS (SELECT * FROM raw_sales) SELECT * FROM sales";
let source_tree = SourceTree::new(sql).unwrap();
let ctes = extract_ctes(&source_tree);
assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].name, "sales");
assert!(ctes[0].column_aliases.is_empty());
}
#[test]
fn test_extract_ctes_none() {
let sql = "SELECT * FROM sales WHERE year = 2024";
let source_tree = SourceTree::new(sql).unwrap();
let ctes = extract_ctes(&source_tree);
assert!(ctes.is_empty());
}
#[test]
fn test_transform_cte_references() {
let test_cases: Vec<(
&str,
Vec<&str>,
Vec<&str>, // strings that should be in result
Option<&str>, // exact match (if result should equal this)
)> = vec![
(
"SELECT * FROM sales WHERE year = 2024",
vec!["sales"],
vec!["FROM \"__ggsql_cte_sales_", "__\" WHERE year = 2024"],
None,
),
(
"SELECT sales.date, targets.revenue FROM sales JOIN targets ON sales.id = targets.id",
vec!["sales", "targets"],
vec![
"FROM \"__ggsql_cte_sales_",
"JOIN \"__ggsql_cte_targets_",
"__ggsql_cte_sales_", "__ggsql_cte_targets_", ],
None,
),
(
"WHERE sales.date > '2024-01-01' AND sales.revenue > 100",
vec!["sales"],
vec!["__ggsql_cte_sales_"],
None,
),
(
"SELECT * FROM other_table",
vec!["sales"],
vec![],
Some("SELECT * FROM other_table"),
),
(
"SELECT * FROM sales",
vec![],
vec![],
Some("SELECT * FROM sales"),
),
(
"SELECT wholesale.date FROM wholesale",
vec!["sales"],
vec![],
Some("SELECT wholesale.date FROM wholesale"),
),
];
for (sql, cte_names_vec, expected_contains, exact_match) in test_cases {
let cte_names: HashSet<String> = cte_names_vec.iter().map(|s| s.to_string()).collect();
let result = transform_cte_references(sql, &cte_names);
if let Some(expected) = exact_match {
assert_eq!(result, expected, "SQL '{}' should remain unchanged", sql);
} else {
for expected in &expected_contains {
assert!(
result.contains(expected),
"Result '{}' should contain '{}' for SQL '{}'",
result,
expected,
sql
);
}
if !cte_names_vec.is_empty() {
assert!(
result.contains(naming::session_id()),
"Result should contain session UUID"
);
}
}
}
}
#[test]
fn test_split_with_query_basic() {
let sql = "WITH cte AS (SELECT * FROM x) SELECT * FROM cte";
let source_tree = SourceTree::new(sql).unwrap();
let (prefix, select) = split_with_query(&source_tree).unwrap();
assert_eq!(prefix, "WITH cte AS (SELECT * FROM x)");
assert_eq!(select, "SELECT * FROM cte");
}
#[test]
fn test_split_with_query_multiple_ctes() {
let sql = "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a JOIN b";
let source_tree = SourceTree::new(sql).unwrap();
let (prefix, select) = split_with_query(&source_tree).unwrap();
assert_eq!(prefix, "WITH a AS (SELECT 1), b AS (SELECT 2)");
assert_eq!(select, "SELECT * FROM a JOIN b");
}
#[test]
fn test_split_with_query_nested_subquery() {
let sql = "WITH cte AS (SELECT * FROM (SELECT 1)) SELECT * FROM cte";
let source_tree = SourceTree::new(sql).unwrap();
let (prefix, select) = split_with_query(&source_tree).unwrap();
assert_eq!(prefix, "WITH cte AS (SELECT * FROM (SELECT 1))");
assert_eq!(select, "SELECT * FROM cte");
}
#[test]
fn test_split_with_query_string_with_select_keyword() {
let sql = "WITH cte AS (SELECT 'SELECT' AS col) SELECT * FROM cte";
let source_tree = SourceTree::new(sql).unwrap();
let (prefix, select) = split_with_query(&source_tree).unwrap();
assert_eq!(prefix, "WITH cte AS (SELECT 'SELECT' AS col)");
assert_eq!(select, "SELECT * FROM cte");
}
#[test]
fn test_split_with_query_string_with_parens() {
let sql = "WITH cte AS (SELECT '()' AS col) SELECT * FROM cte";
let source_tree = SourceTree::new(sql).unwrap();
let (prefix, select) = split_with_query(&source_tree).unwrap();
assert_eq!(prefix, "WITH cte AS (SELECT '()' AS col)");
assert_eq!(select, "SELECT * FROM cte");
}
#[test]
fn test_split_with_query_not_a_with() {
let sql = "SELECT * FROM x";
let source_tree = SourceTree::new(sql).unwrap();
assert!(split_with_query(&source_tree).is_none());
}
#[test]
fn test_split_with_query_no_trailing_select() {
let sql = "WITH cte AS (SELECT 1) VISUALISE DRAW point";
let source_tree = SourceTree::new(sql).unwrap();
assert!(split_with_query(&source_tree).is_none());
}
#[test]
fn test_split_with_query_stat_transform_output() {
let sql = "WITH __stat_src__ AS (SELECT x FROM data), \
__binned__ AS (SELECT x, COUNT(*) AS count FROM __stat_src__ GROUP BY x) \
SELECT *, count * 1.0 / SUM(count) OVER () AS density FROM __binned__";
let source_tree = SourceTree::new(sql).unwrap();
let (prefix, select) = split_with_query(&source_tree).unwrap();
assert!(prefix.starts_with("WITH __stat_src__"));
assert!(prefix.contains("__binned__"));
assert!(prefix.ends_with(")"));
assert!(select.starts_with("SELECT *"));
assert!(select.contains("density"));
}
}