Skip to main content

use_sql_param/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7use use_sql_ident::{SqlIdentifier, SqlIdentifierError, is_valid_unquoted_ident};
8
9/// SQL parameter placeholder styles.
10#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)]
11pub enum SqlParameterStyle {
12    #[default]
13    PostgresIndexed,
14    PositionalQuestionMark,
15    NamedColon,
16    NamedAtSign,
17}
18
19/// A one-based SQL parameter index.
20#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
21pub struct SqlParameterIndex(u32);
22
23impl SqlParameterIndex {
24    /// Creates a one-based parameter index.
25    ///
26    /// # Errors
27    ///
28    /// Returns [`SqlParameterError::ZeroIndex`] when `index` is zero.
29    pub const fn new(index: u32) -> Result<Self, SqlParameterError> {
30        if index == 0 {
31            Err(SqlParameterError::ZeroIndex)
32        } else {
33            Ok(Self(index))
34        }
35    }
36
37    /// Returns the one-based parameter index.
38    #[must_use]
39    pub const fn get(self) -> u32 {
40        self.0
41    }
42}
43
44impl fmt::Display for SqlParameterIndex {
45    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
46        write!(formatter, "{}", self.0)
47    }
48}
49
50/// A named SQL parameter identifier.
51#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
52pub struct SqlParameterName(SqlIdentifier);
53
54impl SqlParameterName {
55    /// Creates a named SQL parameter.
56    ///
57    /// # Errors
58    ///
59    /// Returns [`SqlParameterError`] when the parameter name is empty or not
60    /// conservatively unquoted-identifier-shaped.
61    pub fn new(input: impl AsRef<str>) -> Result<Self, SqlParameterError> {
62        let input = input.as_ref();
63        if !is_valid_unquoted_ident(input) {
64            return Err(SqlParameterError::InvalidName);
65        }
66
67        SqlIdentifier::new(input)
68            .map(Self)
69            .map_err(SqlParameterError::Identifier)
70    }
71
72    /// Returns the parameter name.
73    #[must_use]
74    pub fn as_str(&self) -> &str {
75        self.0.as_str()
76    }
77}
78
79impl AsRef<str> for SqlParameterName {
80    fn as_ref(&self) -> &str {
81        self.as_str()
82    }
83}
84
85impl fmt::Display for SqlParameterName {
86    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
87        self.0.fmt(formatter)
88    }
89}
90
91impl FromStr for SqlParameterName {
92    type Err = SqlParameterError;
93
94    fn from_str(input: &str) -> Result<Self, Self::Err> {
95        Self::new(input)
96    }
97}
98
99/// SQL parameter placeholders.
100#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
101pub enum SqlParameter {
102    PostgresIndexed(SqlParameterIndex),
103    PositionalQuestionMark,
104    NamedColon(SqlParameterName),
105    NamedAtSign(SqlParameterName),
106}
107
108impl SqlParameter {
109    /// Creates a PostgreSQL-style indexed parameter such as `$1`.
110    ///
111    /// # Errors
112    ///
113    /// Returns [`SqlParameterError::ZeroIndex`] when `index` is zero.
114    pub const fn postgres_indexed(index: u32) -> Result<Self, SqlParameterError> {
115        match SqlParameterIndex::new(index) {
116            Ok(index) => Ok(Self::PostgresIndexed(index)),
117            Err(error) => Err(error),
118        }
119    }
120
121    /// Creates a positional question-mark parameter.
122    #[must_use]
123    pub const fn positional() -> Self {
124        Self::PositionalQuestionMark
125    }
126
127    /// Creates a colon-prefixed named parameter.
128    ///
129    /// # Errors
130    ///
131    /// Returns [`SqlParameterError`] when name validation fails.
132    pub fn named_colon(name: impl AsRef<str>) -> Result<Self, SqlParameterError> {
133        SqlParameterName::new(name).map(Self::NamedColon)
134    }
135
136    /// Creates an at-sign-prefixed named parameter.
137    ///
138    /// # Errors
139    ///
140    /// Returns [`SqlParameterError`] when name validation fails.
141    pub fn named_at(name: impl AsRef<str>) -> Result<Self, SqlParameterError> {
142        SqlParameterName::new(name).map(Self::NamedAtSign)
143    }
144
145    /// Returns the placeholder style.
146    #[must_use]
147    pub const fn style(&self) -> SqlParameterStyle {
148        match self {
149            Self::PostgresIndexed(_) => SqlParameterStyle::PostgresIndexed,
150            Self::PositionalQuestionMark => SqlParameterStyle::PositionalQuestionMark,
151            Self::NamedColon(_) => SqlParameterStyle::NamedColon,
152            Self::NamedAtSign(_) => SqlParameterStyle::NamedAtSign,
153        }
154    }
155}
156
157impl fmt::Display for SqlParameter {
158    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
159        match self {
160            Self::PostgresIndexed(index) => write!(formatter, "${index}"),
161            Self::PositionalQuestionMark => formatter.write_str("?"),
162            Self::NamedColon(name) => write!(formatter, ":{name}"),
163            Self::NamedAtSign(name) => write!(formatter, "@{name}"),
164        }
165    }
166}
167
168impl FromStr for SqlParameter {
169    type Err = SqlParameterError;
170
171    fn from_str(input: &str) -> Result<Self, Self::Err> {
172        let trimmed = input.trim();
173        if trimmed.is_empty() {
174            return Err(SqlParameterError::Empty);
175        }
176        if trimmed == "?" {
177            return Ok(Self::positional());
178        }
179        if let Some(index) = trimmed.strip_prefix('$') {
180            if index.is_empty() || !index.chars().all(|character| character.is_ascii_digit()) {
181                return Err(SqlParameterError::InvalidIndexed);
182            }
183            let index = index
184                .parse::<u32>()
185                .map_err(|_| SqlParameterError::InvalidIndexed)?;
186            return Self::postgres_indexed(index);
187        }
188        if let Some(name) = trimmed.strip_prefix(':') {
189            return Self::named_colon(name);
190        }
191        if let Some(name) = trimmed.strip_prefix('@') {
192            return Self::named_at(name);
193        }
194        Err(SqlParameterError::UnknownStyle)
195    }
196}
197
198impl TryFrom<&str> for SqlParameter {
199    type Error = SqlParameterError;
200
201    fn try_from(value: &str) -> Result<Self, Self::Error> {
202        value.parse()
203    }
204}
205
206/// Error returned when SQL parameter placeholders are invalid.
207#[derive(Clone, Debug, Eq, PartialEq)]
208pub enum SqlParameterError {
209    Empty,
210    ZeroIndex,
211    InvalidIndexed,
212    InvalidName,
213    UnknownStyle,
214    Identifier(SqlIdentifierError),
215}
216
217impl fmt::Display for SqlParameterError {
218    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
219        match self {
220            Self::Empty => formatter.write_str("SQL parameter placeholder cannot be empty"),
221            Self::ZeroIndex => formatter.write_str("SQL parameter indexes are one-based"),
222            Self::InvalidIndexed => formatter.write_str("invalid indexed SQL parameter"),
223            Self::InvalidName => formatter.write_str("invalid SQL parameter name"),
224            Self::UnknownStyle => formatter.write_str("unknown SQL parameter placeholder style"),
225            Self::Identifier(error) => {
226                write!(formatter, "invalid SQL parameter identifier: {error}")
227            },
228        }
229    }
230}
231
232impl Error for SqlParameterError {}
233
234#[cfg(test)]
235mod tests {
236    use super::{SqlParameter, SqlParameterError, SqlParameterStyle};
237
238    #[test]
239    fn parses_parameter_styles() -> Result<(), SqlParameterError> {
240        assert_eq!("$1".parse::<SqlParameter>()?.to_string(), "$1");
241        assert_eq!(
242            "?".parse::<SqlParameter>()?.style(),
243            SqlParameterStyle::PositionalQuestionMark
244        );
245        assert_eq!(":user_id".parse::<SqlParameter>()?.to_string(), ":user_id");
246        assert_eq!("@user_id".parse::<SqlParameter>()?.to_string(), "@user_id");
247        Ok(())
248    }
249
250    #[test]
251    fn rejects_invalid_parameters() {
252        assert_eq!(
253            "$0".parse::<SqlParameter>(),
254            Err(SqlParameterError::ZeroIndex)
255        );
256        assert_eq!(
257            "$abc".parse::<SqlParameter>(),
258            Err(SqlParameterError::InvalidIndexed)
259        );
260        assert_eq!(
261            ":select".parse::<SqlParameter>(),
262            Err(SqlParameterError::InvalidName)
263        );
264    }
265}