use crate::SqlError;
use sqlparser::dialect::Dialect;
#[must_use]
pub fn quote_ident(value: &str, dialect: &impl Dialect) -> String {
let q = dialect.identifier_quote_style(value).unwrap_or('"');
let mut out = String::with_capacity(value.len() + 2);
out.push(q);
for ch in value.chars() {
if ch == q {
out.push(q);
}
out.push(ch);
}
out.push(q);
out
}
#[must_use]
pub fn quote_literal(value: &str) -> String {
let mut out = String::with_capacity(value.len() + 2);
out.push('\'');
for ch in value.chars() {
if ch == '\\' {
out.push('\\');
} else if ch == '\'' {
out.push('\'');
}
out.push(ch);
}
out.push('\'');
out
}
pub fn validate_ident(name: &str) -> Result<(), SqlError> {
if name.trim().is_empty() || name.chars().any(char::is_control) {
return Err(SqlError::InvalidIdentifier(name.to_string()));
}
Ok(())
}
#[cfg(test)]
mod tests {
use sqlparser::dialect::{MySqlDialect, PostgreSqlDialect, SQLiteDialect};
use super::*;
#[test]
fn accepts_standard_names() {
assert!(validate_ident("users").is_ok());
assert!(validate_ident("my_table").is_ok());
assert!(validate_ident("DB_123").is_ok());
}
#[test]
fn accepts_hyphenated_names() {
assert!(validate_ident("eu-docker").is_ok());
assert!(validate_ident("access-logs").is_ok());
}
#[test]
fn accepts_special_chars() {
assert!(validate_ident("my.db").is_ok());
assert!(validate_ident("123db").is_ok());
assert!(validate_ident("café").is_ok());
assert!(validate_ident("a b").is_ok());
}
#[test]
fn rejects_empty() {
assert!(validate_ident("").is_err());
}
#[test]
fn rejects_whitespace_only() {
assert!(validate_ident(" ").is_err());
assert!(validate_ident("\t").is_err());
}
#[test]
fn rejects_control_chars() {
assert!(validate_ident("test\x00db").is_err());
assert!(validate_ident("test\ndb").is_err());
assert!(validate_ident("test\x1Fdb").is_err());
}
#[test]
fn quote_with_postgres_dialect() {
let d = PostgreSqlDialect {};
assert_eq!(quote_ident("users", &d), "\"users\"");
assert_eq!(quote_ident("eu-docker", &d), "\"eu-docker\"");
assert_eq!(quote_ident("test\"db", &d), "\"test\"\"db\"");
}
#[test]
fn quote_with_mysql_dialect() {
let d = MySqlDialect {};
assert_eq!(quote_ident("users", &d), "`users`");
assert_eq!(quote_ident("test`db", &d), "`test``db`");
}
#[test]
fn quote_with_sqlite_dialect() {
let d = SQLiteDialect {};
assert_eq!(quote_ident("users", &d), "`users`");
assert_eq!(quote_ident("test`db", &d), "`test``db`");
}
#[test]
fn quote_literal_escapes_single_quotes() {
assert_eq!(quote_literal("my_db"), "'my_db'");
assert_eq!(quote_literal(""), "''");
assert_eq!(quote_literal("it's"), "'it''s'");
assert_eq!(quote_literal("a'b'c"), "'a''b''c'");
}
#[test]
fn accepts_long_identifier() {
let long_name: String = "a".repeat(10_000);
assert!(validate_ident(&long_name).is_ok());
}
#[test]
fn rejects_mixed_valid_and_control() {
assert!(validate_ident("valid\x00").is_err());
assert!(validate_ident("\x01start").is_err());
assert!(validate_ident("mid\x7Fdle").is_err());
}
#[test]
fn accepts_sql_injection_payload_in_ident() {
assert!(validate_ident("Robert'; DROP TABLE students;--").is_ok());
}
#[test]
fn accepts_emoji() {
assert!(validate_ident("🎉").is_ok());
assert!(validate_ident("table_🔥").is_ok());
}
#[test]
fn accepts_cjk() {
assert!(validate_ident("数据库").is_ok());
assert!(validate_ident("テーブル").is_ok());
}
#[test]
fn quote_ident_only_backticks_mysql() {
let d = MySqlDialect {};
assert_eq!(quote_ident("``", &d), "``````");
}
#[test]
fn quote_ident_only_double_quotes_postgres() {
let d = PostgreSqlDialect {};
assert_eq!(quote_ident("\"\"", &d), "\"\"\"\"\"\"");
}
#[test]
fn quote_ident_quote_at_start_and_end() {
let mysql = MySqlDialect {};
assert_eq!(quote_ident("`x`", &mysql), "```x```");
let pg = PostgreSqlDialect {};
assert_eq!(quote_ident("\"x\"", &pg), "\"\"\"x\"\"\"");
}
#[test]
fn quote_ident_cross_dialect_foreign_quote_passes_through() {
let mysql = MySqlDialect {};
assert_eq!(quote_ident("test\"db", &mysql), "`test\"db`");
let pg = PostgreSqlDialect {};
assert_eq!(quote_ident("test`db", &pg), "\"test`db\"");
}
#[test]
fn quote_ident_empty_string() {
let mysql = MySqlDialect {};
assert_eq!(quote_ident("", &mysql), "``");
let pg = PostgreSqlDialect {};
assert_eq!(quote_ident("", &pg), "\"\"");
}
#[test]
fn quote_ident_long_string_completes() {
let long_name: String = "a".repeat(10_000);
let pg = PostgreSqlDialect {};
let quoted = quote_ident(&long_name, &pg);
assert_eq!(quoted.len(), 10_002);
}
#[test]
fn quote_literal_trailing_backslash() {
assert_eq!(quote_literal("test\\"), "'test\\\\'");
}
#[test]
fn quote_literal_single_backslash() {
assert_eq!(quote_literal("\\"), "'\\\\'");
}
#[test]
fn quote_literal_backslash_then_quote() {
assert_eq!(quote_literal("\\'"), "'\\\\'''");
}
#[test]
fn quote_literal_only_backslashes() {
assert_eq!(quote_literal("\\\\\\"), "'\\\\\\\\\\\\'");
}
#[test]
fn quote_literal_sql_injection_payload() {
assert_eq!(
quote_literal("Robert'; DROP TABLE students;--"),
"'Robert''; DROP TABLE students;--'"
);
}
#[test]
fn quote_literal_many_quotes_completes() {
let input: String = "'".repeat(1_000);
let result = quote_literal(&input);
assert_eq!(result.len(), 2_002);
}
#[test]
fn quote_literal_backslash_and_quotes_mixed() {
assert_eq!(quote_literal("it\\'s"), "'it\\\\''s'");
}
#[test]
fn quote_literal_no_special_chars() {
assert_eq!(quote_literal("plain"), "'plain'");
}
#[test]
fn quote_literal_unicode_untouched() {
assert_eq!(quote_literal("café"), "'café'");
assert_eq!(quote_literal("数据"), "'数据'");
}
}