#![allow(clippy::doc_markdown)]
use super::Dialect;
use std::fmt::Write as _;
use std::iter::Peekable;
use std::str::Chars;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TranslatedSql {
pub sql: String,
pub ordered_params: Vec<String>,
}
#[must_use]
pub fn translate_placeholders(sql: &str, dialect: Dialect) -> TranslatedSql {
let mut walker = SqlWalker::new(sql, dialect);
walker.run();
walker.into_translated()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Normal,
StringLiteral(char),
LineComment,
BlockComment(usize),
Placeholder,
CastTypeName,
}
struct SqlWalker<'a> {
chars: Peekable<Chars<'a>>,
state: State,
out: String,
order: Vec<String>,
pg_index: usize,
dialect: Dialect,
pending_name: String,
}
impl<'a> SqlWalker<'a> {
fn new(sql: &'a str, dialect: Dialect) -> Self {
Self {
chars: sql.chars().peekable(),
state: State::Normal,
out: String::with_capacity(sql.len()),
order: Vec::new(),
pg_index: 0,
dialect,
pending_name: String::new(),
}
}
fn run(&mut self) {
while let Some(c) = self.chars.next() {
match self.state {
State::Normal => self.handle_normal(c),
State::StringLiteral(q) => self.handle_string(c, q),
State::LineComment => self.handle_line_comment(c),
State::BlockComment(depth) => self.handle_block_comment(c, depth),
State::Placeholder => self.handle_placeholder(c),
State::CastTypeName => self.handle_cast_type_name(c),
}
}
if self.state == State::Placeholder {
self.emit_placeholder_from_pending();
}
}
fn handle_normal(&mut self, c: char) {
match c {
'\'' | '"' => {
self.out.push(c);
self.state = State::StringLiteral(c);
},
'-' if self.chars.peek() == Some(&'-') => {
self.out.push(c);
self.out.push('-');
self.chars.next();
self.state = State::LineComment;
},
'/' if self.chars.peek() == Some(&'*') => {
self.out.push(c);
self.out.push('*');
self.chars.next();
self.state = State::BlockComment(1);
},
':' => self.dispatch_colon(),
_ => self.out.push(c),
}
}
fn dispatch_colon(&mut self) {
match self.chars.peek().copied() {
Some(':') => {
self.out.push(':');
self.out.push(':');
self.chars.next();
self.state = State::CastTypeName;
},
Some(n) if is_ident_start(n) => {
self.pending_name.clear();
self.state = State::Placeholder;
},
_ => self.out.push(':'),
}
}
fn handle_placeholder(&mut self, c: char) {
if is_ident_continue(c) {
self.pending_name.push(c);
} else {
self.emit_placeholder_from_pending();
self.handle_normal(c);
}
}
fn handle_cast_type_name(&mut self, c: char) {
self.out.push(c);
if !is_ident_continue(c) {
self.state = State::Normal;
}
}
fn emit_placeholder_from_pending(&mut self) {
match self.dialect {
Dialect::Postgres => {
self.pg_index += 1;
let _ = write!(self.out, "${}", self.pg_index);
},
Dialect::MySql | Dialect::Athena => self.out.push('?'),
Dialect::Sqlite => {
let _ = write!(self.out, ":{}", self.pending_name);
},
}
self.order.push(std::mem::take(&mut self.pending_name));
self.state = State::Normal;
}
fn handle_string(&mut self, c: char, q: char) {
self.out.push(c);
if c == q {
if self.chars.peek() == Some(&q) {
self.out.push(q);
self.chars.next();
} else {
self.state = State::Normal;
}
}
}
fn handle_line_comment(&mut self, c: char) {
self.out.push(c);
if c == '\n' {
self.state = State::Normal;
}
}
fn handle_block_comment(&mut self, c: char, depth: usize) {
self.out.push(c);
if c == '*' && self.chars.peek() == Some(&'/') {
self.out.push('/');
self.chars.next();
self.state = if depth <= 1 {
State::Normal
} else {
State::BlockComment(depth - 1)
};
} else if c == '/' && self.chars.peek() == Some(&'*') {
self.out.push('*');
self.chars.next();
self.state = State::BlockComment(depth + 1);
}
}
fn into_translated(self) -> TranslatedSql {
TranslatedSql {
sql: self.out,
ordered_params: self.order,
}
}
}
fn is_ident_start(c: char) -> bool {
c.is_ascii_alphabetic() || c == '_'
}
fn is_ident_continue(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '_'
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn idempotence_no_placeholders(s in "[A-Za-z0-9 _\\.,;\\(\\)=]*") {
for d in [Dialect::Postgres, Dialect::MySql, Dialect::Athena, Dialect::Sqlite] {
let t = translate_placeholders(&s, d);
prop_assert_eq!(&t.sql, &s);
prop_assert!(t.ordered_params.is_empty());
}
}
#[test]
fn bind_order_preserved(names in proptest::collection::vec("[a-z]{1,5}", 1..=5)) {
let sql = names.iter().map(|n| format!(":{n}")).collect::<Vec<_>>().join(", ");
let t = translate_placeholders(&sql, Dialect::Postgres);
prop_assert_eq!(t.ordered_params, names);
}
#[test]
fn postgres_positional_indexing(names in proptest::collection::vec("[a-z]{1,5}", 1..=5)) {
let sql = names.iter().map(|n| format!(":{n}")).collect::<Vec<_>>().join(", ");
let t = translate_placeholders(&sql, Dialect::Postgres);
prop_assert_eq!(t.ordered_params.len(), names.len());
for i in 1..=names.len() {
let token = format!("${i}");
prop_assert!(t.sql.contains(&token));
}
let above = format!("${}", names.len() + 1);
prop_assert!(!t.sql.contains(&above));
}
#[test]
fn sqlite_identity(s in any::<String>()) {
let t = translate_placeholders(&s, Dialect::Sqlite);
prop_assert_eq!(t.sql, s);
}
#[test]
fn no_panic_on_arbitrary_input(s in any::<String>()) {
for d in [Dialect::Postgres, Dialect::MySql, Dialect::Athena, Dialect::Sqlite] {
let _ = translate_placeholders(&s, d);
}
}
}
}
#[cfg(test)]
mod unit_tests {
use super::*;
fn t(sql: &str, d: Dialect) -> TranslatedSql {
translate_placeholders(sql, d)
}
#[test]
fn empty_input_is_identity() {
let r = t("", Dialect::Postgres);
assert_eq!(r.sql, "");
assert!(r.ordered_params.is_empty());
}
#[test]
fn no_placeholder_is_identity_mysql() {
let r = t("SELECT 1", Dialect::MySql);
assert_eq!(r.sql, "SELECT 1");
assert!(r.ordered_params.is_empty());
}
#[test]
fn single_placeholder_postgres() {
let r = t("SELECT :id FROM t", Dialect::Postgres);
assert_eq!(r.sql, "SELECT $1 FROM t");
assert_eq!(r.ordered_params, vec!["id"]);
}
#[test]
fn single_placeholder_mysql() {
let r = t("SELECT :id FROM t", Dialect::MySql);
assert_eq!(r.sql, "SELECT ? FROM t");
assert_eq!(r.ordered_params, vec!["id"]);
}
#[test]
fn single_placeholder_athena() {
let r = t("SELECT :id FROM t", Dialect::Athena);
assert_eq!(r.sql, "SELECT ? FROM t");
assert_eq!(r.ordered_params, vec!["id"]);
}
#[test]
fn single_placeholder_sqlite_is_identity_with_bind_order() {
let r = t("SELECT :id FROM t", Dialect::Sqlite);
assert_eq!(r.sql, "SELECT :id FROM t");
assert_eq!(r.ordered_params, vec!["id"]);
}
#[test]
fn repeated_name_gets_fresh_index_postgres() {
let r = t("WHERE a = :a AND b = :b AND c = :a", Dialect::Postgres);
assert_eq!(r.sql, "WHERE a = $1 AND b = $2 AND c = $3");
assert_eq!(r.ordered_params, vec!["a", "b", "a"]);
}
#[test]
fn three_distinct_names_all_dialects_match_must_haves() {
let sql = "SELECT :id FROM t WHERE x = :x AND y = :id";
let pg = t(sql, Dialect::Postgres);
assert_eq!(pg.sql, "SELECT $1 FROM t WHERE x = $2 AND y = $3");
assert_eq!(pg.ordered_params, vec!["id", "x", "id"]);
let my = t(sql, Dialect::MySql);
assert_eq!(my.sql, "SELECT ? FROM t WHERE x = ? AND y = ?");
assert_eq!(my.ordered_params, vec!["id", "x", "id"]);
let at = t(sql, Dialect::Athena);
assert_eq!(at.sql, "SELECT ? FROM t WHERE x = ? AND y = ?");
assert_eq!(at.ordered_params, vec!["id", "x", "id"]);
let lite = t(sql, Dialect::Sqlite);
assert_eq!(lite.sql, sql);
assert_eq!(lite.ordered_params, vec!["id", "x", "id"]);
}
#[test]
fn placeholder_inside_string_literal_not_translated() {
let r = t("SELECT 'WHERE name = :foo' AS x", Dialect::Postgres);
assert_eq!(r.sql, "SELECT 'WHERE name = :foo' AS x");
assert!(r.ordered_params.is_empty());
}
#[test]
fn doubled_single_quote_escape_stays_in_literal() {
let r = t("SELECT 'it''s :foo' AS x", Dialect::Postgres);
assert_eq!(r.sql, "SELECT 'it''s :foo' AS x");
assert!(r.ordered_params.is_empty());
}
#[test]
fn double_quoted_identifier_skips_placeholder() {
let r = t("SELECT \"col:name\" FROM t", Dialect::Postgres);
assert_eq!(r.sql, "SELECT \"col:name\" FROM t");
assert!(r.ordered_params.is_empty());
}
#[test]
fn placeholder_in_line_comment_not_translated() {
let r = t("SELECT 1 -- bind :id here", Dialect::Postgres);
assert_eq!(r.sql, "SELECT 1 -- bind :id here");
assert!(r.ordered_params.is_empty());
}
#[test]
fn line_comment_ends_at_newline() {
let r = t("SELECT 1 -- :a\nWHERE x = :b", Dialect::Postgres);
assert_eq!(r.sql, "SELECT 1 -- :a\nWHERE x = $1");
assert_eq!(r.ordered_params, vec!["b"]);
}
#[test]
fn placeholder_in_block_comment_not_translated() {
let r = t("SELECT /* :foo */ 1", Dialect::Postgres);
assert_eq!(r.sql, "SELECT /* :foo */ 1");
assert!(r.ordered_params.is_empty());
}
#[test]
fn nested_block_comment_tracked_via_depth() {
let r = t("SELECT /* /* :foo */ :bar */ :baz", Dialect::Postgres);
assert_eq!(r.sql, "SELECT /* /* :foo */ :bar */ $1");
assert_eq!(r.ordered_params, vec!["baz"]);
}
#[test]
fn postgres_double_colon_cast_preserves_text_identifier() {
let r = t("SELECT :id::text FROM t", Dialect::Postgres);
assert_eq!(
r,
TranslatedSql {
sql: "SELECT $1::text FROM t".into(),
ordered_params: vec!["id".into()],
}
);
}
#[test]
fn postgres_double_colon_int_cast_no_placeholder() {
let r = t("SELECT 1::int", Dialect::Postgres);
assert_eq!(
r,
TranslatedSql {
sql: "SELECT 1::int".into(),
ordered_params: vec![],
}
);
}
#[test]
fn mysql_session_variable_assignment_not_a_placeholder() {
let r = t("SET @x := 5", Dialect::MySql);
assert_eq!(
r,
TranslatedSql {
sql: "SET @x := 5".into(),
ordered_params: vec![],
}
);
}
#[test]
fn colon_followed_by_digit_emits_verbatim() {
let r = t("SELECT :1bad FROM t", Dialect::Postgres);
assert_eq!(
r,
TranslatedSql {
sql: "SELECT :1bad FROM t".into(),
ordered_params: vec![],
}
);
}
#[test]
fn string_literal_cast_both_colons_verbatim() {
let r = t("SELECT 'foo'::text", Dialect::Postgres);
assert_eq!(r.sql, "SELECT 'foo'::text");
assert!(r.ordered_params.is_empty());
}
#[test]
fn placeholder_then_cast_then_placeholder() {
let r = t("SELECT :a::text, :b FROM t", Dialect::Postgres);
assert_eq!(r.sql, "SELECT $1::text, $2 FROM t");
assert_eq!(r.ordered_params, vec!["a", "b"]);
}
#[test]
fn lone_colon_at_eof_emits_verbatim() {
let r = t("SELECT 1:", Dialect::Postgres);
assert_eq!(r.sql, "SELECT 1:");
assert!(r.ordered_params.is_empty());
}
#[test]
fn underscore_leading_placeholder_name() {
let r = t("WHERE x = :_id", Dialect::Postgres);
assert_eq!(r.sql, "WHERE x = $1");
assert_eq!(r.ordered_params, vec!["_id"]);
}
#[test]
fn unterminated_literal_does_not_panic() {
let r = t("SELECT 'unterminated :foo", Dialect::Postgres);
assert_eq!(r.sql, "SELECT 'unterminated :foo");
assert!(r.ordered_params.is_empty());
}
#[test]
fn unterminated_block_comment_does_not_panic() {
let r = t("SELECT /* :foo", Dialect::Postgres);
assert_eq!(r.sql, "SELECT /* :foo");
assert!(r.ordered_params.is_empty());
}
#[test]
fn placeholder_at_eof_is_emitted() {
let r = t("WHERE id = :id", Dialect::Postgres);
assert_eq!(r.sql, "WHERE id = $1");
assert_eq!(r.ordered_params, vec!["id"]);
}
}