1use crate::{
2 Key,
3 db::{
4 Db,
5 executor::{
6 FilterEvaluator,
7 plan::{plan_for, scan_plan, set_rows_from_len},
8 },
9 primitives::{FilterDsl, FilterExpr, FilterExt, IntoFilterExpr, Order, SortExpr},
10 query::{LoadQuery, QueryPlan, QueryValidate},
11 response::{Response, ResponseError},
12 store::DataRow,
13 },
14 obs::metrics,
15 runtime_error::RuntimeError,
16 traits::{EntityKind, FieldValue},
17};
18use std::{cmp::Ordering, collections::HashMap, hash::Hash, marker::PhantomData, ops::ControlFlow};
19
20#[derive(Clone, Copy)]
25pub struct LoadExecutor<E: EntityKind> {
26 db: Db<E::Canister>,
27 debug: bool,
28 _marker: PhantomData<E>,
29}
30
31impl<E: EntityKind> LoadExecutor<E> {
32 #[must_use]
37 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
38 Self {
39 db,
40 debug,
41 _marker: PhantomData,
42 }
43 }
44
45 fn debug_log(&self, s: impl Into<String>) {
46 if self.debug {
47 println!("{}", s.into());
48 }
49 }
50
51 pub fn one(&self, value: impl FieldValue) -> Result<Response<E>, RuntimeError> {
57 self.execute(LoadQuery::new().one::<E>(value))
58 }
59
60 pub fn only(&self) -> Result<Response<E>, RuntimeError> {
62 self.execute(LoadQuery::new().one::<E>(()))
63 }
64
65 pub fn many<I, V>(&self, values: I) -> Result<Response<E>, RuntimeError>
67 where
68 I: IntoIterator<Item = V>,
69 V: FieldValue,
70 {
71 let query = LoadQuery::new().many_by_field(E::PRIMARY_KEY, values);
72 self.execute(query)
73 }
74
75 pub fn all(&self) -> Result<Response<E>, RuntimeError> {
77 self.execute(LoadQuery::new())
78 }
79
80 pub fn filter<F, I>(&self, f: F) -> Result<Response<E>, RuntimeError>
82 where
83 F: FnOnce(FilterDsl) -> I,
84 I: IntoFilterExpr,
85 {
86 self.execute(LoadQuery::new().filter(f))
87 }
88
89 pub fn require_one(&self, query: LoadQuery) -> Result<(), RuntimeError> {
95 self.execute(query)?.require_one()
96 }
97
98 pub fn require_one_pk(&self, value: impl FieldValue) -> Result<(), RuntimeError> {
100 self.require_one(LoadQuery::new().one::<E>(value))
101 }
102
103 pub fn require_one_filter<F, I>(&self, f: F) -> Result<(), RuntimeError>
105 where
106 F: FnOnce(FilterDsl) -> I,
107 I: IntoFilterExpr,
108 {
109 self.require_one(LoadQuery::new().filter(f))
110 }
111
112 pub fn exists(&self, query: LoadQuery) -> Result<bool, RuntimeError> {
123 QueryValidate::<E>::validate(&query)?;
124 metrics::record_exists_call_for::<E>();
125
126 let plan = plan_for::<E>(query.filter.as_ref());
127 let filter = query.filter.map(FilterExpr::simplify);
128 let offset = query.limit.as_ref().map_or(0, |lim| lim.offset);
129 let limit = query.limit.as_ref().and_then(|lim| lim.limit);
130 if limit == Some(0) {
131 return Ok(false);
132 }
133 let mut seen = 0u32;
134 let mut scanned = 0u64;
135 let mut found = false;
136
137 scan_plan::<E, _>(&self.db, plan, |_, entity| {
138 scanned = scanned.saturating_add(1);
139 let matches = filter
140 .as_ref()
141 .is_none_or(|f| FilterEvaluator::new(&entity).eval(f));
142
143 if matches {
144 if seen < offset {
145 seen += 1;
146 ControlFlow::Continue(())
147 } else {
148 found = true;
149 ControlFlow::Break(())
150 }
151 } else {
152 ControlFlow::Continue(())
153 }
154 })?;
155
156 metrics::record_rows_scanned_for::<E>(scanned);
157
158 Ok(found)
159 }
160
161 pub fn exists_one(&self, value: impl FieldValue) -> Result<bool, RuntimeError> {
163 self.exists(LoadQuery::new().one::<E>(value))
164 }
165
166 pub fn exists_filter<F, I>(&self, f: F) -> Result<bool, RuntimeError>
168 where
169 F: FnOnce(FilterDsl) -> I,
170 I: IntoFilterExpr,
171 {
172 self.exists(LoadQuery::new().filter(f))
173 }
174
175 pub fn exists_any(&self) -> Result<bool, RuntimeError> {
177 self.exists(LoadQuery::new())
178 }
179
180 pub fn ensure_exists_one(&self, value: impl FieldValue) -> Result<(), RuntimeError> {
186 if self.exists_one(value)? {
187 Ok(())
188 } else {
189 Err(ResponseError::NotFound { entity: E::PATH }.into())
190 }
191 }
192
193 #[allow(clippy::cast_possible_truncation)]
195 pub fn ensure_exists_many<I, V>(&self, values: I) -> Result<(), RuntimeError>
196 where
197 I: IntoIterator<Item = V>,
198 V: FieldValue,
199 {
200 let pks: Vec<_> = values.into_iter().collect();
201
202 let expected = pks.len() as u32;
203 if expected == 0 {
204 return Ok(());
205 }
206
207 let res = self.many(pks)?;
208 res.require_len(expected)?;
209
210 Ok(())
211 }
212
213 pub fn ensure_exists_filter<F, I>(&self, f: F) -> Result<(), RuntimeError>
215 where
216 F: FnOnce(FilterDsl) -> I,
217 I: IntoFilterExpr,
218 {
219 if self.exists_filter(f)? {
220 Ok(())
221 } else {
222 Err(ResponseError::NotFound { entity: E::PATH }.into())
223 }
224 }
225
226 pub fn explain(self, query: LoadQuery) -> Result<QueryPlan, RuntimeError> {
232 QueryValidate::<E>::validate(&query)?;
233
234 Ok(plan_for::<E>(query.filter.as_ref()))
235 }
236
237 fn execute_raw(&self, query: &LoadQuery) -> Result<Vec<DataRow>, RuntimeError> {
238 let ctx = self.db.context::<E>();
239 let plan = plan_for::<E>(query.filter.as_ref());
240
241 if let Some(lim) = &query.limit {
242 Ok(ctx.rows_from_plan_with_pagination(plan, lim.offset, lim.limit)?)
243 } else {
244 Ok(ctx.rows_from_plan(plan)?)
245 }
246 }
247
248 pub fn execute(&self, query: LoadQuery) -> Result<Response<E>, RuntimeError> {
254 let mut span = metrics::Span::<E>::new(metrics::ExecKind::Load);
255 QueryValidate::<E>::validate(&query)?;
256
257 self.debug_log(format!("🧭 Executing query: {:?} on {}", query, E::PATH));
258
259 let ctx = self.db.context::<E>();
260 let plan = plan_for::<E>(query.filter.as_ref());
261
262 self.debug_log(format!("📄 Query plan: {plan:?}"));
263
264 let pre_paginated = query.filter.is_none() && query.sort.is_none() && query.limit.is_some();
266 let mut rows: Vec<(Key, E)> = if pre_paginated {
267 let data_rows = self.execute_raw(&query)?;
268 metrics::record_rows_scanned_for::<E>(data_rows.len() as u64);
269
270 self.debug_log(format!(
271 "📦 Scanned {} data rows before deserialization",
272 data_rows.len()
273 ));
274
275 let rows = ctx.deserialize_rows(data_rows)?;
276 self.debug_log(format!(
277 "🧩 Deserialized {} entities before filtering",
278 rows.len()
279 ));
280 rows
281 } else {
282 let data_rows = ctx.rows_from_plan(plan)?;
283 metrics::record_rows_scanned_for::<E>(data_rows.len() as u64);
284 self.debug_log(format!(
285 "📦 Scanned {} data rows before deserialization",
286 data_rows.len()
287 ));
288
289 let rows = ctx.deserialize_rows(data_rows)?;
290 self.debug_log(format!(
291 "🧩 Deserialized {} entities before filtering",
292 rows.len()
293 ));
294
295 rows
296 };
297
298 if let Some(f) = &query.filter {
300 let simplified = f.clone().simplify();
301 Self::apply_filter(&mut rows, &simplified);
302
303 self.debug_log(format!(
304 "🔎 Applied filter -> {} entities remaining",
305 rows.len()
306 ));
307 }
308
309 if let Some(sort) = &query.sort
311 && rows.len() > 1
312 {
313 Self::apply_sort(&mut rows, sort);
314 self.debug_log("↕️ Applied sort expression");
315 }
316
317 if let Some(lim) = &query.limit
319 && !pre_paginated
320 {
321 apply_pagination(&mut rows, lim.offset, lim.limit);
322 self.debug_log(format!(
323 "📏 Applied pagination (offset={}, limit={:?}) -> {} entities",
324 lim.offset,
325 lim.limit,
326 rows.len()
327 ));
328 }
329
330 set_rows_from_len(&mut span, rows.len());
331 self.debug_log(format!("✅ Query complete -> {} final rows", rows.len()));
332
333 Ok(Response(rows))
334 }
335
336 pub fn count(&self, query: LoadQuery) -> Result<u32, RuntimeError> {
338 Ok(self.execute(query)?.count())
339 }
340
341 pub fn count_all(&self) -> Result<u32, RuntimeError> {
342 self.count(LoadQuery::new())
343 }
344
345 pub fn group_count_by<K, F>(
354 &self,
355 query: LoadQuery,
356 key_fn: F,
357 ) -> Result<HashMap<K, u32>, RuntimeError>
358 where
359 K: Eq + Hash,
360 F: Fn(&E) -> K,
361 {
362 let entities = self.execute(query)?.entities();
363
364 let mut counts = HashMap::new();
365 for e in entities {
366 *counts.entry(key_fn(&e)).or_insert(0) += 1;
367 }
368
369 Ok(counts)
370 }
371
372 fn apply_filter(rows: &mut Vec<(Key, E)>, filter: &FilterExpr) {
378 rows.retain(|(_, e)| FilterEvaluator::new(e).eval(filter));
379 }
380
381 fn apply_sort(rows: &mut [(Key, E)], sort_expr: &SortExpr) {
383 rows.sort_by(|(_, ea), (_, eb)| {
384 for (field, direction) in sort_expr.iter() {
385 let va = ea.get_value(field);
386 let vb = eb.get_value(field);
387
388 let ordering = match (va, vb) {
390 (None, None) => continue, (None, Some(_)) => Ordering::Less, (Some(_), None) => Ordering::Greater, (Some(va), Some(vb)) => match va.partial_cmp(&vb) {
394 Some(ord) => ord,
395 None => continue, },
397 };
398
399 let ordering = match direction {
401 Order::Asc => ordering,
402 Order::Desc => ordering.reverse(),
403 };
404
405 if ordering != Ordering::Equal {
406 return ordering;
407 }
408 }
409
410 Ordering::Equal
412 });
413 }
414}
415
416fn apply_pagination<T>(rows: &mut Vec<T>, offset: u32, limit: Option<u32>) {
418 let total = rows.len();
419 let start = usize::min(offset as usize, total);
420 let end = limit.map_or(total, |l| usize::min(start + l as usize, total));
421
422 if start >= end {
423 rows.clear();
424 } else {
425 rows.drain(..start);
426 rows.truncate(end - start);
427 }
428}