Skip to main content

rust_query/value/
aggregate.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4    rc::Rc,
5};
6
7use crate::{
8    Expr, IntoExpr, lower,
9    rows::Rows,
10    value::{EqTyp, NumTyp},
11};
12
13/// This is the argument type used for [aggregate].
14pub struct Aggregate<'outer, 'inner, S> {
15    pub(crate) query: Rows<'inner, S>,
16    _p: PhantomData<&'inner &'outer ()>,
17}
18
19impl<'inner, S> Deref for Aggregate<'_, 'inner, S> {
20    type Target = Rows<'inner, S>;
21
22    fn deref(&self) -> &Self::Target {
23        &self.query
24    }
25}
26
27impl<S> DerefMut for Aggregate<'_, '_, S> {
28    fn deref_mut(&mut self) -> &mut Self::Target {
29        &mut self.query
30    }
31}
32
33impl<'outer, 'inner, S: 'static> Aggregate<'outer, 'inner, S> {
34    /// This must be used with an aggregating expression.
35    /// otherwise there is a chance that there are multiple rows.
36    fn select_func(&self, agg_func: &'static str, val: Rc<lower::Expr>) -> Rc<lower::Expr> {
37        let expr = Rc::new(lower::Expr::Func(agg_func, Box::new([val])));
38        Rc::new(lower::Expr::AggrIndex(self.ast.clone(), expr))
39    }
40
41    /// Return the average value in a column, this is [None] if there are zero rows.
42    ///
43    /// ```
44    /// # use rust_query::private::doctest_aggregate::*;
45    /// # get_txn(|txn| {
46    /// for x in [1, 2, 3] {
47    ///     txn.insert_ok(Val { x });
48    /// }
49    /// let (avg1, avg2) = txn.query_one(aggregate(|rows| {
50    ///     let val = rows.join(Val);
51    ///     let avg1 = rows.avg(val.x.to_f64());
52    ///     rows.filter(false); // remove all rows
53    ///     let avg2 = rows.avg(val.x.to_f64());
54    ///     (avg1, avg2)
55    /// }));
56    /// assert_eq!(avg1, Some(2.0));
57    /// assert_eq!(avg2, None);
58    /// # });
59    /// ```
60    pub fn avg(&self, val: impl IntoExpr<'inner, S, Typ = f64>) -> Expr<'outer, S, Option<f64>> {
61        let val = val.into_expr().inner;
62        Expr::new(self.select_func("avg", val))
63    }
64
65    /// Return the maximum value in a column, this is [None] if there are zero rows.
66    ///
67    /// ```
68    /// # use rust_query::private::doctest_aggregate::*;
69    /// # get_txn(|txn| {
70    /// for x in [-100, 10, 42] {
71    ///     txn.insert_ok(Val { x });
72    /// }
73    /// let (max1, max2) = txn.query_one(aggregate(|rows| {
74    ///     let val = rows.join(Val);
75    ///     let max1 = rows.max(&val.x);
76    ///     rows.filter(false); // remove all rows
77    ///     let max2 = rows.max(&val.x);
78    ///     (max1, max2)
79    /// }));
80    /// assert_eq!(max1, Some(42));
81    /// assert_eq!(max2, None);
82    /// # });
83    /// ```
84    pub fn max<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
85    where
86        T: EqTyp,
87    {
88        let val = val.into_expr().inner;
89        Expr::new(self.select_func("max", val))
90    }
91
92    /// Return the minimum value in a column, this is [None] if there are zero rows.
93    ///
94    /// ```
95    /// # use rust_query::private::doctest_aggregate::*;
96    /// # get_txn(|txn| {
97    /// for x in [-100, 10, 42] {
98    ///     txn.insert_ok(Val { x });
99    /// }
100    /// let (min1, min2) = txn.query_one(aggregate(|rows| {
101    ///     let val = rows.join(Val);
102    ///     let min1 = rows.min(&val.x);
103    ///     rows.filter(false); // remove all rows
104    ///     let min2 = rows.min(&val.x);
105    ///     (min1, min2)
106    /// }));
107    /// assert_eq!(min1, Some(-100));
108    /// assert_eq!(min2, None);
109    /// # });
110    /// ```
111    pub fn min<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, Option<T>>
112    where
113        T: EqTyp,
114    {
115        let val = val.into_expr().inner;
116        Expr::new(self.select_func("min", val))
117    }
118
119    /// Return the sum of a column.
120    ///
121    /// ```
122    /// # use rust_query::private::doctest_aggregate::*;
123    /// # get_txn(|txn| {
124    /// for x in [-100, 10, 42] {
125    ///     txn.insert_ok(Val { x });
126    /// }
127    /// let (sum1, sum2) = txn.query_one(aggregate(|rows| {
128    ///     let val = rows.join(Val);
129    ///     let sum1 = rows.sum(&val.x);
130    ///     rows.filter(false); // remove all rows
131    ///     let sum2 = rows.sum(&val.x);
132    ///     (sum1, sum2)
133    /// }));
134    /// assert_eq!(sum1, -48);
135    /// assert_eq!(sum2, 0);
136    /// # });
137    /// ```
138    pub fn sum<T>(&self, val: impl IntoExpr<'inner, S, Typ = T>) -> Expr<'outer, S, T>
139    where
140        T: NumTyp,
141    {
142        let val = val.into_expr().inner;
143        let val = self.select_func("sum", val);
144
145        Expr::adhoc(lower::Expr::Func(
146            "IFNULL",
147            Box::new([val, Rc::new(lower::Expr::Constant(T::ZERO))]),
148        ))
149    }
150
151    /// Return the number of distinct values in a column.
152    ///
153    /// ```
154    /// # use rust_query::private::doctest_aggregate::*;
155    /// # get_txn(|txn| {
156    /// for x in [-100, 10, 42, 10] {
157    ///     txn.insert_ok(Val { x });
158    /// }
159    /// let (count1, count2) = txn.query_one(aggregate(|rows| {
160    ///     let val = rows.join(Val);
161    ///     let count1 = rows.count_distinct(&val.x);
162    ///     rows.filter(false); // remove all rows
163    ///     let count2 = rows.count_distinct(&val.x);
164    ///     (count1, count2)
165    /// }));
166    /// assert_eq!(count1, 3);
167    /// assert_eq!(count2, 0);
168    /// # });
169    /// ```
170    pub fn count_distinct<T: EqTyp + 'static>(
171        &self,
172        val: impl IntoExpr<'inner, S, Typ = T>,
173    ) -> Expr<'outer, S, i64> {
174        let val = val.into_expr().inner;
175        let val = self.select_func("COUNT", Rc::new(lower::Expr::Prefix("DISTINCT ", val)));
176        // technically the `if_null` here is only required for correlated sub queries
177        Expr::adhoc(lower::Expr::Func(
178            "IFNULL",
179            Box::new([val, Rc::new(lower::Expr::Constant(i64::ZERO))]),
180        ))
181    }
182
183    /// Return whether there are any rows.
184    ///
185    /// ```
186    /// # use rust_query::private::doctest_aggregate::*;
187    /// # get_txn(|txn| {
188    /// for x in [10, 42, 10] {
189    ///     txn.insert_ok(Val { x });
190    /// }
191    /// let (e1, e2) = txn.query_one(aggregate(|rows| {
192    ///     rows.join(Val);
193    ///     let e1 = rows.exists();
194    ///     rows.filter(false); // removes all rows
195    ///     let e2 = rows.exists();
196    ///     (e1, e2)
197    /// }));
198    /// assert_eq!(e1, true);
199    /// assert_eq!(e2, false);
200    /// # });
201    /// ```
202    pub fn exists(&self) -> Expr<'outer, S, bool> {
203        let zero_expr = Expr::<_, i64>::adhoc(lower::CONST_0);
204        self.count_distinct(zero_expr.clone()).neq(zero_expr)
205    }
206}
207
208/// Perform an aggregate that returns a single result for each of the current rows.
209///
210/// One can filter the rows in the aggregate based on values from the outer query.
211/// See the documentation for [Aggregate] for more information.
212///
213/// ```
214/// # use rust_query::migration::{schema, Config};
215/// # use rust_query::{Database, aggregate};
216/// #[schema(Site)]
217/// pub mod vN {
218///     pub struct Review {
219///         #[index]
220///         pub book: rust_query::TableRow<Book>,
221///         pub rating: f64,
222///     }
223///     pub struct Book {
224///         pub name: String
225///     }
226/// }
227/// use v0::*;
228///
229/// Database::new(Config::open_in_memory()).transaction(|txn| {
230///     let books = txn.query(|rows| {
231///         let book = rows.join(Book);
232///         let rating = aggregate(|aggr| {
233///             let review = aggr.join(Review.book(&book));
234///             // books without reviews will get a rating of 0.0
235///             aggr.avg(&review.rating).unwrap_or(0.0)
236///         });
237///         // top 10 highest rated books
238///         rows.order_by()
239///             .desc(rating)
240///             .into_iter(book)
241///             .take(10)
242///     });
243/// });
244/// ```
245pub fn aggregate<'outer, S, F, R>(f: F) -> R
246where
247    F: for<'inner> FnOnce(&mut Aggregate<'outer, 'inner, S>) -> R,
248{
249    let inner = Rows {
250        phantom: PhantomData,
251        ast: Default::default(),
252        _p: PhantomData,
253    };
254    let mut group = Aggregate {
255        query: inner,
256        _p: PhantomData,
257    };
258    f(&mut group)
259}