use std::fmt;
#[derive(Debug, Clone, Copy)]
pub struct SqlIdentifier<'a>(pub &'a str);
impl fmt::Display for SqlIdentifier<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let needs_quoting =
!is_valid_unquoted_identifier(self.0) || self.0.chars().any(char::is_uppercase);
if needs_quoting {
f.write_str("\"")?;
for c in self.0.chars() {
if c == '"' {
f.write_str("\"\"")?;
} else {
write!(f, "{c}")?;
}
}
f.write_str("\"")
} else {
f.write_str(self.0)
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct SqlLiteral<'a>(pub &'a str);
impl fmt::Display for SqlLiteral<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("'")?;
for c in self.0.chars() {
if c == '\'' {
f.write_str("''")?;
} else {
write!(f, "{c}")?;
}
}
f.write_str("'")
}
}
#[must_use]
pub fn is_valid_unquoted_identifier(s: &str) -> bool {
if s.is_empty() {
return false;
}
let mut chars = s.chars();
match chars.next() {
Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
_ => return false,
}
chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$')
}
#[must_use]
pub fn format_table_name(database: Option<&str>, schema: Option<&str>, table: &str) -> String {
match (database, schema) {
(Some(db), Some(s)) => format!(
"{}.{}.{}",
SqlIdentifier(db),
SqlIdentifier(s),
SqlIdentifier(table)
),
(None, Some(s)) => format!("{}.{}", SqlIdentifier(s), SqlIdentifier(table)),
(Some(db), None) => format!("{}.{}", SqlIdentifier(db), SqlIdentifier(table)),
(None, None) => format!("{}", SqlIdentifier(table)),
}
}
#[must_use]
pub fn escape_identifier(identifier: &str) -> String {
format!("{}", SqlIdentifier(identifier))
}
#[must_use]
pub fn escape_literal(literal: &str) -> String {
format!("{}", SqlLiteral(literal))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_identifier_display() {
assert_eq!(format!("{}", SqlIdentifier("table")), "table");
assert_eq!(format!("{}", SqlIdentifier("my_table")), "my_table");
assert_eq!(format!("{}", SqlIdentifier("table1")), "table1");
assert_eq!(format!("{}", SqlIdentifier("_private")), "_private");
assert_eq!(format!("{}", SqlIdentifier("my$var")), "my$var");
assert_eq!(format!("{}", SqlIdentifier("Segment")), "\"Segment\"");
assert_eq!(format!("{}", SqlIdentifier("CustomerID")), "\"CustomerID\"");
assert_eq!(format!("{}", SqlIdentifier("Table")), "\"Table\"");
assert_eq!(format!("{}", SqlIdentifier("my-table")), "\"my-table\"");
assert_eq!(format!("{}", SqlIdentifier("my table")), "\"my table\"");
assert_eq!(format!("{}", SqlIdentifier("1table")), "\"1table\"");
assert_eq!(format!("{}", SqlIdentifier("my\"table")), "\"my\"\"table\"");
assert_eq!(format!("{}", SqlIdentifier("")), "\"\"");
}
#[test]
fn test_sql_literal_display() {
assert_eq!(format!("{}", SqlLiteral("hello")), "'hello'");
assert_eq!(format!("{}", SqlLiteral("it's")), "'it''s'");
assert_eq!(format!("{}", SqlLiteral("")), "''");
}
#[test]
fn test_is_valid_unquoted_identifier() {
assert!(is_valid_unquoted_identifier("table"));
assert!(is_valid_unquoted_identifier("_private"));
assert!(is_valid_unquoted_identifier("table1"));
assert!(is_valid_unquoted_identifier("my$var"));
assert!(!is_valid_unquoted_identifier(""));
assert!(!is_valid_unquoted_identifier("1table"));
assert!(!is_valid_unquoted_identifier("my-table"));
assert!(!is_valid_unquoted_identifier("my table"));
}
#[test]
fn test_format_table_name() {
assert_eq!(format_table_name(None, None, "users"), "users");
assert_eq!(
format_table_name(None, Some("public"), "users"),
"public.users"
);
assert_eq!(
format_table_name(Some("mydb"), Some("public"), "users"),
"mydb.public.users"
);
assert_eq!(format_table_name(None, None, "my-table"), "\"my-table\"");
assert_eq!(
format_table_name(None, Some("my schema"), "users"),
"\"my schema\".users"
);
}
#[test]
fn test_sql_identifier_in_format() {
let table = "users";
let column = "Customer ID";
let sql = format!(
"SELECT {} FROM {}",
SqlIdentifier(column),
SqlIdentifier(table)
);
assert_eq!(sql, "SELECT \"Customer ID\" FROM users");
}
#[test]
fn test_escape_identifier() {
assert_eq!(escape_identifier("table"), "table");
assert_eq!(escape_identifier("Segment"), "\"Segment\"");
}
#[test]
fn test_escape_literal() {
assert_eq!(escape_literal("hello"), "'hello'");
assert_eq!(escape_literal("it's"), "'it''s'");
}
}