use tokio_postgres::GenericClient;
use crate::schema_shared::{self, SharedSchemaError};
pub use crate::schema_shared::{ColumnPair, IdKind};
#[derive(Debug)]
pub enum SchemaError {
TokioPostgres(tokio_postgres::Error),
InvalidIdentifier(String),
}
impl std::fmt::Display for SchemaError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SchemaError::TokioPostgres(e) => write!(f, "tokio-postgres error: {e}"),
SchemaError::InvalidIdentifier(s) => {
write!(f, "invalid Postgres identifier: {s:?}")
}
}
}
}
impl std::error::Error for SchemaError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
SchemaError::TokioPostgres(e) => Some(e),
SchemaError::InvalidIdentifier(_) => None,
}
}
}
impl From<tokio_postgres::Error> for SchemaError {
fn from(e: tokio_postgres::Error) -> Self {
SchemaError::TokioPostgres(e)
}
}
impl From<SharedSchemaError> for SchemaError {
fn from(e: SharedSchemaError) -> Self {
match e {
SharedSchemaError::InvalidIdentifier(s) => SchemaError::InvalidIdentifier(s),
}
}
}
pub const SCHEMA_SQL: &str = include_str!("../sql/schema.sql");
pub const SESSION_SQL: &str = include_str!("../sql/functions/session.sql");
pub const GENERATE_HEERID_SQL: &str = include_str!("../sql/functions/generate_heerid.sql");
pub const GENERATE_RANJID_SQL: &str = include_str!("../sql/functions/generate_ranjid.sql");
pub const INSTALL_SQL: &str = concat!(
include_str!("../sql/schema.sql"),
"\n",
include_str!("../sql/functions/session.sql"),
"\n",
include_str!("../sql/functions/generate_heerid.sql"),
"\n",
include_str!("../sql/functions/generate_ranjid.sql"),
);
pub const SEED_SQL: &str = include_str!("../sql/seed.sql");
pub const DESC_FLIP_SQL: &str = include_str!("../sql/functions/desc_flip.sql");
pub const DESC_GENERATORS_SQL: &str = include_str!("../sql/functions/desc_generators.sql");
pub const BULK_BACKFILL_SQL: &str = include_str!("../sql/functions/bulk_backfill.sql");
pub async fn install_schema<C>(client: &C) -> Result<(), tokio_postgres::Error>
where
C: GenericClient + ?Sized,
{
client.batch_execute(INSTALL_SQL).await
}
pub async fn seed_default_node<C>(client: &C) -> Result<(), tokio_postgres::Error>
where
C: GenericClient + ?Sized,
{
client.batch_execute(SEED_SQL).await
}
pub async fn install_flip_functions<C>(client: &C) -> Result<(), tokio_postgres::Error>
where
C: GenericClient + ?Sized,
{
client.batch_execute(DESC_FLIP_SQL).await
}
pub async fn install_desc_generators<C>(client: &C) -> Result<(), tokio_postgres::Error>
where
C: GenericClient + ?Sized,
{
client.batch_execute(DESC_GENERATORS_SQL).await
}
pub async fn install_migration_support<C>(client: &C) -> Result<(), tokio_postgres::Error>
where
C: GenericClient + ?Sized,
{
client.batch_execute(BULK_BACKFILL_SQL).await
}
pub async fn install_all_desc_support<C>(client: &C) -> Result<(), tokio_postgres::Error>
where
C: GenericClient + ?Sized,
{
install_flip_functions(client).await?;
install_desc_generators(client).await?;
install_migration_support(client).await?;
Ok(())
}
fn validate_ident(s: &str) -> Result<(), SchemaError> {
schema_shared::validate_ident(s).map_err(Into::into)
}
pub async fn install_autofill_trigger_for_table<C>(
client: &C,
table: &str,
pairs: &[ColumnPair<'_>],
kind: IdKind,
) -> Result<(), SchemaError>
where
C: GenericClient + ?Sized,
{
assert!(!pairs.is_empty(), "at least one ColumnPair required");
validate_ident(table)?;
for p in pairs {
validate_ident(p.src)?;
validate_ident(p.dst)?;
}
let flip_fn = kind.flip_fn();
let fn_name = format!("zzz_{}_autofill_desc", table);
let trig_name = &fn_name;
let mut insert_body = String::new();
let mut update_body = String::new();
for p in pairs {
use std::fmt::Write as _;
writeln!(
insert_body,
" IF NEW.{dst} IS NULL THEN NEW.{dst} := {flip}(NEW.{src}); END IF;",
dst = p.dst,
flip = flip_fn,
src = p.src,
)
.expect("write! to String cannot fail");
write!(
update_body,
" IF NEW.{src} IS DISTINCT FROM OLD.{src} THEN\n\
\x20 NEW.{dst} := {flip}(NEW.{src});\n\
\x20 ELSIF NEW.{dst} IS NULL THEN\n\
\x20 NEW.{dst} := {flip}(NEW.{src});\n\
\x20 END IF;\n",
src = p.src,
dst = p.dst,
flip = flip_fn,
)
.expect("write! to String cannot fail");
}
let sql = format!(
r#"
CREATE OR REPLACE FUNCTION {fn_name}() RETURNS trigger AS $body$
BEGIN
IF TG_OP = 'INSERT' THEN
{insert_body} ELSIF TG_OP = 'UPDATE' THEN
{update_body} END IF;
RETURN NEW;
END;
$body$ LANGUAGE plpgsql;
DROP TRIGGER IF EXISTS {trig_name} ON {table};
CREATE TRIGGER {trig_name}
BEFORE INSERT OR UPDATE ON {table}
FOR EACH ROW EXECUTE FUNCTION {fn_name}();
"#,
fn_name = fn_name,
trig_name = trig_name,
insert_body = insert_body,
update_body = update_body,
table = table,
);
client.batch_execute(&sql).await?;
Ok(())
}
pub async fn drop_autofill_trigger_for_table<C>(client: &C, table: &str) -> Result<(), SchemaError>
where
C: GenericClient + ?Sized,
{
validate_ident(table)?;
let fn_name = format!("zzz_{}_autofill_desc", table);
let sql = format!(
"DROP TRIGGER IF EXISTS {name} ON {tbl};\n\
DROP FUNCTION IF EXISTS {name}() CASCADE;\n",
name = fn_name,
tbl = table,
);
client.batch_execute(&sql).await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn validate_ident_rejects_sql_injection_attempts() {
assert!(validate_ident("tbl; DROP TABLE users").is_err());
assert!(validate_ident("\"quoted\"").is_err());
assert!(validate_ident("it's").is_err());
assert!(validate_ident("tbl--").is_err());
assert!(validate_ident("two words").is_err());
assert!(validate_ident("tab\tname").is_err());
assert!(validate_ident("nl\nname").is_err());
assert!(validate_ident("").is_err());
assert!(validate_ident(&"x".repeat(64)).is_err());
assert!(validate_ident("1tbl").is_err());
assert!(validate_ident("tbl-name").is_err());
assert!(validate_ident("tbl.name").is_err());
}
#[test]
fn validate_ident_accepts_valid_identifiers() {
assert!(validate_ident("tbl").is_ok());
assert!(validate_ident("_internal_thing").is_ok());
assert!(validate_ident("events_v2").is_ok());
assert!(validate_ident("A").is_ok());
assert!(validate_ident("_").is_ok());
assert!(validate_ident("id_desc").is_ok());
assert!(validate_ident(&"a".repeat(63)).is_ok());
}
#[test]
fn id_kind_flip_fn_matches_sql_names() {
assert_eq!(IdKind::Heer.flip_fn(), "heerid_to_desc");
assert_eq!(IdKind::Ranj.flip_fn(), "ranjid_to_desc");
}
}