Skip to main content

prax_query/operations/
count.rs

1//! Count operation for counting records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8
9/// A count operation for counting records.
10///
11/// # Example
12///
13/// ```rust,ignore
14/// let count = client
15///     .user()
16///     .count()
17///     .r#where(user::active::equals(true))
18///     .exec()
19///     .await?;
20/// ```
21pub struct CountOperation<E: QueryEngine, M: Model> {
22    engine: E,
23    filter: Filter,
24    distinct: Option<String>,
25    _model: PhantomData<M>,
26}
27
28impl<E: QueryEngine, M: Model> CountOperation<E, M> {
29    /// Create a new Count operation.
30    pub fn new(engine: E) -> Self {
31        Self {
32            engine,
33            filter: Filter::None,
34            distinct: None,
35            _model: PhantomData,
36        }
37    }
38
39    /// Add a filter condition.
40    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
41        let new_filter = filter.into();
42        self.filter = self.filter.and_then(new_filter);
43        self
44    }
45
46    /// Count distinct values of a column.
47    pub fn distinct(mut self, column: impl Into<String>) -> Self {
48        self.distinct = Some(column.into());
49        self
50    }
51
52    /// Build the SQL query.
53    pub fn build_sql(
54        &self,
55        dialect: &dyn crate::dialect::SqlDialect,
56    ) -> (String, Vec<FilterValue>) {
57        let (where_sql, params) = self.filter.to_sql(0, dialect);
58
59        let mut sql = String::new();
60
61        // SELECT COUNT clause
62        sql.push_str("SELECT COUNT(");
63        match &self.distinct {
64            Some(col) => {
65                sql.push_str("DISTINCT ");
66                sql.push_str(col);
67            }
68            None => sql.push('*'),
69        }
70        sql.push(')');
71
72        // FROM clause
73        sql.push_str(" FROM ");
74        sql.push_str(M::TABLE_NAME);
75
76        // WHERE clause
77        if !self.filter.is_none() {
78            sql.push_str(" WHERE ");
79            sql.push_str(&where_sql);
80        }
81
82        (sql, params)
83    }
84
85    /// Execute the count query.
86    pub async fn exec(self) -> QueryResult<u64> {
87        let dialect = self.engine.dialect();
88        let (sql, params) = self.build_sql(dialect);
89        self.engine.count(&sql, params).await
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96    use crate::error::QueryError;
97
98    struct TestModel;
99
100    impl Model for TestModel {
101        const MODEL_NAME: &'static str = "TestModel";
102        const TABLE_NAME: &'static str = "test_models";
103        const PRIMARY_KEY: &'static [&'static str] = &["id"];
104        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
105    }
106
107    impl crate::row::FromRow for TestModel {
108        fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
109            Ok(TestModel)
110        }
111    }
112
113    #[derive(Clone)]
114    struct MockEngine {
115        count_result: u64,
116    }
117
118    impl MockEngine {
119        fn new() -> Self {
120            Self { count_result: 0 }
121        }
122
123        fn with_count(count: u64) -> Self {
124            Self {
125                count_result: count,
126            }
127        }
128    }
129
130    impl QueryEngine for MockEngine {
131        fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
132            &crate::dialect::Postgres
133        }
134
135        fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
136            &self,
137            _sql: &str,
138            _params: Vec<FilterValue>,
139        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
140            Box::pin(async { Ok(Vec::new()) })
141        }
142
143        fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
144            &self,
145            _sql: &str,
146            _params: Vec<FilterValue>,
147        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
148            Box::pin(async { Err(QueryError::not_found("test")) })
149        }
150
151        fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
152            &self,
153            _sql: &str,
154            _params: Vec<FilterValue>,
155        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
156            Box::pin(async { Ok(None) })
157        }
158
159        fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
160            &self,
161            _sql: &str,
162            _params: Vec<FilterValue>,
163        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
164            Box::pin(async { Err(QueryError::not_found("test")) })
165        }
166
167        fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
168            &self,
169            _sql: &str,
170            _params: Vec<FilterValue>,
171        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
172            Box::pin(async { Ok(Vec::new()) })
173        }
174
175        fn execute_delete(
176            &self,
177            _sql: &str,
178            _params: Vec<FilterValue>,
179        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
180            Box::pin(async { Ok(0) })
181        }
182
183        fn execute_raw(
184            &self,
185            _sql: &str,
186            _params: Vec<FilterValue>,
187        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
188            Box::pin(async { Ok(0) })
189        }
190
191        fn count(
192            &self,
193            _sql: &str,
194            _params: Vec<FilterValue>,
195        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
196            let count = self.count_result;
197            Box::pin(async move { Ok(count) })
198        }
199    }
200
201    // ========== Construction Tests ==========
202
203    #[test]
204    fn test_count_new() {
205        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
206        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
207
208        assert!(sql.contains("SELECT COUNT(*)"));
209        assert!(sql.contains("FROM test_models"));
210        assert!(params.is_empty());
211    }
212
213    #[test]
214    fn test_count_basic() {
215        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
216        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
217
218        assert_eq!(sql, "SELECT COUNT(*) FROM test_models");
219        assert!(params.is_empty());
220    }
221
222    // ========== Filter Tests ==========
223
224    #[test]
225    fn test_count_with_filter() {
226        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
227            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
228
229        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
230
231        assert!(sql.contains("WHERE"));
232        assert!(sql.contains(r#""active" = $1"#));
233        assert_eq!(params.len(), 1);
234    }
235
236    #[test]
237    fn test_count_with_compound_filter() {
238        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
239            .r#where(Filter::Equals(
240                "status".into(),
241                FilterValue::String("active".to_string()),
242            ))
243            .r#where(Filter::Gte("age".into(), FilterValue::Int(18)));
244
245        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
246
247        assert!(sql.contains("WHERE"));
248        assert!(sql.contains("AND"));
249        assert_eq!(params.len(), 2);
250    }
251
252    #[test]
253    fn test_count_with_or_filter() {
254        let op =
255            CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::or([
256                Filter::Equals("role".into(), FilterValue::String("admin".to_string())),
257                Filter::Equals("role".into(), FilterValue::String("moderator".to_string())),
258            ]));
259
260        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
261
262        assert!(sql.contains("OR"));
263        assert_eq!(params.len(), 2);
264    }
265
266    #[test]
267    fn test_count_with_in_filter() {
268        let op =
269            CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::In(
270                "status".into(),
271                vec![
272                    FilterValue::String("pending".to_string()),
273                    FilterValue::String("processing".to_string()),
274                    FilterValue::String("completed".to_string()),
275                ],
276            ));
277
278        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
279
280        assert!(sql.contains("IN"));
281        assert_eq!(params.len(), 3);
282    }
283
284    #[test]
285    fn test_count_without_filter() {
286        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
287        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
288
289        assert!(!sql.contains("WHERE"));
290        assert!(params.is_empty());
291    }
292
293    #[test]
294    fn test_count_with_null_filter() {
295        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
296            .r#where(Filter::IsNull("deleted_at".into()));
297
298        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
299
300        assert!(sql.contains("WHERE"));
301        assert!(sql.contains("IS NULL"));
302        assert!(params.is_empty());
303    }
304
305    #[test]
306    fn test_count_with_not_null_filter() {
307        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
308            .r#where(Filter::IsNotNull("verified_at".into()));
309
310        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
311
312        assert!(sql.contains("IS NOT NULL"));
313        assert!(params.is_empty());
314    }
315
316    // ========== Distinct Tests ==========
317
318    #[test]
319    fn test_count_distinct() {
320        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).distinct("email");
321
322        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
323
324        assert!(sql.contains("COUNT(DISTINCT email)"));
325        assert!(!sql.contains("COUNT(*)"));
326    }
327
328    #[test]
329    fn test_count_distinct_with_filter() {
330        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
331            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
332            .distinct("user_id");
333
334        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
335
336        assert!(sql.contains("COUNT(DISTINCT user_id)"));
337        assert!(sql.contains("WHERE"));
338        assert_eq!(params.len(), 1);
339    }
340
341    #[test]
342    fn test_count_distinct_replaces() {
343        // Later distinct should replace the previous one
344        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
345            .distinct("email")
346            .distinct("user_id");
347
348        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
349
350        assert!(sql.contains("COUNT(DISTINCT user_id)"));
351        assert!(!sql.contains("COUNT(DISTINCT email)"));
352    }
353
354    // ========== SQL Structure Tests ==========
355
356    #[test]
357    fn test_count_sql_structure() {
358        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
359            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
360
361        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
362
363        let count_pos = sql.find("COUNT").unwrap();
364        let from_pos = sql.find("FROM").unwrap();
365        let where_pos = sql.find("WHERE").unwrap();
366
367        assert!(count_pos < from_pos);
368        assert!(from_pos < where_pos);
369    }
370
371    #[test]
372    fn test_count_table_name() {
373        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
374        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
375
376        assert!(sql.contains("test_models"));
377    }
378
379    // ========== Async Execution Tests ==========
380
381    #[tokio::test]
382    async fn test_count_exec() {
383        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(42));
384
385        let result = op.exec().await;
386
387        assert!(result.is_ok());
388        assert_eq!(result.unwrap(), 42);
389    }
390
391    #[tokio::test]
392    async fn test_count_exec_with_filter() {
393        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(10))
394            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
395
396        let result = op.exec().await;
397
398        assert!(result.is_ok());
399        assert_eq!(result.unwrap(), 10);
400    }
401
402    #[tokio::test]
403    async fn test_count_exec_zero() {
404        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
405
406        let result = op.exec().await;
407
408        assert!(result.is_ok());
409        assert_eq!(result.unwrap(), 0);
410    }
411
412    // ========== Method Chaining Tests ==========
413
414    #[test]
415    fn test_count_method_chaining() {
416        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
417            .r#where(Filter::Equals(
418                "status".into(),
419                FilterValue::String("active".to_string()),
420            ))
421            .distinct("user_id");
422
423        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
424
425        assert!(sql.contains("COUNT(DISTINCT user_id)"));
426        assert!(sql.contains("WHERE"));
427        assert_eq!(params.len(), 1);
428    }
429
430    // ========== Edge Cases ==========
431
432    #[test]
433    fn test_count_with_like_filter() {
434        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
435            Filter::Contains(
436                "email".into(),
437                FilterValue::String("@example.com".to_string()),
438            ),
439        );
440
441        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
442
443        assert!(sql.contains("LIKE"));
444        assert_eq!(params.len(), 1);
445    }
446
447    #[test]
448    fn test_count_with_starts_with() {
449        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
450            Filter::StartsWith("name".into(), FilterValue::String("A".to_string())),
451        );
452
453        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
454
455        assert!(sql.contains("LIKE"));
456        assert_eq!(params.len(), 1);
457    }
458
459    #[test]
460    fn test_count_with_not_filter() {
461        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
462            Filter::Not(Box::new(Filter::Equals(
463                "status".into(),
464                FilterValue::String("deleted".to_string()),
465            ))),
466        );
467
468        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
469
470        assert!(sql.contains("NOT"));
471        assert_eq!(params.len(), 1);
472    }
473}