1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, EntitySchemaDescription, FluentDeleteQuery, FluentLoadQuery,
11 MissingRowPolicy, PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError,
12 Query, QueryError, QueryTracePlan, StoreRegistry, TraceExecutionStrategy,
13 WriteBatchResponse,
14 access::AccessStrategy,
15 commit::EntityRuntimeHooks,
16 cursor::decode_optional_cursor_token,
17 describe::describe_entity_model,
18 executor::{
19 DeleteExecutor, ExecutablePlan, ExecutionStrategy, ExecutorPlanError, LoadExecutor,
20 SaveExecutor,
21 },
22 query::{
23 builder::aggregate::AggregateExpr, explain::ExplainAggregateTerminalPlan,
24 plan::QueryMode,
25 },
26 },
27 error::InternalError,
28 obs::sink::{MetricsSink, with_metrics_sink},
29 traits::{CanisterKind, EntityKind, EntityValue},
30 value::Value,
31};
32use std::thread::LocalKey;
33
34fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
36 match err {
37 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
38 }
39}
40
41fn decode_optional_cursor_bytes(cursor_token: Option<&str>) -> Result<Option<Vec<u8>>, QueryError> {
44 decode_optional_cursor_token(cursor_token).map_err(|err| QueryError::from(PlanError::from(err)))
45}
46
47pub struct DbSession<C: CanisterKind> {
54 db: Db<C>,
55 debug: bool,
56 metrics: Option<&'static dyn MetricsSink>,
57}
58
59impl<C: CanisterKind> DbSession<C> {
60 #[must_use]
62 pub(crate) const fn new(db: Db<C>) -> Self {
63 Self {
64 db,
65 debug: false,
66 metrics: None,
67 }
68 }
69
70 #[must_use]
72 pub const fn new_with_hooks(
73 store: &'static LocalKey<StoreRegistry>,
74 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
75 ) -> Self {
76 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
77 }
78
79 #[must_use]
81 pub const fn debug(mut self) -> Self {
82 self.debug = true;
83 self
84 }
85
86 #[must_use]
88 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
89 self.metrics = Some(sink);
90 self
91 }
92
93 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
94 if let Some(sink) = self.metrics {
95 with_metrics_sink(sink, f)
96 } else {
97 f()
98 }
99 }
100
101 fn execute_save_with<E, T, R>(
103 &self,
104 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
105 map: impl FnOnce(T) -> R,
106 ) -> Result<R, InternalError>
107 where
108 E: EntityKind<Canister = C> + EntityValue,
109 {
110 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
111
112 Ok(map(value))
113 }
114
115 fn execute_save_entity<E>(
117 &self,
118 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
119 ) -> Result<E, InternalError>
120 where
121 E: EntityKind<Canister = C> + EntityValue,
122 {
123 self.execute_save_with(op, std::convert::identity)
124 }
125
126 fn execute_save_batch<E>(
127 &self,
128 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
129 ) -> Result<WriteBatchResponse<E>, InternalError>
130 where
131 E: EntityKind<Canister = C> + EntityValue,
132 {
133 self.execute_save_with(op, WriteBatchResponse::new)
134 }
135
136 fn execute_save_view<E>(
137 &self,
138 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
139 ) -> Result<E::ViewType, InternalError>
140 where
141 E: EntityKind<Canister = C> + EntityValue,
142 {
143 self.execute_save_with(op, std::convert::identity)
144 }
145
146 #[must_use]
152 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
153 where
154 E: EntityKind<Canister = C>,
155 {
156 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
157 }
158
159 #[must_use]
161 pub const fn load_with_consistency<E>(
162 &self,
163 consistency: MissingRowPolicy,
164 ) -> FluentLoadQuery<'_, E>
165 where
166 E: EntityKind<Canister = C>,
167 {
168 FluentLoadQuery::new(self, Query::new(consistency))
169 }
170
171 #[must_use]
173 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
174 where
175 E: EntityKind<Canister = C>,
176 {
177 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
178 }
179
180 #[must_use]
182 pub fn delete_with_consistency<E>(
183 &self,
184 consistency: MissingRowPolicy,
185 ) -> FluentDeleteQuery<'_, E>
186 where
187 E: EntityKind<Canister = C>,
188 {
189 FluentDeleteQuery::new(self, Query::new(consistency).delete())
190 }
191
192 #[must_use]
196 pub const fn select_one(&self) -> Value {
197 Value::Int(1)
198 }
199
200 #[must_use]
207 pub fn show_indexes<E>(&self) -> Vec<String>
208 where
209 E: EntityKind<Canister = C>,
210 {
211 let mut indexes = Vec::with_capacity(E::MODEL.indexes.len().saturating_add(1));
212 indexes.push(format!("PRIMARY KEY ({})", E::MODEL.primary_key.name));
213
214 for index in E::MODEL.indexes {
215 let kind = if index.is_unique() {
216 "UNIQUE INDEX"
217 } else {
218 "INDEX"
219 };
220 let fields = index.fields().join(", ");
221 indexes.push(format!("{kind} {} ({fields})", index.name()));
222 }
223
224 indexes
225 }
226
227 #[must_use]
232 pub fn describe_entity<E>(&self) -> EntitySchemaDescription
233 where
234 E: EntityKind<Canister = C>,
235 {
236 describe_entity_model(E::MODEL)
237 }
238
239 #[must_use]
244 pub(in crate::db) const fn load_executor<E>(&self) -> LoadExecutor<E>
245 where
246 E: EntityKind<Canister = C> + EntityValue,
247 {
248 LoadExecutor::new(self.db, self.debug)
249 }
250
251 #[must_use]
252 pub(in crate::db) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
253 where
254 E: EntityKind<Canister = C> + EntityValue,
255 {
256 DeleteExecutor::new(self.db, self.debug)
257 }
258
259 #[must_use]
260 pub(in crate::db) const fn save_executor<E>(&self) -> SaveExecutor<E>
261 where
262 E: EntityKind<Canister = C> + EntityValue,
263 {
264 SaveExecutor::new(self.db, self.debug)
265 }
266
267 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
273 where
274 E: EntityKind<Canister = C> + EntityValue,
275 {
276 let plan = query.plan()?.into_executable();
277
278 let result = match query.mode() {
279 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
280 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
281 };
282
283 result.map_err(QueryError::execute)
284 }
285
286 pub(in crate::db) fn execute_load_query_with<E, T>(
289 &self,
290 query: &Query<E>,
291 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
292 ) -> Result<T, QueryError>
293 where
294 E: EntityKind<Canister = C> + EntityValue,
295 {
296 let plan = query.plan()?.into_executable();
297
298 self.with_metrics(|| op(self.load_executor::<E>(), plan))
299 .map_err(QueryError::execute)
300 }
301
302 pub fn trace_query<E>(&self, query: &Query<E>) -> Result<QueryTracePlan, QueryError>
307 where
308 E: EntityKind<Canister = C>,
309 {
310 let compiled = query.plan()?;
311 let explain = compiled.explain();
312 let plan_hash = explain.fingerprint().to_string();
313
314 let executable = compiled.into_executable();
315 let access_strategy = AccessStrategy::from_plan(executable.access()).debug_summary();
316 let execution_strategy = match query.mode() {
317 QueryMode::Load(_) => Some(trace_execution_strategy(
318 executable
319 .execution_strategy()
320 .map_err(QueryError::execute)?,
321 )),
322 QueryMode::Delete(_) => None,
323 };
324
325 Ok(QueryTracePlan::new(
326 plan_hash,
327 access_strategy,
328 execution_strategy,
329 explain,
330 ))
331 }
332
333 pub(crate) fn explain_load_query_terminal_with<E>(
335 query: &Query<E>,
336 aggregate: AggregateExpr,
337 ) -> Result<ExplainAggregateTerminalPlan, QueryError>
338 where
339 E: EntityKind<Canister = C> + EntityValue,
340 {
341 let compiled = query.plan()?;
343 let query_explain = compiled.explain();
344 let terminal = aggregate.kind();
345
346 let executable = compiled.into_executable();
348 let execution = executable.explain_aggregate_terminal_execution_descriptor(aggregate);
349
350 Ok(ExplainAggregateTerminalPlan::new(
351 query_explain,
352 terminal,
353 execution,
354 ))
355 }
356
357 pub(crate) fn execute_load_query_paged_with_trace<E>(
359 &self,
360 query: &Query<E>,
361 cursor_token: Option<&str>,
362 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
363 where
364 E: EntityKind<Canister = C> + EntityValue,
365 {
366 let plan = query.plan()?.into_executable();
368 match plan.execution_strategy().map_err(QueryError::execute)? {
369 ExecutionStrategy::PrimaryKey => {
370 return Err(QueryError::execute(crate::db::error::executor_invariant(
371 "cursor pagination requires explicit or grouped ordering",
372 )));
373 }
374 ExecutionStrategy::Ordered => {}
375 ExecutionStrategy::Grouped => {
376 return Err(QueryError::execute(crate::db::error::executor_invariant(
377 "grouped plans require execute_grouped(...)",
378 )));
379 }
380 }
381
382 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
384 let cursor = plan
385 .prepare_cursor(cursor_bytes.as_deref())
386 .map_err(map_executor_plan_error)?;
387
388 let (page, trace) = self
390 .with_metrics(|| {
391 self.load_executor::<E>()
392 .execute_paged_with_cursor_traced(plan, cursor)
393 })
394 .map_err(QueryError::execute)?;
395 let next_cursor = page
396 .next_cursor
397 .map(|token| {
398 let Some(token) = token.as_scalar() else {
399 return Err(QueryError::execute(crate::db::error::executor_invariant(
400 "scalar load pagination emitted grouped continuation token",
401 )));
402 };
403
404 token.encode().map_err(|err| {
405 QueryError::execute(InternalError::serialize_internal(format!(
406 "failed to serialize continuation cursor: {err}"
407 )))
408 })
409 })
410 .transpose()?;
411
412 Ok(PagedLoadExecutionWithTrace::new(
413 page.items,
414 next_cursor,
415 trace,
416 ))
417 }
418
419 pub fn execute_grouped<E>(
424 &self,
425 query: &Query<E>,
426 cursor_token: Option<&str>,
427 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
428 where
429 E: EntityKind<Canister = C> + EntityValue,
430 {
431 let plan = query.plan()?.into_executable();
433 if !matches!(
434 plan.execution_strategy().map_err(QueryError::execute)?,
435 ExecutionStrategy::Grouped
436 ) {
437 return Err(QueryError::execute(crate::db::error::executor_invariant(
438 "execute_grouped requires grouped logical plans",
439 )));
440 }
441
442 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
444 let cursor = plan
445 .prepare_grouped_cursor(cursor_bytes.as_deref())
446 .map_err(map_executor_plan_error)?;
447
448 let (page, trace) = self
450 .with_metrics(|| {
451 self.load_executor::<E>()
452 .execute_grouped_paged_with_cursor_traced(plan, cursor)
453 })
454 .map_err(QueryError::execute)?;
455 let next_cursor = page
456 .next_cursor
457 .map(|token| {
458 let Some(token) = token.as_grouped() else {
459 return Err(QueryError::execute(crate::db::error::executor_invariant(
460 "grouped pagination emitted scalar continuation token",
461 )));
462 };
463
464 token.encode().map_err(|err| {
465 QueryError::execute(InternalError::serialize_internal(format!(
466 "failed to serialize grouped continuation cursor: {err}"
467 )))
468 })
469 })
470 .transpose()?;
471
472 Ok(PagedGroupedExecutionWithTrace::new(
473 page.rows,
474 next_cursor,
475 trace,
476 ))
477 }
478
479 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
485 where
486 E: EntityKind<Canister = C> + EntityValue,
487 {
488 self.execute_save_entity(|save| save.insert(entity))
489 }
490
491 pub fn insert_many_atomic<E>(
497 &self,
498 entities: impl IntoIterator<Item = E>,
499 ) -> Result<WriteBatchResponse<E>, InternalError>
500 where
501 E: EntityKind<Canister = C> + EntityValue,
502 {
503 self.execute_save_batch(|save| save.insert_many_atomic(entities))
504 }
505
506 pub fn insert_many_non_atomic<E>(
510 &self,
511 entities: impl IntoIterator<Item = E>,
512 ) -> Result<WriteBatchResponse<E>, InternalError>
513 where
514 E: EntityKind<Canister = C> + EntityValue,
515 {
516 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
517 }
518
519 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
521 where
522 E: EntityKind<Canister = C> + EntityValue,
523 {
524 self.execute_save_entity(|save| save.replace(entity))
525 }
526
527 pub fn replace_many_atomic<E>(
533 &self,
534 entities: impl IntoIterator<Item = E>,
535 ) -> Result<WriteBatchResponse<E>, InternalError>
536 where
537 E: EntityKind<Canister = C> + EntityValue,
538 {
539 self.execute_save_batch(|save| save.replace_many_atomic(entities))
540 }
541
542 pub fn replace_many_non_atomic<E>(
546 &self,
547 entities: impl IntoIterator<Item = E>,
548 ) -> Result<WriteBatchResponse<E>, InternalError>
549 where
550 E: EntityKind<Canister = C> + EntityValue,
551 {
552 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
553 }
554
555 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
557 where
558 E: EntityKind<Canister = C> + EntityValue,
559 {
560 self.execute_save_entity(|save| save.update(entity))
561 }
562
563 pub fn update_many_atomic<E>(
569 &self,
570 entities: impl IntoIterator<Item = E>,
571 ) -> Result<WriteBatchResponse<E>, InternalError>
572 where
573 E: EntityKind<Canister = C> + EntityValue,
574 {
575 self.execute_save_batch(|save| save.update_many_atomic(entities))
576 }
577
578 pub fn update_many_non_atomic<E>(
582 &self,
583 entities: impl IntoIterator<Item = E>,
584 ) -> Result<WriteBatchResponse<E>, InternalError>
585 where
586 E: EntityKind<Canister = C> + EntityValue,
587 {
588 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
589 }
590
591 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
593 where
594 E: EntityKind<Canister = C> + EntityValue,
595 {
596 self.execute_save_view::<E>(|save| save.insert_view(view))
597 }
598
599 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
601 where
602 E: EntityKind<Canister = C> + EntityValue,
603 {
604 self.execute_save_view::<E>(|save| save.replace_view(view))
605 }
606
607 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
609 where
610 E: EntityKind<Canister = C> + EntityValue,
611 {
612 self.execute_save_view::<E>(|save| save.update_view(view))
613 }
614
615 #[cfg(test)]
617 #[doc(hidden)]
618 pub fn clear_stores_for_tests(&self) {
619 self.db.with_store_registry(|reg| {
620 for (_, store) in reg.iter() {
623 store.with_data_mut(DataStore::clear);
624 store.with_index_mut(IndexStore::clear);
625 }
626 });
627 }
628}
629
630const fn trace_execution_strategy(strategy: ExecutionStrategy) -> TraceExecutionStrategy {
631 match strategy {
632 ExecutionStrategy::PrimaryKey => TraceExecutionStrategy::PrimaryKey,
633 ExecutionStrategy::Ordered => TraceExecutionStrategy::Ordered,
634 ExecutionStrategy::Grouped => TraceExecutionStrategy::Grouped,
635 }
636}
637
638#[cfg(test)]
643mod tests {
644 use super::*;
645 use crate::db::cursor::CursorPlanError;
646
647 fn assert_query_error_is_cursor_plan(
649 err: QueryError,
650 predicate: impl FnOnce(&CursorPlanError) -> bool,
651 ) {
652 assert!(matches!(
653 err,
654 QueryError::Plan(plan_err)
655 if matches!(
656 plan_err.as_ref(),
657 PlanError::Cursor(inner) if predicate(inner.as_ref())
658 )
659 ));
660 }
661
662 fn assert_cursor_mapping_parity(
664 build: impl Fn() -> CursorPlanError,
665 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
666 ) {
667 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
668 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
669
670 let mapped_via_plan = QueryError::from(PlanError::from(build()));
671 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
672 }
673
674 #[test]
675 fn session_cursor_error_mapping_parity_boundary_arity() {
676 assert_cursor_mapping_parity(
677 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
678 |inner| {
679 matches!(
680 inner,
681 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
682 expected: 2,
683 found: 1
684 }
685 )
686 },
687 );
688 }
689
690 #[test]
691 fn session_cursor_error_mapping_parity_window_mismatch() {
692 assert_cursor_mapping_parity(
693 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
694 |inner| {
695 matches!(
696 inner,
697 CursorPlanError::ContinuationCursorWindowMismatch {
698 expected_offset: 8,
699 actual_offset: 3
700 }
701 )
702 },
703 );
704 }
705
706 #[test]
707 fn session_cursor_error_mapping_parity_decode_reason() {
708 assert_cursor_mapping_parity(
709 || {
710 CursorPlanError::invalid_continuation_cursor(
711 crate::db::codec::cursor::CursorDecodeError::OddLength,
712 )
713 },
714 |inner| {
715 matches!(
716 inner,
717 CursorPlanError::InvalidContinuationCursor {
718 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
719 }
720 )
721 },
722 );
723 }
724
725 #[test]
726 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
727 assert_cursor_mapping_parity(
728 || {
729 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
730 "id",
731 "ulid",
732 Some(crate::value::Value::Text("not-a-ulid".to_string())),
733 )
734 },
735 |inner| {
736 matches!(
737 inner,
738 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
739 field,
740 expected,
741 value: Some(crate::value::Value::Text(value))
742 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
743 )
744 },
745 );
746 }
747
748 #[test]
749 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
750 assert_cursor_mapping_parity(
752 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
753 |inner| {
754 matches!(
755 inner,
756 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
757 expected: 2,
758 found: 1
759 }
760 )
761 },
762 );
763 }
764}