1use crate::{
2 db::{
3 Db, ensure_recovered,
4 executor::{ExecutorError, SaveExecutor, resolve_unique_pk},
5 store::DataKey,
6 traits::FromKey,
7 },
8 error::{ErrorOrigin, InternalError},
9 model::index::IndexModel,
10 sanitize::sanitize,
11 traits::EntityKind,
12};
13use std::marker::PhantomData;
14
15#[derive(Clone, Copy)]
21pub struct UniqueIndexHandle {
22 index: &'static IndexModel,
23}
24
25impl UniqueIndexHandle {
26 #[must_use]
27 pub const fn index(&self) -> &'static IndexModel {
29 self.index
30 }
31
32 pub fn new<E: EntityKind>(index: &'static IndexModel) -> Result<Self, InternalError> {
34 if !E::INDEXES.iter().any(|cand| **cand == *index) {
35 return Err(
36 ExecutorError::IndexNotFound(E::PATH.to_string(), index.fields.join(", ")).into(),
37 );
38 }
39
40 if !index.unique {
41 return Err(ExecutorError::IndexNotUnique(
42 E::PATH.to_string(),
43 index.fields.join(", "),
44 )
45 .into());
46 }
47
48 Ok(Self { index })
49 }
50
51 pub fn for_fields<E: EntityKind>(fields: &[&str]) -> Result<Self, InternalError> {
53 for index in E::INDEXES {
54 if index.fields == fields {
55 return Self::new::<E>(index);
56 }
57 }
58
59 Err(ExecutorError::IndexNotFound(E::PATH.to_string(), fields.join(", ")).into())
60 }
61}
62
63pub struct UpsertResult<E> {
69 pub entity: E,
70 pub inserted: bool,
71}
72
73#[derive(Clone, Copy)]
78pub struct UpsertExecutor<E: EntityKind> {
79 db: Db<E::Canister>,
80 debug: bool,
81 _marker: PhantomData<E>,
82}
83
84impl<E: EntityKind> UpsertExecutor<E>
85where
86 E::PrimaryKey: FromKey,
87{
88 #[must_use]
90 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
91 Self {
92 db,
93 debug,
94 _marker: PhantomData,
95 }
96 }
97
98 #[must_use]
100 pub const fn debug(mut self) -> Self {
101 self.debug = true;
102 self
103 }
104
105 pub fn by_unique_index(&self, index: UniqueIndexHandle, entity: E) -> Result<E, InternalError> {
107 self.upsert(index.index(), entity)
108 }
109
110 pub fn by_unique_index_merge<F>(
112 &self,
113 index: UniqueIndexHandle,
114 entity: E,
115 merge: F,
116 ) -> Result<E, InternalError>
117 where
118 F: FnOnce(E, E) -> E,
119 {
120 Ok(self
121 .by_unique_index_merge_result(index, entity, merge)?
122 .entity)
123 }
124
125 pub fn by_unique_index_merge_result<F>(
127 &self,
128 index: UniqueIndexHandle,
129 entity: E,
130 merge: F,
131 ) -> Result<UpsertResult<E>, InternalError>
132 where
133 F: FnOnce(E, E) -> E,
134 {
135 self.upsert_merge_result(index.index(), entity, merge)
136 }
137
138 pub fn by_unique_index_result(
140 &self,
141 index: UniqueIndexHandle,
142 entity: E,
143 ) -> Result<UpsertResult<E>, InternalError> {
144 self.upsert_result(index.index(), entity)
145 }
146
147 pub fn by_unique_fields(&self, fields: &[&str], entity: E) -> Result<E, InternalError> {
149 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
150 self.upsert(index.index(), entity)
151 }
152
153 pub fn by_unique_fields_merge<F>(
155 &self,
156 fields: &[&str],
157 entity: E,
158 merge: F,
159 ) -> Result<E, InternalError>
160 where
161 F: FnOnce(E, E) -> E,
162 {
163 Ok(self
164 .by_unique_fields_merge_result(fields, entity, merge)?
165 .entity)
166 }
167
168 pub fn by_unique_fields_merge_result<F>(
170 &self,
171 fields: &[&str],
172 entity: E,
173 merge: F,
174 ) -> Result<UpsertResult<E>, InternalError>
175 where
176 F: FnOnce(E, E) -> E,
177 {
178 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
179 self.upsert_merge_result(index.index(), entity, merge)
180 }
181
182 pub fn by_unique_fields_result(
184 &self,
185 fields: &[&str],
186 entity: E,
187 ) -> Result<UpsertResult<E>, InternalError> {
188 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
189 self.upsert_result(index.index(), entity)
190 }
191
192 fn resolve_existing_pk(
201 &self,
202 index: &'static IndexModel,
203 entity: &E,
204 ) -> Result<Option<E::PrimaryKey>, InternalError> {
205 let mut lookup = entity.clone();
206 sanitize(&mut lookup)?;
207 resolve_unique_pk::<E>(&self.db, index, &lookup)
208 }
209
210 fn upsert(&self, index: &'static IndexModel, entity: E) -> Result<E, InternalError> {
211 Ok(self.upsert_result(index, entity)?.entity)
212 }
213
214 fn upsert_result(
215 &self,
216 index: &'static IndexModel,
217 entity: E,
218 ) -> Result<UpsertResult<E>, InternalError> {
219 ensure_recovered(&self.db)?;
221 let existing_pk = self.resolve_existing_pk(index, &entity)?;
222 let inserted = existing_pk.is_none();
223
224 let saver = SaveExecutor::new(self.db, self.debug);
226
227 let entity = match existing_pk {
228 Some(pk) => {
229 let mut entity = entity;
230 entity.set_primary_key(pk);
231 saver.update(entity)?
232 }
233 None => saver.insert(entity)?,
234 };
235
236 Ok(UpsertResult { entity, inserted })
237 }
238
239 fn upsert_merge_result<F>(
240 &self,
241 index: &'static IndexModel,
242 entity: E,
243 merge: F,
244 ) -> Result<UpsertResult<E>, InternalError>
245 where
246 F: FnOnce(E, E) -> E,
247 {
248 ensure_recovered(&self.db)?;
250 let existing_pk = self.resolve_existing_pk(index, &entity)?;
251
252 let saver = SaveExecutor::new(self.db, self.debug);
254
255 let result = if let Some(pk) = existing_pk {
256 let existing = self.load_existing(index, pk)?;
258 let mut merged = merge(existing, entity);
259 merged.set_primary_key(pk);
260
261 let entity = saver.update(merged)?;
262 UpsertResult {
263 entity,
264 inserted: false,
265 }
266 } else {
267 let entity = saver.insert(entity)?;
268 UpsertResult {
269 entity,
270 inserted: true,
271 }
272 };
273
274 Ok(result)
275 }
276
277 fn load_existing(
278 &self,
279 index: &'static IndexModel,
280 pk: E::PrimaryKey,
281 ) -> Result<E, InternalError> {
282 let data_key = DataKey::new::<E>(pk.into());
283 let raw_data_key = data_key.to_raw();
284 let row = self
285 .db
286 .context::<E>()
287 .with_store(|store| store.get(&raw_data_key))?;
288
289 let Some(row) = row else {
290 return Err(ExecutorError::corruption(
292 ErrorOrigin::Index,
293 format!(
294 "index corrupted: {} ({}) -> {} keys",
295 E::PATH,
296 index.fields.join(", "),
297 1
298 ),
299 )
300 .into());
301 };
302
303 row.try_decode::<E>().map_err(|err| {
304 ExecutorError::corruption(
305 ErrorOrigin::Serialize,
306 format!("failed to deserialize row: {data_key} ({err})"),
307 )
308 .into()
309 })
310 }
311}