icydb_core/db/session/sql/
aggregate.rs1use crate::{
7 db::{
8 DbSession, MissingRowPolicy, PersistedRow, Query, QueryError,
9 executor::{ScalarNumericFieldBoundaryRequest, ScalarProjectionBoundaryRequest},
10 query::plan::{AggregateKind, FieldSlot},
11 session::sql::explain::resolve_sql_aggregate_target_slot,
12 session::sql::{SqlParsedStatement, SqlStatementRoute},
13 sql::lowering::{
14 SqlGlobalAggregateTerminal, compile_sql_global_aggregate_command,
15 is_sql_global_aggregate_statement,
16 },
17 sql::parser::SqlStatement,
18 },
19 traits::{CanisterKind, EntityValue},
20 types::Id,
21 value::Value,
22};
23
24#[derive(Clone, Copy, Debug, Eq, PartialEq)]
25pub(in crate::db::session::sql) enum SqlAggregateSurface {
26 QueryFrom,
27 ExecuteSql,
28 ExecuteSqlGrouped,
29 ExecuteSqlDispatch,
30 GeneratedQuerySurface,
31}
32
33pub(in crate::db::session::sql) fn parsed_requires_dedicated_sql_aggregate_lane(
34 parsed: &SqlParsedStatement,
35) -> bool {
36 is_sql_global_aggregate_statement(&parsed.statement)
37}
38
39pub(in crate::db::session::sql) const fn unsupported_sql_aggregate_lane_message(
40 surface: SqlAggregateSurface,
41) -> &'static str {
42 match surface {
43 SqlAggregateSurface::QueryFrom => {
44 "query_from_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
45 }
46 SqlAggregateSurface::ExecuteSql => {
47 "execute_sql rejects global aggregate SELECT; use execute_sql_aggregate(...)"
48 }
49 SqlAggregateSurface::ExecuteSqlGrouped => {
50 "execute_sql_grouped rejects global aggregate SELECT; use execute_sql_aggregate(...)"
51 }
52 SqlAggregateSurface::ExecuteSqlDispatch => {
53 "execute_sql_dispatch rejects global aggregate SELECT; use execute_sql_aggregate(...)"
54 }
55 SqlAggregateSurface::GeneratedQuerySurface => {
56 "generated SQL query surface rejects global aggregate SELECT; use execute_sql_aggregate(...)"
57 }
58 }
59}
60
61const fn unsupported_sql_aggregate_surface_lane_message(route: &SqlStatementRoute) -> &'static str {
62 match route {
63 SqlStatementRoute::Query { .. } => {
64 "execute_sql_aggregate requires constrained global aggregate SELECT"
65 }
66 SqlStatementRoute::Explain { .. } => {
67 "execute_sql_aggregate rejects EXPLAIN; use execute_sql_dispatch"
68 }
69 SqlStatementRoute::Describe { .. } => {
70 "execute_sql_aggregate rejects DESCRIBE; use execute_sql_dispatch"
71 }
72 SqlStatementRoute::ShowIndexes { .. } => {
73 "execute_sql_aggregate rejects SHOW INDEXES; use execute_sql_dispatch"
74 }
75 SqlStatementRoute::ShowColumns { .. } => {
76 "execute_sql_aggregate rejects SHOW COLUMNS; use execute_sql_dispatch"
77 }
78 SqlStatementRoute::ShowEntities => {
79 "execute_sql_aggregate rejects SHOW ENTITIES; use execute_sql_dispatch"
80 }
81 }
82}
83
84const fn unsupported_sql_aggregate_grouped_message() -> &'static str {
85 "execute_sql_aggregate rejects grouped SELECT; use execute_sql_grouped(...)"
86}
87
88impl<C: CanisterKind> DbSession<C> {
89 pub fn execute_sql_aggregate<E>(&self, sql: &str) -> Result<Value, QueryError>
94 where
95 E: PersistedRow<Canister = C> + EntityValue,
96 {
97 let parsed = self.parse_sql_statement(sql)?;
100 match &parsed.statement {
101 SqlStatement::Select(_) if is_sql_global_aggregate_statement(&parsed.statement) => {}
102 SqlStatement::Select(statement) if !statement.group_by.is_empty() => {
103 return Err(QueryError::unsupported_query(
104 unsupported_sql_aggregate_grouped_message(),
105 ));
106 }
107 SqlStatement::Delete(_) => {
108 return Err(QueryError::unsupported_query(
109 "execute_sql_aggregate rejects DELETE; use execute_sql_dispatch",
110 ));
111 }
112 _ => {
113 return Err(QueryError::unsupported_query(
114 unsupported_sql_aggregate_surface_lane_message(parsed.route()),
115 ));
116 }
117 }
118
119 let command = compile_sql_global_aggregate_command::<E>(sql, MissingRowPolicy::Ignore)
123 .map_err(QueryError::from_sql_lowering_error)?;
124
125 match command.terminal() {
128 SqlGlobalAggregateTerminal::CountRows => self
129 .execute_load_query_with(command.query(), |load, plan| {
130 load.execute_scalar_terminal_request(
131 plan,
132 crate::db::executor::ScalarTerminalBoundaryRequest::Count,
133 )?
134 .into_count()
135 })
136 .map(|count| Value::Uint(u64::from(count))),
137 SqlGlobalAggregateTerminal::CountField(field) => {
138 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
139
140 self.execute_load_query_with(command.query(), |load, plan| {
141 load.execute_scalar_projection_boundary(
142 plan,
143 target_slot,
144 ScalarProjectionBoundaryRequest::Values,
145 )?
146 .into_values()
147 })
148 .map(|values| {
149 let count = values
150 .into_iter()
151 .filter(|value| !matches!(value, Value::Null))
152 .count();
153
154 Value::Uint(u64::try_from(count).unwrap_or(u64::MAX))
155 })
156 }
157 SqlGlobalAggregateTerminal::SumField(field) => {
158 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
159
160 self.execute_load_query_with(command.query(), |load, plan| {
161 load.execute_numeric_field_boundary(
162 plan,
163 target_slot,
164 ScalarNumericFieldBoundaryRequest::Sum,
165 )
166 })
167 .map(|value| value.map_or(Value::Null, Value::Decimal))
168 }
169 SqlGlobalAggregateTerminal::AvgField(field) => {
170 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
171
172 self.execute_load_query_with(command.query(), |load, plan| {
173 load.execute_numeric_field_boundary(
174 plan,
175 target_slot,
176 ScalarNumericFieldBoundaryRequest::Avg,
177 )
178 })
179 .map(|value| value.map_or(Value::Null, Value::Decimal))
180 }
181 SqlGlobalAggregateTerminal::MinField(field) => self
182 .execute_ranked_sql_aggregate_field::<E>(
183 command.query(),
184 field,
185 AggregateKind::Min,
186 ),
187 SqlGlobalAggregateTerminal::MaxField(field) => self
188 .execute_ranked_sql_aggregate_field::<E>(
189 command.query(),
190 field,
191 AggregateKind::Max,
192 ),
193 }
194 }
195
196 fn execute_ranked_sql_aggregate_field<E>(
199 &self,
200 query: &Query<E>,
201 field: &str,
202 kind: AggregateKind,
203 ) -> Result<Value, QueryError>
204 where
205 E: PersistedRow<Canister = C> + EntityValue,
206 {
207 let target_slot = resolve_sql_aggregate_target_slot::<E>(field)?;
208 let matched_id = self.execute_ranked_sql_aggregate_id(query, target_slot, kind)?;
209
210 match matched_id {
211 Some(id) => self
212 .load::<E>()
213 .by_id(id)
214 .first_value_by(field)
215 .map(|value| value.unwrap_or(Value::Null)),
216 None => Ok(Value::Null),
217 }
218 }
219
220 fn execute_ranked_sql_aggregate_id<E>(
223 &self,
224 query: &Query<E>,
225 target_slot: FieldSlot,
226 kind: AggregateKind,
227 ) -> Result<Option<Id<E>>, QueryError>
228 where
229 E: PersistedRow<Canister = C> + EntityValue,
230 {
231 if !kind.is_extrema() {
232 return Err(QueryError::invariant(
233 "ranked SQL aggregate id helper only supports MIN/MAX",
234 ));
235 }
236
237 self.execute_load_query_with(query, |load, plan| {
238 load.execute_scalar_terminal_request(
239 plan,
240 crate::db::executor::ScalarTerminalBoundaryRequest::IdBySlot {
241 kind,
242 target_field: target_slot,
243 },
244 )?
245 .into_id()
246 })
247 }
248}