hyperdb_api_core/protocol/escape.rs
1// Copyright (c) 2026, Salesforce, Inc. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! SQL escaping utilities.
5//!
6//! This module provides zero-cost wrapper types for safe SQL escaping.
7//! Using the newtype pattern with [`std::fmt::Display`] ensures identifiers
8//! and literals are properly escaped at format-time without extra allocations.
9//!
10//! # Why Newtype + Display?
11//!
12//! The alternative -- a function like `fn escape_identifier(s: &str) -> String`
13//! -- allocates immediately even when the result is only used inside a larger
14//! `format!()` call. The newtype pattern defers escaping to `Display::fmt`,
15//! so the escaped output is written directly into the destination buffer.
16//! This is the same approach used by `std::path::Path::display()`.
17//!
18//! The convenience functions [`escape_identifier`] and [`escape_literal`] are
19//! provided for cases where a `String` is needed directly.
20
21use std::fmt;
22
23/// A wrapper that ensures a SQL identifier is properly escaped when formatted.
24///
25/// This is a zero-cost abstraction that performs escaping lazily during formatting.
26/// Identifiers are conditionally quoted:
27/// - Simple lowercase identifiers (`users`, `my_table`) are not quoted
28/// - Identifiers with uppercase letters are quoted to preserve case
29/// - Identifiers with special characters are quoted
30///
31/// # Example
32///
33/// ```
34/// use hyperdb_api_core::protocol::escape::SqlIdentifier;
35///
36/// // Simple identifiers are not quoted
37/// assert_eq!(format!("{}", SqlIdentifier("users")), "users");
38/// assert_eq!(format!("{}", SqlIdentifier("my_table")), "my_table");
39///
40/// // Uppercase letters are quoted to preserve case
41/// assert_eq!(format!("{}", SqlIdentifier("Segment")), "\"Segment\"");
42///
43/// // Special characters require quoting
44/// assert_eq!(format!("{}", SqlIdentifier("my-table")), "\"my-table\"");
45/// assert_eq!(format!("{}", SqlIdentifier("my table")), "\"my table\"");
46///
47/// // Internal quotes are escaped
48/// assert_eq!(format!("{}", SqlIdentifier("my\"table")), "\"my\"\"table\"");
49/// ```
50#[derive(Debug, Clone, Copy)]
51pub struct SqlIdentifier<'a>(pub &'a str);
52
53impl fmt::Display for SqlIdentifier<'_> {
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 // Check if identifier needs quoting:
56 // 1. Not a valid unquoted identifier (has spaces, hyphens, etc.)
57 // 2. Contains uppercase letters (to preserve case - PostgreSQL case-folds unquoted identifiers)
58 let needs_quoting =
59 !is_valid_unquoted_identifier(self.0) || self.0.chars().any(char::is_uppercase);
60
61 if needs_quoting {
62 f.write_str("\"")?;
63 for c in self.0.chars() {
64 if c == '"' {
65 f.write_str("\"\"")?;
66 } else {
67 write!(f, "{c}")?;
68 }
69 }
70 f.write_str("\"")
71 } else {
72 f.write_str(self.0)
73 }
74 }
75}
76
77/// A wrapper that ensures a SQL string literal is properly escaped when formatted.
78///
79/// This wraps the string in single quotes and escapes any internal single quotes.
80///
81/// # Example
82///
83/// ```no_run
84/// // Marked `no_run` to dodge a Windows Defender heuristic that intermittently
85/// // refuses to launch this specific compiled doctest binary with
86/// // `ERROR_ACCESS_DENIED`. The same assertions are exercised by
87/// // `tests::test_sql_literal_display` so coverage is preserved.
88/// use hyperdb_api_core::protocol::escape::SqlLiteral;
89///
90/// assert_eq!(format!("{}", SqlLiteral("hello")), "'hello'");
91/// assert_eq!(format!("{}", SqlLiteral("it's")), "'it''s'");
92/// ```
93#[derive(Debug, Clone, Copy)]
94pub struct SqlLiteral<'a>(pub &'a str);
95
96impl fmt::Display for SqlLiteral<'_> {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.write_str("'")?;
99 for c in self.0.chars() {
100 if c == '\'' {
101 f.write_str("''")?;
102 } else {
103 write!(f, "{c}")?;
104 }
105 }
106 f.write_str("'")
107 }
108}
109
110/// Checks if a string is a valid unquoted identifier.
111///
112/// Valid unquoted identifiers:
113/// - Start with a letter (a-z, A-Z) or underscore
114/// - Contain only letters, digits (0-9), underscores, and dollar signs
115/// - Are not SQL reserved words (this function doesn't check for reserved words)
116#[must_use]
117pub fn is_valid_unquoted_identifier(s: &str) -> bool {
118 if s.is_empty() {
119 return false;
120 }
121
122 let mut chars = s.chars();
123
124 // First character must be letter or underscore
125 match chars.next() {
126 Some(c) if c.is_ascii_alphabetic() || c == '_' => {}
127 _ => return false,
128 }
129
130 // Rest can be letters, digits, underscores, or dollar signs
131 chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$')
132}
133
134/// Formats a qualified table name with proper escaping.
135///
136/// # Arguments
137///
138/// * `database` - Optional database name
139/// * `schema` - Optional schema name
140/// * `table` - Table name
141///
142/// # Example
143///
144/// ```
145/// use hyperdb_api_core::protocol::escape::format_table_name;
146///
147/// assert_eq!(format_table_name(None, None, "users"), "users");
148/// assert_eq!(format_table_name(None, Some("public"), "users"), "public.users");
149/// assert_eq!(format_table_name(Some("mydb"), Some("public"), "users"), "mydb.public.users");
150/// assert_eq!(format_table_name(None, None, "my-table"), "\"my-table\"");
151/// ```
152#[must_use]
153pub fn format_table_name(database: Option<&str>, schema: Option<&str>, table: &str) -> String {
154 match (database, schema) {
155 (Some(db), Some(s)) => format!(
156 "{}.{}.{}",
157 SqlIdentifier(db),
158 SqlIdentifier(s),
159 SqlIdentifier(table)
160 ),
161 (None, Some(s)) => format!("{}.{}", SqlIdentifier(s), SqlIdentifier(table)),
162 (Some(db), None) => format!("{}.{}", SqlIdentifier(db), SqlIdentifier(table)),
163 (None, None) => format!("{}", SqlIdentifier(table)),
164 }
165}
166
167// Backward compatibility functions - can be removed if not needed externally
168
169/// Escapes a SQL identifier (table name, column name, etc.).
170///
171/// This is a convenience function that returns the escaped identifier as a String.
172/// For more efficient formatting, use `SqlIdentifier` directly in format strings.
173///
174/// # Example
175///
176/// ```
177/// use hyperdb_api_core::protocol::escape::escape_identifier;
178///
179/// assert_eq!(escape_identifier("table"), "table");
180/// assert_eq!(escape_identifier("Segment"), "\"Segment\"");
181/// ```
182#[must_use]
183pub fn escape_identifier(identifier: &str) -> String {
184 format!("{}", SqlIdentifier(identifier))
185}
186
187/// Escapes a SQL string literal.
188///
189/// This is a convenience function that returns the escaped literal as a String.
190/// For more efficient formatting, use `SqlLiteral` directly in format strings.
191///
192/// # Example
193///
194/// ```
195/// use hyperdb_api_core::protocol::escape::escape_literal;
196///
197/// assert_eq!(escape_literal("hello"), "'hello'");
198/// assert_eq!(escape_literal("it's"), "'it''s'");
199/// ```
200#[must_use]
201pub fn escape_literal(literal: &str) -> String {
202 format!("{}", SqlLiteral(literal))
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_sql_identifier_display() {
211 // Valid unquoted identifiers with only lowercase should not be quoted
212 assert_eq!(format!("{}", SqlIdentifier("table")), "table");
213 assert_eq!(format!("{}", SqlIdentifier("my_table")), "my_table");
214 assert_eq!(format!("{}", SqlIdentifier("table1")), "table1");
215 assert_eq!(format!("{}", SqlIdentifier("_private")), "_private");
216 assert_eq!(format!("{}", SqlIdentifier("my$var")), "my$var");
217
218 // Identifiers with uppercase letters should be quoted to preserve case
219 assert_eq!(format!("{}", SqlIdentifier("Segment")), "\"Segment\"");
220 assert_eq!(format!("{}", SqlIdentifier("CustomerID")), "\"CustomerID\"");
221 assert_eq!(format!("{}", SqlIdentifier("Table")), "\"Table\"");
222
223 // Invalid unquoted identifiers should be quoted
224 assert_eq!(format!("{}", SqlIdentifier("my-table")), "\"my-table\"");
225 assert_eq!(format!("{}", SqlIdentifier("my table")), "\"my table\"");
226 assert_eq!(format!("{}", SqlIdentifier("1table")), "\"1table\"");
227 assert_eq!(format!("{}", SqlIdentifier("my\"table")), "\"my\"\"table\"");
228 assert_eq!(format!("{}", SqlIdentifier("")), "\"\"");
229 }
230
231 #[test]
232 fn test_sql_literal_display() {
233 assert_eq!(format!("{}", SqlLiteral("hello")), "'hello'");
234 assert_eq!(format!("{}", SqlLiteral("it's")), "'it''s'");
235 assert_eq!(format!("{}", SqlLiteral("")), "''");
236 }
237
238 #[test]
239 fn test_is_valid_unquoted_identifier() {
240 assert!(is_valid_unquoted_identifier("table"));
241 assert!(is_valid_unquoted_identifier("_private"));
242 assert!(is_valid_unquoted_identifier("table1"));
243 assert!(is_valid_unquoted_identifier("my$var"));
244
245 assert!(!is_valid_unquoted_identifier(""));
246 assert!(!is_valid_unquoted_identifier("1table"));
247 assert!(!is_valid_unquoted_identifier("my-table"));
248 assert!(!is_valid_unquoted_identifier("my table"));
249 }
250
251 #[test]
252 fn test_format_table_name() {
253 assert_eq!(format_table_name(None, None, "users"), "users");
254 assert_eq!(
255 format_table_name(None, Some("public"), "users"),
256 "public.users"
257 );
258 assert_eq!(
259 format_table_name(Some("mydb"), Some("public"), "users"),
260 "mydb.public.users"
261 );
262 // Test with names that need quoting
263 assert_eq!(format_table_name(None, None, "my-table"), "\"my-table\"");
264 assert_eq!(
265 format_table_name(None, Some("my schema"), "users"),
266 "\"my schema\".users"
267 );
268 }
269
270 #[test]
271 fn test_sql_identifier_in_format() {
272 // Demonstrate zero-allocation composability
273 let table = "users";
274 let column = "Customer ID";
275 let sql = format!(
276 "SELECT {} FROM {}",
277 SqlIdentifier(column),
278 SqlIdentifier(table)
279 );
280 assert_eq!(sql, "SELECT \"Customer ID\" FROM users");
281 }
282
283 // Backward compat function tests
284 #[test]
285 fn test_escape_identifier() {
286 assert_eq!(escape_identifier("table"), "table");
287 assert_eq!(escape_identifier("Segment"), "\"Segment\"");
288 }
289
290 #[test]
291 fn test_escape_literal() {
292 assert_eq!(escape_literal("hello"), "'hello'");
293 assert_eq!(escape_literal("it's"), "'it''s'");
294 }
295}