1use crate::{
2 Error, Key,
3 db::{
4 Db,
5 executor::{
6 FilterEvaluator,
7 plan::{plan_for, scan_plan, set_rows_from_len},
8 },
9 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10 query::{LoadQuery, QueryPlan, QueryValidate},
11 response::{Response, ResponseError},
12 store::DataRow,
13 },
14 obs::metrics,
15 traits::{EntityKind, FieldValue},
16};
17use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData, ops::ControlFlow};
18
19#[derive(Clone, Copy)]
24pub struct LoadExecutor<E: EntityKind> {
25 db: Db<E::Canister>,
26 debug: bool,
27 _marker: PhantomData<E>,
28}
29
30impl<E: EntityKind> LoadExecutor<E> {
31 #[must_use]
36 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
37 Self {
38 db,
39 debug,
40 _marker: PhantomData,
41 }
42 }
43
44 fn debug_log(&self, s: impl Into<String>) {
45 if self.debug {
46 println!("{}", s.into());
47 }
48 }
49
50 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, Error> {
56 self.execute(LoadQuery::new().one::<E>(value))
57 }
58
59 pub fn only(&self) -> Result<Response<E>, Error> {
61 self.execute(LoadQuery::new().one::<E>(()))
62 }
63
64 pub fn many<I, V>(&self, values: I) -> Result<Response<E>, Error>
66 where
67 I: IntoIterator<Item = V>,
68 V: FieldValue,
69 {
70 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
71 self.execute(query)
72 }
73
74 pub fn all(&self) -> Result<Response<E>, Error> {
76 self.execute(LoadQuery::new())
77 }
78
79 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, Error>
81 where
82 F: FnOnce(FilterDsl) -> I,
83 I: IntoFilterExpr,
84 {
85 self.execute(LoadQuery::new().filter(f))
86 }
87
88 pub fn require_one(&self, query: LoadQuery) -> Result<(), Error> {
94 self.execute(query)?.require_one()
95 }
96
97 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), Error> {
99 self.require_one(LoadQuery::new().one::<E>(value))
100 }
101
102 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), Error>
104 where
105 F: FnOnce(FilterDsl) -> I,
106 I: IntoFilterExpr,
107 {
108 self.require_one(LoadQuery::new().filter(f))
109 }
110
111 pub fn exists(&self, query: LoadQuery) -> Result<bool, Error> {
117 let query = query.limit_1();
118 QueryValidate::<E>::validate(&query)?;
119
120 let plan = plan_for::<E>(query.filter.as_ref());
121 let filter = query.filter.map(FilterExpr::simplify);
122 let mut found = false;
123
124 scan_plan::<E, _>(&self.db, plan, |_, entity| {
125 let matches = filter
126 .as_ref()
127 .is_none_or(|f| FilterEvaluator::new(&entity).eval(f));
128
129 if matches {
130 found = true;
131 ControlFlow::Break(())
132 } else {
133 ControlFlow::Continue(())
134 }
135 })?;
136
137 Ok(found)
138 }
139
140 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, Error> {
142 self.exists(LoadQuery::new().one::<E>(value))
143 }
144
145 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, Error>
147 where
148 F: FnOnce(FilterDsl) -> I,
149 I: IntoFilterExpr,
150 {
151 self.exists(LoadQuery::new().filter(f))
152 }
153
154 pub fn exists_any(&self) -> Result<bool, Error> {
156 self.exists(LoadQuery::new())
157 }
158
159 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), Error> {
165 if self.exists_one(value)? {
166 Ok(())
167 } else {
168 Err(ResponseError::NotFound { entity: E::PATH }.into())
169 }
170 }
171
172 #[allow(clippy::cast_possible_truncation)]
174 pub fn ensure_exists_many<I, V>(&self, values: I) -> Result<(), Error>
175 where
176 I: IntoIterator<Item = V>,
177 V: FieldValue,
178 {
179 let pks: Vec<_> = values.into_iter().collect();
180
181 let expected = pks.len() as u32;
182 if expected == 0 {
183 return Ok(());
184 }
185
186 let res = self.many(pks)?;
187 res.require_len(expected)?;
188
189 Ok(())
190 }
191
192 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), Error>
194 where
195 F: FnOnce(FilterDsl) -> I,
196 I: IntoFilterExpr,
197 {
198 if self.exists_filter(f)? {
199 Ok(())
200 } else {
201 Err(ResponseError::NotFound { entity: E::PATH }.into())
202 }
203 }
204
205 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, Error> {
211 QueryValidate::<E>::validate(&query)?;
212
213 Ok(plan_for::<E>(query.filter.as_ref()))
214 }
215
216 fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, Error> {
217 QueryValidate::<E>::validate(query)?;
218
219 let ctx = self.db.context::<E>();
220 let plan = plan_for::<E>(query.filter.as_ref());
221
222 if let Some(lim) = &query.limit {
223 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
224 } else {
225 Ok(ctx.rows_from_plan(plan)?)
226 }
227 }
228
229 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, Error> {
231 let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
232 QueryValidate::<E>::validate(&query)?;
233
234 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
235
236 let ctx = self.db.context::<E>();
237 let plan = plan_for::<E>(query.filter.as_ref());
238
239 self.debug_log(format!("📄 Query plan: {plan:?}"));
240
241 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
243 let mut rows: Vec<(Key, E)> = if pre_paginated {
244 let data_rows = self.execute_raw(&query)?;
245
246 self.debug_log(format!(
247 "📦 Scanned {} data rows before deserialization",
248 data_rows.len()
249 ));
250
251 let rows = ctx.deserialize_rows(data_rows)?;
252 self.debug_log(format!(
253 "🧩 Deserialized {} entities before filtering",
254 rows.len()
255 ));
256 rows
257 } else {
258 let data_rows = ctx.rows_from_plan(plan)?;
259 self.debug_log(format!(
260 "📦 Scanned {} data rows before deserialization",
261 data_rows.len()
262 ));
263
264 let rows = ctx.deserialize_rows(data_rows)?;
265 self.debug_log(format!(
266 "🧩 Deserialized {} entities before filtering",
267 rows.len()
268 ));
269
270 rows
271 };
272
273 if let Some(f) = &query.filter {
275 let simplified = f.clone().simplify();
276 Self::apply_filter(&mut rows, &simplified);
277
278 self.debug_log(format!(
279 "🔎 Applied filter -> {} entities remaining",
280 rows.len()
281 ));
282 }
283
284 if let Some(sort) = &query.sort
286 && rows.len() > 1
287 {
288 Self::apply_sort(&mut rows, sort);
289 self.debug_log("↕️ Applied sort expression");
290 }
291
292 if let Some(lim) = &query.limit
294 && !pre_paginated
295 {
296 apply_pagination(&mut rows, lim.offset, lim.limit);
297 self.debug_log(format!(
298 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
299 lim.offset,
300 lim.limit,
301 rows.len()
302 ));
303 }
304
305 set_rows_from_len(&mut span, rows.len());
306 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
307
308 Ok(Response(rows))
309 }
310
311 pub fn count(&self, query: LoadQuery) -> Result<u32, Error> {
313 Ok(self.execute(query)?.count())
314 }
315
316 pub fn count_all(&self) -> Result<u32, Error> {
317 self.count(LoadQuery::new())
318 }
319
320 pub fn group_count_by<K, F>(
329 &self,
330 query: LoadQuery,
331 key_fn: F,
332 ) -> Result<HashMap<K, u32>, Error>
333 where
334 K: Eq + Hash,
335 F: Fn(&E) -> K,
336 {
337 let entities = self.execute(query)?.entities();
338
339 let mut counts = HashMap::new();
340 for e in entities {
341 *counts.entry(key_fn(&e)).or_insert(0) += 1;
342 }
343
344 Ok(counts)
345 }
346
347 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
353 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
354 }
355
356 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
358 rows.sort_by(|(_, ea), (_, eb)| {
359 for (field, direction) in sort_expr.iter() {
360 let va = ea.get_value(field);
361 let vb = eb.get_value(field);
362
363 let ordering = match (va, vb) {
365 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
369 Some(ord) => ord,
370 None => continue, },
372 };
373
374 let ordering = match direction {
376 Order::Asc => ordering,
377 Order::Desc => ordering.reverse(),
378 };
379
380 if ordering != Ordering::Equal {
381 return ordering;
382 }
383 }
384
385 Ordering::Equal
387 });
388 }
389}
390
391fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
393 let total = rows.len();
394 let start = usize::min(offset as usize, total);
395 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
396
397 if start >= end {
398 rows.clear();
399 } else {
400 rows.drain(..start);
401 rows.truncate(end - start);
402 }
403}
404
405#[cfg(test)]
410mod tests {
411 use super::{LoadExecutor, apply_pagination};
412 use crate::{
413 IndexSpec, Key, Value,
414 db::primitives::{Order, SortExpr},
415 traits::{
416 CanisterKind, EntityKind, FieldValues, Path, SanitizeAuto, SanitizeCustom, StoreKind,
417 ValidateAuto, ValidateCustom, View, Visitable,
418 },
419 };
420 use serde::{Deserialize, Serialize};
421
422 #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
423 struct SortableEntity {
424 id: u64,
425 primary: i32,
426 secondary: i32,
427 optional_blob: Option<Vec<u8>>,
428 }
429
430 impl SortableEntity {
431 fn new(id: u64, primary: i32, secondary: i32, optional_blob: Option<Vec<u8>>) -> Self {
432 Self {
433 id,
434 primary,
435 secondary,
436 optional_blob,
437 }
438 }
439 }
440
441 struct SortableCanister;
442 struct SortableStore;
443
444 impl Path for SortableCanister {
445 const PATH: &'static str = "test::canister";
446 }
447
448 impl CanisterKind for SortableCanister {}
449
450 impl Path for SortableStore {
451 const PATH: &'static str = "test::store";
452 }
453
454 impl StoreKind for SortableStore {
455 type Canister = SortableCanister;
456 }
457
458 impl Path for SortableEntity {
459 const PATH: &'static str = "test::sortable";
460 }
461
462 impl View for SortableEntity {
463 type ViewType = Self;
464
465 fn to_view(&self) -> Self::ViewType {
466 self.clone()
467 }
468
469 fn from_view(view: Self::ViewType) -> Self {
470 view
471 }
472 }
473
474 impl SanitizeAuto for SortableEntity {}
475 impl SanitizeCustom for SortableEntity {}
476 impl ValidateAuto for SortableEntity {}
477 impl ValidateCustom for SortableEntity {}
478 impl Visitable for SortableEntity {}
479
480 impl FieldValues for SortableEntity {
481 fn get_value(&self, field: &str) -> Option<Value> {
482 match field {
483 "id" => Some(Value::Uint(self.id)),
484 "primary" => Some(Value::Int(i64::from(self.primary))),
485 "secondary" => Some(Value::Int(i64::from(self.secondary))),
486 "optional_blob" => self.optional_blob.clone().map(Value::Blob),
487 _ => None,
488 }
489 }
490 }
491
492 impl EntityKind for SortableEntity {
493 type PrimaryKey = u64;
494 type Store = SortableStore;
495 type Canister = SortableCanister;
496
497 const ENTITY_ID: u64 = 99;
498 const PRIMARY_KEY: &'static str = "id";
499 const FIELDS: &'static [&'static str] = &["id", "primary", "secondary", "optional_blob"];
500 const INDEXES: &'static [&'static IndexSpec] = &[];
501
502 fn key(&self) -> Key {
503 self.id.into()
504 }
505
506 fn primary_key(&self) -> Self::PrimaryKey {
507 self.id
508 }
509
510 fn set_primary_key(&mut self, key: Self::PrimaryKey) {
511 self.id = key;
512 }
513 }
514
515 #[test]
516 fn pagination_empty_vec() {
517 let mut v: Vec<i32> = vec![];
518 apply_pagination(&mut v, 0, Some(10));
519 assert!(v.is_empty());
520 }
521
522 #[test]
523 fn pagination_offset_beyond_len_clears() {
524 let mut v = vec![1, 2, 3];
525 apply_pagination(&mut v, 10, Some(5));
526 assert!(v.is_empty());
527 }
528
529 #[test]
530 fn pagination_no_limit_from_offset() {
531 let mut v = vec![1, 2, 3, 4, 5];
532 apply_pagination(&mut v, 2, None);
533 assert_eq!(v, vec![3, 4, 5]);
534 }
535
536 #[test]
537 fn pagination_exact_window() {
538 let mut v = vec![10, 20, 30, 40, 50];
539 apply_pagination(&mut v, 1, Some(3));
541 assert_eq!(v, vec![20, 30, 40]);
542 }
543
544 #[test]
545 fn pagination_limit_exceeds_tail() {
546 let mut v = vec![10, 20, 30];
547 apply_pagination(&mut v, 1, Some(999));
549 assert_eq!(v, vec![20, 30]);
550 }
551
552 #[test]
553 fn apply_sort_orders_descending() {
554 let mut rows = vec![
555 (Key::from(1_u64), SortableEntity::new(1, 10, 1, None)),
556 (Key::from(2_u64), SortableEntity::new(2, 30, 2, None)),
557 (Key::from(3_u64), SortableEntity::new(3, 20, 3, None)),
558 ];
559 let sort_expr = SortExpr::from(vec![("primary".to_string(), Order::Desc)]);
560
561 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
562
563 let primary: Vec<i32> = rows.iter().map(|(_, e)| e.primary).collect();
564 assert_eq!(primary, vec![30, 20, 10]);
565 }
566
567 #[test]
568 fn apply_sort_uses_secondary_field_for_ties() {
569 let mut rows = vec![
570 (Key::from(1_u64), SortableEntity::new(1, 1, 5, None)),
571 (Key::from(2_u64), SortableEntity::new(2, 1, 8, None)),
572 (Key::from(3_u64), SortableEntity::new(3, 2, 3, None)),
573 ];
574 let sort_expr = SortExpr::from(vec![
575 ("primary".to_string(), Order::Asc),
576 ("secondary".to_string(), Order::Desc),
577 ]);
578
579 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
580
581 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
582 assert_eq!(ids, vec![2, 1, 3]);
583 }
584
585 #[test]
586 fn apply_sort_places_none_before_some_and_falls_back() {
587 let mut rows = vec![
588 (
589 Key::from(3_u64),
590 SortableEntity::new(3, 0, 0, Some(vec![3, 4])),
591 ),
592 (Key::from(1_u64), SortableEntity::new(1, 0, 0, None)),
593 (
594 Key::from(2_u64),
595 SortableEntity::new(2, 0, 0, Some(vec![2])),
596 ),
597 ];
598 let sort_expr = SortExpr::from(vec![
599 ("optional_blob".to_string(), Order::Asc),
600 ("id".to_string(), Order::Asc),
601 ]);
602
603 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
604
605 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
606 assert_eq!(ids, vec![1, 2, 3]);
607 }
608}