Skip to main content

use_pg_extension/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7use use_pg_schema::PgSchemaName;
8
9/// Safe generic extension name constant for `uuid-ossp`.
10pub const UUID_OSSP_EXTENSION: &str = "uuid-ossp";
11/// Safe generic extension name constant for `pgcrypto`.
12pub const PGCRYPTO_EXTENSION: &str = "pgcrypto";
13/// Safe generic extension name constant for `citext`.
14pub const CITEXT_EXTENSION: &str = "citext";
15/// Safe generic extension name constant for `hstore`.
16pub const HSTORE_EXTENSION: &str = "hstore";
17/// Safe generic extension name constant for `postgis`.
18pub const POSTGIS_EXTENSION: &str = "postgis";
19/// Safe generic extension name constant for `pg_trgm`.
20pub const PG_TRGM_EXTENSION: &str = "pg_trgm";
21/// Safe generic extension name constant for `vector`.
22pub const VECTOR_EXTENSION: &str = "vector";
23
24/// PostgreSQL extension name primitive.
25#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
26pub struct PgExtensionName(String);
27
28impl PgExtensionName {
29    /// Creates an extension name label.
30    ///
31    /// # Errors
32    ///
33    /// Returns [`PgExtensionError`] when the label is empty or contains unsupported characters.
34    pub fn new(input: impl AsRef<str>) -> Result<Self, PgExtensionError> {
35        validate_extension_name(input.as_ref()).map(|value| Self(value.to_owned()))
36    }
37
38    /// Returns `uuid-ossp` as an extension name.
39    ///
40    /// # Panics
41    ///
42    /// Panics only if the built-in `uuid-ossp` constant is changed to an invalid extension label.
43    #[must_use]
44    pub fn uuid_ossp() -> Self {
45        Self::new(UUID_OSSP_EXTENSION).expect("uuid-ossp is a valid extension name")
46    }
47
48    /// Returns `pgcrypto` as an extension name.
49    ///
50    /// # Panics
51    ///
52    /// Panics only if the built-in `pgcrypto` constant is changed to an invalid extension label.
53    #[must_use]
54    pub fn pgcrypto() -> Self {
55        Self::new(PGCRYPTO_EXTENSION).expect("pgcrypto is a valid extension name")
56    }
57
58    /// Returns `citext` as an extension name.
59    ///
60    /// # Panics
61    ///
62    /// Panics only if the built-in `citext` constant is changed to an invalid extension label.
63    #[must_use]
64    pub fn citext() -> Self {
65        Self::new(CITEXT_EXTENSION).expect("citext is a valid extension name")
66    }
67
68    /// Returns `hstore` as an extension name.
69    ///
70    /// # Panics
71    ///
72    /// Panics only if the built-in `hstore` constant is changed to an invalid extension label.
73    #[must_use]
74    pub fn hstore() -> Self {
75        Self::new(HSTORE_EXTENSION).expect("hstore is a valid extension name")
76    }
77
78    /// Returns `postgis` as an extension name.
79    ///
80    /// # Panics
81    ///
82    /// Panics only if the built-in `postgis` constant is changed to an invalid extension label.
83    #[must_use]
84    pub fn postgis() -> Self {
85        Self::new(POSTGIS_EXTENSION).expect("postgis is a valid extension name")
86    }
87
88    /// Returns `pg_trgm` as an extension name.
89    ///
90    /// # Panics
91    ///
92    /// Panics only if the built-in `pg_trgm` constant is changed to an invalid extension label.
93    #[must_use]
94    pub fn pg_trgm() -> Self {
95        Self::new(PG_TRGM_EXTENSION).expect("pg_trgm is a valid extension name")
96    }
97
98    /// Returns `vector` as an extension name.
99    ///
100    /// # Panics
101    ///
102    /// Panics only if the built-in `vector` constant is changed to an invalid extension label.
103    #[must_use]
104    pub fn vector() -> Self {
105        Self::new(VECTOR_EXTENSION).expect("vector is a valid extension name")
106    }
107
108    /// Returns the extension name label.
109    #[must_use]
110    pub fn as_str(&self) -> &str {
111        &self.0
112    }
113}
114
115impl AsRef<str> for PgExtensionName {
116    fn as_ref(&self) -> &str {
117        self.as_str()
118    }
119}
120
121impl fmt::Display for PgExtensionName {
122    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
123        formatter.write_str(self.as_str())
124    }
125}
126
127impl FromStr for PgExtensionName {
128    type Err = PgExtensionError;
129
130    fn from_str(input: &str) -> Result<Self, Self::Err> {
131        Self::new(input)
132    }
133}
134
135/// PostgreSQL extension version label.
136#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
137pub struct PgExtensionVersion(String);
138
139impl PgExtensionVersion {
140    /// Creates an extension version label.
141    ///
142    /// # Errors
143    ///
144    /// Returns [`PgExtensionError`] when the label is empty or contains control characters.
145    pub fn new(input: impl AsRef<str>) -> Result<Self, PgExtensionError> {
146        validate_version(input.as_ref()).map(|value| Self(value.to_owned()))
147    }
148
149    /// Returns the version label.
150    #[must_use]
151    pub fn as_str(&self) -> &str {
152        &self.0
153    }
154}
155
156impl fmt::Display for PgExtensionVersion {
157    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
158        formatter.write_str(self.as_str())
159    }
160}
161
162impl FromStr for PgExtensionVersion {
163    type Err = PgExtensionError;
164
165    fn from_str(input: &str) -> Result<Self, Self::Err> {
166        Self::new(input)
167    }
168}
169
170/// PostgreSQL extension metadata.
171#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
172pub struct PgExtension {
173    name: PgExtensionName,
174    version: Option<PgExtensionVersion>,
175    schema: Option<PgSchemaName>,
176    relocatable: Option<bool>,
177}
178
179impl PgExtension {
180    /// Creates extension metadata from a name.
181    #[must_use]
182    pub const fn new(name: PgExtensionName) -> Self {
183        Self {
184            name,
185            version: None,
186            schema: None,
187            relocatable: None,
188        }
189    }
190
191    /// Adds an extension version label.
192    #[must_use]
193    pub fn with_version(mut self, version: PgExtensionVersion) -> Self {
194        self.version = Some(version);
195        self
196    }
197
198    /// Adds schema metadata.
199    #[must_use]
200    pub fn with_schema(mut self, schema: PgSchemaName) -> Self {
201        self.schema = Some(schema);
202        self
203    }
204
205    /// Adds relocatable metadata.
206    #[must_use]
207    pub const fn with_relocatable(mut self, relocatable: bool) -> Self {
208        self.relocatable = Some(relocatable);
209        self
210    }
211
212    /// Returns the extension name.
213    #[must_use]
214    pub const fn name(&self) -> &PgExtensionName {
215        &self.name
216    }
217
218    /// Returns the optional version label.
219    #[must_use]
220    pub const fn version(&self) -> Option<&PgExtensionVersion> {
221        self.version.as_ref()
222    }
223
224    /// Returns the optional schema metadata.
225    #[must_use]
226    pub const fn schema(&self) -> Option<&PgSchemaName> {
227        self.schema.as_ref()
228    }
229
230    /// Returns the optional relocatable metadata.
231    #[must_use]
232    pub const fn relocatable(&self) -> Option<bool> {
233        self.relocatable
234    }
235}
236
237/// Error returned when PostgreSQL extension metadata is invalid.
238#[derive(Clone, Debug, Eq, PartialEq)]
239pub enum PgExtensionError {
240    EmptyName,
241    EmptyVersion,
242    InvalidNameCharacter { index: usize, character: char },
243    ControlCharacter,
244}
245
246impl fmt::Display for PgExtensionError {
247    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
248        match self {
249            Self::EmptyName => formatter.write_str("PostgreSQL extension name cannot be empty"),
250            Self::EmptyVersion => {
251                formatter.write_str("PostgreSQL extension version cannot be empty")
252            }
253            Self::InvalidNameCharacter { index, character } => write!(
254                formatter,
255                "PostgreSQL extension name contains invalid character {character:?} at byte index {index}"
256            ),
257            Self::ControlCharacter => {
258                formatter.write_str("PostgreSQL extension label cannot contain control characters")
259            }
260        }
261    }
262}
263
264impl Error for PgExtensionError {}
265
266fn validate_extension_name(input: &str) -> Result<&str, PgExtensionError> {
267    let trimmed = input.trim();
268    if trimmed.is_empty() {
269        return Err(PgExtensionError::EmptyName);
270    }
271    for (index, character) in trimmed.char_indices() {
272        if character.is_control() {
273            return Err(PgExtensionError::ControlCharacter);
274        }
275        if !(character.is_ascii_alphanumeric() || matches!(character, '_' | '-')) {
276            return Err(PgExtensionError::InvalidNameCharacter { index, character });
277        }
278    }
279    Ok(trimmed)
280}
281
282fn validate_version(input: &str) -> Result<&str, PgExtensionError> {
283    let trimmed = input.trim();
284    if trimmed.is_empty() {
285        return Err(PgExtensionError::EmptyVersion);
286    }
287    if trimmed.chars().any(char::is_control) {
288        return Err(PgExtensionError::ControlCharacter);
289    }
290    Ok(trimmed)
291}
292
293#[cfg(test)]
294mod tests {
295    use super::{
296        CITEXT_EXTENSION, PGCRYPTO_EXTENSION, PgExtension, PgExtensionError, PgExtensionName,
297        PgExtensionVersion, UUID_OSSP_EXTENSION,
298    };
299    use use_pg_schema::PgSchemaName;
300
301    #[test]
302    fn exposes_common_extension_names() {
303        assert_eq!(PgExtensionName::uuid_ossp().as_str(), UUID_OSSP_EXTENSION);
304        assert_eq!(PgExtensionName::pgcrypto().as_str(), PGCRYPTO_EXTENSION);
305        assert_eq!(PgExtensionName::citext().as_str(), CITEXT_EXTENSION);
306    }
307
308    #[test]
309    fn parses_and_renders_versions() -> Result<(), PgExtensionError> {
310        let version: PgExtensionVersion = "1.6".parse()?;
311        assert_eq!(version.as_str(), "1.6");
312        assert_eq!(version.to_string(), "1.6");
313        assert_eq!(
314            PgExtensionVersion::new(""),
315            Err(PgExtensionError::EmptyVersion)
316        );
317        Ok(())
318    }
319
320    #[test]
321    fn creates_extension_metadata() -> Result<(), PgExtensionError> {
322        let extension = PgExtension::new(PgExtensionName::postgis())
323            .with_version(PgExtensionVersion::new("3.5.0")?)
324            .with_schema(PgSchemaName::public())
325            .with_relocatable(false);
326
327        assert_eq!(extension.name().as_str(), "postgis");
328        assert_eq!(
329            extension.version().map(PgExtensionVersion::as_str),
330            Some("3.5.0")
331        );
332        assert_eq!(extension.schema(), Some(&PgSchemaName::public()));
333        assert_eq!(extension.relocatable(), Some(false));
334        Ok(())
335    }
336}