Skip to main content

modo_db/
query.rs

1use std::marker::PhantomData;
2
3use sea_orm::sea_query::IntoCondition;
4use sea_orm::sea_query::IntoValueTuple;
5use sea_orm::{
6    ColumnTrait, ConnectionTrait, EntityTrait, FromQueryResult, IntoIdentity, PaginatorTrait,
7    QueryFilter, QueryOrder, QuerySelect, Select,
8};
9
10use crate::error::db_err_to_error;
11use crate::pagination::{
12    CursorParams, CursorResult, PageParams, PageResult, paginate, paginate_cursor,
13};
14
15// ── EntityQuery ──────────────────────────────────────────────────────────────
16
17/// A chainable query builder that wraps SeaORM's `Select<E>` and
18/// auto-converts results to the domain type `T` via `From<E::Model>`.
19///
20/// Construct one from `E::find()` or `E::find_by_id(pk)` and then chain
21/// filter/order/limit/offset calls before executing with a terminal method.
22///
23/// # Example
24///
25/// ```rust,ignore
26/// let todos: Vec<Todo> = EntityQuery::new(TodoEntity::find())
27///     .filter(todo::Column::Done.eq(false))
28///     .order_by_asc(todo::Column::CreatedAt)
29///     .limit(10)
30///     .all(&db)
31///     .await?;
32/// ```
33pub struct EntityQuery<T, E: EntityTrait> {
34    select: Select<E>,
35    _phantom: PhantomData<T>,
36}
37
38impl<T, E> EntityQuery<T, E>
39where
40    E: EntityTrait,
41    T: From<E::Model> + Send + Sync,
42    E::Model: FromQueryResult + Send + Sync,
43{
44    /// Wrap an existing `Select<E>`.
45    pub fn new(select: Select<E>) -> Self {
46        Self {
47            select,
48            _phantom: PhantomData,
49        }
50    }
51
52    // ── Chainable methods ────────────────────────────────────────────────────
53
54    /// Apply a WHERE condition.
55    pub fn filter(self, f: impl IntoCondition) -> Self {
56        Self {
57            select: QueryFilter::filter(self.select, f),
58            _phantom: PhantomData,
59        }
60    }
61
62    /// ORDER BY `col` ASC.
63    pub fn order_by_asc<C: ColumnTrait>(self, col: C) -> Self {
64        Self {
65            select: QueryOrder::order_by_asc(self.select, col),
66            _phantom: PhantomData,
67        }
68    }
69
70    /// ORDER BY `col` DESC.
71    pub fn order_by_desc<C: ColumnTrait>(self, col: C) -> Self {
72        Self {
73            select: QueryOrder::order_by_desc(self.select, col),
74            _phantom: PhantomData,
75        }
76    }
77
78    /// LIMIT `n` rows.
79    pub fn limit(self, n: u64) -> Self {
80        Self {
81            select: QuerySelect::limit(self.select, Some(n)),
82            _phantom: PhantomData,
83        }
84    }
85
86    /// OFFSET `n` rows.
87    pub fn offset(self, n: u64) -> Self {
88        Self {
89            select: QuerySelect::offset(self.select, Some(n)),
90            _phantom: PhantomData,
91        }
92    }
93
94    // ── Terminal methods ─────────────────────────────────────────────────────
95
96    /// Fetch all matching rows and convert each model to `T`.
97    pub async fn all(self, db: &impl ConnectionTrait) -> Result<Vec<T>, modo::Error> {
98        let rows = self.select.all(db).await.map_err(db_err_to_error)?;
99        Ok(rows.into_iter().map(T::from).collect())
100    }
101
102    /// Fetch at most one row and convert to `T` if present.
103    pub async fn one(self, db: &impl ConnectionTrait) -> Result<Option<T>, modo::Error> {
104        let row = self.select.one(db).await.map_err(db_err_to_error)?;
105        Ok(row.map(T::from))
106    }
107
108    /// Return the number of rows that match the current query.
109    pub async fn count(self, db: &impl ConnectionTrait) -> Result<u64, modo::Error> {
110        self.select.count(db).await.map_err(db_err_to_error)
111    }
112
113    /// Offset-based pagination. Results are auto-converted to `T`.
114    pub async fn paginate(
115        self,
116        db: &impl ConnectionTrait,
117        params: &PageParams,
118    ) -> Result<PageResult<T>, modo::Error> {
119        paginate(self.select, db, params)
120            .await
121            .map_err(db_err_to_error)
122            .map(|r| r.map(T::from))
123    }
124
125    /// Cursor-based pagination. Results are auto-converted to `T`.
126    ///
127    /// - `col` — the column to paginate on (e.g. `Column::Id`).
128    /// - `cursor_fn` — extracts the cursor string from a model instance.
129    pub async fn paginate_cursor<C, V, F>(
130        self,
131        col: C,
132        cursor_fn: F,
133        db: &impl ConnectionTrait,
134        params: &CursorParams<V>,
135    ) -> Result<CursorResult<T>, modo::Error>
136    where
137        C: IntoIdentity,
138        V: IntoValueTuple + Clone,
139        F: Fn(&E::Model) -> String,
140    {
141        paginate_cursor(self.select, col, cursor_fn, db, params)
142            .await
143            .map_err(db_err_to_error)
144            .map(|r| r.map(T::from))
145    }
146
147    // ── Escape hatch ─────────────────────────────────────────────────────────
148
149    /// Unwrap the inner `Select<E>` for advanced SeaORM usage.
150    pub fn into_select(self) -> Select<E> {
151        self.select
152    }
153}
154
155// ── EntityUpdateMany ─────────────────────────────────────────────────────────
156
157/// A chainable wrapper around SeaORM's `UpdateMany<E>` that returns
158/// `rows_affected` on execution.
159///
160/// # Example
161///
162/// ```rust,ignore
163/// let affected = EntityUpdateMany::new(TodoEntity::update_many())
164///     .filter(todo::Column::Done.eq(false))
165///     .col_expr(todo::Column::Done, Expr::value(true))
166///     .exec(&db)
167///     .await?;
168/// ```
169pub struct EntityUpdateMany<E: EntityTrait> {
170    update: sea_orm::UpdateMany<E>,
171}
172
173impl<E: EntityTrait> EntityUpdateMany<E> {
174    /// Wrap an existing `UpdateMany<E>`.
175    pub fn new(update: sea_orm::UpdateMany<E>) -> Self {
176        Self { update }
177    }
178
179    /// Apply a WHERE condition.
180    pub fn filter(self, f: impl IntoCondition) -> Self {
181        Self {
182            update: QueryFilter::filter(self.update, f),
183        }
184    }
185
186    /// Set a column to a `SimpleExpr` value.
187    ///
188    /// Use `sea_orm::sea_query::Expr::value` for simple literals.
189    pub fn col_expr<C: sea_orm::sea_query::IntoIden>(
190        self,
191        col: C,
192        expr: sea_orm::sea_query::SimpleExpr,
193    ) -> Self {
194        Self {
195            update: self.update.col_expr(col, expr),
196        }
197    }
198
199    /// Execute the update and return the number of rows affected.
200    pub async fn exec(self, db: &impl ConnectionTrait) -> Result<u64, modo::Error> {
201        self.update
202            .exec(db)
203            .await
204            .map(|r| r.rows_affected)
205            .map_err(db_err_to_error)
206    }
207}
208
209// ── EntityDeleteMany ─────────────────────────────────────────────────────────
210
211/// A chainable wrapper around SeaORM's `DeleteMany<E>` that returns
212/// `rows_affected` on execution.
213///
214/// # Example
215///
216/// ```rust,ignore
217/// let deleted = EntityDeleteMany::new(TodoEntity::delete_many())
218///     .filter(todo::Column::Done.eq(true))
219///     .exec(&db)
220///     .await?;
221/// ```
222pub struct EntityDeleteMany<E: EntityTrait> {
223    delete: sea_orm::DeleteMany<E>,
224}
225
226impl<E: EntityTrait> EntityDeleteMany<E> {
227    /// Wrap an existing `DeleteMany<E>`.
228    pub fn new(delete: sea_orm::DeleteMany<E>) -> Self {
229        Self { delete }
230    }
231
232    /// Apply a WHERE condition.
233    pub fn filter(self, f: impl IntoCondition) -> Self {
234        Self {
235            delete: QueryFilter::filter(self.delete, f),
236        }
237    }
238
239    /// Execute the delete and return the number of rows affected.
240    pub async fn exec(self, db: &impl ConnectionTrait) -> Result<u64, modo::Error> {
241        self.delete
242            .exec(db)
243            .await
244            .map(|r| r.rows_affected)
245            .map_err(db_err_to_error)
246    }
247}