use crate::storage::find_args::Paged;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Dialect {
Sqlite,
Mysql,
Postgres,
}
impl Dialect {
pub fn now_expr(self) -> &'static str {
match self {
Dialect::Sqlite => "datetime('now')",
Dialect::Mysql => "NOW()",
Dialect::Postgres => "NOW()",
}
}
pub fn quote_column(self, col: &str) -> String {
match self {
Dialect::Mysql | Dialect::Sqlite => format!("`{}`", col),
Dialect::Postgres => format!("\"{}\"", col),
}
}
}
pub fn placeholder(dialect: Dialect, index: usize) -> String {
match dialect {
Dialect::Sqlite | Dialect::Mysql => "?".to_string(),
Dialect::Postgres => format!("${}", index),
}
}
pub fn upsert_suffix(dialect: Dialect, conflict_cols: &[&str], update_cols: &[&str]) -> String {
match dialect {
Dialect::Sqlite | Dialect::Postgres => {
let conflict = conflict_cols.join(", ");
let updates: Vec<String> = update_cols
.iter()
.map(|c| format!("{c} = excluded.{c}"))
.collect();
format!(
" ON CONFLICT ({}) DO UPDATE SET {}",
conflict,
updates.join(", ")
)
}
Dialect::Mysql => {
let updates: Vec<String> = update_cols
.iter()
.map(|c| format!("{c} = VALUES({c})"))
.collect();
format!(" ON DUPLICATE KEY UPDATE {}", updates.join(", "))
}
}
}
pub fn last_insert_id_query(dialect: Dialect) -> &'static str {
match dialect {
Dialect::Sqlite => "-- use result.last_insert_rowid()",
Dialect::Mysql => "SELECT LAST_INSERT_ID()",
Dialect::Postgres => "-- use RETURNING clause",
}
}
pub struct WhereBuilder {
clauses: Vec<String>,
dialect: Dialect,
param_index: usize,
}
impl WhereBuilder {
pub fn new(dialect: Dialect) -> Self {
Self {
clauses: Vec::new(),
dialect,
param_index: 0,
}
}
pub fn new_sqlite() -> Self {
Self::new(Dialect::Sqlite)
}
fn next_placeholder(&mut self) -> String {
self.param_index += 1;
placeholder(self.dialect, self.param_index)
}
pub fn add_eq(&mut self, column: &str) {
let ph = self.next_placeholder();
let qc = self.dialect.quote_column(column);
self.clauses.push(format!("{} = {}", qc, ph));
}
pub fn add_gte(&mut self, column: &str) {
let ph = self.next_placeholder();
let qc = self.dialect.quote_column(column);
self.clauses.push(format!("{} >= {}", qc, ph));
}
#[allow(dead_code)]
pub fn add_lte(&mut self, column: &str) {
let ph = self.next_placeholder();
let qc = self.dialect.quote_column(column);
self.clauses.push(format!("{} <= {}", qc, ph));
}
#[allow(dead_code)]
pub fn add_like(&mut self, column: &str) {
let ph = self.next_placeholder();
let qc = self.dialect.quote_column(column);
self.clauses.push(format!("{} LIKE {}", qc, ph));
}
#[allow(dead_code)]
pub fn add_in(&mut self, column: &str, count: usize) {
if count == 0 {
return;
}
let qc = self.dialect.quote_column(column);
let placeholders: Vec<String> = (0..count).map(|_| self.next_placeholder()).collect();
self.clauses
.push(format!("{} IN ({})", qc, placeholders.join(", ")));
}
#[allow(dead_code)]
pub fn add_subquery_in(
&mut self,
table: &str,
sub_col: &str,
join_col: &str,
outer_table: &str,
outer_col: &str,
count: usize,
) {
if count == 0 {
return;
}
let qt = self.dialect.quote_column(table);
let qsc = self.dialect.quote_column(sub_col);
let qjc = self.dialect.quote_column(join_col);
let qot = self.dialect.quote_column(outer_table);
let qoc = self.dialect.quote_column(outer_col);
let placeholders: Vec<String> = (0..count).map(|_| self.next_placeholder()).collect();
self.clauses.push(format!(
"(SELECT {} FROM {} WHERE {}.{} = {}.{}) IN ({})",
qsc,
qt,
qt,
qjc,
qot,
qoc,
placeholders.join(", ")
));
}
pub fn param_count(&self) -> usize {
self.param_index
}
pub fn build_where(&self) -> String {
if self.clauses.is_empty() {
String::new()
} else {
format!(" WHERE {}", self.clauses.join(" AND "))
}
}
pub fn build_order_by(column: &str, desc: bool) -> String {
if desc {
format!(" ORDER BY {} DESC", column)
} else {
format!(" ORDER BY {} ASC", column)
}
}
pub fn build_limit_offset(paged: &Paged) -> String {
format!(" LIMIT {} OFFSET {}", paged.limit, paged.offset)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sqlite_placeholders() {
let mut wb = WhereBuilder::new(Dialect::Sqlite);
wb.add_eq("userId");
wb.add_gte("created_at");
assert_eq!(
wb.build_where(),
" WHERE `userId` = ? AND `created_at` >= ?"
);
assert_eq!(wb.param_count(), 2);
}
#[test]
fn mysql_placeholders() {
let mut wb = WhereBuilder::new(Dialect::Mysql);
wb.add_eq("userId");
wb.add_eq("status");
assert_eq!(wb.build_where(), " WHERE `userId` = ? AND `status` = ?");
}
#[test]
fn postgres_placeholders() {
let mut wb = WhereBuilder::new(Dialect::Postgres);
wb.add_eq("userId");
wb.add_gte("created_at");
wb.add_eq("status");
assert_eq!(
wb.build_where(),
" WHERE \"userId\" = $1 AND \"created_at\" >= $2 AND \"status\" = $3"
);
assert_eq!(wb.param_count(), 3);
}
#[test]
fn postgres_in_clause() {
let mut wb = WhereBuilder::new(Dialect::Postgres);
wb.add_in("id", 3);
assert_eq!(wb.build_where(), " WHERE \"id\" IN ($1, $2, $3)");
}
#[test]
fn upsert_sqlite() {
let result = upsert_suffix(Dialect::Sqlite, &["userId"], &["name", "updated_at"]);
assert_eq!(
result,
" ON CONFLICT (userId) DO UPDATE SET name = excluded.name, updated_at = excluded.updated_at"
);
}
#[test]
fn upsert_mysql() {
let result = upsert_suffix(Dialect::Mysql, &["userId"], &["name", "updated_at"]);
assert_eq!(
result,
" ON DUPLICATE KEY UPDATE name = VALUES(name), updated_at = VALUES(updated_at)"
);
}
#[test]
fn sqlite_subquery_in() {
let mut wb = WhereBuilder::new(Dialect::Sqlite);
wb.add_eq("userId");
wb.add_subquery_in(
"transactions",
"status",
"transactionId",
"outputs",
"transactionId",
2,
);
assert_eq!(
wb.build_where(),
" WHERE `userId` = ? AND (SELECT `status` FROM `transactions` WHERE `transactions`.`transactionId` = `outputs`.`transactionId`) IN (?, ?)"
);
assert_eq!(wb.param_count(), 3);
}
#[test]
fn postgres_subquery_in() {
let mut wb = WhereBuilder::new(Dialect::Postgres);
wb.add_subquery_in(
"transactions",
"status",
"transactionId",
"outputs",
"transactionId",
2,
);
assert_eq!(
wb.build_where(),
" WHERE (SELECT \"status\" FROM \"transactions\" WHERE \"transactions\".\"transactionId\" = \"outputs\".\"transactionId\") IN ($1, $2)"
);
assert_eq!(wb.param_count(), 2);
}
}