1use crate::{
2 Error, Key,
3 db::{
4 Db,
5 executor::{
6 FilterEvaluator,
7 plan::{plan_for, set_rows_from_len},
8 },
9 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10 query::{LoadQuery, QueryPlan, QueryValidate},
11 response::Response,
12 store::DataRow,
13 },
14 obs::metrics,
15 traits::{EntityKind, FieldValue},
16};
17use std::{cmp::Ordering, marker::PhantomData};
18
19#[derive(Clone, Copy)]
24pub struct LoadExecutor<E: EntityKind> {
25 db: Db<E::Canister>,
26 debug: bool,
27 _marker: PhantomData<E>,
28}
29
30impl<E: EntityKind> LoadExecutor<E> {
31 #[must_use]
36 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
37 Self {
38 db,
39 debug,
40 _marker: PhantomData,
41 }
42 }
43
44 #[inline]
45 fn debug_log(&self, s: impl Into<String>) {
46 if self.debug {
47 println!("{}", s.into());
48 }
49 }
50
51 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, Error> {
57 self.execute(LoadQuery::new().one::<E>(value))
58 }
59
60 pub fn only(&self) -> Result<Response<E>, Error> {
62 self.execute(LoadQuery::new().one::<E>(()))
63 }
64
65 pub fn many(
67 &self,
68 values: impl IntoIterator<Item = impl FieldValue>,
69 ) -> Result<Response<E>, Error> {
70 self.execute(LoadQuery::new().many::<E>(values))
71 }
72
73 pub fn all(&self) -> Result<Response<E>, Error> {
75 self.execute(LoadQuery::new())
76 }
77
78 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, Error>
80 where
81 F: FnOnce(FilterDsl) -> I,
82 I: IntoFilterExpr,
83 {
84 self.execute(LoadQuery::new().filter(f))
85 }
86
87 pub fn require_one(&self, query: LoadQuery) -> Result<(), Error> {
93 self.execute(query)?.require_one()
94 }
95
96 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), Error> {
98 self.require_one(LoadQuery::new().one::<E>(value))
99 }
100
101 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), Error>
103 where
104 F: FnOnce(FilterDsl) -> I,
105 I: IntoFilterExpr,
106 {
107 self.require_one(LoadQuery::new().filter(f))
108 }
109
110 pub fn exists(&self, query: LoadQuery) -> Result<bool, Error> {
116 let query = query.limit_1();
117 Ok(!self.execute_raw(&query)?.is_empty())
118 }
119
120 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, Error> {
122 self.exists(LoadQuery::new().one::<E>(value))
123 }
124
125 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, Error>
127 where
128 F: FnOnce(FilterDsl) -> I,
129 I: IntoFilterExpr,
130 {
131 self.exists(LoadQuery::new().filter(f))
132 }
133
134 pub fn exists_any(&self) -> Result<bool, Error> {
136 self.exists(LoadQuery::new())
137 }
138
139 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, Error> {
145 QueryValidate::<E>::validate(&query)?;
146
147 Ok(plan_for::<E>(query.filter.as_ref()))
148 }
149
150 fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, Error> {
151 QueryValidate::<E>::validate(query)?;
152
153 let ctx = self.db.context::<E>();
154 let plan = plan_for::<E>(query.filter.as_ref());
155
156 if let Some(lim) = &query.limit {
157 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
158 } else {
159 Ok(ctx.rows_from_plan(plan)?)
160 }
161 }
162
163 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, Error> {
165 let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
166 QueryValidate::<E>::validate(&query)?;
167
168 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
169
170 let ctx = self.db.context::<E>();
171 let plan = plan_for::<E>(query.filter.as_ref());
172
173 self.debug_log(format!("📄 Query plan: {plan:?}"));
174
175 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
177 let mut rows: Vec<(Key, E)> = if pre_paginated {
178 let data_rows = self.execute_raw(&query)?;
179
180 self.debug_log(format!(
181 "📦 Scanned {} data rows before deserialization",
182 data_rows.len()
183 ));
184
185 let rows = ctx.deserialize_rows(data_rows)?;
186 self.debug_log(format!(
187 "🧩 Deserialized {} entities before filtering",
188 rows.len()
189 ));
190 rows
191 } else {
192 let data_rows = ctx.rows_from_plan(plan)?;
193 self.debug_log(format!(
194 "📦 Scanned {} data rows before deserialization",
195 data_rows.len()
196 ));
197
198 let rows = ctx.deserialize_rows(data_rows)?;
199 self.debug_log(format!(
200 "🧩 Deserialized {} entities before filtering",
201 rows.len()
202 ));
203
204 rows
205 };
206
207 if let Some(f) = &query.filter {
209 let simplified = f.clone().simplify();
210 Self::apply_filter(&mut rows, &simplified);
211
212 self.debug_log(format!(
213 "🔎 Applied filter -> {} entities remaining",
214 rows.len()
215 ));
216 }
217
218 if let Some(sort) = &query.sort
220 && rows.len() > 1
221 {
222 Self::apply_sort(&mut rows, sort);
223 self.debug_log("↕️ Applied sort expression");
224 }
225
226 if let Some(lim) = &query.limit
228 && !pre_paginated
229 {
230 apply_pagination(&mut rows, lim.offset, lim.limit);
231 self.debug_log(format!(
232 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
233 lim.offset,
234 lim.limit,
235 rows.len()
236 ));
237 }
238
239 set_rows_from_len(&mut span, rows.len());
240 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
241
242 Ok(Response(rows))
243 }
244
245 pub fn count(&self, query: LoadQuery) -> Result<u32, Error> {
247 Ok(self.execute(query)?.count())
248 }
249
250 pub fn count_all(&self) -> Result<u32, Error> {
251 self.count(LoadQuery::new())
252 }
253
254 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
256 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
257 }
258
259 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
261 rows.sort_by(|(_, ea), (_, eb)| {
262 for (field, direction) in sort_expr.iter() {
263 let va = ea.get_value(field);
264 let vb = eb.get_value(field);
265
266 let ordering = match (va, vb) {
268 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
272 Some(ord) => ord,
273 None => continue, },
275 };
276
277 let ordering = match direction {
279 Order::Asc => ordering,
280 Order::Desc => ordering.reverse(),
281 };
282
283 if ordering != Ordering::Equal {
284 return ordering;
285 }
286 }
287
288 Ordering::Equal
290 });
291 }
292}
293
294fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
296 let total = rows.len();
297 let start = usize::min(offset as usize, total);
298 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
299
300 if start >= end {
301 rows.clear();
302 } else {
303 rows.drain(..start);
304 rows.truncate(end - start);
305 }
306}
307
308#[cfg(test)]
313mod tests {
314 use super::{LoadExecutor, apply_pagination};
315 use crate::{
316 IndexSpec, Key, Value,
317 db::primitives::{Order, SortExpr},
318 traits::{
319 CanisterKind, EntityKind, FieldValues, Path, SanitizeAuto, SanitizeCustom, StoreKind,
320 ValidateAuto, ValidateCustom, View, Visitable,
321 },
322 };
323 use serde::{Deserialize, Serialize};
324
325 #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
326 struct SortableEntity {
327 id: u64,
328 primary: i32,
329 secondary: i32,
330 optional_blob: Option<Vec<u8>>,
331 }
332
333 impl SortableEntity {
334 fn new(id: u64, primary: i32, secondary: i32, optional_blob: Option<Vec<u8>>) -> Self {
335 Self {
336 id,
337 primary,
338 secondary,
339 optional_blob,
340 }
341 }
342 }
343
344 struct SortableCanister;
345 struct SortableStore;
346
347 impl Path for SortableCanister {
348 const PATH: &'static str = "test::canister";
349 }
350
351 impl CanisterKind for SortableCanister {}
352
353 impl Path for SortableStore {
354 const PATH: &'static str = "test::store";
355 }
356
357 impl StoreKind for SortableStore {
358 type Canister = SortableCanister;
359 }
360
361 impl Path for SortableEntity {
362 const PATH: &'static str = "test::sortable";
363 }
364
365 impl View for SortableEntity {
366 type ViewType = Self;
367
368 fn to_view(&self) -> Self::ViewType {
369 self.clone()
370 }
371
372 fn from_view(view: Self::ViewType) -> Self {
373 view
374 }
375 }
376
377 impl SanitizeAuto for SortableEntity {}
378 impl SanitizeCustom for SortableEntity {}
379 impl ValidateAuto for SortableEntity {}
380 impl ValidateCustom for SortableEntity {}
381 impl Visitable for SortableEntity {}
382
383 impl FieldValues for SortableEntity {
384 fn get_value(&self, field: &str) -> Option<Value> {
385 match field {
386 "id" => Some(Value::Uint(self.id)),
387 "primary" => Some(Value::Int(i64::from(self.primary))),
388 "secondary" => Some(Value::Int(i64::from(self.secondary))),
389 "optional_blob" => self.optional_blob.clone().map(Value::Blob),
390 _ => None,
391 }
392 }
393 }
394
395 impl EntityKind for SortableEntity {
396 type PrimaryKey = u64;
397 type Store = SortableStore;
398 type Canister = SortableCanister;
399
400 const ENTITY_ID: u64 = 99;
401 const PRIMARY_KEY: &'static str = "id";
402 const FIELDS: &'static [&'static str] = &["id", "primary", "secondary", "optional_blob"];
403 const INDEXES: &'static [&'static IndexSpec] = &[];
404
405 fn key(&self) -> Key {
406 self.id.into()
407 }
408
409 fn primary_key(&self) -> Self::PrimaryKey {
410 self.id
411 }
412 }
413
414 #[test]
415 fn pagination_empty_vec() {
416 let mut v: Vec<i32> = vec![];
417 apply_pagination(&mut v, 0, Some(10));
418 assert!(v.is_empty());
419 }
420
421 #[test]
422 fn pagination_offset_beyond_len_clears() {
423 let mut v = vec![1, 2, 3];
424 apply_pagination(&mut v, 10, Some(5));
425 assert!(v.is_empty());
426 }
427
428 #[test]
429 fn pagination_no_limit_from_offset() {
430 let mut v = vec![1, 2, 3, 4, 5];
431 apply_pagination(&mut v, 2, None);
432 assert_eq!(v, vec![3, 4, 5]);
433 }
434
435 #[test]
436 fn pagination_exact_window() {
437 let mut v = vec![10, 20, 30, 40, 50];
438 apply_pagination(&mut v, 1, Some(3));
440 assert_eq!(v, vec![20, 30, 40]);
441 }
442
443 #[test]
444 fn pagination_limit_exceeds_tail() {
445 let mut v = vec![10, 20, 30];
446 apply_pagination(&mut v, 1, Some(999));
448 assert_eq!(v, vec![20, 30]);
449 }
450
451 #[test]
452 fn apply_sort_orders_descending() {
453 let mut rows = vec![
454 (Key::from(1_u64), SortableEntity::new(1, 10, 1, None)),
455 (Key::from(2_u64), SortableEntity::new(2, 30, 2, None)),
456 (Key::from(3_u64), SortableEntity::new(3, 20, 3, None)),
457 ];
458 let sort_expr = SortExpr::from(vec![("primary".to_string(), Order::Desc)]);
459
460 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
461
462 let primary: Vec<i32> = rows.iter().map(|(_, e)| e.primary).collect();
463 assert_eq!(primary, vec![30, 20, 10]);
464 }
465
466 #[test]
467 fn apply_sort_uses_secondary_field_for_ties() {
468 let mut rows = vec![
469 (Key::from(1_u64), SortableEntity::new(1, 1, 5, None)),
470 (Key::from(2_u64), SortableEntity::new(2, 1, 8, None)),
471 (Key::from(3_u64), SortableEntity::new(3, 2, 3, None)),
472 ];
473 let sort_expr = SortExpr::from(vec![
474 ("primary".to_string(), Order::Asc),
475 ("secondary".to_string(), Order::Desc),
476 ]);
477
478 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
479
480 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
481 assert_eq!(ids, vec![2, 1, 3]);
482 }
483
484 #[test]
485 fn apply_sort_places_none_before_some_and_falls_back() {
486 let mut rows = vec![
487 (
488 Key::from(3_u64),
489 SortableEntity::new(3, 0, 0, Some(vec![3, 4])),
490 ),
491 (Key::from(1_u64), SortableEntity::new(1, 0, 0, None)),
492 (
493 Key::from(2_u64),
494 SortableEntity::new(2, 0, 0, Some(vec![2])),
495 ),
496 ];
497 let sort_expr = SortExpr::from(vec![
498 ("optional_blob".to_string(), Order::Asc),
499 ("id".to_string(), Order::Asc),
500 ]);
501
502 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
503
504 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
505 assert_eq!(ids, vec![1, 2, 3]);
506 }
507}