use crate::Ident;
use crate::condition::Condition;
use crate::error::{OrmError, OrmResult};
use crate::ident::IntoIdent;
use crate::sql::Sql;
use std::sync::Arc;
use tokio_postgres::types::ToSql;
#[derive(Debug, Clone)]
pub enum WhereExpr {
Atom(Condition),
And(Vec<WhereExpr>),
Or(Vec<WhereExpr>),
Not(Box<WhereExpr>),
Raw(String),
}
impl WhereExpr {
pub fn atom(condition: Condition) -> Self {
WhereExpr::Atom(condition)
}
pub fn and(exprs: Vec<WhereExpr>) -> Self {
WhereExpr::And(exprs)
}
pub fn or(exprs: Vec<WhereExpr>) -> Self {
WhereExpr::Or(exprs)
}
#[allow(clippy::should_implement_trait)]
pub fn not(expr: WhereExpr) -> Self {
WhereExpr::Not(Box::new(expr))
}
pub fn raw(sql: impl Into<String>) -> Self {
WhereExpr::Raw(sql.into())
}
pub fn and_with(self, other: WhereExpr) -> WhereExpr {
match self {
WhereExpr::And(mut exprs) => {
exprs.push(other);
WhereExpr::And(exprs)
}
_ => WhereExpr::And(vec![self, other]),
}
}
pub fn or_with(self, other: WhereExpr) -> WhereExpr {
match self {
WhereExpr::Or(mut exprs) => {
exprs.push(other);
WhereExpr::Or(exprs)
}
_ => WhereExpr::Or(vec![self, other]),
}
}
pub fn is_trivially_true(&self) -> bool {
matches!(self, WhereExpr::And(exprs) if exprs.is_empty())
}
pub fn is_trivially_false(&self) -> bool {
matches!(self, WhereExpr::Or(exprs) if exprs.is_empty())
}
pub fn append_to_sql(&self, sql: &mut Sql) {
match self {
WhereExpr::Atom(cond) => {
cond.append_to_sql(sql);
}
WhereExpr::And(exprs) => {
if exprs.is_empty() {
sql.push("TRUE");
} else if exprs.len() == 1 {
exprs[0].append_to_sql(sql);
} else {
sql.push("(");
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
sql.push(" AND ");
}
expr.append_to_sql(sql);
}
sql.push(")");
}
}
WhereExpr::Or(exprs) => {
if exprs.is_empty() {
sql.push("FALSE");
} else if exprs.len() == 1 {
exprs[0].append_to_sql(sql);
} else {
sql.push("(");
for (i, expr) in exprs.iter().enumerate() {
if i > 0 {
sql.push(" OR ");
}
expr.append_to_sql(sql);
}
sql.push(")");
}
}
WhereExpr::Not(expr) => {
sql.push("(NOT ");
expr.append_to_sql(sql);
sql.push(")");
}
WhereExpr::Raw(s) => {
sql.push(s);
}
}
}
}
impl From<Condition> for WhereExpr {
fn from(cond: Condition) -> Self {
WhereExpr::Atom(cond)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SortDir {
#[default]
Asc,
Desc,
}
impl SortDir {
fn to_sql(self) -> &'static str {
match self {
SortDir::Asc => "ASC",
SortDir::Desc => "DESC",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NullsOrder {
First,
Last,
}
impl NullsOrder {
fn to_sql(self) -> &'static str {
match self {
NullsOrder::First => "NULLS FIRST",
NullsOrder::Last => "NULLS LAST",
}
}
}
#[derive(Debug, Clone)]
pub enum OrderItem {
Column {
column: Ident,
dir: SortDir,
nulls: Option<NullsOrder>,
},
Raw(String),
}
impl OrderItem {
pub fn new(column: Ident, dir: SortDir) -> Self {
Self::Column {
column,
dir,
nulls: None,
}
}
pub fn raw(sql: impl Into<String>) -> Self {
Self::Raw(sql.into())
}
pub fn nulls(mut self, order: NullsOrder) -> Self {
if let OrderItem::Column { nulls, .. } = &mut self {
*nulls = Some(order);
}
self
}
fn append_to_sql(&self, sql: &mut Sql) {
match self {
OrderItem::Column { column, dir, nulls } => {
sql.push_ident_ref(column);
sql.push(" ");
sql.push(dir.to_sql());
if let Some(nulls) = nulls {
sql.push(" ");
sql.push(nulls.to_sql());
}
}
OrderItem::Raw(s) => {
sql.push(s);
}
}
}
}
#[derive(Debug, Clone, Default)]
pub struct OrderBy {
items: Vec<OrderItem>,
}
impl OrderBy {
pub fn new() -> Self {
Self::default()
}
pub fn asc(mut self, column: impl IntoIdent) -> OrmResult<Self> {
self.items
.push(OrderItem::new(column.into_ident()?, SortDir::Asc));
Ok(self)
}
pub fn desc(mut self, column: impl IntoIdent) -> OrmResult<Self> {
self.items
.push(OrderItem::new(column.into_ident()?, SortDir::Desc));
Ok(self)
}
pub fn with_nulls(
mut self,
column: impl IntoIdent,
dir: SortDir,
nulls: NullsOrder,
) -> OrmResult<Self> {
self.items
.push(OrderItem::new(column.into_ident()?, dir).nulls(nulls));
Ok(self)
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, item: OrderItem) -> Self {
self.items.push(item);
self
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn append_to_sql(&self, sql: &mut Sql) {
if self.items.is_empty() {
return;
}
sql.push(" ORDER BY ");
for (i, item) in self.items.iter().enumerate() {
if i > 0 {
sql.push(", ");
}
item.append_to_sql(sql);
}
}
pub fn to_sql(&self) -> String {
if self.items.is_empty() {
return String::new();
}
let mut sql = Sql::empty();
sql.push("ORDER BY ");
for (i, item) in self.items.iter().enumerate() {
if i > 0 {
sql.push(", ");
}
item.append_to_sql(&mut sql);
}
sql.to_sql()
}
}
#[derive(Debug, Clone, Default)]
pub struct Pagination {
pub limit: Option<i64>,
pub offset: Option<i64>,
}
impl Pagination {
pub fn new() -> Self {
Self::default()
}
pub fn page(page: i64, per_page: i64) -> OrmResult<Self> {
if page < 1 {
return Err(OrmError::validation(format!(
"page must be >= 1, got {page}"
)));
}
Ok(Self {
limit: Some(per_page),
offset: Some((page - 1) * per_page),
})
}
pub fn limit(mut self, n: i64) -> Self {
self.limit = Some(n);
self
}
pub fn offset(mut self, n: i64) -> Self {
self.offset = Some(n);
self
}
pub fn is_empty(&self) -> bool {
self.limit.is_none() && self.offset.is_none()
}
pub fn append_to_sql(&self, sql: &mut Sql) {
if let Some(limit) = self.limit {
sql.push(" LIMIT ");
sql.push_bind(limit);
}
if let Some(offset) = self.offset {
sql.push(" OFFSET ");
sql.push_bind(offset);
}
}
}
type DynValue = Arc<dyn ToSql + Send + Sync>;
const DEFAULT_KEYSET_LIMIT: i64 = 50;
#[derive(Debug, Clone)]
pub enum Cursor<T> {
After(T),
Before(T),
}
fn seek_cmp_op<T>(dir: SortDir, cursor: &Cursor<T>) -> &'static str {
match (dir, cursor) {
(SortDir::Asc, Cursor::After(_)) => ">",
(SortDir::Asc, Cursor::Before(_)) => "<",
(SortDir::Desc, Cursor::After(_)) => "<",
(SortDir::Desc, Cursor::Before(_)) => ">",
}
}
#[derive(Debug, Clone)]
pub struct Keyset1 {
column: Ident,
dir: SortDir,
cursor: Option<Cursor<DynValue>>,
limit: i64,
}
impl Keyset1 {
pub fn asc(column: impl IntoIdent) -> OrmResult<Self> {
Ok(Self {
column: column.into_ident()?,
dir: SortDir::Asc,
cursor: None,
limit: DEFAULT_KEYSET_LIMIT,
})
}
pub fn desc(column: impl IntoIdent) -> OrmResult<Self> {
Ok(Self {
column: column.into_ident()?,
dir: SortDir::Desc,
cursor: None,
limit: DEFAULT_KEYSET_LIMIT,
})
}
pub fn after<T>(mut self, v: T) -> Self
where
T: ToSql + Send + Sync + 'static,
{
self.cursor = Some(Cursor::After(Arc::new(v)));
self
}
pub fn before<T>(mut self, v: T) -> Self
where
T: ToSql + Send + Sync + 'static,
{
self.cursor = Some(Cursor::Before(Arc::new(v)));
self
}
pub fn limit(mut self, n: i64) -> Self {
self.limit = n;
self
}
pub fn order_by(&self) -> OrderBy {
OrderBy::new().add(OrderItem::new(self.column.clone(), self.dir))
}
pub fn into_where_expr(&self) -> OrmResult<WhereExpr> {
let Some(cursor) = &self.cursor else {
return Ok(WhereExpr::and(Vec::new()));
};
let op = seek_cmp_op(self.dir, cursor);
let v = match cursor {
Cursor::After(v) | Cursor::Before(v) => v.clone(),
};
Ok(WhereExpr::atom(Condition::cmp_dyn(
self.column.clone(),
op,
v,
)))
}
pub fn append_order_by_limit_to_sql(&self, sql: &mut Sql) -> OrmResult<()> {
if self.limit < 1 {
return Err(OrmError::validation(format!(
"keyset limit must be >= 1, got {}",
self.limit
)));
}
self.order_by().append_to_sql(sql);
sql.limit(self.limit);
Ok(())
}
pub fn append_to_sql(&self, sql: &mut Sql) -> OrmResult<()> {
let seek = self.into_where_expr()?;
if !seek.is_trivially_true() {
sql.push(" WHERE ");
seek.append_to_sql(sql);
}
self.append_order_by_limit_to_sql(sql)
}
}
#[derive(Debug, Clone)]
pub struct Keyset2 {
a: Ident,
b: Ident,
dir: SortDir,
cursor: Option<Cursor<(DynValue, DynValue)>>,
limit: i64,
}
impl Keyset2 {
pub fn asc(a: impl IntoIdent, b: impl IntoIdent) -> OrmResult<Self> {
Ok(Self {
a: a.into_ident()?,
b: b.into_ident()?,
dir: SortDir::Asc,
cursor: None,
limit: DEFAULT_KEYSET_LIMIT,
})
}
pub fn desc(a: impl IntoIdent, b: impl IntoIdent) -> OrmResult<Self> {
Ok(Self {
a: a.into_ident()?,
b: b.into_ident()?,
dir: SortDir::Desc,
cursor: None,
limit: DEFAULT_KEYSET_LIMIT,
})
}
pub fn after<A, B>(mut self, a: A, b: B) -> Self
where
A: ToSql + Send + Sync + 'static,
B: ToSql + Send + Sync + 'static,
{
self.cursor = Some(Cursor::After((Arc::new(a), Arc::new(b))));
self
}
pub fn before<A, B>(mut self, a: A, b: B) -> Self
where
A: ToSql + Send + Sync + 'static,
B: ToSql + Send + Sync + 'static,
{
self.cursor = Some(Cursor::Before((Arc::new(a), Arc::new(b))));
self
}
pub fn limit(mut self, n: i64) -> Self {
self.limit = n;
self
}
pub fn order_by(&self) -> OrderBy {
OrderBy::new()
.add(OrderItem::new(self.a.clone(), self.dir))
.add(OrderItem::new(self.b.clone(), self.dir))
}
pub fn into_where_expr(&self) -> OrmResult<WhereExpr> {
let Some(cursor) = &self.cursor else {
return Ok(WhereExpr::and(Vec::new()));
};
let op = seek_cmp_op(self.dir, cursor);
let (va, vb) = match cursor {
Cursor::After((va, vb)) | Cursor::Before((va, vb)) => (va.clone(), vb.clone()),
};
Ok(WhereExpr::atom(Condition::tuple2_cmp_dyn(
self.a.clone(),
self.b.clone(),
op,
va,
vb,
)))
}
pub fn append_order_by_limit_to_sql(&self, sql: &mut Sql) -> OrmResult<()> {
if self.limit < 1 {
return Err(OrmError::validation(format!(
"keyset limit must be >= 1, got {}",
self.limit
)));
}
self.order_by().append_to_sql(sql);
sql.limit(self.limit);
Ok(())
}
pub fn append_to_sql(&self, sql: &mut Sql) -> OrmResult<()> {
let seek = self.into_where_expr()?;
if !seek.is_trivially_true() {
sql.push(" WHERE ");
seek.append_to_sql(sql);
}
self.append_order_by_limit_to_sql(sql)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn where_atom() {
let expr = WhereExpr::atom(Condition::eq("status", "active").unwrap());
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "status = $1");
}
#[test]
fn where_and() {
let expr = WhereExpr::And(vec![
WhereExpr::Atom(Condition::eq("a", 1_i32).unwrap()),
WhereExpr::Atom(Condition::eq("b", 2_i32).unwrap()),
]);
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "(a = $1 AND b = $2)");
}
#[test]
fn where_or() {
let expr = WhereExpr::Or(vec![
WhereExpr::Atom(Condition::eq("role", "admin").unwrap()),
WhereExpr::Atom(Condition::eq("role", "owner").unwrap()),
]);
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "(role = $1 OR role = $2)");
}
#[test]
fn where_not() {
let expr = WhereExpr::Not(Box::new(WhereExpr::Atom(
Condition::eq("deleted", true).unwrap(),
)));
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "(NOT deleted = $1)");
}
#[test]
fn where_nested() {
let expr = WhereExpr::And(vec![
WhereExpr::Atom(Condition::eq("status", "active").unwrap()),
WhereExpr::Or(vec![
WhereExpr::Atom(Condition::eq("role", "admin").unwrap()),
WhereExpr::Atom(Condition::eq("role", "owner").unwrap()),
]),
]);
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "(status = $1 AND (role = $2 OR role = $3))");
}
#[test]
fn where_empty_and_is_true() {
let expr = WhereExpr::And(vec![]);
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "TRUE");
}
#[test]
fn where_empty_or_is_false() {
let expr = WhereExpr::Or(vec![]);
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "FALSE");
}
#[test]
fn where_and_with_combines() {
let a = WhereExpr::atom(Condition::eq("a", 1_i32).unwrap());
let b = WhereExpr::atom(Condition::eq("b", 2_i32).unwrap());
let expr = a.and_with(b);
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "(a = $1 AND b = $2)");
}
#[test]
fn where_raw() {
let expr = WhereExpr::raw("custom_func(x) > 0");
let mut sql = Sql::empty();
expr.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "custom_func(x) > 0");
}
#[test]
fn order_by_single_asc() {
let order = OrderBy::new().asc("created_at").unwrap();
assert_eq!(order.to_sql(), "ORDER BY created_at ASC");
}
#[test]
fn order_by_single_desc() {
let order = OrderBy::new().desc("priority").unwrap();
assert_eq!(order.to_sql(), "ORDER BY priority DESC");
}
#[test]
fn order_by_multiple() {
let order = OrderBy::new()
.asc("status")
.unwrap()
.desc("created_at")
.unwrap();
assert_eq!(order.to_sql(), "ORDER BY status ASC, created_at DESC");
}
#[test]
fn order_by_with_nulls() {
let order = OrderBy::new()
.with_nulls("last_login", SortDir::Desc, NullsOrder::Last)
.unwrap();
assert_eq!(order.to_sql(), "ORDER BY last_login DESC NULLS LAST");
}
#[test]
fn order_by_empty() {
let order = OrderBy::new();
assert!(order.is_empty());
assert_eq!(order.to_sql(), "");
}
#[test]
fn order_by_append() {
let order = OrderBy::new().asc("id").unwrap();
let mut sql = Sql::new("SELECT * FROM users");
order.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "SELECT * FROM users ORDER BY id ASC");
}
#[test]
fn order_by_validates_column() {
let res = OrderBy::new().asc("valid_column; DROP TABLE users;");
assert!(res.is_err());
}
#[test]
fn keyset1_asc_after_generates_sql() {
let keyset = Keyset1::asc("id").unwrap().after(100_i64).limit(10);
let mut sql = Sql::new("SELECT * FROM users");
keyset.append_to_sql(&mut sql).unwrap();
assert_eq!(
sql.to_sql(),
"SELECT * FROM users WHERE id > $1 ORDER BY id ASC LIMIT $2"
);
}
#[test]
fn keyset1_desc_after_flips_comparator() {
let keyset = Keyset1::desc("created_at").unwrap().after(123_i64).limit(5);
let mut sql = Sql::new("SELECT * FROM users");
keyset.append_to_sql(&mut sql).unwrap();
assert_eq!(
sql.to_sql(),
"SELECT * FROM users WHERE created_at < $1 ORDER BY created_at DESC LIMIT $2"
);
}
#[test]
fn keyset1_composes_with_other_where_expr() {
let keyset = Keyset1::asc("id").unwrap().after(10_i64).limit(3);
let where_expr = WhereExpr::atom(Condition::eq("status", "active").unwrap())
.and_with(keyset.into_where_expr().unwrap());
let mut sql = Sql::new("SELECT * FROM users WHERE ");
where_expr.append_to_sql(&mut sql);
keyset.append_order_by_limit_to_sql(&mut sql).unwrap();
assert_eq!(
sql.to_sql(),
"SELECT * FROM users WHERE (status = $1 AND id > $2) ORDER BY id ASC LIMIT $3"
);
}
#[test]
fn keyset2_desc_after_generates_tuple_cmp() {
let keyset = Keyset2::desc("created_at", "id")
.unwrap()
.after(100_i64, 42_i64)
.limit(20);
let mut sql = Sql::new("SELECT * FROM users");
keyset.append_to_sql(&mut sql).unwrap();
assert_eq!(
sql.to_sql(),
"SELECT * FROM users WHERE (created_at, id) < ($1, $2) ORDER BY created_at DESC, id DESC LIMIT $3"
);
}
#[test]
fn keyset2_composes_with_other_where_expr() {
let keyset = Keyset2::asc("created_at", "id")
.unwrap()
.after(123_i64, 456_i64)
.limit(2);
let where_expr = WhereExpr::atom(Condition::eq("status", "active").unwrap())
.and_with(keyset.into_where_expr().unwrap());
let mut sql = Sql::new("SELECT * FROM users WHERE ");
where_expr.append_to_sql(&mut sql);
keyset.append_order_by_limit_to_sql(&mut sql).unwrap();
assert_eq!(
sql.to_sql(),
"SELECT * FROM users WHERE (status = $1 AND (created_at, id) > ($2, $3)) ORDER BY created_at ASC, id ASC LIMIT $4"
);
}
#[test]
fn pagination_limit_only() {
let pag = Pagination::new().limit(10);
let mut sql = Sql::new("SELECT * FROM users");
pag.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "SELECT * FROM users LIMIT $1");
}
#[test]
fn pagination_offset_only() {
let pag = Pagination::new().offset(20);
let mut sql = Sql::new("SELECT * FROM users");
pag.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "SELECT * FROM users OFFSET $1");
}
#[test]
fn pagination_limit_offset() {
let pag = Pagination::new().limit(10).offset(20);
let mut sql = Sql::new("SELECT * FROM users");
pag.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "SELECT * FROM users LIMIT $1 OFFSET $2");
}
#[test]
fn pagination_page() {
let pag = Pagination::page(3, 25).unwrap();
assert_eq!(pag.limit, Some(25));
assert_eq!(pag.offset, Some(50)); }
#[test]
fn pagination_page_one() {
let pag = Pagination::page(1, 10).unwrap();
assert_eq!(pag.limit, Some(10));
assert_eq!(pag.offset, Some(0));
}
#[test]
fn pagination_page_rejects_zero() {
assert!(Pagination::page(0, 10).is_err());
}
#[test]
fn pagination_page_rejects_negative() {
assert!(Pagination::page(-1, 10).is_err());
}
#[test]
fn pagination_empty() {
let pag = Pagination::new();
assert!(pag.is_empty());
let mut sql = Sql::new("SELECT * FROM users");
pag.append_to_sql(&mut sql);
assert_eq!(sql.to_sql(), "SELECT * FROM users");
}
}