use std::collections::BTreeMap;
use std::path::Path;
use crate::error::Result;
use crate::generator::GeneratedFile;
use crate::generator::python::common::{
PyTypeMap, escape_sql, generate_params_class, generate_row_class,
};
use crate::ir::{QueryCommand, QueryDef, SqlcxIR};
use crate::utils::{pascal_case, snake_case};
pub struct PyBodyCtx<'a> {
pub sql_const: &'a str,
pub row_type: &'a str,
pub params_arg: &'a str,
pub command: QueryCommand,
}
pub trait PyDriverShape: PyTypeMap {
fn driver_import(&self) -> &'static str;
fn connection_type(&self) -> &'static str;
fn is_async(&self) -> bool;
fn rewrite_sql(&self, query: &QueryDef) -> String;
fn build_params_arg(&self, query: &QueryDef) -> String;
fn render_body(&self, ctx: &PyBodyCtx<'_>) -> (String, String);
}
pub fn generate_query_function<D: PyDriverShape + ?Sized>(driver: &D, query: &QueryDef) -> String {
let fn_name = snake_case(&query.name);
let row_class = generate_row_class(driver, query);
let params_class = generate_params_class(driver, query);
let has_params = !query.params.is_empty();
let params_type_name = format!("{}Params", pascal_case(&query.name));
let rewritten_sql = driver.rewrite_sql(query);
let sql_const_name = format!("{}_SQL", fn_name.to_uppercase());
let sql_const = format!("{sql_const_name} = \"{}\"", escape_sql(&rewritten_sql));
let params_sig = if has_params {
format!(", params: {params_type_name}")
} else {
String::new()
};
let params_arg = driver.build_params_arg(query);
let row_type_name = format!("{}Row", pascal_case(&query.name));
let ctx = PyBodyCtx {
sql_const: &sql_const_name,
row_type: &row_type_name,
params_arg: ¶ms_arg,
command: query.command,
};
let (return_type, body) = driver.render_body(&ctx);
let def_kw = if driver.is_async() {
"async def"
} else {
"def"
};
let signature = format!(
"{def_kw} {fn_name}(conn: {}{params_sig}) -> {return_type}:\n{body}",
driver.connection_type()
);
let mut parts: Vec<String> = Vec::new();
if !row_class.is_empty() {
parts.push(row_class);
}
if !params_class.is_empty() {
parts.push(params_class);
}
parts.push(sql_const);
parts.push(signature);
parts.join("\n\n")
}
pub fn generate_queries_file<D: PyDriverShape + ?Sized>(
driver: &D,
queries: &[QueryDef],
) -> String {
let header = format!(
"# Code generated by sqlcx. DO NOT EDIT.\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Any\nfrom datetime import datetime\n{}",
driver.driver_import()
);
let functions: Vec<String> = queries
.iter()
.map(|q| generate_query_function(driver, q))
.collect();
if functions.is_empty() {
format!("{header}\n")
} else {
format!("{header}\n\n\n{}", functions.join("\n\n\n"))
}
}
pub fn generate_driver_files<D: PyDriverShape + ?Sized>(
driver: &D,
ir: &SqlcxIR,
) -> Result<Vec<GeneratedFile>> {
let mut files = Vec::new();
let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
for query in &ir.queries {
grouped
.entry(query.source_file.clone())
.or_default()
.push(query);
}
for (source_file, queries) in &grouped {
let basename = Path::new(source_file)
.file_stem()
.unwrap_or_default()
.to_string_lossy();
let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
files.push(GeneratedFile {
path: format!("{}_queries.py", basename),
content: generate_queries_file(driver, &owned),
});
}
Ok(files)
}