async_skipdb/optimistic/
write.rs

1use std::{convert::Infallible, future::Future};
2
3use async_txn::{error::WtmError, PwmComparableRange};
4use skipdb_core::rev_range::WriteTransactionRevRange;
5
6use super::*;
7
8/// A optimistic concurrency control transaction over the [`OptimisticDb`].
9pub struct OptimisticTransaction<K, V, SP: AsyncSpawner, S = RandomState> {
10  db: OptimisticDb<K, V, SP, S>,
11  pub(super) wtm: AsyncWtm<K, V, HashCm<K, S>, BTreePwm<K, V>, SP>,
12}
13
14impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
15where
16  K: Ord + Hash + Eq,
17  S: BuildHasher + Clone,
18  SP: AsyncSpawner,
19{
20  #[inline]
21  pub(super) async fn new(db: OptimisticDb<K, V, SP, S>, cap: Option<usize>) -> Self {
22    let wtm = db
23      .inner
24      .tm
25      .write_with_blocking_cm_and_pwm(
26        (),
27        HashCmOptions::with_capacity(db.inner.hasher.clone(), cap.unwrap_or(8)),
28      )
29      .await
30      .unwrap();
31    Self { db, wtm }
32  }
33}
34
35impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
36where
37  K: Ord + Hash + Eq + Send + Sync + 'static,
38  V: Send + Sync + 'static,
39  S: BuildHasher + Send + Sync + 'static,
40  SP: AsyncSpawner,
41{
42  /// Commits the transaction, following these steps:
43  ///
44  /// 1. If there are no writes, return immediately.
45  ///
46  /// 2. Check if read rows were updated since txn started. If so, return `TransactionError::Conflict`.
47  ///
48  /// 3. If no conflict, generate a commit timestamp and update written rows' commit ts.
49  ///
50  /// 4. Batch up all writes, write them to database.
51  ///
52  /// 5. If callback is provided, Badger will return immediately after checking
53  /// for conflicts. Writes to the database will happen in the background.  If
54  /// there is a conflict, an error will be returned and the callback will not
55  /// run. If there are no conflicts, the callback will be called in the
56  /// background upon successful completion of writes or any error during write.
57  #[inline]
58  pub async fn commit(
59    &mut self,
60  ) -> Result<(), WtmError<Infallible, Infallible, core::convert::Infallible>> {
61    let db = self.db.clone();
62    self
63      .wtm
64      .commit(|ents| async move {
65        db.inner.map.apply(ents);
66        Ok(())
67      })
68      .await
69  }
70}
71
72impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
73where
74  K: Ord + Hash + Eq + Send + Sync + 'static,
75  V: Send + Sync + 'static,
76  S: BuildHasher + Send + Sync + 'static,
77  SP: AsyncSpawner,
78{
79  /// Acts like [`commit`](WriteTransaction::commit), but takes a callback, which gets run via a
80  /// thread to avoid blocking this function. Following these steps:
81  ///
82  /// 1. If there are no writes, return immediately, callback will be invoked.
83  ///
84  /// 2. Check if read rows were updated since txn started. If so, return `TransactionError::Conflict`.
85  ///
86  /// 3. If no conflict, generate a commit timestamp and update written rows' commit ts.
87  ///
88  /// 4. Batch up all writes, write them to database.
89  ///
90  /// 5. Return immediately after checking for conflicts.
91  /// If there is a conflict, an error will be returned immediately and the callback will not
92  /// run. If there are no conflicts, the callback will be called in the
93  /// background upon successful completion of writes or any error during write.
94  #[inline]
95  pub async fn commit_with_task<Fut, E, R>(
96    &mut self,
97    callback: impl FnOnce(Result<(), E>) -> Fut + Send + 'static,
98  ) -> Result<SP::JoinHandle<R>, WtmError<Infallible, Infallible, E>>
99  where
100    Fut: Future<Output = R> + Send + 'static,
101    E: std::error::Error + Send,
102    R: Send + 'static,
103  {
104    let db = self.db.clone();
105
106    self
107      .wtm
108      .commit_with_task(
109        move |ents| async move {
110          db.inner.map.apply(ents);
111          Ok(())
112        },
113        callback,
114      )
115      .await
116  }
117}
118
119impl<K, V, SP, S> OptimisticTransaction<K, V, SP, S>
120where
121  K: Ord + Hash + Eq,
122  V: 'static,
123  S: BuildHasher,
124  SP: AsyncSpawner,
125{
126  /// Returns the read version of the transaction.
127  #[inline]
128  pub fn version(&self) -> u64 {
129    self.wtm.version()
130  }
131
132  /// Rollback the transaction.
133  #[inline]
134  pub fn rollback(&mut self) -> Result<(), TransactionError<Infallible, Infallible>> {
135    self.wtm.rollback_blocking()
136  }
137
138  /// Returns true if the given key exists in the database.
139  #[inline]
140  pub fn contains_key<Q>(
141    &mut self,
142    key: &Q,
143  ) -> Result<bool, TransactionError<Infallible, Infallible>>
144  where
145    K: Borrow<Q>,
146    Q: Hash + Eq + Ord + ?Sized,
147  {
148    let version = self.wtm.version();
149    match self
150      .wtm
151      .contains_key_equivalent_cm_comparable_pm_blocking(key)?
152    {
153      Some(true) => Ok(true),
154      Some(false) => Ok(false),
155      None => Ok(self.db.inner.map.contains_key(key, version)),
156    }
157  }
158
159  /// Get a value from the database.
160  #[inline]
161  pub fn get<'a, 'b: 'a, Q>(
162    &'a mut self,
163    key: &'b Q,
164  ) -> Result<Option<Ref<'a, K, V>>, TransactionError<Infallible, Infallible>>
165  where
166    K: Borrow<Q>,
167    Q: Hash + Eq + Ord + ?Sized,
168  {
169    let version = self.wtm.version();
170    match self.wtm.get_equivalent_cm_comparable_pm_blocking(key)? {
171      Some(v) => {
172        if v.value().is_some() {
173          Ok(Some(v.into()))
174        } else {
175          Ok(None)
176        }
177      }
178      None => Ok(self.db.inner.map.get(key, version).map(Into::into)),
179    }
180  }
181
182  /// Insert a new key-value pair.
183  #[inline]
184  pub fn insert(
185    &mut self,
186    key: K,
187    value: V,
188  ) -> Result<(), TransactionError<Infallible, Infallible>> {
189    self.wtm.insert_blocking(key, value)
190  }
191
192  /// Remove a key.
193  #[inline]
194  pub fn remove(&mut self, key: K) -> Result<(), TransactionError<Infallible, Infallible>> {
195    self.wtm.remove_blocking(key)
196  }
197
198  /// Iterate over the entries of the write transaction.
199  #[inline]
200  pub fn iter(
201    &mut self,
202  ) -> Result<TransactionIter<'_, K, V, HashCm<K, S>>, TransactionError<Infallible, Infallible>> {
203    let version = self.wtm.version();
204    let (marker, pm) = self
205      .wtm
206      .blocking_marker_with_pm()
207      .ok_or(TransactionError::Discard)?;
208
209    let committed = self.db.inner.map.iter(version);
210    let pendings = pm.iter();
211
212    Ok(TransactionIter::new(pendings, committed, Some(marker)))
213  }
214
215  /// Iterate over the entries of the write transaction in reverse order.
216  #[inline]
217  pub fn iter_rev(
218    &mut self,
219  ) -> Result<
220    WriteTransactionRevIter<'_, K, V, HashCm<K, S>>,
221    TransactionError<Infallible, Infallible>,
222  > {
223    let version = self.wtm.version();
224    let (marker, pm) = self
225      .wtm
226      .blocking_marker_with_pm()
227      .ok_or(TransactionError::Discard)?;
228
229    let committed = self.db.inner.map.iter_rev(version);
230    let pendings = pm.iter().rev();
231
232    Ok(WriteTransactionRevIter::new(
233      pendings,
234      committed,
235      Some(marker),
236    ))
237  }
238
239  /// Returns an iterator over the subset of entries of the database.
240  #[inline]
241  pub fn range<'a, Q, R>(
242    &'a mut self,
243    range: R,
244  ) -> Result<
245    TransactionRange<'a, Q, R, K, V, HashCm<K, S>>,
246    TransactionError<Infallible, Infallible>,
247  >
248  where
249    K: Borrow<Q>,
250    R: RangeBounds<Q> + 'a,
251    Q: Ord + ?Sized,
252  {
253    let version = self.wtm.version();
254    let (marker, pm) = self
255      .wtm
256      .blocking_marker_with_pm()
257      .ok_or(TransactionError::Discard)?;
258    let start = range.start_bound();
259    let end = range.end_bound();
260    let pendings = pm.range_comparable((start, end));
261    let committed = self.db.inner.map.range(range, version);
262
263    Ok(TransactionRange::new(pendings, committed, Some(marker)))
264  }
265
266  /// Returns an iterator over the subset of entries of the database in reverse order.
267  #[inline]
268  pub fn range_rev<'a, Q, R>(
269    &'a mut self,
270    range: R,
271  ) -> Result<
272    WriteTransactionRevRange<'a, Q, R, K, V, HashCm<K, S>>,
273    TransactionError<Infallible, Infallible>,
274  >
275  where
276    K: Borrow<Q>,
277    R: RangeBounds<Q> + 'a,
278    Q: Ord + ?Sized,
279  {
280    let version = self.wtm.version();
281    let (marker, pm) = self
282      .wtm
283      .blocking_marker_with_pm()
284      .ok_or(TransactionError::Discard)?;
285    let start = range.start_bound();
286    let end = range.end_bound();
287    let pendings = pm.range_comparable((start, end));
288    let committed = self.db.inner.map.range_rev(range, version);
289
290    Ok(WriteTransactionRevRange::new(
291      pendings.rev(),
292      committed,
293      Some(marker),
294    ))
295  }
296}