1use crate::{
2 Error, Key,
3 db::{
4 Db,
5 executor::{
6 FilterEvaluator,
7 plan::{plan_for, set_rows_from_len},
8 },
9 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10 query::{LoadQuery, QueryPlan, QueryValidate},
11 response::{Response, ResponseError},
12 store::DataRow,
13 },
14 obs::metrics,
15 traits::{EntityKind, FieldValue},
16};
17use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData};
18
19#[derive(Clone, Copy)]
24pub struct LoadExecutor<E: EntityKind> {
25 db: Db<E::Canister>,
26 debug: bool,
27 _marker: PhantomData<E>,
28}
29
30impl<E: EntityKind> LoadExecutor<E> {
31 #[must_use]
36 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
37 Self {
38 db,
39 debug,
40 _marker: PhantomData,
41 }
42 }
43
44 fn debug_log(&self, s: impl Into<String>) {
45 if self.debug {
46 println!("{}", s.into());
47 }
48 }
49
50 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, Error> {
56 self.execute(LoadQuery::new().one::<E>(value))
57 }
58
59 pub fn only(&self) -> Result<Response<E>, Error> {
61 self.execute(LoadQuery::new().one::<E>(()))
62 }
63
64 pub fn many<I, V>(&self, values: I) -> Result<Response<E>, Error>
66 where
67 I: IntoIterator<Item = V>,
68 V: FieldValue,
69 {
70 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
71 self.execute(query)
72 }
73
74 pub fn all(&self) -> Result<Response<E>, Error> {
76 self.execute(LoadQuery::new())
77 }
78
79 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, Error>
81 where
82 F: FnOnce(FilterDsl) -> I,
83 I: IntoFilterExpr,
84 {
85 self.execute(LoadQuery::new().filter(f))
86 }
87
88 pub fn require_one(&self, query: LoadQuery) -> Result<(), Error> {
94 self.execute(query)?.require_one()
95 }
96
97 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), Error> {
99 self.require_one(LoadQuery::new().one::<E>(value))
100 }
101
102 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), Error>
104 where
105 F: FnOnce(FilterDsl) -> I,
106 I: IntoFilterExpr,
107 {
108 self.require_one(LoadQuery::new().filter(f))
109 }
110
111 pub fn exists(&self, query: LoadQuery) -> Result<bool, Error> {
117 let query = query.limit_1();
118 Ok(!self.execute_raw(&query)?.is_empty())
119 }
120
121 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, Error> {
123 self.exists(LoadQuery::new().one::<E>(value))
124 }
125
126 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, Error>
128 where
129 F: FnOnce(FilterDsl) -> I,
130 I: IntoFilterExpr,
131 {
132 self.exists(LoadQuery::new().filter(f))
133 }
134
135 pub fn exists_any(&self) -> Result<bool, Error> {
137 self.exists(LoadQuery::new())
138 }
139
140 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), Error> {
146 if self.exists_one(value)? {
147 Ok(())
148 } else {
149 Err(ResponseError::NotFound { entity: E::PATH }.into())
150 }
151 }
152
153 #[allow(clippy::cast_possible_truncation)]
155 pub fn ensure_exists_many<I, V>(&self, values: I) -> Result<(), Error>
156 where
157 I: IntoIterator<Item = V>,
158 V: FieldValue,
159 {
160 let pks: Vec<_> = values.into_iter().collect();
161
162 let expected = pks.len() as u32;
163 if expected == 0 {
164 return Ok(());
165 }
166
167 let res = self.many(pks)?;
168 res.require_len(expected)?;
169
170 Ok(())
171 }
172
173 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), Error>
175 where
176 F: FnOnce(FilterDsl) -> I,
177 I: IntoFilterExpr,
178 {
179 if self.exists_filter(f)? {
180 Ok(())
181 } else {
182 Err(ResponseError::NotFound { entity: E::PATH }.into())
183 }
184 }
185
186 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, Error> {
192 QueryValidate::<E>::validate(&query)?;
193
194 Ok(plan_for::<E>(query.filter.as_ref()))
195 }
196
197 fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, Error> {
198 QueryValidate::<E>::validate(query)?;
199
200 let ctx = self.db.context::<E>();
201 let plan = plan_for::<E>(query.filter.as_ref());
202
203 if let Some(lim) = &query.limit {
204 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
205 } else {
206 Ok(ctx.rows_from_plan(plan)?)
207 }
208 }
209
210 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, Error> {
212 let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
213 QueryValidate::<E>::validate(&query)?;
214
215 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
216
217 let ctx = self.db.context::<E>();
218 let plan = plan_for::<E>(query.filter.as_ref());
219
220 self.debug_log(format!("📄 Query plan: {plan:?}"));
221
222 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
224 let mut rows: Vec<(Key, E)> = if pre_paginated {
225 let data_rows = self.execute_raw(&query)?;
226
227 self.debug_log(format!(
228 "📦 Scanned {} data rows before deserialization",
229 data_rows.len()
230 ));
231
232 let rows = ctx.deserialize_rows(data_rows)?;
233 self.debug_log(format!(
234 "🧩 Deserialized {} entities before filtering",
235 rows.len()
236 ));
237 rows
238 } else {
239 let data_rows = ctx.rows_from_plan(plan)?;
240 self.debug_log(format!(
241 "📦 Scanned {} data rows before deserialization",
242 data_rows.len()
243 ));
244
245 let rows = ctx.deserialize_rows(data_rows)?;
246 self.debug_log(format!(
247 "🧩 Deserialized {} entities before filtering",
248 rows.len()
249 ));
250
251 rows
252 };
253
254 if let Some(f) = &query.filter {
256 let simplified = f.clone().simplify();
257 Self::apply_filter(&mut rows, &simplified);
258
259 self.debug_log(format!(
260 "🔎 Applied filter -> {} entities remaining",
261 rows.len()
262 ));
263 }
264
265 if let Some(sort) = &query.sort
267 && rows.len() > 1
268 {
269 Self::apply_sort(&mut rows, sort);
270 self.debug_log("↕️ Applied sort expression");
271 }
272
273 if let Some(lim) = &query.limit
275 && !pre_paginated
276 {
277 apply_pagination(&mut rows, lim.offset, lim.limit);
278 self.debug_log(format!(
279 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
280 lim.offset,
281 lim.limit,
282 rows.len()
283 ));
284 }
285
286 set_rows_from_len(&mut span, rows.len());
287 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
288
289 Ok(Response(rows))
290 }
291
292 pub fn count(&self, query: LoadQuery) -> Result<u32, Error> {
294 Ok(self.execute(query)?.count())
295 }
296
297 pub fn count_all(&self) -> Result<u32, Error> {
298 self.count(LoadQuery::new())
299 }
300
301 pub fn group_count_by<K, F>(
310 &self,
311 query: LoadQuery,
312 key_fn: F,
313 ) -> Result<HashMap<K, u32>, Error>
314 where
315 K: Eq + Hash,
316 F: Fn(&E) -> K,
317 {
318 let entities = self.execute(query)?.entities();
319
320 let mut counts = HashMap::new();
321 for e in entities {
322 *counts.entry(key_fn(&e)).or_insert(0) += 1;
323 }
324
325 Ok(counts)
326 }
327
328 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
334 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
335 }
336
337 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
339 rows.sort_by(|(_, ea), (_, eb)| {
340 for (field, direction) in sort_expr.iter() {
341 let va = ea.get_value(field);
342 let vb = eb.get_value(field);
343
344 let ordering = match (va, vb) {
346 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
350 Some(ord) => ord,
351 None => continue, },
353 };
354
355 let ordering = match direction {
357 Order::Asc => ordering,
358 Order::Desc => ordering.reverse(),
359 };
360
361 if ordering != Ordering::Equal {
362 return ordering;
363 }
364 }
365
366 Ordering::Equal
368 });
369 }
370}
371
372fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
374 let total = rows.len();
375 let start = usize::min(offset as usize, total);
376 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
377
378 if start >= end {
379 rows.clear();
380 } else {
381 rows.drain(..start);
382 rows.truncate(end - start);
383 }
384}
385
386#[cfg(test)]
391mod tests {
392 use super::{LoadExecutor, apply_pagination};
393 use crate::{
394 IndexSpec, Key, Value,
395 db::primitives::{Order, SortExpr},
396 traits::{
397 CanisterKind, EntityKind, FieldValues, Path, SanitizeAuto, SanitizeCustom, StoreKind,
398 ValidateAuto, ValidateCustom, View, Visitable,
399 },
400 };
401 use serde::{Deserialize, Serialize};
402
403 #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
404 struct SortableEntity {
405 id: u64,
406 primary: i32,
407 secondary: i32,
408 optional_blob: Option<Vec<u8>>,
409 }
410
411 impl SortableEntity {
412 fn new(id: u64, primary: i32, secondary: i32, optional_blob: Option<Vec<u8>>) -> Self {
413 Self {
414 id,
415 primary,
416 secondary,
417 optional_blob,
418 }
419 }
420 }
421
422 struct SortableCanister;
423 struct SortableStore;
424
425 impl Path for SortableCanister {
426 const PATH: &'static str = "test::canister";
427 }
428
429 impl CanisterKind for SortableCanister {}
430
431 impl Path for SortableStore {
432 const PATH: &'static str = "test::store";
433 }
434
435 impl StoreKind for SortableStore {
436 type Canister = SortableCanister;
437 }
438
439 impl Path for SortableEntity {
440 const PATH: &'static str = "test::sortable";
441 }
442
443 impl View for SortableEntity {
444 type ViewType = Self;
445
446 fn to_view(&self) -> Self::ViewType {
447 self.clone()
448 }
449
450 fn from_view(view: Self::ViewType) -> Self {
451 view
452 }
453 }
454
455 impl SanitizeAuto for SortableEntity {}
456 impl SanitizeCustom for SortableEntity {}
457 impl ValidateAuto for SortableEntity {}
458 impl ValidateCustom for SortableEntity {}
459 impl Visitable for SortableEntity {}
460
461 impl FieldValues for SortableEntity {
462 fn get_value(&self, field: &str) -> Option<Value> {
463 match field {
464 "id" => Some(Value::Uint(self.id)),
465 "primary" => Some(Value::Int(i64::from(self.primary))),
466 "secondary" => Some(Value::Int(i64::from(self.secondary))),
467 "optional_blob" => self.optional_blob.clone().map(Value::Blob),
468 _ => None,
469 }
470 }
471 }
472
473 impl EntityKind for SortableEntity {
474 type PrimaryKey = u64;
475 type Store = SortableStore;
476 type Canister = SortableCanister;
477
478 const ENTITY_ID: u64 = 99;
479 const PRIMARY_KEY: &'static str = "id";
480 const FIELDS: &'static [&'static str] = &["id", "primary", "secondary", "optional_blob"];
481 const INDEXES: &'static [&'static IndexSpec] = &[];
482
483 fn key(&self) -> Key {
484 self.id.into()
485 }
486
487 fn primary_key(&self) -> Self::PrimaryKey {
488 self.id
489 }
490 }
491
492 #[test]
493 fn pagination_empty_vec() {
494 let mut v: Vec<i32> = vec![];
495 apply_pagination(&mut v, 0, Some(10));
496 assert!(v.is_empty());
497 }
498
499 #[test]
500 fn pagination_offset_beyond_len_clears() {
501 let mut v = vec![1, 2, 3];
502 apply_pagination(&mut v, 10, Some(5));
503 assert!(v.is_empty());
504 }
505
506 #[test]
507 fn pagination_no_limit_from_offset() {
508 let mut v = vec![1, 2, 3, 4, 5];
509 apply_pagination(&mut v, 2, None);
510 assert_eq!(v, vec![3, 4, 5]);
511 }
512
513 #[test]
514 fn pagination_exact_window() {
515 let mut v = vec![10, 20, 30, 40, 50];
516 apply_pagination(&mut v, 1, Some(3));
518 assert_eq!(v, vec![20, 30, 40]);
519 }
520
521 #[test]
522 fn pagination_limit_exceeds_tail() {
523 let mut v = vec![10, 20, 30];
524 apply_pagination(&mut v, 1, Some(999));
526 assert_eq!(v, vec![20, 30]);
527 }
528
529 #[test]
530 fn apply_sort_orders_descending() {
531 let mut rows = vec![
532 (Key::from(1_u64), SortableEntity::new(1, 10, 1, None)),
533 (Key::from(2_u64), SortableEntity::new(2, 30, 2, None)),
534 (Key::from(3_u64), SortableEntity::new(3, 20, 3, None)),
535 ];
536 let sort_expr = SortExpr::from(vec![("primary".to_string(), Order::Desc)]);
537
538 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
539
540 let primary: Vec<i32> = rows.iter().map(|(_, e)| e.primary).collect();
541 assert_eq!(primary, vec![30, 20, 10]);
542 }
543
544 #[test]
545 fn apply_sort_uses_secondary_field_for_ties() {
546 let mut rows = vec![
547 (Key::from(1_u64), SortableEntity::new(1, 1, 5, None)),
548 (Key::from(2_u64), SortableEntity::new(2, 1, 8, None)),
549 (Key::from(3_u64), SortableEntity::new(3, 2, 3, None)),
550 ];
551 let sort_expr = SortExpr::from(vec![
552 ("primary".to_string(), Order::Asc),
553 ("secondary".to_string(), Order::Desc),
554 ]);
555
556 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
557
558 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
559 assert_eq!(ids, vec![2, 1, 3]);
560 }
561
562 #[test]
563 fn apply_sort_places_none_before_some_and_falls_back() {
564 let mut rows = vec![
565 (
566 Key::from(3_u64),
567 SortableEntity::new(3, 0, 0, Some(vec![3, 4])),
568 ),
569 (Key::from(1_u64), SortableEntity::new(1, 0, 0, None)),
570 (
571 Key::from(2_u64),
572 SortableEntity::new(2, 0, 0, Some(vec![2])),
573 ),
574 ];
575 let sort_expr = SortExpr::from(vec![
576 ("optional_blob".to_string(), Order::Asc),
577 ("id".to_string(), Order::Asc),
578 ]);
579
580 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
581
582 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
583 assert_eq!(ids, vec![1, 2, 3]);
584 }
585}