prax_query/operations/
count.rs

1//! Count operation for counting records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8
9/// A count operation for counting records.
10///
11/// # Example
12///
13/// ```rust,ignore
14/// let count = client
15///     .user()
16///     .count()
17///     .r#where(user::active::equals(true))
18///     .exec()
19///     .await?;
20/// ```
21pub struct CountOperation<E: QueryEngine, M: Model> {
22    engine: E,
23    filter: Filter,
24    distinct: Option<String>,
25    _model: PhantomData<M>,
26}
27
28impl<E: QueryEngine, M: Model> CountOperation<E, M> {
29    /// Create a new Count operation.
30    pub fn new(engine: E) -> Self {
31        Self {
32            engine,
33            filter: Filter::None,
34            distinct: None,
35            _model: PhantomData,
36        }
37    }
38
39    /// Add a filter condition.
40    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
41        let new_filter = filter.into();
42        self.filter = self.filter.and_then(new_filter);
43        self
44    }
45
46    /// Count distinct values of a column.
47    pub fn distinct(mut self, column: impl Into<String>) -> Self {
48        self.distinct = Some(column.into());
49        self
50    }
51
52    /// Build the SQL query.
53    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
54        let (where_sql, params) = self.filter.to_sql(0);
55
56        let mut sql = String::new();
57
58        // SELECT COUNT clause
59        sql.push_str("SELECT COUNT(");
60        match &self.distinct {
61            Some(col) => {
62                sql.push_str("DISTINCT ");
63                sql.push_str(col);
64            }
65            None => sql.push('*'),
66        }
67        sql.push(')');
68
69        // FROM clause
70        sql.push_str(" FROM ");
71        sql.push_str(M::TABLE_NAME);
72
73        // WHERE clause
74        if !self.filter.is_none() {
75            sql.push_str(" WHERE ");
76            sql.push_str(&where_sql);
77        }
78
79        (sql, params)
80    }
81
82    /// Execute the count query.
83    pub async fn exec(self) -> QueryResult<u64> {
84        let (sql, params) = self.build_sql();
85        self.engine.count(&sql, params).await
86    }
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::error::QueryError;
93
94    struct TestModel;
95
96    impl Model for TestModel {
97        const MODEL_NAME: &'static str = "TestModel";
98        const TABLE_NAME: &'static str = "test_models";
99        const PRIMARY_KEY: &'static [&'static str] = &["id"];
100        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
101    }
102
103    #[derive(Clone)]
104    struct MockEngine {
105        count_result: u64,
106    }
107
108    impl MockEngine {
109        fn new() -> Self {
110            Self { count_result: 0 }
111        }
112
113        fn with_count(count: u64) -> Self {
114            Self { count_result: count }
115        }
116    }
117
118    impl QueryEngine for MockEngine {
119        fn query_many<T: Model + Send + 'static>(
120            &self,
121            _sql: &str,
122            _params: Vec<FilterValue>,
123        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
124            Box::pin(async { Ok(Vec::new()) })
125        }
126
127        fn query_one<T: Model + Send + 'static>(
128            &self,
129            _sql: &str,
130            _params: Vec<FilterValue>,
131        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
132            Box::pin(async { Err(QueryError::not_found("test")) })
133        }
134
135        fn query_optional<T: Model + Send + 'static>(
136            &self,
137            _sql: &str,
138            _params: Vec<FilterValue>,
139        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
140            Box::pin(async { Ok(None) })
141        }
142
143        fn execute_insert<T: Model + Send + 'static>(
144            &self,
145            _sql: &str,
146            _params: Vec<FilterValue>,
147        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
148            Box::pin(async { Err(QueryError::not_found("test")) })
149        }
150
151        fn execute_update<T: Model + Send + 'static>(
152            &self,
153            _sql: &str,
154            _params: Vec<FilterValue>,
155        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
156            Box::pin(async { Ok(Vec::new()) })
157        }
158
159        fn execute_delete(
160            &self,
161            _sql: &str,
162            _params: Vec<FilterValue>,
163        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
164            Box::pin(async { Ok(0) })
165        }
166
167        fn execute_raw(
168            &self,
169            _sql: &str,
170            _params: Vec<FilterValue>,
171        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
172            Box::pin(async { Ok(0) })
173        }
174
175        fn count(
176            &self,
177            _sql: &str,
178            _params: Vec<FilterValue>,
179        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
180            let count = self.count_result;
181            Box::pin(async move { Ok(count) })
182        }
183    }
184
185    // ========== Construction Tests ==========
186
187    #[test]
188    fn test_count_new() {
189        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
190        let (sql, params) = op.build_sql();
191
192        assert!(sql.contains("SELECT COUNT(*)"));
193        assert!(sql.contains("FROM test_models"));
194        assert!(params.is_empty());
195    }
196
197    #[test]
198    fn test_count_basic() {
199        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
200        let (sql, params) = op.build_sql();
201
202        assert_eq!(sql, "SELECT COUNT(*) FROM test_models");
203        assert!(params.is_empty());
204    }
205
206    // ========== Filter Tests ==========
207
208    #[test]
209    fn test_count_with_filter() {
210        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
211            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
212
213        let (sql, params) = op.build_sql();
214
215        assert!(sql.contains("WHERE"));
216        assert!(sql.contains("active = $1"));
217        assert_eq!(params.len(), 1);
218    }
219
220    #[test]
221    fn test_count_with_compound_filter() {
222        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
223            .r#where(Filter::Equals("status".into(), FilterValue::String("active".to_string())))
224            .r#where(Filter::Gte("age".into(), FilterValue::Int(18)));
225
226        let (sql, params) = op.build_sql();
227
228        assert!(sql.contains("WHERE"));
229        assert!(sql.contains("AND"));
230        assert_eq!(params.len(), 2);
231    }
232
233    #[test]
234    fn test_count_with_or_filter() {
235        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
236            .r#where(Filter::or([
237                Filter::Equals("role".into(), FilterValue::String("admin".to_string())),
238                Filter::Equals("role".into(), FilterValue::String("moderator".to_string())),
239            ]));
240
241        let (sql, params) = op.build_sql();
242
243        assert!(sql.contains("OR"));
244        assert_eq!(params.len(), 2);
245    }
246
247    #[test]
248    fn test_count_with_in_filter() {
249        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
250            .r#where(Filter::In(
251                "status".into(),
252                vec![
253                    FilterValue::String("pending".to_string()),
254                    FilterValue::String("processing".to_string()),
255                    FilterValue::String("completed".to_string()),
256                ],
257            ));
258
259        let (sql, params) = op.build_sql();
260
261        assert!(sql.contains("IN"));
262        assert_eq!(params.len(), 3);
263    }
264
265    #[test]
266    fn test_count_without_filter() {
267        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
268        let (sql, params) = op.build_sql();
269
270        assert!(!sql.contains("WHERE"));
271        assert!(params.is_empty());
272    }
273
274    #[test]
275    fn test_count_with_null_filter() {
276        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
277            .r#where(Filter::IsNull("deleted_at".into()));
278
279        let (sql, params) = op.build_sql();
280
281        assert!(sql.contains("WHERE"));
282        assert!(sql.contains("IS NULL"));
283        assert!(params.is_empty());
284    }
285
286    #[test]
287    fn test_count_with_not_null_filter() {
288        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
289            .r#where(Filter::IsNotNull("verified_at".into()));
290
291        let (sql, params) = op.build_sql();
292
293        assert!(sql.contains("IS NOT NULL"));
294        assert!(params.is_empty());
295    }
296
297    // ========== Distinct Tests ==========
298
299    #[test]
300    fn test_count_distinct() {
301        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
302            .distinct("email");
303
304        let (sql, _) = op.build_sql();
305
306        assert!(sql.contains("COUNT(DISTINCT email)"));
307        assert!(!sql.contains("COUNT(*)"));
308    }
309
310    #[test]
311    fn test_count_distinct_with_filter() {
312        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
313            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
314            .distinct("user_id");
315
316        let (sql, params) = op.build_sql();
317
318        assert!(sql.contains("COUNT(DISTINCT user_id)"));
319        assert!(sql.contains("WHERE"));
320        assert_eq!(params.len(), 1);
321    }
322
323    #[test]
324    fn test_count_distinct_replaces() {
325        // Later distinct should replace the previous one
326        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
327            .distinct("email")
328            .distinct("user_id");
329
330        let (sql, _) = op.build_sql();
331
332        assert!(sql.contains("COUNT(DISTINCT user_id)"));
333        assert!(!sql.contains("COUNT(DISTINCT email)"));
334    }
335
336    // ========== SQL Structure Tests ==========
337
338    #[test]
339    fn test_count_sql_structure() {
340        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
341            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
342
343        let (sql, _) = op.build_sql();
344
345        let count_pos = sql.find("COUNT").unwrap();
346        let from_pos = sql.find("FROM").unwrap();
347        let where_pos = sql.find("WHERE").unwrap();
348
349        assert!(count_pos < from_pos);
350        assert!(from_pos < where_pos);
351    }
352
353    #[test]
354    fn test_count_table_name() {
355        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
356        let (sql, _) = op.build_sql();
357
358        assert!(sql.contains("test_models"));
359    }
360
361    // ========== Async Execution Tests ==========
362
363    #[tokio::test]
364    async fn test_count_exec() {
365        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(42));
366
367        let result = op.exec().await;
368
369        assert!(result.is_ok());
370        assert_eq!(result.unwrap(), 42);
371    }
372
373    #[tokio::test]
374    async fn test_count_exec_with_filter() {
375        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(10))
376            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
377
378        let result = op.exec().await;
379
380        assert!(result.is_ok());
381        assert_eq!(result.unwrap(), 10);
382    }
383
384    #[tokio::test]
385    async fn test_count_exec_zero() {
386        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
387
388        let result = op.exec().await;
389
390        assert!(result.is_ok());
391        assert_eq!(result.unwrap(), 0);
392    }
393
394    // ========== Method Chaining Tests ==========
395
396    #[test]
397    fn test_count_method_chaining() {
398        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
399            .r#where(Filter::Equals("status".into(), FilterValue::String("active".to_string())))
400            .distinct("user_id");
401
402        let (sql, params) = op.build_sql();
403
404        assert!(sql.contains("COUNT(DISTINCT user_id)"));
405        assert!(sql.contains("WHERE"));
406        assert_eq!(params.len(), 1);
407    }
408
409    // ========== Edge Cases ==========
410
411    #[test]
412    fn test_count_with_like_filter() {
413        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
414            .r#where(Filter::Contains("email".into(), FilterValue::String("@example.com".to_string())));
415
416        let (sql, params) = op.build_sql();
417
418        assert!(sql.contains("LIKE"));
419        assert_eq!(params.len(), 1);
420    }
421
422    #[test]
423    fn test_count_with_starts_with() {
424        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
425            .r#where(Filter::StartsWith("name".into(), FilterValue::String("A".to_string())));
426
427        let (sql, params) = op.build_sql();
428
429        assert!(sql.contains("LIKE"));
430        assert_eq!(params.len(), 1);
431    }
432
433    #[test]
434    fn test_count_with_not_filter() {
435        let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
436            .r#where(Filter::Not(Box::new(Filter::Equals(
437                "status".into(),
438                FilterValue::String("deleted".to_string()),
439            ))));
440
441        let (sql, params) = op.build_sql();
442
443        assert!(sql.contains("NOT"));
444        assert_eq!(params.len(), 1);
445    }
446}
447