use std::path::Path;
use crate::error::Result;
use crate::generator::rust_lang::common::{param_type, row_field_type};
use crate::generator::{DriverGenerator, GeneratedFile};
use crate::ir::{QueryCommand, QueryDef, SqlcxIR};
use crate::utils::{pascal_case, snake_case};
#[derive(Debug, Clone, Copy)]
pub enum SqlxBackend {
Postgres,
MySql,
Sqlite,
}
impl SqlxBackend {
fn pool_type(self) -> &'static str {
match self {
SqlxBackend::Postgres => "sqlx::PgPool",
SqlxBackend::MySql => "sqlx::MySqlPool",
SqlxBackend::Sqlite => "sqlx::SqlitePool",
}
}
fn rewrite_placeholders(self, sql: &str) -> (String, Vec<u32>) {
match self {
SqlxBackend::Postgres => (sql.to_string(), Vec::new()),
SqlxBackend::MySql | SqlxBackend::Sqlite => {
let mut out = String::with_capacity(sql.len());
let mut indices = Vec::new();
let mut chars = sql.chars().peekable();
let mut in_string = false;
while let Some(c) = chars.next() {
if c == '\'' {
if in_string && chars.peek() == Some(&'\'') {
out.push(c);
out.push(chars.next().unwrap());
continue;
}
in_string = !in_string;
out.push(c);
continue;
}
if !in_string && c == '$' && chars.peek().is_some_and(|ch| ch.is_ascii_digit())
{
let mut num = String::new();
while chars.peek().is_some_and(|ch| ch.is_ascii_digit()) {
num.push(chars.next().unwrap());
}
indices.push(num.parse::<u32>().unwrap_or(0));
out.push('?');
} else {
out.push(c);
}
}
(out, indices)
}
}
}
}
pub struct SqlxGenerator {
backend: SqlxBackend,
}
impl SqlxGenerator {
pub fn postgres() -> Self {
Self {
backend: SqlxBackend::Postgres,
}
}
pub fn mysql() -> Self {
Self {
backend: SqlxBackend::MySql,
}
}
pub fn sqlite() -> Self {
Self {
backend: SqlxBackend::Sqlite,
}
}
}
impl Default for SqlxGenerator {
fn default() -> Self {
Self::postgres()
}
}
fn generate_row_struct(query: &QueryDef) -> String {
if query.returns.is_empty() {
return String::new();
}
let type_name = format!("{}Row", pascal_case(&query.name));
let fields: Vec<String> = query
.returns
.iter()
.map(|col| {
let field_name = col.alias.as_deref().unwrap_or(&col.name);
format!(" pub {}: {},", field_name, row_field_type(col))
})
.collect();
format!(
"#[derive(Debug, Clone, sqlx::FromRow)]\npub struct {} {{\n{}\n}}",
type_name,
fields.join("\n")
)
}
fn generate_result_struct(query: &QueryDef) -> String {
let type_name = format!("{}Result", pascal_case(&query.name));
format!(
"#[derive(Debug, Clone)]\npub struct {} {{\n pub rows_affected: u64,\n}}",
type_name
)
}
fn generate_query_function(backend: SqlxBackend, query: &QueryDef) -> String {
let fn_name = snake_case(&query.name);
let sql_const_name = format!("{}_SQL", fn_name.to_uppercase());
let mut parts: Vec<String> = Vec::new();
let (rewritten_sql, occurrence_indices) = backend.rewrite_placeholders(&query.sql);
parts.push(format!(
"pub const {}: &str = {:?};",
sql_const_name, rewritten_sql
));
let row_struct = generate_row_struct(query);
if !row_struct.is_empty() {
parts.push(row_struct);
}
if query.command == QueryCommand::ExecResult {
parts.push(generate_result_struct(query));
}
let mut params_sig = format!("pool: &{}", backend.pool_type());
for p in &query.params {
let ptype = param_type(&p.sql_type);
params_sig.push_str(&format!(", {}: {}", p.name, ptype));
}
let binds: String = if occurrence_indices.is_empty() {
query
.params
.iter()
.map(|p| format!("\n .bind({})", p.name))
.collect()
} else {
let reused: std::collections::HashSet<u32> = {
let mut once = std::collections::HashSet::new();
let mut dup = std::collections::HashSet::new();
for idx in &occurrence_indices {
if !once.insert(*idx) {
dup.insert(*idx);
}
}
dup
};
occurrence_indices
.iter()
.map(|idx| {
let param_name = query
.params
.iter()
.find(|p| p.index == *idx)
.map(|p| p.name.as_str())
.unwrap_or("unknown");
let expr = if reused.contains(idx) {
format!("{param_name}.clone()")
} else {
param_name.to_string()
};
format!("\n .bind({expr})")
})
.collect()
};
let (return_type, body) = match query.command {
QueryCommand::One => {
let type_name = format!("{}Row", pascal_case(&query.name));
(
format!("Result<Option<{}>, sqlx::Error>", type_name),
format!(
" sqlx::query_as::<_, {}>({}){}
.fetch_optional(pool)
.await",
type_name, sql_const_name, binds
),
)
}
QueryCommand::Many => {
let type_name = format!("{}Row", pascal_case(&query.name));
(
format!("Result<Vec<{}>, sqlx::Error>", type_name),
format!(
" sqlx::query_as::<_, {}>({}){}
.fetch_all(pool)
.await",
type_name, sql_const_name, binds
),
)
}
QueryCommand::Exec => (
"Result<(), sqlx::Error>".to_string(),
format!(
" sqlx::query({}){}
.execute(pool)
.await
.map(|_| ())",
sql_const_name, binds
),
),
QueryCommand::ExecResult => {
let result_type = format!("{}Result", pascal_case(&query.name));
(
format!("Result<{}, sqlx::Error>", result_type),
format!(
" let result = sqlx::query({}){}
.execute(pool)
.await?;
Ok({} {{ rows_affected: result.rows_affected() }})",
sql_const_name, binds, result_type
),
)
}
};
parts.push(format!(
"pub async fn {}({}) -> {} {{\n{}\n}}",
fn_name, params_sig, return_type, body
));
parts.join("\n\n")
}
impl SqlxGenerator {
pub fn generate_client(&self) -> String {
"// Code generated by sqlcx. DO NOT EDIT.\n\n\
// This module uses sqlx for database access.\n\
// Pass a &sqlx::PgPool, &sqlx::MySqlPool, or &sqlx::SqlitePool\n\
// to the query functions below."
.to_string()
}
pub fn generate_query_functions(&self, queries: &[QueryDef]) -> String {
let header = "// Code generated by sqlcx. DO NOT EDIT.\n\nuse sqlx;";
let functions: Vec<String> = queries
.iter()
.map(|q| generate_query_function(self.backend, q))
.collect();
if functions.is_empty() {
return format!("{header}\n");
}
format!("{header}\n\n{}", functions.join("\n\n"))
}
}
impl DriverGenerator for SqlxGenerator {
fn generate(&self, ir: &SqlcxIR) -> Result<Vec<GeneratedFile>> {
let mut files = Vec::new();
files.push(GeneratedFile {
path: "client.rs".to_string(),
content: self.generate_client(),
});
let mut grouped: std::collections::BTreeMap<String, Vec<&QueryDef>> =
std::collections::BTreeMap::new();
for query in &ir.queries {
grouped
.entry(query.source_file.clone())
.or_default()
.push(query);
}
for (source_file, queries) in &grouped {
let basename = Path::new(source_file)
.file_stem()
.unwrap_or_default()
.to_string_lossy();
let owned: Vec<QueryDef> = queries.iter().map(|q| (*q).clone()).collect();
files.push(GeneratedFile {
path: format!("{}_queries.rs", basename),
content: self.generate_query_functions(&owned),
});
}
Ok(files)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::*;
use crate::parser::DatabaseParser;
use crate::parser::postgres::PostgresParser;
fn parse_fixture_ir() -> SqlcxIR {
let schema_sql = include_str!("../../../../../tests/fixtures/schema.sql");
let queries_sql = include_str!("../../../../../tests/fixtures/queries/users.sql");
let parser = PostgresParser::new();
let (tables, enums) = parser.parse_schema(schema_sql).unwrap();
let queries = parser
.parse_queries(queries_sql, &tables, &enums, "queries/users.sql")
.unwrap();
SqlcxIR {
tables,
queries,
enums,
}
}
#[test]
fn generates_client_file() {
let gen_ = SqlxGenerator::postgres();
let content = gen_.generate_client();
assert!(content.contains("sqlx"));
assert!(content.contains("DO NOT EDIT"));
insta::assert_snapshot!("sqlx_client", content);
}
#[test]
fn generates_query_functions() {
let ir = parse_fixture_ir();
let gen_ = SqlxGenerator::postgres();
let content = gen_.generate_query_functions(&ir.queries);
assert!(content.contains("pub async fn get_user"));
assert!(content.contains("pub struct GetUserRow"));
assert!(content.contains("GET_USER_SQL"));
assert!(content.contains("pub async fn list_users"));
assert!(content.contains("pub async fn create_user"));
assert!(content.contains("pub async fn delete_user"));
assert!(content.contains("DeleteUserResult"));
assert!(content.contains("pool: &sqlx::PgPool"));
insta::assert_snapshot!("sqlx_queries", content);
}
#[test]
fn mysql_backend_rewrites_placeholders_and_uses_mysqlpool() {
let ir = parse_fixture_ir();
let gen_ = SqlxGenerator::mysql();
let content = gen_.generate_query_functions(&ir.queries);
assert!(content.contains("pool: &sqlx::MySqlPool"));
assert!(!content.contains("pool: &sqlx::PgPool"));
assert!(content.contains("WHERE id = ?"));
assert!(!content.contains("WHERE id = $1"));
insta::assert_snapshot!("sqlx_mysql_queries", content);
}
#[test]
fn sqlite_backend_rewrites_placeholders_and_uses_sqlitepool() {
let ir = parse_fixture_ir();
let gen_ = SqlxGenerator::sqlite();
let content = gen_.generate_query_functions(&ir.queries);
assert!(content.contains("pool: &sqlx::SqlitePool"));
assert!(!content.contains("pool: &sqlx::PgPool"));
assert!(content.contains("WHERE id = ?"));
insta::assert_snapshot!("sqlx_sqlite_queries", content);
}
#[test]
fn placeholder_rewrite_preserves_nonparam_dollars() {
let (sql, idx) =
SqlxBackend::MySql.rewrite_placeholders("SELECT '$foo' FROM x WHERE a = $1");
assert_eq!(sql, "SELECT '$foo' FROM x WHERE a = ?");
assert_eq!(idx, vec![1]);
let (sql, idx) =
SqlxBackend::Postgres.rewrite_placeholders("SELECT '$foo' FROM x WHERE a = $1");
assert_eq!(sql, "SELECT '$foo' FROM x WHERE a = $1");
assert!(idx.is_empty()); }
#[test]
fn rewrite_tracks_occurrence_indices_for_reused_params() {
let (sql, idx) = SqlxBackend::MySql.rewrite_placeholders("WHERE x = $1 OR y = $1");
assert_eq!(sql, "WHERE x = ? OR y = ?");
assert_eq!(idx, vec![1, 1]);
}
#[test]
fn rewrite_tracks_occurrence_indices_for_out_of_order_params() {
let (sql, idx) = SqlxBackend::Sqlite.rewrite_placeholders("WHERE b = $2 AND a = $1");
assert_eq!(sql, "WHERE b = ? AND a = ?");
assert_eq!(idx, vec![2, 1]);
}
#[test]
fn reused_param_in_mysql_body_clones_every_use() {
let query = QueryDef {
name: "SearchUsers".to_string(),
command: QueryCommand::Many,
sql: "SELECT * FROM users WHERE name ILIKE $1 OR email ILIKE $1".to_string(),
params: vec![ParamDef {
index: 1,
name: "q".to_string(),
sql_type: SqlType {
raw: "text".to_string(),
normalized: "text".to_string(),
category: SqlTypeCategory::String,
element_type: None,
enum_name: None,
enum_values: None,
json_shape: None,
},
}],
returns: vec![],
source_file: "q.sql".to_string(),
};
let out = generate_query_function(SqlxBackend::MySql, &query);
let clone_count = out.matches(".bind(q.clone())").count();
assert_eq!(clone_count, 2, "expected two .bind(q.clone()), got: {out}");
let out = generate_query_function(SqlxBackend::Postgres, &query);
assert!(out.contains(".bind(q)"));
assert!(!out.contains(".bind(q.clone())"));
}
#[test]
fn single_use_param_does_not_clone() {
let query = QueryDef {
name: "GetOne".to_string(),
command: QueryCommand::One,
sql: "SELECT * FROM t WHERE id = $1".to_string(),
params: vec![ParamDef {
index: 1,
name: "id".to_string(),
sql_type: SqlType {
raw: "integer".to_string(),
normalized: "integer".to_string(),
category: SqlTypeCategory::Number,
element_type: None,
enum_name: None,
enum_values: None,
json_shape: None,
},
}],
returns: vec![],
source_file: "q.sql".to_string(),
};
let out = generate_query_function(SqlxBackend::MySql, &query);
assert!(out.contains(".bind(id)"));
assert!(!out.contains(".bind(id.clone())"));
}
#[test]
fn rewrite_preserves_dollar_n_inside_string_literals() {
let (sql, idx) =
SqlxBackend::MySql.rewrite_placeholders("SELECT '$1' FROM users WHERE id = $1");
assert_eq!(sql, "SELECT '$1' FROM users WHERE id = ?");
assert_eq!(idx, vec![1]);
let (sql, idx) =
SqlxBackend::MySql.rewrite_placeholders("SELECT 'O''Brien $1' FROM x WHERE id = $1");
assert_eq!(sql, "SELECT 'O''Brien $1' FROM x WHERE id = ?");
assert_eq!(idx, vec![1]);
}
#[test]
fn snake_case_conversion() {
assert_eq!(snake_case("GetUser"), "get_user");
assert_eq!(snake_case("get_user"), "get_user");
assert_eq!(snake_case("ListUsers"), "list_users");
assert_eq!(snake_case("CreateUser"), "create_user");
}
#[test]
fn param_type_uses_references_for_strings() {
let sql_type = SqlType {
raw: "text".to_string(),
normalized: "text".to_string(),
category: SqlTypeCategory::String,
element_type: None,
enum_name: None,
enum_values: None,
json_shape: None,
};
assert_eq!(param_type(&sql_type), "&str");
}
#[test]
fn param_type_keeps_primitives_by_value() {
let sql_type = SqlType {
raw: "integer".to_string(),
normalized: "integer".to_string(),
category: SqlTypeCategory::Number,
element_type: None,
enum_name: None,
enum_values: None,
json_shape: None,
};
assert_eq!(param_type(&sql_type), "i32");
}
}