use-pg-extension 0.1.0

PostgreSQL extension metadata primitives for RustUse
Documentation
#![forbid(unsafe_code)]
#![doc = include_str!("../README.md")]

use core::{fmt, str::FromStr};
use std::error::Error;

use use_pg_schema::PgSchemaName;

/// Safe generic extension name constant for `uuid-ossp`.
pub const UUID_OSSP_EXTENSION: &str = "uuid-ossp";
/// Safe generic extension name constant for `pgcrypto`.
pub const PGCRYPTO_EXTENSION: &str = "pgcrypto";
/// Safe generic extension name constant for `citext`.
pub const CITEXT_EXTENSION: &str = "citext";
/// Safe generic extension name constant for `hstore`.
pub const HSTORE_EXTENSION: &str = "hstore";
/// Safe generic extension name constant for `postgis`.
pub const POSTGIS_EXTENSION: &str = "postgis";
/// Safe generic extension name constant for `pg_trgm`.
pub const PG_TRGM_EXTENSION: &str = "pg_trgm";
/// Safe generic extension name constant for `vector`.
pub const VECTOR_EXTENSION: &str = "vector";

/// PostgreSQL extension name primitive.
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct PgExtensionName(String);

impl PgExtensionName {
    /// Creates an extension name label.
    ///
    /// # Errors
    ///
    /// Returns [`PgExtensionError`] when the label is empty or contains unsupported characters.
    pub fn new(input: impl AsRef<str>) -> Result<Self, PgExtensionError> {
        validate_extension_name(input.as_ref()).map(|value| Self(value.to_owned()))
    }

    /// Returns `uuid-ossp` as an extension name.
    ///
    /// # Panics
    ///
    /// Panics only if the built-in `uuid-ossp` constant is changed to an invalid extension label.
    #[must_use]
    pub fn uuid_ossp() -> Self {
        Self::new(UUID_OSSP_EXTENSION).expect("uuid-ossp is a valid extension name")
    }

    /// Returns `pgcrypto` as an extension name.
    ///
    /// # Panics
    ///
    /// Panics only if the built-in `pgcrypto` constant is changed to an invalid extension label.
    #[must_use]
    pub fn pgcrypto() -> Self {
        Self::new(PGCRYPTO_EXTENSION).expect("pgcrypto is a valid extension name")
    }

    /// Returns `citext` as an extension name.
    ///
    /// # Panics
    ///
    /// Panics only if the built-in `citext` constant is changed to an invalid extension label.
    #[must_use]
    pub fn citext() -> Self {
        Self::new(CITEXT_EXTENSION).expect("citext is a valid extension name")
    }

    /// Returns `hstore` as an extension name.
    ///
    /// # Panics
    ///
    /// Panics only if the built-in `hstore` constant is changed to an invalid extension label.
    #[must_use]
    pub fn hstore() -> Self {
        Self::new(HSTORE_EXTENSION).expect("hstore is a valid extension name")
    }

    /// Returns `postgis` as an extension name.
    ///
    /// # Panics
    ///
    /// Panics only if the built-in `postgis` constant is changed to an invalid extension label.
    #[must_use]
    pub fn postgis() -> Self {
        Self::new(POSTGIS_EXTENSION).expect("postgis is a valid extension name")
    }

    /// Returns `pg_trgm` as an extension name.
    ///
    /// # Panics
    ///
    /// Panics only if the built-in `pg_trgm` constant is changed to an invalid extension label.
    #[must_use]
    pub fn pg_trgm() -> Self {
        Self::new(PG_TRGM_EXTENSION).expect("pg_trgm is a valid extension name")
    }

    /// Returns `vector` as an extension name.
    ///
    /// # Panics
    ///
    /// Panics only if the built-in `vector` constant is changed to an invalid extension label.
    #[must_use]
    pub fn vector() -> Self {
        Self::new(VECTOR_EXTENSION).expect("vector is a valid extension name")
    }

    /// Returns the extension name label.
    #[must_use]
    pub fn as_str(&self) -> &str {
        &self.0
    }
}

impl AsRef<str> for PgExtensionName {
    fn as_ref(&self) -> &str {
        self.as_str()
    }
}

impl fmt::Display for PgExtensionName {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter.write_str(self.as_str())
    }
}

impl FromStr for PgExtensionName {
    type Err = PgExtensionError;

    fn from_str(input: &str) -> Result<Self, Self::Err> {
        Self::new(input)
    }
}

/// PostgreSQL extension version label.
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct PgExtensionVersion(String);

impl PgExtensionVersion {
    /// Creates an extension version label.
    ///
    /// # Errors
    ///
    /// Returns [`PgExtensionError`] when the label is empty or contains control characters.
    pub fn new(input: impl AsRef<str>) -> Result<Self, PgExtensionError> {
        validate_version(input.as_ref()).map(|value| Self(value.to_owned()))
    }

    /// Returns the version label.
    #[must_use]
    pub fn as_str(&self) -> &str {
        &self.0
    }
}

impl fmt::Display for PgExtensionVersion {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter.write_str(self.as_str())
    }
}

impl FromStr for PgExtensionVersion {
    type Err = PgExtensionError;

    fn from_str(input: &str) -> Result<Self, Self::Err> {
        Self::new(input)
    }
}

/// PostgreSQL extension metadata.
#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct PgExtension {
    name: PgExtensionName,
    version: Option<PgExtensionVersion>,
    schema: Option<PgSchemaName>,
    relocatable: Option<bool>,
}

impl PgExtension {
    /// Creates extension metadata from a name.
    #[must_use]
    pub const fn new(name: PgExtensionName) -> Self {
        Self {
            name,
            version: None,
            schema: None,
            relocatable: None,
        }
    }

    /// Adds an extension version label.
    #[must_use]
    pub fn with_version(mut self, version: PgExtensionVersion) -> Self {
        self.version = Some(version);
        self
    }

    /// Adds schema metadata.
    #[must_use]
    pub fn with_schema(mut self, schema: PgSchemaName) -> Self {
        self.schema = Some(schema);
        self
    }

    /// Adds relocatable metadata.
    #[must_use]
    pub const fn with_relocatable(mut self, relocatable: bool) -> Self {
        self.relocatable = Some(relocatable);
        self
    }

    /// Returns the extension name.
    #[must_use]
    pub const fn name(&self) -> &PgExtensionName {
        &self.name
    }

    /// Returns the optional version label.
    #[must_use]
    pub const fn version(&self) -> Option<&PgExtensionVersion> {
        self.version.as_ref()
    }

    /// Returns the optional schema metadata.
    #[must_use]
    pub const fn schema(&self) -> Option<&PgSchemaName> {
        self.schema.as_ref()
    }

    /// Returns the optional relocatable metadata.
    #[must_use]
    pub const fn relocatable(&self) -> Option<bool> {
        self.relocatable
    }
}

/// Error returned when PostgreSQL extension metadata is invalid.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum PgExtensionError {
    EmptyName,
    EmptyVersion,
    InvalidNameCharacter { index: usize, character: char },
    ControlCharacter,
}

impl fmt::Display for PgExtensionError {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::EmptyName => formatter.write_str("PostgreSQL extension name cannot be empty"),
            Self::EmptyVersion => {
                formatter.write_str("PostgreSQL extension version cannot be empty")
            }
            Self::InvalidNameCharacter { index, character } => write!(
                formatter,
                "PostgreSQL extension name contains invalid character {character:?} at byte index {index}"
            ),
            Self::ControlCharacter => {
                formatter.write_str("PostgreSQL extension label cannot contain control characters")
            }
        }
    }
}

impl Error for PgExtensionError {}

fn validate_extension_name(input: &str) -> Result<&str, PgExtensionError> {
    let trimmed = input.trim();
    if trimmed.is_empty() {
        return Err(PgExtensionError::EmptyName);
    }
    for (index, character) in trimmed.char_indices() {
        if character.is_control() {
            return Err(PgExtensionError::ControlCharacter);
        }
        if !(character.is_ascii_alphanumeric() || matches!(character, '_' | '-')) {
            return Err(PgExtensionError::InvalidNameCharacter { index, character });
        }
    }
    Ok(trimmed)
}

fn validate_version(input: &str) -> Result<&str, PgExtensionError> {
    let trimmed = input.trim();
    if trimmed.is_empty() {
        return Err(PgExtensionError::EmptyVersion);
    }
    if trimmed.chars().any(char::is_control) {
        return Err(PgExtensionError::ControlCharacter);
    }
    Ok(trimmed)
}

#[cfg(test)]
mod tests {
    use super::{
        CITEXT_EXTENSION, PGCRYPTO_EXTENSION, PgExtension, PgExtensionError, PgExtensionName,
        PgExtensionVersion, UUID_OSSP_EXTENSION,
    };
    use use_pg_schema::PgSchemaName;

    #[test]
    fn exposes_common_extension_names() {
        assert_eq!(PgExtensionName::uuid_ossp().as_str(), UUID_OSSP_EXTENSION);
        assert_eq!(PgExtensionName::pgcrypto().as_str(), PGCRYPTO_EXTENSION);
        assert_eq!(PgExtensionName::citext().as_str(), CITEXT_EXTENSION);
    }

    #[test]
    fn parses_and_renders_versions() -> Result<(), PgExtensionError> {
        let version: PgExtensionVersion = "1.6".parse()?;
        assert_eq!(version.as_str(), "1.6");
        assert_eq!(version.to_string(), "1.6");
        assert_eq!(
            PgExtensionVersion::new(""),
            Err(PgExtensionError::EmptyVersion)
        );
        Ok(())
    }

    #[test]
    fn creates_extension_metadata() -> Result<(), PgExtensionError> {
        let extension = PgExtension::new(PgExtensionName::postgis())
            .with_version(PgExtensionVersion::new("3.5.0")?)
            .with_schema(PgSchemaName::public())
            .with_relocatable(false);

        assert_eq!(extension.name().as_str(), "postgis");
        assert_eq!(
            extension.version().map(PgExtensionVersion::as_str),
            Some("3.5.0")
        );
        assert_eq!(extension.schema(), Some(&PgSchemaName::public()));
        assert_eq!(extension.relocatable(), Some(false));
        Ok(())
    }
}