1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, EntityResponse, FluentDeleteQuery, FluentLoadQuery, MissingRowPolicy,
11 PagedGroupedExecutionWithTrace, PagedLoadExecutionWithTrace, PlanError, Query, QueryError,
12 StoreRegistry, WriteBatchResponse,
13 commit::EntityRuntimeHooks,
14 cursor::CursorPlanError,
15 decode_cursor,
16 executor::{DeleteExecutor, ExecutablePlan, ExecutorPlanError, LoadExecutor, SaveExecutor},
17 query::intent::QueryMode,
18 },
19 error::InternalError,
20 obs::sink::{MetricsSink, with_metrics_sink},
21 traits::{CanisterKind, EntityKind, EntityValue},
22};
23use std::thread::LocalKey;
24
25fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
27 match err {
28 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
29 }
30}
31
32pub struct DbSession<C: CanisterKind> {
39 db: Db<C>,
40 debug: bool,
41 metrics: Option<&'static dyn MetricsSink>,
42}
43
44impl<C: CanisterKind> DbSession<C> {
45 #[must_use]
47 pub(crate) const fn new(db: Db<C>) -> Self {
48 Self {
49 db,
50 debug: false,
51 metrics: None,
52 }
53 }
54
55 #[must_use]
57 pub const fn new_with_hooks(
58 store: &'static LocalKey<StoreRegistry>,
59 entity_runtime_hooks: &'static [EntityRuntimeHooks<C>],
60 ) -> Self {
61 Self::new(Db::new_with_hooks(store, entity_runtime_hooks))
62 }
63
64 #[must_use]
66 pub const fn debug(mut self) -> Self {
67 self.debug = true;
68 self
69 }
70
71 #[must_use]
73 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
74 self.metrics = Some(sink);
75 self
76 }
77
78 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
79 if let Some(sink) = self.metrics {
80 with_metrics_sink(sink, f)
81 } else {
82 f()
83 }
84 }
85
86 fn execute_save_with<E, T, R>(
88 &self,
89 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
90 map: impl FnOnce(T) -> R,
91 ) -> Result<R, InternalError>
92 where
93 E: EntityKind<Canister = C> + EntityValue,
94 {
95 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
96
97 Ok(map(value))
98 }
99
100 fn execute_save_entity<E>(
102 &self,
103 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
104 ) -> Result<E, InternalError>
105 where
106 E: EntityKind<Canister = C> + EntityValue,
107 {
108 self.execute_save_with(op, std::convert::identity)
109 }
110
111 fn execute_save_batch<E>(
112 &self,
113 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
114 ) -> Result<WriteBatchResponse<E>, InternalError>
115 where
116 E: EntityKind<Canister = C> + EntityValue,
117 {
118 self.execute_save_with(op, WriteBatchResponse::new)
119 }
120
121 fn execute_save_view<E>(
122 &self,
123 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
124 ) -> Result<E::ViewType, InternalError>
125 where
126 E: EntityKind<Canister = C> + EntityValue,
127 {
128 self.execute_save_with(op, std::convert::identity)
129 }
130
131 #[must_use]
137 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
138 where
139 E: EntityKind<Canister = C>,
140 {
141 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
142 }
143
144 #[must_use]
146 pub const fn load_with_consistency<E>(
147 &self,
148 consistency: MissingRowPolicy,
149 ) -> FluentLoadQuery<'_, E>
150 where
151 E: EntityKind<Canister = C>,
152 {
153 FluentLoadQuery::new(self, Query::new(consistency))
154 }
155
156 #[must_use]
158 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
159 where
160 E: EntityKind<Canister = C>,
161 {
162 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
163 }
164
165 #[must_use]
167 pub fn delete_with_consistency<E>(
168 &self,
169 consistency: MissingRowPolicy,
170 ) -> FluentDeleteQuery<'_, E>
171 where
172 E: EntityKind<Canister = C>,
173 {
174 FluentDeleteQuery::new(self, Query::new(consistency).delete())
175 }
176
177 #[must_use]
182 pub(crate) const fn load_executor<E>(&self) -> LoadExecutor<E>
183 where
184 E: EntityKind<Canister = C> + EntityValue,
185 {
186 LoadExecutor::new(self.db, self.debug)
187 }
188
189 #[must_use]
190 pub(crate) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
191 where
192 E: EntityKind<Canister = C> + EntityValue,
193 {
194 DeleteExecutor::new(self.db, self.debug)
195 }
196
197 #[must_use]
198 pub(crate) const fn save_executor<E>(&self) -> SaveExecutor<E>
199 where
200 E: EntityKind<Canister = C> + EntityValue,
201 {
202 SaveExecutor::new(self.db, self.debug)
203 }
204
205 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<EntityResponse<E>, QueryError>
211 where
212 E: EntityKind<Canister = C> + EntityValue,
213 {
214 let plan = query.plan()?.into_executable();
215
216 let result = match query.mode() {
217 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
218 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
219 };
220
221 result.map_err(QueryError::execute)
222 }
223
224 pub(crate) fn execute_load_query_with<E, T>(
227 &self,
228 query: &Query<E>,
229 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
230 ) -> Result<T, QueryError>
231 where
232 E: EntityKind<Canister = C> + EntityValue,
233 {
234 let plan = query.plan()?.into_executable();
235
236 self.with_metrics(|| op(self.load_executor::<E>(), plan))
237 .map_err(QueryError::execute)
238 }
239
240 pub(crate) fn execute_load_query_paged_with_trace<E>(
242 &self,
243 query: &Query<E>,
244 cursor_token: Option<&str>,
245 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
246 where
247 E: EntityKind<Canister = C> + EntityValue,
248 {
249 let plan = query.plan()?.into_executable();
251 if !plan.supports_continuation() {
252 return Err(QueryError::execute(invariant(
253 "cursor pagination requires load plans",
254 )));
255 }
256 if plan.is_grouped() {
257 return Err(QueryError::execute(invariant(
258 "grouped plans require execute_grouped(...)",
259 )));
260 }
261
262 let cursor_bytes = match cursor_token {
264 Some(token) => Some(decode_cursor(token).map_err(|reason| {
265 QueryError::from(PlanError::from(
266 CursorPlanError::invalid_continuation_cursor(reason),
267 ))
268 })?),
269 None => None,
270 };
271 let cursor = plan
272 .prepare_cursor(cursor_bytes.as_deref())
273 .map_err(map_executor_plan_error)?;
274
275 let (page, trace) = self
277 .with_metrics(|| {
278 self.load_executor::<E>()
279 .execute_paged_with_cursor_traced(plan, cursor)
280 })
281 .map_err(QueryError::execute)?;
282 let next_cursor = page
283 .next_cursor
284 .map(|token| {
285 let Some(token) = token.as_scalar() else {
286 return Err(QueryError::execute(invariant(
287 "scalar load pagination emitted grouped continuation token",
288 )));
289 };
290
291 token.encode().map_err(|err| {
292 QueryError::execute(InternalError::serialize_internal(format!(
293 "failed to serialize continuation cursor: {err}"
294 )))
295 })
296 })
297 .transpose()?;
298
299 Ok(PagedLoadExecutionWithTrace::new(
300 page.items,
301 next_cursor,
302 trace,
303 ))
304 }
305
306 pub fn execute_grouped<E>(
311 &self,
312 query: &Query<E>,
313 cursor_token: Option<&str>,
314 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
315 where
316 E: EntityKind<Canister = C> + EntityValue,
317 {
318 let plan = query.plan()?.into_executable();
320 if !plan.supports_continuation() || !plan.is_grouped() {
321 return Err(QueryError::execute(invariant(
322 "execute_grouped requires grouped logical plans",
323 )));
324 }
325
326 let cursor_bytes = match cursor_token {
328 Some(token) => Some(decode_cursor(token).map_err(|reason| {
329 QueryError::from(PlanError::from(
330 CursorPlanError::invalid_continuation_cursor(reason),
331 ))
332 })?),
333 None => None,
334 };
335 let cursor = plan
336 .prepare_grouped_cursor(cursor_bytes.as_deref())
337 .map_err(map_executor_plan_error)?;
338
339 let (page, trace) = self
341 .with_metrics(|| {
342 self.load_executor::<E>()
343 .execute_grouped_paged_with_cursor_traced(plan, cursor)
344 })
345 .map_err(QueryError::execute)?;
346 let next_cursor = page
347 .next_cursor
348 .map(|token| {
349 let Some(token) = token.as_grouped() else {
350 return Err(QueryError::execute(invariant(
351 "grouped pagination emitted scalar continuation token",
352 )));
353 };
354
355 token.encode().map_err(|err| {
356 QueryError::execute(InternalError::serialize_internal(format!(
357 "failed to serialize grouped continuation cursor: {err}"
358 )))
359 })
360 })
361 .transpose()?;
362
363 Ok(PagedGroupedExecutionWithTrace::new(
364 page.rows,
365 next_cursor,
366 trace,
367 ))
368 }
369
370 pub fn insert<E>(&self, entity: E) -> Result<E, InternalError>
376 where
377 E: EntityKind<Canister = C> + EntityValue,
378 {
379 self.execute_save_entity(|save| save.insert(entity))
380 }
381
382 pub fn insert_many_atomic<E>(
388 &self,
389 entities: impl IntoIterator<Item = E>,
390 ) -> Result<WriteBatchResponse<E>, InternalError>
391 where
392 E: EntityKind<Canister = C> + EntityValue,
393 {
394 self.execute_save_batch(|save| save.insert_many_atomic(entities))
395 }
396
397 pub fn insert_many_non_atomic<E>(
401 &self,
402 entities: impl IntoIterator<Item = E>,
403 ) -> Result<WriteBatchResponse<E>, InternalError>
404 where
405 E: EntityKind<Canister = C> + EntityValue,
406 {
407 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
408 }
409
410 pub fn replace<E>(&self, entity: E) -> Result<E, InternalError>
412 where
413 E: EntityKind<Canister = C> + EntityValue,
414 {
415 self.execute_save_entity(|save| save.replace(entity))
416 }
417
418 pub fn replace_many_atomic<E>(
424 &self,
425 entities: impl IntoIterator<Item = E>,
426 ) -> Result<WriteBatchResponse<E>, InternalError>
427 where
428 E: EntityKind<Canister = C> + EntityValue,
429 {
430 self.execute_save_batch(|save| save.replace_many_atomic(entities))
431 }
432
433 pub fn replace_many_non_atomic<E>(
437 &self,
438 entities: impl IntoIterator<Item = E>,
439 ) -> Result<WriteBatchResponse<E>, InternalError>
440 where
441 E: EntityKind<Canister = C> + EntityValue,
442 {
443 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
444 }
445
446 pub fn update<E>(&self, entity: E) -> Result<E, InternalError>
448 where
449 E: EntityKind<Canister = C> + EntityValue,
450 {
451 self.execute_save_entity(|save| save.update(entity))
452 }
453
454 pub fn update_many_atomic<E>(
460 &self,
461 entities: impl IntoIterator<Item = E>,
462 ) -> Result<WriteBatchResponse<E>, InternalError>
463 where
464 E: EntityKind<Canister = C> + EntityValue,
465 {
466 self.execute_save_batch(|save| save.update_many_atomic(entities))
467 }
468
469 pub fn update_many_non_atomic<E>(
473 &self,
474 entities: impl IntoIterator<Item = E>,
475 ) -> Result<WriteBatchResponse<E>, InternalError>
476 where
477 E: EntityKind<Canister = C> + EntityValue,
478 {
479 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
480 }
481
482 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
484 where
485 E: EntityKind<Canister = C> + EntityValue,
486 {
487 self.execute_save_view::<E>(|save| save.insert_view(view))
488 }
489
490 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
492 where
493 E: EntityKind<Canister = C> + EntityValue,
494 {
495 self.execute_save_view::<E>(|save| save.replace_view(view))
496 }
497
498 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
500 where
501 E: EntityKind<Canister = C> + EntityValue,
502 {
503 self.execute_save_view::<E>(|save| save.update_view(view))
504 }
505
506 #[cfg(test)]
508 #[doc(hidden)]
509 pub fn clear_stores_for_tests(&self) {
510 self.db.with_store_registry(|reg| {
511 for (_, store) in reg.iter() {
514 store.with_data_mut(DataStore::clear);
515 store.with_index_mut(IndexStore::clear);
516 }
517 });
518 }
519}
520
521fn invariant(message: impl Into<String>) -> InternalError {
522 InternalError::query_executor_invariant(message)
523}
524
525#[cfg(test)]
530mod tests {
531 use super::*;
532
533 fn assert_query_error_is_cursor_plan(
535 err: QueryError,
536 predicate: impl FnOnce(&CursorPlanError) -> bool,
537 ) {
538 assert!(matches!(
539 err,
540 QueryError::Plan(plan_err)
541 if matches!(
542 plan_err.as_ref(),
543 PlanError::Cursor(inner) if predicate(inner.as_ref())
544 )
545 ));
546 }
547
548 fn assert_cursor_mapping_parity(
550 build: impl Fn() -> CursorPlanError,
551 predicate: impl Fn(&CursorPlanError) -> bool + Copy,
552 ) {
553 let mapped_via_executor = map_executor_plan_error(ExecutorPlanError::from(build()));
554 assert_query_error_is_cursor_plan(mapped_via_executor, predicate);
555
556 let mapped_via_plan = QueryError::from(PlanError::from(build()));
557 assert_query_error_is_cursor_plan(mapped_via_plan, predicate);
558 }
559
560 #[test]
561 fn session_cursor_error_mapping_parity_boundary_arity() {
562 assert_cursor_mapping_parity(
563 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
564 |inner| {
565 matches!(
566 inner,
567 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
568 expected: 2,
569 found: 1
570 }
571 )
572 },
573 );
574 }
575
576 #[test]
577 fn session_cursor_error_mapping_parity_window_mismatch() {
578 assert_cursor_mapping_parity(
579 || CursorPlanError::continuation_cursor_window_mismatch(8, 3),
580 |inner| {
581 matches!(
582 inner,
583 CursorPlanError::ContinuationCursorWindowMismatch {
584 expected_offset: 8,
585 actual_offset: 3
586 }
587 )
588 },
589 );
590 }
591
592 #[test]
593 fn session_cursor_error_mapping_parity_decode_reason() {
594 assert_cursor_mapping_parity(
595 || {
596 CursorPlanError::invalid_continuation_cursor(
597 crate::db::codec::cursor::CursorDecodeError::OddLength,
598 )
599 },
600 |inner| {
601 matches!(
602 inner,
603 CursorPlanError::InvalidContinuationCursor {
604 reason: crate::db::codec::cursor::CursorDecodeError::OddLength
605 }
606 )
607 },
608 );
609 }
610
611 #[test]
612 fn session_cursor_error_mapping_parity_primary_key_type_mismatch() {
613 assert_cursor_mapping_parity(
614 || {
615 CursorPlanError::continuation_cursor_primary_key_type_mismatch(
616 "id",
617 "ulid",
618 Some(crate::value::Value::Text("not-a-ulid".to_string())),
619 )
620 },
621 |inner| {
622 matches!(
623 inner,
624 CursorPlanError::ContinuationCursorPrimaryKeyTypeMismatch {
625 field,
626 expected,
627 value: Some(crate::value::Value::Text(value))
628 } if field == "id" && expected == "ulid" && value == "not-a-ulid"
629 )
630 },
631 );
632 }
633
634 #[test]
635 fn session_cursor_error_mapping_parity_matrix_preserves_cursor_variants() {
636 assert_cursor_mapping_parity(
638 || CursorPlanError::continuation_cursor_boundary_arity_mismatch(2, 1),
639 |inner| {
640 matches!(
641 inner,
642 CursorPlanError::ContinuationCursorBoundaryArityMismatch {
643 expected: 2,
644 found: 1
645 }
646 )
647 },
648 );
649 }
650}