gluesql_core/executor/
aggregate.rs

1mod state;
2
3use {
4    self::state::State,
5    super::{
6        context::{AggregateContext, RowContext},
7        evaluate::evaluate,
8        filter::check_expr,
9    },
10    crate::{
11        ast::{Expr, SelectItem},
12        data::Value,
13        result::Result,
14        store::GStore,
15    },
16    futures::{
17        future::BoxFuture,
18        stream::{self, Stream, StreamExt, TryStreamExt},
19    },
20    std::sync::Arc,
21};
22
23#[derive(futures_enum::Stream)]
24enum S<T1, T2> {
25    NonAggregate(T1),
26    Aggregate(T2),
27}
28
29fn check_aggregate<'a>(fields: &'a [SelectItem], group_by: &'a [Expr]) -> bool {
30    if !group_by.is_empty() {
31        return true;
32    }
33
34    fields.iter().any(|field| match field {
35        SelectItem::Expr { expr, .. } => check(expr),
36        _ => false,
37    })
38}
39
40pub async fn apply<'a, T: GStore, U: Stream<Item = Result<Arc<RowContext<'a>>>> + 'a>(
41    storage: &'a T,
42    fields: &'a [SelectItem],
43    group_by: &'a [Expr],
44    having: Option<&'a Expr>,
45    filter_context: Option<Arc<RowContext<'a>>>,
46    rows: U,
47) -> Result<impl Stream<Item = Result<AggregateContext<'a>>> + use<'a, T, U>> {
48    if !check_aggregate(fields, group_by) {
49        let rows = rows.map_ok(|project_context| AggregateContext {
50            aggregated: None,
51            next: project_context,
52        });
53        return Ok(S::NonAggregate(rows));
54    }
55
56    let state = rows
57        .into_stream()
58        .enumerate()
59        .map(|(i, row)| row.map(|row| (i, row)))
60        .try_fold(State::new(storage), |state, (index, project_context)| {
61            let filter_context = filter_context.clone();
62
63            async move {
64                let filter_context = match filter_context {
65                    Some(filter_context) => Arc::new(RowContext::concat(
66                        Arc::clone(&project_context),
67                        filter_context,
68                    )),
69                    None => Arc::clone(&project_context),
70                };
71                let filter_context = Some(filter_context);
72
73                let group = stream::iter(group_by.iter())
74                    .then(|expr| {
75                        let filter_clone = filter_context.as_ref().map(Arc::clone);
76                        async move {
77                            evaluate(storage, filter_clone, None, expr)
78                                .await?
79                                .try_into()
80                        }
81                    })
82                    .try_collect::<Vec<Value>>()
83                    .await?;
84
85                let state = state.apply(index, group, Arc::clone(&project_context));
86                let state = stream::iter(fields)
87                    .map(Ok)
88                    .try_fold(state, |state, field| {
89                        let filter_clone = filter_context.as_ref().map(Arc::clone);
90
91                        async move {
92                            match field {
93                                SelectItem::Expr { expr, .. } => {
94                                    aggregate(state, filter_clone, expr).await
95                                }
96                                _ => Ok(state),
97                            }
98                        }
99                    })
100                    .await?;
101
102                Ok(state)
103            }
104        })
105        .await?;
106
107    group_by_having(storage, filter_context, having, state)
108        .await
109        .map(S::Aggregate)
110}
111
112async fn group_by_having<'a, T: GStore>(
113    storage: &'a T,
114    filter_context: Option<Arc<RowContext<'a>>>,
115    having: Option<&'a Expr>,
116    state: State<'a, T>,
117) -> Result<impl Stream<Item = Result<AggregateContext<'a>>>> {
118    let rows = state
119        .export()
120        .await?
121        .into_iter()
122        .filter_map(|(aggregated, next)| next.map(|next| (aggregated, next)));
123    let rows = stream::iter(rows)
124        .filter_map(move |(aggregated, next)| {
125            let filter_context = filter_context.as_ref().map(Arc::clone);
126
127            async move {
128                match having {
129                    None => Some(Ok((aggregated, next))),
130                    Some(having) => {
131                        let filter_context = match filter_context {
132                            Some(filter_context) => {
133                                Arc::new(RowContext::concat(Arc::clone(&next), filter_context))
134                            }
135                            None => Arc::clone(&next),
136                        };
137                        let filter_context = Some(filter_context);
138                        let aggr_rc = aggregated.clone().map(Arc::new);
139
140                        check_expr(storage, filter_context, aggr_rc, having)
141                            .await
142                            .map(|pass| pass.then_some((aggregated, next)))
143                            .transpose()
144                    }
145                }
146            }
147        })
148        .map(|res| res.map(|(aggregated, next)| AggregateContext { aggregated, next }));
149
150    Ok(rows)
151}
152
153fn aggregate<'a, T>(
154    state: State<'a, T>,
155    filter_context: Option<Arc<RowContext<'a>>>,
156    expr: &'a Expr,
157) -> BoxFuture<'a, Result<State<'a, T>>>
158where
159    T: GStore + 'a,
160{
161    Box::pin(async move {
162        match expr {
163            Expr::Between {
164                expr, low, high, ..
165            } => {
166                let state = aggregate(state, filter_context.clone(), expr).await?;
167                let state = aggregate(state, filter_context.clone(), low).await?;
168                aggregate(state, filter_context, high).await
169            }
170            Expr::BinaryOp { left, right, .. } => {
171                let state = aggregate(state, filter_context.clone(), left).await?;
172                aggregate(state, filter_context, right).await
173            }
174            Expr::UnaryOp { expr, .. } => aggregate(state, filter_context, expr).await,
175            Expr::Nested(expr) => aggregate(state, filter_context, expr).await,
176            Expr::Case {
177                operand,
178                when_then,
179                else_result,
180            } => {
181                let mut state = match operand.as_deref() {
182                    Some(op) => aggregate(state, filter_context.clone(), op).await?,
183                    None => state,
184                };
185
186                for (when, then) in when_then {
187                    state = aggregate(state, filter_context.clone(), when).await?;
188                    state = aggregate(state, filter_context.clone(), then).await?;
189                }
190
191                if let Some(else_expr) = else_result.as_deref() {
192                    state = aggregate(state, filter_context.clone(), else_expr).await?;
193                }
194
195                Ok(state)
196            }
197            Expr::Aggregate(aggr_expr) => {
198                state.accumulate(filter_context, aggr_expr.as_ref()).await
199            }
200            _ => Ok(state),
201        }
202    })
203}
204
205fn check(expr: &Expr) -> bool {
206    match expr {
207        Expr::Between {
208            expr, low, high, ..
209        } => check(expr) || check(low) || check(high),
210        Expr::BinaryOp { left, right, .. } => check(left) || check(right),
211        Expr::UnaryOp { expr, .. } => check(expr),
212        Expr::Nested(expr) => check(expr),
213        Expr::Case {
214            operand,
215            when_then,
216            else_result,
217        } => {
218            operand.as_ref().map(|expr| check(expr)).unwrap_or(false)
219                || when_then
220                    .iter()
221                    .any(|(when, then)| check(when) || check(then))
222                || else_result
223                    .as_ref()
224                    .map(|expr| check(expr))
225                    .unwrap_or(false)
226        }
227        Expr::Aggregate(_) => true,
228        _ => false,
229    }
230}