1use crate::{
2 Error, Key,
3 db::{
4 Db,
5 executor::{
6 FilterEvaluator,
7 plan::{plan_for, scan_plan, set_rows_from_len},
8 },
9 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10 query::{LoadQuery, QueryPlan, QueryValidate},
11 response::{Response, ResponseError},
12 store::DataRow,
13 },
14 obs::metrics,
15 traits::{EntityKind, FieldValue},
16};
17use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData, ops::ControlFlow};
18
19#[derive(Clone, Copy)]
24pub struct LoadExecutor<E: EntityKind> {
25 db: Db<E::Canister>,
26 debug: bool,
27 _marker: PhantomData<E>,
28}
29
30impl<E: EntityKind> LoadExecutor<E> {
31 #[must_use]
36 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
37 Self {
38 db,
39 debug,
40 _marker: PhantomData,
41 }
42 }
43
44 fn debug_log(&self, s: impl Into<String>) {
45 if self.debug {
46 println!("{}", s.into());
47 }
48 }
49
50 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, Error> {
56 self.execute(LoadQuery::new().one::<E>(value))
57 }
58
59 pub fn only(&self) -> Result<Response<E>, Error> {
61 self.execute(LoadQuery::new().one::<E>(()))
62 }
63
64 pub fn many<I, V>(&self, values: I) -> Result<Response<E>, Error>
66 where
67 I: IntoIterator<Item = V>,
68 V: FieldValue,
69 {
70 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
71 self.execute(query)
72 }
73
74 pub fn all(&self) -> Result<Response<E>, Error> {
76 self.execute(LoadQuery::new())
77 }
78
79 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, Error>
81 where
82 F: FnOnce(FilterDsl) -> I,
83 I: IntoFilterExpr,
84 {
85 self.execute(LoadQuery::new().filter(f))
86 }
87
88 pub fn require_one(&self, query: LoadQuery) -> Result<(), Error> {
94 self.execute(query)?.require_one()
95 }
96
97 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), Error> {
99 self.require_one(LoadQuery::new().one::<E>(value))
100 }
101
102 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), Error>
104 where
105 F: FnOnce(FilterDsl) -> I,
106 I: IntoFilterExpr,
107 {
108 self.require_one(LoadQuery::new().filter(f))
109 }
110
111 pub fn exists(&self, query: LoadQuery) -> Result<bool, Error> {
122 QueryValidate::<E>::validate(&query)?;
123 metrics::record_exists_call_for::<E>();
124
125 let plan = plan_for::<E>(query.filter.as_ref());
126 let filter = query.filter.map(FilterExpr::simplify);
127 let offset = query.limit.as_ref().map_or(0, |lim| lim.offset);
128 let limit = query.limit.as_ref().and_then(|lim| lim.limit);
129 if limit == Some(0) {
130 return Ok(false);
131 }
132 let mut seen = 0u32;
133 let mut scanned = 0u64;
134 let mut found = false;
135
136 scan_plan::<E, _>(&self.db, plan, |_, entity| {
137 scanned = scanned.saturating_add(1);
138 let matches = filter
139 .as_ref()
140 .is_none_or(|f| FilterEvaluator::new(&entity).eval(f));
141
142 if matches {
143 if seen < offset {
144 seen += 1;
145 ControlFlow::Continue(())
146 } else {
147 found = true;
148 ControlFlow::Break(())
149 }
150 } else {
151 ControlFlow::Continue(())
152 }
153 })?;
154
155 metrics::record_rows_scanned_for::<E>(scanned);
156
157 Ok(found)
158 }
159
160 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, Error> {
162 self.exists(LoadQuery::new().one::<E>(value))
163 }
164
165 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, Error>
167 where
168 F: FnOnce(FilterDsl) -> I,
169 I: IntoFilterExpr,
170 {
171 self.exists(LoadQuery::new().filter(f))
172 }
173
174 pub fn exists_any(&self) -> Result<bool, Error> {
176 self.exists(LoadQuery::new())
177 }
178
179 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), Error> {
185 if self.exists_one(value)? {
186 Ok(())
187 } else {
188 Err(ResponseError::NotFound { entity: E::PATH }.into())
189 }
190 }
191
192 #[allow(clippy::cast_possible_truncation)]
194 pub fn ensure_exists_many<I, V>(&self, values: I) -> Result<(), Error>
195 where
196 I: IntoIterator<Item = V>,
197 V: FieldValue,
198 {
199 let pks: Vec<_> = values.into_iter().collect();
200
201 let expected = pks.len() as u32;
202 if expected == 0 {
203 return Ok(());
204 }
205
206 let res = self.many(pks)?;
207 res.require_len(expected)?;
208
209 Ok(())
210 }
211
212 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), Error>
214 where
215 F: FnOnce(FilterDsl) -> I,
216 I: IntoFilterExpr,
217 {
218 if self.exists_filter(f)? {
219 Ok(())
220 } else {
221 Err(ResponseError::NotFound { entity: E::PATH }.into())
222 }
223 }
224
225 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, Error> {
231 QueryValidate::<E>::validate(&query)?;
232
233 Ok(plan_for::<E>(query.filter.as_ref()))
234 }
235
236 fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, Error> {
237 let ctx = self.db.context::<E>();
238 let plan = plan_for::<E>(query.filter.as_ref());
239
240 if let Some(lim) = &query.limit {
241 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
242 } else {
243 Ok(ctx.rows_from_plan(plan)?)
244 }
245 }
246
247 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, Error> {
253 let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
254 QueryValidate::<E>::validate(&query)?;
255
256 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
257
258 let ctx = self.db.context::<E>();
259 let plan = plan_for::<E>(query.filter.as_ref());
260
261 self.debug_log(format!("📄 Query plan: {plan:?}"));
262
263 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
265 let mut rows: Vec<(Key, E)> = if pre_paginated {
266 let data_rows = self.execute_raw(&query)?;
267 metrics::record_rows_scanned_for::<E>(data_rows.len() as u64);
268
269 self.debug_log(format!(
270 "📦 Scanned {} data rows before deserialization",
271 data_rows.len()
272 ));
273
274 let rows = ctx.deserialize_rows(data_rows)?;
275 self.debug_log(format!(
276 "🧩 Deserialized {} entities before filtering",
277 rows.len()
278 ));
279 rows
280 } else {
281 let data_rows = ctx.rows_from_plan(plan)?;
282 metrics::record_rows_scanned_for::<E>(data_rows.len() as u64);
283 self.debug_log(format!(
284 "📦 Scanned {} data rows before deserialization",
285 data_rows.len()
286 ));
287
288 let rows = ctx.deserialize_rows(data_rows)?;
289 self.debug_log(format!(
290 "🧩 Deserialized {} entities before filtering",
291 rows.len()
292 ));
293
294 rows
295 };
296
297 if let Some(f) = &query.filter {
299 let simplified = f.clone().simplify();
300 Self::apply_filter(&mut rows, &simplified);
301
302 self.debug_log(format!(
303 "🔎 Applied filter -> {} entities remaining",
304 rows.len()
305 ));
306 }
307
308 if let Some(sort) = &query.sort
310 && rows.len() > 1
311 {
312 Self::apply_sort(&mut rows, sort);
313 self.debug_log("↕️ Applied sort expression");
314 }
315
316 if let Some(lim) = &query.limit
318 && !pre_paginated
319 {
320 apply_pagination(&mut rows, lim.offset, lim.limit);
321 self.debug_log(format!(
322 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
323 lim.offset,
324 lim.limit,
325 rows.len()
326 ));
327 }
328
329 set_rows_from_len(&mut span, rows.len());
330 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
331
332 Ok(Response(rows))
333 }
334
335 pub fn count(&self, query: LoadQuery) -> Result<u32, Error> {
337 Ok(self.execute(query)?.count())
338 }
339
340 pub fn count_all(&self) -> Result<u32, Error> {
341 self.count(LoadQuery::new())
342 }
343
344 pub fn group_count_by<K, F>(
353 &self,
354 query: LoadQuery,
355 key_fn: F,
356 ) -> Result<HashMap<K, u32>, Error>
357 where
358 K: Eq + Hash,
359 F: Fn(&E) -> K,
360 {
361 let entities = self.execute(query)?.entities();
362
363 let mut counts = HashMap::new();
364 for e in entities {
365 *counts.entry(key_fn(&e)).or_insert(0) += 1;
366 }
367
368 Ok(counts)
369 }
370
371 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
377 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
378 }
379
380 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
382 rows.sort_by(|(_, ea), (_, eb)| {
383 for (field, direction) in sort_expr.iter() {
384 let va = ea.get_value(field);
385 let vb = eb.get_value(field);
386
387 let ordering = match (va, vb) {
389 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
393 Some(ord) => ord,
394 None => continue, },
396 };
397
398 let ordering = match direction {
400 Order::Asc => ordering,
401 Order::Desc => ordering.reverse(),
402 };
403
404 if ordering != Ordering::Equal {
405 return ordering;
406 }
407 }
408
409 Ordering::Equal
411 });
412 }
413}
414
415fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
417 let total = rows.len();
418 let start = usize::min(offset as usize, total);
419 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
420
421 if start >= end {
422 rows.clear();
423 } else {
424 rows.drain(..start);
425 rows.truncate(end - start);
426 }
427}