use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Serialize, Deserialize)]
pub struct ExtractionSuggestion {
pub expression: String,
pub reason: ExtractionReason,
pub suggested_cte_name: String,
pub cte_query: String,
pub replacement: String,
pub complexity_score: u32,
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub enum ExtractionReason {
ComplexCalculation,
RepeatedExpression,
WindowFunction,
Subquery,
StringManipulation,
CaseStatement,
AggregateInWhere,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CTEChain {
pub ctes: Vec<CTEDefinition>,
pub main_query: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct CTEDefinition {
pub name: String,
pub query: String,
pub dependencies: Vec<String>,
pub columns: Vec<String>,
}
pub struct ExtractionAnalyzer;
impl ExtractionAnalyzer {
pub fn analyze(sql: &str) -> Vec<ExtractionSuggestion> {
let mut suggestions = Vec::new();
if sql.contains(" * ") || sql.contains(" / ") {
if let Some(expr) = Self::find_complex_calculation(sql) {
suggestions.push(ExtractionSuggestion {
expression: expr.clone(),
reason: ExtractionReason::ComplexCalculation,
suggested_cte_name: "calculated".to_string(),
cte_query: Self::generate_cte_for_calculation(&expr),
replacement: "calculated_value".to_string(),
complexity_score: 10,
});
}
}
if sql.to_uppercase().contains("CASE WHEN") {
if let Some(case_expr) = Self::find_case_statement(sql) {
suggestions.push(ExtractionSuggestion {
expression: case_expr.clone(),
reason: ExtractionReason::CaseStatement,
suggested_cte_name: "categorized".to_string(),
cte_query: Self::generate_cte_for_case(&case_expr),
replacement: "category".to_string(),
complexity_score: 15,
});
}
}
if sql.contains("SUBSTRING") || sql.contains("CONTAINS") {
if let Some(str_expr) = Self::find_string_manipulation(sql) {
suggestions.push(ExtractionSuggestion {
expression: str_expr.clone(),
reason: ExtractionReason::StringManipulation,
suggested_cte_name: "parsed".to_string(),
cte_query: Self::generate_cte_for_string(&str_expr),
replacement: "parsed_value".to_string(),
complexity_score: 12,
});
}
}
if sql.contains("OVER (") {
if let Some(window_expr) = Self::find_window_function(sql) {
suggestions.push(ExtractionSuggestion {
expression: window_expr.clone(),
reason: ExtractionReason::WindowFunction,
suggested_cte_name: "windowed".to_string(),
cte_query: Self::generate_cte_for_window(&window_expr),
replacement: "window_result".to_string(),
complexity_score: 20,
});
}
}
suggestions.sort_by_key(|s| std::cmp::Reverse(s.complexity_score));
suggestions
}
fn find_complex_calculation(sql: &str) -> Option<String> {
if sql.contains("price * quantity") {
return Some("price * quantity".to_string());
}
if sql.contains("amount * rate") {
return Some("amount * rate".to_string());
}
None
}
fn find_case_statement(sql: &str) -> Option<String> {
let upper = sql.to_uppercase();
if let Some(start) = upper.find("CASE") {
if let Some(end) = upper[start..].find("END") {
return Some(sql[start..start + end + 3].to_string());
}
}
None
}
fn find_string_manipulation(sql: &str) -> Option<String> {
if sql.contains("SUBSTRING_AFTER") {
if let Some(start) = sql.find("SUBSTRING_AFTER") {
if let Some(end) = Self::find_matching_paren(&sql[start..]) {
return Some(sql[start..start + end + 1].to_string());
}
}
}
None
}
fn find_window_function(sql: &str) -> Option<String> {
if let Some(start) = sql.find("ROW_NUMBER()") {
if let Some(over_pos) = sql[start..].find("OVER") {
if let Some(end) = Self::find_matching_paren(&sql[start + over_pos + 4..]) {
return Some(sql[start..start + over_pos + 5 + end].to_string());
}
}
}
None
}
fn find_matching_paren(s: &str) -> Option<usize> {
let mut depth = 0;
let mut in_paren = false;
for (i, ch) in s.char_indices() {
match ch {
'(' => {
depth += 1;
in_paren = true;
}
')' => {
depth -= 1;
if depth == 0 && in_paren {
return Some(i);
}
}
_ => {}
}
}
None
}
fn generate_cte_for_calculation(expr: &str) -> String {
format!("SELECT *, {} as calculated_value FROM source_table", expr)
}
fn generate_cte_for_case(expr: &str) -> String {
format!("SELECT *, {} as category FROM source_table", expr)
}
fn generate_cte_for_string(expr: &str) -> String {
format!("SELECT *, {} as parsed_value FROM source_table", expr)
}
fn generate_cte_for_window(expr: &str) -> String {
format!("SELECT *, {} as window_result FROM source_table", expr)
}
}
pub struct CTEOptimizer;
impl CTEOptimizer {
pub fn optimize_chain(chain: &CTEChain) -> Vec<String> {
let mut suggestions = Vec::new();
for i in 0..chain.ctes.len() {
for j in i + 1..chain.ctes.len() {
if Self::can_combine(&chain.ctes[i], &chain.ctes[j]) {
suggestions.push(format!(
"CTEs '{}' and '{}' could be combined to reduce complexity",
chain.ctes[i].name, chain.ctes[j].name
));
}
}
}
let used_ctes = Self::find_used_ctes(&chain.main_query, &chain.ctes);
for cte in &chain.ctes {
if !used_ctes.contains(&cte.name) {
suggestions.push(format!("CTE '{}' appears to be unused", cte.name));
}
}
if Self::is_linear_chain(&chain.ctes) {
suggestions.push("This linear CTE chain could potentially be flattened".to_string());
}
suggestions
}
fn can_combine(cte1: &CTEDefinition, cte2: &CTEDefinition) -> bool {
cte1.dependencies.contains(&cte2.name) || cte2.dependencies.contains(&cte1.name)
}
fn find_used_ctes(query: &str, ctes: &[CTEDefinition]) -> HashSet<String> {
let mut used = HashSet::new();
for cte in ctes {
if query.contains(&cte.name) {
used.insert(cte.name.clone());
}
}
used
}
fn is_linear_chain(ctes: &[CTEDefinition]) -> bool {
for i in 1..ctes.len() {
if ctes[i].dependencies.len() != 1 {
return false;
}
if !ctes[i].dependencies.contains(&ctes[i - 1].name) {
return false;
}
}
true
}
}
pub fn suggest_extraction(sql: &str) -> Result<serde_json::Value> {
let suggestions = ExtractionAnalyzer::analyze(sql);
Ok(serde_json::json!({
"original": sql,
"suggestions": suggestions,
"recommendation": if !suggestions.is_empty() {
format!("Consider extracting {} expressions to CTEs", suggestions.len())
} else {
"No extraction opportunities found".to_string()
}
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extraction_detection() {
let sql = "SELECT * FROM orders WHERE price * quantity > 1000";
let suggestions = ExtractionAnalyzer::analyze(sql);
assert!(!suggestions.is_empty());
assert_eq!(
suggestions[0].reason as u32,
ExtractionReason::ComplexCalculation as u32
);
}
#[test]
fn test_case_extraction() {
let sql = "SELECT CASE WHEN age <= 20 THEN 'young' ELSE 'old' END FROM users";
let suggestions = ExtractionAnalyzer::analyze(sql);
assert!(suggestions
.iter()
.any(|s| matches!(s.reason, ExtractionReason::CaseStatement)));
}
}