sql_middleware/
translation.rs

1use std::borrow::Cow;
2
3/// Target placeholder style for translation.
4#[derive(Debug, Clone, Copy, PartialEq, Eq)]
5pub enum PlaceholderStyle {
6    /// PostgreSQL-style placeholders like `$1`.
7    Postgres,
8    /// SQLite-style placeholders like `?1` (also used by LibSQL/Turso).
9    Sqlite,
10}
11
12/// How to resolve translation for a call relative to the pool default.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum TranslationMode {
15    /// Follow the pool's default setting.
16    PoolDefault,
17    /// Force translation on, regardless of pool default.
18    ForceOn,
19    /// Force translation off, regardless of pool default.
20    ForceOff,
21}
22
23impl TranslationMode {
24    #[must_use]
25    pub fn resolve(self, pool_default: bool) -> bool {
26        match self {
27            TranslationMode::PoolDefault => pool_default,
28            TranslationMode::ForceOn => true,
29            TranslationMode::ForceOff => false,
30        }
31    }
32}
33
34/// Per-call options for query/execute paths.
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub struct QueryOptions {
37    pub translation: TranslationMode,
38}
39
40impl Default for QueryOptions {
41    fn default() -> Self {
42        Self {
43            translation: TranslationMode::PoolDefault,
44        }
45    }
46}
47
48impl QueryOptions {
49    #[must_use]
50    pub fn with_translation(mut self, translation: TranslationMode) -> Self {
51        self.translation = translation;
52        self
53    }
54}
55
56/// Translate placeholders between Postgres-style `$N` and SQLite-style `?N`.
57///
58/// Warning: translation skips quoted strings, comments, and dollar-quoted blocks via a lightweight
59/// state machine; it may still miss edge cases in complex SQL. For dialect-specific SQL (e.g.,
60/// PL/pgSQL bodies), prefer backend-specific SQL instead of relying on translation:
61/// ```rust
62/// # use sql_middleware::prelude::*;
63/// # async fn demo(conn: &mut MiddlewarePoolConnection) -> Result<(), SqlMiddlewareDbError> {
64/// let query = match conn {
65///     MiddlewarePoolConnection::Postgres { .. } => r#"$function$
66/// BEGIN
67///     RETURN ($1 ~ $q$[\t\r\n\v\\]$q$);
68/// END;
69/// $function$"#,
70///     MiddlewarePoolConnection::Sqlite { .. } | MiddlewarePoolConnection::Turso { .. } => {
71///         include_str!("../sql/functions/sqlite/03_sp_get_scores.sql")
72///     }
73/// };
74/// # let _ = query;
75/// # Ok(())
76/// # }
77/// ```
78/// Returns a borrowed `Cow` when no changes are needed.
79#[must_use]
80pub fn translate_placeholders(
81    sql: &str,
82    target: PlaceholderStyle,
83    enabled: bool,
84) -> Cow<'_, str> {
85    if !enabled {
86        return Cow::Borrowed(sql);
87    }
88
89    let mut out: Option<String> = None;
90    let mut state = State::Normal;
91    let mut idx = 0;
92    let bytes = sql.as_bytes();
93
94    while idx < bytes.len() {
95        let b = bytes[idx];
96        let mut replaced = false;
97        match state {
98            State::Normal => match b {
99                b'\'' => state = State::SingleQuoted,
100                b'"' => state = State::DoubleQuoted,
101                b'-' if bytes.get(idx + 1) == Some(&b'-') => {
102                    state = State::LineComment;
103                }
104                b'/' if bytes.get(idx + 1) == Some(&b'*') => {
105                    state = State::BlockComment(1);
106                }
107                b'$' => {
108                    if let Some((tag, advance)) = try_start_dollar_quote(bytes, idx) {
109                        state = State::DollarQuoted(tag);
110                        idx = advance;
111                    } else if matches!(target, PlaceholderStyle::Sqlite)
112                        && let Some((digits_end, digits)) = scan_digits(bytes, idx + 1)
113                    {
114                        let buf = out.get_or_insert_with(|| sql[..idx].to_string());
115                        buf.push('?');
116                        buf.push_str(digits);
117                        idx = digits_end - 1;
118                        replaced = true;
119                    }
120                }
121                b'?' if matches!(target, PlaceholderStyle::Postgres) => {
122                    if let Some((digits_end, digits)) = scan_digits(bytes, idx + 1) {
123                        let buf = out.get_or_insert_with(|| sql[..idx].to_string());
124                        buf.push('$');
125                        buf.push_str(digits);
126                        idx = digits_end - 1;
127                        replaced = true;
128                    }
129                }
130                _ => {}
131            },
132            State::SingleQuoted => {
133                if b == b'\'' {
134                    if bytes.get(idx + 1) == Some(&b'\'') {
135                        idx += 1; // skip escaped quote
136                    } else {
137                        state = State::Normal;
138                    }
139                }
140            }
141            State::DoubleQuoted => {
142                if b == b'"' {
143                    if bytes.get(idx + 1) == Some(&b'"') {
144                        idx += 1; // skip escaped quote
145                    } else {
146                        state = State::Normal;
147                    }
148                }
149            }
150            State::LineComment => {
151                if b == b'\n' {
152                    state = State::Normal;
153                }
154            }
155            State::BlockComment(depth) => {
156                if b == b'/' && bytes.get(idx + 1) == Some(&b'*') {
157                    state = State::BlockComment(depth + 1);
158                } else if b == b'*' && bytes.get(idx + 1) == Some(&b'/') {
159                    if depth == 1 {
160                        state = State::Normal;
161                    } else {
162                        state = State::BlockComment(depth - 1);
163                    }
164                }
165            }
166            State::DollarQuoted(ref tag) => {
167                if b == b'$' && matches_tag(bytes, idx, tag) {
168                    let tag_len = tag.len();
169                    state = State::Normal;
170                    idx += tag_len;
171                }
172            }
173        }
174
175        if let Some(ref mut buf) = out && !replaced {
176            buf.push(b as char);
177        }
178
179        idx += 1;
180    }
181
182    match out {
183        Some(buf) => Cow::Owned(buf),
184        None => Cow::Borrowed(sql),
185    }
186}
187
188#[derive(Clone)]
189enum State {
190    Normal,
191    SingleQuoted,
192    DoubleQuoted,
193    LineComment,
194    BlockComment(u32),
195    DollarQuoted(String),
196}
197
198fn scan_digits(bytes: &[u8], start: usize) -> Option<(usize, &str)> {
199    let mut idx = start;
200    while idx < bytes.len() && bytes[idx].is_ascii_digit() {
201        idx += 1;
202    }
203    if idx == start {
204        None
205    } else {
206        std::str::from_utf8(&bytes[start..idx])
207            .ok()
208            .map(|digits| (idx, digits))
209    }
210}
211
212fn try_start_dollar_quote(bytes: &[u8], start: usize) -> Option<(String, usize)> {
213    let mut idx = start + 1;
214    while idx < bytes.len() && bytes[idx] != b'$' {
215        let b = bytes[idx];
216        if !(b.is_ascii_alphanumeric() || b == b'_') {
217            return None;
218        }
219        idx += 1;
220    }
221
222    if idx < bytes.len() && bytes[idx] == b'$' {
223        let tag = String::from_utf8(bytes[start + 1..idx].to_vec()).ok()?;
224        Some((tag, idx))
225    } else {
226        None
227    }
228}
229
230fn matches_tag(bytes: &[u8], idx: usize, tag: &str) -> bool {
231    let end = idx + 1 + tag.len();
232    end < bytes.len()
233        && bytes[idx + 1..=end].starts_with(tag.as_bytes())
234        && bytes.get(end) == Some(&b'$')
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn translates_sqlite_to_postgres() {
243        let sql = "select * from t where a = ?1 and b = ?2";
244        let res = translate_placeholders(sql, PlaceholderStyle::Postgres, true);
245        assert_eq!(res, "select * from t where a = $1 and b = $2");
246    }
247
248    #[test]
249    fn translates_postgres_to_sqlite() {
250        let sql = "insert into t values($1, $2)";
251        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
252        assert_eq!(res, "insert into t values(?1, ?2)");
253    }
254
255    #[test]
256    fn skips_inside_literals_and_comments() {
257        let sql = "select '?1', $1 -- $2\n/* ?3 */ from t where a = $1";
258        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
259        assert_eq!(res, "select '?1', ?1 -- $2\n/* ?3 */ from t where a = ?1");
260    }
261
262    #[test]
263    fn skips_dollar_quoted_blocks() {
264        let sql = "$foo$ select $1 from t $foo$ where a = $1";
265        let res = translate_placeholders(sql, PlaceholderStyle::Sqlite, true);
266        assert_eq!(res, "$foo$ select $1 from t $foo$ where a = ?1");
267    }
268
269    #[test]
270    fn respects_disabled_flag() {
271        let sql = "select * from t where a = ?1";
272        let res = translate_placeholders(sql, PlaceholderStyle::Postgres, false);
273        assert!(matches!(res, Cow::Borrowed(_)));
274        assert_eq!(res, sql);
275    }
276
277    #[test]
278    fn translation_mode_resolution() {
279        assert!(TranslationMode::ForceOn.resolve(false));
280        assert!(!TranslationMode::ForceOff.resolve(true));
281        assert!(TranslationMode::PoolDefault.resolve(true));
282        assert!(!TranslationMode::PoolDefault.resolve(false));
283    }
284}