1use crate::{
2 db::{
3 Db, ensure_recovered,
4 executor::{
5 ExecutorError, SaveExecutor, resolve_unique_pk,
6 trace::{QueryTraceSink, TraceAccess, TraceExecutorKind, start_exec_trace},
7 },
8 store::DataKey,
9 traits::FromKey,
10 },
11 error::{ErrorOrigin, InternalError},
12 model::index::IndexModel,
13 sanitize::sanitize,
14 traits::EntityKind,
15};
16use std::marker::PhantomData;
17
18#[derive(Clone, Copy)]
24pub struct UniqueIndexHandle {
25 index: &'static IndexModel,
26}
27
28impl UniqueIndexHandle {
29 #[must_use]
30 pub const fn index(&self) -> &'static IndexModel {
32 self.index
33 }
34
35 pub fn new<E: EntityKind>(index: &'static IndexModel) -> Result<Self, InternalError> {
37 if !E::INDEXES.iter().any(|cand| **cand == *index) {
38 return Err(
39 ExecutorError::IndexNotFound(E::PATH.to_string(), index.fields.join(", ")).into(),
40 );
41 }
42
43 if !index.unique {
44 return Err(ExecutorError::IndexNotUnique(
45 E::PATH.to_string(),
46 index.fields.join(", "),
47 )
48 .into());
49 }
50
51 Ok(Self { index })
52 }
53
54 pub fn for_fields<E: EntityKind>(fields: &[&str]) -> Result<Self, InternalError> {
56 for index in E::INDEXES {
57 if index.fields == fields {
58 return Self::new::<E>(index);
59 }
60 }
61
62 Err(ExecutorError::IndexNotFound(E::PATH.to_string(), fields.join(", ")).into())
63 }
64}
65
66pub struct UpsertResult<E> {
73 pub entity: E,
74 pub inserted: bool,
75}
76
77#[derive(Clone, Copy)]
82pub struct UpsertExecutor<E: EntityKind> {
83 db: Db<E::Canister>,
84 debug: bool,
85 trace: Option<&'static dyn QueryTraceSink>,
86 _marker: PhantomData<E>,
87}
88
89impl<E: EntityKind> UpsertExecutor<E>
90where
91 E::PrimaryKey: FromKey,
92{
93 #[must_use]
95 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
96 Self {
97 db,
98 debug,
99 trace: None,
100 _marker: PhantomData,
101 }
102 }
103
104 #[must_use]
105 #[allow(dead_code)]
106 pub(crate) const fn with_trace_sink(
107 mut self,
108 sink: Option<&'static dyn QueryTraceSink>,
109 ) -> Self {
110 self.trace = sink;
111 self
112 }
113
114 #[must_use]
116 pub const fn debug(mut self) -> Self {
117 self.debug = true;
118 self
119 }
120
121 fn debug_log(&self, s: impl Into<String>) {
122 if self.debug {
123 println!("{}", s.into());
124 }
125 }
126
127 pub fn by_unique_index(&self, index: UniqueIndexHandle, entity: E) -> Result<E, InternalError> {
129 self.upsert(index.index(), entity)
130 }
131
132 pub fn by_unique_index_merge<F>(
134 &self,
135 index: UniqueIndexHandle,
136 entity: E,
137 merge: F,
138 ) -> Result<E, InternalError>
139 where
140 F: FnOnce(E, E) -> E,
141 {
142 Ok(self
143 .by_unique_index_merge_result(index, entity, merge)?
144 .entity)
145 }
146
147 pub fn by_unique_index_merge_result<F>(
149 &self,
150 index: UniqueIndexHandle,
151 entity: E,
152 merge: F,
153 ) -> Result<UpsertResult<E>, InternalError>
154 where
155 F: FnOnce(E, E) -> E,
156 {
157 self.upsert_merge_result(index.index(), entity, merge)
158 }
159
160 pub fn by_unique_index_result(
162 &self,
163 index: UniqueIndexHandle,
164 entity: E,
165 ) -> Result<UpsertResult<E>, InternalError> {
166 self.upsert_result(index.index(), entity)
167 }
168
169 pub fn by_unique_fields(&self, fields: &[&str], entity: E) -> Result<E, InternalError> {
171 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
172 self.upsert(index.index(), entity)
173 }
174
175 pub fn by_unique_fields_merge<F>(
177 &self,
178 fields: &[&str],
179 entity: E,
180 merge: F,
181 ) -> Result<E, InternalError>
182 where
183 F: FnOnce(E, E) -> E,
184 {
185 Ok(self
186 .by_unique_fields_merge_result(fields, entity, merge)?
187 .entity)
188 }
189
190 pub fn by_unique_fields_merge_result<F>(
192 &self,
193 fields: &[&str],
194 entity: E,
195 merge: F,
196 ) -> Result<UpsertResult<E>, InternalError>
197 where
198 F: FnOnce(E, E) -> E,
199 {
200 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
201 self.upsert_merge_result(index.index(), entity, merge)
202 }
203
204 pub fn by_unique_fields_result(
206 &self,
207 fields: &[&str],
208 entity: E,
209 ) -> Result<UpsertResult<E>, InternalError> {
210 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
211 self.upsert_result(index.index(), entity)
212 }
213
214 fn resolve_existing_pk(
223 &self,
224 index: &'static IndexModel,
225 entity: &E,
226 ) -> Result<Option<E::PrimaryKey>, InternalError> {
227 let mut lookup = entity.clone();
228 sanitize(&mut lookup)?;
229 resolve_unique_pk::<E>(&self.db, index, &lookup)
230 }
231
232 fn upsert(&self, index: &'static IndexModel, entity: E) -> Result<E, InternalError> {
233 Ok(self.upsert_result(index, entity)?.entity)
234 }
235
236 fn upsert_result(
237 &self,
238 index: &'static IndexModel,
239 entity: E,
240 ) -> Result<UpsertResult<E>, InternalError> {
241 let trace = start_exec_trace(
242 self.trace,
243 TraceExecutorKind::Upsert,
244 E::PATH,
245 Some(TraceAccess::UniqueIndex { name: index.name }),
246 Some(index.name),
247 );
248 let result = (|| {
249 self.debug_log(format!(
250 "[debug] upsert on {} (unique index: {})",
251 E::PATH,
252 index.fields.join(", ")
253 ));
254 ensure_recovered(&self.db)?;
256 let existing_pk = self.resolve_existing_pk(index, &entity)?;
257 let inserted = existing_pk.is_none();
258
259 let saver = SaveExecutor::new(self.db, self.debug);
261
262 let entity = match existing_pk {
263 Some(pk) => {
264 let mut entity = entity;
265 entity.set_primary_key(pk);
266 saver.update(entity)?
267 }
268 None => saver.insert(entity)?,
269 };
270
271 Ok(UpsertResult { entity, inserted })
272 })();
273
274 if let Some(trace) = trace {
275 match &result {
276 Ok(_) => trace.finish(1),
277 Err(err) => trace.error(err),
278 }
279 }
280
281 result
282 }
283
284 fn upsert_merge_result<F>(
285 &self,
286 index: &'static IndexModel,
287 entity: E,
288 merge: F,
289 ) -> Result<UpsertResult<E>, InternalError>
290 where
291 F: FnOnce(E, E) -> E,
292 {
293 let trace = start_exec_trace(
294 self.trace,
295 TraceExecutorKind::Upsert,
296 E::PATH,
297 Some(TraceAccess::UniqueIndex { name: index.name }),
298 Some(index.name),
299 );
300 let result = (|| {
301 self.debug_log(format!(
302 "[debug] upsert merge on {} (unique index: {})",
303 E::PATH,
304 index.fields.join(", ")
305 ));
306 ensure_recovered(&self.db)?;
308 let existing_pk = self.resolve_existing_pk(index, &entity)?;
309
310 let saver = SaveExecutor::new(self.db, self.debug);
312
313 let result = if let Some(pk) = existing_pk {
314 let existing = self.load_existing(index, pk)?;
316 let mut merged = merge(existing, entity);
317 merged.set_primary_key(pk);
318
319 let entity = saver.update(merged)?;
320 UpsertResult {
321 entity,
322 inserted: false,
323 }
324 } else {
325 let entity = saver.insert(entity)?;
326 UpsertResult {
327 entity,
328 inserted: true,
329 }
330 };
331
332 Ok(result)
333 })();
334
335 if let Some(trace) = trace {
336 match &result {
337 Ok(_) => trace.finish(1),
338 Err(err) => trace.error(err),
339 }
340 }
341
342 result
343 }
344
345 fn load_existing(
346 &self,
347 index: &'static IndexModel,
348 pk: E::PrimaryKey,
349 ) -> Result<E, InternalError> {
350 let data_key = DataKey::new::<E>(pk.into());
351 let raw_data_key = data_key.to_raw()?;
352 let row = self
353 .db
354 .context::<E>()
355 .with_store(|store| store.get(&raw_data_key))?;
356
357 let Some(row) = row else {
358 return Err(ExecutorError::corruption(
360 ErrorOrigin::Index,
361 format!(
362 "index corrupted: {} ({}) -> {} keys",
363 E::PATH,
364 index.fields.join(", "),
365 1
366 ),
367 )
368 .into());
369 };
370
371 row.try_decode::<E>().map_err(|err| {
372 ExecutorError::corruption(
373 ErrorOrigin::Serialize,
374 format!("failed to deserialize row: {data_key} ({err})"),
375 )
376 .into()
377 })
378 }
379}