use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::path::Path;
use crate::error::Result;
use crate::generator::go::common::{
escape_sql, func_params, generate_result_struct, generate_row_struct, query_args, scan_fields,
sql_const_name,
};
use crate::generator::{DriverGenerator, GeneratedFile};
use crate::ir::{ColumnDef, QueryCommand, QueryDef, SqlcxIR};
use crate::utils::pascal_case;
use super::structs::go_imports_for_columns;
pub struct PgxGenerator;
fn generate_client() -> String {
r#"// Code generated by sqlcx. DO NOT EDIT.
package db
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type DBTX interface {
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
}
type Queries struct {
db DBTX
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}"#
.to_string()
}
fn generate_query_function(query: &QueryDef) -> String {
let const_name = sql_const_name(&query.name);
let func_name = pascal_case(&query.name);
let params = func_params(query);
let args = query_args(query);
let mut parts: Vec<String> = Vec::new();
parts.push(format!(
"const {} = \"{}\"",
const_name,
escape_sql(&query.sql),
));
match query.command {
QueryCommand::One => {
let row_type = format!("{}Row", pascal_case(&query.name));
if let Some(row_struct) = generate_row_struct(query) {
parts.push(row_struct);
}
let scans = scan_fields(&query.returns);
parts.push(format!(
"func (q *Queries) {}({}) (*{}, error) {{\n\
\trow := q.db.QueryRow(ctx, {}{})
\tvar i {}
\terr := row.Scan({})
\tif err != nil {{
\t\tif errors.Is(err, pgx.ErrNoRows) {{
\t\t\treturn nil, nil
\t\t}}
\t\treturn nil, err
\t}}
\treturn &i, nil\n}}",
func_name, params, row_type, const_name, args, row_type, scans,
));
}
QueryCommand::Many => {
let row_type = format!("{}Row", pascal_case(&query.name));
if let Some(row_struct) = generate_row_struct(query) {
parts.push(row_struct);
}
let scans = scan_fields(&query.returns);
parts.push(format!(
"func (q *Queries) {}({}) ([]{}, error) {{\n\
\trows, err := q.db.Query(ctx, {}{})
\tif err != nil {{
\t\treturn nil, err
\t}}
\tdefer rows.Close()
\tvar items []{}
\tfor rows.Next() {{
\t\tvar i {}
\t\tif err := rows.Scan({}); err != nil {{
\t\t\treturn nil, err
\t\t}}
\t\titems = append(items, i)
\t}}
\treturn items, rows.Err()\n}}",
func_name, params, row_type, const_name, args, row_type, row_type, scans,
));
}
QueryCommand::Exec => {
parts.push(format!(
"func (q *Queries) {}({}) error {{\n\
\t_, err := q.db.Exec(ctx, {}{})\n\treturn err\n}}",
func_name, params, const_name, args,
));
}
QueryCommand::ExecResult => {
let result_type = format!("{}Result", pascal_case(&query.name));
parts.push(generate_result_struct(query));
parts.push(format!(
"func (q *Queries) {}({}) (*{}, error) {{\n\
\ttag, err := q.db.Exec(ctx, {}{})
\tif err != nil {{
\t\treturn nil, err
\t}}
\treturn &{}{{RowsAffected: tag.RowsAffected()}}, nil\n}}",
func_name, params, result_type, const_name, args, result_type,
));
}
}
parts.join("\n\n")
}
fn collect_query_imports(queries: &[QueryDef]) -> BTreeSet<String> {
let mut imports = BTreeSet::new();
imports.insert("context".to_string());
let mut needs_pgx_err = false;
for query in queries {
let ret_imports = go_imports_for_columns(&query.returns);
imports.extend(ret_imports);
if query.command == QueryCommand::One {
needs_pgx_err = true;
}
for param in &query.params {
let col = ColumnDef {
name: param.name.clone(),
alias: None,
source_table: None,
sql_type: param.sql_type.clone(),
nullable: false,
has_default: false,
};
let col_imports = go_imports_for_columns(&[col]);
imports.extend(col_imports);
}
}
if needs_pgx_err {
imports.insert("errors".to_string());
imports.insert("github.com/jackc/pgx/v5".to_string());
}
imports
}
impl PgxGenerator {
pub fn generate_client(&self) -> String {
generate_client()
}
pub fn generate_query_file(&self, queries: &[QueryDef]) -> String {
let imports = collect_query_imports(queries);
let imports_str = if imports.is_empty() {
String::new()
} else {
let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
format!("\nimport (\n{}\n)\n", lines.join("\n"))
};
let functions: Vec<String> = queries.iter().map(generate_query_function).collect();
let mut content = String::new();
content.push_str("// Code generated by sqlcx. DO NOT EDIT.\npackage db\n");
content.push_str(&imports_str);
if !functions.is_empty() {
content.push('\n');
content.push_str(&functions.join("\n\n"));
content.push('\n');
}
content
}
}
impl DriverGenerator for PgxGenerator {
fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
let mut files = Vec::new();
files.push(GeneratedFile {
path: "client.go".to_string(),
content: self.generate_client(),
});
let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
for query in &ir.queries {
grouped
.entry(query.source_file.clone())
.or_default()
.push(query);
}
for (source_file, queries) in &grouped {
let basename = Path::new(source_file)
.file_stem()
.unwrap_or_default()
.to_string_lossy();
let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
files.push(GeneratedFile {
path: format!("{}.queries.go", basename),
content: self.generate_query_file(&owned),
});
}
Ok(files)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::DatabaseParser;
use crate::parser::postgres::PostgresParser;
fn parse_fixture_ir() -> SqlcxIR {
let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
let parser = PostgresParser::new();
let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
let queries = parser
.parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
.unwrap();
SqlcxIR {
tables,
queries,
enums,
}
}
#[test]
fn generates_client_file() {
let gen_ = PgxGenerator;
let content = gen_.generate_client();
assert!(content.contains("github.com/jackc/pgx/v5"));
assert!(content.contains("type DBTX interface"));
insta::assert_snapshot!("go_pgx_client", content);
}
#[test]
fn generates_query_file() {
let ir = parse_fixture_ir();
let gen_ = PgxGenerator;
let content = gen_.generate_query_file(&ir.queries);
assert!(content.contains("func (q *Queries) GetUser"));
assert!(content.contains("func (q *Queries) ListUsers"));
insta::assert_snapshot!("go_pgx_queries", content);
}
}