icydb_core/db/executor/
load.rs

1use crate::{
2    Error, Key,
3    db::{
4        Db,
5        executor::{
6            FilterEvaluator,
7            plan::{plan_for, scan_plan, set_rows_from_len},
8        },
9        primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10        query::{LoadQuery, QueryPlan, QueryValidate},
11        response::{Response, ResponseError},
12        store::DataRow,
13    },
14    obs::metrics,
15    traits::{EntityKind, FieldValue},
16};
17use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData, ops::ControlFlow};
18
19///
20/// LoadExecutor
21///
22
23#[derive(Clone, Copy)]
24pub struct LoadExecutor<E: EntityKind> {
25    db: Db<E::Canister>,
26    debug: bool,
27    _marker: PhantomData<E>,
28}
29
30impl<E: EntityKind> LoadExecutor<E> {
31    // ======================================================================
32    // Construction & diagnostics
33    // ======================================================================
34
35    #[must_use]
36    pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
37        Self {
38            db,
39            debug,
40            _marker: PhantomData,
41        }
42    }
43
44    fn debug_log(&self, s: impl Into<String>) {
45        if self.debug {
46            println!("{}", s.into());
47        }
48    }
49
50    // ======================================================================
51    // Query builders (execute and return Response)
52    // ======================================================================
53
54    /// Execute a query for a single primary key.
55    pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, Error> {
56        self.execute(LoadQuery::new().one::<E>(value))
57    }
58
59    /// Execute a query for the unit primary key.
60    pub fn only(&self) -> Result<Response<E>, Error> {
61        self.execute(LoadQuery::new().one::<E>(()))
62    }
63
64    /// Execute a query matching multiple primary keys.
65    pub fn many<I, V>(&self, values: I) -> Result<Response<E>, Error>
66    where
67        I: IntoIterator<Item = V>,
68        V: FieldValue,
69    {
70        let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
71        self.execute(query)
72    }
73
74    /// Execute an unfiltered query for all rows.
75    pub fn all(&self) -> Result<Response<E>, Error> {
76        self.execute(LoadQuery::new())
77    }
78
79    /// Execute a query built from a filter.
80    pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, Error>
81    where
82        F: FnOnce(FilterDsl) -> I,
83        I: IntoFilterExpr,
84    {
85        self.execute(LoadQuery::new().filter(f))
86    }
87
88    // ======================================================================
89    // Cardinality guards (delegated to Response)
90    // ======================================================================
91
92    /// Execute a query and require exactly one row.
93    pub fn require_one(&self, query: LoadQuery) -> Result<(), Error> {
94        self.execute(query)?.require_one()
95    }
96
97    /// Require exactly one row by primary key.
98    pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), Error> {
99        self.require_one(LoadQuery::new().one::<E>(value))
100    }
101
102    /// Require exactly one row from a filter.
103    pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), Error>
104    where
105        F: FnOnce(FilterDsl) -> I,
106        I: IntoFilterExpr,
107    {
108        self.require_one(LoadQuery::new().filter(f))
109    }
110
111    // ======================================================================
112    // Existence checks (≥1 semantics, intentionally weaker)
113    // ======================================================================
114
115    /// Check whether at least one row matches the query.
116    pub fn exists(&self, query: LoadQuery) -> Result<bool, Error> {
117        let query = query.limit_1();
118        QueryValidate::<E>::validate(&query)?;
119
120        let plan = plan_for::<E>(query.filter.as_ref());
121        let filter = query.filter.map(FilterExpr::simplify);
122        let mut found = false;
123
124        scan_plan::<E, _>(&self.db, plan, |_, entity| {
125            let matches = filter
126                .as_ref()
127                .is_none_or(|f| FilterEvaluator::new(&entity).eval(f));
128
129            if matches {
130                found = true;
131                ControlFlow::Break(())
132            } else {
133                ControlFlow::Continue(())
134            }
135        })?;
136
137        Ok(found)
138    }
139
140    /// Check existence by primary key.
141    pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, Error> {
142        self.exists(LoadQuery::new().one::<E>(value))
143    }
144
145    /// Check existence with a filter.
146    pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, Error>
147    where
148        F: FnOnce(FilterDsl) -> I,
149        I: IntoFilterExpr,
150    {
151        self.exists(LoadQuery::new().filter(f))
152    }
153
154    /// Check whether the table contains any rows.
155    pub fn exists_any(&self) -> Result<bool, Error> {
156        self.exists(LoadQuery::new())
157    }
158
159    // ======================================================================
160    // Existence checks with not-found errors (fast path, no deserialization)
161    // ======================================================================
162
163    /// Require at least one row by primary key.
164    pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), Error> {
165        if self.exists_one(value)? {
166            Ok(())
167        } else {
168            Err(ResponseError::NotFound { entity: E::PATH }.into())
169        }
170    }
171
172    /// Require that all provided primary keys exist.
173    #[allow(clippy::cast_possible_truncation)]
174    pub fn ensure_exists_many<I, V>(&self, values: I) -> Result<(), Error>
175    where
176        I: IntoIterator<Item = V>,
177        V: FieldValue,
178    {
179        let pks: Vec<_> = values.into_iter().collect();
180
181        let expected = pks.len() as u32;
182        if expected == 0 {
183            return Ok(());
184        }
185
186        let res = self.many(pks)?;
187        res.require_len(expected)?;
188
189        Ok(())
190    }
191
192    /// Require at least one row from a filter.
193    pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), Error>
194    where
195        F: FnOnce(FilterDsl) -> I,
196        I: IntoFilterExpr,
197    {
198        if self.exists_filter(f)? {
199            Ok(())
200        } else {
201            Err(ResponseError::NotFound { entity: E::PATH }.into())
202        }
203    }
204
205    // ======================================================================
206    // Execution & planning
207    // ======================================================================
208
209    /// Validate and return the query plan without executing.
210    pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, Error> {
211        QueryValidate::<E>::validate(&query)?;
212
213        Ok(plan_for::<E>(query.filter.as_ref()))
214    }
215
216    fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, Error> {
217        QueryValidate::<E>::validate(query)?;
218
219        let ctx = self.db.context::<E>();
220        let plan = plan_for::<E>(query.filter.as_ref());
221
222        if let Some(lim) = &query.limit {
223            Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
224        } else {
225            Ok(ctx.rows_from_plan(plan)?)
226        }
227    }
228
229    /// Execute a full query and return a collection of entities.
230    pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, Error> {
231        let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
232        QueryValidate::<E>::validate(&query)?;
233
234        self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
235
236        let ctx = self.db.context::<E>();
237        let plan = plan_for::<E>(query.filter.as_ref());
238
239        self.debug_log(format!("📄 Query plan: {plan:?}"));
240
241        // Fast path: pre-pagination
242        let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
243        let mut rows: Vec<(Key, E)> = if pre_paginated {
244            let data_rows = self.execute_raw(&query)?;
245
246            self.debug_log(format!(
247                "📦 Scanned {} data rows before deserialization",
248                data_rows.len()
249            ));
250
251            let rows = ctx.deserialize_rows(data_rows)?;
252            self.debug_log(format!(
253                "🧩 Deserialized {} entities before filtering",
254                rows.len()
255            ));
256            rows
257        } else {
258            let data_rows = ctx.rows_from_plan(plan)?;
259            self.debug_log(format!(
260                "📦 Scanned {} data rows before deserialization",
261                data_rows.len()
262            ));
263
264            let rows = ctx.deserialize_rows(data_rows)?;
265            self.debug_log(format!(
266                "🧩 Deserialized {} entities before filtering",
267                rows.len()
268            ));
269
270            rows
271        };
272
273        // Filtering
274        if let Some(f) = &query.filter {
275            let simplified = f.clone().simplify();
276            Self::apply_filter(&mut rows, &simplified);
277
278            self.debug_log(format!(
279                "🔎 Applied filter -> {} entities remaining",
280                rows.len()
281            ));
282        }
283
284        // Sorting
285        if let Some(sort) = &query.sort
286            && rows.len() > 1
287        {
288            Self::apply_sort(&mut rows, sort);
289            self.debug_log("↕️ Applied sort expression");
290        }
291
292        // Pagination
293        if let Some(lim) = &query.limit
294            && !pre_paginated
295        {
296            apply_pagination(&mut rows, lim.offset, lim.limit);
297            self.debug_log(format!(
298                "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
299                lim.offset,
300                lim.limit,
301                rows.len()
302            ));
303        }
304
305        set_rows_from_len(&mut span, rows.len());
306        self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
307
308        Ok(Response(rows))
309    }
310
311    /// Count rows matching a query.
312    pub fn count(&self, query: LoadQuery) -> Result<u32, Error> {
313        Ok(self.execute(query)?.count())
314    }
315
316    pub fn count_all(&self) -> Result<u32, Error> {
317        self.count(LoadQuery::new())
318    }
319
320    // ======================================================================
321    // Aggregations
322    // ======================================================================
323
324    /// Group rows matching a query and count them by a derived key.
325    ///
326    /// This is intentionally implemented on the executor (not Response)
327    /// so it can later avoid full deserialization.
328    pub fn group_count_by<K, F>(
329        &self,
330        query: LoadQuery,
331        key_fn: F,
332    ) -> Result<HashMap<K, u32>, Error>
333    where
334        K: Eq + Hash,
335        F: Fn(&E) -> K,
336    {
337        let entities = self.execute(query)?.entities();
338
339        let mut counts = HashMap::new();
340        for e in entities {
341            *counts.entry(key_fn(&e)).or_insert(0) += 1;
342        }
343
344        Ok(counts)
345    }
346
347    // ======================================================================
348    // Private Helpers
349    // ======================================================================
350
351    // apply_filter
352    fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
353        rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
354    }
355
356    // apply_sort
357    fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
358        rows.sort_by(|(_, ea), (_, eb)| {
359            for (field, direction) in sort_expr.iter() {
360                let va = ea.get_value(field);
361                let vb = eb.get_value(field);
362
363                // Define how to handle missing values (None)
364                let ordering = match (va, vb) {
365                    (None, None) => continue,             // both missing → move to next field
366                    (None, Some(_)) => Ordering::Less,    // None sorts before Some(_)
367                    (Some(_), None) => Ordering::Greater, // Some(_) sorts after None
368                    (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
369                        Some(ord) => ord,
370                        None => continue, // incomparable values → move to next field
371                    },
372                };
373
374                // Apply direction (Asc/Desc)
375                let ordering = match direction {
376                    Order::Asc => ordering,
377                    Order::Desc => ordering.reverse(),
378                };
379
380                if ordering != Ordering::Equal {
381                    return ordering;
382                }
383            }
384
385            // all fields equal
386            Ordering::Equal
387        });
388    }
389}
390
391/// Apply offset/limit pagination to an in-memory vector, in-place.
392fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
393    let total = rows.len();
394    let start = usize::min(offset as usize, total);
395    let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
396
397    if start >= end {
398        rows.clear();
399    } else {
400        rows.drain(..start);
401        rows.truncate(end - start);
402    }
403}
404
405///
406/// TESTS
407///
408
409#[cfg(test)]
410mod tests {
411    use super::{LoadExecutor, apply_pagination};
412    use crate::{
413        IndexSpec, Key, Value,
414        db::primitives::{Order, SortExpr},
415        traits::{
416            CanisterKind, EntityKind, FieldValues, Path, SanitizeAuto, SanitizeCustom, StoreKind,
417            ValidateAuto, ValidateCustom, View, Visitable,
418        },
419    };
420    use serde::{Deserialize, Serialize};
421
422    #[derive(Clone, Debug, Default, Deserialize, PartialEq, Serialize)]
423    struct SortableEntity {
424        id: u64,
425        primary: i32,
426        secondary: i32,
427        optional_blob: Option<Vec<u8>>,
428    }
429
430    impl SortableEntity {
431        fn new(id: u64, primary: i32, secondary: i32, optional_blob: Option<Vec<u8>>) -> Self {
432            Self {
433                id,
434                primary,
435                secondary,
436                optional_blob,
437            }
438        }
439    }
440
441    struct SortableCanister;
442    struct SortableStore;
443
444    impl Path for SortableCanister {
445        const PATH: &'static str = "test::canister";
446    }
447
448    impl CanisterKind for SortableCanister {}
449
450    impl Path for SortableStore {
451        const PATH: &'static str = "test::store";
452    }
453
454    impl StoreKind for SortableStore {
455        type Canister = SortableCanister;
456    }
457
458    impl Path for SortableEntity {
459        const PATH: &'static str = "test::sortable";
460    }
461
462    impl View for SortableEntity {
463        type ViewType = Self;
464
465        fn to_view(&self) -> Self::ViewType {
466            self.clone()
467        }
468
469        fn from_view(view: Self::ViewType) -> Self {
470            view
471        }
472    }
473
474    impl SanitizeAuto for SortableEntity {}
475    impl SanitizeCustom for SortableEntity {}
476    impl ValidateAuto for SortableEntity {}
477    impl ValidateCustom for SortableEntity {}
478    impl Visitable for SortableEntity {}
479
480    impl FieldValues for SortableEntity {
481        fn get_value(&self, field: &str) -> Option<Value> {
482            match field {
483                "id" => Some(Value::Uint(self.id)),
484                "primary" => Some(Value::Int(i64::from(self.primary))),
485                "secondary" => Some(Value::Int(i64::from(self.secondary))),
486                "optional_blob" => self.optional_blob.clone().map(Value::Blob),
487                _ => None,
488            }
489        }
490    }
491
492    impl EntityKind for SortableEntity {
493        type PrimaryKey = u64;
494        type Store = SortableStore;
495        type Canister = SortableCanister;
496
497        const ENTITY_ID: u64 = 99;
498        const PRIMARY_KEY: &'static str = "id";
499        const FIELDS: &'static [&'static str] = &["id", "primary", "secondary", "optional_blob"];
500        const INDEXES: &'static [&'static IndexSpec] = &[];
501
502        fn key(&self) -> Key {
503            self.id.into()
504        }
505
506        fn primary_key(&self) -> Self::PrimaryKey {
507            self.id
508        }
509
510        fn set_primary_key(&mut self, key: Self::PrimaryKey) {
511            self.id = key;
512        }
513    }
514
515    #[test]
516    fn pagination_empty_vec() {
517        let mut v: Vec<i32> = vec![];
518        apply_pagination(&mut v, 0, Some(10));
519        assert!(v.is_empty());
520    }
521
522    #[test]
523    fn pagination_offset_beyond_len_clears() {
524        let mut v = vec![1, 2, 3];
525        apply_pagination(&mut v, 10, Some(5));
526        assert!(v.is_empty());
527    }
528
529    #[test]
530    fn pagination_no_limit_from_offset() {
531        let mut v = vec![1, 2, 3, 4, 5];
532        apply_pagination(&mut v, 2, None);
533        assert_eq!(v, vec![3, 4, 5]);
534    }
535
536    #[test]
537    fn pagination_exact_window() {
538        let mut v = vec![10, 20, 30, 40, 50];
539        // offset 1, limit 3 -> elements [20,30,40]
540        apply_pagination(&mut v, 1, Some(3));
541        assert_eq!(v, vec![20, 30, 40]);
542    }
543
544    #[test]
545    fn pagination_limit_exceeds_tail() {
546        let mut v = vec![10, 20, 30];
547        // offset 1, limit large -> [20,30]
548        apply_pagination(&mut v, 1, Some(999));
549        assert_eq!(v, vec![20, 30]);
550    }
551
552    #[test]
553    fn apply_sort_orders_descending() {
554        let mut rows = vec![
555            (Key::from(1_u64), SortableEntity::new(1, 10, 1, None)),
556            (Key::from(2_u64), SortableEntity::new(2, 30, 2, None)),
557            (Key::from(3_u64), SortableEntity::new(3, 20, 3, None)),
558        ];
559        let sort_expr = SortExpr::from(vec![("primary".to_string(), Order::Desc)]);
560
561        LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
562
563        let primary: Vec<i32> = rows.iter().map(|(_, e)| e.primary).collect();
564        assert_eq!(primary, vec![30, 20, 10]);
565    }
566
567    #[test]
568    fn apply_sort_uses_secondary_field_for_ties() {
569        let mut rows = vec![
570            (Key::from(1_u64), SortableEntity::new(1, 1, 5, None)),
571            (Key::from(2_u64), SortableEntity::new(2, 1, 8, None)),
572            (Key::from(3_u64), SortableEntity::new(3, 2, 3, None)),
573        ];
574        let sort_expr = SortExpr::from(vec![
575            ("primary".to_string(), Order::Asc),
576            ("secondary".to_string(), Order::Desc),
577        ]);
578
579        LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
580
581        let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
582        assert_eq!(ids, vec![2, 1, 3]);
583    }
584
585    #[test]
586    fn apply_sort_places_none_before_some_and_falls_back() {
587        let mut rows = vec![
588            (
589                Key::from(3_u64),
590                SortableEntity::new(3, 0, 0, Some(vec![3, 4])),
591            ),
592            (Key::from(1_u64), SortableEntity::new(1, 0, 0, None)),
593            (
594                Key::from(2_u64),
595                SortableEntity::new(2, 0, 0, Some(vec![2])),
596            ),
597        ];
598        let sort_expr = SortExpr::from(vec![
599            ("optional_blob".to_string(), Order::Asc),
600            ("id".to_string(), Order::Asc),
601        ]);
602
603        LoadExecutor::<SortableEntity>::apply_sort(rows.as_mut_slice(), &sort_expr);
604
605        let ids: Vec<u64> = rows.iter().map(|(_, e)| e.id).collect();
606        assert_eq!(ids, vec![1, 2, 3]);
607    }
608}