1use crate::{
2 db::{
3 Db,
4 executor::{
5 FilterEvaluator,
6 plan::{plan_for, record_plan_metrics, scan_missing_ok, set_rows_from_len},
7 },
8 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
9 query::{LoadQuery, QueryPlan, QueryValidate},
10 response::{Response, ResponseError},
11 store::{DataKey, DataRow},
12 },
13 error::InternalError,
14 obs::sink::{self, ExecKind, MetricsEvent, Span},
15 prelude::*,
16 traits::{EntityKind, FieldValue},
17};
18use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData, ops::ControlFlow};
19
20#[derive(Clone)]
25pub struct LoadExecutor<E: EntityKind> {
26 db: Db<E::Canister>,
27 debug: bool,
28 _marker: PhantomData<E>,
29}
30
31impl<E: EntityKind> LoadExecutor<E> {
32 #[must_use]
37 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
38 Self {
39 db,
40 debug,
41 _marker: PhantomData,
42 }
43 }
44
45 fn debug_log(&self, s: impl Into<String>) {
46 if self.debug {
47 println!("{}", s.into());
48 }
49 }
50
51 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, InternalError> {
57 self.execute(LoadQuery::new().one::<E>(value))
58 }
59
60 pub fn only(&self) -> Result<Response<E>, InternalError> {
62 self.execute(LoadQuery::new().one::<E>(()))
63 }
64
65 pub fn many<I, V>(&self, values: I) -> Result<Response<E>, InternalError>
67 where
68 I: IntoIterator<Item = V>,
69 V: FieldValue,
70 {
71 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
72 self.execute(query)
73 }
74
75 pub fn all(&self) -> Result<Response<E>, InternalError> {
77 self.execute(LoadQuery::new())
78 }
79
80 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, InternalError>
82 where
83 F: FnOnce(FilterDsl) -> I,
84 I: IntoFilterExpr,
85 {
86 self.execute(LoadQuery::new().filter(f))
87 }
88
89 pub fn require_one(&self, query: LoadQuery) -> Result<(), InternalError> {
95 self.execute(query)?.require_one()
96 }
97
98 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), InternalError> {
100 self.require_one(LoadQuery::new().one::<E>(value))
101 }
102
103 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), InternalError>
105 where
106 F: FnOnce(FilterDsl) -> I,
107 I: IntoFilterExpr,
108 {
109 self.require_one(LoadQuery::new().filter(f))
110 }
111
112 pub fn exists(&self, query: LoadQuery) -> Result<bool, InternalError> {
125 QueryValidate::<E>::validate(&query)?;
126 sink::record(MetricsEvent::ExistsCall {
127 entity_path: E::PATH,
128 });
129
130 let plan = plan_for::<E>(query.filter.as_ref());
131 let offset = query.limit.as_ref().map_or(0, |lim| lim.offset);
132 let limit = query.limit.as_ref().and_then(|lim| lim.limit);
133 if limit == Some(0) {
134 return Ok(false);
135 }
136
137 #[allow(clippy::needless_continue)]
138 match plan {
139 QueryPlan::Keys(keys) => {
140 let mut seen = 0u32;
142 for key in keys {
143 let data_key = DataKey::new::<E>(key);
144 match self.db.context::<E>().read(&data_key) {
145 Ok(_) => {
146 if seen < offset {
147 seen += 1;
148 } else {
149 return Ok(true);
150 }
151 }
152 Err(err) if err.is_not_found() => continue,
153 Err(err) => return Err(err),
154 }
155 }
156 Ok(false)
157 }
158 plan => {
159 let filter = query.filter.map(FilterExpr::simplify);
160 let mut seen = 0u32;
161 let mut scanned = 0u64;
162 let mut found = false;
163
164 scan_missing_ok::<E, _>(&self.db, plan, |_, entity| {
166 scanned = scanned.saturating_add(1);
167 let matches = filter
168 .as_ref()
169 .is_none_or(|f| FilterEvaluator::new(&entity).eval(f));
170
171 if matches {
172 if seen < offset {
173 seen += 1;
174 } else {
175 found = true;
176 return ControlFlow::Break(());
177 }
178 }
179
180 ControlFlow::Continue(())
181 })?;
182
183 sink::record(MetricsEvent::RowsScanned {
184 entity_path: E::PATH,
185 rows_scanned: scanned,
186 });
187
188 Ok(found)
189 }
190 }
191 }
192
193 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, InternalError> {
202 let value = value.to_value();
203 let query = LoadQuery::new().one::<E>(value.clone());
204 QueryValidate::<E>::validate(&query)?;
205 sink::record(MetricsEvent::ExistsCall {
206 entity_path: E::PATH,
207 });
208
209 let Some(key) = value.as_key_coerced() else {
210 sink::record(MetricsEvent::RowsScanned {
211 entity_path: E::PATH,
212 rows_scanned: 0,
213 });
214 return Ok(false);
215 };
216
217 let data_key = DataKey::new::<E>(key);
218 let found = match self.db.context::<E>().read(&data_key) {
219 Ok(_) => true,
220 Err(err) if err.is_not_found() => false,
221 Err(err) => return Err(err),
222 };
223
224 sink::record(MetricsEvent::RowsScanned {
225 entity_path: E::PATH,
226 rows_scanned: u64::from(found),
227 });
228
229 Ok(found)
230 }
231
232 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, InternalError>
234 where
235 F: FnOnce(FilterDsl) -> I,
236 I: IntoFilterExpr,
237 {
238 self.exists(LoadQuery::new().filter(f))
239 }
240
241 pub fn exists_any(&self) -> Result<bool, InternalError> {
243 self.exists(LoadQuery::new())
244 }
245
246 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), InternalError> {
252 if self.exists_one(value)? {
253 Ok(())
254 } else {
255 Err(ResponseError::NotFound { entity: E::PATH }.into())
256 }
257 }
258
259 #[allow(clippy::cast_possible_truncation)]
261 pub fn ensure_exists_many<I, V>(&self, values: I) -> Result<(), InternalError>
262 where
263 I: IntoIterator<Item = V>,
264 V: FieldValue,
265 {
266 let pks: Vec<_> = values.into_iter().collect();
267
268 let expected = pks.len() as u32;
269 if expected == 0 {
270 return Ok(());
271 }
272
273 let res = self.many(pks)?;
274 res.require_len(expected)?;
275
276 Ok(())
277 }
278
279 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), InternalError>
281 where
282 F: FnOnce(FilterDsl) -> I,
283 I: IntoFilterExpr,
284 {
285 if self.exists_filter(f)? {
286 Ok(())
287 } else {
288 Err(ResponseError::NotFound { entity: E::PATH }.into())
289 }
290 }
291
292 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, InternalError> {
298 QueryValidate::<E>::validate(&query)?;
299
300 Ok(plan_for::<E>(query.filter.as_ref()))
301 }
302
303 fn execute_raw(
304 &self,
305 plan: QueryPlan,
306 query: &LoadQuery,
307 ) -> Result<Vec<DataRow>, InternalError> {
308 let ctx = self.db.context::<E>();
309
310 if let Some(lim) = &query.limit {
311 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
312 } else {
313 Ok(ctx.rows_from_plan(plan)?)
314 }
315 }
316
317 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, InternalError> {
323 let mut span = Span::<E>::new(ExecKind::Load);
324 QueryValidate::<E>::validate(&query)?;
325
326 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
327
328 let ctx = self.db.context::<E>();
329 let plan = plan_for::<E>(query.filter.as_ref());
330
331 self.debug_log(format!("📄 Query plan: {plan:?}"));
332 record_plan_metrics(&plan);
333
334 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
336 let mut rows: Vec<(Key, E)> = if pre_paginated {
337 let data_rows = self.execute_raw(plan, &query)?;
338 sink::record(MetricsEvent::RowsScanned {
339 entity_path: E::PATH,
340 rows_scanned: data_rows.len() as u64,
341 });
342
343 self.debug_log(format!(
344 "📦 Scanned {} data rows before deserialization",
345 data_rows.len()
346 ));
347
348 let rows = ctx.deserialize_rows(data_rows)?;
349 self.debug_log(format!(
350 "🧩 Deserialized {} entities before filtering",
351 rows.len()
352 ));
353 rows
354 } else {
355 let data_rows = ctx.rows_from_plan(plan)?;
356 sink::record(MetricsEvent::RowsScanned {
357 entity_path: E::PATH,
358 rows_scanned: data_rows.len() as u64,
359 });
360 self.debug_log(format!(
361 "📦 Scanned {} data rows before deserialization",
362 data_rows.len()
363 ));
364
365 let rows = ctx.deserialize_rows(data_rows)?;
366 self.debug_log(format!(
367 "🧩 Deserialized {} entities before filtering",
368 rows.len()
369 ));
370
371 rows
372 };
373
374 if let Some(f) = &query.filter {
376 let simplified = f.clone().simplify();
377 Self::apply_filter(&mut rows, &simplified);
378
379 self.debug_log(format!(
380 "🔎 Applied filter -> {} entities remaining",
381 rows.len()
382 ));
383 }
384
385 if let Some(sort) = &query.sort
387 && rows.len() > 1
388 {
389 Self::apply_sort(&mut rows, sort);
390 self.debug_log("↕️ Applied sort expression");
391 }
392
393 if let Some(lim) = &query.limit
395 && !pre_paginated
396 {
397 apply_pagination(&mut rows, lim.offset, lim.limit);
398 self.debug_log(format!(
399 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
400 lim.offset,
401 lim.limit,
402 rows.len()
403 ));
404 }
405
406 set_rows_from_len(&mut span, rows.len());
407 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
408
409 Ok(Response(rows))
410 }
411
412 pub fn count(&self, query: LoadQuery) -> Result<u32, InternalError> {
414 Ok(self.execute(query)?.count())
415 }
416
417 pub fn count_all(&self) -> Result<u32, InternalError> {
418 self.count(LoadQuery::new())
419 }
420
421 pub fn group_count_by<K, F>(
430 &self,
431 query: LoadQuery,
432 key_fn: F,
433 ) -> Result<HashMap<K, u32>, InternalError>
434 where
435 K: Eq + Hash,
436 F: Fn(&E) -> K,
437 {
438 let entities = self.execute(query)?.entities();
439
440 let mut counts = HashMap::new();
441 for e in entities {
442 *counts.entry(key_fn(&e)).or_insert(0) += 1;
443 }
444
445 Ok(counts)
446 }
447
448 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
454 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
455 }
456
457 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
459 rows.sort_by(|(_, ea), (_, eb)| {
460 for (field, direction) in sort_expr.iter() {
461 let va = ea.get_value(field);
462 let vb = eb.get_value(field);
463
464 let ordering = match (va, vb) {
466 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
470 Some(ord) => ord,
471 None => continue, },
473 };
474
475 let ordering = match direction {
477 Order::Asc => ordering,
478 Order::Desc => ordering.reverse(),
479 };
480
481 if ordering != Ordering::Equal {
482 return ordering;
483 }
484 }
485
486 Ordering::Equal
488 });
489 }
490}
491
492fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
494 let total = rows.len();
495 let start = usize::min(offset as usize, total);
496 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
497
498 if start >= end {
499 rows.clear();
500 } else {
501 rows.drain(..start);
502 rows.truncate(end - start);
503 }
504}