Skip to main content

prax_query/operations/
count.rs

1//! Count operation for counting records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8
9/// A count operation for counting records.
10///
11/// # Example
12///
13/// ```rust,ignore
14/// let count = client
15///     .user()
16///     .count()
17///     .r#where(user::active::equals(true))
18///     .exec()
19///     .await?;
20/// ```
21pub struct CountOperation<E: QueryEngine, M: Model> {
22    engine: E,
23    filter: Filter,
24    distinct: Option<String>,
25    _model: PhantomData<M>,
26}
27
28impl<E: QueryEngine, M: Model> CountOperation<E, M> {
29    /// Create a new Count operation.
30    pub fn new(engine: E) -> Self {
31        Self {
32            engine,
33            filter: Filter::None,
34            distinct: None,
35            _model: PhantomData,
36        }
37    }
38
39    /// Add a filter condition.
40    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
41        let new_filter = filter.into();
42        self.filter = self.filter.and_then(new_filter);
43        self
44    }
45
46    /// Count distinct values of a column.
47    pub fn distinct(mut self, column: impl Into<String>) -> Self {
48        self.distinct = Some(column.into());
49        self
50    }
51
52    /// Build the SQL query.
53    pub fn build_sql(
54        &self,
55        dialect: &dyn crate::dialect::SqlDialect,
56    ) -> (String, Vec<FilterValue>) {
57        let (where_sql, params) = self.filter.to_sql(0, dialect);
58
59        let mut sql = String::new();
60
61        // SELECT COUNT clause
62        sql.push_str("SELECT COUNT(");
63        match &self.distinct {
64            Some(col) => {
65                sql.push_str("DISTINCT ");
66                sql.push_str(col);
67            }
68            None => sql.push('*'),
69        }
70        sql.push(')');
71
72        // FROM clause
73        sql.push_str(" FROM ");
74        sql.push_str(M::TABLE_NAME);
75
76        // WHERE clause
77        if !self.filter.is_none() {
78            sql.push_str(" WHERE ");
79            sql.push_str(&where_sql);
80        }
81
82        (sql, params)
83    }
84
85    /// Apply a typed `WhereInput`. AND-composes with any previously set filter.
86    pub fn with_where_input<W: crate::inputs::WhereInput<Model = M>>(mut self, w: W) -> Self {
87        let f = w.into_ir();
88        self.filter = self.filter.and_then(f);
89        self
90    }
91
92    /// Doc-hidden accessor for the current filter.
93    #[doc(hidden)]
94    pub fn filter_for_test(&self) -> &Filter {
95        &self.filter
96    }
97
98    /// Execute the count query.
99    pub async fn exec(self) -> QueryResult<u64> {
100        let dialect = self.engine.dialect();
101        let (sql, params) = self.build_sql(dialect);
102        self.engine.count(&sql, params).await
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::error::QueryError;
110
111    struct TestModel;
112
113    impl Model for TestModel {
114        const MODEL_NAME: &'static str = "TestModel";
115        const TABLE_NAME: &'static str = "test_models";
116        const PRIMARY_KEY: &'static [&'static str] = &["id"];
117        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
118    }
119
120    impl crate::row::FromRow for TestModel {
121        fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
122            Ok(TestModel)
123        }
124    }
125
126    #[derive(Clone)]
127    struct MockEngine {
128        count_result: u64,
129    }
130
131    impl MockEngine {
132        fn new() -> Self {
133            Self { count_result: 0 }
134        }
135
136        fn with_count(count: u64) -> Self {
137            Self {
138                count_result: count,
139            }
140        }
141    }
142
143    impl QueryEngine for MockEngine {
144        fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
145            &crate::dialect::Postgres
146        }
147
148        fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
149            &self,
150            _sql: &str,
151            _params: Vec<FilterValue>,
152        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
153            Box::pin(async { Ok(Vec::new()) })
154        }
155
156        fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
157            &self,
158            _sql: &str,
159            _params: Vec<FilterValue>,
160        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
161            Box::pin(async { Err(QueryError::not_found("test")) })
162        }
163
164        fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
165            &self,
166            _sql: &str,
167            _params: Vec<FilterValue>,
168        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
169            Box::pin(async { Ok(None) })
170        }
171
172        fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
173            &self,
174            _sql: &str,
175            _params: Vec<FilterValue>,
176        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
177            Box::pin(async { Err(QueryError::not_found("test")) })
178        }
179
180        fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
181            &self,
182            _sql: &str,
183            _params: Vec<FilterValue>,
184        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
185            Box::pin(async { Ok(Vec::new()) })
186        }
187
188        fn execute_delete(
189            &self,
190            _sql: &str,
191            _params: Vec<FilterValue>,
192        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
193            Box::pin(async { Ok(0) })
194        }
195
196        fn execute_raw(
197            &self,
198            _sql: &str,
199            _params: Vec<FilterValue>,
200        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
201            Box::pin(async { Ok(0) })
202        }
203
204        fn count(
205            &self,
206            _sql: &str,
207            _params: Vec<FilterValue>,
208        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
209            let count = self.count_result;
210            Box::pin(async move { Ok(count) })
211        }
212    }
213
214    // ========== Construction Tests ==========
215
216    #[test]
217    fn test_count_new() {
218        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
219        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
220
221        assert!(sql.contains("SELECT COUNT(*)"));
222        assert!(sql.contains("FROM test_models"));
223        assert!(params.is_empty());
224    }
225
226    #[test]
227    fn test_count_basic() {
228        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
229        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
230
231        assert_eq!(sql, "SELECT COUNT(*) FROM test_models");
232        assert!(params.is_empty());
233    }
234
235    // ========== Filter Tests ==========
236
237    #[test]
238    fn test_count_with_filter() {
239        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
240            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
241
242        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
243
244        assert!(sql.contains("WHERE"));
245        assert!(sql.contains(r#""active" = $1"#));
246        assert_eq!(params.len(), 1);
247    }
248
249    #[test]
250    fn test_count_with_compound_filter() {
251        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
252            .r#where(Filter::Equals(
253                "status".into(),
254                FilterValue::String("active".to_string()),
255            ))
256            .r#where(Filter::Gte("age".into(), FilterValue::Int(18)));
257
258        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
259
260        assert!(sql.contains("WHERE"));
261        assert!(sql.contains("AND"));
262        assert_eq!(params.len(), 2);
263    }
264
265    #[test]
266    fn test_count_with_or_filter() {
267        let op =
268            CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::or([
269                Filter::Equals("role".into(), FilterValue::String("admin".to_string())),
270                Filter::Equals("role".into(), FilterValue::String("moderator".to_string())),
271            ]));
272
273        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
274
275        assert!(sql.contains("OR"));
276        assert_eq!(params.len(), 2);
277    }
278
279    #[test]
280    fn test_count_with_in_filter() {
281        let op =
282            CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::In(
283                "status".into(),
284                vec![
285                    FilterValue::String("pending".to_string()),
286                    FilterValue::String("processing".to_string()),
287                    FilterValue::String("completed".to_string()),
288                ],
289            ));
290
291        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
292
293        assert!(sql.contains("IN"));
294        assert_eq!(params.len(), 3);
295    }
296
297    #[test]
298    fn test_count_without_filter() {
299        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
300        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
301
302        assert!(!sql.contains("WHERE"));
303        assert!(params.is_empty());
304    }
305
306    #[test]
307    fn test_count_with_null_filter() {
308        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
309            .r#where(Filter::IsNull("deleted_at".into()));
310
311        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
312
313        assert!(sql.contains("WHERE"));
314        assert!(sql.contains("IS NULL"));
315        assert!(params.is_empty());
316    }
317
318    #[test]
319    fn test_count_with_not_null_filter() {
320        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
321            .r#where(Filter::IsNotNull("verified_at".into()));
322
323        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
324
325        assert!(sql.contains("IS NOT NULL"));
326        assert!(params.is_empty());
327    }
328
329    // ========== Distinct Tests ==========
330
331    #[test]
332    fn test_count_distinct() {
333        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).distinct("email");
334
335        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
336
337        assert!(sql.contains("COUNT(DISTINCT email)"));
338        assert!(!sql.contains("COUNT(*)"));
339    }
340
341    #[test]
342    fn test_count_distinct_with_filter() {
343        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
344            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
345            .distinct("user_id");
346
347        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
348
349        assert!(sql.contains("COUNT(DISTINCT user_id)"));
350        assert!(sql.contains("WHERE"));
351        assert_eq!(params.len(), 1);
352    }
353
354    #[test]
355    fn test_count_distinct_replaces() {
356        // Later distinct should replace the previous one
357        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
358            .distinct("email")
359            .distinct("user_id");
360
361        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
362
363        assert!(sql.contains("COUNT(DISTINCT user_id)"));
364        assert!(!sql.contains("COUNT(DISTINCT email)"));
365    }
366
367    // ========== SQL Structure Tests ==========
368
369    #[test]
370    fn test_count_sql_structure() {
371        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
372            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
373
374        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
375
376        let count_pos = sql.find("COUNT").unwrap();
377        let from_pos = sql.find("FROM").unwrap();
378        let where_pos = sql.find("WHERE").unwrap();
379
380        assert!(count_pos < from_pos);
381        assert!(from_pos < where_pos);
382    }
383
384    #[test]
385    fn test_count_table_name() {
386        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
387        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
388
389        assert!(sql.contains("test_models"));
390    }
391
392    // ========== Async Execution Tests ==========
393
394    #[tokio::test]
395    async fn test_count_exec() {
396        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(42));
397
398        let result = op.exec().await;
399
400        assert!(result.is_ok());
401        assert_eq!(result.unwrap(), 42);
402    }
403
404    #[tokio::test]
405    async fn test_count_exec_with_filter() {
406        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(10))
407            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
408
409        let result = op.exec().await;
410
411        assert!(result.is_ok());
412        assert_eq!(result.unwrap(), 10);
413    }
414
415    #[tokio::test]
416    async fn test_count_exec_zero() {
417        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
418
419        let result = op.exec().await;
420
421        assert!(result.is_ok());
422        assert_eq!(result.unwrap(), 0);
423    }
424
425    // ========== Method Chaining Tests ==========
426
427    #[test]
428    fn test_count_method_chaining() {
429        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
430            .r#where(Filter::Equals(
431                "status".into(),
432                FilterValue::String("active".to_string()),
433            ))
434            .distinct("user_id");
435
436        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
437
438        assert!(sql.contains("COUNT(DISTINCT user_id)"));
439        assert!(sql.contains("WHERE"));
440        assert_eq!(params.len(), 1);
441    }
442
443    // ========== Edge Cases ==========
444
445    #[test]
446    fn test_count_with_like_filter() {
447        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
448            Filter::Contains(
449                "email".into(),
450                FilterValue::String("@example.com".to_string()),
451            ),
452        );
453
454        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
455
456        assert!(sql.contains("LIKE"));
457        assert_eq!(params.len(), 1);
458    }
459
460    #[test]
461    fn test_count_with_starts_with() {
462        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
463            Filter::StartsWith("name".into(), FilterValue::String("A".to_string())),
464        );
465
466        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
467
468        assert!(sql.contains("LIKE"));
469        assert_eq!(params.len(), 1);
470    }
471
472    #[test]
473    fn test_count_with_not_filter() {
474        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
475            Filter::Not(Box::new(Filter::Equals(
476                "status".into(),
477                FilterValue::String("deleted".to_string()),
478            ))),
479        );
480
481        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
482
483        assert!(sql.contains("NOT"));
484        assert_eq!(params.len(), 1);
485    }
486}