1use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8
9pub struct CountOperation<E: QueryEngine, M: Model> {
22 engine: E,
23 filter: Filter,
24 distinct: Option<String>,
25 _model: PhantomData<M>,
26}
27
28impl<E: QueryEngine, M: Model> CountOperation<E, M> {
29 pub fn new(engine: E) -> Self {
31 Self {
32 engine,
33 filter: Filter::None,
34 distinct: None,
35 _model: PhantomData,
36 }
37 }
38
39 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
41 let new_filter = filter.into();
42 self.filter = self.filter.and_then(new_filter);
43 self
44 }
45
46 pub fn distinct(mut self, column: impl Into<String>) -> Self {
48 self.distinct = Some(column.into());
49 self
50 }
51
52 pub fn build_sql(
54 &self,
55 dialect: &dyn crate::dialect::SqlDialect,
56 ) -> (String, Vec<FilterValue>) {
57 let (where_sql, params) = self.filter.to_sql(0, dialect);
58
59 let mut sql = String::new();
60
61 sql.push_str("SELECT COUNT(");
63 match &self.distinct {
64 Some(col) => {
65 sql.push_str("DISTINCT ");
66 sql.push_str(col);
67 }
68 None => sql.push('*'),
69 }
70 sql.push(')');
71
72 sql.push_str(" FROM ");
74 sql.push_str(M::TABLE_NAME);
75
76 if !self.filter.is_none() {
78 sql.push_str(" WHERE ");
79 sql.push_str(&where_sql);
80 }
81
82 (sql, params)
83 }
84
85 pub fn with_where_input<W: crate::inputs::WhereInput<Model = M>>(mut self, w: W) -> Self {
87 let f = w.into_ir();
88 self.filter = self.filter.and_then(f);
89 self
90 }
91
92 #[doc(hidden)]
94 pub fn filter_for_test(&self) -> &Filter {
95 &self.filter
96 }
97
98 pub async fn exec(self) -> QueryResult<u64> {
100 let dialect = self.engine.dialect();
101 let (sql, params) = self.build_sql(dialect);
102 self.engine.count(&sql, params).await
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::error::QueryError;
110
111 struct TestModel;
112
113 impl Model for TestModel {
114 const MODEL_NAME: &'static str = "TestModel";
115 const TABLE_NAME: &'static str = "test_models";
116 const PRIMARY_KEY: &'static [&'static str] = &["id"];
117 const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
118 }
119
120 impl crate::row::FromRow for TestModel {
121 fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
122 Ok(TestModel)
123 }
124 }
125
126 #[derive(Clone)]
127 struct MockEngine {
128 count_result: u64,
129 }
130
131 impl MockEngine {
132 fn new() -> Self {
133 Self { count_result: 0 }
134 }
135
136 fn with_count(count: u64) -> Self {
137 Self {
138 count_result: count,
139 }
140 }
141 }
142
143 impl QueryEngine for MockEngine {
144 fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
145 &crate::dialect::Postgres
146 }
147
148 fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
149 &self,
150 _sql: &str,
151 _params: Vec<FilterValue>,
152 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
153 Box::pin(async { Ok(Vec::new()) })
154 }
155
156 fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
157 &self,
158 _sql: &str,
159 _params: Vec<FilterValue>,
160 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
161 Box::pin(async { Err(QueryError::not_found("test")) })
162 }
163
164 fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
165 &self,
166 _sql: &str,
167 _params: Vec<FilterValue>,
168 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
169 Box::pin(async { Ok(None) })
170 }
171
172 fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
173 &self,
174 _sql: &str,
175 _params: Vec<FilterValue>,
176 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
177 Box::pin(async { Err(QueryError::not_found("test")) })
178 }
179
180 fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
181 &self,
182 _sql: &str,
183 _params: Vec<FilterValue>,
184 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
185 Box::pin(async { Ok(Vec::new()) })
186 }
187
188 fn execute_delete(
189 &self,
190 _sql: &str,
191 _params: Vec<FilterValue>,
192 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
193 Box::pin(async { Ok(0) })
194 }
195
196 fn execute_raw(
197 &self,
198 _sql: &str,
199 _params: Vec<FilterValue>,
200 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
201 Box::pin(async { Ok(0) })
202 }
203
204 fn count(
205 &self,
206 _sql: &str,
207 _params: Vec<FilterValue>,
208 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
209 let count = self.count_result;
210 Box::pin(async move { Ok(count) })
211 }
212 }
213
214 #[test]
217 fn test_count_new() {
218 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
219 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
220
221 assert!(sql.contains("SELECT COUNT(*)"));
222 assert!(sql.contains("FROM test_models"));
223 assert!(params.is_empty());
224 }
225
226 #[test]
227 fn test_count_basic() {
228 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
229 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
230
231 assert_eq!(sql, "SELECT COUNT(*) FROM test_models");
232 assert!(params.is_empty());
233 }
234
235 #[test]
238 fn test_count_with_filter() {
239 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
240 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
241
242 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
243
244 assert!(sql.contains("WHERE"));
245 assert!(sql.contains(r#""active" = $1"#));
246 assert_eq!(params.len(), 1);
247 }
248
249 #[test]
250 fn test_count_with_compound_filter() {
251 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
252 .r#where(Filter::Equals(
253 "status".into(),
254 FilterValue::String("active".to_string()),
255 ))
256 .r#where(Filter::Gte("age".into(), FilterValue::Int(18)));
257
258 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
259
260 assert!(sql.contains("WHERE"));
261 assert!(sql.contains("AND"));
262 assert_eq!(params.len(), 2);
263 }
264
265 #[test]
266 fn test_count_with_or_filter() {
267 let op =
268 CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::or([
269 Filter::Equals("role".into(), FilterValue::String("admin".to_string())),
270 Filter::Equals("role".into(), FilterValue::String("moderator".to_string())),
271 ]));
272
273 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
274
275 assert!(sql.contains("OR"));
276 assert_eq!(params.len(), 2);
277 }
278
279 #[test]
280 fn test_count_with_in_filter() {
281 let op =
282 CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(Filter::In(
283 "status".into(),
284 vec![
285 FilterValue::String("pending".to_string()),
286 FilterValue::String("processing".to_string()),
287 FilterValue::String("completed".to_string()),
288 ],
289 ));
290
291 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
292
293 assert!(sql.contains("IN"));
294 assert_eq!(params.len(), 3);
295 }
296
297 #[test]
298 fn test_count_without_filter() {
299 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
300 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
301
302 assert!(!sql.contains("WHERE"));
303 assert!(params.is_empty());
304 }
305
306 #[test]
307 fn test_count_with_null_filter() {
308 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
309 .r#where(Filter::IsNull("deleted_at".into()));
310
311 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
312
313 assert!(sql.contains("WHERE"));
314 assert!(sql.contains("IS NULL"));
315 assert!(params.is_empty());
316 }
317
318 #[test]
319 fn test_count_with_not_null_filter() {
320 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
321 .r#where(Filter::IsNotNull("verified_at".into()));
322
323 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
324
325 assert!(sql.contains("IS NOT NULL"));
326 assert!(params.is_empty());
327 }
328
329 #[test]
332 fn test_count_distinct() {
333 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).distinct("email");
334
335 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
336
337 assert!(sql.contains("COUNT(DISTINCT email)"));
338 assert!(!sql.contains("COUNT(*)"));
339 }
340
341 #[test]
342 fn test_count_distinct_with_filter() {
343 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
344 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
345 .distinct("user_id");
346
347 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
348
349 assert!(sql.contains("COUNT(DISTINCT user_id)"));
350 assert!(sql.contains("WHERE"));
351 assert_eq!(params.len(), 1);
352 }
353
354 #[test]
355 fn test_count_distinct_replaces() {
356 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
358 .distinct("email")
359 .distinct("user_id");
360
361 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
362
363 assert!(sql.contains("COUNT(DISTINCT user_id)"));
364 assert!(!sql.contains("COUNT(DISTINCT email)"));
365 }
366
367 #[test]
370 fn test_count_sql_structure() {
371 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
372 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
373
374 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
375
376 let count_pos = sql.find("COUNT").unwrap();
377 let from_pos = sql.find("FROM").unwrap();
378 let where_pos = sql.find("WHERE").unwrap();
379
380 assert!(count_pos < from_pos);
381 assert!(from_pos < where_pos);
382 }
383
384 #[test]
385 fn test_count_table_name() {
386 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
387 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
388
389 assert!(sql.contains("test_models"));
390 }
391
392 #[tokio::test]
395 async fn test_count_exec() {
396 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(42));
397
398 let result = op.exec().await;
399
400 assert!(result.is_ok());
401 assert_eq!(result.unwrap(), 42);
402 }
403
404 #[tokio::test]
405 async fn test_count_exec_with_filter() {
406 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::with_count(10))
407 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)));
408
409 let result = op.exec().await;
410
411 assert!(result.is_ok());
412 assert_eq!(result.unwrap(), 10);
413 }
414
415 #[tokio::test]
416 async fn test_count_exec_zero() {
417 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new());
418
419 let result = op.exec().await;
420
421 assert!(result.is_ok());
422 assert_eq!(result.unwrap(), 0);
423 }
424
425 #[test]
428 fn test_count_method_chaining() {
429 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new())
430 .r#where(Filter::Equals(
431 "status".into(),
432 FilterValue::String("active".to_string()),
433 ))
434 .distinct("user_id");
435
436 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
437
438 assert!(sql.contains("COUNT(DISTINCT user_id)"));
439 assert!(sql.contains("WHERE"));
440 assert_eq!(params.len(), 1);
441 }
442
443 #[test]
446 fn test_count_with_like_filter() {
447 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
448 Filter::Contains(
449 "email".into(),
450 FilterValue::String("@example.com".to_string()),
451 ),
452 );
453
454 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
455
456 assert!(sql.contains("LIKE"));
457 assert_eq!(params.len(), 1);
458 }
459
460 #[test]
461 fn test_count_with_starts_with() {
462 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
463 Filter::StartsWith("name".into(), FilterValue::String("A".to_string())),
464 );
465
466 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
467
468 assert!(sql.contains("LIKE"));
469 assert_eq!(params.len(), 1);
470 }
471
472 #[test]
473 fn test_count_with_not_filter() {
474 let op = CountOperation::<MockEngine, TestModel>::new(MockEngine::new()).r#where(
475 Filter::Not(Box::new(Filter::Equals(
476 "status".into(),
477 FilterValue::String("deleted".to_string()),
478 ))),
479 );
480
481 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
482
483 assert!(sql.contains("NOT"));
484 assert_eq!(params.len(), 1);
485 }
486}