1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, EntitySchemaDescription, FluentDeleteQuery, FluentLoadQuery,
11 MissingRowPolicy, PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError,
12 Query, QueryError, QueryTracePlan, StorageReport, StoreRegistry, TraceExecutionStrategy,
13 WriteBatchResponse,
14 access::AccessStrategy,
15 commit::EntityRuntimeHooks,
16 cursor::decode_optional_cursor_token,
17 describe::describe_entity_model,
18 executor::{
19 DeleteExecutor, ExecutablePlan, ExecutionStrategy, ExecutorPlanError, LoadExecutor,
20 SaveExecutor,
21 },
22 query::{
23 builder::aggregate::AggregateExpr, explain::ExplainAggregateTerminalPlan,
24 plan::QueryMode,
25 },
26 },
27 error::InternalError,
28 obs::sink::{MetricsSink, with_metrics_sink},
29 traits::{CanisterKind, EntityKind, EntityValue},
30 value::Value,
31};
32use std::thread::LocalKey;
33
34fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
36 match err {
37 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
38 }
39}
40
41fn decode_optional_cursor_bytes(cursor_token: Option<&str>) -> Result<Option<Vec<u8>>, QueryError> {
44 decode_optional_cursor_token(cursor_token).map_err(|err| QueryError::from(PlanError::from(err)))
45}
46
47pub struct DbSession<C: CanisterKind> {
54 db: Db<C>,
55 debug: bool,
56 metrics: Option<&'static dyn MetricsSink>,
57}
58
59impl<C: CanisterKind> DbSession<C> {
60 #[must_use]
62 pub(crate) const fn new(db: Db<C>) -> Self {
63 Self {
64 db,
65 debug: false,
66 metrics: None,
67 }
68 }
69
70 #[must_use]
72 pub const fn new_with_hooks(
73 store: &'static LocalKey<StoreRegistry>,
74 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
75 ) -> Self {
76 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
77 }
78
79 #[must_use]
81 pub const fn debug(mut self) -> Self {
82 self.debug = true;
83 self
84 }
85
86 #[must_use]
88 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
89 self.metrics = Some(sink);
90 self
91 }
92
93 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
94 if let Some(sink) = self.metrics {
95 with_metrics_sink(sink, f)
96 } else {
97 f()
98 }
99 }
100
101 fn execute_save_with<E, T, R>(
103 &self,
104 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
105 map: impl FnOnce(T) -> R,
106 ) -> Result<R, InternalError>
107 where
108 E: EntityKind<Canister = C> + EntityValue,
109 {
110 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
111
112 Ok(map(value))
113 }
114
115 fn execute_save_entity<E>(
117 &self,
118 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
119 ) -> Result<E, InternalError>
120 where
121 E: EntityKind<Canister = C> + EntityValue,
122 {
123 self.execute_save_with(op, std::convert::identity)
124 }
125
126 fn execute_save_batch<E>(
127 &self,
128 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
129 ) -> Result<WriteBatchResponse<E>, InternalError>
130 where
131 E: EntityKind<Canister = C> + EntityValue,
132 {
133 self.execute_save_with(op, WriteBatchResponse::new)
134 }
135
136 fn execute_save_view<E>(
137 &self,
138 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
139 ) -> Result<E::ViewType, InternalError>
140 where
141 E: EntityKind<Canister = C> + EntityValue,
142 {
143 self.execute_save_with(op, std::convert::identity)
144 }
145
146 #[must_use]
152 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
153 where
154 E: EntityKind<Canister = C>,
155 {
156 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
157 }
158
159 #[must_use]
161 pub const fn load_with_consistency<E>(
162 &self,
163 consistency: MissingRowPolicy,
164 ) -> FluentLoadQuery<'_, E>
165 where
166 E: EntityKind<Canister = C>,
167 {
168 FluentLoadQuery::new(self, Query::new(consistency))
169 }
170
171 #[must_use]
173 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
174 where
175 E: EntityKind<Canister = C>,
176 {
177 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
178 }
179
180 #[must_use]
182 pub fn delete_with_consistency<E>(
183 &self,
184 consistency: MissingRowPolicy,
185 ) -> FluentDeleteQuery<'_, E>
186 where
187 E: EntityKind<Canister = C>,
188 {
189 FluentDeleteQuery::new(self, Query::new(consistency).delete())
190 }
191
192 #[must_use]
196 pub const fn select_one(&self) -> Value {
197 Value::Int(1)
198 }
199
200 #[must_use]
207 pub fn show_indexes<E>(&self) -> Vec<String>
208 where
209 E: EntityKind<Canister = C>,
210 {
211 let mut indexes = Vec::with_capacity(E::MODEL.indexes.len().saturating_add(1));
212 indexes.push(format!("PRIMARY KEY ({})", E::MODEL.primary_key.name));
213
214 for index in E::MODEL.indexes {
215 let kind = if index.is_unique() {
216 "UNIQUE INDEX"
217 } else {
218 "INDEX"
219 };
220 let fields = index.fields().join(", ");
221 indexes.push(format!("{kind} {} ({fields})", index.name()));
222 }
223
224 indexes
225 }
226
227 #[must_use]
232 pub fn describe_entity<E>(&self) -> EntitySchemaDescription
233 where
234 E: EntityKind<Canister = C>,
235 {
236 describe_entity_model(E::MODEL)
237 }
238
239 pub fn storage_report(
241 &self,
242 name_to_path: &[(&'static str, &'static str)],
243 ) -> Result<StorageReport, InternalError> {
244 self.db.storage_report(name_to_path)
245 }
246
247 #[must_use]
252 pub(in crate::db) const fn load_executor<E>(&self) -> LoadExecutor<E>
253 where
254 E: EntityKind<Canister = C> + EntityValue,
255 {
256 LoadExecutor::new(self.db, self.debug)
257 }
258
259 #[must_use]
260 pub(in crate::db) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
261 where
262 E: EntityKind<Canister = C> + EntityValue,
263 {
264 DeleteExecutor::new(self.db, self.debug)
265 }
266
267 #[must_use]
268 pub(in crate::db) const fn save_executor<E>(&self) -> SaveExecutor<E>
269 where
270 E: EntityKind<Canister = C> + EntityValue,
271 {
272 SaveExecutor::new(self.db, self.debug)
273 }
274
275 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
281 where
282 E: EntityKind<Canister = C> + EntityValue,
283 {
284 let plan = query.plan()?.into_executable();
285
286 let result = match query.mode() {
287 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
288 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
289 };
290
291 result.map_err(QueryError::execute)
292 }
293
294 pub(in crate::db) fn execute_load_query_with<E, T>(
297 &self,
298 query: &Query<E>,
299 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
300 ) -> Result<T, QueryError>
301 where
302 E: EntityKind<Canister = C> + EntityValue,
303 {
304 let plan = query.plan()?.into_executable();
305
306 self.with_metrics(|| op(self.load_executor::<E>(), plan))
307 .map_err(QueryError::execute)
308 }
309
310 pub fn trace_query<E>(&self, query: &Query<E>) -> Result<QueryTracePlan, QueryError>
315 where
316 E: EntityKind<Canister = C>,
317 {
318 let compiled = query.plan()?;
319 let explain = compiled.explain();
320 let plan_hash = compiled.plan_hash_hex();
321
322 let executable = compiled.into_executable();
323 let access_strategy = AccessStrategy::from_plan(executable.access()).debug_summary();
324 let execution_strategy = match query.mode() {
325 QueryMode::Load(_) => Some(trace_execution_strategy(
326 executable
327 .execution_strategy()
328 .map_err(QueryError::execute)?,
329 )),
330 QueryMode::Delete(_) => None,
331 };
332
333 Ok(QueryTracePlan::new(
334 plan_hash,
335 access_strategy,
336 execution_strategy,
337 explain,
338 ))
339 }
340
341 pub(crate) fn explain_load_query_terminal_with<E>(
343 query: &Query<E>,
344 aggregate: AggregateExpr,
345 ) -> Result<ExplainAggregateTerminalPlan, QueryError>
346 where
347 E: EntityKind<Canister = C> + EntityValue,
348 {
349 let compiled = query.plan()?;
351 let query_explain = compiled.explain();
352 let terminal = aggregate.kind();
353
354 let executable = compiled.into_executable();
356 let execution = executable.explain_aggregate_terminal_execution_descriptor(aggregate);
357
358 Ok(ExplainAggregateTerminalPlan::new(
359 query_explain,
360 terminal,
361 execution,
362 ))
363 }
364
365 pub(crate) fn execute_load_query_paged_with_trace<E>(
367 &self,
368 query: &Query<E>,
369 cursor_token: Option<&str>,
370 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
371 where
372 E: EntityKind<Canister = C> + EntityValue,
373 {
374 let plan = query.plan()?.into_executable();
376 match plan.execution_strategy().map_err(QueryError::execute)? {
377 ExecutionStrategy::PrimaryKey => {
378 return Err(QueryError::execute(crate::db::error::executor_invariant(
379 "cursor pagination requires explicit or grouped ordering",
380 )));
381 }
382 ExecutionStrategy::Ordered => {}
383 ExecutionStrategy::Grouped => {
384 return Err(QueryError::execute(crate::db::error::executor_invariant(
385 "grouped plans require execute_grouped(...)",
386 )));
387 }
388 }
389
390 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
392 let cursor = plan
393 .prepare_cursor(cursor_bytes.as_deref())
394 .map_err(map_executor_plan_error)?;
395
396 let (page, trace) = self
398 .with_metrics(|| {
399 self.load_executor::<E>()
400 .execute_paged_with_cursor_traced(plan, cursor)
401 })
402 .map_err(QueryError::execute)?;
403 let next_cursor = page
404 .next_cursor
405 .map(|token| {
406 let Some(token) = token.as_scalar() else {
407 return Err(QueryError::execute(crate::db::error::executor_invariant(
408 "scalar load pagination emitted grouped continuation token",
409 )));
410 };
411
412 token.encode().map_err(|err| {
413 QueryError::execute(InternalError::serialize_internal(format!(
414 "failed to serialize continuation cursor: {err}"
415 )))
416 })
417 })
418 .transpose()?;
419
420 Ok(PagedLoadExecutionWithTrace::new(
421 page.items,
422 next_cursor,
423 trace,
424 ))
425 }
426
427 pub fn execute_grouped<E>(
432 &self,
433 query: &Query<E>,
434 cursor_token: Option<&str>,
435 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
436 where
437 E: EntityKind<Canister = C> + EntityValue,
438 {
439 let plan = query.plan()?.into_executable();
441 if !matches!(
442 plan.execution_strategy().map_err(QueryError::execute)?,
443 ExecutionStrategy::Grouped
444 ) {
445 return Err(QueryError::execute(crate::db::error::executor_invariant(
446 "execute_grouped requires grouped logical plans",
447 )));
448 }
449
450 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
452 let cursor = plan
453 .prepare_grouped_cursor(cursor_bytes.as_deref())
454 .map_err(map_executor_plan_error)?;
455
456 let (page, trace) = self
458 .with_metrics(|| {
459 self.load_executor::<E>()
460 .execute_grouped_paged_with_cursor_traced(plan, cursor)
461 })
462 .map_err(QueryError::execute)?;
463 let next_cursor = page
464 .next_cursor
465 .map(|token| {
466 let Some(token) = token.as_grouped() else {
467 return Err(QueryError::execute(crate::db::error::executor_invariant(
468 "grouped pagination emitted scalar continuation token",
469 )));
470 };
471
472 token.encode().map_err(|err| {
473 QueryError::execute(InternalError::serialize_internal(format!(
474 "failed to serialize grouped continuation cursor: {err}"
475 )))
476 })
477 })
478 .transpose()?;
479
480 Ok(PagedGroupedExecutionWithTrace::new(
481 page.rows,
482 next_cursor,
483 trace,
484 ))
485 }
486
487 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
493 where
494 E: EntityKind<Canister = C> + EntityValue,
495 {
496 self.execute_save_entity(|save| save.insert(entity))
497 }
498
499 pub fn insert_many_atomic<E>(
505 &self,
506 entities: impl IntoIterator<Item = E>,
507 ) -> Result<WriteBatchResponse<E>, InternalError>
508 where
509 E: EntityKind<Canister = C> + EntityValue,
510 {
511 self.execute_save_batch(|save| save.insert_many_atomic(entities))
512 }
513
514 pub fn insert_many_non_atomic<E>(
518 &self,
519 entities: impl IntoIterator<Item = E>,
520 ) -> Result<WriteBatchResponse<E>, InternalError>
521 where
522 E: EntityKind<Canister = C> + EntityValue,
523 {
524 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
525 }
526
527 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
529 where
530 E: EntityKind<Canister = C> + EntityValue,
531 {
532 self.execute_save_entity(|save| save.replace(entity))
533 }
534
535 pub fn replace_many_atomic<E>(
541 &self,
542 entities: impl IntoIterator<Item = E>,
543 ) -> Result<WriteBatchResponse<E>, InternalError>
544 where
545 E: EntityKind<Canister = C> + EntityValue,
546 {
547 self.execute_save_batch(|save| save.replace_many_atomic(entities))
548 }
549
550 pub fn replace_many_non_atomic<E>(
554 &self,
555 entities: impl IntoIterator<Item = E>,
556 ) -> Result<WriteBatchResponse<E>, InternalError>
557 where
558 E: EntityKind<Canister = C> + EntityValue,
559 {
560 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
561 }
562
563 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
565 where
566 E: EntityKind<Canister = C> + EntityValue,
567 {
568 self.execute_save_entity(|save| save.update(entity))
569 }
570
571 pub fn update_many_atomic<E>(
577 &self,
578 entities: impl IntoIterator<Item = E>,
579 ) -> Result<WriteBatchResponse<E>, InternalError>
580 where
581 E: EntityKind<Canister = C> + EntityValue,
582 {
583 self.execute_save_batch(|save| save.update_many_atomic(entities))
584 }
585
586 pub fn update_many_non_atomic<E>(
590 &self,
591 entities: impl IntoIterator<Item = E>,
592 ) -> Result<WriteBatchResponse<E>, InternalError>
593 where
594 E: EntityKind<Canister = C> + EntityValue,
595 {
596 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
597 }
598
599 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
601 where
602 E: EntityKind<Canister = C> + EntityValue,
603 {
604 self.execute_save_view::<E>(|save| save.insert_view(view))
605 }
606
607 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
609 where
610 E: EntityKind<Canister = C> + EntityValue,
611 {
612 self.execute_save_view::<E>(|save| save.replace_view(view))
613 }
614
615 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
617 where
618 E: EntityKind<Canister = C> + EntityValue,
619 {
620 self.execute_save_view::<E>(|save| save.update_view(view))
621 }
622
623 #[cfg(test)]
625 #[doc(hidden)]
626 pub fn clear_stores_for_tests(&self) {
627 self.db.with_store_registry(|reg| {
628 for (_, store) in reg.iter() {
631 store.with_data_mut(DataStore::clear);
632 store.with_index_mut(IndexStore::clear);
633 }
634 });
635 }
636}
637
638const fn trace_execution_strategy(strategy: ExecutionStrategy) -> TraceExecutionStrategy {
639 match strategy {
640 ExecutionStrategy::PrimaryKey => TraceExecutionStrategy::PrimaryKey,
641 ExecutionStrategy::Ordered => TraceExecutionStrategy::Ordered,
642 ExecutionStrategy::Grouped => TraceExecutionStrategy::Grouped,
643 }
644}
645
646#[cfg(test)]
651mod tests {
652 use super::*;
653 use crate::db::cursor::CursorPlanError;
654
655 fn assert_query_error_is_cursor_plan(
657 err: QueryError,
658 predicate: impl FnOnce(&CursorPlanError) -> bool,
659 ) {
660 assert!(matches!(
661 err,
662 QueryError::Plan(plan_err)
663 if matches!(
664 plan_err.as_ref(),
665 PlanError::Cursor(inner) if predicate(inner.as_ref())
666 )
667 ));
668 }
669
670 fn assert_cursor_mapping_parity(
672 build: impl Fn() -> CursorPlanError,
673 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
674 ) {
675 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
676 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
677
678 let mapped_via_plan = QueryError::from(PlanError::from(build()));
679 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
680 }
681
682 #[test]
683 fn session_cursor_error_mapping_parity_boundary_arity() {
684 assert_cursor_mapping_parity(
685 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
686 |inner| {
687 matches!(
688 inner,
689 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
690 expected: 2,
691 found: 1
692 }
693 )
694 },
695 );
696 }
697
698 #[test]
699 fn session_cursor_error_mapping_parity_window_mismatch() {
700 assert_cursor_mapping_parity(
701 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
702 |inner| {
703 matches!(
704 inner,
705 CursorPlanError::ContinuationCursorWindowMismatch {
706 expected_offset: 8,
707 actual_offset: 3
708 }
709 )
710 },
711 );
712 }
713
714 #[test]
715 fn session_cursor_error_mapping_parity_decode_reason() {
716 assert_cursor_mapping_parity(
717 || {
718 CursorPlanError::invalid_continuation_cursor(
719 crate::db::codec::cursor::CursorDecodeError::OddLength,
720 )
721 },
722 |inner| {
723 matches!(
724 inner,
725 CursorPlanError::InvalidContinuationCursor {
726 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
727 }
728 )
729 },
730 );
731 }
732
733 #[test]
734 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
735 assert_cursor_mapping_parity(
736 || {
737 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
738 "id",
739 "ulid",
740 Some(crate::value::Value::Text("not-a-ulid".to_string())),
741 )
742 },
743 |inner| {
744 matches!(
745 inner,
746 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
747 field,
748 expected,
749 value: Some(crate::value::Value::Text(value))
750 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
751 )
752 },
753 );
754 }
755
756 #[test]
757 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
758 assert_cursor_mapping_parity(
760 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
761 |inner| {
762 matches!(
763 inner,
764 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
765 expected: 2,
766 found: 1
767 }
768 )
769 },
770 );
771 }
772}