1use crate::{
2 db::{
3 Db,
4 executor::{
5 FilterEvaluator,
6 plan::{plan_for, record_plan_metrics, scan_strict, set_rows_from_len},
7 },
8 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
9 query::{LoadQuery, QueryPlan, QueryValidate},
10 response::{Response, ResponseError},
11 store::DataRow,
12 },
13 error::InternalError,
14 obs::sink::{self, ExecKind, MetricsEvent, Span},
15 prelude::*,
16 traits::{EntityKind, FieldValue},
17};
18use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData, ops::ControlFlow};
19
20#[derive(Clone)]
25pub struct LoadExecutor<E: EntityKind> {
26 db: Db<E::Canister>,
27 debug: bool,
28 _marker: PhantomData<E>,
29}
30
31impl<E: EntityKind> LoadExecutor<E> {
32 #[must_use]
37 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
38 Self {
39 db,
40 debug,
41 _marker: PhantomData,
42 }
43 }
44
45 fn debug_log(&self, s: impl Into<String>) {
46 if self.debug {
47 println!("{}", s.into());
48 }
49 }
50
51 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, InternalError> {
57 self.execute(LoadQuery::new().one::<E>(value))
58 }
59
60 pub fn only(&self) -> Result<Response<E>, InternalError> {
62 self.execute(LoadQuery::new().one::<E>(()))
63 }
64
65 pub fn many<I, V>(&self, values: I) -> Result<Response<E>, InternalError>
67 where
68 I: IntoIterator<Item = V>,
69 V: FieldValue,
70 {
71 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
72 self.execute(query)
73 }
74
75 pub fn all(&self) -> Result<Response<E>, InternalError> {
77 self.execute(LoadQuery::new())
78 }
79
80 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, InternalError>
82 where
83 F: FnOnce(FilterDsl) -> I,
84 I: IntoFilterExpr,
85 {
86 self.execute(LoadQuery::new().filter(f))
87 }
88
89 pub fn require_one(&self, query: LoadQuery) -> Result<(), InternalError> {
95 self.execute(query)?.require_one()
96 }
97
98 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), InternalError> {
100 self.require_one(LoadQuery::new().one::<E>(value))
101 }
102
103 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), InternalError>
105 where
106 F: FnOnce(FilterDsl) -> I,
107 I: IntoFilterExpr,
108 {
109 self.require_one(LoadQuery::new().filter(f))
110 }
111
112 pub fn exists(&self, query: LoadQuery) -> Result<bool, InternalError> {
123 QueryValidate::<E>::validate(&query)?;
124 sink::record(MetricsEvent::ExistsCall {
125 entity_path: E::PATH,
126 });
127
128 let plan = plan_for::<E>(query.filter.as_ref());
129 let filter = query.filter.map(FilterExpr::simplify);
130 let offset = query.limit.as_ref().map_or(0, |lim| lim.offset);
131 let limit = query.limit.as_ref().and_then(|lim| lim.limit);
132 if limit == Some(0) {
133 return Ok(false);
134 }
135 let mut seen = 0u32;
136 let mut scanned = 0u64;
137 let mut found = false;
138
139 scan_strict::<E, _>(&self.db, plan, |_, entity| {
140 scanned = scanned.saturating_add(1);
141 let matches = filter
142 .as_ref()
143 .is_none_or(|f| FilterEvaluator::new(&entity).eval(f));
144
145 if matches {
146 if seen < offset {
147 seen += 1;
148 ControlFlow::Continue(())
149 } else {
150 found = true;
151 ControlFlow::Break(())
152 }
153 } else {
154 ControlFlow::Continue(())
155 }
156 })?;
157
158 sink::record(MetricsEvent::RowsScanned {
159 entity_path: E::PATH,
160 rows_scanned: scanned,
161 });
162
163 Ok(found)
164 }
165
166 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, InternalError> {
168 self.exists(LoadQuery::new().one::<E>(value))
169 }
170
171 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, InternalError>
173 where
174 F: FnOnce(FilterDsl) -> I,
175 I: IntoFilterExpr,
176 {
177 self.exists(LoadQuery::new().filter(f))
178 }
179
180 pub fn exists_any(&self) -> Result<bool, InternalError> {
182 self.exists(LoadQuery::new())
183 }
184
185 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), InternalError> {
191 if self.exists_one(value)? {
192 Ok(())
193 } else {
194 Err(ResponseError::NotFound { entity: E::PATH }.into())
195 }
196 }
197
198 #[allow(clippy::cast_possible_truncation)]
200 pub fn ensure_exists_many<I, V>(&self, values: I) -> Result<(), InternalError>
201 where
202 I: IntoIterator<Item = V>,
203 V: FieldValue,
204 {
205 let pks: Vec<_> = values.into_iter().collect();
206
207 let expected = pks.len() as u32;
208 if expected == 0 {
209 return Ok(());
210 }
211
212 let res = self.many(pks)?;
213 res.require_len(expected)?;
214
215 Ok(())
216 }
217
218 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), InternalError>
220 where
221 F: FnOnce(FilterDsl) -> I,
222 I: IntoFilterExpr,
223 {
224 if self.exists_filter(f)? {
225 Ok(())
226 } else {
227 Err(ResponseError::NotFound { entity: E::PATH }.into())
228 }
229 }
230
231 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, InternalError> {
237 QueryValidate::<E>::validate(&query)?;
238
239 Ok(plan_for::<E>(query.filter.as_ref()))
240 }
241
242 fn execute_raw(
243 &self,
244 plan: QueryPlan,
245 query: &LoadQuery,
246 ) -> Result<Vec<DataRow>, InternalError> {
247 let ctx = self.db.context::<E>();
248
249 if let Some(lim) = &query.limit {
250 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
251 } else {
252 Ok(ctx.rows_from_plan(plan)?)
253 }
254 }
255
256 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, InternalError> {
262 let mut span = Span::<E>::new(ExecKind::Load);
263 QueryValidate::<E>::validate(&query)?;
264
265 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
266
267 let ctx = self.db.context::<E>();
268 let plan = plan_for::<E>(query.filter.as_ref());
269
270 self.debug_log(format!("📄 Query plan: {plan:?}"));
271 record_plan_metrics(&plan);
272
273 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
275 let mut rows: Vec<(Key, E)> = if pre_paginated {
276 let data_rows = self.execute_raw(plan, &query)?;
277 sink::record(MetricsEvent::RowsScanned {
278 entity_path: E::PATH,
279 rows_scanned: data_rows.len() as u64,
280 });
281
282 self.debug_log(format!(
283 "📦 Scanned {} data rows before deserialization",
284 data_rows.len()
285 ));
286
287 let rows = ctx.deserialize_rows(data_rows)?;
288 self.debug_log(format!(
289 "🧩 Deserialized {} entities before filtering",
290 rows.len()
291 ));
292 rows
293 } else {
294 let data_rows = ctx.rows_from_plan(plan)?;
295 sink::record(MetricsEvent::RowsScanned {
296 entity_path: E::PATH,
297 rows_scanned: data_rows.len() as u64,
298 });
299 self.debug_log(format!(
300 "📦 Scanned {} data rows before deserialization",
301 data_rows.len()
302 ));
303
304 let rows = ctx.deserialize_rows(data_rows)?;
305 self.debug_log(format!(
306 "🧩 Deserialized {} entities before filtering",
307 rows.len()
308 ));
309
310 rows
311 };
312
313 if let Some(f) = &query.filter {
315 let simplified = f.clone().simplify();
316 Self::apply_filter(&mut rows, &simplified);
317
318 self.debug_log(format!(
319 "🔎 Applied filter -> {} entities remaining",
320 rows.len()
321 ));
322 }
323
324 if let Some(sort) = &query.sort
326 && rows.len() > 1
327 {
328 Self::apply_sort(&mut rows, sort);
329 self.debug_log("↕️ Applied sort expression");
330 }
331
332 if let Some(lim) = &query.limit
334 && !pre_paginated
335 {
336 apply_pagination(&mut rows, lim.offset, lim.limit);
337 self.debug_log(format!(
338 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
339 lim.offset,
340 lim.limit,
341 rows.len()
342 ));
343 }
344
345 set_rows_from_len(&mut span, rows.len());
346 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
347
348 Ok(Response(rows))
349 }
350
351 pub fn count(&self, query: LoadQuery) -> Result<u32, InternalError> {
353 Ok(self.execute(query)?.count())
354 }
355
356 pub fn count_all(&self) -> Result<u32, InternalError> {
357 self.count(LoadQuery::new())
358 }
359
360 pub fn group_count_by<K, F>(
369 &self,
370 query: LoadQuery,
371 key_fn: F,
372 ) -> Result<HashMap<K, u32>, InternalError>
373 where
374 K: Eq + Hash,
375 F: Fn(&E) -> K,
376 {
377 let entities = self.execute(query)?.entities();
378
379 let mut counts = HashMap::new();
380 for e in entities {
381 *counts.entry(key_fn(&e)).or_insert(0) += 1;
382 }
383
384 Ok(counts)
385 }
386
387 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
393 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
394 }
395
396 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
398 rows.sort_by(|(_, ea), (_, eb)| {
399 for (field, direction) in sort_expr.iter() {
400 let va = ea.get_value(field);
401 let vb = eb.get_value(field);
402
403 let ordering = match (va, vb) {
405 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
409 Some(ord) => ord,
410 None => continue, },
412 };
413
414 let ordering = match direction {
416 Order::Asc => ordering,
417 Order::Desc => ordering.reverse(),
418 };
419
420 if ordering != Ordering::Equal {
421 return ordering;
422 }
423 }
424
425 Ordering::Equal
427 });
428 }
429}
430
431fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
433 let total = rows.len();
434 let start = usize::min(offset as usize, total);
435 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
436
437 if start >= end {
438 rows.clear();
439 } else {
440 rows.drain(..start);
441 rows.truncate(end - start);
442 }
443}