use std::borrow::Cow;
use core::fmt::{Display, Formatter};
use core::str::FromStr;
pub const MAX_LENGTH: usize = 63;
const fn validate(input: &str) -> Option<ParseError> {
if input.is_empty() {
return Some(ParseError::Empty);
}
if input.len() > MAX_LENGTH {
return Some(ParseError::TooLong);
}
let bytes = input.as_bytes();
let mut index = 0;
while index < bytes.len() {
if bytes[index] == 0 {
return Some(ParseError::ContainsNul);
}
index += 1;
}
None
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
#[serde(try_from = "String")]
struct Identifier(Cow<'static, str>);
impl TryFrom<String> for Identifier {
type Error = ParseError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()
}
}
impl Identifier {
#[must_use]
const fn from_static_or_panic(input: &'static str) -> Self {
match validate(input) {
Some(error) => panic!("{}", error.message()),
None => Self(Cow::Borrowed(input)),
}
}
#[must_use]
fn as_str(&self) -> &str {
&self.0
}
}
impl Display for Identifier {
fn fmt(&self, formatter: &mut Formatter<'_>) -> core::fmt::Result {
write!(formatter, "{}", self.0)
}
}
impl AsRef<str> for Identifier {
fn as_ref(&self) -> &str {
&self.0
}
}
impl FromStr for Identifier {
type Err = ParseError;
fn from_str(input: &str) -> Result<Self, Self::Err> {
match validate(input) {
Some(error) => Err(error),
None => Ok(Self(Cow::Owned(input.to_owned()))),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ParseError {
Empty,
TooLong,
ContainsNul,
}
impl ParseError {
#[must_use]
pub const fn message(&self) -> &'static str {
match self {
Self::Empty => "identifier cannot be empty",
Self::TooLong => "identifier exceeds maximum length",
Self::ContainsNul => "identifier cannot contain NUL bytes",
}
}
}
impl Display for ParseError {
fn fmt(&self, formatter: &mut Formatter<'_>) -> core::fmt::Result {
write!(formatter, "{}", self.message())
}
}
impl std::error::Error for ParseError {}
macro_rules! define_identifier_type {
($(#[$meta:meta])* $name:ident, $test_mod:ident) => {
$(#[$meta])*
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
#[serde(try_from = "String")]
pub struct $name(Identifier);
impl TryFrom<String> for $name {
type Error = ParseError;
fn try_from(value: String) -> Result<Self, Self::Error> {
value.parse()
}
}
impl $name {
#[must_use]
pub const fn from_static_or_panic(input: &'static str) -> Self {
Self(Identifier::from_static_or_panic(input))
}
#[must_use]
pub fn as_str(&self) -> &str {
self.0.as_str()
}
}
impl Display for $name {
fn fmt(&self, formatter: &mut Formatter<'_>) -> core::fmt::Result {
write!(formatter, "{}", self.0)
}
}
impl AsRef<str> for $name {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl FromStr for $name {
type Err = ParseError;
fn from_str(input: &str) -> Result<Self, Self::Err> {
Identifier::from_str(input).map(Self)
}
}
#[cfg(test)]
mod $test_mod {
use super::*;
#[test]
fn parse_valid() {
let value: $name = "test".parse().unwrap();
assert_eq!(value.to_string(), "test");
}
#[test]
fn parse_valid_with_space() {
let value: $name = "test value".parse().unwrap();
assert_eq!(value.to_string(), "test value");
}
#[test]
fn parse_empty_fails() {
let result: Result<$name, _> = "".parse();
assert!(matches!(result, Err(ParseError::Empty)));
}
#[test]
fn parse_contains_nul_fails() {
let result: Result<$name, _> = "test\0value".parse();
assert!(matches!(result, Err(ParseError::ContainsNul)));
}
#[test]
fn parse_too_long_fails() {
let input = "a".repeat(MAX_LENGTH + 1);
let result: Result<$name, _> = input.parse();
assert!(matches!(result, Err(ParseError::TooLong)));
}
}
};
}
define_identifier_type!(
Table,
table
);
define_identifier_type!(
Schema,
schema
);
impl Schema {
pub const PUBLIC: Self = Self::from_static_or_panic("public");
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)]
pub struct QualifiedTable {
pub schema: Schema,
pub table: Table,
}
impl Display for QualifiedTable {
fn fmt(&self, formatter: &mut Formatter<'_>) -> core::fmt::Result {
write!(formatter, "{}.{}", self.schema, self.table)
}
}
define_identifier_type!(
Column,
column
);
define_identifier_type!(
Index,
index
);
define_identifier_type!(
Constraint,
constraint
);
define_identifier_type!(
Extension,
extension
);
define_identifier_type!(
Sequence,
sequence
);
define_identifier_type!(
Function,
function
);
define_identifier_type!(
Trigger,
trigger
);
define_identifier_type!(
Domain,
domain
);
define_identifier_type!(
Type,
r#type
);
define_identifier_type!(
View,
view
);
define_identifier_type!(
Relation,
relation
);
impl From<Table> for Relation {
fn from(table: Table) -> Self {
Self(table.0)
}
}
impl From<View> for Relation {
fn from(view: View) -> Self {
Self(view.0)
}
}
define_identifier_type!(
MaterializedView,
materialized_view
);
impl From<MaterializedView> for Relation {
fn from(materialized_view: MaterializedView) -> Self {
Self(materialized_view.0)
}
}
define_identifier_type!(
Operator,
operator
);
define_identifier_type!(
Aggregate,
aggregate
);
define_identifier_type!(
Collation,
collation
);
define_identifier_type!(
Tablespace,
tablespace
);
define_identifier_type!(
Policy,
policy
);
define_identifier_type!(
Rule,
rule
);
define_identifier_type!(
Publication,
publication
);
define_identifier_type!(
Subscription,
subscription
);
define_identifier_type!(
ForeignServer,
foreign_server
);
define_identifier_type!(
ForeignDataWrapper,
foreign_data_wrapper
);
define_identifier_type!(
ForeignTable,
foreign_table
);
define_identifier_type!(
EventTrigger,
event_trigger
);
define_identifier_type!(
Language,
language
);
define_identifier_type!(
TextSearchConfiguration,
text_search_configuration
);
define_identifier_type!(
TextSearchDictionary,
text_search_dictionary
);
define_identifier_type!(
Conversion,
conversion
);
define_identifier_type!(
OperatorClass,
operator_class
);
define_identifier_type!(
OperatorFamily,
operator_family
);
define_identifier_type!(
AccessMethod,
access_method
);
define_identifier_type!(
StatisticsObject,
statistics_object
);
define_identifier_type!(
Database,
database
);
impl Database {
pub const POSTGRES: Self = Self::from_static_or_panic("postgres");
}
define_identifier_type!(
Role,
role
);
impl Role {
pub const POSTGRES: Self = Self::from_static_or_panic("postgres");
}
pub type User = Role;
#[cfg(test)]
mod tests {
use super::*;
mod identifier {
use super::*;
#[test]
fn parse_valid_simple() {
let identifier: Identifier = "users".parse().unwrap();
assert_eq!(identifier.to_string(), "users");
}
#[test]
fn parse_valid_with_space() {
let identifier: Identifier = "my table".parse().unwrap();
assert_eq!(identifier.to_string(), "my table");
}
#[test]
fn parse_valid_with_special_chars() {
let identifier: Identifier = "my-table.name".parse().unwrap();
assert_eq!(identifier.to_string(), "my-table.name");
}
#[test]
fn parse_valid_starting_with_digit() {
let identifier: Identifier = "1table".parse().unwrap();
assert_eq!(identifier.to_string(), "1table");
}
#[test]
fn parse_valid_max_length() {
let input = "a".repeat(MAX_LENGTH);
let identifier: Identifier = input.parse().unwrap();
assert_eq!(identifier.to_string(), input);
}
#[test]
fn parse_empty_fails() {
let result: Result<Identifier, _> = "".parse();
assert_eq!(result, Err(ParseError::Empty));
}
#[test]
fn parse_too_long_fails() {
let input = "a".repeat(MAX_LENGTH + 1);
let result: Result<Identifier, _> = input.parse();
assert_eq!(result, Err(ParseError::TooLong));
}
#[test]
fn parse_contains_nul_fails() {
let result: Result<Identifier, _> = "my\0table".parse();
assert_eq!(result, Err(ParseError::ContainsNul));
}
}
}