use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::path::Path;
use crate::error::Result;
use crate::generator::go::common::{
escape_sql, func_params, generate_result_struct, generate_row_struct, query_args, scan_fields,
sql_const_name,
};
use crate::generator::{DriverGenerator, GeneratedFile};
use crate::ir::{ColumnDef, QueryCommand, QueryDef, SqlcxIR};
use crate::utils::pascal_case;
use super::structs::go_imports_for_columns;
#[derive(Debug, Clone, Copy)]
pub enum DatabaseSqlBackend {
Postgres,
MySql,
Sqlite,
}
impl DatabaseSqlBackend {
fn rewrite_placeholders(self, sql: &str) -> (String, Vec<u32>) {
match self {
DatabaseSqlBackend::Postgres => (sql.to_string(), Vec::new()),
DatabaseSqlBackend::MySql | DatabaseSqlBackend::Sqlite => {
let mut out = String::with_capacity(sql.len());
let mut indices = Vec::new();
let mut chars = sql.chars().peekable();
let mut in_string = false;
while let Some(c) = chars.next() {
if c == '\'' {
if in_string && chars.peek() == Some(&'\'') {
out.push(c);
out.push(chars.next().unwrap());
continue;
}
in_string = !in_string;
out.push(c);
continue;
}
if !in_string && c == '$' && chars.peek().is_some_and(|ch| ch.is_ascii_digit())
{
let mut num = String::new();
while chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
num.push(chars.next().unwrap());
}
indices.push(num.parse::<u32>().unwrap_or(0));
out.push('?');
} else {
out.push(c);
}
}
(out, indices)
}
}
}
}
pub struct DatabaseSqlGenerator {
backend: DatabaseSqlBackend,
}
impl DatabaseSqlGenerator {
pub fn postgres() -> Self {
Self {
backend: DatabaseSqlBackend::Postgres,
}
}
pub fn mysql() -> Self {
Self {
backend: DatabaseSqlBackend::MySql,
}
}
pub fn sqlite() -> Self {
Self {
backend: DatabaseSqlBackend::Sqlite,
}
}
}
impl Default for DatabaseSqlGenerator {
fn default() -> Self {
Self::postgres()
}
}
fn generate_client() -> String {
r#"// Code generated by sqlcx. DO NOT EDIT.
package db
import (
"context"
"database/sql"
)
type DBTX interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type Queries struct {
db DBTX
}
func New(db DBTX) *Queries {
return &Queries{db: db}
}"#
.to_string()
}
fn generate_query_function(backend: DatabaseSqlBackend, query: &QueryDef) -> String {
let const_name = sql_const_name(&query.name);
let func_name = pascal_case(&query.name);
let params = func_params(query);
let mut parts: Vec<String> = Vec::new();
let (rewritten_sql, occurrence_indices) = backend.rewrite_placeholders(&query.sql);
parts.push(format!(
"const {} = \"{}\"",
const_name,
escape_sql(&rewritten_sql),
));
let args = if occurrence_indices.is_empty() {
query_args(query)
} else {
let names: Vec<String> = occurrence_indices
.iter()
.map(|idx| {
query
.params
.iter()
.find(|p| p.index == *idx)
.map(|p| p.name.clone())
.unwrap_or_else(|| "nil".to_string())
})
.collect();
if names.is_empty() {
String::new()
} else {
format!(", {}", names.join(", "))
}
};
match query.command {
QueryCommand::One => {
let row_type = format!("{}Row", pascal_case(&query.name));
if let Some(row_struct) = generate_row_struct(query) {
parts.push(row_struct);
}
let scans = scan_fields(&query.returns);
parts.push(format!(
"func (q *Queries) {}({}) (*{}, error) {{\n\
\trow := q.db.QueryRowContext(ctx, {}{})
\tvar i {}
\terr := row.Scan({})
\tif err == sql.ErrNoRows {{
\t\treturn nil, nil
\t}}
\tif err != nil {{
\t\treturn nil, err
\t}}
\treturn &i, nil\n}}",
func_name, params, row_type, const_name, args, row_type, scans,
));
}
QueryCommand::Many => {
let row_type = format!("{}Row", pascal_case(&query.name));
if let Some(row_struct) = generate_row_struct(query) {
parts.push(row_struct);
}
let scans = scan_fields(&query.returns);
parts.push(format!(
"func (q *Queries) {}({}) ([]{}, error) {{\n\
\trows, err := q.db.QueryContext(ctx, {}{})
\tif err != nil {{
\t\treturn nil, err
\t}}
\tdefer rows.Close()
\tvar items []{}
\tfor rows.Next() {{
\t\tvar i {}
\t\tif err := rows.Scan({}); err != nil {{
\t\t\treturn nil, err
\t\t}}
\t\titems = append(items, i)
\t}}
\treturn items, rows.Err()\n}}",
func_name, params, row_type, const_name, args, row_type, row_type, scans,
));
}
QueryCommand::Exec => {
parts.push(format!(
"func (q *Queries) {}({}) error {{\n\
\t_, err := q.db.ExecContext(ctx, {}{})\n\treturn err\n}}",
func_name, params, const_name, args,
));
}
QueryCommand::ExecResult => {
let result_type = format!("{}Result", pascal_case(&query.name));
parts.push(generate_result_struct(query));
parts.push(format!(
"func (q *Queries) {}({}) (*{}, error) {{\n\
\tresult, err := q.db.ExecContext(ctx, {}{})
\tif err != nil {{
\t\treturn nil, err
\t}}
\taffected, err := result.RowsAffected()
\tif err != nil {{
\t\treturn nil, err
\t}}
\treturn &{}{{RowsAffected: affected}}, nil\n}}",
func_name, params, result_type, const_name, args, result_type,
));
}
}
parts.join("\n\n")
}
fn collect_query_imports(queries: &[QueryDef]) -> BTreeSet<String> {
let mut imports = BTreeSet::new();
imports.insert("context".to_string());
let mut needs_database_sql = false;
for query in queries {
let ret_imports = go_imports_for_columns(&query.returns);
imports.extend(ret_imports);
match query.command {
QueryCommand::One | QueryCommand::ExecResult => {
needs_database_sql = true;
}
_ => {}
}
for param in &query.params {
let col = ColumnDef {
name: param.name.clone(),
alias: None,
source_table: None,
sql_type: param.sql_type.clone(),
nullable: false,
has_default: false,
};
let col_imports = go_imports_for_columns(&[col]);
imports.extend(col_imports);
}
}
if needs_database_sql {
imports.insert("database/sql".to_string());
}
imports
}
impl DatabaseSqlGenerator {
pub fn generate_client(&self) -> String {
generate_client()
}
pub fn generate_query_file(&self, queries: &[QueryDef]) -> String {
let imports = collect_query_imports(queries);
let imports_str = if imports.is_empty() {
String::new()
} else {
let lines: Vec<String> = imports.iter().map(|i| format!("\t\"{}\"", i)).collect();
format!("\nimport (\n{}\n)\n", lines.join("\n"))
};
let functions: Vec<String> = queries
.iter()
.map(|q| generate_query_function(self.backend, q))
.collect();
let mut content = String::new();
content.push_str("// Code generated by sqlcx. DO NOT EDIT.\npackage db\n");
content.push_str(&imports_str);
if !functions.is_empty() {
content.push('\n');
content.push_str(&functions.join("\n\n"));
content.push('\n');
}
content
}
}
impl DriverGenerator for DatabaseSqlGenerator {
fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
let mut files = Vec::new();
files.push(GeneratedFile {
path: "client.go".to_string(),
content: self.generate_client(),
});
let mut grouped: BTreeMap<String, Vec<&QueryDef>> = BTreeMap::new();
for query in &ir.queries {
grouped
.entry(query.source_file.clone())
.or_default()
.push(query);
}
for (source_file, queries) in &grouped {
let basename = Path::new(source_file)
.file_stem()
.unwrap_or_default()
.to_string_lossy();
let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
files.push(GeneratedFile {
path: format!("{}.queries.go", basename),
content: self.generate_query_file(&owned),
});
}
Ok(files)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::DatabaseParser;
use crate::parser::postgres::PostgresParser;
fn parse_fixture_ir() -> SqlcxIR {
let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
let parser = PostgresParser::new();
let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
let queries = parser
.parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
.unwrap();
SqlcxIR {
tables,
queries,
enums,
}
}
#[test]
fn generates_client_file() {
let gen_ = DatabaseSqlGenerator::postgres();
let content = gen_.generate_client();
assert!(content.contains("type DBTX interface"));
assert!(content.contains("type Queries struct"));
assert!(content.contains("func New(db DBTX) *Queries"));
insta::assert_snapshot!("go_database_sql_client", content);
}
#[test]
fn generates_query_file() {
let ir = parse_fixture_ir();
let gen_ = DatabaseSqlGenerator::postgres();
let content = gen_.generate_query_file(&ir.queries);
assert!(content.contains("func (q *Queries) GetUser"));
assert!(content.contains("func (q *Queries) ListUsers"));
assert!(content.contains("func (q *Queries) CreateUser"));
assert!(content.contains("func (q *Queries) DeleteUser"));
assert!(content.contains("WHERE id = $1"));
insta::assert_snapshot!("go_database_sql_queries", content);
}
#[test]
fn mysql_backend_rewrites_placeholders_to_question_marks() {
let ir = parse_fixture_ir();
let gen_ = DatabaseSqlGenerator::mysql();
let content = gen_.generate_query_file(&ir.queries);
assert!(content.contains("WHERE id = ?"));
assert!(!content.contains("WHERE id = $1"));
insta::assert_snapshot!("go_database_sql_mysql_queries", content);
}
#[test]
fn sqlite_backend_rewrites_placeholders_to_question_marks() {
let ir = parse_fixture_ir();
let gen_ = DatabaseSqlGenerator::sqlite();
let content = gen_.generate_query_file(&ir.queries);
assert!(content.contains("WHERE id = ?"));
assert!(!content.contains("WHERE id = $1"));
insta::assert_snapshot!("go_database_sql_sqlite_queries", content);
}
#[test]
fn placeholder_rewrite_preserves_dollar_in_string_literals() {
let (rewritten, idx) =
DatabaseSqlBackend::MySql.rewrite_placeholders("SELECT '$1' FROM x WHERE a = $1");
assert_eq!(rewritten, "SELECT '$1' FROM x WHERE a = ?");
assert_eq!(idx, vec![1]);
}
#[test]
fn reused_param_emits_repeated_args_in_mysql() {
use crate::ir::{ParamDef, SqlType, SqlTypeCategory};
let query = QueryDef {
name: "Search".to_string(),
command: QueryCommand::Many,
sql: "SELECT id FROM users WHERE name = $1 OR email = $1".to_string(),
params: vec![ParamDef {
index: 1,
name: "q".to_string(),
sql_type: SqlType {
raw: "text".to_string(),
normalized: "text".to_string(),
category: SqlTypeCategory::String,
element_type: None,
enum_name: None,
enum_values: None,
json_shape: None,
},
}],
returns: vec![],
source_file: "q.sql".to_string(),
};
let out = generate_query_function(DatabaseSqlBackend::MySql, &query);
assert_eq!(out.matches('?').count(), 2);
assert!(out.contains(", q, q)"), "expected `, q, q)` in: {out}");
}
#[test]
fn out_of_order_params_bind_in_document_order_in_sqlite() {
use crate::ir::{ParamDef, SqlType, SqlTypeCategory};
let int_type = SqlType {
raw: "integer".to_string(),
normalized: "integer".to_string(),
category: SqlTypeCategory::Number,
element_type: None,
enum_name: None,
enum_values: None,
json_shape: None,
};
let query = QueryDef {
name: "Range".to_string(),
command: QueryCommand::Many,
sql: "SELECT id FROM t WHERE b = $2 AND a = $1".to_string(),
params: vec![
ParamDef {
index: 1,
name: "a".to_string(),
sql_type: int_type.clone(),
},
ParamDef {
index: 2,
name: "b".to_string(),
sql_type: int_type,
},
],
returns: vec![],
source_file: "q.sql".to_string(),
};
let out = generate_query_function(DatabaseSqlBackend::Sqlite, &query);
assert!(out.contains(", b, a)"), "expected `, b, a)` in: {out}");
assert!(!out.contains(", a, b)"));
}
}