use std::marker::PhantomData;
use derive_more::with_trait::Debug;
use sea_query::{ExprTrait, IntoColumnRef};
use crate::db;
use crate::db::{
Auto, Database, DatabaseBackend, DbFieldValue, DbValue, ForeignKey, FromDbValue, Identifier,
Model, StatementResult, ToDbFieldValue,
};
pub struct Query<T> {
filter: Option<Expr>,
limit: Option<u64>,
offset: Option<u64>,
phantom_data: PhantomData<fn() -> T>,
}
impl<T> Debug for Query<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Query")
.field("filter", &self.filter)
.field("limit", &self.limit)
.field("offset", &self.offset)
.field("phantom_data", &self.phantom_data)
.finish()
}
}
impl<T> Clone for Query<T> {
fn clone(&self) -> Self {
Self {
filter: self.filter.clone(),
limit: self.limit,
offset: self.offset,
phantom_data: PhantomData,
}
}
}
impl<T> PartialEq for Query<T> {
fn eq(&self, other: &Self) -> bool {
self.filter == other.filter
}
}
impl<T: Model> Default for Query<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Model> Query<T> {
#[must_use]
pub fn new() -> Self {
Self {
filter: None,
limit: None,
offset: None,
phantom_data: PhantomData,
}
}
pub fn filter(&mut self, filter: Expr) -> &mut Self {
self.filter = Some(filter);
self
}
pub fn limit(&mut self, limit: u64) -> &mut Self {
self.limit = Some(limit);
self
}
pub fn offset(&mut self, offset: u64) -> &mut Self {
self.offset = Some(offset);
self
}
pub async fn all<DB: DatabaseBackend>(&self, db: &DB) -> db::Result<Vec<T>> {
db.query(self).await
}
pub async fn get<DB: DatabaseBackend>(&self, db: &DB) -> db::Result<Option<T>> {
db.get(self).await
}
pub async fn count(&self, db: &Database) -> db::Result<u64> {
let mut select = sea_query::Query::select();
select
.from(T::TABLE_NAME)
.expr(sea_query::Expr::col(sea_query::Asterisk).count());
self.add_filter_to_statement(&mut select);
let row = db.fetch_option(&select).await?;
let count = match row {
#[expect(clippy::cast_sign_loss)]
Some(row) => row.get::<i64>(0)? as u64,
None => 0,
};
Ok(count)
}
pub async fn exists<DB: DatabaseBackend>(&self, db: &DB) -> db::Result<bool> {
db.exists(self).await
}
pub async fn delete<DB: DatabaseBackend>(&self, db: &DB) -> db::Result<StatementResult> {
db.delete(self).await
}
pub(super) fn add_filter_to_statement<S: sea_query::ConditionalStatement>(
&self,
statement: &mut S,
) {
if let Some(filter) = &self.filter {
statement.and_where(filter.as_sea_query_expr());
}
}
pub(super) fn add_limit_to_statement(&self, statement: &mut sea_query::SelectStatement) {
if let Some(limit) = self.limit {
statement.limit(limit);
}
}
pub(super) fn add_offset_to_statement(&self, statement: &mut sea_query::SelectStatement) {
if let Some(offset) = self.offset {
statement.offset(offset);
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub enum Expr {
Field(Identifier),
Value(DbValue),
And(Box<Expr>, Box<Expr>),
Or(Box<Expr>, Box<Expr>),
Eq(Box<Expr>, Box<Expr>),
Ne(Box<Expr>, Box<Expr>),
Lt(Box<Expr>, Box<Expr>),
Lte(Box<Expr>, Box<Expr>),
Gt(Box<Expr>, Box<Expr>),
Gte(Box<Expr>, Box<Expr>),
Add(Box<Expr>, Box<Expr>),
Sub(Box<Expr>, Box<Expr>),
Mul(Box<Expr>, Box<Expr>),
Div(Box<Expr>, Box<Expr>),
}
impl Expr {
#[must_use]
pub fn field<T: Into<Identifier>>(identifier: T) -> Self {
Self::Field(identifier.into())
}
#[must_use]
#[expect(clippy::needless_pass_by_value)]
pub fn value<T: ToDbFieldValue>(value: T) -> Self {
match value.to_db_field_value() {
DbFieldValue::Value(value) => Self::Value(value),
DbFieldValue::Auto => panic!("Cannot create query with a non-value field"),
}
}
#[must_use]
pub fn and(lhs: Self, rhs: Self) -> Self {
Self::And(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn or(lhs: Self, rhs: Self) -> Self {
Self::Or(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn eq(lhs: Self, rhs: Self) -> Self {
Self::Eq(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn ne(lhs: Self, rhs: Self) -> Self {
Self::Ne(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn lt(lhs: Self, rhs: Self) -> Self {
Self::Lt(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn lte(lhs: Self, rhs: Self) -> Self {
Self::Lte(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn gt(lhs: Self, rhs: Self) -> Self {
Self::Gt(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn gte(lhs: Self, rhs: Self) -> Self {
Self::Gte(Box::new(lhs), Box::new(rhs))
}
#[expect(clippy::should_implement_trait)]
#[must_use]
pub fn add(lhs: Self, rhs: Self) -> Self {
Self::Add(Box::new(lhs), Box::new(rhs))
}
#[expect(clippy::should_implement_trait)]
#[must_use]
pub fn sub(lhs: Self, rhs: Self) -> Self {
Self::Sub(Box::new(lhs), Box::new(rhs))
}
#[expect(clippy::should_implement_trait)]
#[must_use]
pub fn mul(lhs: Self, rhs: Self) -> Self {
Self::Mul(Box::new(lhs), Box::new(rhs))
}
#[expect(clippy::should_implement_trait)]
#[must_use]
pub fn div(lhs: Self, rhs: Self) -> Self {
Self::Div(Box::new(lhs), Box::new(rhs))
}
#[must_use]
pub fn as_sea_query_expr(&self) -> sea_query::SimpleExpr {
match self {
Self::Field(identifier) => (*identifier).into_column_ref().into(),
Self::Value(value) => (*value).clone().into(),
Self::And(lhs, rhs) => lhs.as_sea_query_expr().and(rhs.as_sea_query_expr()),
Self::Or(lhs, rhs) => lhs.as_sea_query_expr().or(rhs.as_sea_query_expr()),
Self::Eq(lhs, rhs) => lhs.as_sea_query_expr().eq(rhs.as_sea_query_expr()),
Self::Ne(lhs, rhs) => lhs.as_sea_query_expr().ne(rhs.as_sea_query_expr()),
Self::Lt(lhs, rhs) => lhs.as_sea_query_expr().lt(rhs.as_sea_query_expr()),
Self::Lte(lhs, rhs) => lhs.as_sea_query_expr().lte(rhs.as_sea_query_expr()),
Self::Gt(lhs, rhs) => lhs.as_sea_query_expr().gt(rhs.as_sea_query_expr()),
Self::Gte(lhs, rhs) => lhs.as_sea_query_expr().gte(rhs.as_sea_query_expr()),
Self::Add(lhs, rhs) => lhs.as_sea_query_expr().add(rhs.as_sea_query_expr()),
Self::Sub(lhs, rhs) => lhs.as_sea_query_expr().sub(rhs.as_sea_query_expr()),
Self::Mul(lhs, rhs) => lhs.as_sea_query_expr().mul(rhs.as_sea_query_expr()),
Self::Div(lhs, rhs) => lhs.as_sea_query_expr().div(rhs.as_sea_query_expr()),
}
}
}
#[derive(Debug)]
pub struct FieldRef<T> {
identifier: Identifier,
phantom_data: PhantomData<T>,
}
impl<T: FromDbValue + ToDbFieldValue> FieldRef<T> {
#[must_use]
pub const fn new(identifier: Identifier) -> Self {
Self {
identifier,
phantom_data: PhantomData,
}
}
}
impl<T> FieldRef<T> {
#[must_use]
pub fn as_expr(&self) -> Expr {
Expr::Field(self.identifier)
}
}
pub trait ExprEq<T> {
fn eq<V: IntoField<T>>(self, other: V) -> Expr;
fn ne<V: IntoField<T>>(self, other: V) -> Expr;
}
impl<T: ToDbFieldValue + 'static> ExprEq<T> for FieldRef<T> {
fn eq<V: IntoField<T>>(self, other: V) -> Expr {
Expr::eq(self.as_expr(), Expr::value(other.into_field()))
}
fn ne<V: IntoField<T>>(self, other: V) -> Expr {
Expr::ne(self.as_expr(), Expr::value(other.into_field()))
}
}
pub trait ExprAdd<T> {
fn add<V: Into<T>>(self, other: V) -> Expr;
}
pub trait ExprSub<T> {
fn sub<V: Into<T>>(self, other: V) -> Expr;
}
pub trait ExprMul<T> {
fn mul<V: Into<T>>(self, other: V) -> Expr;
}
pub trait ExprDiv<T> {
fn div<V: Into<T>>(self, other: V) -> Expr;
}
pub trait ExprOrd<T> {
fn lt<V: IntoField<T>>(self, other: V) -> Expr;
fn lte<V: IntoField<T>>(self, other: V) -> Expr;
fn gt<V: IntoField<T>>(self, other: V) -> Expr;
fn gte<V: IntoField<T>>(self, other: V) -> Expr;
}
impl<T: ToDbFieldValue + Ord + 'static> ExprOrd<T> for FieldRef<T> {
fn lt<V: IntoField<T>>(self, other: V) -> Expr {
Expr::lt(self.as_expr(), Expr::value(other.into_field()))
}
fn lte<V: IntoField<T>>(self, other: V) -> Expr {
Expr::lte(self.as_expr(), Expr::value(other.into_field()))
}
fn gt<V: IntoField<T>>(self, other: V) -> Expr {
Expr::gt(self.as_expr(), Expr::value(other.into_field()))
}
fn gte<V: IntoField<T>>(self, other: V) -> Expr {
Expr::gte(self.as_expr(), Expr::value(other.into_field()))
}
}
macro_rules! impl_expr {
($ty:ty, $trait:ident, $method:ident) => {
impl $trait<$ty> for FieldRef<$ty> {
fn $method<V: Into<$ty>>(self, other: V) -> Expr {
Expr::$method(self.as_expr(), Expr::value(other.into()))
}
}
};
}
macro_rules! impl_num_expr {
($ty:ty) => {
impl_expr!($ty, ExprAdd, add);
impl_expr!($ty, ExprSub, sub);
impl_expr!($ty, ExprMul, mul);
impl_expr!($ty, ExprDiv, div);
};
}
impl_num_expr!(i8);
impl_num_expr!(i16);
impl_num_expr!(i32);
impl_num_expr!(i64);
impl_num_expr!(u8);
impl_num_expr!(u16);
impl_num_expr!(u32);
impl_num_expr!(u64);
impl_num_expr!(f32);
impl_num_expr!(f64);
pub trait IntoField<T> {
fn into_field(self) -> T;
}
impl<T: ToDbFieldValue> IntoField<T> for T {
fn into_field(self) -> T {
self
}
}
impl<T> IntoField<Auto<T>> for T {
fn into_field(self) -> Auto<T> {
Auto::fixed(self)
}
}
impl IntoField<String> for &str {
fn into_field(self) -> String {
self.to_string()
}
}
impl<T: Model + Send + Sync> IntoField<ForeignKey<T>> for T {
fn into_field(self) -> ForeignKey<T> {
ForeignKey::from(self)
}
}
impl<T: Model + Send + Sync> IntoField<ForeignKey<T>> for &T {
fn into_field(self) -> ForeignKey<T> {
ForeignKey::from(self)
}
}
#[cfg(test)]
mod tests {
use cot_macros::model;
use super::*;
use crate::db::{MockDatabaseBackend, RowsNum};
#[model]
#[derive(std::fmt::Debug, PartialEq, Eq)]
struct MockModel {
#[model(primary_key)]
id: i32,
}
#[test]
fn query_new() {
let query: Query<MockModel> = Query::new();
assert!(query.filter.is_none());
assert!(query.limit.is_none());
assert!(query.offset.is_none());
}
#[test]
fn query_default() {
let query: Query<MockModel> = Query::default();
assert!(query.filter.is_none());
assert!(query.limit.is_none());
assert!(query.offset.is_none());
}
#[test]
fn query_filter() {
let mut query: Query<MockModel> = Query::new();
query.filter(Expr::eq(Expr::field("name"), Expr::value("John")));
assert!(query.filter.is_some());
}
#[test]
fn query_limit() {
let mut query: Query<MockModel> = Query::new();
query.limit(10);
assert!(query.limit.is_some());
assert_eq!(query.limit.unwrap(), 10);
}
#[test]
fn query_offset() {
let mut query: Query<MockModel> = Query::new();
query.offset(10);
assert!(query.offset.is_some());
assert_eq!(query.offset.unwrap(), 10);
}
#[cot::test]
async fn query_all() {
let mut db = MockDatabaseBackend::new();
db.expect_query().returning(|_| Ok(Vec::<MockModel>::new()));
let query: Query<MockModel> = Query::new();
let result = query.all(&db).await;
assert_eq!(result.unwrap(), Vec::<MockModel>::new());
}
#[cot::test]
async fn query_get() {
let mut db = MockDatabaseBackend::new();
db.expect_get().returning(|_| Ok(Option::<MockModel>::None));
let query: Query<MockModel> = Query::new();
let result = query.get(&db).await;
assert_eq!(result.unwrap(), Option::<MockModel>::None);
}
#[cot::test]
async fn query_exists() {
let mut db = MockDatabaseBackend::new();
db.expect_exists()
.returning(|_: &Query<MockModel>| Ok(false));
let query: Query<MockModel> = Query::new();
let result = query.exists(&db).await;
assert!(result.is_ok());
}
#[cot::test]
async fn query_delete() {
let mut db = MockDatabaseBackend::new();
db.expect_delete()
.returning(|_: &Query<MockModel>| Ok(StatementResult::new(RowsNum(0))));
let query: Query<MockModel> = Query::new();
let result = query.delete(&db).await;
assert!(result.is_ok());
}
#[test]
fn expr_field() {
let expr = Expr::field("name");
if let Expr::Field(identifier) = expr {
assert_eq!(identifier.to_string(), "name");
} else {
panic!("Expected Expr::Field");
}
}
#[test]
fn expr_value() {
let expr = Expr::value(30);
if let Expr::Value(value) = expr {
assert_eq!(value.to_string(), "30");
} else {
panic!("Expected Expr::Value");
}
}
macro_rules! test_expr_constructor {
($test_name:ident, $match:ident, $constructor:ident) => {
#[test]
fn $test_name() {
let expr = Expr::$constructor(Expr::field("name"), Expr::value("John"));
if let Expr::$match(lhs, rhs) = expr {
assert!(matches!(*lhs, Expr::Field(_)));
assert!(matches!(*rhs, Expr::Value(_)));
} else {
panic!(concat!("Expected Expr::", stringify!($match)));
}
}
};
}
test_expr_constructor!(expr_and, And, and);
test_expr_constructor!(expr_or, Or, or);
test_expr_constructor!(expr_eq, Eq, eq);
test_expr_constructor!(expr_ne, Ne, ne);
test_expr_constructor!(expr_lt, Lt, lt);
test_expr_constructor!(expr_lte, Lte, lte);
test_expr_constructor!(expr_gt, Gt, gt);
test_expr_constructor!(expr_gte, Gte, gte);
test_expr_constructor!(expr_add, Add, add);
test_expr_constructor!(expr_sub, Sub, sub);
test_expr_constructor!(expr_mul, Mul, mul);
test_expr_constructor!(expr_div, Div, div);
}