use scythe_backend::manifest::BackendManifest;
use scythe_backend::naming::{
enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use std::fmt::Write;
use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
const DEFAULT_MANIFEST_REDSHIFT: &str =
include_str!("../../manifests/rust-tokio-postgres.redshift.toml");
pub struct TokioPostgresBackend {
manifest: BackendManifest,
serde: bool,
extra_derives: Vec<String>,
}
impl TokioPostgresBackend {
pub fn new(engine: &str) -> Result<Self, ScytheError> {
let default_toml = match engine {
"postgresql" | "postgres" | "pg" => DEFAULT_MANIFEST_TOML,
"redshift" => DEFAULT_MANIFEST_REDSHIFT,
_ => {
return Err(ScytheError::new(
ErrorCode::InternalError,
format!(
"rust-tokio-postgres only supports PostgreSQL/Redshift, got engine '{}'",
engine
),
));
}
};
let manifest = super::load_or_default_manifest(
"backends/rust-tokio-postgres/manifest.toml",
default_toml,
)?;
Ok(Self {
manifest,
serde: false,
extra_derives: Vec::new(),
})
}
fn struct_derives(&self) -> String {
let mut derives = vec!["Debug", "Clone"];
if self.serde {
derives.push("serde::Serialize");
derives.push("serde::Deserialize");
}
for d in &self.extra_derives {
derives.push(d);
}
format!("#[derive({})]", derives.join(", "))
}
fn enum_derives(&self) -> String {
let mut derives = vec!["Debug", "Clone", "PartialEq", "Eq"];
if self.serde {
derives.push("serde::Serialize");
derives.push("serde::Deserialize");
}
for d in &self.extra_derives {
derives.push(d);
}
format!("#[derive({})]", derives.join(", "))
}
}
const CLIENT_PARAM: &str = "client: &(impl tokio_postgres::GenericClient + Sync)";
const ERROR_TYPE: &str = "tokio_postgres::Error";
impl CodegenBackend for TokioPostgresBackend {
fn name(&self) -> &str {
"rust-tokio-postgres"
}
fn manifest(&self) -> &scythe_backend::manifest::BackendManifest {
&self.manifest
}
fn supported_engines(&self) -> &[&str] {
&["postgresql", "redshift"]
}
fn file_header(&self) -> String {
"// Auto-generated by scythe. Do not edit.\n#![allow(dead_code, unused_imports, clippy::needless_question_mark, clippy::redundant_closure)]"
.to_string()
}
fn apply_options(
&mut self,
options: &std::collections::HashMap<String, String>,
) -> Result<(), ScytheError> {
if let Some(val) = options.get("serde") {
self.serde = val == "true";
}
if let Some(val) = options.get("derive") {
self.extra_derives = val.split(',').map(|s| s.trim().to_string()).collect();
}
Ok(())
}
fn generate_row_struct(
&self,
query_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let struct_name = row_struct_name(query_name, &self.manifest.naming);
generate_struct_with_from_row(&struct_name, columns, &self.struct_derives())
}
fn generate_model_struct(
&self,
table_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let singular = singularize(table_name);
let struct_name = to_pascal_case(&singular).into_owned();
generate_struct_with_from_row(&struct_name, columns, &self.struct_derives())
}
fn generate_query_fn(
&self,
analyzed: &AnalyzedQuery,
struct_name: &str,
_columns: &[ResolvedColumn],
params: &[ResolvedParam],
) -> Result<String, ScytheError> {
let func_name = fn_name(&analyzed.name, &self.manifest.naming);
let mut out = String::new();
if let Some(ref msg) = analyzed.deprecated {
let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
}
let mut param_parts: Vec<String> = vec![CLIENT_PARAM.to_string()];
for param in params {
param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
}
let sql = super::clean_sql_with_optional(
&analyzed.sql,
&analyzed.optional_params,
&analyzed.params,
);
if matches!(analyzed.command, QueryCommand::Batch) {
let batch_fn_name = format!("{}_batch", func_name);
if params.len() > 1 {
let params_struct_name = format!("{}BatchParams", struct_name);
let _ = writeln!(out, "{}", self.struct_derives());
let _ = writeln!(out, "pub struct {} {{", params_struct_name);
for param in params {
let _ = writeln!(out, " pub {}: {},", param.field_name, param.full_type);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(
out,
"pub async fn {}({}, items: &[{}]) -> Result<(), {}> {{",
batch_fn_name, CLIENT_PARAM, params_struct_name, ERROR_TYPE
);
let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
let _ = writeln!(out, " for item in items {{");
let refs: Vec<String> = params
.iter()
.map(|p| {
if p.neutral_type.starts_with("enum::") {
format!("&item.{}.to_string()", p.field_name)
} else {
format!("&item.{}", p.field_name)
}
})
.collect();
let _ = writeln!(
out,
" client.execute(&stmt, &[{}]).await?;",
refs.join(", ")
);
let _ = writeln!(out, " }}");
let _ = writeln!(out, " Ok(())");
} else if params.len() == 1 {
let param = ¶ms[0];
let _ = writeln!(
out,
"pub async fn {}({}, items: &[{}]) -> Result<(), {}> {{",
batch_fn_name, CLIENT_PARAM, param.full_type, ERROR_TYPE
);
let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
let _ = writeln!(out, " for item in items {{");
let _ = writeln!(out, " client.execute(&stmt, &[item]).await?;");
let _ = writeln!(out, " }}");
let _ = writeln!(out, " Ok(())");
} else {
let _ = writeln!(
out,
"pub async fn {}({}, count: usize) -> Result<(), {}> {{",
batch_fn_name, CLIENT_PARAM, ERROR_TYPE
);
let _ = writeln!(out, " let stmt = client.prepare(r#\"{}\"#).await?;", sql);
let _ = writeln!(out, " for _ in 0..count {{");
let _ = writeln!(out, " client.execute(&stmt, &[]).await?;");
let _ = writeln!(out, " }}");
let _ = writeln!(out, " Ok(())");
}
let _ = write!(out, "}}");
return Ok(out);
}
let return_type = match &analyzed.command {
QueryCommand::One => struct_name.to_string(),
QueryCommand::Opt => format!("Option<{}>", struct_name),
QueryCommand::Many => format!("Vec<{}>", struct_name),
QueryCommand::Exec => "()".to_string(),
QueryCommand::ExecResult => "u64".to_string(),
QueryCommand::ExecRows => "u64".to_string(),
QueryCommand::Batch => unreachable!(),
QueryCommand::Grouped => {
return Err(ScytheError::new(
ErrorCode::InternalError,
"Grouped queries should be rewritten before codegen".to_string(),
));
}
};
let _ = writeln!(
out,
"pub async fn {}({}) -> Result<{}, {}> {{",
func_name,
param_parts.join(", "),
return_type,
ERROR_TYPE
);
let param_refs: String = if params.is_empty() {
"&[]".to_string()
} else {
let refs: Vec<String> = params
.iter()
.map(|p| {
if p.neutral_type.starts_with("enum::") {
format!("&{}.to_string()", p.field_name)
} else {
format!("&{}", p.field_name)
}
})
.collect();
format!("&[{}]", refs.join(", "))
};
match &analyzed.command {
QueryCommand::One => {
let _ = writeln!(
out,
" let row = client.query_one(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok({}::from_row(&row))", struct_name);
}
QueryCommand::Opt => {
let _ = writeln!(
out,
" let row = client.query_opt(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok(row.as_ref().map({}::from_row))", struct_name);
}
QueryCommand::Many => {
let _ = writeln!(
out,
" let rows = client.query(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(
out,
" Ok(rows.iter().map({}::from_row).collect())",
struct_name
);
}
QueryCommand::Exec => {
let _ = writeln!(
out,
" client.execute(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok(())");
}
QueryCommand::ExecResult | QueryCommand::ExecRows => {
let _ = writeln!(
out,
" let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok(rows_affected)");
}
QueryCommand::Batch => unreachable!(),
QueryCommand::Grouped => {
return Err(ScytheError::new(
ErrorCode::InternalError,
"Grouped queries should be rewritten before codegen".to_string(),
));
}
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
let mut out = String::with_capacity(512);
let _ = writeln!(out, "{}", self.enum_derives());
let _ = writeln!(out, "pub enum {} {{", type_name);
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(out, " {},", variant);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
let _ = writeln!(
out,
" fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
);
let _ = writeln!(out, " match self {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(
out,
" {}::{} => write!(f, \"{}\"),",
type_name, variant, value
);
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
let _ = writeln!(out, " type Err = String;");
let _ = writeln!(
out,
" fn from_str(s: &str) -> Result<Self, Self::Err> {{"
);
let _ = writeln!(out, " match s {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(
out,
" \"{}\" => Ok({}::{}),",
value, type_name, variant
);
}
let _ = writeln!(
out,
" _ => Err(format!(\"unknown variant: {{}}\", s)),"
);
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
Ok(out)
}
fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
let struct_name = to_pascal_case(&composite.sql_name).into_owned();
let mut out = String::new();
let _ = writeln!(out, "{}", self.struct_derives());
let _ = writeln!(out, "pub struct {} {{", struct_name);
for field in &composite.fields {
let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
.map(|t| t.into_owned())
.map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("composite field type error: {}", e),
)
})?;
let _ = writeln!(
out,
" pub {}: {},",
to_snake_case(&field.name),
rust_type
);
}
let _ = write!(out, "}}");
Ok(out)
}
}
fn generate_struct_with_from_row(
struct_name: &str,
columns: &[ResolvedColumn],
derive_line: &str,
) -> Result<String, ScytheError> {
let mut out = String::new();
let _ = writeln!(out, "{}", derive_line);
let _ = writeln!(out, "pub struct {} {{", struct_name);
for col in columns {
let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl {} {{", struct_name);
let _ = writeln!(
out,
" pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
);
let _ = writeln!(out, " Self {{");
for col in columns {
if col.neutral_type.starts_with("enum::") {
if col.nullable {
let _ = writeln!(
out,
" {field}: row.get::<_, Option<String>>(\"{col}\").map(|s| s.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{col}': {{}}\", s))),",
field = col.field_name,
col = col.name
);
} else {
let _ = writeln!(
out,
" {field}: {{ let val: String = row.get(\"{col}\"); val.parse().unwrap_or_else(|_| panic!(\"unexpected enum value for column '{col}': {{}}\", val)) }},",
field = col.field_name,
col = col.name
);
}
} else {
let _ = writeln!(
out,
" {}: row.get(\"{}\"),",
col.field_name, col.name
);
}
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
Ok(out)
}