sqlcx-core 0.2.1

SQL-first cross-language type-safe code generator — core library
Documentation
// Shared Python query-function generator. Per-driver divergence flows
// through `PyDriverShape`: imports, connection type, async-ness, placeholder
// rewrite, params-arg formatting, and the body of each command. No client.py,
// no wrappers — queries.py imports directly from the driver package.

use std::collections::BTreeMap;
use std::path::Path;

use crate::error::Result;
use crate::generator::GeneratedFile;
use crate::generator::python::common::{
    PyTypeMap, escape_sql, generate_params_class, generate_row_class,
};
use crate::ir::{QueryCommand, QueryDef, SqlcxIR};
use crate::utils::{pascal_case, snake_case};

pub struct PyBodyCtx<'a> {
    pub sql_const: &'a str,
    pub row_type: &'a str,
    pub params_arg: &'a str,
    pub command: QueryCommand,
}

pub trait PyDriverShape: PyTypeMap {
    /// Driver-specific import line appended to the common module header.
    fn driver_import(&self) -> &'static str;
    fn connection_type(&self) -> &'static str;
    fn is_async(&self) -> bool;
    fn rewrite_sql(&self, query: &QueryDef) -> String;
    fn build_params_arg(&self, query: &QueryDef) -> String;
    fn render_body(&self, ctx: &PyBodyCtx<'_>) -> (String, String);
}

pub fn generate_query_function<D: PyDriverShape + ?Sized>(driver: &D, query: &QueryDef) -> String {
    let fn_name = snake_case(&query.name);
    let row_class = generate_row_class(driver, query);
    let params_class = generate_params_class(driver, query);
    let has_params = !query.params.is_empty();
    let params_type_name = format!("{}Params", pascal_case(&query.name));
    let rewritten_sql = driver.rewrite_sql(query);
    let sql_const_name = format!("{}_SQL", fn_name.to_uppercase());
    let sql_const = format!("{sql_const_name} = \"{}\"", escape_sql(&rewritten_sql));
    let params_sig = if has_params {
        format!(", params: {params_type_name}")
    } else {
        String::new()
    };
    let params_arg = driver.build_params_arg(query);
    let row_type_name = format!("{}Row", pascal_case(&query.name));
    let ctx = PyBodyCtx {
        sql_const: &sql_const_name,
        row_type: &row_type_name,
        params_arg: &params_arg,
        command: query.command,
    };
    let (return_type, body) = driver.render_body(&ctx);
    let def_kw = if driver.is_async() {
        "async def"
    } else {
        "def"
    };
    let signature = format!(
        "{def_kw} {fn_name}(conn: {}{params_sig}) -> {return_type}:\n{body}",
        driver.connection_type()
    );
    let mut parts: Vec<String> = Vec::new();
    if !row_class.is_empty() {
        parts.push(row_class);
    }
    if !params_class.is_empty() {
        parts.push(params_class);
    }
    parts.push(sql_const);
    parts.push(signature);
    parts.join("\n\n")
}

pub fn generate_queries_file<D: PyDriverShape + ?Sized>(
    driver: &D,
    queries: &[QueryDef],
) -> String {
    let header = format!(
        "# Code generated by sqlcx. DO NOT EDIT.\nfrom __future__ import annotations\n\nfrom dataclasses import dataclass\nfrom typing import Any\nfrom datetime import datetime\n{}",
        driver.driver_import()
    );
    let functions: Vec<String> = queries
        .iter()
        .map(|q| generate_query_function(driver, q))
        .collect();
    if functions.is_empty() {
        format!("{header}\n")
    } else {
        format!("{header}\n\n\n{}", functions.join("\n\n\n"))
    }
}

pub fn generate_driver_files<D: PyDriverShape + ?Sized>(
    driver: &D,
    ir: &SqlcxIR,
) -> Result<Vec<GeneratedFile>> {
    let mut files = Vec::new();
    let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
    for query in &ir.queries {
        grouped
            .entry(query.source_file.clone())
            .or_default()
            .push(query);
    }
    for (source_file, queries) in &grouped {
        let basename = Path::new(source_file)
            .file_stem()
            .unwrap_or_default()
            .to_string_lossy();
        let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
        files.push(GeneratedFile {
            path: format!("{}_queries.py", basename),
            content: generate_queries_file(driver, &owned),
        });
    }
    Ok(files)
}