1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, EntitySchemaDescription, FluentDeleteQuery, FluentLoadQuery,
11 MissingRowPolicy, PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError,
12 Query, QueryError, QueryTracePlan, StorageReport, StoreRegistry, TraceExecutionStrategy,
13 WriteBatchResponse,
14 access::AccessStrategy,
15 commit::EntityRuntimeHooks,
16 cursor::decode_optional_cursor_token,
17 executor::{
18 DeleteExecutor, ExecutablePlan, ExecutionStrategy, ExecutorPlanError, LoadExecutor,
19 SaveExecutor,
20 },
21 query::{
22 builder::aggregate::AggregateExpr, explain::ExplainAggregateTerminalPlan,
23 plan::QueryMode,
24 },
25 schema::{describe_entity_model, show_indexes_for_model},
26 },
27 error::InternalError,
28 metrics::sink::{MetricsSink, with_metrics_sink},
29 traits::{CanisterKind, EntityKind, EntityValue},
30 value::Value,
31};
32use std::thread::LocalKey;
33
34fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
36 match err {
37 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
38 }
39}
40
41fn decode_optional_cursor_bytes(cursor_token: Option<&str>) -> Result<Option<Vec<u8>>, QueryError> {
44 decode_optional_cursor_token(cursor_token).map_err(|err| QueryError::from(PlanError::from(err)))
45}
46
47pub struct DbSession<C: CanisterKind> {
54 db: Db<C>,
55 debug: bool,
56 metrics: Option<&'static dyn MetricsSink>,
57}
58
59impl<C: CanisterKind> DbSession<C> {
60 #[must_use]
62 pub(crate) const fn new(db: Db<C>) -> Self {
63 Self {
64 db,
65 debug: false,
66 metrics: None,
67 }
68 }
69
70 #[must_use]
72 pub const fn new_with_hooks(
73 store: &'static LocalKey<StoreRegistry>,
74 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
75 ) -> Self {
76 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
77 }
78
79 #[must_use]
81 pub const fn debug(mut self) -> Self {
82 self.debug = true;
83 self
84 }
85
86 #[must_use]
88 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
89 self.metrics = Some(sink);
90 self
91 }
92
93 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
94 if let Some(sink) = self.metrics {
95 with_metrics_sink(sink, f)
96 } else {
97 f()
98 }
99 }
100
101 fn execute_save_with<E, T, R>(
103 &self,
104 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
105 map: impl FnOnce(T) -> R,
106 ) -> Result<R, InternalError>
107 where
108 E: EntityKind<Canister = C> + EntityValue,
109 {
110 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
111
112 Ok(map(value))
113 }
114
115 fn execute_save_entity<E>(
117 &self,
118 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
119 ) -> Result<E, InternalError>
120 where
121 E: EntityKind<Canister = C> + EntityValue,
122 {
123 self.execute_save_with(op, std::convert::identity)
124 }
125
126 fn execute_save_batch<E>(
127 &self,
128 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
129 ) -> Result<WriteBatchResponse<E>, InternalError>
130 where
131 E: EntityKind<Canister = C> + EntityValue,
132 {
133 self.execute_save_with(op, WriteBatchResponse::new)
134 }
135
136 fn execute_save_view<E>(
137 &self,
138 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
139 ) -> Result<E::ViewType, InternalError>
140 where
141 E: EntityKind<Canister = C> + EntityValue,
142 {
143 self.execute_save_with(op, std::convert::identity)
144 }
145
146 #[must_use]
152 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
153 where
154 E: EntityKind<Canister = C>,
155 {
156 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
157 }
158
159 #[must_use]
161 pub const fn load_with_consistency<E>(
162 &self,
163 consistency: MissingRowPolicy,
164 ) -> FluentLoadQuery<'_, E>
165 where
166 E: EntityKind<Canister = C>,
167 {
168 FluentLoadQuery::new(self, Query::new(consistency))
169 }
170
171 #[must_use]
173 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
174 where
175 E: EntityKind<Canister = C>,
176 {
177 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
178 }
179
180 #[must_use]
182 pub fn delete_with_consistency<E>(
183 &self,
184 consistency: MissingRowPolicy,
185 ) -> FluentDeleteQuery<'_, E>
186 where
187 E: EntityKind<Canister = C>,
188 {
189 FluentDeleteQuery::new(self, Query::new(consistency).delete())
190 }
191
192 #[must_use]
196 pub const fn select_one(&self) -> Value {
197 Value::Int(1)
198 }
199
200 #[must_use]
207 pub fn show_indexes<E>(&self) -> Vec<String>
208 where
209 E: EntityKind<Canister = C>,
210 {
211 show_indexes_for_model(E::MODEL)
212 }
213
214 #[must_use]
219 pub fn describe_entity<E>(&self) -> EntitySchemaDescription
220 where
221 E: EntityKind<Canister = C>,
222 {
223 describe_entity_model(E::MODEL)
224 }
225
226 pub fn storage_report(
228 &self,
229 name_to_path: &[(&'static str, &'static str)],
230 ) -> Result<StorageReport, InternalError> {
231 self.db.storage_report(name_to_path)
232 }
233
234 #[must_use]
239 pub(in crate::db) const fn load_executor<E>(&self) -> LoadExecutor<E>
240 where
241 E: EntityKind<Canister = C> + EntityValue,
242 {
243 LoadExecutor::new(self.db, self.debug)
244 }
245
246 #[must_use]
247 pub(in crate::db) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
248 where
249 E: EntityKind<Canister = C> + EntityValue,
250 {
251 DeleteExecutor::new(self.db, self.debug)
252 }
253
254 #[must_use]
255 pub(in crate::db) const fn save_executor<E>(&self) -> SaveExecutor<E>
256 where
257 E: EntityKind<Canister = C> + EntityValue,
258 {
259 SaveExecutor::new(self.db, self.debug)
260 }
261
262 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
268 where
269 E: EntityKind<Canister = C> + EntityValue,
270 {
271 let plan = query.plan()?.into_executable();
272
273 let result = match query.mode() {
274 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
275 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
276 };
277
278 result.map_err(QueryError::execute)
279 }
280
281 pub(in crate::db) fn execute_load_query_with<E, T>(
284 &self,
285 query: &Query<E>,
286 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
287 ) -> Result<T, QueryError>
288 where
289 E: EntityKind<Canister = C> + EntityValue,
290 {
291 let plan = query.plan()?.into_executable();
292
293 self.with_metrics(|| op(self.load_executor::<E>(), plan))
294 .map_err(QueryError::execute)
295 }
296
297 pub fn trace_query<E>(&self, query: &Query<E>) -> Result<QueryTracePlan, QueryError>
302 where
303 E: EntityKind<Canister = C>,
304 {
305 let compiled = query.plan()?;
306 let explain = compiled.explain();
307 let plan_hash = compiled.plan_hash_hex();
308
309 let executable = compiled.into_executable();
310 let access_strategy = AccessStrategy::from_plan(executable.access()).debug_summary();
311 let execution_strategy = match query.mode() {
312 QueryMode::Load(_) => Some(trace_execution_strategy(
313 executable
314 .execution_strategy()
315 .map_err(QueryError::execute)?,
316 )),
317 QueryMode::Delete(_) => None,
318 };
319
320 Ok(QueryTracePlan::new(
321 plan_hash,
322 access_strategy,
323 execution_strategy,
324 explain,
325 ))
326 }
327
328 pub(crate) fn explain_load_query_terminal_with<E>(
330 query: &Query<E>,
331 aggregate: AggregateExpr,
332 ) -> Result<ExplainAggregateTerminalPlan, QueryError>
333 where
334 E: EntityKind<Canister = C> + EntityValue,
335 {
336 let compiled = query.plan()?;
338 let query_explain = compiled.explain();
339 let terminal = aggregate.kind();
340
341 let executable = compiled.into_executable();
343 let execution = executable.explain_aggregate_terminal_execution_descriptor(aggregate);
344
345 Ok(ExplainAggregateTerminalPlan::new(
346 query_explain,
347 terminal,
348 execution,
349 ))
350 }
351
352 pub(crate) fn execute_load_query_paged_with_trace<E>(
354 &self,
355 query: &Query<E>,
356 cursor_token: Option<&str>,
357 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
358 where
359 E: EntityKind<Canister = C> + EntityValue,
360 {
361 let plan = query.plan()?.into_executable();
363 match plan.execution_strategy().map_err(QueryError::execute)? {
364 ExecutionStrategy::PrimaryKey => {
365 return Err(QueryError::execute(
366 InternalError::query_executor_invariant(
367 "cursor pagination requires explicit or grouped ordering",
368 ),
369 ));
370 }
371 ExecutionStrategy::Ordered => {}
372 ExecutionStrategy::Grouped => {
373 return Err(QueryError::execute(
374 InternalError::query_executor_invariant(
375 "grouped plans require execute_grouped(...)",
376 ),
377 ));
378 }
379 }
380
381 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
383 let cursor = plan
384 .prepare_cursor(cursor_bytes.as_deref())
385 .map_err(map_executor_plan_error)?;
386
387 let (page, trace) = self
389 .with_metrics(|| {
390 self.load_executor::<E>()
391 .execute_paged_with_cursor_traced(plan, cursor)
392 })
393 .map_err(QueryError::execute)?;
394 let next_cursor = page
395 .next_cursor
396 .map(|token| {
397 let Some(token) = token.as_scalar() else {
398 return Err(QueryError::execute(
399 InternalError::query_executor_invariant(
400 "scalar load pagination emitted grouped continuation token",
401 ),
402 ));
403 };
404
405 token.encode().map_err(|err| {
406 QueryError::execute(InternalError::serialize_internal(format!(
407 "failed to serialize continuation cursor: {err}"
408 )))
409 })
410 })
411 .transpose()?;
412
413 Ok(PagedLoadExecutionWithTrace::new(
414 page.items,
415 next_cursor,
416 trace,
417 ))
418 }
419
420 pub fn execute_grouped<E>(
425 &self,
426 query: &Query<E>,
427 cursor_token: Option<&str>,
428 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
429 where
430 E: EntityKind<Canister = C> + EntityValue,
431 {
432 let plan = query.plan()?.into_executable();
434 if !matches!(
435 plan.execution_strategy().map_err(QueryError::execute)?,
436 ExecutionStrategy::Grouped
437 ) {
438 return Err(QueryError::execute(
439 InternalError::query_executor_invariant(
440 "execute_grouped requires grouped logical plans",
441 ),
442 ));
443 }
444
445 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
447 let cursor = plan
448 .prepare_grouped_cursor(cursor_bytes.as_deref())
449 .map_err(map_executor_plan_error)?;
450
451 let (page, trace) = self
453 .with_metrics(|| {
454 self.load_executor::<E>()
455 .execute_grouped_paged_with_cursor_traced(plan, cursor)
456 })
457 .map_err(QueryError::execute)?;
458 let next_cursor = page
459 .next_cursor
460 .map(|token| {
461 let Some(token) = token.as_grouped() else {
462 return Err(QueryError::execute(
463 InternalError::query_executor_invariant(
464 "grouped pagination emitted scalar continuation token",
465 ),
466 ));
467 };
468
469 token.encode().map_err(|err| {
470 QueryError::execute(InternalError::serialize_internal(format!(
471 "failed to serialize grouped continuation cursor: {err}"
472 )))
473 })
474 })
475 .transpose()?;
476
477 Ok(PagedGroupedExecutionWithTrace::new(
478 page.rows,
479 next_cursor,
480 trace,
481 ))
482 }
483
484 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
490 where
491 E: EntityKind<Canister = C> + EntityValue,
492 {
493 self.execute_save_entity(|save| save.insert(entity))
494 }
495
496 pub fn insert_many_atomic<E>(
502 &self,
503 entities: impl IntoIterator<Item = E>,
504 ) -> Result<WriteBatchResponse<E>, InternalError>
505 where
506 E: EntityKind<Canister = C> + EntityValue,
507 {
508 self.execute_save_batch(|save| save.insert_many_atomic(entities))
509 }
510
511 pub fn insert_many_non_atomic<E>(
515 &self,
516 entities: impl IntoIterator<Item = E>,
517 ) -> Result<WriteBatchResponse<E>, InternalError>
518 where
519 E: EntityKind<Canister = C> + EntityValue,
520 {
521 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
522 }
523
524 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
526 where
527 E: EntityKind<Canister = C> + EntityValue,
528 {
529 self.execute_save_entity(|save| save.replace(entity))
530 }
531
532 pub fn replace_many_atomic<E>(
538 &self,
539 entities: impl IntoIterator<Item = E>,
540 ) -> Result<WriteBatchResponse<E>, InternalError>
541 where
542 E: EntityKind<Canister = C> + EntityValue,
543 {
544 self.execute_save_batch(|save| save.replace_many_atomic(entities))
545 }
546
547 pub fn replace_many_non_atomic<E>(
551 &self,
552 entities: impl IntoIterator<Item = E>,
553 ) -> Result<WriteBatchResponse<E>, InternalError>
554 where
555 E: EntityKind<Canister = C> + EntityValue,
556 {
557 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
558 }
559
560 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
562 where
563 E: EntityKind<Canister = C> + EntityValue,
564 {
565 self.execute_save_entity(|save| save.update(entity))
566 }
567
568 pub fn update_many_atomic<E>(
574 &self,
575 entities: impl IntoIterator<Item = E>,
576 ) -> Result<WriteBatchResponse<E>, InternalError>
577 where
578 E: EntityKind<Canister = C> + EntityValue,
579 {
580 self.execute_save_batch(|save| save.update_many_atomic(entities))
581 }
582
583 pub fn update_many_non_atomic<E>(
587 &self,
588 entities: impl IntoIterator<Item = E>,
589 ) -> Result<WriteBatchResponse<E>, InternalError>
590 where
591 E: EntityKind<Canister = C> + EntityValue,
592 {
593 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
594 }
595
596 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
598 where
599 E: EntityKind<Canister = C> + EntityValue,
600 {
601 self.execute_save_view::<E>(|save| save.insert_view(view))
602 }
603
604 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
606 where
607 E: EntityKind<Canister = C> + EntityValue,
608 {
609 self.execute_save_view::<E>(|save| save.replace_view(view))
610 }
611
612 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
614 where
615 E: EntityKind<Canister = C> + EntityValue,
616 {
617 self.execute_save_view::<E>(|save| save.update_view(view))
618 }
619
620 #[cfg(test)]
622 #[doc(hidden)]
623 pub fn clear_stores_for_tests(&self) {
624 self.db.with_store_registry(|reg| {
625 for (_, store) in reg.iter() {
628 store.with_data_mut(DataStore::clear);
629 store.with_index_mut(IndexStore::clear);
630 }
631 });
632 }
633}
634
635const fn trace_execution_strategy(strategy: ExecutionStrategy) -> TraceExecutionStrategy {
636 match strategy {
637 ExecutionStrategy::PrimaryKey => TraceExecutionStrategy::PrimaryKey,
638 ExecutionStrategy::Ordered => TraceExecutionStrategy::Ordered,
639 ExecutionStrategy::Grouped => TraceExecutionStrategy::Grouped,
640 }
641}
642
643#[cfg(test)]
648mod tests {
649 use super::*;
650 use crate::db::cursor::CursorPlanError;
651
652 fn assert_query_error_is_cursor_plan(
654 err: QueryError,
655 predicate: impl FnOnce(&CursorPlanError) -> bool,
656 ) {
657 assert!(matches!(
658 err,
659 QueryError::Plan(plan_err)
660 if matches!(
661 plan_err.as_ref(),
662 PlanError::Cursor(inner) if predicate(inner.as_ref())
663 )
664 ));
665 }
666
667 fn assert_cursor_mapping_parity(
669 build: impl Fn() -> CursorPlanError,
670 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
671 ) {
672 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
673 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
674
675 let mapped_via_plan = QueryError::from(PlanError::from(build()));
676 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
677 }
678
679 #[test]
680 fn session_cursor_error_mapping_parity_boundary_arity() {
681 assert_cursor_mapping_parity(
682 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
683 |inner| {
684 matches!(
685 inner,
686 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
687 expected: 2,
688 found: 1
689 }
690 )
691 },
692 );
693 }
694
695 #[test]
696 fn session_cursor_error_mapping_parity_window_mismatch() {
697 assert_cursor_mapping_parity(
698 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
699 |inner| {
700 matches!(
701 inner,
702 CursorPlanError::ContinuationCursorWindowMismatch {
703 expected_offset: 8,
704 actual_offset: 3
705 }
706 )
707 },
708 );
709 }
710
711 #[test]
712 fn session_cursor_error_mapping_parity_decode_reason() {
713 assert_cursor_mapping_parity(
714 || {
715 CursorPlanError::invalid_continuation_cursor(
716 crate::db::codec::cursor::CursorDecodeError::OddLength,
717 )
718 },
719 |inner| {
720 matches!(
721 inner,
722 CursorPlanError::InvalidContinuationCursor {
723 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
724 }
725 )
726 },
727 );
728 }
729
730 #[test]
731 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
732 assert_cursor_mapping_parity(
733 || {
734 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
735 "id",
736 "ulid",
737 Some(crate::value::Value::Text("not-a-ulid".to_string())),
738 )
739 },
740 |inner| {
741 matches!(
742 inner,
743 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
744 field,
745 expected,
746 value: Some(crate::value::Value::Text(value))
747 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
748 )
749 },
750 );
751 }
752
753 #[test]
754 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
755 assert_cursor_mapping_parity(
757 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
758 |inner| {
759 matches!(
760 inner,
761 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
762 expected: 2,
763 found: 1
764 }
765 )
766 },
767 );
768 }
769}