1use crate::{
2 Error, IndexSpec,
3 db::{
4 Db,
5 executor::{ExecutorError, SaveExecutor, resolve_unique_pk},
6 store::DataKey,
7 },
8 deserialize, sanitize,
9 traits::{EntityKind, FromKey},
10};
11use std::marker::PhantomData;
12
13#[derive(Clone, Copy)]
19pub struct UniqueIndexHandle {
20 index: &'static IndexSpec,
21}
22
23impl UniqueIndexHandle {
24 #[must_use]
25 pub const fn index(&self) -> &'static IndexSpec {
27 self.index
28 }
29
30 pub fn new<E: EntityKind>(index: &'static IndexSpec) -> Result<Self, Error> {
32 if !E::INDEXES.iter().any(|cand| **cand == *index) {
33 return Err(
34 ExecutorError::IndexNotFound(E::PATH.to_string(), index.fields.join(", ")).into(),
35 );
36 }
37
38 if !index.unique {
39 return Err(ExecutorError::IndexNotUnique(
40 E::PATH.to_string(),
41 index.fields.join(", "),
42 )
43 .into());
44 }
45
46 Ok(Self { index })
47 }
48
49 pub fn for_fields<E: EntityKind>(fields: &[&str]) -> Result<Self, Error> {
51 for index in E::INDEXES {
52 if index.fields == fields {
53 return Self::new::<E>(index);
54 }
55 }
56
57 Err(ExecutorError::IndexNotFound(E::PATH.to_string(), fields.join(", ")).into())
58 }
59}
60
61pub struct UpsertResult<E> {
67 pub entity: E,
68 pub inserted: bool,
69}
70
71#[derive(Clone, Copy)]
76pub struct UpsertExecutor<E: EntityKind> {
77 db: Db<E::Canister>,
78 debug: bool,
79 _marker: PhantomData<E>,
80}
81
82impl<E: EntityKind> UpsertExecutor<E>
83where
84 E::PrimaryKey: FromKey,
85{
86 #[must_use]
88 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
89 Self {
90 db,
91 debug,
92 _marker: PhantomData,
93 }
94 }
95
96 #[must_use]
98 pub const fn debug(mut self) -> Self {
99 self.debug = true;
100 self
101 }
102
103 pub fn by_unique_index(&self, index: UniqueIndexHandle, entity: E) -> Result<E, Error> {
105 self.upsert(index.index(), entity)
106 }
107
108 pub fn by_unique_index_merge<F>(
110 &self,
111 index: UniqueIndexHandle,
112 entity: E,
113 merge: F,
114 ) -> Result<E, Error>
115 where
116 F: FnOnce(E, E) -> E,
117 {
118 Ok(self
119 .by_unique_index_merge_result(index, entity, merge)?
120 .entity)
121 }
122
123 pub fn by_unique_index_merge_result<F>(
125 &self,
126 index: UniqueIndexHandle,
127 entity: E,
128 merge: F,
129 ) -> Result<UpsertResult<E>, Error>
130 where
131 F: FnOnce(E, E) -> E,
132 {
133 self.upsert_merge_result(index.index(), entity, merge)
134 }
135
136 pub fn by_unique_index_result(
138 &self,
139 index: UniqueIndexHandle,
140 entity: E,
141 ) -> Result<UpsertResult<E>, Error> {
142 self.upsert_result(index.index(), entity)
143 }
144
145 pub fn by_unique_fields(&self, fields: &[&str], entity: E) -> Result<E, Error> {
147 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
148 self.upsert(index.index(), entity)
149 }
150
151 pub fn by_unique_fields_merge<F>(
153 &self,
154 fields: &[&str],
155 entity: E,
156 merge: F,
157 ) -> Result<E, Error>
158 where
159 F: FnOnce(E, E) -> E,
160 {
161 Ok(self
162 .by_unique_fields_merge_result(fields, entity, merge)?
163 .entity)
164 }
165
166 pub fn by_unique_fields_merge_result<F>(
168 &self,
169 fields: &[&str],
170 entity: E,
171 merge: F,
172 ) -> Result<UpsertResult<E>, Error>
173 where
174 F: FnOnce(E, E) -> E,
175 {
176 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
177 self.upsert_merge_result(index.index(), entity, merge)
178 }
179
180 pub fn by_unique_fields_result(
182 &self,
183 fields: &[&str],
184 entity: E,
185 ) -> Result<UpsertResult<E>, Error> {
186 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
187 self.upsert_result(index.index(), entity)
188 }
189
190 fn resolve_existing_pk(
199 &self,
200 index: &'static IndexSpec,
201 entity: &E,
202 ) -> Result<Option<E::PrimaryKey>, Error> {
203 let mut lookup = entity.clone();
204 sanitize(&mut lookup)?;
205 resolve_unique_pk::<E>(&self.db, index, &lookup)
206 }
207
208 fn upsert(&self, index: &'static IndexSpec, entity: E) -> Result<E, Error> {
209 Ok(self.upsert_result(index, entity)?.entity)
210 }
211
212 fn upsert_result(
213 &self,
214 index: &'static IndexSpec,
215 entity: E,
216 ) -> Result<UpsertResult<E>, Error> {
217 let existing_pk = self.resolve_existing_pk(index, &entity)?;
218 let inserted = existing_pk.is_none();
219
220 let saver = SaveExecutor::new(self.db, self.debug);
222
223 let entity = match existing_pk {
224 Some(pk) => {
225 let mut entity = entity;
226 entity.set_primary_key(pk);
227 saver.update(entity)?
228 }
229 None => saver.insert(entity)?,
230 };
231
232 Ok(UpsertResult { entity, inserted })
233 }
234
235 fn upsert_merge_result<F>(
236 &self,
237 index: &'static IndexSpec,
238 entity: E,
239 merge: F,
240 ) -> Result<UpsertResult<E>, Error>
241 where
242 F: FnOnce(E, E) -> E,
243 {
244 let existing_pk = self.resolve_existing_pk(index, &entity)?;
245
246 let saver = SaveExecutor::new(self.db, self.debug);
248
249 let result = if let Some(pk) = existing_pk {
250 let existing = self.load_existing(index, pk)?;
252 let mut merged = merge(existing, entity);
253 merged.set_primary_key(pk);
254
255 let entity = saver.update(merged)?;
256 UpsertResult {
257 entity,
258 inserted: false,
259 }
260 } else {
261 let entity = saver.insert(entity)?;
262 UpsertResult {
263 entity,
264 inserted: true,
265 }
266 };
267
268 Ok(result)
269 }
270
271 fn load_existing(&self, index: &'static IndexSpec, pk: E::PrimaryKey) -> Result<E, Error> {
272 let data_key = DataKey::new::<E>(pk.into());
273 let bytes = self
274 .db
275 .context::<E>()
276 .with_store(|store| store.get(&data_key))?;
277
278 let Some(bytes) = bytes else {
279 return Err(ExecutorError::IndexCorrupted(
281 E::PATH.to_string(),
282 index.fields.join(", "),
283 1,
284 )
285 .into());
286 };
287
288 deserialize::<E>(&bytes)
289 }
290}