1use crate::{
2 Error, Key,
3 db::{
4 Db,
5 executor::{
6 FilterEvaluator,
7 plan::{plan_for, set_rows_from_len},
8 },
9 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10 query::{LoadQuery, QueryPlan, QueryValidate},
11 response::{Response, ResponseError},
12 store::DataRow,
13 },
14 obs::metrics,
15 traits::{EntityKind, FieldValue},
16};
17use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData};
18
19#[derive(Clone, Copy)]
24pub struct LoadExecutor<E: EntityKind> {
25 db: Db<E::Canister>,
26 debug: bool,
27 _marker: PhantomData<E>,
28}
29
30impl<E: EntityKind> LoadExecutor<E> {
31 #[must_use]
36 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
37 Self {
38 db,
39 debug,
40 _marker: PhantomData,
41 }
42 }
43
44 fn debug_log(&self, s: impl Into<String>) {
45 if self.debug {
46 println!("{}", s.into());
47 }
48 }
49
50 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, Error> {
56 self.execute(LoadQuery::new().one::<E>(value))
57 }
58
59 pub fn only(&self) -> Result<Response<E>, Error> {
61 self.execute(LoadQuery::new().one::<E>(()))
62 }
63
64 pub fn many(
66 &self,
67 values: impl IntoIterator<Item = impl FieldValue>,
68 ) -> Result<Response<E>, Error> {
69 self.execute(LoadQuery::new().many::<E>(values))
70 }
71
72 pub fn all(&self) -> Result<Response<E>, Error> {
74 self.execute(LoadQuery::new())
75 }
76
77 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, Error>
79 where
80 F: FnOnce(FilterDsl) -> I,
81 I: IntoFilterExpr,
82 {
83 self.execute(LoadQuery::new().filter(f))
84 }
85
86 pub fn require_one(&self, query: LoadQuery) -> Result<(), Error> {
92 self.execute(query)?.require_one()
93 }
94
95 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), Error> {
97 self.require_one(LoadQuery::new().one::<E>(value))
98 }
99
100 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), Error>
102 where
103 F: FnOnce(FilterDsl) -> I,
104 I: IntoFilterExpr,
105 {
106 self.require_one(LoadQuery::new().filter(f))
107 }
108
109 pub fn exists(&self, query: LoadQuery) -> Result<bool, Error> {
115 let query = query.limit_1();
116 Ok(!self.execute_raw(&query)?.is_empty())
117 }
118
119 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, Error> {
121 self.exists(LoadQuery::new().one::<E>(value))
122 }
123
124 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, Error>
126 where
127 F: FnOnce(FilterDsl) -> I,
128 I: IntoFilterExpr,
129 {
130 self.exists(LoadQuery::new().filter(f))
131 }
132
133 pub fn exists_any(&self) -> Result<bool, Error> {
135 self.exists(LoadQuery::new())
136 }
137
138 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), Error> {
144 if self.exists_one(value)? {
145 Ok(())
146 } else {
147 Err(ResponseError::NotFound { entity: E::PATH }.into())
148 }
149 }
150
151 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), Error>
153 where
154 F: FnOnce(FilterDsl) -> I,
155 I: IntoFilterExpr,
156 {
157 if self.exists_filter(f)? {
158 Ok(())
159 } else {
160 Err(ResponseError::NotFound { entity: E::PATH }.into())
161 }
162 }
163
164 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, Error> {
170 QueryValidate::<E>::validate(&query)?;
171
172 Ok(plan_for::<E>(query.filter.as_ref()))
173 }
174
175 fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, Error> {
176 QueryValidate::<E>::validate(query)?;
177
178 let ctx = self.db.context::<E>();
179 let plan = plan_for::<E>(query.filter.as_ref());
180
181 if let Some(lim) = &query.limit {
182 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
183 } else {
184 Ok(ctx.rows_from_plan(plan)?)
185 }
186 }
187
188 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, Error> {
190 let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
191 QueryValidate::<E>::validate(&query)?;
192
193 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
194
195 let ctx = self.db.context::<E>();
196 let plan = plan_for::<E>(query.filter.as_ref());
197
198 self.debug_log(format!("📄 Query plan: {plan:?}"));
199
200 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
202 let mut rows: Vec<(Key, E)> = if pre_paginated {
203 let data_rows = self.execute_raw(&query)?;
204
205 self.debug_log(format!(
206 "📦 Scanned {} data rows before deserialization",
207 data_rows.len()
208 ));
209
210 let rows = ctx.deserialize_rows(data_rows)?;
211 self.debug_log(format!(
212 "🧩 Deserialized {} entities before filtering",
213 rows.len()
214 ));
215 rows
216 } else {
217 let data_rows = ctx.rows_from_plan(plan)?;
218 self.debug_log(format!(
219 "📦 Scanned {} data rows before deserialization",
220 data_rows.len()
221 ));
222
223 let rows = ctx.deserialize_rows(data_rows)?;
224 self.debug_log(format!(
225 "🧩 Deserialized {} entities before filtering",
226 rows.len()
227 ));
228
229 rows
230 };
231
232 if let Some(f) = &query.filter {
234 let simplified = f.clone().simplify();
235 Self::apply_filter(&mut rows, &simplified);
236
237 self.debug_log(format!(
238 "🔎 Applied filter -> {} entities remaining",
239 rows.len()
240 ));
241 }
242
243 if let Some(sort) = &query.sort
245 && rows.len() > 1
246 {
247 Self::apply_sort(&mut rows, sort);
248 self.debug_log("↕️ Applied sort expression");
249 }
250
251 if let Some(lim) = &query.limit
253 && !pre_paginated
254 {
255 apply_pagination(&mut rows, lim.offset, lim.limit);
256 self.debug_log(format!(
257 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
258 lim.offset,
259 lim.limit,
260 rows.len()
261 ));
262 }
263
264 set_rows_from_len(&mut span, rows.len());
265 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
266
267 Ok(Response(rows))
268 }
269
270 pub fn count(&self, query: LoadQuery) -> Result<u32, Error> {
272 Ok(self.execute(query)?.count())
273 }
274
275 pub fn count_all(&self) -> Result<u32, Error> {
276 self.count(LoadQuery::new())
277 }
278
279 pub fn group_count_by<K, F>(
288 &self,
289 query: LoadQuery,
290 key_fn: F,
291 ) -> Result<HashMap<K, u32>, Error>
292 where
293 K: Eq + Hash,
294 F: Fn(&E) -> K,
295 {
296 let entities = self.execute(query)?.entities();
297
298 let mut counts = HashMap::new();
299 for e in entities {
300 *counts.entry(key_fn(&e)).or_insert(0) += 1;
301 }
302
303 Ok(counts)
304 }
305
306 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
312 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
313 }
314
315 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
317 rows.sort_by(|(_, ea), (_, eb)| {
318 for (field, direction) in sort_expr.iter() {
319 let va = ea.get_value(field);
320 let vb = eb.get_value(field);
321
322 let ordering = match (va, vb) {
324 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
328 Some(ord) => ord,
329 None => continue, },
331 };
332
333 let ordering = match direction {
335 Order::Asc => ordering,
336 Order::Desc => ordering.reverse(),
337 };
338
339 if ordering != Ordering::Equal {
340 return ordering;
341 }
342 }
343
344 Ordering::Equal
346 });
347 }
348}
349
350fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
352 let total = rows.len();
353 let start = usize::min(offset as usize, total);
354 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
355
356 if start >= end {
357 rows.clear();
358 } else {
359 rows.drain(..start);
360 rows.truncate(end - start);
361 }
362}
363
364#[cfg(test)]
369mod tests {
370 use super::{LoadExecutor, apply_pagination};
371 use crate::{
372 IndexSpec, Key, Value,
373 db::primitives::{Order, SortExpr},
374 traits::{
375 CanisterKind, EntityKind, FieldValues, Path, SanitizeAuto, SanitizeCustom, StoreKind,
376 ValidateAuto, ValidateCustom, View, Visitable,
377 },
378 };
379 use serde::{Deserialize, Serialize};
380
381 #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
382 struct SortableEntity {
383 id: u64,
384 primary: i32,
385 secondary: i32,
386 optional_blob: Option<Vec<u8>>,
387 }
388
389 impl SortableEntity {
390 fn new(id: u64, primary: i32, secondary: i32, optional_blob: Option<Vec<u8>>) -> Self {
391 Self {
392 id,
393 primary,
394 secondary,
395 optional_blob,
396 }
397 }
398 }
399
400 struct SortableCanister;
401 struct SortableStore;
402
403 impl Path for SortableCanister {
404 const PATH: &'static str = "test::canister";
405 }
406
407 impl CanisterKind for SortableCanister {}
408
409 impl Path for SortableStore {
410 const PATH: &'static str = "test::store";
411 }
412
413 impl StoreKind for SortableStore {
414 type Canister = SortableCanister;
415 }
416
417 impl Path for SortableEntity {
418 const PATH: &'static str = "test::sortable";
419 }
420
421 impl View for SortableEntity {
422 type ViewType = Self;
423
424 fn to_view(&self) -> Self::ViewType {
425 self.clone()
426 }
427
428 fn from_view(view: Self::ViewType) -> Self {
429 view
430 }
431 }
432
433 impl SanitizeAuto for SortableEntity {}
434 impl SanitizeCustom for SortableEntity {}
435 impl ValidateAuto for SortableEntity {}
436 impl ValidateCustom for SortableEntity {}
437 impl Visitable for SortableEntity {}
438
439 impl FieldValues for SortableEntity {
440 fn get_value(&self, field: &str) -> Option<Value> {
441 match field {
442 "id" => Some(Value::Uint(self.id)),
443 "primary" => Some(Value::Int(i64::from(self.primary))),
444 "secondary" => Some(Value::Int(i64::from(self.secondary))),
445 "optional_blob" => self.optional_blob.clone().map(Value::Blob),
446 _ => None,
447 }
448 }
449 }
450
451 impl EntityKind for SortableEntity {
452 type PrimaryKey = u64;
453 type Store = SortableStore;
454 type Canister = SortableCanister;
455
456 const ENTITY_ID: u64 = 99;
457 const PRIMARY_KEY: &'static str = "id";
458 const FIELDS: &'static [&'static str] = &["id", "primary", "secondary", "optional_blob"];
459 const INDEXES: &'static [&'static IndexSpec] = &[];
460
461 fn key(&self) -> Key {
462 self.id.into()
463 }
464
465 fn primary_key(&self) -> Self::PrimaryKey {
466 self.id
467 }
468 }
469
470 #[test]
471 fn pagination_empty_vec() {
472 let mut v: Vec<i32> = vec![];
473 apply_pagination(&mut v, 0, Some(10));
474 assert!(v.is_empty());
475 }
476
477 #[test]
478 fn pagination_offset_beyond_len_clears() {
479 let mut v = vec![1, 2, 3];
480 apply_pagination(&mut v, 10, Some(5));
481 assert!(v.is_empty());
482 }
483
484 #[test]
485 fn pagination_no_limit_from_offset() {
486 let mut v = vec![1, 2, 3, 4, 5];
487 apply_pagination(&mut v, 2, None);
488 assert_eq!(v, vec![3, 4, 5]);
489 }
490
491 #[test]
492 fn pagination_exact_window() {
493 let mut v = vec![10, 20, 30, 40, 50];
494 apply_pagination(&mut v, 1, Some(3));
496 assert_eq!(v, vec![20, 30, 40]);
497 }
498
499 #[test]
500 fn pagination_limit_exceeds_tail() {
501 let mut v = vec![10, 20, 30];
502 apply_pagination(&mut v, 1, Some(999));
504 assert_eq!(v, vec![20, 30]);
505 }
506
507 #[test]
508 fn apply_sort_orders_descending() {
509 let mut rows = vec![
510 (Key::from(1_u64), SortableEntity::new(1, 10, 1, None)),
511 (Key::from(2_u64), SortableEntity::new(2, 30, 2, None)),
512 (Key::from(3_u64), SortableEntity::new(3, 20, 3, None)),
513 ];
514 let sort_expr = SortExpr::from(vec![("primary".to_string(), Order::Desc)]);
515
516 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
517
518 let primary: Vec<i32> = rows.iter().map(|(_, e)| e.primary).collect();
519 assert_eq!(primary, vec![30, 20, 10]);
520 }
521
522 #[test]
523 fn apply_sort_uses_secondary_field_for_ties() {
524 let mut rows = vec![
525 (Key::from(1_u64), SortableEntity::new(1, 1, 5, None)),
526 (Key::from(2_u64), SortableEntity::new(2, 1, 8, None)),
527 (Key::from(3_u64), SortableEntity::new(3, 2, 3, None)),
528 ];
529 let sort_expr = SortExpr::from(vec![
530 ("primary".to_string(), Order::Asc),
531 ("secondary".to_string(), Order::Desc),
532 ]);
533
534 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
535
536 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
537 assert_eq!(ids, vec![2, 1, 3]);
538 }
539
540 #[test]
541 fn apply_sort_places_none_before_some_and_falls_back() {
542 let mut rows = vec![
543 (
544 Key::from(3_u64),
545 SortableEntity::new(3, 0, 0, Some(vec![3, 4])),
546 ),
547 (Key::from(1_u64), SortableEntity::new(1, 0, 0, None)),
548 (
549 Key::from(2_u64),
550 SortableEntity::new(2, 0, 0, Some(vec![2])),
551 ),
552 ];
553 let sort_expr = SortExpr::from(vec![
554 ("optional_blob".to_string(), Order::Asc),
555 ("id".to_string(), Order::Asc),
556 ]);
557
558 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
559
560 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
561 assert_eq!(ids, vec![1, 2, 3]);
562 }
563}