use std::fmt::Write;
use std::path::Path;
use scythe_backend::manifest::{BackendManifest, load_manifest};
use scythe_backend::naming::{
enum_type_name, enum_variant_name, fn_name, row_struct_name, to_pascal_case, to_snake_case,
};
use scythe_backend::types::resolve_type;
use scythe_core::analyzer::{AnalyzedQuery, CompositeInfo, EnumInfo};
use scythe_core::errors::{ErrorCode, ScytheError};
use scythe_core::parser::QueryCommand;
use crate::backend_trait::{CodegenBackend, ResolvedColumn, ResolvedParam};
use crate::singularize;
const DEFAULT_MANIFEST_TOML: &str = include_str!("../../manifests/rust-tokio-postgres.toml");
pub struct TokioPostgresBackend {
manifest: BackendManifest,
}
impl TokioPostgresBackend {
pub fn new() -> Result<Self, ScytheError> {
let manifest = load_tokio_postgres_manifest()?;
Ok(Self { manifest })
}
pub fn manifest(&self) -> &BackendManifest {
&self.manifest
}
}
fn load_tokio_postgres_manifest() -> Result<BackendManifest, ScytheError> {
let manifest_path = Path::new("backends/rust-tokio-postgres/manifest.toml");
if manifest_path.exists() {
load_manifest(manifest_path).map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("failed to load manifest: {e}"),
)
})
} else {
toml::from_str(DEFAULT_MANIFEST_TOML).map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("failed to parse embedded manifest: {e}"),
)
})
}
}
impl CodegenBackend for TokioPostgresBackend {
fn name(&self) -> &str {
"rust-tokio-postgres"
}
fn generate_row_struct(
&self,
query_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let struct_name = row_struct_name(query_name, &self.manifest.naming);
generate_struct_with_from_row(&struct_name, columns)
}
fn generate_model_struct(
&self,
table_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let singular = singularize(table_name);
let struct_name = to_pascal_case(&singular).into_owned();
generate_struct_with_from_row(&struct_name, columns)
}
fn generate_query_fn(
&self,
analyzed: &AnalyzedQuery,
struct_name: &str,
_columns: &[ResolvedColumn],
params: &[ResolvedParam],
) -> Result<String, ScytheError> {
let func_name = fn_name(&analyzed.name, &self.manifest.naming);
let mut out = String::new();
if let Some(ref msg) = analyzed.deprecated {
let _ = writeln!(out, "#[deprecated(note = \"{}\")]", msg);
}
let mut param_parts: Vec<String> = vec!["client: &tokio_postgres::Client".to_string()];
for param in params {
param_parts.push(format!("{}: {}", param.field_name, param.borrowed_type));
}
let return_type = match &analyzed.command {
QueryCommand::One => struct_name.to_string(),
QueryCommand::Many => format!("Vec<{}>", struct_name),
QueryCommand::Exec => "()".to_string(),
QueryCommand::ExecResult => "u64".to_string(),
QueryCommand::ExecRows => "u64".to_string(),
QueryCommand::Batch => format!("Vec<{}>", struct_name),
};
let _ = writeln!(
out,
"pub async fn {}({}) -> Result<{}, tokio_postgres::Error> {{",
func_name,
param_parts.join(", "),
return_type
);
let sql = super::clean_sql(&analyzed.sql);
let param_refs: String = if params.is_empty() {
"&[]".to_string()
} else {
let refs: Vec<String> = params
.iter()
.map(|p| {
if p.neutral_type.starts_with("enum::") {
format!("&{}.to_string()", p.field_name)
} else {
format!("&{}", p.field_name)
}
})
.collect();
format!("&[{}]", refs.join(", "))
};
match &analyzed.command {
QueryCommand::One => {
let _ = writeln!(
out,
" let row = client.query_one(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok({}::from_row(&row))", struct_name);
}
QueryCommand::Many | QueryCommand::Batch => {
let _ = writeln!(
out,
" let rows = client.query(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(
out,
" Ok(rows.iter().map({}::from_row).collect())",
struct_name
);
}
QueryCommand::Exec => {
let _ = writeln!(
out,
" client.execute(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok(())");
}
QueryCommand::ExecResult | QueryCommand::ExecRows => {
let _ = writeln!(
out,
" let rows_affected = client.execute(r#\"{}\"#, {}).await?;",
sql, param_refs
);
let _ = writeln!(out, " Ok(rows_affected)");
}
}
let _ = write!(out, "}}");
Ok(out)
}
fn generate_enum_def(&self, enum_info: &EnumInfo) -> Result<String, ScytheError> {
let type_name = enum_type_name(&enum_info.sql_name, &self.manifest.naming);
let mut out = String::with_capacity(512);
let _ = writeln!(out, "#[derive(Debug, Clone, PartialEq, Eq)]");
let _ = writeln!(out, "pub enum {} {{", type_name);
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(out, " {},", variant);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl std::fmt::Display for {} {{", type_name);
let _ = writeln!(
out,
" fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {{"
);
let _ = writeln!(out, " match self {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(
out,
" {}::{} => write!(f, \"{}\"),",
type_name, variant, value
);
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl std::str::FromStr for {} {{", type_name);
let _ = writeln!(out, " type Err = String;");
let _ = writeln!(
out,
" fn from_str(s: &str) -> Result<Self, Self::Err> {{"
);
let _ = writeln!(out, " match s {{");
for value in &enum_info.values {
let variant = enum_variant_name(value, &self.manifest.naming);
let _ = writeln!(
out,
" \"{}\" => Ok({}::{}),",
value, type_name, variant
);
}
let _ = writeln!(
out,
" _ => Err(format!(\"unknown variant: {{}}\", s)),"
);
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
Ok(out)
}
fn generate_composite_def(&self, composite: &CompositeInfo) -> Result<String, ScytheError> {
let struct_name = to_pascal_case(&composite.sql_name).into_owned();
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone)]");
let _ = writeln!(out, "pub struct {} {{", struct_name);
for field in &composite.fields {
let rust_type = resolve_type(&field.neutral_type, &self.manifest, false)
.map(|t| t.into_owned())
.map_err(|e| {
ScytheError::new(
ErrorCode::InternalError,
format!("composite field type error: {}", e),
)
})?;
let _ = writeln!(
out,
" pub {}: {},",
to_snake_case(&field.name),
rust_type
);
}
let _ = write!(out, "}}");
Ok(out)
}
}
fn generate_struct_with_from_row(
struct_name: &str,
columns: &[ResolvedColumn],
) -> Result<String, ScytheError> {
let mut out = String::new();
let _ = writeln!(out, "#[derive(Debug, Clone)]");
let _ = writeln!(out, "pub struct {} {{", struct_name);
for col in columns {
let _ = writeln!(out, " pub {}: {},", col.field_name, col.full_type);
}
let _ = writeln!(out, "}}");
let _ = writeln!(out);
let _ = writeln!(out, "impl {} {{", struct_name);
let _ = writeln!(
out,
" pub fn from_row(row: &tokio_postgres::Row) -> Self {{"
);
let _ = writeln!(out, " Self {{");
for col in columns {
if col.neutral_type.starts_with("enum::") {
if col.nullable {
let _ = writeln!(
out,
" {}: row.get::<_, Option<String>>(\"{}\").map(|s| s.parse().unwrap()),",
col.field_name, col.name
);
} else {
let _ = writeln!(
out,
" {}: row.get::<_, String>(\"{}\").parse().unwrap(),",
col.field_name, col.name
);
}
} else {
let _ = writeln!(
out,
" {}: row.get(\"{}\"),",
col.field_name, col.name
);
}
}
let _ = writeln!(out, " }}");
let _ = writeln!(out, " }}");
let _ = write!(out, "}}");
Ok(out)
}