use std::ops::{Add, BitAnd, BitOr, Mul, Not, Sub};
use sea_orm::sea_query::{
Alias, CaseStatement, Expr, ExprTrait, Func as SeaQueryFunc, SelectStatement, SimpleExpr,
};
use sea_orm::{Condition, Value};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct F(pub String);
impl F {
pub fn new(field: &str) -> Self {
Self(field.to_string())
}
pub fn col(&self) -> SimpleExpr {
Expr::col(Alias::new(&self.0))
}
}
impl Add<i64> for F {
type Output = SimpleExpr;
fn add(self, rhs: i64) -> Self::Output {
self.col().add(rhs)
}
}
impl Sub<i64> for F {
type Output = SimpleExpr;
fn sub(self, rhs: i64) -> Self::Output {
self.col().sub(rhs)
}
}
impl Mul<i64> for F {
type Output = SimpleExpr;
fn mul(self, rhs: i64) -> Self::Output {
self.col().mul(rhs)
}
}
#[derive(Clone, Debug)]
pub struct Q {
condition: Condition,
}
impl Q {
pub fn new(expr: SimpleExpr) -> Self {
Self {
condition: Condition::all().add(expr),
}
}
pub fn and(self, other: Q) -> Self {
Self {
condition: Condition::all().add(self.condition).add(other.condition),
}
}
pub fn or(self, other: Q) -> Self {
Self {
condition: Condition::any().add(self.condition).add(other.condition),
}
}
pub fn not(self) -> Self {
Self {
condition: self.condition.not(),
}
}
pub fn into_condition(self) -> Condition {
self.condition
}
}
impl BitAnd for Q {
type Output = Q;
fn bitand(self, rhs: Self) -> Self::Output {
self.and(rhs)
}
}
impl BitOr for Q {
type Output = Q;
fn bitor(self, rhs: Self) -> Self::Output {
self.or(rhs)
}
}
impl Not for Q {
type Output = Q;
fn not(self) -> Self::Output {
Q::not(self)
}
}
pub fn value<V: Into<Value>>(v: V) -> SimpleExpr {
Expr::value(v)
}
#[derive(Clone, Debug)]
pub struct When {
pub condition: SimpleExpr,
pub then: SimpleExpr,
}
impl When {
pub fn new(condition: SimpleExpr, then: SimpleExpr) -> Self {
Self { condition, then }
}
}
#[derive(Clone, Debug, Default)]
pub struct Case {
whens: Vec<When>,
default: Option<SimpleExpr>,
}
impl Case {
pub fn new() -> Self {
Self {
whens: Vec::new(),
default: None,
}
}
pub fn when(mut self, when: When) -> Self {
self.whens.push(when);
self
}
pub fn default(mut self, value: SimpleExpr) -> Self {
self.default = Some(value);
self
}
pub fn build(self) -> SimpleExpr {
let mut statement = CaseStatement::new();
for when in self.whens {
statement = statement.case(when.condition, when.then);
}
if let Some(default) = self.default {
statement = statement.finally(default);
}
statement.into()
}
}
#[derive(Clone, Debug)]
pub struct Subquery {
query: SelectStatement,
}
impl Subquery {
pub fn new(query: SelectStatement) -> Self {
Self { query }
}
pub fn build(self) -> SimpleExpr {
self.query.into()
}
}
#[derive(Clone, Debug)]
pub struct Exists {
subquery: Subquery,
negated: bool,
}
impl Exists {
pub fn new(subquery: Subquery) -> Self {
Self {
subquery,
negated: false,
}
}
pub fn negate(mut self) -> Self {
self.negated = true;
self
}
pub fn build(self) -> SimpleExpr {
if self.negated {
Expr::not_exists(self.subquery.query)
} else {
Expr::exists(self.subquery.query)
}
}
}
#[derive(Clone, Debug)]
pub struct OuterRef(pub String);
impl OuterRef {
pub fn col(&self) -> SimpleExpr {
Expr::col(Alias::new(&self.0))
}
}
#[derive(Clone, Debug)]
pub struct Func {
name: String,
args: Vec<SimpleExpr>,
}
impl Func {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
args: Vec::new(),
}
}
pub fn arg(mut self, expr: SimpleExpr) -> Self {
self.args.push(expr);
self
}
pub fn build(self) -> SimpleExpr {
SeaQueryFunc::cust(Alias::new(&self.name))
.args(self.args)
.into()
}
}
pub fn coalesce(args: Vec<SimpleExpr>) -> SimpleExpr {
SeaQueryFunc::coalesce(args).into()
}
pub fn greatest(a: SimpleExpr, b: SimpleExpr) -> SimpleExpr {
SeaQueryFunc::greatest([a, b]).into()
}
pub fn least(a: SimpleExpr, b: SimpleExpr) -> SimpleExpr {
SeaQueryFunc::least([a, b]).into()
}
pub fn concat(args: Vec<SimpleExpr>) -> SimpleExpr {
Func::new("CONCAT").build_with_args(args)
}
pub fn cast(expr: SimpleExpr, type_name: &str) -> SimpleExpr {
SeaQueryFunc::cast_as(expr, Alias::new(type_name)).into()
}
#[derive(Clone, Debug)]
pub struct Window {
pub expr: SimpleExpr,
pub partition_by: Vec<SimpleExpr>,
pub order_by: Vec<(SimpleExpr, bool)>,
}
impl Window {
pub fn new(expr: SimpleExpr) -> Self {
Self {
expr,
partition_by: Vec::new(),
order_by: Vec::new(),
}
}
pub fn partition_by(mut self, col: SimpleExpr) -> Self {
self.partition_by.push(col);
self
}
pub fn order_by(mut self, col: SimpleExpr, asc: bool) -> Self {
self.order_by.push((col, asc));
self
}
pub fn build(self) -> SimpleExpr {
let mut sql = String::from("? OVER (");
let mut exprs = vec![self.expr];
let mut wrote_clause = false;
if !self.partition_by.is_empty() {
sql.push_str("PARTITION BY ");
for (index, partition) in self.partition_by.into_iter().enumerate() {
if index > 0 {
sql.push_str(", ");
}
sql.push('?');
exprs.push(partition);
}
wrote_clause = true;
}
if !self.order_by.is_empty() {
if wrote_clause {
sql.push(' ');
}
sql.push_str("ORDER BY ");
for (index, (expr, asc)) in self.order_by.into_iter().enumerate() {
if index > 0 {
sql.push_str(", ");
}
sql.push('?');
sql.push_str(if asc { " ASC" } else { " DESC" });
exprs.push(expr);
}
}
sql.push(')');
Expr::cust_with_exprs(sql, exprs)
}
}
pub fn row_number() -> SimpleExpr {
Func::new("ROW_NUMBER").build()
}
pub fn rank() -> SimpleExpr {
Func::new("RANK").build()
}
pub fn dense_rank() -> SimpleExpr {
Func::new("DENSE_RANK").build()
}
pub fn lag(col: SimpleExpr, offset: u32) -> SimpleExpr {
Func::new("LAG")
.arg(col)
.arg(value(i64::from(offset)))
.build()
}
pub fn lead(col: SimpleExpr, offset: u32) -> SimpleExpr {
Func::new("LEAD")
.arg(col)
.arg(value(i64::from(offset)))
.build()
}
#[derive(Clone, Debug)]
pub struct RawSQL {
pub sql: String,
pub params: Vec<Value>,
}
impl RawSQL {
pub fn new(sql: impl Into<String>, params: Vec<Value>) -> Self {
Self {
sql: sql.into(),
params,
}
}
pub fn build(self) -> SimpleExpr {
SimpleExpr::Custom(self.sql.into())
}
}
impl Func {
fn build_with_args(mut self, args: Vec<SimpleExpr>) -> SimpleExpr {
self.args = args;
self.build()
}
}
#[cfg(test)]
mod tests {
use sea_orm::Condition;
use sea_orm::sea_query::{Alias, Expr, ExprTrait, Query, SimpleExpr, SqliteQueryBuilder};
use super::{
Case, Exists, F, Func, OuterRef, Q, RawSQL, Subquery, When, Window, cast, coalesce, concat,
dense_rank, greatest, lag, lead, least, rank, row_number, value,
};
fn render_select_expr(expr: SimpleExpr, alias: &str) -> String {
Query::select()
.expr_as(expr, Alias::new(alias))
.from(Alias::new("widgets"))
.to_owned()
.to_string(SqliteQueryBuilder)
}
fn render_where(condition: sea_orm::Condition) -> String {
Query::select()
.column(Alias::new("id"))
.from(Alias::new("widgets"))
.cond_where(condition)
.to_owned()
.to_string(SqliteQueryBuilder)
}
#[test]
fn f_arithmetic_renders_column_math() {
let added = render_select_expr(F::new("priority") + 2, "shifted");
let subtracted = render_select_expr(F::new("priority") - 3, "shifted");
let multiplied = render_select_expr(F::new("priority") * 4, "shifted");
assert!(
added.contains("\"priority\" + 2"),
"expected addition SQL, got: {added}"
);
assert!(
subtracted.contains("\"priority\" - 3"),
"expected subtraction SQL, got: {subtracted}"
);
assert!(
multiplied.contains("\"priority\" * 4"),
"expected multiplication SQL, got: {multiplied}"
);
}
#[test]
fn q_bitwise_operators_render_combined_conditions() {
let high_priority = Q::new(Expr::col(Alias::new("priority")).gt(10));
let open_status = Q::new(Expr::col(Alias::new("status")).eq("open"));
let archived_name = Q::new(Expr::col(Alias::new("name")).eq("archived"));
let sql = render_where(((high_priority & open_status) | !archived_name).into_condition());
assert!(
sql.contains("\"priority\" > 10"),
"expected priority condition, got: {sql}"
);
assert!(
sql.contains("\"status\" = 'open'"),
"expected status condition, got: {sql}"
);
assert!(
sql.contains("\"name\" = 'archived'"),
"expected archived condition, got: {sql}"
);
assert!(sql.contains("AND"), "expected AND in SQL, got: {sql}");
assert!(sql.contains("OR"), "expected OR in SQL, got: {sql}");
assert!(sql.contains("NOT"), "expected NOT in SQL, got: {sql}");
}
#[test]
fn value_renders_literal_expression() {
let sql = render_select_expr(value(42_i32), "answer");
assert!(
sql.contains("42"),
"expected literal value in SQL, got: {sql}"
);
}
#[test]
fn case_when_renders_sql() {
let sql = render_select_expr(
Case::new()
.when(When::new(Expr::col(Alias::new("x")).gt(10), value("high")))
.default(value("low"))
.build(),
"bucket",
);
assert!(sql.contains("CASE"), "expected CASE in SQL, got: {sql}");
assert!(
sql.contains("WHEN (\"x\" > 10)"),
"expected WHEN clause, got: {sql}"
);
assert!(
sql.contains("THEN 'high'"),
"expected THEN clause, got: {sql}"
);
assert!(
sql.contains("ELSE 'low' END"),
"expected ELSE clause, got: {sql}"
);
}
#[test]
fn subquery_renders_sql() {
let subquery = Subquery::new(
Query::select()
.column(Alias::new("widget_id"))
.from(Alias::new("orders"))
.to_owned(),
);
let condition =
Expr::cust_with_exprs("? IN ?", [Expr::col(Alias::new("id")), subquery.build()]);
let sql = render_where(Condition::all().add(condition));
assert!(
sql.contains("\"id\" IN (SELECT \"widget_id\" FROM \"orders\")"),
"expected IN subquery SQL, got: {sql}"
);
}
#[test]
fn exists_renders_sql() {
let exists = Exists::new(Subquery::new(
Query::select()
.column(Alias::new("id"))
.from(Alias::new("orders"))
.to_owned(),
));
let sql = render_where(Condition::all().add(exists.build()));
assert!(
sql.contains("EXISTS(SELECT \"id\" FROM \"orders\")"),
"expected EXISTS SQL, got: {sql}"
);
}
#[test]
fn exists_negated_renders_sql() {
let exists = Exists::new(Subquery::new(
Query::select()
.column(Alias::new("id"))
.from(Alias::new("orders"))
.to_owned(),
))
.negate();
let sql = render_where(Condition::all().add(exists.build()));
assert!(
sql.contains("NOT EXISTS(SELECT \"id\" FROM \"orders\")"),
"expected NOT EXISTS SQL, got: {sql}"
);
}
#[test]
fn generic_func_renders_sql() {
let sql = render_select_expr(
Func::new("LOWER")
.arg(Expr::col(Alias::new("name")))
.build(),
"lower_name",
);
assert!(
sql.contains("LOWER(\"name\")"),
"expected LOWER SQL, got: {sql}"
);
}
#[test]
fn convenience_funcs_render_sql() {
let coalesce_sql = render_select_expr(
coalesce(vec![Expr::col(Alias::new("nickname")), value("guest")]),
"screen_name",
);
let greatest_sql = render_select_expr(
greatest(
Expr::col(Alias::new("score")),
Expr::col(Alias::new("bonus")),
),
"max_score",
);
let least_sql = render_select_expr(
least(
Expr::col(Alias::new("score")),
Expr::col(Alias::new("bonus")),
),
"min_score",
);
let concat_sql = render_select_expr(
concat(vec![value("a"), Expr::col(Alias::new("name")), value("z")]),
"joined",
);
let cast_sql =
render_select_expr(cast(Expr::col(Alias::new("score")), "INTEGER"), "score_i64");
assert!(
coalesce_sql.contains("COALESCE"),
"expected COALESCE SQL, got: {coalesce_sql}"
);
assert!(
greatest_sql.contains("MAX(\"score\", \"bonus\")")
|| greatest_sql.contains("GREATEST(\"score\", \"bonus\")"),
"expected GREATEST/MAX SQL, got: {greatest_sql}"
);
assert!(
least_sql.contains("MIN(\"score\", \"bonus\")")
|| least_sql.contains("LEAST(\"score\", \"bonus\")"),
"expected LEAST/MIN SQL, got: {least_sql}"
);
assert!(
concat_sql.contains("CONCAT('a', \"name\", 'z')"),
"expected CONCAT SQL, got: {concat_sql}"
);
assert!(
cast_sql.contains("CAST(\"score\" AS INTEGER)"),
"expected CAST SQL, got: {cast_sql}"
);
}
#[test]
fn window_renders_sql() {
let sql = render_select_expr(
Window::new(row_number())
.partition_by(Expr::col(Alias::new("team")))
.order_by(Expr::col(Alias::new("created_at")), false)
.build(),
"row_num",
);
assert!(
sql.contains("ROW_NUMBER() OVER (PARTITION BY \"team\" ORDER BY \"created_at\" DESC)"),
"expected window SQL, got: {sql}"
);
}
#[test]
fn rank_functions_render_sql() {
let rank_sql = render_select_expr(rank(), "ranking");
let dense_rank_sql = render_select_expr(dense_rank(), "dense_ranking");
assert!(
rank_sql.contains("RANK()"),
"expected RANK SQL, got: {rank_sql}"
);
assert!(
dense_rank_sql.contains("DENSE_RANK()"),
"expected DENSE_RANK SQL, got: {dense_rank_sql}"
);
}
#[test]
fn lag_and_lead_render_sql() {
let lag_sql = render_select_expr(
lag(Expr::col(Alias::new("position")), 2),
"previous_position",
);
let lead_sql =
render_select_expr(lead(Expr::col(Alias::new("position")), 1), "next_position");
assert!(
lag_sql.contains("LAG(\"position\", 2)"),
"expected LAG SQL, got: {lag_sql}"
);
assert!(
lead_sql.contains("LEAD(\"position\", 1)"),
"expected LEAD SQL, got: {lead_sql}"
);
}
#[test]
fn raw_sql_passes_through() {
let sql = render_select_expr(
RawSQL::new("json_extract(data, '$.name')", vec![]).build(),
"name",
);
assert!(
sql.contains("json_extract(data, '$.name')"),
"expected raw SQL, got: {sql}"
);
}
#[test]
fn outer_ref_renders_column() {
let sql = render_select_expr(OuterRef("parent_id".to_string()).col(), "outer_id");
assert!(
sql.contains("\"parent_id\""),
"expected outer ref column, got: {sql}"
);
}
}