1use crate::{
2 Error, IndexSpec,
3 db::{
4 Db,
5 executor::{ExecutorError, SaveExecutor, resolve_unique_pk},
6 store::DataKey,
7 },
8 deserialize,
9 traits::{EntityKind, FromKey},
10 visitor::sanitize,
11};
12use std::marker::PhantomData;
13
14#[derive(Clone, Copy)]
20pub struct UniqueIndexHandle {
21 index: &'static IndexSpec,
22}
23
24impl UniqueIndexHandle {
25 #[must_use]
26 pub const fn index(&self) -> &'static IndexSpec {
28 self.index
29 }
30
31 pub fn new<E: EntityKind>(index: &'static IndexSpec) -> Result<Self, Error> {
33 if !E::INDEXES.iter().any(|cand| **cand == *index) {
34 return Err(
35 ExecutorError::IndexNotFound(E::PATH.to_string(), index.fields.join(", ")).into(),
36 );
37 }
38
39 if !index.unique {
40 return Err(ExecutorError::IndexNotUnique(
41 E::PATH.to_string(),
42 index.fields.join(", "),
43 )
44 .into());
45 }
46
47 Ok(Self { index })
48 }
49
50 pub fn for_fields<E: EntityKind>(fields: &[&str]) -> Result<Self, Error> {
52 for index in E::INDEXES {
53 if index.fields == fields {
54 return Self::new::<E>(index);
55 }
56 }
57
58 Err(ExecutorError::IndexNotFound(E::PATH.to_string(), fields.join(", ")).into())
59 }
60}
61
62pub struct UpsertResult<E> {
68 pub entity: E,
69 pub inserted: bool,
70}
71
72#[derive(Clone, Copy)]
77pub struct UpsertExecutor<E: EntityKind> {
78 db: Db<E::Canister>,
79 debug: bool,
80 _marker: PhantomData<E>,
81}
82
83impl<E: EntityKind> UpsertExecutor<E>
84where
85 E::PrimaryKey: FromKey,
86{
87 #[must_use]
89 pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
90 Self {
91 db,
92 debug,
93 _marker: PhantomData,
94 }
95 }
96
97 #[must_use]
99 pub const fn debug(mut self) -> Self {
100 self.debug = true;
101 self
102 }
103
104 pub fn by_unique_index(&self, index: UniqueIndexHandle, entity: E) -> Result<E, Error> {
106 self.upsert(index.index(), entity)
107 }
108
109 pub fn by_unique_index_merge<F>(
111 &self,
112 index: UniqueIndexHandle,
113 entity: E,
114 merge: F,
115 ) -> Result<E, Error>
116 where
117 F: FnOnce(E, E) -> E,
118 {
119 Ok(self
120 .by_unique_index_merge_result(index, entity, merge)?
121 .entity)
122 }
123
124 pub fn by_unique_index_merge_result<F>(
126 &self,
127 index: UniqueIndexHandle,
128 entity: E,
129 merge: F,
130 ) -> Result<UpsertResult<E>, Error>
131 where
132 F: FnOnce(E, E) -> E,
133 {
134 self.upsert_merge_result(index.index(), entity, merge)
135 }
136
137 pub fn by_unique_index_result(
139 &self,
140 index: UniqueIndexHandle,
141 entity: E,
142 ) -> Result<UpsertResult<E>, Error> {
143 self.upsert_result(index.index(), entity)
144 }
145
146 pub fn by_unique_fields(&self, fields: &[&str], entity: E) -> Result<E, Error> {
148 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
149 self.upsert(index.index(), entity)
150 }
151
152 pub fn by_unique_fields_merge<F>(
154 &self,
155 fields: &[&str],
156 entity: E,
157 merge: F,
158 ) -> Result<E, Error>
159 where
160 F: FnOnce(E, E) -> E,
161 {
162 Ok(self
163 .by_unique_fields_merge_result(fields, entity, merge)?
164 .entity)
165 }
166
167 pub fn by_unique_fields_merge_result<F>(
169 &self,
170 fields: &[&str],
171 entity: E,
172 merge: F,
173 ) -> Result<UpsertResult<E>, Error>
174 where
175 F: FnOnce(E, E) -> E,
176 {
177 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
178 self.upsert_merge_result(index.index(), entity, merge)
179 }
180
181 pub fn by_unique_fields_result(
183 &self,
184 fields: &[&str],
185 entity: E,
186 ) -> Result<UpsertResult<E>, Error> {
187 let index = UniqueIndexHandle::for_fields::<E>(fields)?;
188 self.upsert_result(index.index(), entity)
189 }
190
191 fn resolve_existing_pk(
200 &self,
201 index: &'static IndexSpec,
202 entity: &E,
203 ) -> Result<Option<E::PrimaryKey>, Error> {
204 let mut lookup = entity.clone();
205 sanitize(&mut lookup);
206 resolve_unique_pk::<E>(&self.db, index, &lookup)
207 }
208
209 fn upsert(&self, index: &'static IndexSpec, entity: E) -> Result<E, Error> {
210 Ok(self.upsert_result(index, entity)?.entity)
211 }
212
213 fn upsert_result(
214 &self,
215 index: &'static IndexSpec,
216 entity: E,
217 ) -> Result<UpsertResult<E>, Error> {
218 let existing_pk = self.resolve_existing_pk(index, &entity)?;
219 let inserted = existing_pk.is_none();
220
221 let saver = SaveExecutor::new(self.db, self.debug);
223
224 let entity = match existing_pk {
225 Some(pk) => {
226 let mut entity = entity;
227 entity.set_primary_key(pk);
228 saver.update(entity)?
229 }
230 None => saver.insert(entity)?,
231 };
232
233 Ok(UpsertResult { entity, inserted })
234 }
235
236 fn upsert_merge_result<F>(
237 &self,
238 index: &'static IndexSpec,
239 entity: E,
240 merge: F,
241 ) -> Result<UpsertResult<E>, Error>
242 where
243 F: FnOnce(E, E) -> E,
244 {
245 let existing_pk = self.resolve_existing_pk(index, &entity)?;
246
247 let saver = SaveExecutor::new(self.db, self.debug);
249
250 let result = if let Some(pk) = existing_pk {
251 let existing = self.load_existing(index, pk)?;
253 let mut merged = merge(existing, entity);
254 merged.set_primary_key(pk);
255
256 let entity = saver.update(merged)?;
257 UpsertResult {
258 entity,
259 inserted: false,
260 }
261 } else {
262 let entity = saver.insert(entity)?;
263 UpsertResult {
264 entity,
265 inserted: true,
266 }
267 };
268
269 Ok(result)
270 }
271
272 fn load_existing(&self, index: &'static IndexSpec, pk: E::PrimaryKey) -> Result<E, Error> {
273 let data_key = DataKey::new::<E>(pk.into());
274 let bytes = self
275 .db
276 .context::<E>()
277 .with_store(|store| store.get(&data_key))?;
278
279 let Some(bytes) = bytes else {
280 return Err(ExecutorError::IndexCorrupted(
282 E::PATH.to_string(),
283 index.fields.join(", "),
284 1,
285 )
286 .into());
287 };
288
289 deserialize::<E>(&bytes)
290 }
291}