pub mod backend_trait;
pub mod backends;
pub mod resolve;
pub mod validation;
pub use backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
pub use backends::get_backend;
use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{row_struct_name, to_pascal_case};
use scythe_core::analyzer::{AnalyzedQuery, EnumInfo};
use scythe_core::catalog::Catalog;
use scythe_core::errors::ScytheError;
use scythe_core::parser::QueryCommand;
#[derive(Debug, Default)]
pub struct GeneratedCode {
pub query_fn: Option<String>,
pub row_struct: Option<String>,
pub model_struct: Option<String>,
pub enum_def: Option<String>,
}
pub(crate) fn singularize(name: &str) -> String {
if let Some(stem) = name.strip_suffix("ies") {
format!("{stem}y")
} else if name.ends_with("sses")
|| name.ends_with("shes")
|| name.ends_with("ches")
|| name.ends_with("xes")
|| name.ends_with("zes")
|| name.ends_with("ses")
{
name[..name.len() - 2].to_string()
} else if name.ends_with('s') && !name.ends_with("ss") {
name[..name.len() - 1].to_string()
} else {
name.to_string()
}
}
pub fn get_manifest_for_backend(backend_name: &str) -> Result<BackendManifest, ScytheError> {
let backend = get_backend(backend_name, "postgresql")?;
Ok(backend.manifest().clone())
}
fn determine_struct_name(analyzed: &AnalyzedQuery, manifest: &BackendManifest) -> String {
if let Some(ref table_name) = analyzed.source_table {
let singular = singularize(table_name);
to_pascal_case(&singular).into_owned()
} else {
row_struct_name(&analyzed.name, &manifest.naming)
}
}
pub fn generate_with_backend(
analyzed: &AnalyzedQuery,
backend: &dyn CodegenBackend,
) -> Result<GeneratedCode, ScytheError> {
let manifest = backend.manifest();
let columns = resolve::resolve_columns(&analyzed.columns, manifest)?;
let params = resolve::resolve_params(&analyzed.params, manifest)?;
let mut result = GeneratedCode::default();
let enum_def = generate_enum_defs_via_backend(analyzed, backend)?;
if !enum_def.is_empty() {
result.enum_def = Some(enum_def);
}
let needs_row_struct = matches!(analyzed.command, QueryCommand::One | QueryCommand::Many);
if needs_row_struct && !analyzed.columns.is_empty() {
if let Some(ref table_name) = analyzed.source_table {
result.model_struct = Some(backend.generate_model_struct(table_name, &columns)?);
} else {
result.row_struct = Some(backend.generate_row_struct(&analyzed.name, &columns)?);
}
}
if !analyzed.composites.is_empty() {
let mut comp_defs = String::new();
for (i, comp) in analyzed.composites.iter().enumerate() {
if i > 0 {
comp_defs.push_str("\n\n");
}
comp_defs.push_str(&backend.generate_composite_def(comp)?);
}
if !comp_defs.is_empty() {
if let Some(ref mut existing) = result.model_struct {
existing.push_str("\n\n");
existing.push_str(&comp_defs);
} else {
result.model_struct = Some(comp_defs);
}
}
}
let struct_name = determine_struct_name(analyzed, manifest);
result.query_fn = Some(backend.generate_query_fn(analyzed, &struct_name, &columns, ¶ms)?);
Ok(result)
}
fn generate_enum_defs_via_backend(
analyzed: &AnalyzedQuery,
backend: &dyn CodegenBackend,
) -> Result<String, ScytheError> {
use ahash::AHashSet;
use std::fmt::Write;
let mut out = String::new();
let mut seen_enums: AHashSet<String> = AHashSet::new();
let enum_sources: Vec<&str> = analyzed
.columns
.iter()
.filter_map(|col| col.neutral_type.strip_prefix("enum::"))
.chain(
analyzed
.params
.iter()
.filter_map(|p| p.neutral_type.strip_prefix("enum::")),
)
.collect();
for sql_name in enum_sources {
if !seen_enums.insert(sql_name.to_string()) {
continue;
}
if !out.is_empty() {
let _ = writeln!(out);
}
if let Some(enum_info) = analyzed.enums.iter().find(|e| e.sql_name == sql_name) {
out.push_str(&backend.generate_enum_def(enum_info)?);
} else {
let stub_info = EnumInfo {
sql_name: sql_name.to_string(),
values: vec![],
};
out.push_str(&backend.generate_enum_def(&stub_info)?);
}
}
Ok(out)
}
pub fn generate(analyzed: &AnalyzedQuery) -> Result<GeneratedCode, ScytheError> {
let backend = get_backend("rust-sqlx", "postgresql")?;
generate_with_backend(analyzed, &*backend)
}
pub fn generate_from_catalog(_catalog: &Catalog) -> Result<GeneratedCode, ScytheError> {
Ok(GeneratedCode::default())
}
pub fn generate_single_enum_def_with_backend(
enum_info: &EnumInfo,
backend: &dyn CodegenBackend,
) -> Result<String, ScytheError> {
backend.generate_enum_def(enum_info)
}
pub fn generate_single_enum_def(enum_info: &EnumInfo, manifest: &BackendManifest) -> String {
use scythe_backend::naming::{enum_type_name, enum_variant_name};
use std::fmt::Write;
let mut out = String::with_capacity(256);
let type_name = enum_type_name(&enum_info.sql_name, &manifest.naming);
let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq, sqlx::Type)]");
let _ = writeln!(
out,
"#[sqlx(type_name = \"{}\", rename_all = \"snake_case\")]",
enum_info.sql_name
);
let _ = writeln!(out, "pub enum {type_name} {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &manifest.naming);
let _ = writeln!(out, " {variant},");
}
let _ = write!(out, "}}");
out
}
pub fn load_or_default_manifest() -> Result<BackendManifest, ScytheError> {
let b = backends::sqlx::SqlxBackend::new("postgresql")?;
Ok(b.manifest().clone())
}
#[cfg(test)]
mod tests {
use super::*;
use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
use scythe_core::parser::QueryCommand;
fn make_query(
name: &str,
command: QueryCommand,
sql: &str,
columns: Vec<AnalyzedColumn>,
params: Vec<AnalyzedParam>,
) -> AnalyzedQuery {
AnalyzedQuery {
name: name.to_string(),
command,
sql: sql.to_string(),
columns,
params,
deprecated: None,
source_table: None,
composites: Vec::new(),
enums: Vec::new(),
}
}
#[test]
fn test_generate_select_many() {
let query = make_query(
"ListUsers",
QueryCommand::Many,
"SELECT id, name, email FROM users",
vec![
AnalyzedColumn {
name: "id".to_string(),
neutral_type: "int32".to_string(),
nullable: false,
},
AnalyzedColumn {
name: "name".to_string(),
neutral_type: "string".to_string(),
nullable: false,
},
AnalyzedColumn {
name: "email".to_string(),
neutral_type: "string".to_string(),
nullable: true,
},
],
vec![],
);
let result = generate(&query).unwrap();
let row_struct = result.row_struct.unwrap();
assert!(row_struct.contains("pub struct ListUsersRow"));
assert!(row_struct.contains("pub id: i32"));
assert!(row_struct.contains("pub name: String"));
assert!(row_struct.contains("pub email: Option<String>"));
let query_fn = result.query_fn.unwrap();
assert!(query_fn.contains("pub async fn list_users("));
assert!(query_fn.contains("-> Result<Vec<ListUsersRow>, sqlx::Error>"));
assert!(query_fn.contains(".fetch_all(pool)"));
}
#[test]
fn test_generate_select_one_with_param() {
let query = make_query(
"GetUser",
QueryCommand::One,
"SELECT id, name FROM users WHERE id = $1",
vec![
AnalyzedColumn {
name: "id".to_string(),
neutral_type: "int32".to_string(),
nullable: false,
},
AnalyzedColumn {
name: "name".to_string(),
neutral_type: "string".to_string(),
nullable: false,
},
],
vec![AnalyzedParam {
name: "id".to_string(),
neutral_type: "int32".to_string(),
nullable: false,
position: 1,
}],
);
let result = generate(&query).unwrap();
let query_fn = result.query_fn.unwrap();
assert!(query_fn.contains("pub async fn get_user("));
assert!(query_fn.contains("id: i32"));
assert!(query_fn.contains("-> Result<GetUserRow, sqlx::Error>"));
assert!(query_fn.contains(".fetch_one(pool)"));
}
#[test]
fn test_generate_exec() {
let query = make_query(
"DeleteUser",
QueryCommand::Exec,
"DELETE FROM users WHERE id = $1",
vec![],
vec![AnalyzedParam {
name: "id".to_string(),
neutral_type: "int32".to_string(),
nullable: false,
position: 1,
}],
);
let result = generate(&query).unwrap();
assert!(result.row_struct.is_none());
let query_fn = result.query_fn.unwrap();
assert!(query_fn.contains("pub async fn delete_user("));
assert!(query_fn.contains("-> Result<(), sqlx::Error>"));
assert!(query_fn.contains(".execute(pool)"));
}
#[test]
fn test_generate_with_enum_column() {
let query = make_query(
"GetUserStatus",
QueryCommand::One,
"SELECT id, status FROM users WHERE id = $1",
vec![
AnalyzedColumn {
name: "id".to_string(),
neutral_type: "int32".to_string(),
nullable: false,
},
AnalyzedColumn {
name: "status".to_string(),
neutral_type: "enum::user_status".to_string(),
nullable: false,
},
],
vec![AnalyzedParam {
name: "id".to_string(),
neutral_type: "int32".to_string(),
nullable: false,
position: 1,
}],
);
let result = generate(&query).unwrap();
assert!(result.enum_def.is_some());
let enum_def = result.enum_def.unwrap();
assert!(enum_def.contains("pub enum UserStatus"));
assert!(enum_def.contains("type_name = \"user_status\""));
let row_struct = result.row_struct.unwrap();
assert!(row_struct.contains("pub status: UserStatus"));
}
#[test]
fn test_generate_from_catalog_returns_default() {
let catalog = Catalog::from_ddl(&["CREATE TABLE t (id INTEGER);"]).unwrap();
let result = generate_from_catalog(&catalog).unwrap();
assert!(result.query_fn.is_none());
assert!(result.row_struct.is_none());
}
#[test]
fn test_singularize_basic() {
assert_eq!(singularize("users"), "user");
assert_eq!(singularize("orders"), "order");
assert_eq!(singularize("posts"), "post");
}
#[test]
fn test_singularize_ies() {
assert_eq!(singularize("categories"), "category");
assert_eq!(singularize("entries"), "entry");
}
#[test]
fn test_singularize_sses() {
assert_eq!(singularize("addresses"), "address");
assert_eq!(singularize("classes"), "class");
}
#[test]
fn test_singularize_no_change() {
assert_eq!(singularize("status"), "statu");
assert_eq!(singularize("boss"), "boss");
assert_eq!(singularize("address"), "address");
}
#[test]
fn test_singularize_shes_ches_xes() {
assert_eq!(singularize("batches"), "batch");
assert_eq!(singularize("boxes"), "box");
assert_eq!(singularize("wishes"), "wish");
}
#[test]
fn test_tokio_postgres_backend_basic() {
let backend = get_backend("tokio-postgres", "postgresql").unwrap();
let query = make_query(
"ListUsers",
QueryCommand::Many,
"SELECT id, name FROM users",
vec![
AnalyzedColumn {
name: "id".to_string(),
neutral_type: "int32".to_string(),
nullable: false,
},
AnalyzedColumn {
name: "name".to_string(),
neutral_type: "string".to_string(),
nullable: false,
},
],
vec![],
);
let result = generate_with_backend(&query, &*backend).unwrap();
let row_struct = result.row_struct.unwrap();
assert!(row_struct.contains("pub struct ListUsersRow"));
assert!(row_struct.contains("pub id: i32"));
assert!(row_struct.contains("pub name: String"));
assert!(row_struct.contains("from_row"));
assert!(row_struct.contains("tokio_postgres::Row"));
assert!(!row_struct.contains("sqlx"));
let query_fn = result.query_fn.unwrap();
assert!(query_fn.contains("pub async fn list_users("));
assert!(query_fn.contains("tokio_postgres::Client"));
assert!(query_fn.contains("tokio_postgres::Error"));
assert!(!query_fn.contains("sqlx"));
}
#[test]
fn test_tokio_postgres_enum() {
let backend = get_backend("tokio-postgres", "postgresql").unwrap();
let enum_info = scythe_core::analyzer::EnumInfo {
sql_name: "user_status".to_string(),
values: vec!["active".to_string(), "inactive".to_string()],
};
let def = backend.generate_enum_def(&enum_info).unwrap();
assert!(def.contains("pub enum UserStatus"));
assert!(def.contains("Active"));
assert!(def.contains("Inactive"));
assert!(def.contains("impl std::fmt::Display"));
assert!(def.contains("impl std::str::FromStr"));
assert!(!def.contains("sqlx"));
}
}