1#[cfg(test)]
3use crate::db::{DataStore, IndexStore};
4use crate::{
5 db::{
6 Db, FluentDeleteQuery, FluentLoadQuery, MissingRowPolicy, PagedGroupedExecutionWithTrace,
7 PagedLoadExecutionWithTrace, PlanError, Query, QueryError, Response, WriteBatchResponse,
8 WriteResponse,
9 cursor::CursorPlanError,
10 decode_cursor,
11 executor::{DeleteExecutor, ExecutablePlan, ExecutorPlanError, LoadExecutor, SaveExecutor},
12 query::intent::QueryMode,
13 },
14 error::InternalError,
15 obs::sink::{MetricsSink, with_metrics_sink},
16 traits::{CanisterKind, EntityKind, EntityValue},
17};
18
19fn map_executor_plan_error(err: ExecutorPlanError) -> QueryError {
21 QueryError::from(err.into_plan_error())
22}
23
24pub struct DbSession<C: CanisterKind> {
31 db: Db<C>,
32 debug: bool,
33 metrics: Option<&'static dyn MetricsSink>,
34}
35
36impl<C: CanisterKind> DbSession<C> {
37 #[must_use]
38 pub const fn new(db: Db<C>) -> Self {
39 Self {
40 db,
41 debug: false,
42 metrics: None,
43 }
44 }
45
46 #[must_use]
47 pub const fn debug(mut self) -> Self {
48 self.debug = true;
49 self
50 }
51
52 #[must_use]
53 pub const fn metrics_sink(mut self, sink: &'static dyn MetricsSink) -> Self {
54 self.metrics = Some(sink);
55 self
56 }
57
58 fn with_metrics<T>(&self, f: impl FnOnce() -> T) -> T {
59 if let Some(sink) = self.metrics {
60 with_metrics_sink(sink, f)
61 } else {
62 f()
63 }
64 }
65
66 fn execute_save_with<E, T, R>(
68 &self,
69 op: impl FnOnce(SaveExecutor<E>) -> Result<T, InternalError>,
70 map: impl FnOnce(T) -> R,
71 ) -> Result<R, InternalError>
72 where
73 E: EntityKind<Canister = C> + EntityValue,
74 {
75 let value = self.with_metrics(|| op(self.save_executor::<E>()))?;
76
77 Ok(map(value))
78 }
79
80 fn execute_save_entity<E>(
82 &self,
83 op: impl FnOnce(SaveExecutor<E>) -> Result<E, InternalError>,
84 ) -> Result<WriteResponse<E>, InternalError>
85 where
86 E: EntityKind<Canister = C> + EntityValue,
87 {
88 self.execute_save_with(op, WriteResponse::new)
89 }
90
91 fn execute_save_batch<E>(
92 &self,
93 op: impl FnOnce(SaveExecutor<E>) -> Result<Vec<E>, InternalError>,
94 ) -> Result<WriteBatchResponse<E>, InternalError>
95 where
96 E: EntityKind<Canister = C> + EntityValue,
97 {
98 self.execute_save_with(op, WriteBatchResponse::new)
99 }
100
101 fn execute_save_view<E>(
102 &self,
103 op: impl FnOnce(SaveExecutor<E>) -> Result<E::ViewType, InternalError>,
104 ) -> Result<E::ViewType, InternalError>
105 where
106 E: EntityKind<Canister = C> + EntityValue,
107 {
108 self.execute_save_with(op, std::convert::identity)
109 }
110
111 #[must_use]
116 pub const fn load<E>(&self) -> FluentLoadQuery<'_, E>
117 where
118 E: EntityKind<Canister = C>,
119 {
120 FluentLoadQuery::new(self, Query::new(MissingRowPolicy::Ignore))
121 }
122
123 #[must_use]
124 pub const fn load_with_consistency<E>(
125 &self,
126 consistency: MissingRowPolicy,
127 ) -> FluentLoadQuery<'_, E>
128 where
129 E: EntityKind<Canister = C>,
130 {
131 FluentLoadQuery::new(self, Query::new(consistency))
132 }
133
134 #[must_use]
135 pub fn delete<E>(&self) -> FluentDeleteQuery<'_, E>
136 where
137 E: EntityKind<Canister = C>,
138 {
139 FluentDeleteQuery::new(self, Query::new(MissingRowPolicy::Ignore).delete())
140 }
141
142 #[must_use]
143 pub fn delete_with_consistency<E>(
144 &self,
145 consistency: MissingRowPolicy,
146 ) -> FluentDeleteQuery<'_, E>
147 where
148 E: EntityKind<Canister = C>,
149 {
150 FluentDeleteQuery::new(self, Query::new(consistency).delete())
151 }
152
153 #[must_use]
158 pub(crate) const fn load_executor<E>(&self) -> LoadExecutor<E>
159 where
160 E: EntityKind<Canister = C> + EntityValue,
161 {
162 LoadExecutor::new(self.db, self.debug)
163 }
164
165 #[must_use]
166 pub(crate) const fn delete_executor<E>(&self) -> DeleteExecutor<E>
167 where
168 E: EntityKind<Canister = C> + EntityValue,
169 {
170 DeleteExecutor::new(self.db, self.debug)
171 }
172
173 #[must_use]
174 pub(crate) const fn save_executor<E>(&self) -> SaveExecutor<E>
175 where
176 E: EntityKind<Canister = C> + EntityValue,
177 {
178 SaveExecutor::new(self.db, self.debug)
179 }
180
181 pub fn execute_query<E>(&self, query: &Query<E>) -> Result<Response<E>, QueryError>
186 where
187 E: EntityKind<Canister = C> + EntityValue,
188 {
189 let plan = query.plan()?.into_executable();
190
191 let result = match query.mode() {
192 QueryMode::Load(_) => self.with_metrics(|| self.load_executor::<E>().execute(plan)),
193 QueryMode::Delete(_) => self.with_metrics(|| self.delete_executor::<E>().execute(plan)),
194 };
195
196 result.map_err(QueryError::Execute)
197 }
198
199 pub(crate) fn execute_load_query_with<E, T>(
202 &self,
203 query: &Query<E>,
204 op: impl FnOnce(LoadExecutor<E>, ExecutablePlan<E>) -> Result<T, InternalError>,
205 ) -> Result<T, QueryError>
206 where
207 E: EntityKind<Canister = C> + EntityValue,
208 {
209 let plan = query.plan()?.into_executable();
210
211 self.with_metrics(|| op(self.load_executor::<E>(), plan))
212 .map_err(QueryError::Execute)
213 }
214
215 pub(crate) fn execute_load_query_paged_with_trace<E>(
216 &self,
217 query: &Query<E>,
218 cursor_token: Option<&str>,
219 ) -> Result<PagedLoadExecutionWithTrace<E>, QueryError>
220 where
221 E: EntityKind<Canister = C> + EntityValue,
222 {
223 let plan = query.plan()?.into_executable();
224 if plan.as_inner().grouped_plan().is_some() {
225 return Err(QueryError::Execute(
226 InternalError::query_executor_invariant(
227 "grouped plans require execute_grouped(...)",
228 ),
229 ));
230 }
231 let cursor_bytes = match cursor_token {
232 Some(token) => Some(decode_cursor(token).map_err(|reason| {
233 QueryError::from(PlanError::from(
234 CursorPlanError::InvalidContinuationCursor { reason },
235 ))
236 })?),
237 None => None,
238 };
239 let cursor = plan
240 .prepare_cursor(cursor_bytes.as_deref())
241 .map_err(map_executor_plan_error)?;
242
243 let (page, trace) = self
244 .with_metrics(|| {
245 self.load_executor::<E>()
246 .execute_paged_with_cursor_traced(plan, cursor)
247 })
248 .map_err(QueryError::Execute)?;
249 let next_cursor = page
250 .next_cursor
251 .map(|token| {
252 let Some(token) = token.as_scalar() else {
253 return Err(QueryError::Execute(
254 InternalError::query_executor_invariant(
255 "scalar load pagination emitted grouped continuation token",
256 ),
257 ));
258 };
259
260 token.encode().map_err(|err| {
261 QueryError::Execute(InternalError::serialize_internal(format!(
262 "failed to serialize continuation cursor: {err}"
263 )))
264 })
265 })
266 .transpose()?;
267
268 Ok(PagedLoadExecutionWithTrace::new(
269 page.items,
270 next_cursor,
271 trace,
272 ))
273 }
274
275 pub fn execute_grouped<E>(
280 &self,
281 query: &Query<E>,
282 cursor_token: Option<&str>,
283 ) -> Result<PagedGroupedExecutionWithTrace, QueryError>
284 where
285 E: EntityKind<Canister = C> + EntityValue,
286 {
287 let plan = query.plan()?.into_executable();
288 if plan.as_inner().grouped_plan().is_none() {
289 return Err(QueryError::Execute(
290 InternalError::query_executor_invariant(
291 "execute_grouped requires grouped logical plans",
292 ),
293 ));
294 }
295 let cursor_bytes = match cursor_token {
296 Some(token) => Some(decode_cursor(token).map_err(|reason| {
297 QueryError::from(PlanError::from(
298 CursorPlanError::InvalidContinuationCursor { reason },
299 ))
300 })?),
301 None => None,
302 };
303 let cursor = plan
304 .prepare_grouped_cursor(cursor_bytes.as_deref())
305 .map_err(map_executor_plan_error)?;
306
307 let (page, trace) = self
308 .with_metrics(|| {
309 self.load_executor::<E>()
310 .execute_grouped_paged_with_cursor_traced(plan, cursor)
311 })
312 .map_err(QueryError::Execute)?;
313 let next_cursor = page
314 .next_cursor
315 .map(|token| {
316 let Some(token) = token.as_grouped() else {
317 return Err(QueryError::Execute(
318 InternalError::query_executor_invariant(
319 "grouped pagination emitted scalar continuation token",
320 ),
321 ));
322 };
323
324 token.encode().map_err(|err| {
325 QueryError::Execute(InternalError::serialize_internal(format!(
326 "failed to serialize grouped continuation cursor: {err}"
327 )))
328 })
329 })
330 .transpose()?;
331
332 Ok(PagedGroupedExecutionWithTrace::new(
333 page.rows,
334 next_cursor,
335 trace,
336 ))
337 }
338
339 pub fn insert<E>(&self, entity: E) -> Result<WriteResponse<E>, InternalError>
344 where
345 E: EntityKind<Canister = C> + EntityValue,
346 {
347 self.execute_save_entity(|save| save.insert(entity))
348 }
349
350 pub fn insert_many_atomic<E>(
356 &self,
357 entities: impl IntoIterator<Item = E>,
358 ) -> Result<WriteBatchResponse<E>, InternalError>
359 where
360 E: EntityKind<Canister = C> + EntityValue,
361 {
362 self.execute_save_batch(|save| save.insert_many_atomic(entities))
363 }
364
365 pub fn insert_many_non_atomic<E>(
369 &self,
370 entities: impl IntoIterator<Item = E>,
371 ) -> Result<WriteBatchResponse<E>, InternalError>
372 where
373 E: EntityKind<Canister = C> + EntityValue,
374 {
375 self.execute_save_batch(|save| save.insert_many_non_atomic(entities))
376 }
377
378 pub fn replace<E>(&self, entity: E) -> Result<WriteResponse<E>, InternalError>
379 where
380 E: EntityKind<Canister = C> + EntityValue,
381 {
382 self.execute_save_entity(|save| save.replace(entity))
383 }
384
385 pub fn replace_many_atomic<E>(
391 &self,
392 entities: impl IntoIterator<Item = E>,
393 ) -> Result<WriteBatchResponse<E>, InternalError>
394 where
395 E: EntityKind<Canister = C> + EntityValue,
396 {
397 self.execute_save_batch(|save| save.replace_many_atomic(entities))
398 }
399
400 pub fn replace_many_non_atomic<E>(
404 &self,
405 entities: impl IntoIterator<Item = E>,
406 ) -> Result<WriteBatchResponse<E>, InternalError>
407 where
408 E: EntityKind<Canister = C> + EntityValue,
409 {
410 self.execute_save_batch(|save| save.replace_many_non_atomic(entities))
411 }
412
413 pub fn update<E>(&self, entity: E) -> Result<WriteResponse<E>, InternalError>
414 where
415 E: EntityKind<Canister = C> + EntityValue,
416 {
417 self.execute_save_entity(|save| save.update(entity))
418 }
419
420 pub fn update_many_atomic<E>(
426 &self,
427 entities: impl IntoIterator<Item = E>,
428 ) -> Result<WriteBatchResponse<E>, InternalError>
429 where
430 E: EntityKind<Canister = C> + EntityValue,
431 {
432 self.execute_save_batch(|save| save.update_many_atomic(entities))
433 }
434
435 pub fn update_many_non_atomic<E>(
439 &self,
440 entities: impl IntoIterator<Item = E>,
441 ) -> Result<WriteBatchResponse<E>, InternalError>
442 where
443 E: EntityKind<Canister = C> + EntityValue,
444 {
445 self.execute_save_batch(|save| save.update_many_non_atomic(entities))
446 }
447
448 pub fn insert_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
449 where
450 E: EntityKind<Canister = C> + EntityValue,
451 {
452 self.execute_save_view::<E>(|save| save.insert_view(view))
453 }
454
455 pub fn replace_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
456 where
457 E: EntityKind<Canister = C> + EntityValue,
458 {
459 self.execute_save_view::<E>(|save| save.replace_view(view))
460 }
461
462 pub fn update_view<E>(&self, view: E::ViewType) -> Result<E::ViewType, InternalError>
463 where
464 E: EntityKind<Canister = C> + EntityValue,
465 {
466 self.execute_save_view::<E>(|save| save.update_view(view))
467 }
468
469 #[cfg(test)]
471 #[doc(hidden)]
472 pub fn clear_stores_for_tests(&self) {
473 self.db.with_store_registry(|reg| {
474 for (_, store) in reg.iter() {
477 store.with_data_mut(DataStore::clear);
478 store.with_index_mut(IndexStore::clear);
479 }
480 });
481 }
482}