1mod state;
2
3use {
4 self::state::State,
5 super::{
6 context::{AggregateContext, RowContext},
7 evaluate::evaluate,
8 filter::check_expr,
9 },
10 crate::{
11 ast::{Expr, SelectItem},
12 data::Value,
13 result::Result,
14 store::GStore,
15 },
16 futures::{
17 future::BoxFuture,
18 stream::{self, Stream, StreamExt, TryStreamExt},
19 },
20 std::sync::Arc,
21};
22
23#[derive(futures_enum::Stream)]
24enum S<T1, T2> {
25 NonAggregate(T1),
26 Aggregate(T2),
27}
28
29fn check_aggregate<'a>(fields: &'a [SelectItem], group_by: &'a [Expr]) -> bool {
30 if !group_by.is_empty() {
31 return true;
32 }
33
34 fields.iter().any(|field| match field {
35 SelectItem::Expr { expr, .. } => check(expr),
36 _ => false,
37 })
38}
39
40pub async fn apply<'a, T: GStore, U: Stream<Item = Result<Arc<RowContext<'a>>>> + 'a>(
41 storage: &'a T,
42 fields: &'a [SelectItem],
43 group_by: &'a [Expr],
44 having: Option<&'a Expr>,
45 filter_context: Option<Arc<RowContext<'a>>>,
46 rows: U,
47) -> Result<impl Stream<Item = Result<AggregateContext<'a>>> + use<'a, T, U>> {
48 if !check_aggregate(fields, group_by) {
49 let rows = rows.map_ok(|project_context| AggregateContext {
50 aggregated: None,
51 next: project_context,
52 });
53 return Ok(S::NonAggregate(rows));
54 }
55
56 let state = rows
57 .into_stream()
58 .enumerate()
59 .map(|(i, row)| row.map(|row| (i, row)))
60 .try_fold(State::new(storage), |state, (index, project_context)| {
61 let filter_context = filter_context.clone();
62
63 async move {
64 let filter_context = match filter_context {
65 Some(filter_context) => Arc::new(RowContext::concat(
66 Arc::clone(&project_context),
67 filter_context,
68 )),
69 None => Arc::clone(&project_context),
70 };
71 let filter_context = Some(filter_context);
72
73 let group = stream::iter(group_by.iter())
74 .then(|expr| {
75 let filter_clone = filter_context.as_ref().map(Arc::clone);
76 async move {
77 evaluate(storage, filter_clone, None, expr)
78 .await?
79 .try_into()
80 }
81 })
82 .try_collect::<Vec<Value>>()
83 .await?;
84
85 let state = state.apply(index, group, Arc::clone(&project_context));
86 let state = stream::iter(fields)
87 .map(Ok)
88 .try_fold(state, |state, field| {
89 let filter_clone = filter_context.as_ref().map(Arc::clone);
90
91 async move {
92 match field {
93 SelectItem::Expr { expr, .. } => {
94 aggregate(state, filter_clone, expr).await
95 }
96 _ => Ok(state),
97 }
98 }
99 })
100 .await?;
101
102 Ok(state)
103 }
104 })
105 .await?;
106
107 group_by_having(storage, filter_context, having, state)
108 .await
109 .map(S::Aggregate)
110}
111
112async fn group_by_having<'a, T: GStore>(
113 storage: &'a T,
114 filter_context: Option<Arc<RowContext<'a>>>,
115 having: Option<&'a Expr>,
116 state: State<'a, T>,
117) -> Result<impl Stream<Item = Result<AggregateContext<'a>>>> {
118 let rows = state
119 .export()
120 .await?
121 .into_iter()
122 .filter_map(|(aggregated, next)| next.map(|next| (aggregated, next)));
123 let rows = stream::iter(rows)
124 .filter_map(move |(aggregated, next)| {
125 let filter_context = filter_context.as_ref().map(Arc::clone);
126
127 async move {
128 match having {
129 None => Some(Ok((aggregated, next))),
130 Some(having) => {
131 let filter_context = match filter_context {
132 Some(filter_context) => {
133 Arc::new(RowContext::concat(Arc::clone(&next), filter_context))
134 }
135 None => Arc::clone(&next),
136 };
137 let filter_context = Some(filter_context);
138 let aggr_rc = aggregated.clone().map(Arc::new);
139
140 check_expr(storage, filter_context, aggr_rc, having)
141 .await
142 .map(|pass| pass.then_some((aggregated, next)))
143 .transpose()
144 }
145 }
146 }
147 })
148 .map(|res| res.map(|(aggregated, next)| AggregateContext { aggregated, next }));
149
150 Ok(rows)
151}
152
153fn aggregate<'a, T>(
154 state: State<'a, T>,
155 filter_context: Option<Arc<RowContext<'a>>>,
156 expr: &'a Expr,
157) -> BoxFuture<'a, Result<State<'a, T>>>
158where
159 T: GStore + 'a,
160{
161 Box::pin(async move {
162 match expr {
163 Expr::Between {
164 expr, low, high, ..
165 } => {
166 let state = aggregate(state, filter_context.clone(), expr).await?;
167 let state = aggregate(state, filter_context.clone(), low).await?;
168 aggregate(state, filter_context, high).await
169 }
170 Expr::BinaryOp { left, right, .. } => {
171 let state = aggregate(state, filter_context.clone(), left).await?;
172 aggregate(state, filter_context, right).await
173 }
174 Expr::UnaryOp { expr, .. } => aggregate(state, filter_context, expr).await,
175 Expr::Nested(expr) => aggregate(state, filter_context, expr).await,
176 Expr::Case {
177 operand,
178 when_then,
179 else_result,
180 } => {
181 let mut state = match operand.as_deref() {
182 Some(op) => aggregate(state, filter_context.clone(), op).await?,
183 None => state,
184 };
185
186 for (when, then) in when_then {
187 state = aggregate(state, filter_context.clone(), when).await?;
188 state = aggregate(state, filter_context.clone(), then).await?;
189 }
190
191 if let Some(else_expr) = else_result.as_deref() {
192 state = aggregate(state, filter_context.clone(), else_expr).await?;
193 }
194
195 Ok(state)
196 }
197 Expr::Aggregate(aggr_expr) => {
198 state.accumulate(filter_context, aggr_expr.as_ref()).await
199 }
200 _ => Ok(state),
201 }
202 })
203}
204
205fn check(expr: &Expr) -> bool {
206 match expr {
207 Expr::Between {
208 expr, low, high, ..
209 } => check(expr) || check(low) || check(high),
210 Expr::BinaryOp { left, right, .. } => check(left) || check(right),
211 Expr::UnaryOp { expr, .. } => check(expr),
212 Expr::Nested(expr) => check(expr),
213 Expr::Case {
214 operand,
215 when_then,
216 else_result,
217 } => {
218 operand.as_ref().map(|expr| check(expr)).unwrap_or(false)
219 || when_then
220 .iter()
221 .any(|(when, then)| check(when) || check(then))
222 || else_result
223 .as_ref()
224 .map(|expr| check(expr))
225 .unwrap_or(false)
226 }
227 Expr::Aggregate(_) => true,
228 _ => false,
229 }
230}