sqlcx-core 0.2.1

SQL-first cross-language type-safe code generator — core library
Documentation
// tokio-postgres driver generator for Rust

use std::path::Path;

use crate::error::Result;
use crate::generator::rust_lang::common::{param_type, row_field_type};
use crate::generator::{DriverGenerator, GeneratedFile};
use crate::ir::{QueryCommand, QueryDef, SqlcxIR};
use crate::utils::{pascal_case, snake_case};

pub struct TokioPostgresGenerator;

fn generate_row_struct(query: &QueryDef) -> String {
    if query.returns.is_empty() {
        return String::new();
    }
    let type_name = format!("{}Row", pascal_case(&query.name));
    let fields: Vec<String> = query
        .returns
        .iter()
        .map(|col| {
            let field_name = col.alias.as_deref().unwrap_or(&col.name);
            format!("    pub {}: {},", field_name, row_field_type(col))
        })
        .collect();
    format!(
        "#[derive(Debug, Clone)]\npub struct {} {{\n{}\n}}",
        type_name,
        fields.join("\n")
    )
}

fn generate_row_from_impl(query: &QueryDef) -> String {
    if query.returns.is_empty() {
        return String::new();
    }
    let type_name = format!("{}Row", pascal_case(&query.name));
    let field_mappings: Vec<String> = query
        .returns
        .iter()
        .enumerate()
        .map(|(i, col)| {
            let field_name = col.alias.as_deref().unwrap_or(&col.name);
            format!("            {}: row.get({}),", field_name, i)
        })
        .collect();
    format!(
        "impl {} {{\n    fn from_row(row: &tokio_postgres::Row) -> Self {{\n        Self {{\n{}\n        }}\n    }}\n}}",
        type_name,
        field_mappings.join("\n")
    )
}

fn generate_query_function(query: &QueryDef) -> String {
    let fn_name = snake_case(&query.name);
    let sql_const_name = format!("{}_SQL", fn_name.to_uppercase());

    let mut parts: Vec<String> = Vec::new();

    parts.push(format!(
        "pub const {}: &str = {:?};",
        sql_const_name, query.sql
    ));

    let row_struct = generate_row_struct(query);
    if !row_struct.is_empty() {
        parts.push(row_struct);
        parts.push(generate_row_from_impl(query));
    }

    let mut params_sig = String::from("client: &tokio_postgres::Client");
    for p in &query.params {
        let ptype = param_type(&p.sql_type);
        params_sig.push_str(&format!(", {}: {}", p.name, ptype));
    }

    let params_array = if query.params.is_empty() {
        "&[]".to_string()
    } else {
        let args: Vec<String> = query
            .params
            .iter()
            .map(|p| format!("&{}", p.name))
            .collect();
        format!("&[{}]", args.join(", "))
    };

    let (return_type, body) = match query.command {
        QueryCommand::One => {
            let type_name = format!("{}Row", pascal_case(&query.name));
            (
                format!(
                    "std::result::Result<Option<{}>, tokio_postgres::Error>",
                    type_name
                ),
                format!(
                    "    let row = client.query_opt({}, {}).await?;\n    Ok(row.map(|r| {}::from_row(&r)))",
                    sql_const_name, params_array, type_name
                ),
            )
        }
        QueryCommand::Many => {
            let type_name = format!("{}Row", pascal_case(&query.name));
            (
                format!(
                    "std::result::Result<Vec<{}>, tokio_postgres::Error>",
                    type_name
                ),
                format!(
                    "    let rows = client.query({}, {}).await?;\n    Ok(rows.iter().map(|r| {}::from_row(r)).collect())",
                    sql_const_name, params_array, type_name
                ),
            )
        }
        QueryCommand::Exec => (
            "std::result::Result<(), tokio_postgres::Error>".to_string(),
            format!(
                "    client.execute({}, {}).await?;\n    Ok(())",
                sql_const_name, params_array
            ),
        ),
        QueryCommand::ExecResult => (
            "std::result::Result<u64, tokio_postgres::Error>".to_string(),
            format!(
                "    let count = client.execute({}, {}).await?;\n    Ok(count)",
                sql_const_name, params_array
            ),
        ),
    };

    parts.push(format!(
        "pub async fn {}({}) -> {} {{\n{}\n}}",
        fn_name, params_sig, return_type, body
    ));

    parts.join("\n\n")
}

impl TokioPostgresGenerator {
    pub fn generate_client(&self) -> String {
        "// Code generated by sqlcx. DO NOT EDIT.\n\n\
         // This module uses tokio-postgres for database access.\n\
         // Pass a &tokio_postgres::Client to the query functions below."
            .to_string()
    }

    pub fn generate_query_functions(&self, queries: &[QueryDef]) -> String {
        let header = "// Code generated by sqlcx. DO NOT EDIT.";
        let functions: Vec<String> = queries.iter().map(generate_query_function).collect();
        if functions.is_empty() {
            return format!("{header}\n");
        }
        format!("{header}\n\n{}", functions.join("\n\n"))
    }
}

impl DriverGenerator for TokioPostgresGenerator {
    fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
        let mut files = Vec::new();

        files.push(GeneratedFile {
            path: "client.rs".to_string(),
            content: self.generate_client(),
        });

        let mut grouped: std::collections::BTreeMap<String, Vec<&QueryDef>> =
            std::collections::BTreeMap::new();
        for query in &ir.queries {
            grouped
                .entry(query.source_file.clone())
                .or_default()
                .push(query);
        }
        for (source_file, queries) in &grouped {
            let basename = Path::new(source_file)
                .file_stem()
                .unwrap_or_default()
                .to_string_lossy();
            let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
            files.push(GeneratedFile {
                path: format!("{}_queries.rs", basename),
                content: self.generate_query_functions(&owned),
            });
        }

        Ok(files)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::ir::*;
    use crate::parser::DatabaseParser;
    use crate::parser::postgres::PostgresParser;

    fn parse_fixture_ir() -> SqlcxIR {
        let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
        let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
        let parser = PostgresParser::new();
        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
        let queries = parser
            .parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
            .unwrap();
        SqlcxIR {
            tables,
            queries,
            enums,
        }
    }

    #[test]
    fn generates_client_file() {
        let gen_ = TokioPostgresGenerator;
        let content = gen_.generate_client();
        assert!(content.contains("tokio-postgres"));
        assert!(content.contains("DO NOT EDIT"));
        insta::assert_snapshot!("tokio_postgres_client", content);
    }

    #[test]
    fn generates_query_functions() {
        let ir = parse_fixture_ir();
        let gen_ = TokioPostgresGenerator;
        let content = gen_.generate_query_functions(&ir.queries);
        assert!(content.contains("pub async fn get_user"));
        assert!(content.contains("pub struct GetUserRow"));
        assert!(content.contains("query_opt"));
        assert!(content.contains("from_row"));
        insta::assert_snapshot!("tokio_postgres_queries", content);
    }
}