use std::fmt::Write;
use std::path::Path;
use scythe_backend::manifest::{BackendManifest, load_manifest};
use scythe_backend::naming::{
enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_core::analyzer::{AnalyzedColumn, AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-sqlx.toml");
pub struct SqlxBackend {
manifest: BackendManifest,
}
impl SqlxBackend {
pub fn new() -> Result<Self, ScytheError> {
let manifest = load_sqlx_manifest()?;
Ok(Self { manifest })
}
pub fn manifest(&self) -> &BackendManifest {
&self.manifest
}
}
fn load_sqlx_manifest() -> Result<BackendManifest, ScytheError> {
let manifest_path = Path::new("backends/rust-sqlx/manifest.toml");
if manifest_path.exists() {
load_manifest(manifest_path).map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("failed to load manifest: {e}"),
)
})
} else {
toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("failed to parse embedded manifest: {e}"),
)
})
}
}
impl CodegenBackend for SqlxBackend {
fn name(&self) -> &str {
"rust-sqlx"
}
fn generate_row_struct(
&self,
query_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let struct_name = row_struct_name(query_name, &self.manifest.naming);
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
let _ = writeln!(out, "pub struct {} {{", struct_name);
for col in columns {
let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_model_struct(
&self,
table_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let singular = singularize(table_name);
let struct_name = to_pascal_case(&singular).into_owned();
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, sqlx::FromRow)]");
let _ = writeln!(out, "pub struct {} {{", struct_name);
for col in columns {
let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_query_fn(
&self,
analyzed: &AnalyzedQuery,
struct_name: &str,
_columns: &[ResolvedColumn],
params: &[ResolvedParam],
) -> Result<String, ScytheError> {
let func_name = fn_name(&analyzed.name, &self.manifest.naming);
let mut out = String::new();
if let Some(ref msg) = analyzed.deprecated {
let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
}
let mut param_parts: Vec<String> = vec!["pool: &sqlx::PgPool".to_string()];
for param in params {
param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
}
let return_type = match &analyzed.command {
QueryCommand::One => struct_name.to_string(),
QueryCommand::Many => format!("Vec<{}>", struct_name),
QueryCommand::Exec => "()".to_string(),
QueryCommand::ExecResult => "sqlx::postgres::PgQueryResult".to_string(),
QueryCommand::ExecRows => "u64".to_string(),
QueryCommand::Batch => format!("Vec<{}>", struct_name),
};
let _ = writeln!(
out,
"pub async fn {}({}) -> Result<{}, sqlx::Error> {{",
func_name,
param_parts.join(", "),
return_type
);
let sql_raw = super::clean_sql(&analyzed.sql);
let sql = rewrite_sql_for_enums(&sql_raw, &analyzed.columns, &self.manifest);
let has_row_struct = matches!(
analyzed.command,
QueryCommand::One | QueryCommand::Many | QueryCommand::Batch
);
let bind_params: String = analyzed
.params
.iter()
.map(|p| {
let param_name = to_snake_case(&p.name);
if p.neutral_type.starts_with("enum::") {
let enum_name = p.neutral_type.strip_prefix("enum::").unwrap();
let rust_type = enum_type_name(enum_name, &self.manifest.naming);
format!(", {} as &{}", param_name, rust_type)
} else {
format!(", {}", param_name)
}
})
.collect();
let is_exec_rows = matches!(analyzed.command, QueryCommand::ExecRows);
if is_exec_rows {
if has_row_struct && !analyzed.columns.is_empty() {
let _ = write!(
out,
" let result = sqlx::query_as!({}, \"{}\"{})",
struct_name, sql, bind_params
);
} else {
let _ = write!(
out,
" let result = sqlx::query!(\"{}\"{})",
sql, bind_params
);
}
} else if has_row_struct && !analyzed.columns.is_empty() {
let _ = write!(
out,
" sqlx::query_as!({}, \"{}\"{})",
struct_name, sql, bind_params
);
} else {
let _ = write!(out, " sqlx::query!(\"{}\"{})", sql, bind_params);
}
let _ = writeln!(out);
let fetch_method = match &analyzed.command {
QueryCommand::One => ".fetch_one(pool)",
QueryCommand::Many => ".fetch_all(pool)",
QueryCommand::Exec => ".execute(pool)",
QueryCommand::ExecResult => ".execute(pool)",
QueryCommand::ExecRows => ".execute(pool)",
QueryCommand::Batch => ".fetch_all(pool)",
};
let _ = write!(out, " {}", fetch_method);
let _ = writeln!(out);
match &analyzed.command {
QueryCommand::Exec => {
let _ = writeln!(out, " .await?;");
let _ = writeln!(out, " Ok(())");
}
QueryCommand::ExecRows => {
let _ = writeln!(out, " .await?;");
let _ = writeln!(out, " Ok(result.rows_affected())");
}
_ => {
let _ = writeln!(out, " .await");
}
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
let mut out = String::with_capacity(256);
let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
let _ = writeln!(
out,
"#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
enum_info.sql_name
);
let _ = writeln!(out, "pub enum {type_name} {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(out, " {variant},");
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
use scythe_backend::types::resolve_type;
let struct_name = to_pascal_case(&composite.sql_name).into_owned();
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone, sqlx::Type)]");
let _ = writeln!(out, "#[sqlx(type_name = \"{}\")]", composite.sql_name);
let _ = writeln!(out, "pub struct {} {{", struct_name);
for field in &composite.fields {
let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
.map(|t| t.into_owned())
.map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("composite field type error: {}", e),
)
})?;
let _ = writeln!(
out,
" pub {}: {},",
to_snake_case(&field.name),
rust_type
);
}
let _ = write!(out, "}}");
Ok(out)
}
}
fn rewrite_sql_for_enums(
sql: &str,
columns: &[AnalyzedColumn],
manifest: &BackendManifest,
) -> String {
let enum_cols: Vec<(&str, String)> = columns
.iter()
.filter_map(|col| {
if let Some(enum_name) = col.neutral_type.strip_prefix("enum::") {
let rust_type = enum_type_name(enum_name, &manifest.naming);
let annotation = if col.nullable {
format!("Option<{}>", rust_type)
} else {
rust_type
};
Some((col.name.as_str(), annotation))
} else {
None
}
})
.collect();
if enum_cols.is_empty() {
return sql.to_string();
}
let mut result = sql.to_string();
for (col_name, annotation) in &enum_cols {
let alias = format!("{} AS \\\"{}: {}\\\"", col_name, col_name, annotation);
if let Some(from_pos) = result.to_uppercase().find(" FROM ") {
let select_part = &result[..from_pos];
let rest = &result[from_pos..];
let new_select = replace_column_in_select(select_part, col_name, &alias);
result = format!("{}{}", new_select, rest);
}
}
result
}
fn replace_column_in_select(select: &str, col_name: &str, replacement: &str) -> String {
let mut result = select.to_string();
let patterns = [format!(", {}", col_name), format!(" {}", col_name)];
for pattern in &patterns {
if let Some(pos) = result.rfind(pattern.as_str()) {
let after = pos + pattern.len();
let next_char = result[after..].chars().next();
if next_char.is_none() || matches!(next_char, Some(' ') | Some(',') | Some('\n')) {
let prefix = &result[..pos + pattern.len() - col_name.len()];
let suffix = &result[after..];
result = format!("{}{}{}", prefix, replacement, suffix);
break;
}
}
}
result
}