Skip to main content

use_sql_ident/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7/// A validated SQL identifier segment.
8#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
9pub struct SqlIdentifier(String);
10
11impl SqlIdentifier {
12    /// Creates an identifier segment from conservative SQL identifier text.
13    ///
14    /// # Errors
15    ///
16    /// Returns [`SqlIdentifierError`] when the value is empty, contains a dot,
17    /// or contains control characters.
18    pub fn new(input: impl AsRef<str>) -> Result<Self, SqlIdentifierError> {
19        validate_identifier_text(input.as_ref()).map(|value| Self(value.to_owned()))
20    }
21
22    /// Returns the stored identifier text.
23    #[must_use]
24    pub fn as_str(&self) -> &str {
25        &self.0
26    }
27
28    /// Consumes the identifier and returns the stored text.
29    #[must_use]
30    pub fn into_string(self) -> String {
31        self.0
32    }
33}
34
35impl AsRef<str> for SqlIdentifier {
36    fn as_ref(&self) -> &str {
37        self.as_str()
38    }
39}
40
41impl fmt::Display for SqlIdentifier {
42    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
43        formatter.write_str(self.as_str())
44    }
45}
46
47impl FromStr for SqlIdentifier {
48    type Err = SqlIdentifierError;
49
50    fn from_str(input: &str) -> Result<Self, Self::Err> {
51        Self::new(input)
52    }
53}
54
55impl TryFrom<&str> for SqlIdentifier {
56    type Error = SqlIdentifierError;
57
58    fn try_from(value: &str) -> Result<Self, Self::Error> {
59        Self::new(value)
60    }
61}
62
63/// A dot-qualified SQL name.
64#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
65pub struct SqlQualifiedName {
66    parts: Vec<SqlIdentifier>,
67}
68
69impl SqlQualifiedName {
70    /// Creates a qualified name from one or more identifier parts.
71    ///
72    /// # Errors
73    ///
74    /// Returns [`SqlIdentifierError::EmptyQualifiedName`] when `parts` is empty.
75    pub fn new(parts: Vec<SqlIdentifier>) -> Result<Self, SqlIdentifierError> {
76        if parts.is_empty() {
77            return Err(SqlIdentifierError::EmptyQualifiedName);
78        }
79
80        Ok(Self { parts })
81    }
82
83    /// Parses a dot-qualified name using conservative dot splitting.
84    ///
85    /// # Errors
86    ///
87    /// Returns [`SqlIdentifierError`] when any segment is invalid.
88    pub fn parse(input: &str) -> Result<Self, SqlIdentifierError> {
89        let trimmed = input.trim();
90        if trimmed.is_empty() {
91            return Err(SqlIdentifierError::EmptyQualifiedName);
92        }
93
94        let parts = trimmed
95            .split('.')
96            .map(SqlIdentifier::new)
97            .collect::<Result<Vec<_>, _>>()?;
98        Self::new(parts)
99    }
100
101    /// Returns the identifier parts.
102    #[must_use]
103    pub fn parts(&self) -> &[SqlIdentifier] {
104        &self.parts
105    }
106}
107
108impl fmt::Display for SqlQualifiedName {
109    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
110        let mut parts = self.parts.iter();
111        if let Some(first) = parts.next() {
112            write!(formatter, "{first}")?;
113        }
114        for part in parts {
115            write!(formatter, ".{part}")?;
116        }
117        Ok(())
118    }
119}
120
121impl FromStr for SqlQualifiedName {
122    type Err = SqlIdentifierError;
123
124    fn from_str(input: &str) -> Result<Self, Self::Err> {
125        Self::parse(input)
126    }
127}
128
129impl TryFrom<&str> for SqlQualifiedName {
130    type Error = SqlIdentifierError;
131
132    fn try_from(value: &str) -> Result<Self, Self::Error> {
133        Self::parse(value)
134    }
135}
136
137/// A SQL alias identifier.
138#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
139pub struct SqlAlias(SqlIdentifier);
140
141impl SqlAlias {
142    /// Creates an alias from identifier text.
143    ///
144    /// # Errors
145    ///
146    /// Returns [`SqlIdentifierError`] when validation fails.
147    pub fn new(input: impl AsRef<str>) -> Result<Self, SqlIdentifierError> {
148        SqlIdentifier::new(input).map(Self)
149    }
150
151    /// Returns the alias as an identifier.
152    #[must_use]
153    pub const fn identifier(&self) -> &SqlIdentifier {
154        &self.0
155    }
156
157    /// Returns the stored alias text.
158    #[must_use]
159    pub fn as_str(&self) -> &str {
160        self.0.as_str()
161    }
162}
163
164impl AsRef<str> for SqlAlias {
165    fn as_ref(&self) -> &str {
166        self.as_str()
167    }
168}
169
170impl fmt::Display for SqlAlias {
171    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
172        self.0.fmt(formatter)
173    }
174}
175
176impl FromStr for SqlAlias {
177    type Err = SqlIdentifierError;
178
179    fn from_str(input: &str) -> Result<Self, Self::Err> {
180        Self::new(input)
181    }
182}
183
184impl TryFrom<&str> for SqlAlias {
185    type Error = SqlIdentifierError;
186
187    fn try_from(value: &str) -> Result<Self, Self::Error> {
188        Self::new(value)
189    }
190}
191
192/// Error returned when SQL identifier text is rejected.
193#[derive(Clone, Copy, Debug, Eq, PartialEq)]
194pub enum SqlIdentifierError {
195    /// The supplied value was empty after trimming.
196    Empty,
197    /// A single identifier segment cannot contain `.`.
198    ContainsDot,
199    /// A qualified name requires at least one segment.
200    EmptyQualifiedName,
201    /// The supplied value contained a control character.
202    ControlCharacter {
203        /// Byte index of the rejected character.
204        index: usize,
205        /// The rejected character.
206        character: char,
207    },
208}
209
210impl fmt::Display for SqlIdentifierError {
211    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
212        match self {
213            Self::Empty => formatter.write_str("SQL identifier cannot be empty"),
214            Self::ContainsDot => formatter.write_str("SQL identifier segment cannot contain a dot"),
215            Self::EmptyQualifiedName => formatter.write_str("SQL qualified name cannot be empty"),
216            Self::ControlCharacter { index, character } => write!(
217                formatter,
218                "SQL identifier contains control character {character:?} at byte index {index}"
219            ),
220        }
221    }
222}
223
224impl Error for SqlIdentifierError {}
225
226/// Returns `true` when `input` is conservatively valid as an unquoted SQL identifier.
227#[must_use]
228pub fn is_valid_unquoted_ident(input: &str) -> bool {
229    validate_unquoted_ident(input).is_ok()
230}
231
232/// Returns `true` when an identifier should be quoted for conservative SQL rendering.
233#[must_use]
234pub fn needs_quoting(input: &str) -> bool {
235    !is_valid_unquoted_ident(input)
236}
237
238/// Quotes an identifier with SQL double quotes, doubling embedded double quotes.
239#[must_use]
240pub fn quote_ident(input: &str) -> String {
241    let trimmed = input.trim();
242    let mut quoted = String::with_capacity(trimmed.len() + 2);
243    quoted.push('"');
244    for character in trimmed.chars() {
245        if character == '"' {
246            quoted.push('"');
247        }
248        quoted.push(character);
249    }
250    quoted.push('"');
251    quoted
252}
253
254/// Normalizes an identifier for simple display-oriented comparisons.
255#[must_use]
256pub fn normalize_ident(input: &str) -> String {
257    let trimmed = input.trim();
258    if is_valid_unquoted_ident(trimmed) {
259        trimmed.to_ascii_lowercase()
260    } else {
261        quote_ident(trimmed)
262    }
263}
264
265fn validate_identifier_text(input: &str) -> Result<&str, SqlIdentifierError> {
266    let trimmed = input.trim();
267    if trimmed.is_empty() {
268        return Err(SqlIdentifierError::Empty);
269    }
270    if trimmed.contains('.') {
271        return Err(SqlIdentifierError::ContainsDot);
272    }
273    if let Some((index, character)) = trimmed
274        .char_indices()
275        .find(|(_, character)| character.is_control())
276    {
277        return Err(SqlIdentifierError::ControlCharacter { index, character });
278    }
279    Ok(trimmed)
280}
281
282fn validate_unquoted_ident(input: &str) -> Result<(), SqlIdentifierError> {
283    let trimmed = validate_identifier_text(input)?;
284    let mut characters = trimmed.chars();
285    let Some(first) = characters.next() else {
286        return Err(SqlIdentifierError::Empty);
287    };
288    if !(first == '_' || first.is_ascii_alphabetic()) {
289        return Err(SqlIdentifierError::Empty);
290    }
291    if !characters.all(|character| character == '_' || character.is_ascii_alphanumeric()) {
292        return Err(SqlIdentifierError::Empty);
293    }
294    if is_reserved_like(trimmed) {
295        return Err(SqlIdentifierError::Empty);
296    }
297    Ok(())
298}
299
300fn is_reserved_like(input: &str) -> bool {
301    matches!(
302        input.trim().to_ascii_uppercase().as_str(),
303        "SELECT"
304            | "INSERT"
305            | "UPDATE"
306            | "DELETE"
307            | "CREATE"
308            | "ALTER"
309            | "DROP"
310            | "TABLE"
311            | "VIEW"
312            | "INDEX"
313            | "WHERE"
314            | "FROM"
315            | "JOIN"
316            | "GROUP"
317            | "ORDER"
318            | "LIMIT"
319            | "OFFSET"
320            | "RETURNING"
321            | "PRIMARY"
322            | "FOREIGN"
323            | "KEY"
324            | "UNIQUE"
325            | "NOT"
326            | "NULL"
327            | "CHECK"
328            | "DEFAULT"
329    )
330}
331
332#[cfg(test)]
333mod tests {
334    use super::{
335        SqlIdentifier, SqlIdentifierError, SqlQualifiedName, is_valid_unquoted_ident,
336        needs_quoting, normalize_ident, quote_ident,
337    };
338
339    #[test]
340    fn validates_identifier_text() -> Result<(), SqlIdentifierError> {
341        let identifier = SqlIdentifier::new(" users ")?;
342        assert_eq!(identifier.as_str(), "users");
343        assert_eq!(SqlIdentifier::new(""), Err(SqlIdentifierError::Empty));
344        assert_eq!(
345            SqlIdentifier::new("public.users"),
346            Err(SqlIdentifierError::ContainsDot)
347        );
348        Ok(())
349    }
350
351    #[test]
352    fn checks_unquoted_identifiers() {
353        assert!(is_valid_unquoted_ident("users_1"));
354        assert!(!is_valid_unquoted_ident("1users"));
355        assert!(!is_valid_unquoted_ident("select"));
356        assert!(needs_quoting("order items"));
357    }
358
359    #[test]
360    fn quotes_and_normalizes_identifiers() {
361        assert_eq!(quote_ident("user\"name"), "\"user\"\"name\"");
362        assert_eq!(normalize_ident("Users"), "users");
363        assert_eq!(normalize_ident("select"), "\"select\"");
364    }
365
366    #[test]
367    fn parses_qualified_names() -> Result<(), SqlIdentifierError> {
368        let qualified = SqlQualifiedName::parse("public.users")?;
369        assert_eq!(qualified.parts().len(), 2);
370        assert_eq!(qualified.to_string(), "public.users");
371        assert!(SqlQualifiedName::parse("public.").is_err());
372        Ok(())
373    }
374}