sqlcx-core 0.2.1

SQL-first cross-language type-safe code generator — core library
Documentation
// pgx (github.com/jackc/pgx/v5) driver generator for Go

use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::path::Path;

use crate::error::Result;
use crate::generator::go::common::{
    escape_sql, func_params, generate_result_struct, generate_row_struct, query_args, scan_fields,
    sql_const_name,
};
use crate::generator::{DriverGenerator, GeneratedFile};
use crate::ir::{ColumnDef, QueryCommand, QueryDef, SqlcxIR};
use crate::utils::pascal_case;

use super::structs::go_imports_for_columns;

pub struct PgxGenerator;

fn generate_client() -> String {
    r#"// Code generated by sqlcx. DO NOT EDIT.
package db

import (
	"context"

	"github.com/jackc/pgx/v5"
	"github.com/jackc/pgx/v5/pgconn"
)

type DBTX interface {
	Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
	Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
	QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}

type Queries struct {
	db DBTX
}

func New(db DBTX) *Queries {
	return &Queries{db: db}
}"#
    .to_string()
}

fn generate_query_function(query: &QueryDef) -> String {
    let const_name = sql_const_name(&query.name);
    let func_name = pascal_case(&query.name);
    let params = func_params(query);
    let args = query_args(query);

    let mut parts: Vec<String> = Vec::new();

    parts.push(format!(
        "const {} = \"{}\"",
        const_name,
        escape_sql(&query.sql),
    ));

    match query.command {
        QueryCommand::One => {
            let row_type = format!("{}Row", pascal_case(&query.name));
            if let Some(row_struct) = generate_row_struct(query) {
                parts.push(row_struct);
            }
            let scans = scan_fields(&query.returns);
            parts.push(format!(
                "func (q *Queries) {}({}) (*{}, error) {{\n\
                \trow := q.db.QueryRow(ctx, {}{})
\tvar i {}
\terr := row.Scan({})
\tif err != nil {{
\t\tif errors.Is(err, pgx.ErrNoRows) {{
\t\t\treturn nil, nil
\t\t}}
\t\treturn nil, err
\t}}
\treturn &i, nil\n}}",
                func_name, params, row_type, const_name, args, row_type, scans,
            ));
        }
        QueryCommand::Many => {
            let row_type = format!("{}Row", pascal_case(&query.name));
            if let Some(row_struct) = generate_row_struct(query) {
                parts.push(row_struct);
            }
            let scans = scan_fields(&query.returns);
            parts.push(format!(
                "func (q *Queries) {}({}) ([]{}, error) {{\n\
                \trows, err := q.db.Query(ctx, {}{})
\tif err != nil {{
\t\treturn nil, err
\t}}
\tdefer rows.Close()
\tvar items []{}
\tfor rows.Next() {{
\t\tvar i {}
\t\tif err := rows.Scan({}); err != nil {{
\t\t\treturn nil, err
\t\t}}
\t\titems = append(items, i)
\t}}
\treturn items, rows.Err()\n}}",
                func_name, params, row_type, const_name, args, row_type, row_type, scans,
            ));
        }
        QueryCommand::Exec => {
            parts.push(format!(
                "func (q *Queries) {}({}) error {{\n\
                \t_, err := q.db.Exec(ctx, {}{})\n\treturn err\n}}",
                func_name, params, const_name, args,
            ));
        }
        QueryCommand::ExecResult => {
            let result_type = format!("{}Result", pascal_case(&query.name));
            parts.push(generate_result_struct(query));
            parts.push(format!(
                "func (q *Queries) {}({}) (*{}, error) {{\n\
                \ttag, err := q.db.Exec(ctx, {}{})
\tif err != nil {{
\t\treturn nil, err
\t}}
\treturn &{}{{RowsAffected: tag.RowsAffected()}}, nil\n}}",
                func_name, params, result_type, const_name, args, result_type,
            ));
        }
    }

    parts.join("\n\n")
}

fn collect_query_imports(queries: &[QueryDef]) -> BTreeSet<String> {
    let mut imports = BTreeSet::new();
    imports.insert("context".to_string());

    let mut needs_pgx_err = false;
    for query in queries {
        let ret_imports = go_imports_for_columns(&query.returns);
        imports.extend(ret_imports);

        if query.command == QueryCommand::One {
            needs_pgx_err = true;
        }

        for param in &query.params {
            let col = ColumnDef {
                name: param.name.clone(),
                alias: None,
                source_table: None,
                sql_type: param.sql_type.clone(),
                nullable: false,
                has_default: false,
            };
            let col_imports = go_imports_for_columns(&[col]);
            imports.extend(col_imports);
        }
    }

    if needs_pgx_err {
        imports.insert("errors".to_string());
        imports.insert("github.com/jackc/pgx/v5".to_string());
    }

    imports
}

impl PgxGenerator {
    pub fn generate_client(&self) -> String {
        generate_client()
    }

    pub fn generate_query_file(&self, queries: &[QueryDef]) -> String {
        let imports = collect_query_imports(queries);
        let imports_str = if imports.is_empty() {
            String::new()
        } else {
            let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
            format!("\nimport (\n{}\n)\n", lines.join("\n"))
        };

        let functions: Vec<String> = queries.iter().map(generate_query_function).collect();

        let mut content = String::new();
        content.push_str("// Code generated by sqlcx. DO NOT EDIT.\npackage db\n");
        content.push_str(&imports_str);
        if !functions.is_empty() {
            content.push('\n');
            content.push_str(&functions.join("\n\n"));
            content.push('\n');
        }
        content
    }
}

impl DriverGenerator for PgxGenerator {
    fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
        let mut files = Vec::new();

        files.push(GeneratedFile {
            path: "client.go".to_string(),
            content: self.generate_client(),
        });

        let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
        for query in &ir.queries {
            grouped
                .entry(query.source_file.clone())
                .or_default()
                .push(query);
        }
        for (source_file, queries) in &grouped {
            let basename = Path::new(source_file)
                .file_stem()
                .unwrap_or_default()
                .to_string_lossy();
            let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
            files.push(GeneratedFile {
                path: format!("{}.queries.go", basename),
                content: self.generate_query_file(&owned),
            });
        }

        Ok(files)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::parser::DatabaseParser;
    use crate::parser::postgres::PostgresParser;

    fn parse_fixture_ir() -> SqlcxIR {
        let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
        let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
        let parser = PostgresParser::new();
        let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
        let queries = parser
            .parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
            .unwrap();
        SqlcxIR {
            tables,
            queries,
            enums,
        }
    }

    #[test]
    fn generates_client_file() {
        let gen_ = PgxGenerator;
        let content = gen_.generate_client();
        assert!(content.contains("github.com/jackc/pgx/v5"));
        assert!(content.contains("type DBTX interface"));
        insta::assert_snapshot!("go_pgx_client", content);
    }

    #[test]
    fn generates_query_file() {
        let ir = parse_fixture_ir();
        let gen_ = PgxGenerator;
        let content = gen_.generate_query_file(&ir.queries);
        assert!(content.contains("func (q *Queries) GetUser"));
        assert!(content.contains("func (q *Queries) ListUsers"));
        insta::assert_snapshot!("go_pgx_queries", content);
    }
}