Skip to main content

hyperdb_api_core/protocol/
escape.rs

1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! SQL escaping utilities.
5//!
6//! This module provides zero-cost wrapper types for safe SQL escaping.
7//! Using the newtype pattern with [`std::fmt::Display`] ensures identifiers
8//! and literals are properly escaped at format-time without extra allocations.
9//!
10//! # Why Newtype + Display?
11//!
12//! The alternative -- a function like `fn escape_identifier(s: &str) -> String`
13//! -- allocates immediately even when the result is only used inside a larger
14//! `format!()` call. The newtype pattern defers escaping to `Display::fmt`,
15//! so the escaped output is written directly into the destination buffer.
16//! This is the same approach used by `std::path::Path::display()`.
17//!
18//! The convenience functions [`escape_identifier`] and [`escape_literal`] are
19//! provided for cases where a `String` is needed directly.
20
21use std::fmt;
22
23/// A wrapper that ensures a SQL identifier is properly escaped when formatted.
24///
25/// This is a zero-cost abstraction that performs escaping lazily during formatting.
26/// Identifiers are conditionally quoted:
27/// - Simple lowercase identifiers (`users`, `my_table`) are not quoted
28/// - Identifiers with uppercase letters are quoted to preserve case
29/// - Identifiers with special characters are quoted
30///
31/// # Example
32///
33/// ```
34/// use hyperdb_api_core::protocol::escape::SqlIdentifier;
35///
36/// // Simple identifiers are not quoted
37/// assert_eq!(format!("{}", SqlIdentifier("users")), "users");
38/// assert_eq!(format!("{}", SqlIdentifier("my_table")), "my_table");
39///
40/// // Uppercase letters are quoted to preserve case
41/// assert_eq!(format!("{}", SqlIdentifier("Segment")), "\"Segment\"");
42///
43/// // Special characters require quoting
44/// assert_eq!(format!("{}", SqlIdentifier("my-table")), "\"my-table\"");
45/// assert_eq!(format!("{}", SqlIdentifier("my table")), "\"my table\"");
46///
47/// // Internal quotes are escaped
48/// assert_eq!(format!("{}", SqlIdentifier("my\"table")), "\"my\"\"table\"");
49/// ```
50#[derive(Debug, Clone, Copy)]
51pub struct SqlIdentifier<'a>(pub &'a str);
52
53impl fmt::Display for SqlIdentifier<'_> {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        // Check if identifier needs quoting:
56        // 1. Not a valid unquoted identifier (has spaces, hyphens, etc.)
57        // 2. Contains uppercase letters (to preserve case - PostgreSQL case-folds unquoted identifiers)
58        let needs_quoting =
59            !is_valid_unquoted_identifier(self.0) || self.0.chars().any(char::is_uppercase);
60
61        if needs_quoting {
62            f.write_str("\"")?;
63            for c in self.0.chars() {
64                if c == '"' {
65                    f.write_str("\"\"")?;
66                } else {
67                    write!(f, "{c}")?;
68                }
69            }
70            f.write_str("\"")
71        } else {
72            f.write_str(self.0)
73        }
74    }
75}
76
77/// A wrapper that ensures a SQL string literal is properly escaped when formatted.
78///
79/// This wraps the string in single quotes and escapes any internal single quotes.
80///
81/// # Example
82///
83/// ```no_run
84/// // Marked `no_run` to dodge a Windows Defender heuristic that intermittently
85/// // refuses to launch this specific compiled doctest binary with
86/// // `ERROR_ACCESS_DENIED`. The same assertions are exercised by
87/// // `tests::test_sql_literal_display` so coverage is preserved.
88/// use hyperdb_api_core::protocol::escape::SqlLiteral;
89///
90/// assert_eq!(format!("{}", SqlLiteral("hello")), "'hello'");
91/// assert_eq!(format!("{}", SqlLiteral("it's")), "'it''s'");
92/// ```
93#[derive(Debug, Clone, Copy)]
94pub struct SqlLiteral<'a>(pub &'a str);
95
96impl fmt::Display for SqlLiteral<'_> {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        f.write_str("'")?;
99        for c in self.0.chars() {
100            if c == '\'' {
101                f.write_str("''")?;
102            } else {
103                write!(f, "{c}")?;
104            }
105        }
106        f.write_str("'")
107    }
108}
109
110/// Checks if a string is a valid unquoted identifier.
111///
112/// Valid unquoted identifiers:
113/// - Start with a letter (a-z, A-Z) or underscore
114/// - Contain only letters, digits (0-9), underscores, and dollar signs
115/// - Are not SQL reserved words (this function doesn't check for reserved words)
116#[must_use]
117pub fn is_valid_unquoted_identifier(s: &str) -> bool {
118    if s.is_empty() {
119        return false;
120    }
121
122    let mut chars = s.chars();
123
124    // First character must be letter or underscore
125    match chars.next() {
126        Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
127        _ => return false,
128    }
129
130    // Rest can be letters, digits, underscores, or dollar signs
131    chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$')
132}
133
134/// Formats a qualified table name with proper escaping.
135///
136/// # Arguments
137///
138/// * `database` - Optional database name
139/// * `schema` - Optional schema name
140/// * `table` - Table name
141///
142/// # Example
143///
144/// ```
145/// use hyperdb_api_core::protocol::escape::format_table_name;
146///
147/// assert_eq!(format_table_name(None, None, "users"), "users");
148/// assert_eq!(format_table_name(None, Some("public"), "users"), "public.users");
149/// assert_eq!(format_table_name(Some("mydb"), Some("public"), "users"), "mydb.public.users");
150/// assert_eq!(format_table_name(None, None, "my-table"), "\"my-table\"");
151/// ```
152#[must_use]
153pub fn format_table_name(database: Option<&str>, schema: Option<&str>, table: &str) -> String {
154    match (database, schema) {
155        (Some(db), Some(s)) => format!(
156            "{}.{}.{}",
157            SqlIdentifier(db),
158            SqlIdentifier(s),
159            SqlIdentifier(table)
160        ),
161        (None, Some(s)) => format!("{}.{}", SqlIdentifier(s), SqlIdentifier(table)),
162        (Some(db), None) => format!("{}.{}", SqlIdentifier(db), SqlIdentifier(table)),
163        (None, None) => format!("{}", SqlIdentifier(table)),
164    }
165}
166
167// Backward compatibility functions - can be removed if not needed externally
168
169/// Escapes a SQL identifier (table name, column name, etc.).
170///
171/// This is a convenience function that returns the escaped identifier as a String.
172/// For more efficient formatting, use `SqlIdentifier` directly in format strings.
173///
174/// # Example
175///
176/// ```
177/// use hyperdb_api_core::protocol::escape::escape_identifier;
178///
179/// assert_eq!(escape_identifier("table"), "table");
180/// assert_eq!(escape_identifier("Segment"), "\"Segment\"");
181/// ```
182#[must_use]
183pub fn escape_identifier(identifier: &str) -> String {
184    format!("{}", SqlIdentifier(identifier))
185}
186
187/// Escapes a SQL string literal.
188///
189/// This is a convenience function that returns the escaped literal as a String.
190/// For more efficient formatting, use `SqlLiteral` directly in format strings.
191///
192/// # Example
193///
194/// ```
195/// use hyperdb_api_core::protocol::escape::escape_literal;
196///
197/// assert_eq!(escape_literal("hello"), "'hello'");
198/// assert_eq!(escape_literal("it's"), "'it''s'");
199/// ```
200#[must_use]
201pub fn escape_literal(literal: &str) -> String {
202    format!("{}", SqlLiteral(literal))
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_sql_identifier_display() {
211        // Valid unquoted identifiers with only lowercase should not be quoted
212        assert_eq!(format!("{}", SqlIdentifier("table")), "table");
213        assert_eq!(format!("{}", SqlIdentifier("my_table")), "my_table");
214        assert_eq!(format!("{}", SqlIdentifier("table1")), "table1");
215        assert_eq!(format!("{}", SqlIdentifier("_private")), "_private");
216        assert_eq!(format!("{}", SqlIdentifier("my$var")), "my$var");
217
218        // Identifiers with uppercase letters should be quoted to preserve case
219        assert_eq!(format!("{}", SqlIdentifier("Segment")), "\"Segment\"");
220        assert_eq!(format!("{}", SqlIdentifier("CustomerID")), "\"CustomerID\"");
221        assert_eq!(format!("{}", SqlIdentifier("Table")), "\"Table\"");
222
223        // Invalid unquoted identifiers should be quoted
224        assert_eq!(format!("{}", SqlIdentifier("my-table")), "\"my-table\"");
225        assert_eq!(format!("{}", SqlIdentifier("my table")), "\"my table\"");
226        assert_eq!(format!("{}", SqlIdentifier("1table")), "\"1table\"");
227        assert_eq!(format!("{}", SqlIdentifier("my\"table")), "\"my\"\"table\"");
228        assert_eq!(format!("{}", SqlIdentifier("")), "\"\"");
229    }
230
231    #[test]
232    fn test_sql_literal_display() {
233        assert_eq!(format!("{}", SqlLiteral("hello")), "'hello'");
234        assert_eq!(format!("{}", SqlLiteral("it's")), "'it''s'");
235        assert_eq!(format!("{}", SqlLiteral("")), "''");
236    }
237
238    #[test]
239    fn test_is_valid_unquoted_identifier() {
240        assert!(is_valid_unquoted_identifier("table"));
241        assert!(is_valid_unquoted_identifier("_private"));
242        assert!(is_valid_unquoted_identifier("table1"));
243        assert!(is_valid_unquoted_identifier("my$var"));
244
245        assert!(!is_valid_unquoted_identifier(""));
246        assert!(!is_valid_unquoted_identifier("1table"));
247        assert!(!is_valid_unquoted_identifier("my-table"));
248        assert!(!is_valid_unquoted_identifier("my table"));
249    }
250
251    #[test]
252    fn test_format_table_name() {
253        assert_eq!(format_table_name(None, None, "users"), "users");
254        assert_eq!(
255            format_table_name(None, Some("public"), "users"),
256            "public.users"
257        );
258        assert_eq!(
259            format_table_name(Some("mydb"), Some("public"), "users"),
260            "mydb.public.users"
261        );
262        // Test with names that need quoting
263        assert_eq!(format_table_name(None, None, "my-table"), "\"my-table\"");
264        assert_eq!(
265            format_table_name(None, Some("my schema"), "users"),
266            "\"my schema\".users"
267        );
268    }
269
270    #[test]
271    fn test_sql_identifier_in_format() {
272        // Demonstrate zero-allocation composability
273        let table = "users";
274        let column = "Customer ID";
275        let sql = format!(
276            "SELECT {} FROM {}",
277            SqlIdentifier(column),
278            SqlIdentifier(table)
279        );
280        assert_eq!(sql, "SELECT \"Customer ID\" FROM users");
281    }
282
283    // Backward compat function tests
284    #[test]
285    fn test_escape_identifier() {
286        assert_eq!(escape_identifier("table"), "table");
287        assert_eq!(escape_identifier("Segment"), "\"Segment\"");
288    }
289
290    #[test]
291    fn test_escape_literal() {
292        assert_eq!(escape_literal("hello"), "'hello'");
293        assert_eq!(escape_literal("it's"), "'it''s'");
294    }
295}