icydb_core/db/executor/
upsert.rs

1use crate::{
2    Error, IndexSpec, Key,
3    db::{
4        Db,
5        executor::{ExecutorError, SaveExecutor},
6        store::{DataKey, IndexKey},
7    },
8    traits::EntityKind,
9};
10use std::{any::type_name, marker::PhantomData};
11
12///
13/// UniqueIndexSpec
14///
15
16#[derive(Clone, Copy)]
17pub struct UniqueIndexSpec {
18    index: &'static IndexSpec,
19}
20
21impl UniqueIndexSpec {
22    #[must_use]
23    /// Return the underlying index specification.
24    pub const fn index(&self) -> &'static IndexSpec {
25        self.index
26    }
27
28    /// Wrap a unique index for the given entity type.
29    pub fn new<E: EntityKind>(index: &'static IndexSpec) -> Result<Self, Error> {
30        if !E::INDEXES.iter().any(|cand| **cand == *index) {
31            return Err(
32                ExecutorError::IndexNotFound(E::PATH.to_string(), index.fields.join(", ")).into(),
33            );
34        }
35
36        if !index.unique {
37            return Err(ExecutorError::IndexNotUnique(
38                E::PATH.to_string(),
39                index.fields.join(", "),
40            )
41            .into());
42        }
43
44        Ok(Self { index })
45    }
46
47    /// Resolve a unique index by its field list for the given entity type.
48    pub fn for_fields<E: EntityKind>(fields: &[&str]) -> Result<Self, Error> {
49        for index in E::INDEXES {
50            if index.fields == fields {
51                return Self::new::<E>(index);
52            }
53        }
54
55        Err(ExecutorError::IndexNotFound(E::PATH.to_string(), fields.join(", ")).into())
56    }
57}
58
59///
60/// UpsertExecutor
61///
62
63#[derive(Clone, Copy)]
64pub struct UpsertExecutor<E: EntityKind> {
65    db: Db<E::Canister>,
66    debug: bool,
67    _marker: PhantomData<E>,
68}
69
70impl<E: EntityKind> UpsertExecutor<E>
71where
72    E::PrimaryKey: PrimaryKeyFromKey,
73{
74    /// Construct a new upsert executor.
75    #[must_use]
76    pub const fn new(db: Db<E::Canister>, debug: bool) -> Self {
77        Self {
78            db,
79            debug,
80            _marker: PhantomData,
81        }
82    }
83
84    /// Enable debug logging for subsequent upsert operations.
85    #[must_use]
86    pub const fn debug(mut self) -> Self {
87        self.debug = true;
88        self
89    }
90
91    /// Upsert using a unique index specification.
92    pub fn by_unique_index(&self, index: UniqueIndexSpec, entity: E) -> Result<E, Error> {
93        self.upsert(index.index(), entity)
94    }
95
96    /// Upsert a view using a unique index specification.
97    pub fn by_unique_index_view(
98        &self,
99        index: UniqueIndexSpec,
100        view: E::ViewType,
101    ) -> Result<E::ViewType, Error> {
102        let entity = E::from_view(view);
103        Ok(self.by_unique_index(index, entity)?.to_view())
104    }
105
106    /// Upsert using a unique index identified by its field list.
107    pub fn by_unique_fields(&self, fields: &[&str], entity: E) -> Result<E, Error> {
108        let index = UniqueIndexSpec::for_fields::<E>(fields)?;
109        self.upsert(index.index(), entity)
110    }
111
112    /// Upsert a view using a unique index identified by its field list.
113    pub fn by_unique_fields_view(
114        &self,
115        fields: &[&str],
116        view: E::ViewType,
117    ) -> Result<E::ViewType, Error> {
118        let entity = E::from_view(view);
119        Ok(self.by_unique_fields(fields, entity)?.to_view())
120    }
121
122    fn upsert(&self, index: &'static IndexSpec, mut entity: E) -> Result<E, Error> {
123        let existing_pk = self.resolve_unique_pk(index, &entity)?;
124        let saver = SaveExecutor::new(self.db, self.debug);
125
126        if let Some(pk) = existing_pk {
127            entity.set_primary_key(pk);
128            saver.update(entity)
129        } else {
130            saver.insert(entity)
131        }
132    }
133
134    fn resolve_unique_pk(
135        &self,
136        index: &'static IndexSpec,
137        entity: &E,
138    ) -> Result<Option<E::PrimaryKey>, Error> {
139        let Some(index_key) = IndexKey::new(entity, index) else {
140            return Err(ExecutorError::IndexKeyMissing(
141                E::PATH.to_string(),
142                index.fields.join(", "),
143            )
144            .into());
145        };
146
147        let store = self.db.with_index(|reg| reg.try_get_store(index.store))?;
148        let entry = store.with_borrow(|s| s.get(&index_key));
149
150        let Some(entry) = entry else {
151            return Ok(None);
152        };
153
154        let len = entry.len();
155        if len == 0 {
156            return Err(ExecutorError::IndexCorrupted(
157                E::PATH.to_string(),
158                index.fields.join(", "),
159                len,
160            )
161            .into());
162        }
163
164        if len > 1 {
165            return Err(ExecutorError::IndexCorrupted(
166                E::PATH.to_string(),
167                index.fields.join(", "),
168                len,
169            )
170            .into());
171        }
172
173        let key = entry.single_key().ok_or_else(|| {
174            ExecutorError::IndexCorrupted(E::PATH.to_string(), index.fields.join(", "), len)
175        })?;
176
177        let data_key = DataKey::new::<E>(key);
178        let exists = self
179            .db
180            .context::<E>()
181            .with_store(|store| store.get(&data_key).is_some())?;
182        if !exists {
183            return Err(ExecutorError::IndexCorrupted(
184                E::PATH.to_string(),
185                index.fields.join(", "),
186                len,
187            )
188            .into());
189        }
190
191        Ok(Some(E::PrimaryKey::try_from_key(key)?))
192    }
193}
194
195/// Convert a stored [`Key`] into a concrete primary key type.
196pub trait PrimaryKeyFromKey: Copy {
197    fn try_from_key(key: Key) -> Result<Self, ExecutorError>;
198}
199
200const fn key_kind(key: &Key) -> &'static str {
201    match key {
202        Key::Account(_) => "Account",
203        Key::Int(_) => "Int",
204        Key::Principal(_) => "Principal",
205        Key::Subaccount(_) => "Subaccount",
206        Key::Timestamp(_) => "Timestamp",
207        Key::Uint(_) => "Uint",
208        Key::Ulid(_) => "Ulid",
209        Key::Unit => "Unit",
210    }
211}
212
213fn key_type_mismatch<T>(key: &Key) -> ExecutorError {
214    ExecutorError::KeyTypeMismatch(type_name::<T>().to_string(), key_kind(key).to_string())
215}
216
217fn key_out_of_range<T>(value: impl std::fmt::Display) -> ExecutorError {
218    ExecutorError::KeyOutOfRange(type_name::<T>().to_string(), value.to_string())
219}
220
221macro_rules! impl_pk_from_key_uint {
222    ( $( $ty:ty ),* $(,)? ) => {
223        $(
224            impl PrimaryKeyFromKey for $ty {
225                fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
226                    match key {
227                        Key::Uint(v) => <$ty>::try_from(v).map_err(|_| key_out_of_range::<$ty>(v)),
228                        other => Err(key_type_mismatch::<$ty>(&other)),
229                    }
230                }
231            }
232        )*
233    };
234}
235
236macro_rules! impl_pk_from_key_int {
237    ( $( $ty:ty ),* $(,)? ) => {
238        $(
239            impl PrimaryKeyFromKey for $ty {
240                fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
241                    match key {
242                        Key::Int(v) => <$ty>::try_from(v).map_err(|_| key_out_of_range::<$ty>(v)),
243                        other => Err(key_type_mismatch::<$ty>(&other)),
244                    }
245                }
246            }
247        )*
248    };
249}
250
251impl PrimaryKeyFromKey for i64 {
252    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
253        match key {
254            Key::Int(v) => Ok(v),
255            other => Err(key_type_mismatch::<Self>(&other)),
256        }
257    }
258}
259
260impl PrimaryKeyFromKey for u64 {
261    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
262        match key {
263            Key::Uint(v) => Ok(v),
264            other => Err(key_type_mismatch::<Self>(&other)),
265        }
266    }
267}
268
269impl PrimaryKeyFromKey for () {
270    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
271        match key {
272            Key::Unit => Ok(()),
273            other => Err(key_type_mismatch::<Self>(&other)),
274        }
275    }
276}
277
278impl PrimaryKeyFromKey for crate::types::Account {
279    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
280        match key {
281            Key::Account(v) => Ok(v),
282            other => Err(key_type_mismatch::<Self>(&other)),
283        }
284    }
285}
286
287impl PrimaryKeyFromKey for crate::types::Principal {
288    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
289        match key {
290            Key::Principal(v) => Ok(v),
291            other => Err(key_type_mismatch::<Self>(&other)),
292        }
293    }
294}
295
296impl PrimaryKeyFromKey for crate::types::Subaccount {
297    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
298        match key {
299            Key::Subaccount(v) => Ok(v),
300            other => Err(key_type_mismatch::<Self>(&other)),
301        }
302    }
303}
304
305impl PrimaryKeyFromKey for crate::types::Timestamp {
306    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
307        match key {
308            Key::Timestamp(v) => Ok(v),
309            other => Err(key_type_mismatch::<Self>(&other)),
310        }
311    }
312}
313
314impl PrimaryKeyFromKey for crate::types::Ulid {
315    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
316        match key {
317            Key::Ulid(v) => Ok(v),
318            other => Err(key_type_mismatch::<Self>(&other)),
319        }
320    }
321}
322
323impl PrimaryKeyFromKey for crate::types::Unit {
324    fn try_from_key(key: Key) -> Result<Self, ExecutorError> {
325        match key {
326            Key::Unit => Ok(Self),
327            other => Err(key_type_mismatch::<Self>(&other)),
328        }
329    }
330}
331
332impl_pk_from_key_uint!(u8, u16, u32);
333impl_pk_from_key_int!(i8, i16, i32);