use std::borrow::Cow;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize, serde::Serialize)]
#[serde(try_from = "String", into = "String")]
pub struct Name(Cow<'static, str>);
impl From<Name> for String {
fn from(name: Name) -> Self {
name.0.into()
}
}
impl Name {
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub const fn from_static_or_panic(name: &'static str) -> Self {
match validate_name(name) {
Ok(()) => {}
Err(NameError::Empty) => {
panic!("PostgreSQL parameter name cannot be empty");
}
Err(NameError::InvalidStartCharacter) => {
panic!("PostgreSQL parameter name must start with a letter or underscore");
}
Err(NameError::InvalidCharacter) => {
panic!("PostgreSQL parameter name contains an invalid character");
}
}
Self(Cow::Borrowed(name))
}
}
impl AsRef<str> for Name {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for Name {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str(self.as_str())
}
}
#[derive(Debug, thiserror::Error)]
pub enum NameError {
#[error("PostgreSQL parameter name cannot be empty")]
Empty,
#[error("PostgreSQL parameter name must start with a letter or underscore")]
InvalidStartCharacter,
#[error("PostgreSQL parameter name contains an invalid character")]
InvalidCharacter,
}
impl std::str::FromStr for Name {
type Err = NameError;
fn from_str(name: &str) -> Result<Self, Self::Err> {
validate_name(name).map(|()| Self(Cow::Owned(name.to_string())))
}
}
impl TryFrom<String> for Name {
type Error = NameError;
fn try_from(name: String) -> Result<Self, Self::Error> {
validate_name(&name).map(|()| Self(Cow::Owned(name)))
}
}
impl From<&'static str> for Name {
fn from(name: &'static str) -> Self {
Self::from_static_or_panic(name)
}
}
const fn validate_name(name: &str) -> Result<(), NameError> {
let bytes = name.as_bytes();
if bytes.is_empty() {
return Err(NameError::Empty);
}
let first = bytes[0];
if !(first.is_ascii_alphabetic() || first == b'_') {
return Err(NameError::InvalidStartCharacter);
}
let mut index = 1;
while index < bytes.len() {
let byte = bytes[index];
if !(byte.is_ascii_alphanumeric() || byte == b'_' || byte == b'.') {
return Err(NameError::InvalidCharacter);
}
index += 1;
}
Ok(())
}
pub const VALUE_MAX_LEN: usize = 4096;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, serde::Deserialize, serde::Serialize)]
#[serde(try_from = "String", into = "String")]
pub struct Value(Cow<'static, str>);
impl From<Value> for String {
fn from(value: Value) -> Self {
value.0.into()
}
}
impl Value {
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
pub const fn from_static_or_panic(value: &'static str) -> Self {
match validate_value(value) {
Ok(()) => {}
Err(ValueError::ContainsNul { .. }) => {
panic!("PostgreSQL parameter value cannot contain NUL byte");
}
Err(ValueError::TooLong { .. }) => {
panic!("PostgreSQL parameter value exceeds maximum of 4096 bytes");
}
}
Self(Cow::Borrowed(value))
}
}
impl AsRef<str> for Value {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::fmt::Display for Value {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str(self.as_str())
}
}
#[derive(Debug, thiserror::Error)]
pub enum ValueError {
#[error("PostgreSQL parameter value contains NUL byte at index {index}")]
ContainsNul { index: usize },
#[error("PostgreSQL parameter value length {length} exceeds maximum {max}")]
TooLong { length: usize, max: usize },
}
impl std::str::FromStr for Value {
type Err = ValueError;
fn from_str(value: &str) -> Result<Self, Self::Err> {
validate_value(value).map(|()| Self(Cow::Owned(value.to_string())))
}
}
impl TryFrom<String> for Value {
type Error = ValueError;
fn try_from(value: String) -> Result<Self, Self::Error> {
validate_value(&value).map(|()| Self(Cow::Owned(value)))
}
}
impl From<&'static str> for Value {
fn from(value: &'static str) -> Self {
Self::from_static_or_panic(value)
}
}
const fn validate_value(value: &str) -> Result<(), ValueError> {
let bytes = value.as_bytes();
if bytes.len() > VALUE_MAX_LEN {
return Err(ValueError::TooLong {
length: bytes.len(),
max: VALUE_MAX_LEN,
});
}
let mut index = 0;
while index < bytes.len() {
if bytes[index] == 0 {
return Err(ValueError::ContainsNul { index });
}
index += 1;
}
Ok(())
}
pub type Map = std::collections::BTreeMap<Name, Value>;