1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, FluentDeleteQuery, FluentLoadQuery, MissingRowPolicy,
11 PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError, Query, QueryError,
12 StoreRegistry, WriteBatchResponse,
13 commit::EntityRuntimeHooks,
14 cursor::CursorPlanError,
15 decode_cursor,
16 executor::{
17 DeleteExecutor, ExecutablePlan, ExecutionStrategy, ExecutorPlanError, LoadExecutor,
18 SaveExecutor,
19 },
20 query::plan::QueryMode,
21 },
22 error::InternalError,
23 obs::sink::{MetricsSink, with_metrics_sink},
24 traits::{CanisterKind, EntityKind, EntityValue},
25};
26use std::thread::LocalKey;
27
28fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
30 match err {
31 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
32 }
33}
34
35fn decode_optional_cursor_bytes(cursor_token: Option<&str>) -> Result<Option<Vec<u8>>, QueryError> {
38 cursor_token
39 .map(|token| {
40 decode_cursor(token).map_err(|reason| {
41 QueryError::from(PlanError::from(
42 CursorPlanError::invalid_continuation_cursor(reason),
43 ))
44 })
45 })
46 .transpose()
47}
48
49pub struct DbSession<C: CanisterKind> {
56 db: Db<C>,
57 debug: bool,
58 metrics: Option<&'static dyn MetricsSink>,
59}
60
61impl<C: CanisterKind> DbSession<C> {
62 #[must_use]
64 pub(crate) const fn new(db: Db<C>) -> Self {
65 Self {
66 db,
67 debug: false,
68 metrics: None,
69 }
70 }
71
72 #[must_use]
74 pub const fn new_with_hooks(
75 store: &'static LocalKey<StoreRegistry>,
76 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
77 ) -> Self {
78 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
79 }
80
81 #[must_use]
83 pub const fn debug(mut self) -> Self {
84 self.debug = true;
85 self
86 }
87
88 #[must_use]
90 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
91 self.metrics = Some(sink);
92 self
93 }
94
95 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
96 if let Some(sink) = self.metrics {
97 with_metrics_sink(sink, f)
98 } else {
99 f()
100 }
101 }
102
103 fn execute_save_with<E, T, R>(
105 &self,
106 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
107 map: impl FnOnce(T) -> R,
108 ) -> Result<R, InternalError>
109 where
110 E: EntityKind<Canister = C> + EntityValue,
111 {
112 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
113
114 Ok(map(value))
115 }
116
117 fn execute_save_entity<E>(
119 &self,
120 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
121 ) -> Result<E, InternalError>
122 where
123 E: EntityKind<Canister = C> + EntityValue,
124 {
125 self.execute_save_with(op, std::convert::identity)
126 }
127
128 fn execute_save_batch<E>(
129 &self,
130 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
131 ) -> Result<WriteBatchResponse<E>, InternalError>
132 where
133 E: EntityKind<Canister = C> + EntityValue,
134 {
135 self.execute_save_with(op, WriteBatchResponse::new)
136 }
137
138 fn execute_save_view<E>(
139 &self,
140 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
141 ) -> Result<E::ViewType, InternalError>
142 where
143 E: EntityKind<Canister = C> + EntityValue,
144 {
145 self.execute_save_with(op, std::convert::identity)
146 }
147
148 #[must_use]
154 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
155 where
156 E: EntityKind<Canister = C>,
157 {
158 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
159 }
160
161 #[must_use]
163 pub const fn load_with_consistency<E>(
164 &self,
165 consistency: MissingRowPolicy,
166 ) -> FluentLoadQuery<'_, E>
167 where
168 E: EntityKind<Canister = C>,
169 {
170 FluentLoadQuery::new(self, Query::new(consistency))
171 }
172
173 #[must_use]
175 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
176 where
177 E: EntityKind<Canister = C>,
178 {
179 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
180 }
181
182 #[must_use]
184 pub fn delete_with_consistency<E>(
185 &self,
186 consistency: MissingRowPolicy,
187 ) -> FluentDeleteQuery<'_, E>
188 where
189 E: EntityKind<Canister = C>,
190 {
191 FluentDeleteQuery::new(self, Query::new(consistency).delete())
192 }
193
194 #[must_use]
199 pub(crate) const fn load_executor<E>(&self) -> LoadExecutor<E>
200 where
201 E: EntityKind<Canister = C> + EntityValue,
202 {
203 LoadExecutor::new(self.db, self.debug)
204 }
205
206 #[must_use]
207 pub(crate) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
208 where
209 E: EntityKind<Canister = C> + EntityValue,
210 {
211 DeleteExecutor::new(self.db, self.debug)
212 }
213
214 #[must_use]
215 pub(crate) const fn save_executor<E>(&self) -> SaveExecutor<E>
216 where
217 E: EntityKind<Canister = C> + EntityValue,
218 {
219 SaveExecutor::new(self.db, self.debug)
220 }
221
222 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
228 where
229 E: EntityKind<Canister = C> + EntityValue,
230 {
231 let plan = query.plan()?.into_executable();
232
233 let result = match query.mode() {
234 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
235 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
236 };
237
238 result.map_err(QueryError::execute)
239 }
240
241 pub(crate) fn execute_load_query_with<E, T>(
244 &self,
245 query: &Query<E>,
246 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
247 ) -> Result<T, QueryError>
248 where
249 E: EntityKind<Canister = C> + EntityValue,
250 {
251 let plan = query.plan()?.into_executable();
252
253 self.with_metrics(|| op(self.load_executor::<E>(), plan))
254 .map_err(QueryError::execute)
255 }
256
257 pub(crate) fn execute_load_query_paged_with_trace<E>(
259 &self,
260 query: &Query<E>,
261 cursor_token: Option<&str>,
262 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
263 where
264 E: EntityKind<Canister = C> + EntityValue,
265 {
266 let plan = query.plan()?.into_executable();
268 match plan.execution_strategy().map_err(QueryError::execute)? {
269 ExecutionStrategy::PrimaryKey => {
270 return Err(QueryError::execute(invariant(
271 "cursor pagination requires explicit or grouped ordering",
272 )));
273 }
274 ExecutionStrategy::Ordered => {}
275 ExecutionStrategy::Grouped => {
276 return Err(QueryError::execute(invariant(
277 "grouped plans require execute_grouped(...)",
278 )));
279 }
280 }
281
282 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
284 let cursor = plan
285 .prepare_cursor(cursor_bytes.as_deref())
286 .map_err(map_executor_plan_error)?;
287
288 let (page, trace) = self
290 .with_metrics(|| {
291 self.load_executor::<E>()
292 .execute_paged_with_cursor_traced(plan, cursor)
293 })
294 .map_err(QueryError::execute)?;
295 let next_cursor = page
296 .next_cursor
297 .map(|token| {
298 let Some(token) = token.as_scalar() else {
299 return Err(QueryError::execute(invariant(
300 "scalar load pagination emitted grouped continuation token",
301 )));
302 };
303
304 token.encode().map_err(|err| {
305 QueryError::execute(InternalError::serialize_internal(format!(
306 "failed to serialize continuation cursor: {err}"
307 )))
308 })
309 })
310 .transpose()?;
311
312 Ok(PagedLoadExecutionWithTrace::new(
313 page.items,
314 next_cursor,
315 trace,
316 ))
317 }
318
319 pub fn execute_grouped<E>(
324 &self,
325 query: &Query<E>,
326 cursor_token: Option<&str>,
327 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
328 where
329 E: EntityKind<Canister = C> + EntityValue,
330 {
331 let plan = query.plan()?.into_executable();
333 if !matches!(
334 plan.execution_strategy().map_err(QueryError::execute)?,
335 ExecutionStrategy::Grouped
336 ) {
337 return Err(QueryError::execute(invariant(
338 "execute_grouped requires grouped logical plans",
339 )));
340 }
341
342 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
344 let cursor = plan
345 .prepare_grouped_cursor(cursor_bytes.as_deref())
346 .map_err(map_executor_plan_error)?;
347
348 let (page, trace) = self
350 .with_metrics(|| {
351 self.load_executor::<E>()
352 .execute_grouped_paged_with_cursor_traced(plan, cursor)
353 })
354 .map_err(QueryError::execute)?;
355 let next_cursor = page
356 .next_cursor
357 .map(|token| {
358 let Some(token) = token.as_grouped() else {
359 return Err(QueryError::execute(invariant(
360 "grouped pagination emitted scalar continuation token",
361 )));
362 };
363
364 token.encode().map_err(|err| {
365 QueryError::execute(InternalError::serialize_internal(format!(
366 "failed to serialize grouped continuation cursor: {err}"
367 )))
368 })
369 })
370 .transpose()?;
371
372 Ok(PagedGroupedExecutionWithTrace::new(
373 page.rows,
374 next_cursor,
375 trace,
376 ))
377 }
378
379 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
385 where
386 E: EntityKind<Canister = C> + EntityValue,
387 {
388 self.execute_save_entity(|save| save.insert(entity))
389 }
390
391 pub fn insert_many_atomic<E>(
397 &self,
398 entities: impl IntoIterator<Item = E>,
399 ) -> Result<WriteBatchResponse<E>, InternalError>
400 where
401 E: EntityKind<Canister = C> + EntityValue,
402 {
403 self.execute_save_batch(|save| save.insert_many_atomic(entities))
404 }
405
406 pub fn insert_many_non_atomic<E>(
410 &self,
411 entities: impl IntoIterator<Item = E>,
412 ) -> Result<WriteBatchResponse<E>, InternalError>
413 where
414 E: EntityKind<Canister = C> + EntityValue,
415 {
416 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
417 }
418
419 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
421 where
422 E: EntityKind<Canister = C> + EntityValue,
423 {
424 self.execute_save_entity(|save| save.replace(entity))
425 }
426
427 pub fn replace_many_atomic<E>(
433 &self,
434 entities: impl IntoIterator<Item = E>,
435 ) -> Result<WriteBatchResponse<E>, InternalError>
436 where
437 E: EntityKind<Canister = C> + EntityValue,
438 {
439 self.execute_save_batch(|save| save.replace_many_atomic(entities))
440 }
441
442 pub fn replace_many_non_atomic<E>(
446 &self,
447 entities: impl IntoIterator<Item = E>,
448 ) -> Result<WriteBatchResponse<E>, InternalError>
449 where
450 E: EntityKind<Canister = C> + EntityValue,
451 {
452 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
453 }
454
455 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
457 where
458 E: EntityKind<Canister = C> + EntityValue,
459 {
460 self.execute_save_entity(|save| save.update(entity))
461 }
462
463 pub fn update_many_atomic<E>(
469 &self,
470 entities: impl IntoIterator<Item = E>,
471 ) -> Result<WriteBatchResponse<E>, InternalError>
472 where
473 E: EntityKind<Canister = C> + EntityValue,
474 {
475 self.execute_save_batch(|save| save.update_many_atomic(entities))
476 }
477
478 pub fn update_many_non_atomic<E>(
482 &self,
483 entities: impl IntoIterator<Item = E>,
484 ) -> Result<WriteBatchResponse<E>, InternalError>
485 where
486 E: EntityKind<Canister = C> + EntityValue,
487 {
488 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
489 }
490
491 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
493 where
494 E: EntityKind<Canister = C> + EntityValue,
495 {
496 self.execute_save_view::<E>(|save| save.insert_view(view))
497 }
498
499 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
501 where
502 E: EntityKind<Canister = C> + EntityValue,
503 {
504 self.execute_save_view::<E>(|save| save.replace_view(view))
505 }
506
507 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
509 where
510 E: EntityKind<Canister = C> + EntityValue,
511 {
512 self.execute_save_view::<E>(|save| save.update_view(view))
513 }
514
515 #[cfg(test)]
517 #[doc(hidden)]
518 pub fn clear_stores_for_tests(&self) {
519 self.db.with_store_registry(|reg| {
520 for (_, store) in reg.iter() {
523 store.with_data_mut(DataStore::clear);
524 store.with_index_mut(IndexStore::clear);
525 }
526 });
527 }
528}
529
530fn invariant(message: impl Into<String>) -> InternalError {
531 InternalError::query_executor_invariant(message)
532}
533
534#[cfg(test)]
539mod tests {
540 use super::*;
541
542 fn assert_query_error_is_cursor_plan(
544 err: QueryError,
545 predicate: impl FnOnce(&CursorPlanError) -> bool,
546 ) {
547 assert!(matches!(
548 err,
549 QueryError::Plan(plan_err)
550 if matches!(
551 plan_err.as_ref(),
552 PlanError::Cursor(inner) if predicate(inner.as_ref())
553 )
554 ));
555 }
556
557 fn assert_cursor_mapping_parity(
559 build: impl Fn() -> CursorPlanError,
560 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
561 ) {
562 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
563 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
564
565 let mapped_via_plan = QueryError::from(PlanError::from(build()));
566 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
567 }
568
569 #[test]
570 fn session_cursor_error_mapping_parity_boundary_arity() {
571 assert_cursor_mapping_parity(
572 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
573 |inner| {
574 matches!(
575 inner,
576 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
577 expected: 2,
578 found: 1
579 }
580 )
581 },
582 );
583 }
584
585 #[test]
586 fn session_cursor_error_mapping_parity_window_mismatch() {
587 assert_cursor_mapping_parity(
588 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
589 |inner| {
590 matches!(
591 inner,
592 CursorPlanError::ContinuationCursorWindowMismatch {
593 expected_offset: 8,
594 actual_offset: 3
595 }
596 )
597 },
598 );
599 }
600
601 #[test]
602 fn session_cursor_error_mapping_parity_decode_reason() {
603 assert_cursor_mapping_parity(
604 || {
605 CursorPlanError::invalid_continuation_cursor(
606 crate::db::codec::cursor::CursorDecodeError::OddLength,
607 )
608 },
609 |inner| {
610 matches!(
611 inner,
612 CursorPlanError::InvalidContinuationCursor {
613 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
614 }
615 )
616 },
617 );
618 }
619
620 #[test]
621 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
622 assert_cursor_mapping_parity(
623 || {
624 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
625 "id",
626 "ulid",
627 Some(crate::value::Value::Text("not-a-ulid".to_string())),
628 )
629 },
630 |inner| {
631 matches!(
632 inner,
633 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
634 field,
635 expected,
636 value: Some(crate::value::Value::Text(value))
637 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
638 )
639 },
640 );
641 }
642
643 #[test]
644 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
645 assert_cursor_mapping_parity(
647 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
648 |inner| {
649 matches!(
650 inner,
651 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
652 expected: 2,
653 found: 1
654 }
655 )
656 },
657 );
658 }
659}