1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, EntitySchemaDescription, FluentDeleteQuery, FluentLoadQuery,
11 MissingRowPolicy, PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError,
12 Query, QueryError, QueryTracePlan, StorageReport, StoreRegistry, TraceExecutionStrategy,
13 WriteBatchResponse,
14 access::AccessStrategy,
15 commit::EntityRuntimeHooks,
16 cursor::decode_optional_cursor_token,
17 describe::describe_entity_model,
18 executor::{
19 DeleteExecutor, ExecutablePlan, ExecutionStrategy, ExecutorPlanError, LoadExecutor,
20 SaveExecutor,
21 },
22 query::{
23 builder::aggregate::AggregateExpr, explain::ExplainAggregateTerminalPlan,
24 plan::QueryMode,
25 },
26 },
27 error::InternalError,
28 obs::sink::{MetricsSink, with_metrics_sink},
29 traits::{CanisterKind, EntityKind, EntityValue},
30 value::Value,
31};
32use std::thread::LocalKey;
33
34fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
36 match err {
37 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
38 }
39}
40
41fn decode_optional_cursor_bytes(cursor_token: Option<&str>) -> Result<Option<Vec<u8>>, QueryError> {
44 decode_optional_cursor_token(cursor_token).map_err(|err| QueryError::from(PlanError::from(err)))
45}
46
47pub struct DbSession<C: CanisterKind> {
54 db: Db<C>,
55 debug: bool,
56 metrics: Option<&'static dyn MetricsSink>,
57}
58
59impl<C: CanisterKind> DbSession<C> {
60 #[must_use]
62 pub(crate) const fn new(db: Db<C>) -> Self {
63 Self {
64 db,
65 debug: false,
66 metrics: None,
67 }
68 }
69
70 #[must_use]
72 pub const fn new_with_hooks(
73 store: &'static LocalKey<StoreRegistry>,
74 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
75 ) -> Self {
76 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
77 }
78
79 #[must_use]
81 pub const fn debug(mut self) -> Self {
82 self.debug = true;
83 self
84 }
85
86 #[must_use]
88 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
89 self.metrics = Some(sink);
90 self
91 }
92
93 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
94 if let Some(sink) = self.metrics {
95 with_metrics_sink(sink, f)
96 } else {
97 f()
98 }
99 }
100
101 fn execute_save_with<E, T, R>(
103 &self,
104 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
105 map: impl FnOnce(T) -> R,
106 ) -> Result<R, InternalError>
107 where
108 E: EntityKind<Canister = C> + EntityValue,
109 {
110 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
111
112 Ok(map(value))
113 }
114
115 fn execute_save_entity<E>(
117 &self,
118 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
119 ) -> Result<E, InternalError>
120 where
121 E: EntityKind<Canister = C> + EntityValue,
122 {
123 self.execute_save_with(op, std::convert::identity)
124 }
125
126 fn execute_save_batch<E>(
127 &self,
128 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
129 ) -> Result<WriteBatchResponse<E>, InternalError>
130 where
131 E: EntityKind<Canister = C> + EntityValue,
132 {
133 self.execute_save_with(op, WriteBatchResponse::new)
134 }
135
136 fn execute_save_view<E>(
137 &self,
138 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
139 ) -> Result<E::ViewType, InternalError>
140 where
141 E: EntityKind<Canister = C> + EntityValue,
142 {
143 self.execute_save_with(op, std::convert::identity)
144 }
145
146 #[must_use]
152 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
153 where
154 E: EntityKind<Canister = C>,
155 {
156 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
157 }
158
159 #[must_use]
161 pub const fn load_with_consistency<E>(
162 &self,
163 consistency: MissingRowPolicy,
164 ) -> FluentLoadQuery<'_, E>
165 where
166 E: EntityKind<Canister = C>,
167 {
168 FluentLoadQuery::new(self, Query::new(consistency))
169 }
170
171 #[must_use]
173 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
174 where
175 E: EntityKind<Canister = C>,
176 {
177 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
178 }
179
180 #[must_use]
182 pub fn delete_with_consistency<E>(
183 &self,
184 consistency: MissingRowPolicy,
185 ) -> FluentDeleteQuery<'_, E>
186 where
187 E: EntityKind<Canister = C>,
188 {
189 FluentDeleteQuery::new(self, Query::new(consistency).delete())
190 }
191
192 #[must_use]
196 pub const fn select_one(&self) -> Value {
197 Value::Int(1)
198 }
199
200 #[must_use]
207 pub fn show_indexes<E>(&self) -> Vec<String>
208 where
209 E: EntityKind<Canister = C>,
210 {
211 let mut indexes = Vec::with_capacity(E::MODEL.indexes.len().saturating_add(1));
212 indexes.push(format!("PRIMARY KEY ({})", E::MODEL.primary_key.name));
213
214 for index in E::MODEL.indexes {
215 let kind = if index.is_unique() {
216 "UNIQUE INDEX"
217 } else {
218 "INDEX"
219 };
220 let fields = index.fields().join(", ");
221 indexes.push(format!("{kind} {} ({fields})", index.name()));
222 }
223
224 indexes
225 }
226
227 #[must_use]
232 pub fn describe_entity<E>(&self) -> EntitySchemaDescription
233 where
234 E: EntityKind<Canister = C>,
235 {
236 describe_entity_model(E::MODEL)
237 }
238
239 pub fn storage_report(
241 &self,
242 name_to_path: &[(&'static str, &'static str)],
243 ) -> Result<StorageReport, InternalError> {
244 self.db.storage_report(name_to_path)
245 }
246
247 #[must_use]
252 pub(in crate::db) const fn load_executor<E>(&self) -> LoadExecutor<E>
253 where
254 E: EntityKind<Canister = C> + EntityValue,
255 {
256 LoadExecutor::new(self.db, self.debug)
257 }
258
259 #[must_use]
260 pub(in crate::db) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
261 where
262 E: EntityKind<Canister = C> + EntityValue,
263 {
264 DeleteExecutor::new(self.db, self.debug)
265 }
266
267 #[must_use]
268 pub(in crate::db) const fn save_executor<E>(&self) -> SaveExecutor<E>
269 where
270 E: EntityKind<Canister = C> + EntityValue,
271 {
272 SaveExecutor::new(self.db, self.debug)
273 }
274
275 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
281 where
282 E: EntityKind<Canister = C> + EntityValue,
283 {
284 let plan = query.plan()?.into_executable();
285
286 let result = match query.mode() {
287 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
288 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
289 };
290
291 result.map_err(QueryError::execute)
292 }
293
294 pub(in crate::db) fn execute_load_query_with<E, T>(
297 &self,
298 query: &Query<E>,
299 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
300 ) -> Result<T, QueryError>
301 where
302 E: EntityKind<Canister = C> + EntityValue,
303 {
304 let plan = query.plan()?.into_executable();
305
306 self.with_metrics(|| op(self.load_executor::<E>(), plan))
307 .map_err(QueryError::execute)
308 }
309
310 pub fn trace_query<E>(&self, query: &Query<E>) -> Result<QueryTracePlan, QueryError>
315 where
316 E: EntityKind<Canister = C>,
317 {
318 let compiled = query.plan()?;
319 let explain = compiled.explain();
320 let plan_hash = compiled.plan_hash_hex();
321
322 let executable = compiled.into_executable();
323 let access_strategy = AccessStrategy::from_plan(executable.access()).debug_summary();
324 let execution_strategy = match query.mode() {
325 QueryMode::Load(_) => Some(trace_execution_strategy(
326 executable
327 .execution_strategy()
328 .map_err(QueryError::execute)?,
329 )),
330 QueryMode::Delete(_) => None,
331 };
332
333 Ok(QueryTracePlan::new(
334 plan_hash,
335 access_strategy,
336 execution_strategy,
337 explain,
338 ))
339 }
340
341 pub(crate) fn explain_load_query_terminal_with<E>(
343 query: &Query<E>,
344 aggregate: AggregateExpr,
345 ) -> Result<ExplainAggregateTerminalPlan, QueryError>
346 where
347 E: EntityKind<Canister = C> + EntityValue,
348 {
349 let compiled = query.plan()?;
351 let query_explain = compiled.explain();
352 let terminal = aggregate.kind();
353
354 let executable = compiled.into_executable();
356 let execution = executable.explain_aggregate_terminal_execution_descriptor(aggregate);
357
358 Ok(ExplainAggregateTerminalPlan::new(
359 query_explain,
360 terminal,
361 execution,
362 ))
363 }
364
365 pub(crate) fn execute_load_query_paged_with_trace<E>(
367 &self,
368 query: &Query<E>,
369 cursor_token: Option<&str>,
370 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
371 where
372 E: EntityKind<Canister = C> + EntityValue,
373 {
374 let plan = query.plan()?.into_executable();
376 match plan.execution_strategy().map_err(QueryError::execute)? {
377 ExecutionStrategy::PrimaryKey => {
378 return Err(QueryError::execute(
379 InternalError::query_executor_invariant(
380 "cursor pagination requires explicit or grouped ordering",
381 ),
382 ));
383 }
384 ExecutionStrategy::Ordered => {}
385 ExecutionStrategy::Grouped => {
386 return Err(QueryError::execute(
387 InternalError::query_executor_invariant(
388 "grouped plans require execute_grouped(...)",
389 ),
390 ));
391 }
392 }
393
394 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
396 let cursor = plan
397 .prepare_cursor(cursor_bytes.as_deref())
398 .map_err(map_executor_plan_error)?;
399
400 let (page, trace) = self
402 .with_metrics(|| {
403 self.load_executor::<E>()
404 .execute_paged_with_cursor_traced(plan, cursor)
405 })
406 .map_err(QueryError::execute)?;
407 let next_cursor = page
408 .next_cursor
409 .map(|token| {
410 let Some(token) = token.as_scalar() else {
411 return Err(QueryError::execute(
412 InternalError::query_executor_invariant(
413 "scalar load pagination emitted grouped continuation token",
414 ),
415 ));
416 };
417
418 token.encode().map_err(|err| {
419 QueryError::execute(InternalError::serialize_internal(format!(
420 "failed to serialize continuation cursor: {err}"
421 )))
422 })
423 })
424 .transpose()?;
425
426 Ok(PagedLoadExecutionWithTrace::new(
427 page.items,
428 next_cursor,
429 trace,
430 ))
431 }
432
433 pub fn execute_grouped<E>(
438 &self,
439 query: &Query<E>,
440 cursor_token: Option<&str>,
441 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
442 where
443 E: EntityKind<Canister = C> + EntityValue,
444 {
445 let plan = query.plan()?.into_executable();
447 if !matches!(
448 plan.execution_strategy().map_err(QueryError::execute)?,
449 ExecutionStrategy::Grouped
450 ) {
451 return Err(QueryError::execute(
452 InternalError::query_executor_invariant(
453 "execute_grouped requires grouped logical plans",
454 ),
455 ));
456 }
457
458 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
460 let cursor = plan
461 .prepare_grouped_cursor(cursor_bytes.as_deref())
462 .map_err(map_executor_plan_error)?;
463
464 let (page, trace) = self
466 .with_metrics(|| {
467 self.load_executor::<E>()
468 .execute_grouped_paged_with_cursor_traced(plan, cursor)
469 })
470 .map_err(QueryError::execute)?;
471 let next_cursor = page
472 .next_cursor
473 .map(|token| {
474 let Some(token) = token.as_grouped() else {
475 return Err(QueryError::execute(
476 InternalError::query_executor_invariant(
477 "grouped pagination emitted scalar continuation token",
478 ),
479 ));
480 };
481
482 token.encode().map_err(|err| {
483 QueryError::execute(InternalError::serialize_internal(format!(
484 "failed to serialize grouped continuation cursor: {err}"
485 )))
486 })
487 })
488 .transpose()?;
489
490 Ok(PagedGroupedExecutionWithTrace::new(
491 page.rows,
492 next_cursor,
493 trace,
494 ))
495 }
496
497 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
503 where
504 E: EntityKind<Canister = C> + EntityValue,
505 {
506 self.execute_save_entity(|save| save.insert(entity))
507 }
508
509 pub fn insert_many_atomic<E>(
515 &self,
516 entities: impl IntoIterator<Item = E>,
517 ) -> Result<WriteBatchResponse<E>, InternalError>
518 where
519 E: EntityKind<Canister = C> + EntityValue,
520 {
521 self.execute_save_batch(|save| save.insert_many_atomic(entities))
522 }
523
524 pub fn insert_many_non_atomic<E>(
528 &self,
529 entities: impl IntoIterator<Item = E>,
530 ) -> Result<WriteBatchResponse<E>, InternalError>
531 where
532 E: EntityKind<Canister = C> + EntityValue,
533 {
534 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
535 }
536
537 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
539 where
540 E: EntityKind<Canister = C> + EntityValue,
541 {
542 self.execute_save_entity(|save| save.replace(entity))
543 }
544
545 pub fn replace_many_atomic<E>(
551 &self,
552 entities: impl IntoIterator<Item = E>,
553 ) -> Result<WriteBatchResponse<E>, InternalError>
554 where
555 E: EntityKind<Canister = C> + EntityValue,
556 {
557 self.execute_save_batch(|save| save.replace_many_atomic(entities))
558 }
559
560 pub fn replace_many_non_atomic<E>(
564 &self,
565 entities: impl IntoIterator<Item = E>,
566 ) -> Result<WriteBatchResponse<E>, InternalError>
567 where
568 E: EntityKind<Canister = C> + EntityValue,
569 {
570 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
571 }
572
573 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
575 where
576 E: EntityKind<Canister = C> + EntityValue,
577 {
578 self.execute_save_entity(|save| save.update(entity))
579 }
580
581 pub fn update_many_atomic<E>(
587 &self,
588 entities: impl IntoIterator<Item = E>,
589 ) -> Result<WriteBatchResponse<E>, InternalError>
590 where
591 E: EntityKind<Canister = C> + EntityValue,
592 {
593 self.execute_save_batch(|save| save.update_many_atomic(entities))
594 }
595
596 pub fn update_many_non_atomic<E>(
600 &self,
601 entities: impl IntoIterator<Item = E>,
602 ) -> Result<WriteBatchResponse<E>, InternalError>
603 where
604 E: EntityKind<Canister = C> + EntityValue,
605 {
606 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
607 }
608
609 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
611 where
612 E: EntityKind<Canister = C> + EntityValue,
613 {
614 self.execute_save_view::<E>(|save| save.insert_view(view))
615 }
616
617 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
619 where
620 E: EntityKind<Canister = C> + EntityValue,
621 {
622 self.execute_save_view::<E>(|save| save.replace_view(view))
623 }
624
625 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
627 where
628 E: EntityKind<Canister = C> + EntityValue,
629 {
630 self.execute_save_view::<E>(|save| save.update_view(view))
631 }
632
633 #[cfg(test)]
635 #[doc(hidden)]
636 pub fn clear_stores_for_tests(&self) {
637 self.db.with_store_registry(|reg| {
638 for (_, store) in reg.iter() {
641 store.with_data_mut(DataStore::clear);
642 store.with_index_mut(IndexStore::clear);
643 }
644 });
645 }
646}
647
648const fn trace_execution_strategy(strategy: ExecutionStrategy) -> TraceExecutionStrategy {
649 match strategy {
650 ExecutionStrategy::PrimaryKey => TraceExecutionStrategy::PrimaryKey,
651 ExecutionStrategy::Ordered => TraceExecutionStrategy::Ordered,
652 ExecutionStrategy::Grouped => TraceExecutionStrategy::Grouped,
653 }
654}
655
656#[cfg(test)]
661mod tests {
662 use super::*;
663 use crate::db::cursor::CursorPlanError;
664
665 fn assert_query_error_is_cursor_plan(
667 err: QueryError,
668 predicate: impl FnOnce(&CursorPlanError) -> bool,
669 ) {
670 assert!(matches!(
671 err,
672 QueryError::Plan(plan_err)
673 if matches!(
674 plan_err.as_ref(),
675 PlanError::Cursor(inner) if predicate(inner.as_ref())
676 )
677 ));
678 }
679
680 fn assert_cursor_mapping_parity(
682 build: impl Fn() -> CursorPlanError,
683 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
684 ) {
685 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
686 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
687
688 let mapped_via_plan = QueryError::from(PlanError::from(build()));
689 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
690 }
691
692 #[test]
693 fn session_cursor_error_mapping_parity_boundary_arity() {
694 assert_cursor_mapping_parity(
695 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
696 |inner| {
697 matches!(
698 inner,
699 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
700 expected: 2,
701 found: 1
702 }
703 )
704 },
705 );
706 }
707
708 #[test]
709 fn session_cursor_error_mapping_parity_window_mismatch() {
710 assert_cursor_mapping_parity(
711 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
712 |inner| {
713 matches!(
714 inner,
715 CursorPlanError::ContinuationCursorWindowMismatch {
716 expected_offset: 8,
717 actual_offset: 3
718 }
719 )
720 },
721 );
722 }
723
724 #[test]
725 fn session_cursor_error_mapping_parity_decode_reason() {
726 assert_cursor_mapping_parity(
727 || {
728 CursorPlanError::invalid_continuation_cursor(
729 crate::db::codec::cursor::CursorDecodeError::OddLength,
730 )
731 },
732 |inner| {
733 matches!(
734 inner,
735 CursorPlanError::InvalidContinuationCursor {
736 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
737 }
738 )
739 },
740 );
741 }
742
743 #[test]
744 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
745 assert_cursor_mapping_parity(
746 || {
747 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
748 "id",
749 "ulid",
750 Some(crate::value::Value::Text("not-a-ulid".to_string())),
751 )
752 },
753 |inner| {
754 matches!(
755 inner,
756 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
757 field,
758 expected,
759 value: Some(crate::value::Value::Text(value))
760 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
761 )
762 },
763 );
764 }
765
766 #[test]
767 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
768 assert_cursor_mapping_parity(
770 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
771 |inner| {
772 matches!(
773 inner,
774 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
775 expected: 2,
776 found: 1
777 }
778 )
779 },
780 );
781 }
782}