1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, EntitySchemaDescription, FluentDeleteQuery, FluentLoadQuery,
11 MissingRowPolicy, PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError,
12 Query, QueryError, QueryTracePlan, StoreRegistry, TraceExecutionStrategy,
13 WriteBatchResponse,
14 access::AccessStrategy,
15 commit::EntityRuntimeHooks,
16 cursor::decode_optional_cursor_token,
17 describe::describe_entity_model,
18 executor::{
19 DeleteExecutor, ExecutablePlan, ExecutionStrategy, ExecutorPlanError, LoadExecutor,
20 SaveExecutor,
21 },
22 query::{
23 builder::aggregate::AggregateExpr, explain::ExplainAggregateTerminalPlan,
24 plan::QueryMode,
25 },
26 },
27 error::InternalError,
28 obs::sink::{MetricsSink, with_metrics_sink},
29 traits::{CanisterKind, EntityKind, EntityValue},
30};
31use std::thread::LocalKey;
32
33fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
35 match err {
36 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
37 }
38}
39
40fn decode_optional_cursor_bytes(cursor_token: Option<&str>) -> Result<Option<Vec<u8>>, QueryError> {
43 decode_optional_cursor_token(cursor_token).map_err(|err| QueryError::from(PlanError::from(err)))
44}
45
46pub struct DbSession<C: CanisterKind> {
53 db: Db<C>,
54 debug: bool,
55 metrics: Option<&'static dyn MetricsSink>,
56}
57
58impl<C: CanisterKind> DbSession<C> {
59 #[must_use]
61 pub(crate) const fn new(db: Db<C>) -> Self {
62 Self {
63 db,
64 debug: false,
65 metrics: None,
66 }
67 }
68
69 #[must_use]
71 pub const fn new_with_hooks(
72 store: &'static LocalKey<StoreRegistry>,
73 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
74 ) -> Self {
75 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
76 }
77
78 #[must_use]
80 pub const fn debug(mut self) -> Self {
81 self.debug = true;
82 self
83 }
84
85 #[must_use]
87 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
88 self.metrics = Some(sink);
89 self
90 }
91
92 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
93 if let Some(sink) = self.metrics {
94 with_metrics_sink(sink, f)
95 } else {
96 f()
97 }
98 }
99
100 fn execute_save_with<E, T, R>(
102 &self,
103 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
104 map: impl FnOnce(T) -> R,
105 ) -> Result<R, InternalError>
106 where
107 E: EntityKind<Canister = C> + EntityValue,
108 {
109 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
110
111 Ok(map(value))
112 }
113
114 fn execute_save_entity<E>(
116 &self,
117 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
118 ) -> Result<E, InternalError>
119 where
120 E: EntityKind<Canister = C> + EntityValue,
121 {
122 self.execute_save_with(op, std::convert::identity)
123 }
124
125 fn execute_save_batch<E>(
126 &self,
127 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
128 ) -> Result<WriteBatchResponse<E>, InternalError>
129 where
130 E: EntityKind<Canister = C> + EntityValue,
131 {
132 self.execute_save_with(op, WriteBatchResponse::new)
133 }
134
135 fn execute_save_view<E>(
136 &self,
137 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
138 ) -> Result<E::ViewType, InternalError>
139 where
140 E: EntityKind<Canister = C> + EntityValue,
141 {
142 self.execute_save_with(op, std::convert::identity)
143 }
144
145 #[must_use]
151 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
152 where
153 E: EntityKind<Canister = C>,
154 {
155 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
156 }
157
158 #[must_use]
160 pub const fn load_with_consistency<E>(
161 &self,
162 consistency: MissingRowPolicy,
163 ) -> FluentLoadQuery<'_, E>
164 where
165 E: EntityKind<Canister = C>,
166 {
167 FluentLoadQuery::new(self, Query::new(consistency))
168 }
169
170 #[must_use]
172 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
173 where
174 E: EntityKind<Canister = C>,
175 {
176 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
177 }
178
179 #[must_use]
181 pub fn delete_with_consistency<E>(
182 &self,
183 consistency: MissingRowPolicy,
184 ) -> FluentDeleteQuery<'_, E>
185 where
186 E: EntityKind<Canister = C>,
187 {
188 FluentDeleteQuery::new(self, Query::new(consistency).delete())
189 }
190
191 #[must_use]
198 pub fn show_indexes<E>(&self) -> Vec<String>
199 where
200 E: EntityKind<Canister = C>,
201 {
202 let mut indexes = Vec::with_capacity(E::MODEL.indexes.len().saturating_add(1));
203 indexes.push(format!("PRIMARY KEY ({})", E::MODEL.primary_key.name));
204
205 for index in E::MODEL.indexes {
206 let kind = if index.unique {
207 "UNIQUE INDEX"
208 } else {
209 "INDEX"
210 };
211 let fields = index.fields.join(", ");
212 indexes.push(format!("{kind} {} ({fields})", index.name));
213 }
214
215 indexes
216 }
217
218 #[must_use]
223 pub fn describe_entity<E>(&self) -> EntitySchemaDescription
224 where
225 E: EntityKind<Canister = C>,
226 {
227 describe_entity_model(E::MODEL)
228 }
229
230 #[must_use]
235 pub(crate) const fn load_executor<E>(&self) -> LoadExecutor<E>
236 where
237 E: EntityKind<Canister = C> + EntityValue,
238 {
239 LoadExecutor::new(self.db, self.debug)
240 }
241
242 #[must_use]
243 pub(crate) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
244 where
245 E: EntityKind<Canister = C> + EntityValue,
246 {
247 DeleteExecutor::new(self.db, self.debug)
248 }
249
250 #[must_use]
251 pub(crate) const fn save_executor<E>(&self) -> SaveExecutor<E>
252 where
253 E: EntityKind<Canister = C> + EntityValue,
254 {
255 SaveExecutor::new(self.db, self.debug)
256 }
257
258 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
264 where
265 E: EntityKind<Canister = C> + EntityValue,
266 {
267 let plan = query.plan()?.into_executable();
268
269 let result = match query.mode() {
270 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
271 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
272 };
273
274 result.map_err(QueryError::execute)
275 }
276
277 pub(crate) fn execute_load_query_with<E, T>(
280 &self,
281 query: &Query<E>,
282 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
283 ) -> Result<T, QueryError>
284 where
285 E: EntityKind<Canister = C> + EntityValue,
286 {
287 let plan = query.plan()?.into_executable();
288
289 self.with_metrics(|| op(self.load_executor::<E>(), plan))
290 .map_err(QueryError::execute)
291 }
292
293 pub fn trace_query<E>(&self, query: &Query<E>) -> Result<QueryTracePlan, QueryError>
298 where
299 E: EntityKind<Canister = C>,
300 {
301 let compiled = query.plan()?;
302 let explain = compiled.explain();
303 let plan_hash = explain.fingerprint().to_string();
304
305 let executable = compiled.into_executable();
306 let access_strategy = AccessStrategy::from_plan(executable.access()).debug_summary();
307 let execution_strategy = match query.mode() {
308 QueryMode::Load(_) => Some(trace_execution_strategy(
309 executable
310 .execution_strategy()
311 .map_err(QueryError::execute)?,
312 )),
313 QueryMode::Delete(_) => None,
314 };
315
316 Ok(QueryTracePlan::new(
317 plan_hash,
318 access_strategy,
319 execution_strategy,
320 explain,
321 ))
322 }
323
324 pub(crate) fn explain_load_query_terminal_with<E>(
326 query: &Query<E>,
327 aggregate: AggregateExpr,
328 ) -> Result<ExplainAggregateTerminalPlan, QueryError>
329 where
330 E: EntityKind<Canister = C> + EntityValue,
331 {
332 let compiled = query.plan()?;
334 let query_explain = compiled.explain();
335 let terminal = aggregate.kind();
336
337 let executable = compiled.into_executable();
339 let execution = executable.explain_aggregate_terminal_execution_descriptor(aggregate);
340
341 Ok(ExplainAggregateTerminalPlan::new(
342 query_explain,
343 terminal,
344 execution,
345 ))
346 }
347
348 pub(crate) fn execute_load_query_paged_with_trace<E>(
350 &self,
351 query: &Query<E>,
352 cursor_token: Option<&str>,
353 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
354 where
355 E: EntityKind<Canister = C> + EntityValue,
356 {
357 let plan = query.plan()?.into_executable();
359 match plan.execution_strategy().map_err(QueryError::execute)? {
360 ExecutionStrategy::PrimaryKey => {
361 return Err(QueryError::execute(invariant(
362 "cursor pagination requires explicit or grouped ordering",
363 )));
364 }
365 ExecutionStrategy::Ordered => {}
366 ExecutionStrategy::Grouped => {
367 return Err(QueryError::execute(invariant(
368 "grouped plans require execute_grouped(...)",
369 )));
370 }
371 }
372
373 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
375 let cursor = plan
376 .prepare_cursor(cursor_bytes.as_deref())
377 .map_err(map_executor_plan_error)?;
378
379 let (page, trace) = self
381 .with_metrics(|| {
382 self.load_executor::<E>()
383 .execute_paged_with_cursor_traced(plan, cursor)
384 })
385 .map_err(QueryError::execute)?;
386 let next_cursor = page
387 .next_cursor
388 .map(|token| {
389 let Some(token) = token.as_scalar() else {
390 return Err(QueryError::execute(invariant(
391 "scalar load pagination emitted grouped continuation token",
392 )));
393 };
394
395 token.encode().map_err(|err| {
396 QueryError::execute(InternalError::serialize_internal(format!(
397 "failed to serialize continuation cursor: {err}"
398 )))
399 })
400 })
401 .transpose()?;
402
403 Ok(PagedLoadExecutionWithTrace::new(
404 page.items,
405 next_cursor,
406 trace,
407 ))
408 }
409
410 pub fn execute_grouped<E>(
415 &self,
416 query: &Query<E>,
417 cursor_token: Option<&str>,
418 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
419 where
420 E: EntityKind<Canister = C> + EntityValue,
421 {
422 let plan = query.plan()?.into_executable();
424 if !matches!(
425 plan.execution_strategy().map_err(QueryError::execute)?,
426 ExecutionStrategy::Grouped
427 ) {
428 return Err(QueryError::execute(invariant(
429 "execute_grouped requires grouped logical plans",
430 )));
431 }
432
433 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
435 let cursor = plan
436 .prepare_grouped_cursor(cursor_bytes.as_deref())
437 .map_err(map_executor_plan_error)?;
438
439 let (page, trace) = self
441 .with_metrics(|| {
442 self.load_executor::<E>()
443 .execute_grouped_paged_with_cursor_traced(plan, cursor)
444 })
445 .map_err(QueryError::execute)?;
446 let next_cursor = page
447 .next_cursor
448 .map(|token| {
449 let Some(token) = token.as_grouped() else {
450 return Err(QueryError::execute(invariant(
451 "grouped pagination emitted scalar continuation token",
452 )));
453 };
454
455 token.encode().map_err(|err| {
456 QueryError::execute(InternalError::serialize_internal(format!(
457 "failed to serialize grouped continuation cursor: {err}"
458 )))
459 })
460 })
461 .transpose()?;
462
463 Ok(PagedGroupedExecutionWithTrace::new(
464 page.rows,
465 next_cursor,
466 trace,
467 ))
468 }
469
470 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
476 where
477 E: EntityKind<Canister = C> + EntityValue,
478 {
479 self.execute_save_entity(|save| save.insert(entity))
480 }
481
482 pub fn insert_many_atomic<E>(
488 &self,
489 entities: impl IntoIterator<Item = E>,
490 ) -> Result<WriteBatchResponse<E>, InternalError>
491 where
492 E: EntityKind<Canister = C> + EntityValue,
493 {
494 self.execute_save_batch(|save| save.insert_many_atomic(entities))
495 }
496
497 pub fn insert_many_non_atomic<E>(
501 &self,
502 entities: impl IntoIterator<Item = E>,
503 ) -> Result<WriteBatchResponse<E>, InternalError>
504 where
505 E: EntityKind<Canister = C> + EntityValue,
506 {
507 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
508 }
509
510 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
512 where
513 E: EntityKind<Canister = C> + EntityValue,
514 {
515 self.execute_save_entity(|save| save.replace(entity))
516 }
517
518 pub fn replace_many_atomic<E>(
524 &self,
525 entities: impl IntoIterator<Item = E>,
526 ) -> Result<WriteBatchResponse<E>, InternalError>
527 where
528 E: EntityKind<Canister = C> + EntityValue,
529 {
530 self.execute_save_batch(|save| save.replace_many_atomic(entities))
531 }
532
533 pub fn replace_many_non_atomic<E>(
537 &self,
538 entities: impl IntoIterator<Item = E>,
539 ) -> Result<WriteBatchResponse<E>, InternalError>
540 where
541 E: EntityKind<Canister = C> + EntityValue,
542 {
543 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
544 }
545
546 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
548 where
549 E: EntityKind<Canister = C> + EntityValue,
550 {
551 self.execute_save_entity(|save| save.update(entity))
552 }
553
554 pub fn update_many_atomic<E>(
560 &self,
561 entities: impl IntoIterator<Item = E>,
562 ) -> Result<WriteBatchResponse<E>, InternalError>
563 where
564 E: EntityKind<Canister = C> + EntityValue,
565 {
566 self.execute_save_batch(|save| save.update_many_atomic(entities))
567 }
568
569 pub fn update_many_non_atomic<E>(
573 &self,
574 entities: impl IntoIterator<Item = E>,
575 ) -> Result<WriteBatchResponse<E>, InternalError>
576 where
577 E: EntityKind<Canister = C> + EntityValue,
578 {
579 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
580 }
581
582 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
584 where
585 E: EntityKind<Canister = C> + EntityValue,
586 {
587 self.execute_save_view::<E>(|save| save.insert_view(view))
588 }
589
590 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
592 where
593 E: EntityKind<Canister = C> + EntityValue,
594 {
595 self.execute_save_view::<E>(|save| save.replace_view(view))
596 }
597
598 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
600 where
601 E: EntityKind<Canister = C> + EntityValue,
602 {
603 self.execute_save_view::<E>(|save| save.update_view(view))
604 }
605
606 #[cfg(test)]
608 #[doc(hidden)]
609 pub fn clear_stores_for_tests(&self) {
610 self.db.with_store_registry(|reg| {
611 for (_, store) in reg.iter() {
614 store.with_data_mut(DataStore::clear);
615 store.with_index_mut(IndexStore::clear);
616 }
617 });
618 }
619}
620
621fn invariant(message: impl Into<String>) -> InternalError {
622 InternalError::query_executor_invariant(message)
623}
624
625const fn trace_execution_strategy(strategy: ExecutionStrategy) -> TraceExecutionStrategy {
626 match strategy {
627 ExecutionStrategy::PrimaryKey => TraceExecutionStrategy::PrimaryKey,
628 ExecutionStrategy::Ordered => TraceExecutionStrategy::Ordered,
629 ExecutionStrategy::Grouped => TraceExecutionStrategy::Grouped,
630 }
631}
632
633#[cfg(test)]
638mod tests {
639 use super::*;
640 use crate::db::cursor::CursorPlanError;
641
642 fn assert_query_error_is_cursor_plan(
644 err: QueryError,
645 predicate: impl FnOnce(&CursorPlanError) -> bool,
646 ) {
647 assert!(matches!(
648 err,
649 QueryError::Plan(plan_err)
650 if matches!(
651 plan_err.as_ref(),
652 PlanError::Cursor(inner) if predicate(inner.as_ref())
653 )
654 ));
655 }
656
657 fn assert_cursor_mapping_parity(
659 build: impl Fn() -> CursorPlanError,
660 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
661 ) {
662 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
663 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
664
665 let mapped_via_plan = QueryError::from(PlanError::from(build()));
666 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
667 }
668
669 #[test]
670 fn session_cursor_error_mapping_parity_boundary_arity() {
671 assert_cursor_mapping_parity(
672 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
673 |inner| {
674 matches!(
675 inner,
676 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
677 expected: 2,
678 found: 1
679 }
680 )
681 },
682 );
683 }
684
685 #[test]
686 fn session_cursor_error_mapping_parity_window_mismatch() {
687 assert_cursor_mapping_parity(
688 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
689 |inner| {
690 matches!(
691 inner,
692 CursorPlanError::ContinuationCursorWindowMismatch {
693 expected_offset: 8,
694 actual_offset: 3
695 }
696 )
697 },
698 );
699 }
700
701 #[test]
702 fn session_cursor_error_mapping_parity_decode_reason() {
703 assert_cursor_mapping_parity(
704 || {
705 CursorPlanError::invalid_continuation_cursor(
706 crate::db::codec::cursor::CursorDecodeError::OddLength,
707 )
708 },
709 |inner| {
710 matches!(
711 inner,
712 CursorPlanError::InvalidContinuationCursor {
713 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
714 }
715 )
716 },
717 );
718 }
719
720 #[test]
721 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
722 assert_cursor_mapping_parity(
723 || {
724 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
725 "id",
726 "ulid",
727 Some(crate::value::Value::Text("not-a-ulid".to_string())),
728 )
729 },
730 |inner| {
731 matches!(
732 inner,
733 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
734 field,
735 expected,
736 value: Some(crate::value::Value::Text(value))
737 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
738 )
739 },
740 );
741 }
742
743 #[test]
744 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
745 assert_cursor_mapping_parity(
747 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
748 |inner| {
749 matches!(
750 inner,
751 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
752 expected: 2,
753 found: 1
754 }
755 )
756 },
757 );
758 }
759}