1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, FluentDeleteQuery, FluentLoadQuery, MissingRowPolicy,
11 PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError, Query, QueryError,
12 StoreRegistry, WriteBatchResponse,
13 commit::EntityRuntimeHooks,
14 cursor::decode_optional_cursor_token,
15 executor::{
16 DeleteExecutor, ExecutablePlan, ExecutionStrategy, ExecutorPlanError, LoadExecutor,
17 SaveExecutor,
18 },
19 query::plan::QueryMode,
20 },
21 error::InternalError,
22 obs::sink::{MetricsSink, with_metrics_sink},
23 traits::{CanisterKind, EntityKind, EntityValue},
24};
25use std::thread::LocalKey;
26
27fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
29 match err {
30 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
31 }
32}
33
34fn decode_optional_cursor_bytes(cursor_token: Option<&str>) -> Result<Option<Vec<u8>>, QueryError> {
37 decode_optional_cursor_token(cursor_token).map_err(|err| QueryError::from(PlanError::from(err)))
38}
39
40pub struct DbSession<C: CanisterKind> {
47 db: Db<C>,
48 debug: bool,
49 metrics: Option<&'static dyn MetricsSink>,
50}
51
52impl<C: CanisterKind> DbSession<C> {
53 #[must_use]
55 pub(crate) const fn new(db: Db<C>) -> Self {
56 Self {
57 db,
58 debug: false,
59 metrics: None,
60 }
61 }
62
63 #[must_use]
65 pub const fn new_with_hooks(
66 store: &'static LocalKey<StoreRegistry>,
67 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
68 ) -> Self {
69 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
70 }
71
72 #[must_use]
74 pub const fn debug(mut self) -> Self {
75 self.debug = true;
76 self
77 }
78
79 #[must_use]
81 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
82 self.metrics = Some(sink);
83 self
84 }
85
86 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
87 if let Some(sink) = self.metrics {
88 with_metrics_sink(sink, f)
89 } else {
90 f()
91 }
92 }
93
94 fn execute_save_with<E, T, R>(
96 &self,
97 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
98 map: impl FnOnce(T) -> R,
99 ) -> Result<R, InternalError>
100 where
101 E: EntityKind<Canister = C> + EntityValue,
102 {
103 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
104
105 Ok(map(value))
106 }
107
108 fn execute_save_entity<E>(
110 &self,
111 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
112 ) -> Result<E, InternalError>
113 where
114 E: EntityKind<Canister = C> + EntityValue,
115 {
116 self.execute_save_with(op, std::convert::identity)
117 }
118
119 fn execute_save_batch<E>(
120 &self,
121 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
122 ) -> Result<WriteBatchResponse<E>, InternalError>
123 where
124 E: EntityKind<Canister = C> + EntityValue,
125 {
126 self.execute_save_with(op, WriteBatchResponse::new)
127 }
128
129 fn execute_save_view<E>(
130 &self,
131 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
132 ) -> Result<E::ViewType, InternalError>
133 where
134 E: EntityKind<Canister = C> + EntityValue,
135 {
136 self.execute_save_with(op, std::convert::identity)
137 }
138
139 #[must_use]
145 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
146 where
147 E: EntityKind<Canister = C>,
148 {
149 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
150 }
151
152 #[must_use]
154 pub const fn load_with_consistency<E>(
155 &self,
156 consistency: MissingRowPolicy,
157 ) -> FluentLoadQuery<'_, E>
158 where
159 E: EntityKind<Canister = C>,
160 {
161 FluentLoadQuery::new(self, Query::new(consistency))
162 }
163
164 #[must_use]
166 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
167 where
168 E: EntityKind<Canister = C>,
169 {
170 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
171 }
172
173 #[must_use]
175 pub fn delete_with_consistency<E>(
176 &self,
177 consistency: MissingRowPolicy,
178 ) -> FluentDeleteQuery<'_, E>
179 where
180 E: EntityKind<Canister = C>,
181 {
182 FluentDeleteQuery::new(self, Query::new(consistency).delete())
183 }
184
185 #[must_use]
190 pub(crate) const fn load_executor<E>(&self) -> LoadExecutor<E>
191 where
192 E: EntityKind<Canister = C> + EntityValue,
193 {
194 LoadExecutor::new(self.db, self.debug)
195 }
196
197 #[must_use]
198 pub(crate) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
199 where
200 E: EntityKind<Canister = C> + EntityValue,
201 {
202 DeleteExecutor::new(self.db, self.debug)
203 }
204
205 #[must_use]
206 pub(crate) const fn save_executor<E>(&self) -> SaveExecutor<E>
207 where
208 E: EntityKind<Canister = C> + EntityValue,
209 {
210 SaveExecutor::new(self.db, self.debug)
211 }
212
213 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
219 where
220 E: EntityKind<Canister = C> + EntityValue,
221 {
222 let plan = query.plan()?.into_executable();
223
224 let result = match query.mode() {
225 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
226 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
227 };
228
229 result.map_err(QueryError::execute)
230 }
231
232 pub(crate) fn execute_load_query_with<E, T>(
235 &self,
236 query: &Query<E>,
237 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
238 ) -> Result<T, QueryError>
239 where
240 E: EntityKind<Canister = C> + EntityValue,
241 {
242 let plan = query.plan()?.into_executable();
243
244 self.with_metrics(|| op(self.load_executor::<E>(), plan))
245 .map_err(QueryError::execute)
246 }
247
248 pub(crate) fn execute_load_query_paged_with_trace<E>(
250 &self,
251 query: &Query<E>,
252 cursor_token: Option<&str>,
253 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
254 where
255 E: EntityKind<Canister = C> + EntityValue,
256 {
257 let plan = query.plan()?.into_executable();
259 match plan.execution_strategy().map_err(QueryError::execute)? {
260 ExecutionStrategy::PrimaryKey => {
261 return Err(QueryError::execute(invariant(
262 "cursor pagination requires explicit or grouped ordering",
263 )));
264 }
265 ExecutionStrategy::Ordered => {}
266 ExecutionStrategy::Grouped => {
267 return Err(QueryError::execute(invariant(
268 "grouped plans require execute_grouped(...)",
269 )));
270 }
271 }
272
273 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
275 let cursor = plan
276 .prepare_cursor(cursor_bytes.as_deref())
277 .map_err(map_executor_plan_error)?;
278
279 let (page, trace) = self
281 .with_metrics(|| {
282 self.load_executor::<E>()
283 .execute_paged_with_cursor_traced(plan, cursor)
284 })
285 .map_err(QueryError::execute)?;
286 let next_cursor = page
287 .next_cursor
288 .map(|token| {
289 let Some(token) = token.as_scalar() else {
290 return Err(QueryError::execute(invariant(
291 "scalar load pagination emitted grouped continuation token",
292 )));
293 };
294
295 token.encode().map_err(|err| {
296 QueryError::execute(InternalError::serialize_internal(format!(
297 "failed to serialize continuation cursor: {err}"
298 )))
299 })
300 })
301 .transpose()?;
302
303 Ok(PagedLoadExecutionWithTrace::new(
304 page.items,
305 next_cursor,
306 trace,
307 ))
308 }
309
310 pub fn execute_grouped<E>(
315 &self,
316 query: &Query<E>,
317 cursor_token: Option<&str>,
318 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
319 where
320 E: EntityKind<Canister = C> + EntityValue,
321 {
322 let plan = query.plan()?.into_executable();
324 if !matches!(
325 plan.execution_strategy().map_err(QueryError::execute)?,
326 ExecutionStrategy::Grouped
327 ) {
328 return Err(QueryError::execute(invariant(
329 "execute_grouped requires grouped logical plans",
330 )));
331 }
332
333 let cursor_bytes = decode_optional_cursor_bytes(cursor_token)?;
335 let cursor = plan
336 .prepare_grouped_cursor(cursor_bytes.as_deref())
337 .map_err(map_executor_plan_error)?;
338
339 let (page, trace) = self
341 .with_metrics(|| {
342 self.load_executor::<E>()
343 .execute_grouped_paged_with_cursor_traced(plan, cursor)
344 })
345 .map_err(QueryError::execute)?;
346 let next_cursor = page
347 .next_cursor
348 .map(|token| {
349 let Some(token) = token.as_grouped() else {
350 return Err(QueryError::execute(invariant(
351 "grouped pagination emitted scalar continuation token",
352 )));
353 };
354
355 token.encode().map_err(|err| {
356 QueryError::execute(InternalError::serialize_internal(format!(
357 "failed to serialize grouped continuation cursor: {err}"
358 )))
359 })
360 })
361 .transpose()?;
362
363 Ok(PagedGroupedExecutionWithTrace::new(
364 page.rows,
365 next_cursor,
366 trace,
367 ))
368 }
369
370 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
376 where
377 E: EntityKind<Canister = C> + EntityValue,
378 {
379 self.execute_save_entity(|save| save.insert(entity))
380 }
381
382 pub fn insert_many_atomic<E>(
388 &self,
389 entities: impl IntoIterator<Item = E>,
390 ) -> Result<WriteBatchResponse<E>, InternalError>
391 where
392 E: EntityKind<Canister = C> + EntityValue,
393 {
394 self.execute_save_batch(|save| save.insert_many_atomic(entities))
395 }
396
397 pub fn insert_many_non_atomic<E>(
401 &self,
402 entities: impl IntoIterator<Item = E>,
403 ) -> Result<WriteBatchResponse<E>, InternalError>
404 where
405 E: EntityKind<Canister = C> + EntityValue,
406 {
407 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
408 }
409
410 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
412 where
413 E: EntityKind<Canister = C> + EntityValue,
414 {
415 self.execute_save_entity(|save| save.replace(entity))
416 }
417
418 pub fn replace_many_atomic<E>(
424 &self,
425 entities: impl IntoIterator<Item = E>,
426 ) -> Result<WriteBatchResponse<E>, InternalError>
427 where
428 E: EntityKind<Canister = C> + EntityValue,
429 {
430 self.execute_save_batch(|save| save.replace_many_atomic(entities))
431 }
432
433 pub fn replace_many_non_atomic<E>(
437 &self,
438 entities: impl IntoIterator<Item = E>,
439 ) -> Result<WriteBatchResponse<E>, InternalError>
440 where
441 E: EntityKind<Canister = C> + EntityValue,
442 {
443 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
444 }
445
446 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
448 where
449 E: EntityKind<Canister = C> + EntityValue,
450 {
451 self.execute_save_entity(|save| save.update(entity))
452 }
453
454 pub fn update_many_atomic<E>(
460 &self,
461 entities: impl IntoIterator<Item = E>,
462 ) -> Result<WriteBatchResponse<E>, InternalError>
463 where
464 E: EntityKind<Canister = C> + EntityValue,
465 {
466 self.execute_save_batch(|save| save.update_many_atomic(entities))
467 }
468
469 pub fn update_many_non_atomic<E>(
473 &self,
474 entities: impl IntoIterator<Item = E>,
475 ) -> Result<WriteBatchResponse<E>, InternalError>
476 where
477 E: EntityKind<Canister = C> + EntityValue,
478 {
479 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
480 }
481
482 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
484 where
485 E: EntityKind<Canister = C> + EntityValue,
486 {
487 self.execute_save_view::<E>(|save| save.insert_view(view))
488 }
489
490 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
492 where
493 E: EntityKind<Canister = C> + EntityValue,
494 {
495 self.execute_save_view::<E>(|save| save.replace_view(view))
496 }
497
498 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
500 where
501 E: EntityKind<Canister = C> + EntityValue,
502 {
503 self.execute_save_view::<E>(|save| save.update_view(view))
504 }
505
506 #[cfg(test)]
508 #[doc(hidden)]
509 pub fn clear_stores_for_tests(&self) {
510 self.db.with_store_registry(|reg| {
511 for (_, store) in reg.iter() {
514 store.with_data_mut(DataStore::clear);
515 store.with_index_mut(IndexStore::clear);
516 }
517 });
518 }
519}
520
521fn invariant(message: impl Into<String>) -> InternalError {
522 InternalError::query_executor_invariant(message)
523}
524
525#[cfg(test)]
530mod tests {
531 use super::*;
532 use crate::db::cursor::CursorPlanError;
533
534 fn assert_query_error_is_cursor_plan(
536 err: QueryError,
537 predicate: impl FnOnce(&CursorPlanError) -> bool,
538 ) {
539 assert!(matches!(
540 err,
541 QueryError::Plan(plan_err)
542 if matches!(
543 plan_err.as_ref(),
544 PlanError::Cursor(inner) if predicate(inner.as_ref())
545 )
546 ));
547 }
548
549 fn assert_cursor_mapping_parity(
551 build: impl Fn() -> CursorPlanError,
552 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
553 ) {
554 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
555 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
556
557 let mapped_via_plan = QueryError::from(PlanError::from(build()));
558 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
559 }
560
561 #[test]
562 fn session_cursor_error_mapping_parity_boundary_arity() {
563 assert_cursor_mapping_parity(
564 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
565 |inner| {
566 matches!(
567 inner,
568 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
569 expected: 2,
570 found: 1
571 }
572 )
573 },
574 );
575 }
576
577 #[test]
578 fn session_cursor_error_mapping_parity_window_mismatch() {
579 assert_cursor_mapping_parity(
580 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
581 |inner| {
582 matches!(
583 inner,
584 CursorPlanError::ContinuationCursorWindowMismatch {
585 expected_offset: 8,
586 actual_offset: 3
587 }
588 )
589 },
590 );
591 }
592
593 #[test]
594 fn session_cursor_error_mapping_parity_decode_reason() {
595 assert_cursor_mapping_parity(
596 || {
597 CursorPlanError::invalid_continuation_cursor(
598 crate::db::codec::cursor::CursorDecodeError::OddLength,
599 )
600 },
601 |inner| {
602 matches!(
603 inner,
604 CursorPlanError::InvalidContinuationCursor {
605 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
606 }
607 )
608 },
609 );
610 }
611
612 #[test]
613 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
614 assert_cursor_mapping_parity(
615 || {
616 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
617 "id",
618 "ulid",
619 Some(crate::value::Value::Text("not-a-ulid".to_string())),
620 )
621 },
622 |inner| {
623 matches!(
624 inner,
625 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
626 field,
627 expected,
628 value: Some(crate::value::Value::Text(value))
629 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
630 )
631 },
632 );
633 }
634
635 #[test]
636 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
637 assert_cursor_mapping_parity(
639 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
640 |inner| {
641 matches!(
642 inner,
643 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
644 expected: 2,
645 found: 1
646 }
647 )
648 },
649 );
650 }
651}