1use crate::{
2 db::{
3 Db,
4 executor::{
5 FilterEvaluator,
6 plan::{plan_for, record_plan_metrics, scan_missing_ok, set_rows_from_len},
7 },
8 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
9 query::{LoadQuery, QueryPlan, QueryValidate},
10 response::{Response, ResponseError},
11 store::{DataKey, DataRow},
12 },
13 error::InternalError,
14 obs::sink::{self, ExecKind, MetricsEvent, Span},
15 prelude::*,
16 traits::{EntityKind, FieldValue},
17};
18use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData, ops::ControlFlow};
19
20#[derive(Clone)]
25pub struct LoadExecutor<E: EntityKind> {
26 db: Db<E::Canister>,
27 debug: bool,
28 _marker: PhantomData<E>,
29}
30
31impl<E: EntityKind> LoadExecutor<E> {
32 #[must_use]
37 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
38 Self {
39 db,
40 debug,
41 _marker: PhantomData,
42 }
43 }
44
45 fn debug_log(&self, s: impl Into<String>) {
46 if self.debug {
47 println!("{}", s.into());
48 }
49 }
50
51 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, InternalError> {
57 self.execute(LoadQuery::new().one::<E>(value))
58 }
59
60 pub fn only(&self) -> Result<Response<E>, InternalError> {
62 self.execute(LoadQuery::new().one::<E>(()))
63 }
64
65 pub fn many<I, V>(&self, values: I) -> Result<Response<E>, InternalError>
67 where
68 I: IntoIterator<Item = V>,
69 V: FieldValue,
70 {
71 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
72 self.execute(query)
73 }
74
75 pub fn all(&self) -> Result<Response<E>, InternalError> {
77 self.execute(LoadQuery::new())
78 }
79
80 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, InternalError>
82 where
83 F: FnOnce(FilterDsl) -> I,
84 I: IntoFilterExpr,
85 {
86 self.execute(LoadQuery::new().filter(f))
87 }
88
89 pub fn require_one(&self, query: LoadQuery) -> Result<(), InternalError> {
95 self.execute(query)?.require_one()
96 }
97
98 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), InternalError> {
100 self.require_one(LoadQuery::new().one::<E>(value))
101 }
102
103 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), InternalError>
105 where
106 F: FnOnce(FilterDsl) -> I,
107 I: IntoFilterExpr,
108 {
109 self.require_one(LoadQuery::new().filter(f))
110 }
111
112 pub fn exists(&self, query: LoadQuery) -> Result<bool, InternalError> {
125 QueryValidate::<E>::validate(&query)?;
126 sink::record(MetricsEvent::ExistsCall {
127 entity_path: E::PATH,
128 });
129
130 self.debug_log(format!("[debug] exists query {:?} on {}", query, E::PATH));
131
132 let plan = plan_for::<E>(query.filter.as_ref());
133 let offset = query.limit.as_ref().map_or(0, |lim| lim.offset);
134 let limit = query.limit.as_ref().and_then(|lim| lim.limit);
135 if limit == Some(0) {
136 return Ok(false);
137 }
138
139 #[allow(clippy::needless_continue)]
140 match plan {
141 QueryPlan::Keys(keys) => {
142 let mut seen = 0u32;
144 for key in keys {
145 let data_key = DataKey::new::<E>(key);
146 match self.db.context::<E>().read(&data_key) {
147 Ok(_) => {
148 if seen < offset {
149 seen += 1;
150 } else {
151 return Ok(true);
152 }
153 }
154 Err(err) if err.is_not_found() => continue,
155 Err(err) => return Err(err),
156 }
157 }
158 Ok(false)
159 }
160 plan => {
161 let filter = query.filter.map(FilterExpr::simplify);
162 let mut seen = 0u32;
163 let mut scanned = 0u64;
164 let mut found = false;
165
166 scan_missing_ok::<E, _>(&self.db, plan, |_, entity| {
168 scanned = scanned.saturating_add(1);
169 let matches = filter
170 .as_ref()
171 .is_none_or(|f| FilterEvaluator::new(&entity).eval(f));
172
173 if matches {
174 if seen < offset {
175 seen += 1;
176 } else {
177 found = true;
178 return ControlFlow::Break(());
179 }
180 }
181
182 ControlFlow::Continue(())
183 })?;
184
185 sink::record(MetricsEvent::RowsScanned {
186 entity_path: E::PATH,
187 rows_scanned: scanned,
188 });
189
190 Ok(found)
191 }
192 }
193 }
194
195 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, InternalError> {
204 let value = value.to_value();
205 let query = LoadQuery::new().one::<E>(value.clone());
206 QueryValidate::<E>::validate(&query)?;
207 sink::record(MetricsEvent::ExistsCall {
208 entity_path: E::PATH,
209 });
210
211 self.debug_log(format!(
212 "[debug] exists_one on {} (value={:?})",
213 E::PATH,
214 value
215 ));
216
217 let Some(key) = value.as_key_coerced() else {
218 sink::record(MetricsEvent::RowsScanned {
219 entity_path: E::PATH,
220 rows_scanned: 0,
221 });
222 return Ok(false);
223 };
224
225 let data_key = DataKey::new::<E>(key);
226 let found = match self.db.context::<E>().read(&data_key) {
227 Ok(_) => true,
228 Err(err) if err.is_not_found() => false,
229 Err(err) => return Err(err),
230 };
231
232 sink::record(MetricsEvent::RowsScanned {
233 entity_path: E::PATH,
234 rows_scanned: u64::from(found),
235 });
236
237 Ok(found)
238 }
239
240 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, InternalError>
242 where
243 F: FnOnce(FilterDsl) -> I,
244 I: IntoFilterExpr,
245 {
246 self.exists(LoadQuery::new().filter(f))
247 }
248
249 pub fn exists_any(&self) -> Result<bool, InternalError> {
251 self.exists(LoadQuery::new())
252 }
253
254 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), InternalError> {
260 let value = value.to_value();
261 let query = LoadQuery::new().one::<E>(value.clone());
262 QueryValidate::<E>::validate(&query)?;
263
264 if self.exists_one(value)? {
265 Ok(())
266 } else {
267 Err(ResponseError::NotFound { entity: E::PATH }.into())
268 }
269 }
270
271 pub fn ensure_exists_all<I, V>(&self, values: I) -> Result<(), InternalError>
276 where
277 I: IntoIterator<Item = V>,
278 V: FieldValue,
279 {
280 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
281 QueryValidate::<E>::validate(&query)?;
282
283 let QueryPlan::Keys(keys) = plan_for::<E>(query.filter.as_ref()) else {
284 return Ok(());
285 };
286
287 if keys.is_empty() {
288 return Ok(());
289 }
290
291 let ctx = self.db.context::<E>();
292 for key in keys {
293 let data_key = DataKey::new::<E>(key);
294 match ctx.read(&data_key) {
295 Ok(_) => {}
296 Err(err) if err.is_not_found() => {
297 return Err(ResponseError::NotFound { entity: E::PATH }.into());
298 }
299 Err(err) => return Err(err),
300 }
301 }
302
303 Ok(())
304 }
305
306 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), InternalError>
308 where
309 F: FnOnce(FilterDsl) -> I,
310 I: IntoFilterExpr,
311 {
312 if self.exists_filter(f)? {
313 Ok(())
314 } else {
315 Err(ResponseError::NotFound { entity: E::PATH }.into())
316 }
317 }
318
319 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, InternalError> {
325 QueryValidate::<E>::validate(&query)?;
326
327 Ok(plan_for::<E>(query.filter.as_ref()))
328 }
329
330 fn execute_raw(
331 &self,
332 plan: QueryPlan,
333 query: &LoadQuery,
334 ) -> Result<Vec<DataRow>, InternalError> {
335 let ctx = self.db.context::<E>();
336
337 if let Some(lim) = &query.limit {
338 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
339 } else {
340 Ok(ctx.rows_from_plan(plan)?)
341 }
342 }
343
344 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, InternalError> {
350 let mut span = Span::<E>::new(ExecKind::Load);
351 QueryValidate::<E>::validate(&query)?;
352
353 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
354
355 let ctx = self.db.context::<E>();
356 let plan = plan_for::<E>(query.filter.as_ref());
357
358 self.debug_log(format!("📄 Query plan: {plan:?}"));
359 record_plan_metrics(&plan);
360
361 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
363 let mut rows: Vec<(Key, E)> = if pre_paginated {
364 let data_rows = self.execute_raw(plan, &query)?;
365 sink::record(MetricsEvent::RowsScanned {
366 entity_path: E::PATH,
367 rows_scanned: data_rows.len() as u64,
368 });
369
370 self.debug_log(format!(
371 "📦 Scanned {} data rows before deserialization",
372 data_rows.len()
373 ));
374
375 let rows = ctx.deserialize_rows(data_rows)?;
376 self.debug_log(format!(
377 "🧩 Deserialized {} entities before filtering",
378 rows.len()
379 ));
380 rows
381 } else {
382 let data_rows = ctx.rows_from_plan(plan)?;
383 sink::record(MetricsEvent::RowsScanned {
384 entity_path: E::PATH,
385 rows_scanned: data_rows.len() as u64,
386 });
387 self.debug_log(format!(
388 "📦 Scanned {} data rows before deserialization",
389 data_rows.len()
390 ));
391
392 let rows = ctx.deserialize_rows(data_rows)?;
393 self.debug_log(format!(
394 "🧩 Deserialized {} entities before filtering",
395 rows.len()
396 ));
397
398 rows
399 };
400
401 if let Some(f) = &query.filter {
403 let simplified = f.clone().simplify();
404 Self::apply_filter(&mut rows, &simplified);
405
406 self.debug_log(format!(
407 "🔎 Applied filter -> {} entities remaining",
408 rows.len()
409 ));
410 }
411
412 if let Some(sort) = &query.sort
414 && rows.len() > 1
415 {
416 Self::apply_sort(&mut rows, sort);
417 self.debug_log("↕️ Applied sort expression");
418 }
419
420 if let Some(lim) = &query.limit
422 && !pre_paginated
423 {
424 apply_pagination(&mut rows, lim.offset, lim.limit);
425 self.debug_log(format!(
426 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
427 lim.offset,
428 lim.limit,
429 rows.len()
430 ));
431 }
432
433 set_rows_from_len(&mut span, rows.len());
434 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
435
436 Ok(Response(rows))
437 }
438
439 pub fn count(&self, query: LoadQuery) -> Result<u32, InternalError> {
441 Ok(self.execute(query)?.count())
442 }
443
444 pub fn count_all(&self) -> Result<u32, InternalError> {
445 self.count(LoadQuery::new())
446 }
447
448 pub fn group_count_by<K, F>(
457 &self,
458 query: LoadQuery,
459 key_fn: F,
460 ) -> Result<HashMap<K, u32>, InternalError>
461 where
462 K: Eq + Hash,
463 F: Fn(&E) -> K,
464 {
465 let entities = self.execute(query)?.entities();
466
467 let mut counts = HashMap::new();
468 for e in entities {
469 *counts.entry(key_fn(&e)).or_insert(0) += 1;
470 }
471
472 Ok(counts)
473 }
474
475 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
481 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
482 }
483
484 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
486 rows.sort_by(|(_, ea), (_, eb)| {
487 for (field, direction) in sort_expr.iter() {
488 let va = ea.get_value(field);
489 let vb = eb.get_value(field);
490
491 let ordering = match (va, vb) {
493 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
497 Some(ord) => ord,
498 None => continue, },
500 };
501
502 let ordering = match direction {
504 Order::Asc => ordering,
505 Order::Desc => ordering.reverse(),
506 };
507
508 if ordering != Ordering::Equal {
509 return ordering;
510 }
511 }
512
513 Ordering::Equal
515 });
516 }
517}
518
519fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
521 let total = rows.len();
522 let start = usize::min(offset as usize, total);
523 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
524
525 if start >= end {
526 rows.clear();
527 } else {
528 rows.drain(..start);
529 rows.truncate(end - start);
530 }
531}