1#[cfg(test)]
7use crate::db::{DataStore, IndexStore};
8use crate::{
9 db::{
10 Db, FluentDeleteQuery, FluentLoadQuery, MissingRowPolicy, PagedGroupedExecutionWithTrace,
11 PagedLoadExecutionWithTrace, PlanError, Query, QueryError, Response, WriteBatchResponse,
12 WriteResponse,
13 cursor::CursorPlanError,
14 decode_cursor,
15 executor::{DeleteExecutor, ExecutablePlan, ExecutorPlanError, LoadExecutor, SaveExecutor},
16 query::intent::QueryMode,
17 },
18 error::InternalError,
19 obs::sink::{MetricsSink, with_metrics_sink},
20 traits::{CanisterKind, EntityKind, EntityValue},
21};
22
23fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
25 match err {
26 ExecutorPlanError::Cursor(err) => QueryError::from(PlanError::from(*err)),
27 }
28}
29
30pub struct DbSession<C: CanisterKind> {
37 db: Db<C>,
38 debug: bool,
39 metrics: Option<&'static dyn MetricsSink>,
40}
41
42impl<C: CanisterKind> DbSession<C> {
43 #[must_use]
45 pub const fn new(db: Db<C>) -> Self {
46 Self {
47 db,
48 debug: false,
49 metrics: None,
50 }
51 }
52
53 #[must_use]
55 pub const fn debug(mut self) -> Self {
56 self.debug = true;
57 self
58 }
59
60 #[must_use]
62 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
63 self.metrics = Some(sink);
64 self
65 }
66
67 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
68 if let Some(sink) = self.metrics {
69 with_metrics_sink(sink, f)
70 } else {
71 f()
72 }
73 }
74
75 fn execute_save_with<E, T, R>(
77 &self,
78 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
79 map: impl FnOnce(T) -> R,
80 ) -> Result<R, InternalError>
81 where
82 E: EntityKind<Canister = C> + EntityValue,
83 {
84 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
85
86 Ok(map(value))
87 }
88
89 fn execute_save_entity<E>(
91 &self,
92 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
93 ) -> Result<WriteResponse<E>, InternalError>
94 where
95 E: EntityKind<Canister = C> + EntityValue,
96 {
97 self.execute_save_with(op, WriteResponse::new)
98 }
99
100 fn execute_save_batch<E>(
101 &self,
102 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
103 ) -> Result<WriteBatchResponse<E>, InternalError>
104 where
105 E: EntityKind<Canister = C> + EntityValue,
106 {
107 self.execute_save_with(op, WriteBatchResponse::new)
108 }
109
110 fn execute_save_view<E>(
111 &self,
112 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
113 ) -> Result<E::ViewType, InternalError>
114 where
115 E: EntityKind<Canister = C> + EntityValue,
116 {
117 self.execute_save_with(op, std::convert::identity)
118 }
119
120 #[must_use]
126 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
127 where
128 E: EntityKind<Canister = C>,
129 {
130 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
131 }
132
133 #[must_use]
135 pub const fn load_with_consistency<E>(
136 &self,
137 consistency: MissingRowPolicy,
138 ) -> FluentLoadQuery<'_, E>
139 where
140 E: EntityKind<Canister = C>,
141 {
142 FluentLoadQuery::new(self, Query::new(consistency))
143 }
144
145 #[must_use]
147 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
148 where
149 E: EntityKind<Canister = C>,
150 {
151 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
152 }
153
154 #[must_use]
156 pub fn delete_with_consistency<E>(
157 &self,
158 consistency: MissingRowPolicy,
159 ) -> FluentDeleteQuery<'_, E>
160 where
161 E: EntityKind<Canister = C>,
162 {
163 FluentDeleteQuery::new(self, Query::new(consistency).delete())
164 }
165
166 #[must_use]
171 pub(crate) const fn load_executor<E>(&self) -> LoadExecutor<E>
172 where
173 E: EntityKind<Canister = C> + EntityValue,
174 {
175 LoadExecutor::new(self.db, self.debug)
176 }
177
178 #[must_use]
179 pub(crate) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
180 where
181 E: EntityKind<Canister = C> + EntityValue,
182 {
183 DeleteExecutor::new(self.db, self.debug)
184 }
185
186 #[must_use]
187 pub(crate) const fn save_executor<E>(&self) -> SaveExecutor<E>
188 where
189 E: EntityKind<Canister = C> + EntityValue,
190 {
191 SaveExecutor::new(self.db, self.debug)
192 }
193
194 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<Response<E>, QueryError>
200 where
201 E: EntityKind<Canister = C> + EntityValue,
202 {
203 let plan = query.plan()?.into_executable();
204
205 let result = match query.mode() {
206 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
207 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
208 };
209
210 result.map_err(QueryError::Execute)
211 }
212
213 pub(crate) fn execute_load_query_with<E, T>(
216 &self,
217 query: &Query<E>,
218 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
219 ) -> Result<T, QueryError>
220 where
221 E: EntityKind<Canister = C> + EntityValue,
222 {
223 let plan = query.plan()?.into_executable();
224
225 self.with_metrics(|| op(self.load_executor::<E>(), plan))
226 .map_err(QueryError::Execute)
227 }
228
229 pub(crate) fn execute_load_query_paged_with_trace<E>(
231 &self,
232 query: &Query<E>,
233 cursor_token: Option<&str>,
234 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
235 where
236 E: EntityKind<Canister = C> + EntityValue,
237 {
238 let plan = query.plan()?.into_executable();
240 if plan.as_inner().grouped_plan().is_some() {
241 return Err(QueryError::Execute(
242 InternalError::query_executor_invariant(
243 "grouped plans require execute_grouped(...)",
244 ),
245 ));
246 }
247
248 let cursor_bytes = match cursor_token {
250 Some(token) => Some(decode_cursor(token).map_err(|reason| {
251 QueryError::from(PlanError::from(
252 CursorPlanError::invalid_continuation_cursor(reason),
253 ))
254 })?),
255 None => None,
256 };
257 let cursor = plan
258 .prepare_cursor(cursor_bytes.as_deref())
259 .map_err(map_executor_plan_error)?;
260
261 let (page, trace) = self
263 .with_metrics(|| {
264 self.load_executor::<E>()
265 .execute_paged_with_cursor_traced(plan, cursor)
266 })
267 .map_err(QueryError::Execute)?;
268 let next_cursor = page
269 .next_cursor
270 .map(|token| {
271 let Some(token) = token.as_scalar() else {
272 return Err(QueryError::Execute(
273 InternalError::query_executor_invariant(
274 "scalar load pagination emitted grouped continuation token",
275 ),
276 ));
277 };
278
279 token.encode().map_err(|err| {
280 QueryError::Execute(InternalError::serialize_internal(format!(
281 "failed to serialize continuation cursor: {err}"
282 )))
283 })
284 })
285 .transpose()?;
286
287 Ok(PagedLoadExecutionWithTrace::new(
288 page.items,
289 next_cursor,
290 trace,
291 ))
292 }
293
294 pub fn execute_grouped<E>(
299 &self,
300 query: &Query<E>,
301 cursor_token: Option<&str>,
302 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
303 where
304 E: EntityKind<Canister = C> + EntityValue,
305 {
306 let plan = query.plan()?.into_executable();
308 if plan.as_inner().grouped_plan().is_none() {
309 return Err(QueryError::Execute(
310 InternalError::query_executor_invariant(
311 "execute_grouped requires grouped logical plans",
312 ),
313 ));
314 }
315
316 let cursor_bytes = match cursor_token {
318 Some(token) => Some(decode_cursor(token).map_err(|reason| {
319 QueryError::from(PlanError::from(
320 CursorPlanError::invalid_continuation_cursor(reason),
321 ))
322 })?),
323 None => None,
324 };
325 let cursor = plan
326 .prepare_grouped_cursor(cursor_bytes.as_deref())
327 .map_err(map_executor_plan_error)?;
328
329 let (page, trace) = self
331 .with_metrics(|| {
332 self.load_executor::<E>()
333 .execute_grouped_paged_with_cursor_traced(plan, cursor)
334 })
335 .map_err(QueryError::Execute)?;
336 let next_cursor = page
337 .next_cursor
338 .map(|token| {
339 let Some(token) = token.as_grouped() else {
340 return Err(QueryError::Execute(
341 InternalError::query_executor_invariant(
342 "grouped pagination emitted scalar continuation token",
343 ),
344 ));
345 };
346
347 token.encode().map_err(|err| {
348 QueryError::Execute(InternalError::serialize_internal(format!(
349 "failed to serialize grouped continuation cursor: {err}"
350 )))
351 })
352 })
353 .transpose()?;
354
355 Ok(PagedGroupedExecutionWithTrace::new(
356 page.rows,
357 next_cursor,
358 trace,
359 ))
360 }
361
362 pub fn insert<E>(&self, entity: E) -> Result<WriteResponse<E>, InternalError>
368 where
369 E: EntityKind<Canister = C> + EntityValue,
370 {
371 self.execute_save_entity(|save| save.insert(entity))
372 }
373
374 pub fn insert_many_atomic<E>(
380 &self,
381 entities: impl IntoIterator<Item = E>,
382 ) -> Result<WriteBatchResponse<E>, InternalError>
383 where
384 E: EntityKind<Canister = C> + EntityValue,
385 {
386 self.execute_save_batch(|save| save.insert_many_atomic(entities))
387 }
388
389 pub fn insert_many_non_atomic<E>(
393 &self,
394 entities: impl IntoIterator<Item = E>,
395 ) -> Result<WriteBatchResponse<E>, InternalError>
396 where
397 E: EntityKind<Canister = C> + EntityValue,
398 {
399 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
400 }
401
402 pub fn replace<E>(&self, entity: E) -> Result<WriteResponse<E>, InternalError>
404 where
405 E: EntityKind<Canister = C> + EntityValue,
406 {
407 self.execute_save_entity(|save| save.replace(entity))
408 }
409
410 pub fn replace_many_atomic<E>(
416 &self,
417 entities: impl IntoIterator<Item = E>,
418 ) -> Result<WriteBatchResponse<E>, InternalError>
419 where
420 E: EntityKind<Canister = C> + EntityValue,
421 {
422 self.execute_save_batch(|save| save.replace_many_atomic(entities))
423 }
424
425 pub fn replace_many_non_atomic<E>(
429 &self,
430 entities: impl IntoIterator<Item = E>,
431 ) -> Result<WriteBatchResponse<E>, InternalError>
432 where
433 E: EntityKind<Canister = C> + EntityValue,
434 {
435 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
436 }
437
438 pub fn update<E>(&self, entity: E) -> Result<WriteResponse<E>, InternalError>
440 where
441 E: EntityKind<Canister = C> + EntityValue,
442 {
443 self.execute_save_entity(|save| save.update(entity))
444 }
445
446 pub fn update_many_atomic<E>(
452 &self,
453 entities: impl IntoIterator<Item = E>,
454 ) -> Result<WriteBatchResponse<E>, InternalError>
455 where
456 E: EntityKind<Canister = C> + EntityValue,
457 {
458 self.execute_save_batch(|save| save.update_many_atomic(entities))
459 }
460
461 pub fn update_many_non_atomic<E>(
465 &self,
466 entities: impl IntoIterator<Item = E>,
467 ) -> Result<WriteBatchResponse<E>, InternalError>
468 where
469 E: EntityKind<Canister = C> + EntityValue,
470 {
471 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
472 }
473
474 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
476 where
477 E: EntityKind<Canister = C> + EntityValue,
478 {
479 self.execute_save_view::<E>(|save| save.insert_view(view))
480 }
481
482 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
484 where
485 E: EntityKind<Canister = C> + EntityValue,
486 {
487 self.execute_save_view::<E>(|save| save.replace_view(view))
488 }
489
490 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
492 where
493 E: EntityKind<Canister = C> + EntityValue,
494 {
495 self.execute_save_view::<E>(|save| save.update_view(view))
496 }
497
498 #[cfg(test)]
500 #[doc(hidden)]
501 pub fn clear_stores_for_tests(&self) {
502 self.db.with_store_registry(|reg| {
503 for (_, store) in reg.iter() {
506 store.with_data_mut(DataStore::clear);
507 store.with_index_mut(IndexStore::clear);
508 }
509 });
510 }
511}