1use crate::{
2 Error, Key,
3 db::{
4 Db,
5 executor::{
6 FilterEvaluator,
7 plan::{plan_for, set_rows_from_len},
8 },
9 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10 query::{LoadQuery, QueryPlan, QueryValidate},
11 response::Response,
12 store::DataRow,
13 },
14 obs::metrics,
15 traits::{EntityKind, FieldValue},
16};
17use std::{cmp::Ordering, marker::PhantomData};
18
19#[derive(Clone, Copy)]
24pub struct LoadExecutor<E: EntityKind> {
25 db: Db<E::Canister>,
26 debug: bool,
27 _marker: PhantomData<E>,
28}
29
30impl<E: EntityKind> LoadExecutor<E> {
31 #[must_use]
32 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
33 Self {
34 db,
35 debug,
36 _marker: PhantomData,
37 }
38 }
39
40 #[inline]
41 fn debug_log(&self, s: impl Into<String>) {
42 if self.debug {
43 println!("{}", s.into());
44 }
45 }
46
47 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, Error> {
53 let query = LoadQuery::new().one::<E>(value);
54 self.execute(query)
55 }
56
57 pub fn only(&self) -> Result<Response<E>, Error> {
59 let query = LoadQuery::new().one::<E>(());
60 self.execute(query)
61 }
62
63 pub fn many(
65 &self,
66 values: impl IntoIterator<Item = impl FieldValue>,
67 ) -> Result<Response<E>, Error> {
68 let query = LoadQuery::new().many::<E>(values);
69 self.execute(query)
70 }
71
72 pub fn all(&self) -> Result<Response<E>, Error> {
74 let query = LoadQuery::new();
75 self.execute(query)
76 }
77
78 pub fn filter<F, I>(self, f: F) -> Result<Response<E>, Error>
80 where
81 F: FnOnce(FilterDsl) -> I,
82 I: IntoFilterExpr,
83 {
84 let query = LoadQuery::new().filter(f);
85 self.execute(query)
86 }
87
88 pub fn count_all_rows(self) -> Result<u32, Error> {
90 let query = LoadQuery::all();
91 self.count(query)
92 }
93
94 pub fn exists(self, query: LoadQuery) -> Result<bool, Error> {
100 let query = query.limit_1();
101
102 Ok(!self.execute_raw(&query)?.is_empty())
103 }
104
105 pub fn exists_one(self, value: impl FieldValue) -> Result<bool, Error> {
107 let query = LoadQuery::new().one::<E>(value);
108 self.exists(query)
109 }
110
111 pub fn exists_filter<F, I>(self, f: F) -> Result<bool, Error>
113 where
114 F: FnOnce(FilterDsl) -> I,
115 I: IntoFilterExpr,
116 {
117 let query = LoadQuery::new().filter(f);
118 self.exists(query)
119 }
120
121 pub fn exists_any(self) -> Result<bool, Error> {
123 self.exists(LoadQuery::new())
124 }
125
126 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, Error> {
133 QueryValidate::<E>::validate(&query)?;
134
135 Ok(plan_for::<E>(query.filter.as_ref()))
136 }
137
138 fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, Error> {
139 QueryValidate::<E>::validate(query)?;
140
141 let ctx = self.db.context::<E>();
142 let plan = plan_for::<E>(query.filter.as_ref());
143
144 if let Some(lim) = &query.limit {
145 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
146 } else {
147 Ok(ctx.rows_from_plan(plan)?)
148 }
149 }
150
151 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, Error> {
153 let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
154 QueryValidate::<E>::validate(&query)?;
155
156 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
157
158 let ctx = self.db.context::<E>();
159 let plan = plan_for::<E>(query.filter.as_ref());
160
161 self.debug_log(format!("📄 Query plan: {plan:?}"));
162
163 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
165 let mut rows: Vec<(Key, E)> = if pre_paginated {
166 let data_rows = self.execute_raw(&query)?;
167
168 self.debug_log(format!(
169 "📦 Scanned {} data rows before deserialization",
170 data_rows.len()
171 ));
172
173 let rows = ctx.deserialize_rows(data_rows)?;
174 self.debug_log(format!(
175 "🧩 Deserialized {} entities before filtering",
176 rows.len()
177 ));
178 rows
179 } else {
180 let data_rows = ctx.rows_from_plan(plan)?;
181 self.debug_log(format!(
182 "📦 Scanned {} data rows before deserialization",
183 data_rows.len()
184 ));
185
186 let rows = ctx.deserialize_rows(data_rows)?;
187 self.debug_log(format!(
188 "🧩 Deserialized {} entities before filtering",
189 rows.len()
190 ));
191
192 rows
193 };
194
195 if let Some(f) = &query.filter {
197 let simplified = f.clone().simplify();
198 Self::apply_filter(&mut rows, &simplified);
199
200 self.debug_log(format!(
201 "🔎 Applied filter -> {} entities remaining",
202 rows.len()
203 ));
204 }
205
206 if let Some(sort) = &query.sort
208 && rows.len() > 1
209 {
210 Self::apply_sort(&mut rows, sort);
211 self.debug_log("↕️ Applied sort expression");
212 }
213
214 if let Some(lim) = &query.limit
216 && !pre_paginated
217 {
218 apply_pagination(&mut rows, lim.offset, lim.limit);
219 self.debug_log(format!(
220 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
221 lim.offset,
222 lim.limit,
223 rows.len()
224 ));
225 }
226
227 set_rows_from_len(&mut span, rows.len());
228 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
229
230 Ok(Response(rows))
231 }
232
233 #[allow(clippy::cast_possible_truncation)]
236 pub fn count(self, query: LoadQuery) -> Result<u32, Error> {
237 let count = self.execute(query)?.count();
238
239 Ok(count)
240 }
241
242 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
244 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
245 }
246
247 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
249 rows.sort_by(|(_, ea), (_, eb)| {
250 for (field, direction) in sort_expr.iter() {
251 let va = ea.get_value(field);
252 let vb = eb.get_value(field);
253
254 let ordering = match (va, vb) {
256 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
260 Some(ord) => ord,
261 None => continue, },
263 };
264
265 let ordering = match direction {
267 Order::Asc => ordering,
268 Order::Desc => ordering.reverse(),
269 };
270
271 if ordering != Ordering::Equal {
272 return ordering;
273 }
274 }
275
276 Ordering::Equal
278 });
279 }
280}
281
282fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
284 let total = rows.len();
285 let start = usize::min(offset as usize, total);
286 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
287
288 if start >= end {
289 rows.clear();
290 } else {
291 rows.drain(..start);
292 rows.truncate(end - start);
293 }
294}
295
296#[cfg(test)]
301mod tests {
302 use super::{LoadExecutor, apply_pagination};
303 use crate::{
304 IndexSpec, Key, Value,
305 db::primitives::{Order, SortExpr},
306 traits::{
307 CanisterKind, EntityKind, FieldValues, Path, SanitizeAuto, SanitizeCustom, StoreKind,
308 ValidateAuto, ValidateCustom, View, Visitable,
309 },
310 };
311 use serde::{Deserialize, Serialize};
312
313 #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
314 struct SortableEntity {
315 id: u64,
316 primary: i32,
317 secondary: i32,
318 optional_blob: Option<Vec<u8>>,
319 }
320
321 impl SortableEntity {
322 fn new(id: u64, primary: i32, secondary: i32, optional_blob: Option<Vec<u8>>) -> Self {
323 Self {
324 id,
325 primary,
326 secondary,
327 optional_blob,
328 }
329 }
330 }
331
332 struct SortableCanister;
333 struct SortableStore;
334
335 impl Path for SortableCanister {
336 const PATH: &'static str = "test::canister";
337 }
338
339 impl CanisterKind for SortableCanister {}
340
341 impl Path for SortableStore {
342 const PATH: &'static str = "test::store";
343 }
344
345 impl StoreKind for SortableStore {
346 type Canister = SortableCanister;
347 }
348
349 impl Path for SortableEntity {
350 const PATH: &'static str = "test::sortable";
351 }
352
353 impl View for SortableEntity {
354 type ViewType = Self;
355
356 fn to_view(&self) -> Self::ViewType {
357 self.clone()
358 }
359
360 fn from_view(view: Self::ViewType) -> Self {
361 view
362 }
363 }
364
365 impl SanitizeAuto for SortableEntity {}
366 impl SanitizeCustom for SortableEntity {}
367 impl ValidateAuto for SortableEntity {}
368 impl ValidateCustom for SortableEntity {}
369 impl Visitable for SortableEntity {}
370
371 impl FieldValues for SortableEntity {
372 fn get_value(&self, field: &str) -> Option<Value> {
373 match field {
374 "id" => Some(Value::Uint(self.id)),
375 "primary" => Some(Value::Int(i64::from(self.primary))),
376 "secondary" => Some(Value::Int(i64::from(self.secondary))),
377 "optional_blob" => self.optional_blob.clone().map(Value::Blob),
378 _ => None,
379 }
380 }
381 }
382
383 impl EntityKind for SortableEntity {
384 type PrimaryKey = u64;
385 type Store = SortableStore;
386 type Canister = SortableCanister;
387
388 const ENTITY_ID: u64 = 99;
389 const PRIMARY_KEY: &'static str = "id";
390 const FIELDS: &'static [&'static str] = &["id", "primary", "secondary", "optional_blob"];
391 const INDEXES: &'static [&'static IndexSpec] = &[];
392
393 fn key(&self) -> Key {
394 self.id.into()
395 }
396
397 fn primary_key(&self) -> Self::PrimaryKey {
398 self.id
399 }
400 }
401
402 #[test]
403 fn pagination_empty_vec() {
404 let mut v: Vec<i32> = vec![];
405 apply_pagination(&mut v, 0, Some(10));
406 assert!(v.is_empty());
407 }
408
409 #[test]
410 fn pagination_offset_beyond_len_clears() {
411 let mut v = vec![1, 2, 3];
412 apply_pagination(&mut v, 10, Some(5));
413 assert!(v.is_empty());
414 }
415
416 #[test]
417 fn pagination_no_limit_from_offset() {
418 let mut v = vec![1, 2, 3, 4, 5];
419 apply_pagination(&mut v, 2, None);
420 assert_eq!(v, vec![3, 4, 5]);
421 }
422
423 #[test]
424 fn pagination_exact_window() {
425 let mut v = vec![10, 20, 30, 40, 50];
426 apply_pagination(&mut v, 1, Some(3));
428 assert_eq!(v, vec![20, 30, 40]);
429 }
430
431 #[test]
432 fn pagination_limit_exceeds_tail() {
433 let mut v = vec![10, 20, 30];
434 apply_pagination(&mut v, 1, Some(999));
436 assert_eq!(v, vec![20, 30]);
437 }
438
439 #[test]
440 fn apply_sort_orders_descending() {
441 let mut rows = vec![
442 (Key::from(1_u64), SortableEntity::new(1, 10, 1, None)),
443 (Key::from(2_u64), SortableEntity::new(2, 30, 2, None)),
444 (Key::from(3_u64), SortableEntity::new(3, 20, 3, None)),
445 ];
446 let sort_expr = SortExpr::from(vec![("primary".to_string(), Order::Desc)]);
447
448 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
449
450 let primary: Vec<i32> = rows.iter().map(|(_, e)| e.primary).collect();
451 assert_eq!(primary, vec![30, 20, 10]);
452 }
453
454 #[test]
455 fn apply_sort_uses_secondary_field_for_ties() {
456 let mut rows = vec![
457 (Key::from(1_u64), SortableEntity::new(1, 1, 5, None)),
458 (Key::from(2_u64), SortableEntity::new(2, 1, 8, None)),
459 (Key::from(3_u64), SortableEntity::new(3, 2, 3, None)),
460 ];
461 let sort_expr = SortExpr::from(vec![
462 ("primary".to_string(), Order::Asc),
463 ("secondary".to_string(), Order::Desc),
464 ]);
465
466 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
467
468 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
469 assert_eq!(ids, vec![2, 1, 3]);
470 }
471
472 #[test]
473 fn apply_sort_places_none_before_some_and_falls_back() {
474 let mut rows = vec![
475 (
476 Key::from(3_u64),
477 SortableEntity::new(3, 0, 0, Some(vec![3, 4])),
478 ),
479 (Key::from(1_u64), SortableEntity::new(1, 0, 0, None)),
480 (
481 Key::from(2_u64),
482 SortableEntity::new(2, 0, 0, Some(vec![2])),
483 ),
484 ];
485 let sort_expr = SortExpr::from(vec![
486 ("optional_blob".to_string(), Order::Asc),
487 ("id".to_string(), Order::Asc),
488 ]);
489
490 LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
491
492 let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
493 assert_eq!(ids, vec![1, 2, 3]);
494 }
495}