prax_query/operations/
count.rs

1//! Count operation for counting records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8
9/// A count operation for counting records.
10///
11/// # Example
12///
13/// ```rust,ignore
14/// let count = client
15///     .user()
16///     .count()
17///     .r#where(user::active::equals(true))
18///     .exec()
19///     .await?;
20/// ```
21pub struct CountOperation<E: QueryEngine, M: Model> {
22    engine: E,
23    filter: Filter,
24    distinct: Option<String>,
25    _model: PhantomData<M>,
26}
27
28impl<E: QueryEngine, M: Model> CountOperation<E, M> {
29    /// Create a new Count operation.
30    pub fn new(engine: E) -> Self {
31        Self {
32            engine,
33            filter: Filter::None,
34            distinct: None,
35            _model: PhantomData,
36        }
37    }
38
39    /// Add a filter condition.
40    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
41        let new_filter = filter.into();
42        self.filter = self.filter.and_then(new_filter);
43        self
44    }
45
46    /// Count distinct values of a column.
47    pub fn distinct(mut self, column: impl Into<String>) -> Self {
48        self.distinct = Some(column.into());
49        self
50    }
51
52    /// Build the SQL query.
53    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
54        let (where_sql, params) = self.filter.to_sql(0);
55
56        let mut sql = String::new();
57
58        // SELECT COUNT clause
59        sql.push_str("SELECT COUNT(");
60        match &self.distinct {
61            Some(col) => {
62                sql.push_str("DISTINCT ");
63                sql.push_str(col);
64            }
65            None => sql.push('*'),
66        }
67        sql.push(')');
68
69        // FROM clause
70        sql.push_str(" FROM ");
71        sql.push_str(M::TABLE_NAME);
72
73        // WHERE clause
74        if !self.filter.is_none() {
75            sql.push_str(" WHERE ");
76            sql.push_str(&where_sql);
77        }
78
79        (sql, params)
80    }
81
82    /// Execute the count query.
83    pub async fn exec(self) -> QueryResult<u64> {
84        let (sql, params) = self.build_sql();
85        self.engine.count(&sql, params).await
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::error::QueryError;
93
94    struct TestModel;
95
96    impl Model for TestModel {
97        const MODEL_NAME: &'static str = "TestModel";
98        const TABLE_NAME: &'static str = "test_models";
99        const PRIMARY_KEY: &'static [&'static str] = &["id"];
100        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
101    }
102
103    #[derive(Clone)]
104    struct MockEngine {
105        count_result: u64,
106    }
107
108    impl MockEngine {
109        fn new() -> Self {
110            Self { count_result: 0 }
111        }
112
113        fn with_count(count: u64) -> Self {
114            Self {
115                count_result: count,
116            }
117        }
118    }
119
120    impl QueryEngine for MockEngine {
121        fn query_many<T: Model + Send + 'static>(
122            &self,
123            _sql: &str,
124            _params: Vec<FilterValue>,
125        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
126            Box::pin(async { Ok(Vec::new()) })
127        }
128
129        fn query_one<T: Model + Send + 'static>(
130            &self,
131            _sql: &str,
132            _params: Vec<FilterValue>,
133        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
134            Box::pin(async { Err(QueryError::not_found("test")) })
135        }
136
137        fn query_optional<T: Model + Send + 'static>(
138            &self,
139            _sql: &str,
140            _params: Vec<FilterValue>,
141        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
142            Box::pin(async { Ok(None) })
143        }
144
145        fn execute_insert<T: Model + Send + 'static>(
146            &self,
147            _sql: &str,
148            _params: Vec<FilterValue>,
149        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
150            Box::pin(async { Err(QueryError::not_found("test")) })
151        }
152
153        fn execute_update<T: Model + Send + 'static>(
154            &self,
155            _sql: &str,
156            _params: Vec<FilterValue>,
157        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
158            Box::pin(async { Ok(Vec::new()) })
159        }
160
161        fn execute_delete(
162            &self,
163            _sql: &str,
164            _params: Vec<FilterValue>,
165        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
166            Box::pin(async { Ok(0) })
167        }
168
169        fn execute_raw(
170            &self,
171            _sql: &str,
172            _params: Vec<FilterValue>,
173        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
174            Box::pin(async { Ok(0) })
175        }
176
177        fn count(
178            &self,
179            _sql: &str,
180            _params: Vec<FilterValue>,
181        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
182            let count = self.count_result;
183            Box::pin(async move { Ok(count) })
184        }
185    }
186
187    // ========== Construction Tests ==========
188
189    #[test]
190    fn test_count_new() {
191        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
192        let (sql, params) = op.build_sql();
193
194        assert!(sql.contains("SELECT COUNT(*)"));
195        assert!(sql.contains("FROM test_models"));
196        assert!(params.is_empty());
197    }
198
199    #[test]
200    fn test_count_basic() {
201        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
202        let (sql, params) = op.build_sql();
203
204        assert_eq!(sql, "SELECT COUNT(*) FROM test_models");
205        assert!(params.is_empty());
206    }
207
208    // ========== Filter Tests ==========
209
210    #[test]
211    fn test_count_with_filter() {
212        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
213            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
214
215        let (sql, params) = op.build_sql();
216
217        assert!(sql.contains("WHERE"));
218        assert!(sql.contains("active = $1"));
219        assert_eq!(params.len(), 1);
220    }
221
222    #[test]
223    fn test_count_with_compound_filter() {
224        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
225            .r#where(Filter::Equals(
226                "status".into(),
227                FilterValue::String("active".to_string()),
228            ))
229            .r#where(Filter::Gte("age".into(), FilterValue::Int(18)));
230
231        let (sql, params) = op.build_sql();
232
233        assert!(sql.contains("WHERE"));
234        assert!(sql.contains("AND"));
235        assert_eq!(params.len(), 2);
236    }
237
238    #[test]
239    fn test_count_with_or_filter() {
240        let op =
241            CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::or([
242                Filter::Equals("role".into(), FilterValue::String("admin".to_string())),
243                Filter::Equals("role".into(), FilterValue::String("moderator".to_string())),
244            ]));
245
246        let (sql, params) = op.build_sql();
247
248        assert!(sql.contains("OR"));
249        assert_eq!(params.len(), 2);
250    }
251
252    #[test]
253    fn test_count_with_in_filter() {
254        let op =
255            CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::In(
256                "status".into(),
257                vec![
258                    FilterValue::String("pending".to_string()),
259                    FilterValue::String("processing".to_string()),
260                    FilterValue::String("completed".to_string()),
261                ],
262            ));
263
264        let (sql, params) = op.build_sql();
265
266        assert!(sql.contains("IN"));
267        assert_eq!(params.len(), 3);
268    }
269
270    #[test]
271    fn test_count_without_filter() {
272        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
273        let (sql, params) = op.build_sql();
274
275        assert!(!sql.contains("WHERE"));
276        assert!(params.is_empty());
277    }
278
279    #[test]
280    fn test_count_with_null_filter() {
281        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
282            .r#where(Filter::IsNull("deleted_at".into()));
283
284        let (sql, params) = op.build_sql();
285
286        assert!(sql.contains("WHERE"));
287        assert!(sql.contains("IS NULL"));
288        assert!(params.is_empty());
289    }
290
291    #[test]
292    fn test_count_with_not_null_filter() {
293        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
294            .r#where(Filter::IsNotNull("verified_at".into()));
295
296        let (sql, params) = op.build_sql();
297
298        assert!(sql.contains("IS NOT NULL"));
299        assert!(params.is_empty());
300    }
301
302    // ========== Distinct Tests ==========
303
304    #[test]
305    fn test_count_distinct() {
306        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).distinct("email");
307
308        let (sql, _) = op.build_sql();
309
310        assert!(sql.contains("COUNT(DISTINCT email)"));
311        assert!(!sql.contains("COUNT(*)"));
312    }
313
314    #[test]
315    fn test_count_distinct_with_filter() {
316        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
317            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
318            .distinct("user_id");
319
320        let (sql, params) = op.build_sql();
321
322        assert!(sql.contains("COUNT(DISTINCT user_id)"));
323        assert!(sql.contains("WHERE"));
324        assert_eq!(params.len(), 1);
325    }
326
327    #[test]
328    fn test_count_distinct_replaces() {
329        // Later distinct should replace the previous one
330        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
331            .distinct("email")
332            .distinct("user_id");
333
334        let (sql, _) = op.build_sql();
335
336        assert!(sql.contains("COUNT(DISTINCT user_id)"));
337        assert!(!sql.contains("COUNT(DISTINCT email)"));
338    }
339
340    // ========== SQL Structure Tests ==========
341
342    #[test]
343    fn test_count_sql_structure() {
344        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
345            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
346
347        let (sql, _) = op.build_sql();
348
349        let count_pos = sql.find("COUNT").unwrap();
350        let from_pos = sql.find("FROM").unwrap();
351        let where_pos = sql.find("WHERE").unwrap();
352
353        assert!(count_pos < from_pos);
354        assert!(from_pos < where_pos);
355    }
356
357    #[test]
358    fn test_count_table_name() {
359        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
360        let (sql, _) = op.build_sql();
361
362        assert!(sql.contains("test_models"));
363    }
364
365    // ========== Async Execution Tests ==========
366
367    #[tokio::test]
368    async fn test_count_exec() {
369        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(42));
370
371        let result = op.exec().await;
372
373        assert!(result.is_ok());
374        assert_eq!(result.unwrap(), 42);
375    }
376
377    #[tokio::test]
378    async fn test_count_exec_with_filter() {
379        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(10))
380            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
381
382        let result = op.exec().await;
383
384        assert!(result.is_ok());
385        assert_eq!(result.unwrap(), 10);
386    }
387
388    #[tokio::test]
389    async fn test_count_exec_zero() {
390        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
391
392        let result = op.exec().await;
393
394        assert!(result.is_ok());
395        assert_eq!(result.unwrap(), 0);
396    }
397
398    // ========== Method Chaining Tests ==========
399
400    #[test]
401    fn test_count_method_chaining() {
402        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
403            .r#where(Filter::Equals(
404                "status".into(),
405                FilterValue::String("active".to_string()),
406            ))
407            .distinct("user_id");
408
409        let (sql, params) = op.build_sql();
410
411        assert!(sql.contains("COUNT(DISTINCT user_id)"));
412        assert!(sql.contains("WHERE"));
413        assert_eq!(params.len(), 1);
414    }
415
416    // ========== Edge Cases ==========
417
418    #[test]
419    fn test_count_with_like_filter() {
420        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
421            Filter::Contains(
422                "email".into(),
423                FilterValue::String("@example.com".to_string()),
424            ),
425        );
426
427        let (sql, params) = op.build_sql();
428
429        assert!(sql.contains("LIKE"));
430        assert_eq!(params.len(), 1);
431    }
432
433    #[test]
434    fn test_count_with_starts_with() {
435        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
436            Filter::StartsWith("name".into(), FilterValue::String("A".to_string())),
437        );
438
439        let (sql, params) = op.build_sql();
440
441        assert!(sql.contains("LIKE"));
442        assert_eq!(params.len(), 1);
443    }
444
445    #[test]
446    fn test_count_with_not_filter() {
447        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
448            Filter::Not(Box::new(Filter::Equals(
449                "status".into(),
450                FilterValue::String("deleted".to_string()),
451            ))),
452        );
453
454        let (sql, params) = op.build_sql();
455
456        assert!(sql.contains("NOT"));
457        assert_eq!(params.len(), 1);
458    }
459}