use crate::value::{Value, Values};
#[derive(Debug, Clone)]
pub struct SqlWriter {
sql: String,
values: Values,
param_index: usize,
}
impl SqlWriter {
pub fn new() -> Self {
Self {
sql: String::new(),
values: Values::default(),
param_index: 1,
}
}
pub fn push(&mut self, s: &str) {
self.sql.push_str(s);
}
pub fn push_space(&mut self) {
if !self.sql.is_empty() && !self.sql.ends_with(' ') {
self.sql.push(' ');
}
}
pub fn push_identifier<F>(&mut self, ident: &str, escape_fn: F)
where
F: FnOnce(&str) -> String,
{
self.sql.push_str(&escape_fn(ident));
}
pub fn push_comma(&mut self) {
self.sql.push_str(", ");
}
pub fn push_value<F>(&mut self, value: Value, format_fn: F) -> Option<usize>
where
F: FnOnce(usize) -> String,
{
if value.is_null() {
self.sql.push_str("NULL");
return None;
}
let index = self.param_index;
self.sql.push_str(&format_fn(index));
self.values.push(value);
self.param_index += 1;
Some(index)
}
pub fn push_keyword(&mut self, keyword: &str) {
self.push_space();
self.sql.push_str(keyword);
}
pub fn sql(&self) -> &str {
&self.sql
}
pub fn values(&self) -> &Values {
&self.values
}
pub fn param_index(&self) -> usize {
self.param_index
}
pub fn finish(self) -> (String, Values) {
(self.sql, self.values)
}
pub fn into_string(self) -> String {
self.sql
}
pub fn sql_mut(&mut self) -> &mut String {
&mut self.sql
}
pub fn values_mut(&mut self) -> &mut Values {
&mut self.values
}
pub fn is_empty(&self) -> bool {
self.sql.is_empty()
}
pub fn len(&self) -> usize {
self.sql.len()
}
pub fn push_list<I, T, F>(&mut self, items: I, separator: &str, mut f: F)
where
I: IntoIterator<Item = T>,
F: FnMut(&mut Self, T),
{
let mut first = true;
for item in items {
if !first {
self.sql.push_str(separator);
}
f(self, item);
first = false;
}
}
pub fn append_values(&mut self, other: &Values) {
for value in other.iter() {
self.values.push(value.clone());
}
self.param_index += other.len();
}
}
impl Default for SqlWriter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sql_writer_basic() {
let mut writer = SqlWriter::new();
writer.push("SELECT");
writer.push_space();
writer.push("*");
assert_eq!(writer.sql(), "SELECT *");
}
#[test]
fn test_sql_writer_identifier() {
let mut writer = SqlWriter::new();
writer.push_identifier("user", |s| format!("\"{}\"", s.replace('"', "\"\"")));
assert_eq!(writer.sql(), "\"user\"");
}
#[test]
fn test_sql_writer_value_postgres() {
let mut writer = SqlWriter::new();
writer.push_value(Value::Int(Some(42)), |i| format!("${}", i));
writer.push_space();
writer.push_value(Value::String(Some(Box::new("test".to_string()))), |i| {
format!("${}", i)
});
assert_eq!(writer.sql(), "$1 $2");
assert_eq!(writer.values().len(), 2);
}
#[test]
fn test_sql_writer_value_mysql() {
let mut writer = SqlWriter::new();
writer.push_value(Value::Int(Some(42)), |_| "?".to_string());
writer.push_space();
writer.push_value(Value::String(Some(Box::new("test".to_string()))), |_| {
"?".to_string()
});
assert_eq!(writer.sql(), "? ?");
assert_eq!(writer.values().len(), 2);
}
#[test]
fn test_sql_writer_keyword() {
let mut writer = SqlWriter::new();
writer.push("SELECT");
writer.push_keyword("FROM");
writer.push_keyword("WHERE");
assert_eq!(writer.sql(), "SELECT FROM WHERE");
}
#[test]
fn test_sql_writer_list() {
let mut writer = SqlWriter::new();
writer.push_list(vec!["a", "b", "c"], ", ", |w, item| {
w.push_identifier(item, |s| format!("\"{}\"", s.replace('"', "\"\"")));
});
assert_eq!(writer.sql(), "\"a\", \"b\", \"c\"");
}
#[test]
fn test_sql_writer_comma() {
let mut writer = SqlWriter::new();
writer.push("a");
writer.push_comma();
writer.push("b");
assert_eq!(writer.sql(), "a, b");
}
#[test]
fn test_sql_writer_finish() {
let mut writer = SqlWriter::new();
writer.push("SELECT");
writer.push_space();
writer.push_value(Value::Int(Some(42)), |i| format!("${}", i));
let (sql, values) = writer.finish();
assert_eq!(sql, "SELECT $1");
assert_eq!(values.len(), 1);
}
#[test]
fn test_sql_writer_identifier_with_embedded_double_quotes() {
let mut writer = SqlWriter::new();
writer.push_identifier("table\"; DROP TABLE users; --", |s| {
format!("\"{}\"", s.replace('"', "\"\""))
});
assert_eq!(writer.sql(), "\"table\"\"; DROP TABLE users; --\"");
}
#[test]
fn test_sql_writer_list_with_special_identifiers() {
let mut writer = SqlWriter::new();
let items = vec!["normal", "has\"quote"];
writer.push_list(items, ", ", |w, item| {
w.push_identifier(item, |s| format!("\"{}\"", s.replace('"', "\"\"")));
});
assert_eq!(writer.sql(), "\"normal\", \"has\"\"quote\"");
}
}