1use crate::{
2 db::{
3 Db, ensure_recovered,
4 executor::{ExecutorError, SaveExecutor, resolve_unique_pk},
5 store::DataKey,
6 traits::FromKey,
7 },
8 error::{ErrorOrigin, InternalError},
9 model::index::IndexModel,
10 sanitize::sanitize,
11 traits::EntityKind,
12};
13use std::marker::PhantomData;
14
15#[derive(Clone, Copy)]
21pub struct UniqueIndexHandle {
22 index: &'static IndexModel,
23}
24
25impl UniqueIndexHandle {
26 #[must_use]
27 pub const fn index(&self) -> &'static IndexModel {
29 self.index
30 }
31
32 pub fn new<E: EntityKind>(index: &'static IndexModel) -> Result<Self, InternalError> {
34 if !E::INDEXES.iter().any(|cand| **cand == *index) {
35 return Err(
36 ExecutorError::IndexNotFound(E::PATH.to_string(), index.fields.join(", ")).into(),
37 );
38 }
39
40 if !index.unique {
41 return Err(ExecutorError::IndexNotUnique(
42 E::PATH.to_string(),
43 index.fields.join(", "),
44 )
45 .into());
46 }
47
48 Ok(Self { index })
49 }
50
51 pub fn for_fields<E: EntityKind>(fields: &[&str]) -> Result<Self, InternalError> {
53 for index in E::INDEXES {
54 if index.fields == fields {
55 return Self::new::<E>(index);
56 }
57 }
58
59 Err(ExecutorError::IndexNotFound(E::PATH.to_string(), fields.join(", ")).into())
60 }
61}
62
63pub struct UpsertResult<E> {
69 pub entity: E,
70 pub inserted: bool,
71}
72
73#[derive(Clone, Copy)]
78pub struct UpsertExecutor<E: EntityKind> {
79 db: Db<E::Canister>,
80 debug: bool,
81 _marker: PhantomData<E>,
82}
83
84impl<E: EntityKind> UpsertExecutor<E>
85where
86 E::PrimaryKey: FromKey,
87{
88 #[must_use]
90 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
91 Self {
92 db,
93 debug,
94 _marker: PhantomData,
95 }
96 }
97
98 #[must_use]
100 pub const fn debug(mut self) -> Self {
101 self.debug = true;
102 self
103 }
104
105 fn debug_log(&self, s: impl Into<String>) {
106 if self.debug {
107 println!("{}", s.into());
108 }
109 }
110
111 pub fn by_unique_index(&self, index: UniqueIndexHandle, entity: E) -> Result<E, InternalError> {
113 self.upsert(index.index(), entity)
114 }
115
116 pub fn by_unique_index_merge<F>(
118 &self,
119 index: UniqueIndexHandle,
120 entity: E,
121 merge: F,
122 ) -> Result<E, InternalError>
123 where
124 F: FnOnce(E, E) -> E,
125 {
126 Ok(self
127 .by_unique_index_merge_result(index, entity, merge)?
128 .entity)
129 }
130
131 pub fn by_unique_index_merge_result<F>(
133 &self,
134 index: UniqueIndexHandle,
135 entity: E,
136 merge: F,
137 ) -> Result<UpsertResult<E>, InternalError>
138 where
139 F: FnOnce(E, E) -> E,
140 {
141 self.upsert_merge_result(index.index(), entity, merge)
142 }
143
144 pub fn by_unique_index_result(
146 &self,
147 index: UniqueIndexHandle,
148 entity: E,
149 ) -> Result<UpsertResult<E>, InternalError> {
150 self.upsert_result(index.index(), entity)
151 }
152
153 pub fn by_unique_fields(&self, fields: &[&str], entity: E) -> Result<E, InternalError> {
155 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
156 self.upsert(index.index(), entity)
157 }
158
159 pub fn by_unique_fields_merge<F>(
161 &self,
162 fields: &[&str],
163 entity: E,
164 merge: F,
165 ) -> Result<E, InternalError>
166 where
167 F: FnOnce(E, E) -> E,
168 {
169 Ok(self
170 .by_unique_fields_merge_result(fields, entity, merge)?
171 .entity)
172 }
173
174 pub fn by_unique_fields_merge_result<F>(
176 &self,
177 fields: &[&str],
178 entity: E,
179 merge: F,
180 ) -> Result<UpsertResult<E>, InternalError>
181 where
182 F: FnOnce(E, E) -> E,
183 {
184 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
185 self.upsert_merge_result(index.index(), entity, merge)
186 }
187
188 pub fn by_unique_fields_result(
190 &self,
191 fields: &[&str],
192 entity: E,
193 ) -> Result<UpsertResult<E>, InternalError> {
194 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
195 self.upsert_result(index.index(), entity)
196 }
197
198 fn resolve_existing_pk(
207 &self,
208 index: &'static IndexModel,
209 entity: &E,
210 ) -> Result<Option<E::PrimaryKey>, InternalError> {
211 let mut lookup = entity.clone();
212 sanitize(&mut lookup)?;
213 resolve_unique_pk::<E>(&self.db, index, &lookup)
214 }
215
216 fn upsert(&self, index: &'static IndexModel, entity: E) -> Result<E, InternalError> {
217 Ok(self.upsert_result(index, entity)?.entity)
218 }
219
220 fn upsert_result(
221 &self,
222 index: &'static IndexModel,
223 entity: E,
224 ) -> Result<UpsertResult<E>, InternalError> {
225 self.debug_log(format!(
226 "[debug] upsert on {} (unique index: {})",
227 E::PATH,
228 index.fields.join(", ")
229 ));
230 ensure_recovered(&self.db)?;
232 let existing_pk = self.resolve_existing_pk(index, &entity)?;
233 let inserted = existing_pk.is_none();
234
235 let saver = SaveExecutor::new(self.db, self.debug);
237
238 let entity = match existing_pk {
239 Some(pk) => {
240 let mut entity = entity;
241 entity.set_primary_key(pk);
242 saver.update(entity)?
243 }
244 None => saver.insert(entity)?,
245 };
246
247 Ok(UpsertResult { entity, inserted })
248 }
249
250 fn upsert_merge_result<F>(
251 &self,
252 index: &'static IndexModel,
253 entity: E,
254 merge: F,
255 ) -> Result<UpsertResult<E>, InternalError>
256 where
257 F: FnOnce(E, E) -> E,
258 {
259 self.debug_log(format!(
260 "[debug] upsert merge on {} (unique index: {})",
261 E::PATH,
262 index.fields.join(", ")
263 ));
264 ensure_recovered(&self.db)?;
266 let existing_pk = self.resolve_existing_pk(index, &entity)?;
267
268 let saver = SaveExecutor::new(self.db, self.debug);
270
271 let result = if let Some(pk) = existing_pk {
272 let existing = self.load_existing(index, pk)?;
274 let mut merged = merge(existing, entity);
275 merged.set_primary_key(pk);
276
277 let entity = saver.update(merged)?;
278 UpsertResult {
279 entity,
280 inserted: false,
281 }
282 } else {
283 let entity = saver.insert(entity)?;
284 UpsertResult {
285 entity,
286 inserted: true,
287 }
288 };
289
290 Ok(result)
291 }
292
293 fn load_existing(
294 &self,
295 index: &'static IndexModel,
296 pk: E::PrimaryKey,
297 ) -> Result<E, InternalError> {
298 let data_key = DataKey::new::<E>(pk.into());
299 let raw_data_key = data_key.to_raw();
300 let row = self
301 .db
302 .context::<E>()
303 .with_store(|store| store.get(&raw_data_key))?;
304
305 let Some(row) = row else {
306 return Err(ExecutorError::corruption(
308 ErrorOrigin::Index,
309 format!(
310 "index corrupted: {} ({}) -> {} keys",
311 E::PATH,
312 index.fields.join(", "),
313 1
314 ),
315 )
316 .into());
317 };
318
319 row.try_decode::<E>().map_err(|err| {
320 ExecutorError::corruption(
321 ErrorOrigin::Serialize,
322 format!("failed to deserialize row: {data_key} ({err})"),
323 )
324 .into()
325 })
326 }
327}