use std::fmt::Write;
use std::path::Path;
use scythe_backend::manifest::{BackendManifest, load_manifest};
use scythe_backend::naming::{
enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
pub struct TokioPostgresBackend {
manifest: BackendManifest,
}
impl TokioPostgresBackend {
pub fn new(engine: &str) -> Result<Self, ScytheError> {
match engine {
"postgresql" | "postgres" | "pg" => {}
_ => {
return Err(ScytheError::new(
ErrorCode::InternalError,
format!(
"rust-tokio-postgres only supports PostgreSQL, got engine '{}'",
engine
),
));
}
}
let manifest = load_tokio_postgres_manifest()?;
Ok(Self { manifest })
}
}
fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
if manifest_path.exists() {
load_manifest(manifest_path).map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("failed to load manifest: {e}"),
)
})
} else {
toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("failed to parse embedded manifest: {e}"),
)
})
}
}
impl CodegenBackend for TokioPostgresBackend {
fn name(&self) -> &str {
"rust-tokio-postgres"
}
fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
&self.manifest
}
fn file_header(&self) -> String {
"// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::all)]"
.to_string()
}
fn generate_row_struct(
&self,
query_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let struct_name = row_struct_name(query_name, &self.manifest.naming);
generate_struct_with_from_row(&struct_name, columns)
}
fn generate_model_struct(
&self,
table_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let singular = singularize(table_name);
let struct_name = to_pascal_case(&singular).into_owned();
generate_struct_with_from_row(&struct_name, columns)
}
fn generate_query_fn(
&self,
analyzed: &AnalyzedQuery,
struct_name: &str,
_columns: &[ResolvedColumn],
params: &[ResolvedParam],
) -> Result<String, ScytheError> {
let func_name = fn_name(&analyzed.name, &self.manifest.naming);
let mut out = String::new();
if let Some(ref msg) = analyzed.deprecated {
let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
}
let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
for param in params {
param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
}
let sql = super::clean_sql_with_optional(
&analyzed.sql,
&analyzed.optional_params,
&analyzed.params,
);
if matches!(analyzed.command, QueryCommand::Batch) {
let batch_fn_name = format!("{}_batch", func_name);
if params.len() > 1 {
let params_struct_name = format!("{}BatchParams", struct_name);
let _ = writeln!(out, "#[derive(Debug, Clone)]");
let _ = writeln!(out, "pub struct {} {{", params_struct_name);
for param in params {
let _ = writeln!(out, " pub {}: {},", param.field_name, param.full_type);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(
out,
"pub async fn {}(client: &tokio_postgres::Client, items: &[{}]) -> Result<(), tokio_postgres::Error> {{",
batch_fn_name, params_struct_name
);
let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
let _ = writeln!(out, " let tx = client.transaction().await?;");
let _ = writeln!(out, " for item in items {{");
let refs: Vec<String> = params
.iter()
.map(|p| {
if p.neutral_type.starts_with("enum::") {
format!("&item.{}.to_string()", p.field_name)
} else {
format!("&item.{}", p.field_name)
}
})
.collect();
let _ = writeln!(
out,
" tx.execute(&stmt, &[{}]).await?;",
refs.join(", ")
);
let _ = writeln!(out, " }}");
let _ = writeln!(out, " tx.commit().await?;");
let _ = writeln!(out, " Ok(())");
} else if params.len() == 1 {
let param = ¶ms[0];
let _ = writeln!(
out,
"pub async fn {}(client: &tokio_postgres::Client, items: &[{}]) -> Result<(), tokio_postgres::Error> {{",
batch_fn_name, param.full_type
);
let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
let _ = writeln!(out, " let tx = client.transaction().await?;");
let _ = writeln!(out, " for item in items {{");
let _ = writeln!(out, " tx.execute(&stmt, &[item]).await?;");
let _ = writeln!(out, " }}");
let _ = writeln!(out, " tx.commit().await?;");
let _ = writeln!(out, " Ok(())");
} else {
let _ = writeln!(
out,
"pub async fn {}(client: &tokio_postgres::Client, count: usize) -> Result<(), tokio_postgres::Error> {{",
batch_fn_name
);
let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
let _ = writeln!(out, " let tx = client.transaction().await?;");
let _ = writeln!(out, " for _ in 0..count {{");
let _ = writeln!(out, " tx.execute(&stmt, &[]).await?;");
let _ = writeln!(out, " }}");
let _ = writeln!(out, " tx.commit().await?;");
let _ = writeln!(out, " Ok(())");
}
let _ = write!(out, "}}");
return Ok(out);
}
let return_type = match &analyzed.command {
QueryCommand::One => struct_name.to_string(),
QueryCommand::Many => format!("Vec<{}>", struct_name),
QueryCommand::Exec => "()".to_string(),
QueryCommand::ExecResult => "u64".to_string(),
QueryCommand::ExecRows => "u64".to_string(),
QueryCommand::Batch => unreachable!(),
};
let _ = writeln!(
out,
"pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
func_name,
param_parts.join(", "),
return_type
);
let param_refs: String = if params.is_empty() {
"&[]".to_string()
} else {
let refs: Vec<String> = params
.iter()
.map(|p| {
if p.neutral_type.starts_with("enum::") {
format!("&{}.to_string()", p.field_name)
} else {
format!("&{}", p.field_name)
}
})
.collect();
format!("&[{}]", refs.join(", "))
};
match &analyzed.command {
QueryCommand::One => {
let _ = writeln!(
out,
" let row = client.query_one(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok({}::from_row(&row))", struct_name);
}
QueryCommand::Many => {
let _ = writeln!(
out,
" let rows = client.query(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(
out,
" Ok(rows.iter().map({}::from_row).collect())",
struct_name
);
}
QueryCommand::Exec => {
let _ = writeln!(
out,
" client.execute(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok(())");
}
QueryCommand::ExecResult | QueryCommand::ExecRows => {
let _ = writeln!(
out,
" let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok(rows_affected)");
}
QueryCommand::Batch => unreachable!(),
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
let mut out = String::with_capacity(512);
let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
let _ = writeln!(out, "pub enum {} {{", type_name);
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(out, " {},", variant);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
let _ = writeln!(
out,
" fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
);
let _ = writeln!(out, " match self {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(
out,
" {}::{} => write!(f, \"{}\"),",
type_name, variant, value
);
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
let _ = writeln!(out, " type Err = String;");
let _ = writeln!(
out,
" fn from_str(s: &str) -> Result<Self, Self::Err> {{"
);
let _ = writeln!(out, " match s {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(
out,
" \"{}\" => Ok({}::{}),",
value, type_name, variant
);
}
let _ = writeln!(
out,
" _ => Err(format!(\"unknown variant: {{}}\", s)),"
);
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
Ok(out)
}
fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
let struct_name = to_pascal_case(&composite.sql_name).into_owned();
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone)]");
let _ = writeln!(out, "pub struct {} {{", struct_name);
for field in &composite.fields {
let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
.map(|t| t.into_owned())
.map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("composite field type error: {}", e),
)
})?;
let _ = writeln!(
out,
" pub {}: {},",
to_snake_case(&field.name),
rust_type
);
}
let _ = write!(out, "}}");
Ok(out)
}
}
fn generate_struct_with_from_row(
struct_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone)]");
let _ = writeln!(out, "pub struct {} {{", struct_name);
for col in columns {
let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl {} {{", struct_name);
let _ = writeln!(
out,
" pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
);
let _ = writeln!(out, " Self {{");
for col in columns {
if col.neutral_type.starts_with("enum::") {
if col.nullable {
let _ = writeln!(
out,
" {field}: row.get::<_, Option<String>>(\"{col}\").map(|s| s.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", s))),",
field = col.field_name,
col = col.name
);
} else {
let _ = writeln!(
out,
" {field}: {{ let val = row.get::<_, String>(\"{col}\"); val.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{{}}': {{}}\", \"{col}\", val)) }},",
field = col.field_name,
col = col.name
);
}
} else {
let _ = writeln!(
out,
" {}: row.get(\"{}\"),",
col.field_name, col.name
);
}
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
Ok(out)
}