use super::value::{SqlValue, ToSqlValue};
#[must_use]
pub fn dyn_col(name: &str) -> ColumnRef {
ColumnRef {
table: None,
name: String::from(name),
}
}
#[derive(Debug, Clone)]
pub struct ColumnRef {
pub table: Option<String>,
pub name: String,
}
impl ColumnRef {
#[must_use]
pub fn qualified(table: &str, name: &str) -> Self {
Self {
table: Some(String::from(table)),
name: String::from(name),
}
}
#[must_use]
pub fn to_sql(&self) -> String {
match &self.table {
Some(t) => format!("{t}.{}", self.name),
None => self.name.clone(),
}
}
#[must_use]
pub fn eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), "=", value.to_sql_value().into())
}
#[must_use]
pub fn not_eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), "!=", value.to_sql_value().into())
}
#[must_use]
pub fn lt<T: ToSqlValue>(self, value: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), "<", value.to_sql_value().into())
}
#[must_use]
pub fn lt_eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), "<=", value.to_sql_value().into())
}
#[must_use]
pub fn gt<T: ToSqlValue>(self, value: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), ">", value.to_sql_value().into())
}
#[must_use]
pub fn gt_eq<T: ToSqlValue>(self, value: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), ">=", value.to_sql_value().into())
}
#[must_use]
pub fn is_null(self) -> ExprBuilder {
ExprBuilder::postfix(self.into(), "IS NULL")
}
#[must_use]
pub fn is_not_null(self) -> ExprBuilder {
ExprBuilder::postfix(self.into(), "IS NOT NULL")
}
#[must_use]
pub fn like<T: ToSqlValue>(self, pattern: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), "LIKE", pattern.to_sql_value().into())
}
#[must_use]
pub fn not_like<T: ToSqlValue>(self, pattern: T) -> ExprBuilder {
ExprBuilder::binary(self.into(), "NOT LIKE", pattern.to_sql_value().into())
}
#[must_use]
pub fn between<T: ToSqlValue, U: ToSqlValue>(self, low: T, high: U) -> ExprBuilder {
ExprBuilder::between(self.into(), low.to_sql_value(), high.to_sql_value(), false)
}
#[must_use]
pub fn not_between<T: ToSqlValue, U: ToSqlValue>(self, low: T, high: U) -> ExprBuilder {
ExprBuilder::between(self.into(), low.to_sql_value(), high.to_sql_value(), true)
}
#[must_use]
pub fn in_list<T: ToSqlValue>(self, values: Vec<T>) -> ExprBuilder {
let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
ExprBuilder::in_list_impl(self.into(), sql_values, false)
}
#[must_use]
pub fn not_in_list<T: ToSqlValue>(self, values: Vec<T>) -> ExprBuilder {
let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
ExprBuilder::in_list_impl(self.into(), sql_values, true)
}
}
#[derive(Debug, Clone)]
pub struct ExprBuilder {
sql: String,
params: Vec<SqlValue>,
}
impl ExprBuilder {
#[must_use]
pub fn raw(sql: impl Into<String>) -> Self {
Self {
sql: sql.into(),
params: vec![],
}
}
#[must_use]
pub fn column(name: &str) -> Self {
Self {
sql: String::from(name),
params: vec![],
}
}
#[must_use]
pub fn value<T: ToSqlValue>(value: T) -> Self {
Self {
sql: String::from("?"),
params: vec![value.to_sql_value()],
}
}
fn binary(left: Self, op: &str, right: Self) -> Self {
let mut params = left.params;
params.extend(right.params);
Self {
sql: format!("{} {op} {}", left.sql, right.sql),
params,
}
}
fn postfix(operand: Self, op: &str) -> Self {
Self {
sql: format!("{} {op}", operand.sql),
params: operand.params,
}
}
fn between(expr: Self, low: SqlValue, high: SqlValue, negated: bool) -> Self {
let keyword = if negated { "NOT BETWEEN" } else { "BETWEEN" };
let mut params = expr.params;
params.push(low);
params.push(high);
Self {
sql: format!("{} {keyword} ? AND ?", expr.sql),
params,
}
}
fn in_list_impl(expr: Self, values: Vec<SqlValue>, negated: bool) -> Self {
let keyword = if negated { "NOT IN" } else { "IN" };
let placeholders: Vec<&str> = values.iter().map(|_| "?").collect();
let mut params = expr.params;
params.extend(values);
Self {
sql: format!("{} {keyword} ({})", expr.sql, placeholders.join(", ")),
params,
}
}
#[must_use]
pub fn and(self, other: Self) -> Self {
Self::binary(self, "AND", other)
}
#[must_use]
pub fn or(self, other: Self) -> Self {
Self::binary(self, "OR", other)
}
#[must_use]
pub fn paren(self) -> Self {
Self {
sql: format!("({})", self.sql),
params: self.params,
}
}
#[must_use]
#[allow(clippy::should_implement_trait)]
pub fn not(self) -> Self {
Self {
sql: format!("NOT {}", self.sql),
params: self.params,
}
}
#[must_use]
pub fn eq<T: ToSqlValue>(self, value: T) -> Self {
Self::binary(self, "=", value.to_sql_value().into())
}
#[must_use]
pub fn not_eq<T: ToSqlValue>(self, value: T) -> Self {
Self::binary(self, "!=", value.to_sql_value().into())
}
#[must_use]
pub fn lt<T: ToSqlValue>(self, value: T) -> Self {
Self::binary(self, "<", value.to_sql_value().into())
}
#[must_use]
pub fn lt_eq<T: ToSqlValue>(self, value: T) -> Self {
Self::binary(self, "<=", value.to_sql_value().into())
}
#[must_use]
pub fn gt<T: ToSqlValue>(self, value: T) -> Self {
Self::binary(self, ">", value.to_sql_value().into())
}
#[must_use]
pub fn gt_eq<T: ToSqlValue>(self, value: T) -> Self {
Self::binary(self, ">=", value.to_sql_value().into())
}
#[must_use]
pub fn is_null(self) -> Self {
Self::postfix(self, "IS NULL")
}
#[must_use]
pub fn is_not_null(self) -> Self {
Self::postfix(self, "IS NOT NULL")
}
#[must_use]
pub fn like<T: ToSqlValue>(self, pattern: T) -> Self {
Self::binary(self, "LIKE", pattern.to_sql_value().into())
}
#[must_use]
pub fn in_list<T: ToSqlValue>(self, values: Vec<T>) -> Self {
let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
Self::in_list_impl(self, sql_values, false)
}
#[must_use]
pub fn not_in_list<T: ToSqlValue>(self, values: Vec<T>) -> Self {
let sql_values: Vec<SqlValue> = values.into_iter().map(ToSqlValue::to_sql_value).collect();
Self::in_list_impl(self, sql_values, true)
}
#[must_use]
pub fn sql(&self) -> &str {
&self.sql
}
#[must_use]
pub fn params(&self) -> &[SqlValue] {
&self.params
}
#[must_use]
pub fn build(self) -> (String, Vec<SqlValue>) {
(self.sql, self.params)
}
}
impl From<ColumnRef> for ExprBuilder {
fn from(col: ColumnRef) -> Self {
Self {
sql: col.to_sql(),
params: vec![],
}
}
}
impl From<SqlValue> for ExprBuilder {
fn from(value: SqlValue) -> Self {
Self {
sql: String::from("?"),
params: vec![value],
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_column_eq() {
let expr = dyn_col("name").eq("Alice");
assert_eq!(expr.sql(), "name = ?");
assert_eq!(expr.params().len(), 1);
}
#[test]
fn test_column_comparison() {
assert_eq!(dyn_col("age").gt(18).sql(), "age > ?");
assert_eq!(dyn_col("age").lt_eq(65).sql(), "age <= ?");
}
#[test]
fn test_is_null() {
let expr = dyn_col("deleted_at").is_null();
assert_eq!(expr.sql(), "deleted_at IS NULL");
assert!(expr.params().is_empty());
}
#[test]
fn test_like() {
let expr = dyn_col("email").like("%@example.com");
assert_eq!(expr.sql(), "email LIKE ?");
}
#[test]
fn test_between() {
let expr = dyn_col("price").between(10, 100);
assert_eq!(expr.sql(), "price BETWEEN ? AND ?");
assert_eq!(expr.params().len(), 2);
}
#[test]
fn test_in_list() {
let expr = dyn_col("status").in_list(vec!["active", "pending"]);
assert_eq!(expr.sql(), "status IN (?, ?)");
assert_eq!(expr.params().len(), 2);
}
#[test]
fn test_and_or() {
let expr = dyn_col("active").eq(true).and(
dyn_col("age")
.gt(18)
.or(dyn_col("verified").eq(true))
.paren(),
);
assert_eq!(expr.sql(), "active = ? AND (age > ? OR verified = ?)");
assert_eq!(expr.params().len(), 3);
}
#[test]
fn test_qualified_column() {
let expr = ColumnRef::qualified("users", "name").eq("Bob");
assert_eq!(expr.sql(), "users.name = ?");
}
#[test]
fn test_sql_injection_prevention() {
let malicious = "'; DROP TABLE users; --";
let expr = dyn_col("name").eq(malicious);
assert_eq!(expr.sql(), "name = ?");
assert!(matches!(&expr.params()[0], SqlValue::Text(s) if s == malicious));
}
}