use std::fmt;
use std::hash::{Hash, Hasher};
use std::str::FromStr;
use smallvec::SmallVec;
use crate::error::{Error, Result};
pub(crate) const PG_IDENTIFIER_LIMIT: usize = 63;
pub fn escape_name(name: &str) -> Result<String> {
let len = name.chars().count();
if len > PG_IDENTIFIER_LIMIT {
return Err(Error::InvalidName(format!(
"Name exceeds PostgreSQL identifier limit ({len} > {PG_IDENTIFIER_LIMIT})"
)));
}
let escaped_inner = name.replace('"', "\"\"");
Ok(format!("\"{escaped_inner}\""))
}
#[must_use]
pub fn escape_sql_path(path: &str) -> String {
let escaped_inner = path.replace('"', "\"\"");
format!("\"{escaped_inner}\"")
}
#[must_use]
pub fn escape_string_literal(value: &str) -> String {
format!("'{}'", value.replace('\'', "''"))
}
#[derive(Clone, Debug)]
#[must_use = "Name represents a validated SQL identifier that should not be discarded. Use it in your SQL queries or table definitions"]
pub struct Name {
escaped: String,
unescaped: String,
}
impl Name {
pub fn try_new(name: impl Into<String>) -> Result<Self> {
let unescaped = name.into();
if unescaped.is_empty() {
return Err(Error::InvalidName("Name must not be empty".into()));
}
let escaped = escape_name(&unescaped)?;
Ok(Name { escaped, unescaped })
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.escaped
}
#[must_use]
pub fn unescaped(&self) -> &str {
&self.unescaped
}
}
fn parse_qualified_identifier(s: &str) -> SmallVec<[String; 3]> {
let mut parts = SmallVec::new();
let mut current = String::new();
let mut in_quotes = false;
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
match c {
'"' => {
if in_quotes && chars.peek() == Some(&'"') {
current.push('"');
chars.next(); } else {
in_quotes = !in_quotes;
}
}
'.' if !in_quotes => {
if !current.is_empty() {
parts.push(current.split_off(0));
}
}
_ => current.push(c),
}
}
if !current.is_empty() {
parts.push(current);
}
parts
}
impl fmt::Display for Name {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.escaped)
}
}
impl PartialEq for Name {
fn eq(&self, other: &Self) -> bool {
self.unescaped == other.unescaped
}
}
impl Eq for Name {}
impl PartialOrd for Name {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Name {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.unescaped.cmp(&other.unescaped)
}
}
impl Hash for Name {
fn hash<H: Hasher>(&self, state: &mut H) {
self.unescaped.hash(state);
}
}
impl TryFrom<&str> for Name {
type Error = Error;
fn try_from(s: &str) -> Result<Self> {
Self::try_new(s)
}
}
impl TryFrom<&String> for Name {
type Error = Error;
fn try_from(s: &String) -> Result<Self> {
Self::try_new(s.as_str())
}
}
impl TryFrom<String> for Name {
type Error = Error;
fn try_from(s: String) -> Result<Self> {
Self::try_new(s)
}
}
impl FromStr for Name {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::try_new(s)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[must_use = "DatabaseName represents a validated database identifier that should not be discarded. Use it in your connection or table definitions"]
pub struct DatabaseName {
name: Name,
}
impl DatabaseName {
pub fn try_new(name: impl Into<String>) -> Result<Self> {
Ok(DatabaseName {
name: Name::try_new(name)?,
})
}
pub fn name(&self) -> &Name {
&self.name
}
#[must_use]
pub fn unescaped(&self) -> &str {
self.name.unescaped()
}
}
impl fmt::Display for DatabaseName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.name)
}
}
impl TryFrom<&str> for DatabaseName {
type Error = Error;
fn try_from(s: &str) -> Result<Self> {
Self::try_new(s)
}
}
impl TryFrom<&String> for DatabaseName {
type Error = Error;
fn try_from(s: &String) -> Result<Self> {
Self::try_new(s.as_str())
}
}
impl TryFrom<String> for DatabaseName {
type Error = Error;
fn try_from(s: String) -> Result<Self> {
Self::try_new(s)
}
}
impl From<Name> for DatabaseName {
fn from(name: Name) -> Self {
DatabaseName { name }
}
}
impl FromStr for DatabaseName {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::try_new(s)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[must_use = "SchemaName represents a validated schema identifier that should not be discarded. Use it in your table definitions or queries"]
pub struct SchemaName {
database: Option<DatabaseName>,
schema: Name,
}
impl SchemaName {
pub fn try_new(schema: impl Into<String>) -> Result<Self> {
Ok(SchemaName {
database: None,
schema: Name::try_new(schema)?,
})
}
pub fn with_database(mut self, database: impl Into<String>) -> Result<Self> {
self.database = Some(DatabaseName::try_new(database)?);
Ok(self)
}
#[must_use]
pub fn database(&self) -> Option<&DatabaseName> {
self.database.as_ref()
}
pub fn schema(&self) -> &Name {
&self.schema
}
#[must_use]
pub fn unescaped(&self) -> &str {
self.schema.unescaped()
}
}
impl fmt::Display for SchemaName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref db) = self.database {
write!(f, "{}.{}", db, self.schema)
} else {
write!(f, "{}", self.schema)
}
}
}
impl TryFrom<&str> for SchemaName {
type Error = Error;
fn try_from(s: &str) -> Result<Self> {
s.parse()
}
}
impl TryFrom<&String> for SchemaName {
type Error = Error;
fn try_from(s: &String) -> Result<Self> {
s.as_str().parse()
}
}
impl TryFrom<String> for SchemaName {
type Error = Error;
fn try_from(s: String) -> Result<Self> {
s.parse()
}
}
impl From<Name> for SchemaName {
fn from(name: Name) -> Self {
SchemaName {
database: None,
schema: name,
}
}
}
impl FromStr for SchemaName {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let parts = parse_qualified_identifier(s);
match parts.as_slice() {
[s] => SchemaName::try_new(s),
[d, s] => SchemaName::try_new(s)?.with_database(d),
_ => Err(Error::InvalidName(format!("Invalid SQL identifier: {s}"))),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[must_use = "TableName represents a validated table identifier that should not be discarded. Use it in your queries or table operations"]
pub struct TableName {
database: Option<DatabaseName>,
schema: Option<Name>,
table: Name,
}
impl TableName {
pub fn try_new(table: impl Into<String>) -> Result<Self> {
Ok(TableName {
database: None,
schema: None,
table: Name::try_new(table)?,
})
}
pub fn with_schema(mut self, schema: impl Into<String>) -> Result<Self> {
self.schema = Some(Name::try_new(schema)?);
Ok(self)
}
pub fn with_database(mut self, database: impl Into<String>) -> Result<Self> {
self.database = Some(DatabaseName::try_new(database)?);
Ok(self)
}
#[must_use]
pub fn database(&self) -> Option<&DatabaseName> {
self.database.as_ref()
}
#[must_use]
pub fn schema(&self) -> Option<&Name> {
self.schema.as_ref()
}
pub fn table(&self) -> &Name {
&self.table
}
#[must_use]
pub fn unescaped(&self) -> &str {
self.table.unescaped()
}
}
impl fmt::Display for TableName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let Some(ref db) = self.database {
write!(f, "{db}.")?;
}
if let Some(ref schema) = self.schema {
write!(f, "{schema}.")?;
}
write!(f, "{}", self.table)
}
}
impl TryFrom<&str> for TableName {
type Error = Error;
fn try_from(s: &str) -> Result<Self> {
s.parse()
}
}
impl TryFrom<&String> for TableName {
type Error = Error;
fn try_from(s: &String) -> Result<Self> {
s.as_str().parse()
}
}
impl TryFrom<String> for TableName {
type Error = Error;
fn try_from(s: String) -> Result<Self> {
s.parse()
}
}
impl From<Name> for TableName {
fn from(name: Name) -> Self {
TableName {
database: None,
schema: None,
table: name,
}
}
}
impl FromStr for TableName {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
let mut parts = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
let mut chars = s.chars().peekable();
while let Some(c) = chars.next() {
match c {
'"' => {
if in_quotes && chars.peek() == Some(&'"') {
current.push('"');
chars.next(); } else {
in_quotes = !in_quotes;
}
}
'.' if !in_quotes => {
if !current.is_empty() {
parts.push(current.split_off(0));
}
}
_ => current.push(c),
}
}
if !current.is_empty() {
parts.push(current);
}
match parts.as_slice() {
[t] => TableName::try_new(t),
[s, t] => TableName::try_new(t)?.with_schema(s),
[d, s, t] => TableName::try_new(t)?.with_schema(s)?.with_database(d),
_ => Err(Error::InvalidName(format!("Invalid SQL identifier: {s}"))),
}
}
}
#[macro_export]
macro_rules! table_name {
($db:expr, $schema:expr, $table:expr) => {
$crate::TableName::try_new($table)?
.with_schema($schema)?
.with_database($db)
};
($schema:expr, $table:expr) => {
$crate::TableName::try_new($table)?.with_schema($schema)
};
($table:expr) => {
$crate::TableName::try_new($table)
};
}
#[macro_export]
macro_rules! schema_name {
($db:expr, $schema:expr) => {
$crate::SchemaName::try_new($schema)?.with_database($db)
};
($schema:expr) => {
$crate::SchemaName::try_new($schema)
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_name() {
assert_eq!(escape_name("table").unwrap(), "\"table\"");
assert_eq!(escape_name("my_table").unwrap(), "\"my_table\"");
assert_eq!(escape_name("table\"quote").unwrap(), "\"table\"\"quote\"");
assert_eq!(escape_name("").unwrap(), "\"\"");
}
#[test]
fn test_escape_name_too_long() {
let max_name = "a".repeat(PG_IDENTIFIER_LIMIT);
assert!(escape_name(&max_name).is_ok());
let too_long = "a".repeat(PG_IDENTIFIER_LIMIT + 1);
let err = escape_name(&too_long).unwrap_err();
assert!(err.to_string().contains("identifier limit"));
}
#[test]
fn test_escape_sql_path() {
assert_eq!(escape_sql_path("/tmp/data.hyper"), "\"/tmp/data.hyper\"");
assert_eq!(
escape_sql_path("/tmp/my \"db\".hyper"),
"\"/tmp/my \"\"db\"\".hyper\""
);
assert_eq!(escape_sql_path(""), "\"\"");
let long_path = format!("/very/long/path/{}.hyper", "a".repeat(100));
let escaped = escape_sql_path(&long_path);
assert!(escaped.starts_with('"'));
assert!(escaped.ends_with('"'));
}
#[test]
fn test_escape_string_literal() {
assert_eq!(escape_string_literal("hello"), "'hello'");
assert_eq!(escape_string_literal("it's"), "'it''s'");
assert_eq!(escape_string_literal(""), "''");
}
#[test]
fn test_name() {
let name = Name::try_new("users").unwrap();
assert_eq!(name.to_string(), "\"users\"");
assert_eq!(name.unescaped(), "users");
assert!(!name.unescaped().is_empty());
}
#[test]
fn test_name_with_quotes() {
let name = Name::try_new("table\"name").unwrap();
assert_eq!(name.to_string(), "\"table\"\"name\"");
assert_eq!(name.unescaped(), "table\"name");
}
#[test]
fn test_database_name() {
let db = DatabaseName::try_new("mydb").unwrap();
assert_eq!(db.to_string(), "\"mydb\"");
assert_eq!(db.unescaped(), "mydb");
}
#[test]
fn test_schema_name() {
let schema = SchemaName::try_new("public").unwrap();
assert_eq!(schema.to_string(), "\"public\"");
let qualified = SchemaName::try_new("public")
.unwrap()
.with_database("mydb")
.unwrap();
assert_eq!(qualified.to_string(), "\"mydb\".\"public\"");
}
#[test]
fn test_table_name() {
let simple = TableName::try_new("users").unwrap();
assert_eq!(simple.to_string(), "\"users\"");
let with_schema = TableName::try_new("users")
.unwrap()
.with_schema("public")
.unwrap();
assert_eq!(with_schema.to_string(), "\"public\".\"users\"");
let full = TableName::try_new("users")
.unwrap()
.with_schema("public")
.unwrap()
.with_database("mydb")
.unwrap();
assert_eq!(full.to_string(), "\"mydb\".\"public\".\"users\"");
}
#[test]
fn test_name_equality() {
let name1 = Name::try_new("test").unwrap();
let name2 = Name::try_new("test").unwrap();
let name3 = Name::try_new("other").unwrap();
assert_eq!(name1, name2);
assert_ne!(name1, name3);
}
#[test]
fn test_schema_name_from_str() {
let schema: SchemaName = "public".parse().unwrap();
assert_eq!(schema.to_string(), "\"public\"");
assert_eq!(schema.unescaped(), "public");
let qualified: SchemaName = "mydb.public".parse().unwrap();
assert_eq!(qualified.to_string(), "\"mydb\".\"public\"");
assert_eq!(qualified.unescaped(), "public");
assert_eq!(qualified.database().unwrap().unescaped(), "mydb");
let quoted: SchemaName = "\"my db\".\"my schema\"".parse().unwrap();
assert_eq!(quoted.to_string(), "\"my db\".\"my schema\"");
assert_eq!(quoted.unescaped(), "my schema");
let escaped: SchemaName = "\"schema\"\"name\"".parse().unwrap();
assert_eq!(escaped.to_string(), "\"schema\"\"name\"");
assert_eq!(escaped.unescaped(), "schema\"name");
assert!("db.schema.table".parse::<SchemaName>().is_err());
assert!("".parse::<SchemaName>().is_err());
}
#[test]
fn test_table_name_from_str() {
let table: TableName = "users".parse().unwrap();
assert_eq!(table.to_string(), "\"users\"");
assert_eq!(table.unescaped(), "users");
let with_schema: TableName = "public.users".parse().unwrap();
assert_eq!(with_schema.to_string(), "\"public\".\"users\"");
assert_eq!(with_schema.unescaped(), "users");
assert_eq!(with_schema.schema().unwrap().unescaped(), "public");
let full: TableName = "mydb.public.users".parse().unwrap();
assert_eq!(full.to_string(), "\"mydb\".\"public\".\"users\"");
assert_eq!(full.unescaped(), "users");
assert_eq!(full.schema().unwrap().unescaped(), "public");
assert_eq!(full.database().unwrap().unescaped(), "mydb");
let quoted: TableName = "\"my db\".\"my schema\".\"my table\"".parse().unwrap();
assert_eq!(quoted.to_string(), "\"my db\".\"my schema\".\"my table\"");
assert_eq!(quoted.unescaped(), "my table");
let escaped: TableName = "\"table\"\"name\"".parse().unwrap();
assert_eq!(escaped.to_string(), "\"table\"\"name\"");
assert_eq!(escaped.unescaped(), "table\"name");
let with_dots: TableName = "\"schema.name\".\"table.name\"".parse().unwrap();
assert_eq!(with_dots.to_string(), "\"schema.name\".\"table.name\"");
assert_eq!(with_dots.schema().unwrap().unescaped(), "schema.name");
assert_eq!(with_dots.unescaped(), "table.name");
assert!("db.schema.table.extra".parse::<TableName>().is_err());
assert!("".parse::<TableName>().is_err());
}
#[test]
fn test_schema_name_macro() -> Result<()> {
let schema = schema_name!("public")?;
assert_eq!(schema.to_string(), "\"public\"");
assert_eq!(schema.unescaped(), "public");
let qualified = schema_name!("mydb", "public")?;
assert_eq!(qualified.to_string(), "\"mydb\".\"public\"");
assert_eq!(qualified.unescaped(), "public");
assert_eq!(qualified.database().unwrap().unescaped(), "mydb");
Ok(())
}
#[test]
fn test_table_name_macro() -> Result<()> {
let table = table_name!("users")?;
assert_eq!(table.to_string(), "\"users\"");
assert_eq!(table.unescaped(), "users");
let with_schema = table_name!("public", "users")?;
assert_eq!(with_schema.to_string(), "\"public\".\"users\"");
assert_eq!(with_schema.unescaped(), "users");
assert_eq!(with_schema.schema().unwrap().unescaped(), "public");
let full = table_name!("mydb", "public", "users")?;
assert_eq!(full.to_string(), "\"mydb\".\"public\".\"users\"");
assert_eq!(full.unescaped(), "users");
assert_eq!(full.schema().unwrap().unescaped(), "public");
assert_eq!(full.database().unwrap().unescaped(), "mydb");
Ok(())
}
#[test]
fn test_schema_name_try_from() {
let schema: SchemaName = "public".try_into().unwrap();
assert_eq!(schema.to_string(), "\"public\"");
assert_eq!(schema.unescaped(), "public");
let qualified: SchemaName = "mydb.public".try_into().unwrap();
assert_eq!(qualified.to_string(), "\"mydb\".\"public\"");
assert_eq!(qualified.unescaped(), "public");
assert_eq!(qualified.database().unwrap().unescaped(), "mydb");
let schema_string: SchemaName = String::from("public").try_into().unwrap();
assert_eq!(schema_string.to_string(), "\"public\"");
}
#[test]
fn test_table_name_try_from() {
let table: TableName = "users".try_into().unwrap();
assert_eq!(table.to_string(), "\"users\"");
assert_eq!(table.unescaped(), "users");
let with_schema: TableName = "public.users".try_into().unwrap();
assert_eq!(with_schema.to_string(), "\"public\".\"users\"");
assert_eq!(with_schema.unescaped(), "users");
assert_eq!(with_schema.schema().unwrap().unescaped(), "public");
let full: TableName = "mydb.public.users".try_into().unwrap();
assert_eq!(full.to_string(), "\"mydb\".\"public\".\"users\"");
assert_eq!(full.unescaped(), "users");
assert_eq!(full.schema().unwrap().unescaped(), "public");
assert_eq!(full.database().unwrap().unescaped(), "mydb");
let table_string: TableName = String::from("users").try_into().unwrap();
assert_eq!(table_string.to_string(), "\"users\"");
let invalid: std::result::Result<TableName, _> = "db.schema.table.extra".try_into();
assert!(invalid.is_err());
}
}