Skip to main content

chopin_orm/
builder.rs

1use crate::{Model, OrmError, OrmResult, PgValue};
2use std::marker::PhantomData;
3
4/// A type alias for `Condition<M>`, representing a SQL expression for a specific Model.
5pub type Expr<M> = Condition<M>;
6
7/// A tree representing a SQL condition (WHERE or HAVING clause) used to filter queries.
8/// Supports generic `And` and `Or` nesting and intelligently binds indexed parameters.
9pub enum Condition<M> {
10    Raw(String, Vec<PgValue>, PhantomData<M>),
11    And(Vec<Condition<M>>),
12    Or(Vec<Condition<M>>),
13}
14
15impl<M> Clone for Condition<M> {
16    fn clone(&self) -> Self {
17        match self {
18            Condition::Raw(c, p, _) => Condition::Raw(c.clone(), p.clone(), PhantomData),
19            Condition::And(c) => Condition::And(c.clone()),
20            Condition::Or(c) => Condition::Or(c.clone()),
21        }
22    }
23}
24
25impl<M> Condition<M> {
26    /// Create a raw SQL condition segment.
27    /// Use `{}` as a placeholder for parameterized values.
28    /// ```ignore
29    /// Condition::new("age > {}", vec![25.to_param()])
30    /// ```
31    pub fn new(clause: impl Into<String>, params: Vec<PgValue>) -> Self {
32        Condition::Raw(clause.into(), params, PhantomData)
33    }
34
35    /// Combine this condition with another using `AND`.
36    pub fn and(self, other: Self) -> Self {
37        match self {
38            Condition::And(mut conds) => {
39                conds.push(other);
40                Condition::And(conds)
41            }
42            _ => Condition::And(vec![self, other]),
43        }
44    }
45
46    /// Combine this condition with another using `OR`.
47    pub fn or(self, other: Self) -> Self {
48        match self {
49            Condition::Or(mut conds) => {
50                conds.push(other);
51                Condition::Or(conds)
52            }
53            _ => Condition::Or(vec![self, other]),
54        }
55    }
56
57    /// Resolves the condition tree into a parameterized SQL string.
58    ///
59    /// Collects references to the `PgValue` parameters owned by this condition tree.
60    /// The returned references are valid as long as `self` is alive.
61    ///
62    /// # Safety / Security
63    /// This method performs placeholder mapping from `{}` to `$n`. While standard
64    /// DSL methods in `ColumnTrait` are safe, using `Condition::Raw` with unsanitized
65    /// user input in the clause string can lead to SQL injection. Always use `{}`
66    /// for values and pass them via the `params` vector.
67    fn resolve<'a>(&'a self, param_idx: &mut usize, params_out: &mut Vec<&'a PgValue>) -> String {
68        match self {
69            Condition::Raw(clause, params, _) => {
70                let mut resolved = String::with_capacity(clause.len());
71                let mut chars = clause.chars().peekable();
72                while let Some(c) = chars.next() {
73                    if c == '{' && chars.peek() == Some(&'}') {
74                        chars.next();
75                        resolved.push('$');
76                        resolved.push_str(&param_idx.to_string());
77                        *param_idx += 1;
78                    } else {
79                        resolved.push(c);
80                    }
81                }
82                params_out.extend(params.iter());
83                resolved
84            }
85            Condition::And(conds) => {
86                let resolved: Vec<_> = conds
87                    .iter()
88                    .map(|c| c.resolve(param_idx, params_out))
89                    .collect();
90                format!("({})", resolved.join(" AND "))
91            }
92            Condition::Or(conds) => {
93                let resolved: Vec<_> = conds
94                    .iter()
95                    .map(|c| c.resolve(param_idx, params_out))
96                    .collect();
97                format!("({})", resolved.join(" OR "))
98            }
99        }
100    }
101}
102
103/// Trait for defining operations natively on database columns.
104pub trait ColumnTrait<M: Model> {
105    fn column_name(&self) -> &'static str;
106
107    fn eq(self, val: impl crate::ToSql) -> Expr<M>
108    where
109        Self: Sized,
110    {
111        Expr::new(format!("{} = {{}}", self.column_name()), vec![val.to_sql()])
112    }
113    fn neq(self, val: impl crate::ToSql) -> Expr<M>
114    where
115        Self: Sized,
116    {
117        Expr::new(
118            format!("{} != {{}}", self.column_name()),
119            vec![val.to_sql()],
120        )
121    }
122    fn gt(self, val: impl crate::ToSql) -> Expr<M>
123    where
124        Self: Sized,
125    {
126        Expr::new(format!("{} > {{}}", self.column_name()), vec![val.to_sql()])
127    }
128    fn gte(self, val: impl crate::ToSql) -> Expr<M>
129    where
130        Self: Sized,
131    {
132        Expr::new(
133            format!("{} >= {{}}", self.column_name()),
134            vec![val.to_sql()],
135        )
136    }
137    fn lt(self, val: impl crate::ToSql) -> Expr<M>
138    where
139        Self: Sized,
140    {
141        Expr::new(format!("{} < {{}}", self.column_name()), vec![val.to_sql()])
142    }
143    fn lte(self, val: impl crate::ToSql) -> Expr<M>
144    where
145        Self: Sized,
146    {
147        Expr::new(
148            format!("{} <= {{}}", self.column_name()),
149            vec![val.to_sql()],
150        )
151    }
152    #[allow(clippy::wrong_self_convention)]
153    fn is_null(self) -> Expr<M>
154    where
155        Self: Sized,
156    {
157        Expr::new(format!("{} IS NULL", self.column_name()), vec![])
158    }
159    #[allow(clippy::wrong_self_convention)]
160    fn is_not_null(self) -> Expr<M>
161    where
162        Self: Sized,
163    {
164        Expr::new(format!("{} IS NOT NULL", self.column_name()), vec![])
165    }
166    fn count(self) -> Expr<M>
167    where
168        Self: Sized,
169    {
170        Expr::new(format!("COUNT({})", self.column_name()), vec![])
171    }
172    fn sum(self) -> Expr<M>
173    where
174        Self: Sized,
175    {
176        Expr::new(format!("SUM({})", self.column_name()), vec![])
177    }
178    fn max(self) -> Expr<M>
179    where
180        Self: Sized,
181    {
182        Expr::new(format!("MAX({})", self.column_name()), vec![])
183    }
184    fn min(self) -> Expr<M>
185    where
186        Self: Sized,
187    {
188        Expr::new(format!("MIN({})", self.column_name()), vec![])
189    }
190    fn like(self, val: impl crate::ToSql) -> Expr<M>
191    where
192        Self: Sized,
193    {
194        Expr::new(
195            format!("{} LIKE {{}}", self.column_name()),
196            vec![val.to_sql()],
197        )
198    }
199    fn ilike(self, val: impl crate::ToSql) -> Expr<M>
200    where
201        Self: Sized,
202    {
203        Expr::new(
204            format!("{} ILIKE {{}}", self.column_name()),
205            vec![val.to_sql()],
206        )
207    }
208    #[allow(clippy::wrong_self_convention)]
209    fn is_in<T: crate::ToSql>(self, vals: Vec<T>) -> Expr<M>
210    where
211        Self: Sized,
212    {
213        let placeholders: Vec<String> = (0..vals.len()).map(|_| "{}".to_string()).collect();
214        let params: Vec<PgValue> = vals.into_iter().map(|v| v.to_sql()).collect();
215        Expr::new(
216            format!("{} IN ({})", self.column_name(), placeholders.join(", ")),
217            params,
218        )
219    }
220}
221
222/// A type-safe SQL query builder.
223///
224/// Constructed primarily via `<Model>::find()` and `<Model>::select(...)`.
225/// Chain methods like `.filter()`, `.order_by()`, and `.limit()` before executing.
226#[must_use = "QueryBuilder does nothing until executed with .all(), .one(), .count(), etc."]
227pub struct QueryBuilder<M> {
228    _marker: PhantomData<M>,
229    select_override: Option<Vec<Expr<M>>>,
230    joins: Vec<String>,
231    filters: Vec<Expr<M>>,
232    group_by: Option<String>,
233    having: Vec<Expr<M>>,
234    order_by: Option<String>,
235    limit: Option<usize>,
236    offset: Option<usize>,
237}
238
239impl<M> Clone for QueryBuilder<M> {
240    fn clone(&self) -> Self {
241        Self {
242            _marker: PhantomData,
243            select_override: self.select_override.clone(),
244            joins: self.joins.clone(),
245            filters: self.filters.clone(),
246            group_by: self.group_by.clone(),
247            having: self.having.clone(),
248            order_by: self.order_by.clone(),
249            limit: self.limit,
250            offset: self.offset,
251        }
252    }
253}
254
255impl<M: Model + Send + Sync> Default for QueryBuilder<M> {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261impl<M: Model + Send + Sync> QueryBuilder<M> {
262    pub fn new() -> Self {
263        Self {
264            _marker: PhantomData,
265            select_override: None,
266            joins: Vec::new(),
267            filters: Vec::new(),
268            group_by: None,
269            having: Vec::new(),
270            order_by: None,
271            limit: None,
272            offset: None,
273        }
274    }
275
276    /// Add a WHERE filter clause, e.g., filter(Expr::new("age >= {}", vec![18.to_param()]))
277    /// Or using DSL: filter(UserColumn::Age.gte(18.to_param()))
278    /// Backwards compatibility raw signature with $1 etc is also supported if string literal passes without `{}`
279    pub fn filter<E>(mut self, expr: E) -> Self
280    where
281        E: IntoExpr<M>,
282    {
283        self.filters.push(expr.into_expr());
284        self
285    }
286
287    pub fn select_only<E: IntoExpr<M>>(mut self, exprs: Vec<E>) -> Self {
288        self.select_override = Some(exprs.into_iter().map(|e| e.into_expr()).collect());
289        self
290    }
291
292    pub fn join(mut self, clause: &str) -> Self {
293        self.joins.push(clause.into());
294        self
295    }
296
297    /// Adds an `INNER JOIN` automatically resolving foreign keys using `HasForeignKey`.
298    pub fn join_child<R: Model + crate::HasForeignKey<M>>(mut self) -> Self {
299        let (other_table, mappings) = R::foreign_key_info();
300        let my_table = M::table_name();
301
302        let join_on = mappings
303            .iter()
304            .map(|(child_col, parent_col)| {
305                format!(
306                    "{}.{} = {}.{}",
307                    other_table, child_col, my_table, parent_col
308                )
309            })
310            .collect::<Vec<_>>()
311            .join(" AND ");
312
313        self.joins
314            .push(format!("JOIN {} ON {}", other_table, join_on));
315        self
316    }
317
318    /// Adds an `INNER JOIN` automatically resolving the parent entity foreign keys.
319    pub fn join_parent<R: Model>(mut self) -> Self
320    where
321        M: crate::HasForeignKey<R>,
322    {
323        let (my_table, mappings) = M::foreign_key_info();
324        let other_table = R::table_name();
325
326        let join_on = mappings
327            .iter()
328            .map(|(local_col, parent_col)| {
329                format!(
330                    "{}.{} = {}.{}",
331                    other_table, parent_col, my_table, local_col
332                )
333            })
334            .collect::<Vec<_>>()
335            .join(" AND ");
336
337        self.joins
338            .push(format!("JOIN {} ON {}", other_table, join_on));
339        self
340    }
341
342    pub fn group_by(mut self, clause: &str) -> Self {
343        self.group_by = Some(clause.into());
344        self
345    }
346
347    pub fn having<E: IntoExpr<M>>(mut self, expr: E) -> Self {
348        self.having.push(expr.into_expr());
349        self
350    }
351
352    pub fn order_by(mut self, clause: &str) -> Self {
353        self.order_by = Some(clause.to_string());
354        self
355    }
356
357    pub fn limit(mut self, limit: usize) -> Self {
358        self.limit = Some(limit);
359        self
360    }
361
362    pub fn offset(mut self, offset: usize) -> Self {
363        self.offset = Some(offset);
364        self
365    }
366
367    pub(crate) fn build_query(&self) -> (String, Vec<&PgValue>) {
368        let mut all_params: Vec<&PgValue> = Vec::new();
369        let mut param_idx = 1;
370
371        let select_clause = if let Some(exprs) = &self.select_override {
372            let mapped: Vec<_> = exprs
373                .iter()
374                .map(|e| e.resolve(&mut param_idx, &mut all_params))
375                .collect();
376            mapped.join(", ")
377        } else {
378            M::select_clause().to_string()
379        };
380
381        let mut query = format!("SELECT {} FROM {}", select_clause, M::table_name());
382
383        if !self.joins.is_empty() {
384            query.push(' ');
385            query.push_str(&self.joins.join(" "));
386        }
387
388        if !self.filters.is_empty() {
389            query.push_str(" WHERE ");
390            let filter_strings: Vec<_> = self
391                .filters
392                .iter()
393                .map(|e| e.resolve(&mut param_idx, &mut all_params))
394                .collect();
395            query.push_str(&filter_strings.join(" AND "));
396        }
397
398        if let Some(gb) = &self.group_by {
399            query.push_str(" GROUP BY ");
400            query.push_str(gb);
401        }
402
403        if !self.having.is_empty() {
404            query.push_str(" HAVING ");
405            let having_strings: Vec<_> = self
406                .having
407                .iter()
408                .map(|e| e.resolve(&mut param_idx, &mut all_params))
409                .collect();
410            query.push_str(&having_strings.join(" AND "));
411        }
412
413        if let Some(order) = &self.order_by {
414            query.push_str(" ORDER BY ");
415            query.push_str(order);
416        }
417
418        if let Some(limit) = self.limit {
419            query.push_str(&format!(" LIMIT {}", limit));
420        }
421
422        if let Some(offset) = self.offset {
423            query.push_str(&format!(" OFFSET {}", offset));
424        }
425
426        (query, all_params)
427    }
428
429    /// Executes the query and returns raw `Row` results without model mapping.
430    pub fn into_raw(self, executor: &mut impl crate::Executor) -> OrmResult<Vec<crate::Row>> {
431        let (query, all_params) = self.build_query();
432        #[cfg(feature = "log")]
433        log::debug!("into_raw: {} | params: {}", query, all_params.len());
434        let params_ref: Vec<&dyn chopin_pg::types::ToSql> =
435            all_params.iter().map(|p| *p as _).collect();
436        executor.query(&query, &params_ref)
437    }
438
439    /// Executes the query and returns a list of models.
440    pub fn all(self, executor: &mut impl crate::Executor) -> OrmResult<Vec<M>> {
441        let (query, all_params) = self.build_query();
442
443        let params_ref: Vec<&dyn chopin_pg::types::ToSql> =
444            all_params.iter().map(|p| *p as _).collect();
445
446        let rows = executor.query(&query, &params_ref)?;
447
448        let mut result = Vec::with_capacity(rows.len());
449        for row in rows {
450            result.push(M::from_row(&row)?);
451        }
452        Ok(result)
453    }
454
455    /// Converts this query builder into a `Paginator` using the specified page size.
456    pub fn paginate(self, page_size: usize) -> Paginator<M> {
457        Paginator::new(self, page_size)
458    }
459
460    /// Executes the query, returning the first matching model, or `None` if not found.
461    pub fn one(mut self, executor: &mut impl crate::Executor) -> OrmResult<Option<M>> {
462        self.limit = Some(1);
463        let mut all = self.all(executor)?;
464        Ok(all.pop())
465    }
466
467    /// Executes a `COUNT(*)` query for the current filters.
468    pub fn count(mut self, executor: &mut impl crate::Executor) -> OrmResult<i64> {
469        self.select_override = Some(vec![Expr::new("COUNT(*)", vec![])]);
470        let (query, all_params) = self.build_query();
471
472        let params_ref: Vec<&dyn chopin_pg::types::ToSql> =
473            all_params.iter().map(|p| *p as _).collect();
474
475        let rows = executor.query(&query, &params_ref)?;
476        if let Some(row) = rows.first() {
477            let val: PgValue = row.get(0).map_err(OrmError::from)?;
478            return Ok(match val {
479                PgValue::Int8(v) => v,
480                PgValue::Int4(v) => v as i64,
481                PgValue::Text(s) => s.parse().unwrap_or(0),
482                _ => 0,
483            });
484        }
485        Ok(0)
486    }
487}
488
489/// A paginated result set holding data and metadata (counts).
490#[derive(Debug)]
491pub struct Page<M> {
492    pub items: Vec<M>,
493    pub total: i64,
494    pub page: usize,
495    pub page_size: usize,
496    pub total_pages: usize,
497}
498
499impl<M> Page<M> {
500    /// Returns `true` if there are more pages after this one.
501    pub fn has_next(&self) -> bool {
502        self.page < self.total_pages
503    }
504
505    /// Returns `true` if there are pages before this one.
506    pub fn has_prev(&self) -> bool {
507        self.page > 1
508    }
509}
510
511/// An iterator-like coordinator wrapping a `QueryBuilder` for pagination slicing.
512#[must_use = "Paginator does nothing until .fetch() is called"]
513pub struct Paginator<M> {
514    builder: QueryBuilder<M>,
515    page_size: usize,
516    page: usize,
517}
518
519impl<M: Model + Send + Sync> Paginator<M> {
520    pub fn new(builder: QueryBuilder<M>, page_size: usize) -> Self {
521        Self {
522            builder,
523            page_size,
524            page: 1,
525        }
526    }
527
528    /// Advances the paginator to the specified page (1-indexed).
529    pub fn page(mut self, page: usize) -> Self {
530        self.page = page;
531        self
532    }
533
534    /// Executes the underlying count and slice queries, returning a populated `Page`.
535    pub fn fetch(self, executor: &mut impl crate::Executor) -> OrmResult<Page<M>> {
536        let total = self.builder.clone().count(executor)?;
537        let offset = self.page_size * self.page.saturating_sub(1);
538
539        let items = self
540            .builder
541            .limit(self.page_size)
542            .offset(offset)
543            .all(executor)?;
544
545        let total_pages = (total as usize).div_ceil(self.page_size);
546
547        Ok(Page {
548            items,
549            total,
550            page: self.page,
551            page_size: self.page_size,
552            total_pages,
553        })
554    }
555}
556
557pub trait IntoExpr<M> {
558    fn into_expr(self) -> Expr<M>;
559}
560
561impl<M> IntoExpr<M> for Expr<M> {
562    fn into_expr(self) -> Expr<M> {
563        self
564    }
565}
566
567// For backwards compatibility filter("age = $1", vec![...])
568impl<M, S: Into<String>> IntoExpr<M> for (S, Vec<PgValue>) {
569    fn into_expr(self) -> Expr<M> {
570        Expr::new(self.0.into(), self.1)
571    }
572}
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use crate::{FromRow, Model};
578    use chopin_pg::Row;
579
580    struct MockModel {
581        pub id: i32,
582    }
583
584    impl crate::Validate for MockModel {}
585
586    impl FromRow for MockModel {
587        fn from_row(_row: &Row) -> OrmResult<Self> {
588            Ok(Self { id: 0 })
589        }
590    }
591
592    impl Model for MockModel {
593        fn table_name() -> &'static str {
594            "mocks"
595        }
596        fn primary_key_columns() -> &'static [&'static str] {
597            &["id"]
598        }
599        fn generated_columns() -> &'static [&'static str] {
600            &["id"]
601        }
602        fn columns() -> &'static [&'static str] {
603            &["id", "name"]
604        }
605        fn select_clause() -> &'static str {
606            "id, name"
607        }
608        fn primary_key_values(&self) -> Vec<PgValue> {
609            vec![PgValue::Int4(self.id)]
610        }
611        fn set_generated_values(&mut self, mut vals: Vec<PgValue>) -> OrmResult<()> {
612            if vals.is_empty() {
613                return Ok(());
614            }
615            if let PgValue::Int4(v) = vals.remove(0) {
616                self.id = v;
617            }
618            Ok(())
619        }
620        fn get_values(&self) -> Vec<PgValue> {
621            vec![]
622        }
623        fn create_table_stmt() -> String {
624            "".into()
625        }
626        fn column_definitions() -> Vec<(&'static str, &'static str)> {
627            vec![]
628        }
629    }
630
631    enum MockColumn {
632        Id,
633        Name,
634    }
635
636    impl ColumnTrait<MockModel> for MockColumn {
637        fn column_name(&self) -> &'static str {
638            match self {
639                Self::Id => "id",
640                Self::Name => "name",
641            }
642        }
643    }
644
645    #[test]
646    fn test_query_builder_sql_generation() {
647        let qb: QueryBuilder<MockModel> = QueryBuilder::new();
648        assert_eq!(qb.build_query().0, "SELECT id, name FROM mocks");
649
650        let qb = QueryBuilder::<MockModel>::new()
651            .filter(("name = $1", vec![]))
652            .filter(("id > $2", vec![]));
653        assert_eq!(
654            qb.build_query().0,
655            "SELECT id, name FROM mocks WHERE name = $1 AND id > $2"
656        );
657
658        let qb = QueryBuilder::<MockModel>::new()
659            .order_by("name DESC")
660            .limit(10)
661            .offset(5);
662        assert_eq!(
663            qb.build_query().0,
664            "SELECT id, name FROM mocks ORDER BY name DESC LIMIT 10 OFFSET 5"
665        );
666    }
667
668    #[test]
669    fn test_order_by_without_where() {
670        let qb = QueryBuilder::<MockModel>::new().order_by("id ASC");
671        assert_eq!(
672            qb.build_query().0,
673            "SELECT id, name FROM mocks ORDER BY id ASC"
674        );
675    }
676
677    #[test]
678    fn test_limit_only() {
679        let qb = QueryBuilder::<MockModel>::new().limit(20);
680        assert_eq!(qb.build_query().0, "SELECT id, name FROM mocks LIMIT 20");
681    }
682
683    #[test]
684    fn test_offset_only() {
685        let qb = QueryBuilder::<MockModel>::new().offset(15);
686        assert_eq!(qb.build_query().0, "SELECT id, name FROM mocks OFFSET 15");
687    }
688
689    #[test]
690    fn test_limit_and_offset_without_where() {
691        let qb = QueryBuilder::<MockModel>::new().limit(5).offset(10);
692        assert_eq!(
693            qb.build_query().0,
694            "SELECT id, name FROM mocks LIMIT 5 OFFSET 10"
695        );
696    }
697
698    #[test]
699    fn test_multiple_filters() {
700        let qb = QueryBuilder::<MockModel>::new()
701            .filter(("id > $1", vec![]))
702            .filter(("name = $2", vec![]))
703            .filter(("active = $3", vec![]));
704        assert_eq!(
705            qb.build_query().0,
706            "SELECT id, name FROM mocks WHERE id > $1 AND name = $2 AND active = $3"
707        );
708    }
709
710    #[test]
711    fn test_full_query_with_all_clauses() {
712        let qb = QueryBuilder::<MockModel>::new()
713            .filter(("status = $1", vec![]))
714            .order_by("created_at DESC")
715            .limit(25)
716            .offset(50);
717        assert_eq!(
718            qb.build_query().0,
719            "SELECT id, name FROM mocks WHERE status = $1 ORDER BY created_at DESC LIMIT 25 OFFSET 50"
720        );
721    }
722
723    #[test]
724    fn test_default_equals_new() {
725        let qb_default: QueryBuilder<MockModel> = Default::default();
726        let qb_new: QueryBuilder<MockModel> = QueryBuilder::new();
727        assert_eq!(qb_default.build_query().0, qb_new.build_query().0);
728    }
729
730    #[test]
731    fn test_no_clauses_is_plain_select() {
732        let qb: QueryBuilder<MockModel> = QueryBuilder::new();
733        assert_eq!(qb.build_query().0, "SELECT id, name FROM mocks");
734    }
735
736    #[test]
737    fn test_dsl_generation() {
738        let qb = QueryBuilder::<MockModel>::new()
739            .filter(MockColumn::Id.gt(10))
740            .filter(MockColumn::Name.eq("test"));
741        let (sql, _) = qb.build_query();
742        assert_eq!(
743            sql,
744            "SELECT id, name FROM mocks WHERE id > $1 AND name = $2"
745        );
746    }
747}