use std::marker::PhantomData;
use crate::condition::{Condition, JoinOp, OrderDir, SqlValue};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Dialect {
#[default]
Postgres,
Sqlite,
}
#[derive(Debug, Clone)]
pub enum Join {
Inner(String, String),
Left(String, String),
Right(String, String),
Raw(String),
}
#[derive(Debug, Clone)]
pub struct QueryBuilder<T> {
table: String,
select_cols: Option<Vec<String>>,
select_raw: Option<String>,
distinct: bool,
distinct_on_cols: Vec<String>,
joins: Vec<Join>,
conditions: Vec<(JoinOp, Condition)>,
group_by: Vec<String>,
having: Option<String>,
order: Vec<(String, OrderDir)>,
order_raw: Option<String>,
limit_val: Option<usize>,
offset_val: Option<usize>,
pub use_replica: bool,
extra_select_exprs: Vec<String>,
_marker: PhantomData<T>,
}
impl<T> QueryBuilder<T> {
pub fn new(table: impl Into<String>) -> Self {
Self {
table: table.into(),
select_cols: None,
select_raw: None,
distinct: false,
distinct_on_cols: Vec::new(),
joins: Vec::new(),
conditions: Vec::new(),
group_by: Vec::new(),
having: None,
order: Vec::new(),
order_raw: None,
limit_val: None,
offset_val: None,
use_replica: true,
extra_select_exprs: Vec::new(),
_marker: PhantomData,
}
}
pub fn select(mut self, cols: &[&str]) -> Self {
self.select_cols = Some(cols.iter().map(|s| s.to_string()).collect());
self
}
pub fn select_raw(mut self, expr: &str) -> Self {
self.select_raw = Some(expr.to_string());
self
}
pub fn add_select_expr(mut self, expr: impl Into<String>) -> Self {
self.extra_select_exprs.push(expr.into());
self
}
pub fn distinct(mut self) -> Self {
self.distinct = true;
self
}
pub fn distinct_on(mut self, cols: &[&str]) -> Self {
self.distinct_on_cols = cols.iter().map(|s| s.to_string()).collect();
self
}
pub fn inner_join(mut self, table: &str, on: &str) -> Self {
self.joins
.push(Join::Inner(table.to_string(), on.to_string()));
self
}
pub fn left_join(mut self, table: &str, on: &str) -> Self {
self.joins
.push(Join::Left(table.to_string(), on.to_string()));
self
}
pub fn right_join(mut self, table: &str, on: &str) -> Self {
self.joins
.push(Join::Right(table.to_string(), on.to_string()));
self
}
pub fn join_raw(mut self, raw: &str) -> Self {
self.joins.push(Join::Raw(raw.to_string()));
self
}
pub fn group_by(mut self, cols: &[&str]) -> Self {
self.group_by = cols.iter().map(|s| s.to_string()).collect();
self
}
pub fn having(mut self, expr: &str) -> Self {
self.having = Some(expr.to_string());
self
}
pub fn where_eq(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::And, Condition::Eq(col.into(), val.into()))
}
pub fn where_ne(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::And, Condition::Ne(col.into(), val.into()))
}
pub fn where_gt(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::And, Condition::Gt(col.into(), val.into()))
}
pub fn where_gte(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::And, Condition::Gte(col.into(), val.into()))
}
pub fn where_lt(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::And, Condition::Lt(col.into(), val.into()))
}
pub fn where_lte(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::And, Condition::Lte(col.into(), val.into()))
}
pub fn where_like(self, col: &str, pattern: &str) -> Self {
self.push(JoinOp::And, Condition::Like(col.into(), pattern.into()))
}
pub fn where_not_like(self, col: &str, pattern: &str) -> Self {
self.push(JoinOp::And, Condition::NotLike(col.into(), pattern.into()))
}
pub fn where_null(self, col: &str) -> Self {
self.push(JoinOp::And, Condition::IsNull(col.into()))
}
pub fn where_not_null(self, col: &str) -> Self {
self.push(JoinOp::And, Condition::IsNotNull(col.into()))
}
pub fn where_in(self, col: &str, vals: Vec<impl Into<SqlValue>>) -> Self {
self.push(
JoinOp::And,
Condition::In(col.into(), vals.into_iter().map(Into::into).collect()),
)
}
pub fn where_not_in(self, col: &str, vals: Vec<impl Into<SqlValue>>) -> Self {
self.push(
JoinOp::And,
Condition::NotIn(col.into(), vals.into_iter().map(Into::into).collect()),
)
}
pub fn where_between(
self,
col: &str,
lo: impl Into<SqlValue>,
hi: impl Into<SqlValue>,
) -> Self {
self.push(
JoinOp::And,
Condition::Between(col.into(), lo.into(), hi.into()),
)
}
pub fn where_not_between(
self,
col: &str,
lo: impl Into<SqlValue>,
hi: impl Into<SqlValue>,
) -> Self {
self.push(
JoinOp::And,
Condition::NotBetween(col.into(), lo.into(), hi.into()),
)
}
pub fn where_raw(self, sql: &str) -> Self {
self.push(JoinOp::And, Condition::Raw(sql.into()))
}
pub fn where_ilike(self, col: &str, pattern: &str) -> Self {
self.push(JoinOp::And, Condition::ILike(col.into(), pattern.into()))
}
pub fn or_where_ilike(self, col: &str, pattern: &str) -> Self {
self.push(JoinOp::Or, Condition::ILike(col.into(), pattern.into()))
}
pub fn where_json(self, col: &str, key: &str, val: impl Into<SqlValue>) -> Self {
self.push(
JoinOp::And,
Condition::JsonGet(col.into(), key.into(), val.into()),
)
}
pub fn or_where_json(self, col: &str, key: &str, val: impl Into<SqlValue>) -> Self {
self.push(
JoinOp::Or,
Condition::JsonGet(col.into(), key.into(), val.into()),
)
}
pub fn where_json_contains(self, col: &str, json_val: &str) -> Self {
self.push(
JoinOp::And,
Condition::Raw(format!("{col} @> '{json_val}'::jsonb")),
)
}
pub fn where_column(self, col: &str, op: &str, val: impl Into<SqlValue>) -> Self {
self.where_op(col, op, val)
}
pub fn where_op(self, col: &str, op: &str, val: impl Into<SqlValue>) -> Self {
let cond = match op {
"=" | "==" => Condition::Eq(col.into(), val.into()),
"!=" | "<>" => Condition::Ne(col.into(), val.into()),
">" => Condition::Gt(col.into(), val.into()),
">=" => Condition::Gte(col.into(), val.into()),
"<" => Condition::Lt(col.into(), val.into()),
"<=" => Condition::Lte(col.into(), val.into()),
other => Condition::Raw(format!("{col} {other} {}", val.into())),
};
self.push(JoinOp::And, cond)
}
pub fn where_group<F>(self, f: F) -> Self
where
F: FnOnce(QueryBuilder<T>) -> QueryBuilder<T>,
{
let inner_builder = f(QueryBuilder::new(""));
if inner_builder.conditions.is_empty() {
return self;
}
self.push(JoinOp::And, Condition::Group(inner_builder.conditions))
}
pub fn or_where_group<F>(self, f: F) -> Self
where
F: FnOnce(QueryBuilder<T>) -> QueryBuilder<T>,
{
let inner_builder = f(QueryBuilder::new(""));
if inner_builder.conditions.is_empty() {
return self;
}
self.push(JoinOp::Or, Condition::Group(inner_builder.conditions))
}
pub fn or_where(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.or_where_eq(col, val)
}
pub fn or_where_eq(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::Or, Condition::Eq(col.into(), val.into()))
}
pub fn or_where_ne(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::Or, Condition::Ne(col.into(), val.into()))
}
pub fn or_where_gt(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::Or, Condition::Gt(col.into(), val.into()))
}
pub fn or_where_gte(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::Or, Condition::Gte(col.into(), val.into()))
}
pub fn or_where_lt(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::Or, Condition::Lt(col.into(), val.into()))
}
pub fn or_where_lte(self, col: &str, val: impl Into<SqlValue>) -> Self {
self.push(JoinOp::Or, Condition::Lte(col.into(), val.into()))
}
pub fn or_where_like(self, col: &str, pattern: &str) -> Self {
self.push(JoinOp::Or, Condition::Like(col.into(), pattern.into()))
}
pub fn or_where_null(self, col: &str) -> Self {
self.push(JoinOp::Or, Condition::IsNull(col.into()))
}
pub fn or_where_not_null(self, col: &str) -> Self {
self.push(JoinOp::Or, Condition::IsNotNull(col.into()))
}
pub fn or_where_in(self, col: &str, vals: Vec<impl Into<SqlValue>>) -> Self {
self.push(
JoinOp::Or,
Condition::In(col.into(), vals.into_iter().map(Into::into).collect()),
)
}
pub fn or_where_between(
self,
col: &str,
lo: impl Into<SqlValue>,
hi: impl Into<SqlValue>,
) -> Self {
self.push(
JoinOp::Or,
Condition::Between(col.into(), lo.into(), hi.into()),
)
}
pub fn or_where_raw(self, sql: &str) -> Self {
self.push(JoinOp::Or, Condition::Raw(sql.into()))
}
pub fn order_by(mut self, col: &str) -> Self {
self.order.push((col.into(), OrderDir::Asc));
self
}
pub fn order_by_desc(mut self, col: &str) -> Self {
self.order.push((col.into(), OrderDir::Desc));
self
}
pub fn order_by_raw(mut self, expr: &str) -> Self {
self.order_raw = Some(expr.to_string());
self
}
pub fn order_by_many(mut self, cols: &[(&str, OrderDir)]) -> Self {
for (col, dir) in cols {
self.order.push((col.to_string(), *dir));
}
self
}
pub fn reorder(mut self, col: &str) -> Self {
self.order.clear();
self.order_raw = None;
self.order.push((col.into(), OrderDir::Asc));
self
}
pub fn reorder_desc(mut self, col: &str) -> Self {
self.order.clear();
self.order_raw = None;
self.order.push((col.into(), OrderDir::Desc));
self
}
pub fn limit(mut self, n: usize) -> Self {
self.limit_val = Some(n);
self
}
pub fn offset(mut self, n: usize) -> Self {
self.offset_val = Some(n);
self
}
pub fn to_sql(&self) -> (String, Vec<SqlValue>) {
self.to_sql_with_dialect(Dialect::Postgres)
}
pub fn to_sql_with_dialect(&self, dialect: Dialect) -> (String, Vec<SqlValue>) {
let base_cols = if let Some(raw) = &self.select_raw {
raw.clone()
} else {
self.select_cols
.as_ref()
.map(|c| c.join(", "))
.unwrap_or_else(|| "*".into())
};
let cols = if self.extra_select_exprs.is_empty() {
base_cols
} else {
format!("{}, {}", base_cols, self.extra_select_exprs.join(", "))
};
let distinct_kw = if !self.distinct_on_cols.is_empty() {
format!("DISTINCT ON ({}) ", self.distinct_on_cols.join(", "))
} else if self.distinct {
"DISTINCT ".to_string()
} else {
String::new()
};
let mut sql = format!("SELECT {distinct_kw}{cols} FROM {}", self.table);
let mut params: Vec<SqlValue> = Vec::new();
sql.push_str(&self.build_joins());
sql.push_str(&self.build_where_dialect(dialect, &mut params));
sql.push_str(&self.build_group_by());
sql.push_str(&self.build_order());
if let Some(n) = self.limit_val {
sql.push_str(&format!(" LIMIT {n}"));
}
if let Some(n) = self.offset_val {
sql.push_str(&format!(" OFFSET {n}"));
}
(sql, params)
}
pub fn to_count_sql(&self) -> (String, Vec<SqlValue>) {
self.to_count_sql_with_dialect(Dialect::Postgres)
}
pub fn to_count_sql_with_dialect(&self, dialect: Dialect) -> (String, Vec<SqlValue>) {
let mut params: Vec<SqlValue> = Vec::new();
let joins = self.build_joins();
let where_clause = self.build_where_dialect(dialect, &mut params);
(
format!(
"SELECT COUNT(*) FROM {}{}{}",
self.table, joins, where_clause
),
params,
)
}
pub fn to_delete_sql(&self) -> (String, Vec<SqlValue>) {
self.to_delete_sql_with_dialect(Dialect::Postgres)
}
pub fn to_delete_sql_with_dialect(&self, dialect: Dialect) -> (String, Vec<SqlValue>) {
let mut params: Vec<SqlValue> = Vec::new();
let where_clause = self.build_where_dialect(dialect, &mut params);
(
format!("DELETE FROM {}{}", self.table, where_clause),
params,
)
}
pub fn to_update_sql(&self, data: &[(&str, SqlValue)]) -> (String, Vec<SqlValue>) {
self.to_update_sql_with_dialect(Dialect::Postgres, data)
}
pub fn to_update_sql_with_dialect(
&self,
dialect: Dialect,
data: &[(&str, SqlValue)],
) -> (String, Vec<SqlValue>) {
let mut params: Vec<SqlValue> = Vec::new();
let set_clauses: Vec<String> = data
.iter()
.enumerate()
.map(|(i, (col, val))| {
params.push(val.clone());
match dialect {
Dialect::Postgres => format!("{col} = ${}", i + 1),
Dialect::Sqlite => format!("{col} = ?"),
}
})
.collect();
let mut sql = format!("UPDATE {} SET {}", self.table, set_clauses.join(", "));
sql.push_str(&self.build_where_dialect(dialect, &mut params));
(sql, params)
}
pub fn insert_sql(table: &str, data: &[(&str, SqlValue)]) -> (String, Vec<SqlValue>) {
Self::insert_sql_with_dialect(Dialect::Postgres, table, data)
}
pub fn insert_sql_with_dialect(
dialect: Dialect,
table: &str,
data: &[(&str, SqlValue)],
) -> (String, Vec<SqlValue>) {
let cols: Vec<&str> = data.iter().map(|(c, _)| *c).collect();
let placeholders: Vec<String> = match dialect {
Dialect::Postgres => (1..=data.len()).map(|i| format!("${i}")).collect(),
Dialect::Sqlite => (0..data.len()).map(|_| "?".to_string()).collect(),
};
let params: Vec<SqlValue> = data.iter().map(|(_, v)| v.clone()).collect();
(
format!(
"INSERT INTO {table} ({}) VALUES ({})",
cols.join(", "),
placeholders.join(", ")
),
params,
)
}
pub fn bulk_insert_sql(table: &str, rows: &[Vec<(&str, SqlValue)>]) -> (String, Vec<SqlValue>) {
assert!(
!rows.is_empty(),
"bulk_insert_sql requires at least one row"
);
let cols: Vec<&str> = rows[0].iter().map(|(c, _)| *c).collect();
let mut params: Vec<SqlValue> = Vec::new();
let mut value_groups: Vec<String> = Vec::new();
let mut offset = 1usize;
for row in rows {
let placeholders: Vec<String> = (offset..offset + row.len())
.map(|i| format!("${i}"))
.collect();
value_groups.push(format!("({})", placeholders.join(", ")));
for (_, v) in row.iter() {
params.push(v.clone());
}
offset += row.len();
}
(
format!(
"INSERT INTO {table} ({}) VALUES {}",
cols.join(", "),
value_groups.join(", ")
),
params,
)
}
pub fn update_sql(
table: &str,
data: &[(&str, SqlValue)],
conditions: &[(JoinOp, Condition)],
) -> (String, Vec<SqlValue>) {
let mut params: Vec<SqlValue> = Vec::new();
let set_clauses: Vec<String> = data
.iter()
.enumerate()
.map(|(i, (col, val))| {
params.push(val.clone());
format!("{col} = ${}", i + 1)
})
.collect();
let mut sql = format!("UPDATE {table} SET {}", set_clauses.join(", "));
if !conditions.is_empty() {
let where_frag = build_where_from(conditions, &mut params);
sql.push_str(&where_frag);
}
(sql, params)
}
pub fn when<F>(self, condition: bool, f: F) -> Self
where
F: FnOnce(Self) -> Self,
{
if condition {
f(self)
} else {
self
}
}
pub fn when_some<V, F>(self, opt: Option<V>, f: F) -> Self
where
F: FnOnce(Self, V) -> Self,
{
match opt {
Some(v) => f(self, v),
None => self,
}
}
pub fn push_condition(self, op: JoinOp, cond: Condition) -> Self {
self.push(op, cond)
}
pub fn on_write_db(mut self) -> Self {
self.use_replica = false;
self
}
pub fn nearest_to(self, col: &str, embedding: &[f32], k: usize) -> Self {
let vec_lit = format_vector(embedding);
self.order_by_raw(&format!("{col} <-> '{vec_lit}'::vector"))
.limit(k)
}
pub fn where_cosine_distance(
self,
col: &str,
embedding: &[f32],
op: &str,
threshold: f64,
) -> Self {
let vec_lit = format_vector(embedding);
self.where_raw(&format!("{col} <=> '{vec_lit}'::vector {op} {threshold}"))
}
pub fn where_vector_distance(
self,
col: &str,
embedding: &[f32],
op: &str,
threshold: f64,
) -> Self {
let vec_lit = format_vector(embedding);
self.where_raw(&format!("{col} <-> '{vec_lit}'::vector {op} {threshold}"))
}
pub fn where_inner_product(
self,
col: &str,
embedding: &[f32],
op: &str,
threshold: f64,
) -> Self {
let vec_lit = format_vector(embedding);
self.where_raw(&format!("{col} <#> '{vec_lit}'::vector {op} {threshold}"))
}
fn push(mut self, op: JoinOp, cond: Condition) -> Self {
self.conditions.push((op, cond));
self
}
fn build_joins(&self) -> String {
let mut out = String::new();
for join in &self.joins {
match join {
Join::Inner(t, on) => out.push_str(&format!(" INNER JOIN {t} ON {on}")),
Join::Left(t, on) => out.push_str(&format!(" LEFT JOIN {t} ON {on}")),
Join::Right(t, on) => out.push_str(&format!(" RIGHT JOIN {t} ON {on}")),
Join::Raw(raw) => {
out.push(' ');
out.push_str(raw);
}
}
}
out
}
fn build_where_dialect(&self, dialect: Dialect, params: &mut Vec<SqlValue>) -> String {
build_where_from_dialect(dialect, &self.conditions, params)
}
fn build_group_by(&self) -> String {
let mut out = String::new();
if !self.group_by.is_empty() {
out.push_str(&format!(" GROUP BY {}", self.group_by.join(", ")));
}
if let Some(ref h) = self.having {
out.push_str(&format!(" HAVING {h}"));
}
out
}
fn build_order(&self) -> String {
let mut parts: Vec<String> = self
.order
.iter()
.map(|(col, dir)| format!("{col} {dir}"))
.collect();
if let Some(raw) = &self.order_raw {
parts.push(raw.clone());
}
if parts.is_empty() {
return String::new();
}
format!(" ORDER BY {}", parts.join(", "))
}
pub fn conditions(&self) -> &[(JoinOp, Condition)] {
&self.conditions
}
pub fn to_where_clause(&self) -> (String, Vec<SqlValue>) {
let mut params = Vec::new();
let clause = self.build_where_dialect(Dialect::Postgres, &mut params);
(clause, params)
}
pub fn to_aggregate_sql(&self, agg_expr: &str) -> (String, Vec<SqlValue>) {
let mut params: Vec<SqlValue> = Vec::new();
let joins = self.build_joins();
let where_clause = self.build_where_dialect(Dialect::Postgres, &mut params);
(
format!(
"SELECT {agg_expr} FROM {}{}{}",
self.table, joins, where_clause
),
params,
)
}
pub fn upsert_sql(
table: &str,
data: &[(&str, SqlValue)],
conflict_cols: &[&str],
) -> (String, Vec<SqlValue>) {
let cols: Vec<&str> = data.iter().map(|(c, _)| *c).collect();
let placeholders: Vec<String> = (1..=data.len()).map(|i| format!("${i}")).collect();
let params: Vec<SqlValue> = data.iter().map(|(_, v)| v.clone()).collect();
let conflict_target = conflict_cols.join(", ");
let update_set: Vec<String> = cols
.iter()
.filter(|c| !conflict_cols.contains(c))
.map(|c| format!("{c} = EXCLUDED.{c}"))
.collect();
let sql = if update_set.is_empty() {
format!(
"INSERT INTO {table} ({}) VALUES ({}) ON CONFLICT ({conflict_target}) DO NOTHING",
cols.join(", "),
placeholders.join(", "),
)
} else {
format!(
"INSERT INTO {table} ({}) VALUES ({}) ON CONFLICT ({conflict_target}) DO UPDATE SET {}",
cols.join(", "),
placeholders.join(", "),
update_set.join(", "),
)
};
(sql, params)
}
}
fn format_vector(v: &[f32]) -> String {
let elems: Vec<String> = v.iter().map(|x| x.to_string()).collect();
format!("[{}]", elems.join(","))
}
fn build_where_from(conditions: &[(JoinOp, Condition)], params: &mut Vec<SqlValue>) -> String {
build_where_from_dialect(Dialect::Postgres, conditions, params)
}
fn build_where_from_dialect(
dialect: Dialect,
conditions: &[(JoinOp, Condition)],
params: &mut Vec<SqlValue>,
) -> String {
if conditions.is_empty() {
return String::new();
}
let mut out = " WHERE ".to_string();
for (idx, (op, cond)) in conditions.iter().enumerate() {
let (frag, ps) = match dialect {
Dialect::Postgres => cond.to_param_sql(params.len() + 1),
Dialect::Sqlite => cond.to_param_sql_sqlite(),
};
params.extend(ps);
if idx > 0 {
out.push(' ');
out.push_str(&op.to_string());
out.push(' ');
}
out.push_str(&frag);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_select() {
let (sql, params) = QueryBuilder::<()>::new("users").to_sql();
assert_eq!(sql, "SELECT * FROM users");
assert!(params.is_empty());
}
#[test]
fn distinct_select() {
let (sql, _) = QueryBuilder::<()>::new("users").distinct().to_sql();
assert!(sql.starts_with("SELECT DISTINCT * FROM users"));
}
#[test]
fn where_eq_generates_param() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_eq("id", 42i64)
.to_sql();
assert!(sql.contains("WHERE id = $1"));
assert_eq!(params.len(), 1);
assert_eq!(params[0], SqlValue::Integer(42));
}
#[test]
fn multiple_conditions() {
let (sql, params) = QueryBuilder::<()>::new("posts")
.where_eq("active", true)
.where_like("title", "%rust%")
.to_sql();
assert!(sql.contains("WHERE active = $1 AND title LIKE $2"));
assert_eq!(params.len(), 2);
}
#[test]
fn or_conditions() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_eq("role", "admin")
.or_where_eq("role", "moderator")
.to_sql();
assert!(sql.contains("WHERE role = $1 OR role = $2"));
assert_eq!(params.len(), 2);
}
#[test]
fn where_between() {
let (sql, params) = QueryBuilder::<()>::new("orders")
.where_between("amount", 10i64, 100i64)
.to_sql();
assert!(sql.contains("amount BETWEEN $1 AND $2"));
assert_eq!(params.len(), 2);
}
#[test]
fn where_not_in() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_not_in("status", vec!["banned", "deleted"])
.to_sql();
assert!(sql.contains("status NOT IN ($1, $2)"));
assert_eq!(params.len(), 2);
}
#[test]
fn where_not_like() {
let (sql, _) = QueryBuilder::<()>::new("users")
.where_not_like("email", "%@spam.com")
.to_sql();
assert!(sql.contains("email NOT LIKE $1"));
}
#[test]
fn to_update_sql() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_eq("id", 1i64)
.to_update_sql(&[("name", "Bob".into()), ("active", true.into())]);
assert!(sql.starts_with("UPDATE users SET name = $1, active = $2"));
assert!(sql.contains("WHERE id = $3"));
assert_eq!(params.len(), 3);
}
#[test]
fn order_limit_offset() {
let (sql, _) = QueryBuilder::<()>::new("users")
.order_by_desc("created_at")
.order_by("name")
.limit(10)
.offset(20)
.to_sql();
assert!(sql.contains("ORDER BY created_at DESC, name ASC"));
assert!(sql.contains("LIMIT 10"));
assert!(sql.contains("OFFSET 20"));
}
#[test]
fn count_sql() {
let (sql, _) = QueryBuilder::<()>::new("users")
.where_eq("active", true)
.to_count_sql();
assert!(sql.starts_with("SELECT COUNT(*) FROM users"));
}
#[test]
fn delete_sql() {
let (sql, params) = QueryBuilder::<()>::new("sessions")
.where_eq("user_id", 5i64)
.to_delete_sql();
assert!(sql.contains("DELETE FROM sessions WHERE user_id = $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn insert_sql() {
let (sql, params) = QueryBuilder::<()>::insert_sql(
"users",
&[("name", "Alice".into()), ("email", "a@a.com".into())],
);
assert!(sql.contains("INSERT INTO users (name, email) VALUES ($1, $2)"));
assert_eq!(params.len(), 2);
}
#[test]
fn where_in() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_in("id", vec![1i64, 2, 3])
.to_sql();
assert!(sql.contains("id IN ($1, $2, $3)"));
assert_eq!(params.len(), 3);
}
#[test]
fn select_specific_columns() {
let (sql, _) = QueryBuilder::<()>::new("users")
.select(&["id", "email"])
.to_sql();
assert!(sql.starts_with("SELECT id, email FROM users"));
}
#[test]
fn option_value_null() {
let val: SqlValue = Option::<i64>::None.into();
assert_eq!(val, SqlValue::Null);
}
#[test]
fn option_value_some() {
let val: SqlValue = Some(42i64).into();
assert_eq!(val, SqlValue::Integer(42));
}
#[test]
fn inner_join() {
let (sql, _) = QueryBuilder::<()>::new("orders")
.inner_join("users", "users.id = orders.user_id")
.to_sql();
assert!(sql.contains("INNER JOIN users ON users.id = orders.user_id"));
}
#[test]
fn left_join_with_where() {
let (sql, params) = QueryBuilder::<()>::new("orders")
.left_join("users", "users.id = orders.user_id")
.where_eq("orders.status", "paid")
.to_sql();
assert!(sql.contains("LEFT JOIN users ON users.id = orders.user_id"));
assert!(sql.contains("WHERE orders.status = $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn right_join() {
let (sql, _) = QueryBuilder::<()>::new("orders")
.right_join("products", "products.id = orders.product_id")
.to_sql();
assert!(sql.contains("RIGHT JOIN products ON products.id = orders.product_id"));
}
#[test]
fn group_by_and_having() {
let (sql, _) = QueryBuilder::<()>::new("orders")
.select(&["user_id", "COUNT(*) as total"])
.group_by(&["user_id"])
.having("COUNT(*) > 5")
.to_sql();
assert!(sql.contains("GROUP BY user_id"));
assert!(sql.contains("HAVING COUNT(*) > 5"));
let gpos = sql.find("GROUP BY").unwrap();
let hpos = sql.find("HAVING").unwrap();
assert!(gpos < hpos);
}
#[test]
fn count_sql_with_join() {
let (sql, _) = QueryBuilder::<()>::new("orders")
.inner_join("users", "users.id = orders.user_id")
.where_eq("users.active", true)
.to_count_sql();
assert!(sql.contains("INNER JOIN users ON users.id = orders.user_id"));
assert!(sql.contains("SELECT COUNT(*) FROM orders"));
}
#[test]
fn bulk_insert_sql_two_rows() {
let rows: Vec<Vec<(&str, SqlValue)>> = vec![
vec![("name", "Alice".into()), ("email", "a@a.com".into())],
vec![("name", "Bob".into()), ("email", "b@b.com".into())],
];
let (sql, params) = QueryBuilder::<()>::bulk_insert_sql("users", &rows);
assert!(sql.starts_with("INSERT INTO users (name, email) VALUES"));
assert!(sql.contains("($1, $2), ($3, $4)"));
assert_eq!(params.len(), 4);
}
#[test]
fn bulk_insert_sql_single_row() {
let rows = vec![vec![("x", SqlValue::Integer(1))]];
let (sql, params) = QueryBuilder::<()>::bulk_insert_sql("t", &rows);
assert!(sql.contains("($1)"));
assert_eq!(params.len(), 1);
}
#[test]
fn where_ilike() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_ilike("name", "alice%")
.to_sql();
assert!(sql.contains("name ILIKE $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn where_op_gt() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_op("age", ">", 18i64)
.to_sql();
assert!(sql.contains("age > $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn where_group_subquery() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_eq("active", true)
.where_group(|q| q.where_eq("role", "admin").or_where_eq("role", "mod"))
.to_sql();
assert!(sql.contains("active = $1"));
assert!(sql.contains("AND (role = $2 OR role = $3)"));
assert_eq!(params.len(), 3);
}
#[test]
fn select_raw() {
let (sql, _) = QueryBuilder::<()>::new("users")
.select_raw("id, LOWER(email) as email_lower")
.to_sql();
assert!(sql.starts_with("SELECT id, LOWER(email) as email_lower FROM users"));
}
#[test]
fn distinct_on() {
let (sql, _) = QueryBuilder::<()>::new("users")
.distinct_on(&["email"])
.to_sql();
assert!(sql.starts_with("SELECT DISTINCT ON (email) * FROM users"));
}
#[test]
fn join_raw() {
let (sql, _) = QueryBuilder::<()>::new("users")
.join_raw("INNER JOIN subscriptions s ON s.user_id = users.id AND s.active = true")
.to_sql();
assert!(sql.contains("INNER JOIN subscriptions s ON s.user_id = users.id"));
}
#[test]
fn order_by_raw() {
let (sql, _) = QueryBuilder::<()>::new("users")
.order_by_raw("NULLS LAST, score DESC")
.to_sql();
assert!(sql.contains("ORDER BY NULLS LAST, score DESC"));
}
#[test]
fn order_by_many() {
let (sql, _) = QueryBuilder::<()>::new("users")
.order_by_many(&[("role", OrderDir::Asc), ("created_at", OrderDir::Desc)])
.to_sql();
assert!(sql.contains("ORDER BY role ASC, created_at DESC"));
}
#[test]
fn reorder_clears_previous() {
let (sql, _) = QueryBuilder::<()>::new("users")
.order_by("name")
.reorder_desc("created_at")
.to_sql();
assert!(sql.contains("ORDER BY created_at DESC"));
assert!(!sql.contains("name"));
}
#[test]
fn upsert_sql_basic() {
let (sql, params) = QueryBuilder::<()>::upsert_sql(
"users",
&[("email", "x@y.com".into()), ("name", "Bob".into())],
&["email"],
);
assert!(sql.contains("ON CONFLICT (email) DO UPDATE SET name = EXCLUDED.name"));
assert_eq!(params.len(), 2);
}
#[test]
fn where_json_extraction() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_json("settings", "theme", "dark")
.to_sql();
assert!(sql.contains("settings->>'theme' = $1"), "sql={sql}");
assert_eq!(params.len(), 1);
assert_eq!(params[0], SqlValue::Text("dark".into()));
}
#[test]
fn where_json_contains_raw() {
let (sql, _) = QueryBuilder::<()>::new("posts")
.where_json_contains("tags", r#"["rust"]"#)
.to_sql();
assert!(sql.contains(r#"tags @> '["rust"]'::jsonb"#), "sql={sql}");
}
#[test]
fn to_aggregate_sql() {
let (sql, params) = QueryBuilder::<()>::new("orders")
.where_eq("user_id", 1i64)
.to_aggregate_sql("MAX(total)");
assert!(sql.starts_with("SELECT MAX(total) FROM orders"));
assert!(sql.contains("WHERE user_id = $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn when_applies_closure_when_true() {
let (sql, params) = QueryBuilder::<()>::new("users")
.when(true, |q| q.where_eq("active", true))
.to_sql();
assert!(sql.contains("WHERE active = $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn when_noop_when_false() {
let (sql, params) = QueryBuilder::<()>::new("users")
.when(false, |q| q.where_eq("active", true))
.to_sql();
assert!(!sql.contains("WHERE"));
assert!(params.is_empty());
}
#[test]
fn when_some_applies_with_value() {
let role: Option<&str> = Some("admin");
let (sql, params) = QueryBuilder::<()>::new("users")
.when_some(role, |q, r| q.where_eq("role", r))
.to_sql();
assert!(sql.contains("role = $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn when_some_noop_when_none() {
let role: Option<&str> = None;
let (sql, params) = QueryBuilder::<()>::new("users")
.when_some(role, |q, r| q.where_eq("role", r))
.to_sql();
assert!(!sql.contains("WHERE"));
assert!(params.is_empty());
}
#[test]
fn chained_when_calls() {
let (sql, params) = QueryBuilder::<()>::new("users")
.when(true, |q| q.where_eq("active", true))
.when(true, |q| q.where_eq("role", "admin"))
.when(false, |q| q.where_eq("deleted", true))
.to_sql();
assert!(sql.contains("active = $1"));
assert!(sql.contains("role = $2"));
assert!(!sql.contains("deleted"));
assert_eq!(params.len(), 2);
}
#[test]
fn add_select_expr_appends_to_star() {
let (sql, _) = QueryBuilder::<()>::new("users")
.add_select_expr(
"(SELECT COUNT(*) FROM posts WHERE posts.user_id = users.id) AS posts_count",
)
.to_sql();
assert!(sql.starts_with("SELECT *, (SELECT COUNT(*)"));
}
#[test]
fn add_select_expr_appends_to_cols() {
let (sql, _) = QueryBuilder::<()>::new("users")
.select(&["id", "email"])
.add_select_expr("42 AS answer")
.to_sql();
assert!(sql.starts_with("SELECT id, email, 42 AS answer FROM users"));
}
#[test]
fn multiple_add_select_exprs() {
let (sql, _) = QueryBuilder::<()>::new("users")
.add_select_expr("(SELECT COUNT(*) FROM posts WHERE posts.user_id = users.id) AS posts_count")
.add_select_expr("(SELECT COUNT(*) FROM comments WHERE comments.user_id = users.id) AS comments_count")
.to_sql();
assert!(sql.contains("posts_count"));
assert!(sql.contains("comments_count"));
}
#[test]
fn where_column_alias() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_column("age", ">", 18i64)
.to_sql();
assert!(sql.contains("age > $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn or_where_alias() {
let (sql, params) = QueryBuilder::<()>::new("users")
.where_eq("role", "admin")
.or_where("role", "moderator")
.to_sql();
assert!(sql.contains("role = $1 OR role = $2"));
assert_eq!(params.len(), 2);
}
#[test]
fn subquery_exists_no_inner() {
use crate::condition::Condition;
use crate::condition::JoinOp;
let (sql, params) = QueryBuilder::<()>::new("users")
.push_condition(
JoinOp::And,
Condition::Subquery {
exists: true,
table: "posts".to_string(),
fk_expr: "posts.user_id = users.id".to_string(),
inner: vec![],
},
)
.to_sql();
assert!(sql.contains("EXISTS (SELECT 1 FROM posts WHERE posts.user_id = users.id)"));
assert!(params.is_empty());
}
#[test]
fn subquery_not_exists_with_inner() {
let inner = vec![(
JoinOp::And,
Condition::Eq("published".to_string(), SqlValue::Bool(true)),
)];
let (sql, params) = QueryBuilder::<()>::new("users")
.push_condition(
JoinOp::And,
Condition::Subquery {
exists: false,
table: "posts".to_string(),
fk_expr: "posts.user_id = users.id".to_string(),
inner,
},
)
.to_sql();
assert!(sql.contains(
"NOT EXISTS (SELECT 1 FROM posts WHERE posts.user_id = users.id AND published = $1)"
));
assert_eq!(params.len(), 1);
assert_eq!(params[0], SqlValue::Bool(true));
}
#[test]
fn subquery_outer_params_plus_inner_params() {
let inner = vec![(
JoinOp::And,
Condition::Eq("published".to_string(), SqlValue::Bool(true)),
)];
let (sql, params) = QueryBuilder::<()>::new("users")
.where_eq("active", true) .push_condition(
JoinOp::And,
Condition::Subquery {
exists: true,
table: "posts".to_string(),
fk_expr: "posts.user_id = users.id".to_string(),
inner,
},
)
.to_sql();
assert!(sql.contains("active = $1"));
assert!(sql.contains("published = $2"), "sql={sql}");
assert_eq!(params.len(), 2);
}
#[test]
fn use_replica_default_true() {
let q = QueryBuilder::<()>::new("users");
assert!(q.use_replica);
}
#[test]
fn on_write_db_clears_replica() {
let q = QueryBuilder::<()>::new("users").on_write_db();
assert!(!q.use_replica);
}
#[test]
fn nearest_to_generates_order_and_limit() {
let embedding = vec![1.0f32, 2.0, 3.0];
let (sql, params) = QueryBuilder::<()>::new("documents")
.nearest_to("embedding", &embedding, 10)
.to_sql();
assert!(sql.contains("embedding <-> '[1,2,3]'::vector"), "sql={sql}");
assert!(sql.contains("LIMIT 10"), "sql={sql}");
assert!(params.is_empty());
}
#[test]
fn where_cosine_distance_generates_filter() {
let embedding = vec![0.5f32, 0.5];
let (sql, params) = QueryBuilder::<()>::new("docs")
.where_cosine_distance("embedding", &embedding, "<", 0.3)
.to_sql();
assert!(
sql.contains("embedding <=> '[0.5,0.5]'::vector < 0.3"),
"sql={sql}"
);
assert!(params.is_empty());
}
#[test]
fn where_vector_distance_generates_filter() {
let embedding = vec![1.0f32];
let (sql, _) = QueryBuilder::<()>::new("docs")
.where_vector_distance("vec", &embedding, "<", 1.5)
.to_sql();
assert!(sql.contains("vec <-> '[1]'::vector < 1.5"), "sql={sql}");
}
#[test]
fn nearest_to_with_additional_filter() {
let embedding = vec![1.0f32, 0.0];
let (sql, params) = QueryBuilder::<()>::new("docs")
.where_eq("active", true)
.nearest_to("emb", &embedding, 5)
.to_sql();
assert!(sql.contains("WHERE active = $1"), "sql={sql}");
assert!(sql.contains("emb <-> '[1,0]'::vector"), "sql={sql}");
assert!(sql.contains("LIMIT 5"), "sql={sql}");
assert_eq!(params.len(), 1);
}
}