use thiserror::Error;
const MAX_SCHEMA_NAME_LENGTH: usize = 63;
const RESERVED_SCHEMA_NAMES: &[&str] = &["public", "pg_catalog", "information_schema", "pg_temp"];
#[derive(Debug, Error)]
pub enum SchemaError {
#[error("Schema name length invalid: '{name}' (must be 1-{max} characters)")]
InvalidLength { name: String, max: usize },
#[error("Schema name must start with a letter or underscore: '{0}'")]
InvalidStart(String),
#[error(
"Schema name contains invalid characters (only alphanumeric and underscore allowed): '{0}'"
)]
InvalidCharacters(String),
#[error("Schema name is reserved: '{0}'")]
ReservedName(String),
}
pub fn validate_schema_name(name: &str) -> Result<&str, SchemaError> {
if name.is_empty() || name.len() > MAX_SCHEMA_NAME_LENGTH {
return Err(SchemaError::InvalidLength {
name: name.to_string(),
max: MAX_SCHEMA_NAME_LENGTH,
});
}
let first_char = name.chars().next().unwrap(); if !first_char.is_ascii_alphabetic() && first_char != '_' {
return Err(SchemaError::InvalidStart(name.to_string()));
}
if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
return Err(SchemaError::InvalidCharacters(name.to_string()));
}
let lower_name = name.to_lowercase();
if RESERVED_SCHEMA_NAMES.contains(&lower_name.as_str()) {
return Err(SchemaError::ReservedName(name.to_string()));
}
Ok(name)
}
const RESERVED_USERNAMES: &[&str] = &[
"postgres",
"pg_database_owner",
"pg_read_all_data",
"pg_write_all_data",
"pg_read_all_settings",
"pg_read_all_stats",
"pg_stat_scan_tables",
"pg_monitor",
"pg_read_server_files",
"pg_write_server_files",
"pg_execute_server_program",
"pg_signal_backend",
"pg_checkpoint",
];
#[derive(Debug, Error)]
pub enum UsernameError {
#[error("Username length invalid: '{name}' (must be 1-{max} characters)")]
InvalidLength { name: String, max: usize },
#[error("Username must start with a letter or underscore: '{0}'")]
InvalidStart(String),
#[error(
"Username contains invalid characters (only alphanumeric and underscore allowed): '{0}'"
)]
InvalidCharacters(String),
#[error("Username is reserved: '{0}'")]
ReservedName(String),
}
pub fn validate_username(name: &str) -> Result<&str, UsernameError> {
if name.is_empty() || name.len() > MAX_SCHEMA_NAME_LENGTH {
return Err(UsernameError::InvalidLength {
name: name.to_string(),
max: MAX_SCHEMA_NAME_LENGTH,
});
}
let first_char = name.chars().next().unwrap(); if !first_char.is_ascii_alphabetic() && first_char != '_' {
return Err(UsernameError::InvalidStart(name.to_string()));
}
if !name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
return Err(UsernameError::InvalidCharacters(name.to_string()));
}
let lower_name = name.to_lowercase();
if RESERVED_USERNAMES.contains(&lower_name.as_str()) {
return Err(UsernameError::ReservedName(name.to_string()));
}
Ok(name)
}
pub fn escape_password(password: &str) -> String {
password.replace('\'', "''")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_schema_names() {
assert!(validate_schema_name("my_schema").is_ok());
assert!(validate_schema_name("tenant_123").is_ok());
assert!(validate_schema_name("MySchema").is_ok());
assert!(validate_schema_name("_private").is_ok());
assert!(validate_schema_name("_123").is_ok());
assert!(validate_schema_name("a").is_ok());
assert!(validate_schema_name("_").is_ok());
let max_name = "a".repeat(63);
assert!(validate_schema_name(&max_name).is_ok());
}
#[test]
fn test_sql_injection_attempts_rejected() {
assert!(matches!(
validate_schema_name("test; DROP TABLE users; --"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("test' OR '1'='1"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("test/*comment*/"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("test--comment"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("test()"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("test=1"),
Err(SchemaError::InvalidCharacters(_))
));
}
#[test]
fn test_invalid_length() {
assert!(matches!(
validate_schema_name(""),
Err(SchemaError::InvalidLength { .. })
));
let too_long = "a".repeat(64);
assert!(matches!(
validate_schema_name(&too_long),
Err(SchemaError::InvalidLength { .. })
));
let way_too_long = "a".repeat(1000);
assert!(matches!(
validate_schema_name(&way_too_long),
Err(SchemaError::InvalidLength { .. })
));
}
#[test]
fn test_invalid_start_character() {
assert!(matches!(
validate_schema_name("123abc"),
Err(SchemaError::InvalidStart(_))
));
assert!(matches!(
validate_schema_name("-schema"),
Err(SchemaError::InvalidStart(_))
));
assert!(matches!(
validate_schema_name(".schema"),
Err(SchemaError::InvalidStart(_))
));
assert!(matches!(
validate_schema_name(" schema"),
Err(SchemaError::InvalidStart(_))
));
}
#[test]
fn test_invalid_characters() {
assert!(matches!(
validate_schema_name("my-schema"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("my.schema"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("my schema"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("schema@test"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("schema#1"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("schema$"),
Err(SchemaError::InvalidCharacters(_))
));
}
#[test]
fn test_reserved_names() {
assert!(matches!(
validate_schema_name("public"),
Err(SchemaError::ReservedName(_))
));
assert!(matches!(
validate_schema_name("PUBLIC"),
Err(SchemaError::ReservedName(_))
));
assert!(matches!(
validate_schema_name("Public"),
Err(SchemaError::ReservedName(_))
));
assert!(matches!(
validate_schema_name("pg_catalog"),
Err(SchemaError::ReservedName(_))
));
assert!(matches!(
validate_schema_name("PG_CATALOG"),
Err(SchemaError::ReservedName(_))
));
assert!(matches!(
validate_schema_name("information_schema"),
Err(SchemaError::ReservedName(_))
));
assert!(matches!(
validate_schema_name("INFORMATION_SCHEMA"),
Err(SchemaError::ReservedName(_))
));
assert!(matches!(
validate_schema_name("pg_temp"),
Err(SchemaError::ReservedName(_))
));
}
#[test]
fn test_schema_error_display() {
let err = validate_schema_name("").unwrap_err();
assert!(err.to_string().contains("length"));
let err = validate_schema_name("123abc").unwrap_err();
assert!(err.to_string().contains("start"));
let err = validate_schema_name("my-schema").unwrap_err();
assert!(err.to_string().contains("invalid characters"));
let err = validate_schema_name("public").unwrap_err();
assert!(err.to_string().contains("reserved"));
}
#[test]
fn test_unicode_characters_rejected() {
assert!(matches!(
validate_schema_name("schema_\u{03B1}"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("schema_\u{2603}"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("caf\u{00E9}"),
Err(SchemaError::InvalidCharacters(_))
));
assert!(matches!(
validate_schema_name("schema_\u{4E2D}"),
Err(SchemaError::InvalidCharacters(_))
));
}
#[test]
fn test_valid_usernames() {
assert!(validate_username("tenant_user").is_ok());
assert!(validate_username("acme_admin").is_ok());
assert!(validate_username("user123").is_ok());
assert!(validate_username("_private_user").is_ok());
assert!(validate_username("a").is_ok());
}
#[test]
fn test_username_sql_injection_rejected() {
assert!(matches!(
validate_username("admin'; DROP TABLE users; --"),
Err(UsernameError::InvalidCharacters(_))
));
assert!(matches!(
validate_username("user' OR '1'='1"),
Err(UsernameError::InvalidCharacters(_))
));
assert!(matches!(
validate_username("admin--"),
Err(UsernameError::InvalidCharacters(_))
));
assert!(matches!(
validate_username("user()"),
Err(UsernameError::InvalidCharacters(_))
));
assert!(matches!(
validate_username("admin user"),
Err(UsernameError::InvalidCharacters(_))
));
}
#[test]
fn test_reserved_usernames() {
assert!(matches!(
validate_username("postgres"),
Err(UsernameError::ReservedName(_))
));
assert!(matches!(
validate_username("POSTGRES"),
Err(UsernameError::ReservedName(_))
));
assert!(matches!(
validate_username("pg_read_all_data"),
Err(UsernameError::ReservedName(_))
));
assert!(matches!(
validate_username("pg_monitor"),
Err(UsernameError::ReservedName(_))
));
}
#[test]
fn test_username_invalid_length() {
assert!(matches!(
validate_username(""),
Err(UsernameError::InvalidLength { .. })
));
let too_long = "a".repeat(64);
assert!(matches!(
validate_username(&too_long),
Err(UsernameError::InvalidLength { .. })
));
}
#[test]
fn test_username_invalid_start() {
assert!(matches!(
validate_username("123user"),
Err(UsernameError::InvalidStart(_))
));
assert!(matches!(
validate_username("-user"),
Err(UsernameError::InvalidStart(_))
));
}
#[test]
fn test_escape_password_no_quotes() {
assert_eq!(escape_password("simple"), "simple");
assert_eq!(escape_password("Password123"), "Password123");
assert_eq!(escape_password(""), "");
}
#[test]
fn test_escape_password_with_quotes() {
assert_eq!(escape_password("pass'word"), "pass''word");
assert_eq!(escape_password("it's a test"), "it''s a test");
assert_eq!(escape_password("'quoted'"), "''quoted''");
assert_eq!(escape_password("'''"), "''''''");
}
#[test]
fn test_escape_password_sql_injection_safe() {
let dangerous = "'; DROP TABLE users; --";
let escaped = escape_password(dangerous);
assert_eq!(escaped, "''; DROP TABLE users; --");
}
}