async_txn/
write.rs

1use self::error::WtmError;
2
3use core::{borrow::Borrow, future::Future, hash::Hash};
4
5use super::*;
6
7mod blocking;
8
9/// AsyncWtm is used to perform writes to the database. It is created by
10/// calling [`AsyncTm::write`].
11pub struct AsyncWtm<K, V, C, P, S>
12where
13  S: AsyncSpawner,
14{
15  pub(super) read_ts: u64,
16  pub(super) size: u64,
17  pub(super) count: u64,
18  pub(super) orc: Arc<Oracle<C, S>>,
19  pub(super) conflict_manager: Option<C>,
20
21  // buffer stores any writes done by txn.
22  pub(super) pending_writes: Option<P>,
23  // Used in managed mode to store duplicate entries.
24  pub(super) duplicate_writes: OneOrMore<Entry<K, V>>,
25
26  pub(super) discarded: bool,
27  pub(super) done_read: bool,
28}
29
30impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
31where
32  S: AsyncSpawner,
33{
34  /// Returns the version of this read transaction.
35  #[inline]
36  pub const fn version(&self) -> u64 {
37    self.read_ts
38  }
39
40  /// Sets the current read version of the transaction manager.
41  // This should be used only for testing purposes.
42  #[doc(hidden)]
43  #[inline]
44  pub fn __set_read_version(&mut self, version: u64) {
45    self.read_ts = version;
46  }
47
48  /// Returns the pending writes manager.
49  ///
50  /// `None` means the transaction has already been discarded.
51  #[inline]
52  pub fn pwm(&self) -> Option<&P> {
53    self.pending_writes.as_ref()
54  }
55
56  /// Returns the conflict manager.
57  ///
58  /// `None` means the transaction has already been discarded.
59  #[inline]
60  pub fn cm(&self) -> Option<&C> {
61    self.conflict_manager.as_ref()
62  }
63}
64
65impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
66where
67  C: AsyncCm<Key = K>,
68  S: AsyncSpawner,
69{
70  /// This method is used to create a marker for the keys that are operated.
71  /// It must be used to mark keys when end user is implementing iterators to
72  /// make sure the transaction manager works correctly.
73  ///
74  /// `None` means the transaction has already been discarded.
75  ///
76  /// e.g.
77  ///
78  /// ```ignore, rust
79  /// let mut txn = custom_database.write(conflict_manger_opts, pending_manager_opts).unwrap();
80  /// let mut marker = txn.marker();
81  /// custom_database.iter().map(|k, v| marker.mark(&k));
82  /// ```
83  pub fn marker(&mut self) -> Option<AsyncMarker<'_, C>> {
84    self.conflict_manager.as_mut().map(AsyncMarker::new)
85  }
86
87  /// Returns a marker for the keys that are operated and the pending writes manager.
88  ///
89  /// `None` means the transaction has already been discarded.
90  ///
91  /// As Rust's borrow checker does not allow to borrow mutable marker and the immutable pending writes manager at the same
92  /// time, this method is used to solve this problem.
93  pub fn marker_with_pm(&mut self) -> Option<(AsyncMarker<'_, C>, &P)> {
94    self.conflict_manager.as_mut().map(|marker| {
95      (
96        AsyncMarker::new(marker),
97        self.pending_writes.as_ref().unwrap(),
98      )
99    })
100  }
101
102  /// Marks a key is read.
103  pub async fn mark_read(&mut self, k: &K) {
104    if let Some(ref mut conflict_manager) = self.conflict_manager {
105      conflict_manager.mark_read(k).await;
106    }
107  }
108
109  /// Marks a key is conflict.
110  pub async fn mark_conflict(&mut self, k: &K) {
111    if let Some(ref mut conflict_manager) = self.conflict_manager {
112      conflict_manager.mark_conflict(k).await;
113    }
114  }
115}
116
117impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
118where
119  C: AsyncCm<Key = K>,
120  P: AsyncPwm<Key = K, Value = V>,
121  S: AsyncSpawner,
122{
123  /// Rolls back the transaction.
124  #[inline]
125  pub async fn rollback(&mut self) -> Result<(), TransactionError<C::Error, P::Error>> {
126    if self.discarded {
127      return Err(TransactionError::Discard);
128    }
129
130    self
131      .pending_writes
132      .as_mut()
133      .unwrap()
134      .rollback()
135      .await
136      .map_err(TransactionError::Pwm)?;
137    self
138      .conflict_manager
139      .as_mut()
140      .unwrap()
141      .rollback()
142      .await
143      .map_err(TransactionError::Cm)?;
144    Ok(())
145  }
146
147  /// Insert a key-value pair to the transaction.
148  pub async fn insert(
149    &mut self,
150    key: K,
151    value: V,
152  ) -> Result<(), TransactionError<C::Error, P::Error>> {
153    self.insert_with_in(key, value).await
154  }
155
156  /// Removes a key.
157  ///
158  /// This is done by adding a delete marker for the key at commit timestamp.  Any
159  /// reads happening before this timestamp would be unaffected. Any reads after
160  /// this commit would see the deletion.
161  pub async fn remove(&mut self, key: K) -> Result<(), TransactionError<C::Error, P::Error>> {
162    self
163      .modify(Entry {
164        data: EntryData::Remove(key),
165        version: 0,
166      })
167      .await
168  }
169
170  /// Returns `true` if the pending writes contains the key.
171  pub async fn contains_key(
172    &mut self,
173    key: &K,
174  ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>> {
175    if self.discarded {
176      return Err(TransactionError::Discard);
177    }
178
179    match self
180      .pending_writes
181      .as_ref()
182      .unwrap()
183      .get(key)
184      .await
185      .map_err(TransactionError::pending)?
186    {
187      Some(ent) => {
188        // If the value is None, it means that the key is removed.
189        if ent.value.is_none() {
190          return Ok(Some(false));
191        }
192
193        // Fulfill from buffer.
194        Ok(Some(true))
195      }
196      None => {
197        // track reads. No need to track read if txn serviced it
198        // internally.
199        if let Some(ref mut conflict_manager) = self.conflict_manager {
200          conflict_manager.mark_read(key).await;
201        }
202
203        Ok(None)
204      }
205    }
206  }
207
208  /// Looks for the key in the pending writes, if such key is not in the pending writes,
209  /// the end user can read the key from the database.
210  pub async fn get<'a, 'b: 'a>(
211    &'a mut self,
212    key: &'b K,
213  ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>> {
214    if self.discarded {
215      return Err(TransactionError::Discard);
216    }
217
218    if let Some(e) = self
219      .pending_writes
220      .as_ref()
221      .unwrap()
222      .get(key)
223      .await
224      .map_err(TransactionError::Pwm)?
225    {
226      // If the value is None, it means that the key is removed.
227      if e.value.is_none() {
228        return Ok(None);
229      }
230
231      // Fulfill from buffer.
232      Ok(Some(EntryRef {
233        data: match &e.value {
234          Some(value) => EntryDataRef::Insert { key, value },
235          None => EntryDataRef::Remove(key),
236        },
237        version: e.version,
238      }))
239    } else {
240      // track reads. No need to track read if txn serviced it
241      // internally.
242      if let Some(ref mut conflict_manager) = self.conflict_manager {
243        conflict_manager.mark_read(key).await;
244      }
245
246      Ok(None)
247    }
248  }
249
250  /// Commits the transaction, following these steps:
251  ///
252  /// 1. If there are no writes, return immediately.
253  ///
254  /// 2. Check if read rows were updated since txn started. If so, return `TransactionError::Conflict`.
255  ///
256  /// 3. If no conflict, generate a commit timestamp and update written rows' commit ts.
257  ///
258  /// 4. Batch up all writes, write them to database.
259  ///
260  /// 5. If callback is provided, Badger will return immediately after checking
261  /// for conflicts. Writes to the database will happen in the background.  If
262  /// there is a conflict, an error will be returned and the callback will not
263  /// run. If there are no conflicts, the callback will be called in the
264  /// background upon successful completion of writes or any error during write.
265  pub async fn commit<F, Fut, O, E>(
266    &mut self,
267    apply: F,
268  ) -> Result<O, WtmError<C::Error, P::Error, E>>
269  where
270    Fut: Future<Output = Result<O, E>>,
271    F: FnOnce(OneOrMore<Entry<K, V>>) -> Fut,
272    E: std::error::Error,
273  {
274    if self.pending_writes.as_ref().unwrap().is_empty().await {
275      // Nothing to commit
276      self.discard();
277      return apply(Default::default()).await.map_err(WtmError::commit);
278    }
279
280    match self.commit_entries().await {
281      Ok((commit_ts, entries)) => match apply(entries).await {
282        Ok(output) => {
283          self.orc.done_commit(commit_ts);
284          self.discard();
285          Ok(output)
286        }
287        Err(e) => {
288          self.orc.done_commit(commit_ts);
289          self.discard();
290          Err(WtmError::commit(e))
291        }
292      },
293      Err(e) => {
294        self.discard();
295        Err(WtmError::transaction(e))
296      }
297    }
298  }
299}
300
301impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
302where
303  C: AsyncCmEquivalent<Key = K>,
304  P: AsyncPwm<Key = K, Value = V>,
305  S: AsyncSpawner,
306{
307  /// Marks a key is read.
308  pub async fn mark_read_equivalent<Q>(&mut self, k: &Q)
309  where
310    K: Borrow<Q>,
311    Q: ?Sized + Eq + Hash,
312  {
313    if let Some(ref mut conflict_manager) = self.conflict_manager {
314      conflict_manager.mark_read_equivalent(k).await;
315    }
316  }
317
318  /// Marks a key is conflict.
319  pub async fn mark_conflict_equivalent<Q>(&mut self, k: &Q)
320  where
321    K: Borrow<Q>,
322    Q: ?Sized + Eq + Hash,
323  {
324    if let Some(ref mut conflict_manager) = self.conflict_manager {
325      conflict_manager.mark_conflict_equivalent(k).await;
326    }
327  }
328}
329
330impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
331where
332  C: AsyncCmEquivalent<Key = K>,
333  P: AsyncPwmEquivalent<Key = K, Value = V>,
334  S: AsyncSpawner,
335{
336  /// Returns `true` if the pending writes contains the key.
337  ///
338  /// - `Ok(None)`: means the key is not in the pending writes, the end user can read the key from the database.
339  /// - `Ok(Some(true))`: means the key is in the pending writes.
340  /// - `Ok(Some(false))`: means the key is in the pending writes and but is a remove entry.
341  pub async fn contains_key_equivalent<'a, 'b: 'a, Q>(
342    &'a mut self,
343    key: &'b Q,
344  ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
345  where
346    K: Borrow<Q>,
347    Q: ?Sized + Eq + Hash,
348  {
349    if self.discarded {
350      return Err(TransactionError::Discard);
351    }
352
353    match self
354      .pending_writes
355      .as_ref()
356      .unwrap()
357      .get_equivalent(key)
358      .await
359      .map_err(TransactionError::pending)?
360    {
361      Some(ent) => {
362        // If the value is None, it means that the key is removed.
363        if ent.value.is_none() {
364          return Ok(Some(false));
365        }
366
367        // Fulfill from buffer.
368        Ok(Some(true))
369      }
370      None => {
371        // track reads. No need to track read if txn serviced it
372        // internally.
373        if let Some(ref mut conflict_manager) = self.conflict_manager {
374          conflict_manager.mark_read_equivalent(key).await;
375        }
376
377        Ok(None)
378      }
379    }
380  }
381
382  /// Looks for the key in the pending writes, if such key is not in the pending writes,
383  /// the end user can read the key from the database.
384  pub async fn get_equivalent<'a, 'b: 'a, Q>(
385    &'a mut self,
386    key: &'b Q,
387  ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
388  where
389    K: Borrow<Q>,
390    Q: ?Sized + Eq + Hash,
391  {
392    if self.discarded {
393      return Err(TransactionError::Discard);
394    }
395
396    if let Some((k, e)) = self
397      .pending_writes
398      .as_ref()
399      .unwrap()
400      .get_entry_equivalent(key)
401      .await
402      .map_err(TransactionError::Pwm)?
403    {
404      // If the value is None, it means that the key is removed.
405      if e.value.is_none() {
406        return Ok(None);
407      }
408
409      // Fulfill from buffer.
410      Ok(Some(EntryRef {
411        data: match &e.value {
412          Some(value) => EntryDataRef::Insert { key: k, value },
413          None => EntryDataRef::Remove(k),
414        },
415        version: e.version,
416      }))
417    } else {
418      // track reads. No need to track read if txn serviced it
419      // internally.
420      if let Some(ref mut conflict_manager) = self.conflict_manager {
421        conflict_manager.mark_read_equivalent(key).await;
422      }
423
424      Ok(None)
425    }
426  }
427}
428
429impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
430where
431  C: AsyncCmComparable<Key = K>,
432  P: AsyncPwmEquivalent<Key = K, Value = V>,
433  S: AsyncSpawner,
434{
435  /// Returns `true` if the pending writes contains the key.
436  ///
437  /// - `Ok(None)`: means the key is not in the pending writes, the end user can read the key from the database.
438  /// - `Ok(Some(true))`: means the key is in the pending writes.
439  /// - `Ok(Some(false))`: means the key is in the pending writes and but is a remove entry.
440  pub async fn contains_key_comparable_cm_equivalent_pm<'a, 'b: 'a, Q>(
441    &'a mut self,
442    key: &'b Q,
443  ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
444  where
445    K: Borrow<Q>,
446    Q: ?Sized + Eq + Ord + Hash,
447  {
448    match self
449      .pending_writes
450      .as_ref()
451      .unwrap()
452      .get_equivalent(key)
453      .await
454      .map_err(TransactionError::pending)?
455    {
456      Some(ent) => {
457        // If the value is None, it means that the key is removed.
458        if ent.value.is_none() {
459          return Ok(Some(false));
460        }
461
462        // Fulfill from buffer.
463        Ok(Some(true))
464      }
465      None => {
466        // track reads. No need to track read if txn serviced it
467        // internally.
468        if let Some(ref mut conflict_manager) = self.conflict_manager {
469          conflict_manager.mark_read_comparable(key).await;
470        }
471
472        Ok(None)
473      }
474    }
475  }
476
477  /// Looks for the key in the pending writes, if such key is not in the pending writes,
478  /// the end user can read the key from the database.
479  pub async fn get_comparable_cm_equivalent_pm<'a, 'b: 'a, Q>(
480    &'a mut self,
481    key: &'b Q,
482  ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
483  where
484    K: Borrow<Q>,
485    Q: ?Sized + Eq + Ord + Hash,
486  {
487    if let Some((k, e)) = self
488      .pending_writes
489      .as_ref()
490      .unwrap()
491      .get_entry_equivalent(key)
492      .await
493      .map_err(TransactionError::Pwm)?
494    {
495      // If the value is None, it means that the key is removed.
496      if e.value.is_none() {
497        return Ok(None);
498      }
499
500      // Fulfill from buffer.
501      Ok(Some(EntryRef {
502        data: match &e.value {
503          Some(value) => EntryDataRef::Insert { key: k, value },
504          None => EntryDataRef::Remove(k),
505        },
506        version: e.version,
507      }))
508    } else {
509      // track reads. No need to track read if txn serviced it
510      // internally.
511      if let Some(ref mut conflict_manager) = self.conflict_manager {
512        conflict_manager.mark_read_comparable(key).await;
513      }
514
515      Ok(None)
516    }
517  }
518}
519
520impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
521where
522  C: AsyncCmComparable<Key = K>,
523  S: AsyncSpawner,
524{
525  /// Marks a key is read.
526  pub async fn mark_read_comparable<Q>(&mut self, k: &Q)
527  where
528    K: Borrow<Q>,
529    Q: ?Sized + Ord,
530  {
531    if let Some(ref mut conflict_manager) = self.conflict_manager {
532      conflict_manager.mark_read_comparable(k).await;
533    }
534  }
535
536  /// Marks a key is conflict.
537  pub async fn mark_conflict_comparable<Q>(&mut self, k: &Q)
538  where
539    K: Borrow<Q>,
540    Q: ?Sized + Ord,
541  {
542    if let Some(ref mut conflict_manager) = self.conflict_manager {
543      conflict_manager.mark_conflict_comparable(k).await;
544    }
545  }
546}
547
548impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
549where
550  C: AsyncCmComparable<Key = K>,
551  P: AsyncPwmComparable<Key = K, Value = V>,
552  S: AsyncSpawner,
553{
554  /// Returns `true` if the pending writes contains the key.
555  ///
556  /// - `Ok(None)`: means the key is not in the pending writes, the end user can read the key from the database.
557  /// - `Ok(Some(true))`: means the key is in the pending writes.
558  /// - `Ok(Some(false))`: means the key is in the pending writes and but is a remove entry.
559  pub async fn contains_key_comparable<'a, 'b: 'a, Q>(
560    &'a mut self,
561    key: &'b Q,
562  ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
563  where
564    K: Borrow<Q>,
565    Q: ?Sized + Ord,
566  {
567    match self
568      .pending_writes
569      .as_ref()
570      .unwrap()
571      .get_comparable(key)
572      .await
573      .map_err(TransactionError::pending)?
574    {
575      Some(ent) => {
576        // If the value is None, it means that the key is removed.
577        if ent.value.is_none() {
578          return Ok(Some(false));
579        }
580
581        // Fulfill from buffer.
582        Ok(Some(true))
583      }
584      None => {
585        // track reads. No need to track read if txn serviced it
586        // internally.
587        if let Some(ref mut conflict_manager) = self.conflict_manager {
588          conflict_manager.mark_read_comparable(key).await;
589        }
590
591        Ok(None)
592      }
593    }
594  }
595
596  /// Looks for the key in the pending writes, if such key is not in the pending writes,
597  /// the end user can read the key from the database.
598  pub async fn get_comparable<'a, 'b: 'a, Q>(
599    &'a mut self,
600    key: &'b Q,
601  ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
602  where
603    K: Borrow<Q>,
604    Q: ?Sized + Ord,
605  {
606    if let Some((k, e)) = self
607      .pending_writes
608      .as_ref()
609      .unwrap()
610      .get_entry_comparable(key)
611      .await
612      .map_err(TransactionError::Pwm)?
613    {
614      // If the value is None, it means that the key is removed.
615      if e.value.is_none() {
616        return Ok(None);
617      }
618
619      // Fulfill from buffer.
620      Ok(Some(EntryRef {
621        data: match &e.value {
622          Some(value) => EntryDataRef::Insert { key: k, value },
623          None => EntryDataRef::Remove(k),
624        },
625        version: e.version,
626      }))
627    } else {
628      // track reads. No need to track read if txn serviced it
629      // internally.
630      if let Some(ref mut conflict_manager) = self.conflict_manager {
631        conflict_manager.mark_read_comparable(key).await;
632      }
633
634      Ok(None)
635    }
636  }
637}
638
639impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
640where
641  C: AsyncCmEquivalent<Key = K>,
642  P: AsyncPwmComparable<Key = K, Value = V>,
643  S: AsyncSpawner,
644{
645  /// Returns `true` if the pending writes contains the key.
646  ///
647  /// - `Ok(None)`: means the key is not in the pending writes, the end user can read the key from the database.
648  /// - `Ok(Some(true))`: means the key is in the pending writes.
649  /// - `Ok(Some(false))`: means the key is in the pending writes and but is a remove entry.
650  pub async fn contains_key_equivalent_cm_comparable_pm<'a, 'b: 'a, Q>(
651    &'a mut self,
652    key: &'b Q,
653  ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
654  where
655    K: Borrow<Q>,
656    Q: ?Sized + Eq + Ord + Hash,
657  {
658    match self
659      .pending_writes
660      .as_ref()
661      .unwrap()
662      .get_comparable(key)
663      .await
664      .map_err(TransactionError::pending)?
665    {
666      Some(ent) => {
667        // If the value is None, it means that the key is removed.
668        if ent.value.is_none() {
669          return Ok(Some(false));
670        }
671
672        // Fulfill from buffer.
673        Ok(Some(true))
674      }
675      None => {
676        // track reads. No need to track read if txn serviced it
677        // internally.
678        if let Some(ref mut conflict_manager) = self.conflict_manager {
679          conflict_manager.mark_read_equivalent(key).await;
680        }
681
682        Ok(None)
683      }
684    }
685  }
686
687  /// Looks for the key in the pending writes, if such key is not in the pending writes,
688  /// the end user can read the key from the database.
689  pub async fn get_equivalent_cm_comparable_pm<'a, 'b: 'a, Q>(
690    &'a mut self,
691    key: &'b Q,
692  ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
693  where
694    K: Borrow<Q>,
695    Q: ?Sized + Eq + Ord + Hash,
696  {
697    if let Some((k, e)) = self
698      .pending_writes
699      .as_ref()
700      .unwrap()
701      .get_entry_comparable(key)
702      .await
703      .map_err(TransactionError::Pwm)?
704    {
705      // If the value is None, it means that the key is removed.
706      if e.value.is_none() {
707        return Ok(None);
708      }
709
710      // Fulfill from buffer.
711      Ok(Some(EntryRef {
712        data: match &e.value {
713          Some(value) => EntryDataRef::Insert { key: k, value },
714          None => EntryDataRef::Remove(k),
715        },
716        version: e.version,
717      }))
718    } else {
719      // track reads. No need to track read if txn serviced it
720      // internally.
721      if let Some(ref mut conflict_manager) = self.conflict_manager {
722        conflict_manager.mark_read_equivalent(key).await;
723      }
724
725      Ok(None)
726    }
727  }
728}
729
730impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
731where
732  C: AsyncCm<Key = K> + Send,
733  P: AsyncPwm<Key = K, Value = V> + Send,
734  S: AsyncSpawner,
735{
736  /// Acts like [`commit`](AsyncWtm::commit), but takes a future and a spawner, which gets run via a
737  /// task to avoid blocking this function. Following these steps:
738  ///
739  /// 1. If there are no writes, return immediately, a new task will be spawned, and future will be invoked.
740  ///
741  /// 2. Check if read rows were updated since txn started. If so, return `TransactionError::Conflict`.
742  ///
743  /// 3. If no conflict, generate a commit timestamp and update written rows' commit ts.
744  ///
745  /// 4. Batch up all writes, write them to database.
746  ///
747  /// 5. Return immediately after checking for conflicts.
748  /// If there is a conflict, an error will be returned immediately and the no task will be spawned
749  /// run. If there are no conflicts, a task will be spawned and the future will be called in the
750  /// background upon successful completion of writes or any error during write.
751  pub async fn commit_with_task<F, Fut, CFut, E, R>(
752    &mut self,
753    apply: F,
754    fut: impl FnOnce(Result<(), E>) -> CFut + Send + 'static,
755  ) -> Result<<S as AsyncSpawner>::JoinHandle<R>, WtmError<C::Error, P::Error, E>>
756  where
757    K: Send + 'static,
758    V: Send + 'static,
759    Fut: Future<Output = Result<(), E>> + Send,
760    F: FnOnce(OneOrMore<Entry<K, V>>) -> Fut + Send + 'static,
761    CFut: Future<Output = R> + Send + 'static,
762    E: std::error::Error + Send,
763    C: 'static,
764    R: Send + 'static,
765  {
766    if self.discarded {
767      return Err(WtmError::transaction(TransactionError::Discard));
768    }
769
770    if self.pending_writes.as_ref().unwrap().is_empty().await {
771      // Nothing to commit
772      self.discard();
773      return Ok(S::spawn(async move { fut(Ok(())).await }));
774    }
775
776    match self.commit_entries().await {
777      Ok((commit_ts, entries)) => {
778        let orc = self.orc.clone();
779        Ok(S::spawn(async move {
780          match apply(entries).await {
781            Ok(_) => {
782              orc.done_commit(commit_ts);
783              fut(Ok(())).await
784            }
785            Err(e) => {
786              orc.done_commit(commit_ts);
787              fut(Err(e)).await
788            }
789          }
790        }))
791      }
792      Err(e) => match e {
793        TransactionError::Conflict => Err(WtmError::transaction(e)),
794        _ => {
795          self.discard();
796          Err(WtmError::transaction(e))
797        }
798      },
799    }
800  }
801}
802
803impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
804where
805  C: AsyncCm<Key = K>,
806  P: AsyncPwm<Key = K, Value = V>,
807  S: AsyncSpawner,
808{
809  async fn insert_with_in(
810    &mut self,
811    key: K,
812    value: V,
813  ) -> Result<(), TransactionError<C::Error, P::Error>> {
814    let ent = Entry {
815      data: EntryData::Insert { key, value },
816      version: self.read_ts,
817    };
818
819    self.modify(ent).await
820  }
821
822  async fn modify(&mut self, ent: Entry<K, V>) -> Result<(), TransactionError<C::Error, P::Error>> {
823    if self.discarded {
824      return Err(TransactionError::Discard);
825    }
826
827    let pending_writes = self.pending_writes.as_mut().unwrap();
828    pending_writes
829      .validate_entry(&ent)
830      .await
831      .map_err(TransactionError::Pwm)?;
832
833    let cnt = self.count + 1;
834    // Extra bytes for the version in key.
835    let size = self.size + pending_writes.estimate_size(&ent);
836    if cnt >= pending_writes.max_batch_entries() || size >= pending_writes.max_batch_size() {
837      return Err(TransactionError::LargeTxn);
838    }
839
840    self.count = cnt;
841    self.size = size;
842
843    // The conflict_manager is used for conflict detection. If conflict detection
844    // is disabled, we don't need to store key hashes in the conflict_manager.
845    if let Some(ref mut conflict_manager) = self.conflict_manager {
846      conflict_manager.mark_conflict(ent.key()).await;
847    }
848
849    // If a duplicate entry was inserted in managed mode, move it to the duplicate writes slice.
850    // Add the entry to duplicateWrites only if both the entries have different versions. For
851    // same versions, we will overwrite the existing entry.
852    let eversion = ent.version;
853    let (ek, ev) = ent.split();
854
855    if let Some((old_key, old_value)) = pending_writes
856      .remove_entry(&ek)
857      .await
858      .map_err(TransactionError::Pwm)?
859    {
860      if old_value.version != eversion {
861        self
862          .duplicate_writes
863          .push(Entry::unsplit(old_key, old_value));
864      }
865    }
866    pending_writes
867      .insert(ek, ev)
868      .await
869      .map_err(TransactionError::Pwm)?;
870
871    Ok(())
872  }
873
874  async fn commit_entries(
875    &mut self,
876  ) -> Result<(u64, OneOrMore<Entry<K, V>>), TransactionError<C::Error, P::Error>> {
877    // Ensure that the order in which we get the commit timestamp is the same as
878    // the order in which we push these updates to the write channel. So, we
879    // acquire a writeChLock before getting a commit timestamp, and only release
880    // it after pushing the entries to it.
881    let _write_lock = self.orc.write_serialize_lock.lock().await;
882
883    let conflict_manager = if self.conflict_manager.is_none() {
884      None
885    } else {
886      mem::take(&mut self.conflict_manager)
887    };
888
889    match self
890      .orc
891      .new_commit_ts(&mut self.done_read, self.read_ts, conflict_manager)
892      .await
893    {
894      CreateCommitTimestampResult::Conflict(conflict_manager) => {
895        // If there is a conflict, we should not send the updates to the write channel.
896        // Instead, we should return the conflict error to the user.
897        self.conflict_manager = conflict_manager;
898        Err(TransactionError::Conflict)
899      }
900      CreateCommitTimestampResult::Timestamp(commit_ts) => {
901        let pending_writes = mem::take(&mut self.pending_writes).unwrap();
902        let duplicate_writes = mem::take(&mut self.duplicate_writes);
903        let mut entries =
904          OneOrMore::with_capacity(pending_writes.len().await + self.duplicate_writes.len());
905
906        let process_entry = |entries: &mut OneOrMore<Entry<K, V>>, mut ent: Entry<K, V>| {
907          ent.version = commit_ts;
908          entries.push(ent);
909        };
910        pending_writes
911          .into_iter()
912          .await
913          .for_each(|(k, v)| process_entry(&mut entries, Entry::unsplit(k, v)));
914        duplicate_writes
915          .into_iter()
916          .for_each(|ent| process_entry(&mut entries, ent));
917
918        // CommitTs should not be zero if we're inserting transaction markers.
919        assert_ne!(commit_ts, 0);
920
921        Ok((commit_ts, entries))
922      }
923    }
924  }
925}
926
927impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
928where
929  S: AsyncSpawner,
930{
931  fn done_read(&mut self) {
932    if !self.done_read {
933      self.done_read = true;
934      self.orc().read_mark.done(self.read_ts).unwrap();
935    }
936  }
937
938  #[inline]
939  fn orc(&self) -> &Oracle<C, S> {
940    &self.orc
941  }
942
943  /// Discards a created transaction. This method is very important and must be called. `commit*`
944  /// methods calls this internally.
945  ///
946  /// NOTE: If any operations are run on a discarded transaction, [`TransactionError::Discard`] is returned.
947  pub fn discard(&mut self) {
948    if self.discarded {
949      return;
950    }
951    self.discarded = true;
952    self.done_read();
953  }
954
955  /// Returns true if the transaction is discarded.
956  #[inline]
957  pub const fn is_discard(&self) -> bool {
958    self.discarded
959  }
960}
961
962impl<K, V, C, P, S> Drop for AsyncWtm<K, V, C, P, S>
963where
964  S: AsyncSpawner,
965{
966  fn drop(&mut self) {
967    if !self.discarded {
968      self.discard();
969    }
970  }
971}
972
973#[cfg(test)]
974mod tests {
975  use std::{collections::BTreeSet, convert::Infallible, marker::PhantomData};
976
977  use super::*;
978
979  #[async_std::test]
980  async fn wtm() {
981    let tm = AsyncTm::<String, u64, HashCm<String>, IndexMapPwm<String, u64>, wmark::AsyncStdSpawner>::new("test", 0).await;
982    let mut wtm = tm
983      .write(Default::default(), Default::default())
984      .await
985      .unwrap();
986    assert!(!wtm.is_discard());
987    assert!(wtm.pwm().is_some());
988    assert!(wtm.cm().is_some());
989
990    let mut marker = wtm.marker().unwrap();
991
992    marker.mark(&"1".to_owned()).await;
993    marker.mark_equivalent("3").await;
994    marker.mark_conflict(&"2".to_owned()).await;
995    marker.mark_conflict_equivalent("4").await;
996    wtm.mark_read(&"2".to_owned()).await;
997    wtm.mark_conflict(&"1".to_owned()).await;
998    wtm.mark_conflict_equivalent("2").await;
999    wtm.mark_read_equivalent("3").await;
1000
1001    wtm.insert("5".into(), 5).await.unwrap();
1002
1003    assert_eq!(wtm.contains_key_equivalent("5").await.unwrap(), Some(true));
1004    assert_eq!(
1005      wtm
1006        .get_equivalent("5")
1007        .await
1008        .unwrap()
1009        .unwrap()
1010        .value()
1011        .unwrap(),
1012      &5
1013    );
1014
1015    assert_eq!(wtm.contains_key(&"5".to_owned()).await.unwrap(), Some(true));
1016    assert_eq!(
1017      wtm
1018        .get(&"5".to_owned())
1019        .await
1020        .unwrap()
1021        .unwrap()
1022        .value()
1023        .unwrap(),
1024      &5
1025    );
1026
1027    assert_eq!(wtm.contains_key_equivalent("6").await.unwrap(), None);
1028    assert_eq!(wtm.get_equivalent("6").await.unwrap(), None);
1029    assert_eq!(wtm.contains_key_blocking(&"6".to_owned()).unwrap(), None);
1030
1031    wtm.remove("5".into()).await.unwrap();
1032    wtm.rollback().await.unwrap();
1033
1034    wtm
1035      .commit::<_, _, _, Infallible>(|_| async { Ok(()) })
1036      .await
1037      .unwrap();
1038
1039    assert!(wtm.is_discard());
1040  }
1041
1042  #[async_std::test]
1043  async fn wtm2() {
1044    let tm =
1045      AsyncTm::<String, u64, HashCm<String>, BTreePwm<String, u64>, wmark::AsyncStdSpawner>::new(
1046        "test", 0,
1047      )
1048      .await;
1049    let mut wtm = tm.write((), Default::default()).await.unwrap();
1050    assert!(!wtm.is_discard());
1051    assert!(wtm.pwm().is_some());
1052    assert!(wtm.cm().is_some());
1053    assert!(wtm.blocking_marker().is_some());
1054    assert!(wtm.marker_with_pm().is_some());
1055
1056    let mut marker = wtm.marker().unwrap();
1057
1058    marker.mark(&"1".to_owned()).await;
1059    marker.mark_blocking(&"3".to_owned());
1060    marker.mark_equivalent("3").await;
1061    marker.mark_equivalent_blocking("3");
1062    marker.mark_conflict(&"2".to_owned()).await;
1063    marker.mark_conflict_equivalent_blocking("4");
1064    marker.mark_conflict_equivalent("4").await;
1065    wtm.mark_read(&"2".to_owned()).await;
1066    wtm.mark_read_blocking(&"3".to_owned());
1067    wtm.mark_read_equivalent_blocking("3");
1068    wtm.mark_conflict(&"1".to_owned()).await;
1069    wtm.mark_conflict_equivalent("2").await;
1070    wtm.mark_conflict_equivalent_blocking("2");
1071    wtm.mark_read_equivalent("3").await;
1072
1073    wtm.insert("5".into(), 5).await.unwrap();
1074
1075    assert_eq!(
1076      wtm
1077        .contains_key_equivalent_cm_comparable_pm("5")
1078        .await
1079        .unwrap(),
1080      Some(true)
1081    );
1082    assert_eq!(
1083      wtm
1084        .get_equivalent_cm_comparable_pm("5")
1085        .await
1086        .unwrap()
1087        .unwrap()
1088        .value()
1089        .unwrap(),
1090      &5
1091    );
1092
1093    assert_eq!(wtm.contains_key(&"5".to_owned()).await.unwrap(), Some(true));
1094    assert_eq!(
1095      wtm
1096        .get(&"5".to_owned())
1097        .await
1098        .unwrap()
1099        .unwrap()
1100        .value()
1101        .unwrap(),
1102      &5
1103    );
1104
1105    assert_eq!(
1106      wtm
1107        .contains_key_equivalent_cm_comparable_pm("6")
1108        .await
1109        .unwrap(),
1110      None
1111    );
1112    assert_eq!(
1113      wtm.get_equivalent_cm_comparable_pm("6").await.unwrap(),
1114      None
1115    );
1116    assert_eq!(wtm.contains_key(&"6".to_owned()).await.unwrap(), None);
1117    assert_eq!(wtm.get(&"6".to_owned()).await.unwrap(), None);
1118
1119    wtm.remove("5".into()).await.unwrap();
1120    wtm.rollback().await.unwrap();
1121
1122    wtm
1123      .commit::<_, _, _, Infallible>(|_| async { Ok(()) })
1124      .await
1125      .unwrap();
1126
1127    assert!(wtm.is_discard());
1128  }
1129
1130  struct TestCm<K> {
1131    conflict_keys: BTreeSet<usize>,
1132    reads: BTreeSet<usize>,
1133    _m: PhantomData<K>,
1134  }
1135
1136  impl<K> Cm for TestCm<K> {
1137    type Error = Infallible;
1138
1139    type Key = K;
1140
1141    type Options = ();
1142
1143    fn new(_options: Self::Options) -> Result<Self, Self::Error> {
1144      Ok(Self {
1145        conflict_keys: BTreeSet::new(),
1146        reads: BTreeSet::new(),
1147        _m: PhantomData,
1148      })
1149    }
1150
1151    fn mark_read(&mut self, key: &Self::Key) {
1152      self.reads.insert(key as *const K as usize);
1153    }
1154
1155    fn mark_conflict(&mut self, key: &Self::Key) {
1156      self.conflict_keys.insert(key as *const K as usize);
1157    }
1158
1159    fn has_conflict(&self, other: &Self) -> bool {
1160      if self.reads.is_empty() {
1161        return false;
1162      }
1163
1164      for ro in self.reads.iter() {
1165        if other.conflict_keys.contains(ro) {
1166          return true;
1167        }
1168      }
1169      false
1170    }
1171
1172    fn rollback(&mut self) -> Result<(), Self::Error> {
1173      self.conflict_keys.clear();
1174      self.reads.clear();
1175      Ok(())
1176    }
1177  }
1178
1179  impl<K> CmComparable for TestCm<K> {
1180    fn mark_read_comparable<Q>(&mut self, key: &Q)
1181    where
1182      Self::Key: Borrow<Q>,
1183      Q: Ord + ?Sized,
1184    {
1185      self.reads.insert(key as *const Q as *const () as usize);
1186    }
1187
1188    fn mark_conflict_comparable<Q>(&mut self, key: &Q)
1189    where
1190      Self::Key: Borrow<Q>,
1191      Q: Ord + ?Sized,
1192    {
1193      self
1194        .conflict_keys
1195        .insert(key as *const Q as *const () as usize);
1196    }
1197  }
1198
1199  #[async_std::test]
1200  async fn wtm3() {
1201    let tm = AsyncTm::<
1202      Arc<u64>,
1203      u64,
1204      TestCm<Arc<u64>>,
1205      IndexMapPwm<Arc<u64>, u64>,
1206      wmark::AsyncStdSpawner,
1207    >::new("test", 0)
1208    .await;
1209    let mut wtm = tm.write(Default::default(), ()).await.unwrap();
1210    assert!(!wtm.is_discard());
1211    assert!(wtm.pwm().is_some());
1212    assert!(wtm.cm().is_some());
1213
1214    let mut marker = wtm.marker().unwrap();
1215
1216    let one = Arc::new(1);
1217    let two = Arc::new(2);
1218    let three = Arc::new(3);
1219    let four = Arc::new(4);
1220    let five = Arc::new(5);
1221    marker.mark(&one).await;
1222    marker.mark_comparable(&three).await;
1223    marker.mark_conflict(&two).await;
1224    marker.mark_conflict_comparable(&four).await;
1225    wtm.mark_read(&two).await;
1226    wtm.mark_conflict(&one).await;
1227    wtm.mark_conflict_comparable(&two).await;
1228    wtm.mark_read_comparable(&three).await;
1229
1230    wtm.insert(five.clone(), 5).await.unwrap();
1231
1232    assert_eq!(
1233      wtm
1234        .contains_key_comparable_cm_equivalent_pm(&five)
1235        .await
1236        .unwrap(),
1237      Some(true)
1238    );
1239    assert_eq!(
1240      wtm
1241        .get_comparable_cm_equivalent_pm(&five)
1242        .await
1243        .unwrap()
1244        .unwrap()
1245        .value()
1246        .unwrap(),
1247      &5
1248    );
1249
1250    assert_eq!(
1251      wtm
1252        .contains_key_comparable_cm_equivalent_pm_blocking(&five)
1253        .unwrap(),
1254      Some(true)
1255    );
1256    assert_eq!(
1257      wtm
1258        .get_comparable_cm_equivalent_pm_blocking(&five)
1259        .unwrap()
1260        .unwrap()
1261        .value()
1262        .unwrap(),
1263      &5
1264    );
1265
1266    let six = Arc::new(6);
1267
1268    assert_eq!(
1269      wtm
1270        .contains_key_comparable_cm_equivalent_pm(&six)
1271        .await
1272        .unwrap(),
1273      None
1274    );
1275    assert_eq!(
1276      wtm.get_comparable_cm_equivalent_pm(&six).await.unwrap(),
1277      None
1278    );
1279    assert_eq!(
1280      wtm
1281        .contains_key_comparable_cm_equivalent_pm_blocking(&six)
1282        .unwrap(),
1283      None
1284    );
1285    assert_eq!(
1286      wtm.get_comparable_cm_equivalent_pm_blocking(&six).unwrap(),
1287      None
1288    );
1289  }
1290
1291  #[async_std::test]
1292  async fn wtm4() {
1293    let tm = AsyncTm::<
1294      Arc<u64>,
1295      u64,
1296      TestCm<Arc<u64>>,
1297      BTreePwm<Arc<u64>, u64>,
1298      wmark::AsyncStdSpawner,
1299    >::new("test", 0)
1300    .await;
1301    let mut wtm = tm.write((), ()).await.unwrap();
1302    assert!(!wtm.is_discard());
1303    assert!(wtm.pwm().is_some());
1304    assert!(wtm.cm().is_some());
1305
1306    let mut marker = wtm.marker().unwrap();
1307
1308    let one = Arc::new(1);
1309    let two = Arc::new(2);
1310    let three = Arc::new(3);
1311    let four = Arc::new(4);
1312    let five = Arc::new(5);
1313    marker.mark(&one).await;
1314    marker.mark_blocking(&one);
1315    marker.mark_comparable(&three).await;
1316    marker.mark_comparable_blocking(&three);
1317    marker.mark_conflict(&two).await;
1318    marker.mark_conflict_blocking(&two);
1319    marker.mark_conflict_comparable(&four).await;
1320    marker.mark_conflict_comparable_blocking(&four);
1321    wtm.mark_read(&two).await;
1322    wtm.mark_read_blocking(&two);
1323    wtm.mark_read_comparable_blocking(&two);
1324    wtm.mark_conflict(&one).await;
1325    wtm.mark_conflict_blocking(&one);
1326    wtm.mark_conflict_comparable(&two).await;
1327    wtm.mark_conflict_comparable_blocking(&two);
1328    wtm.mark_read_comparable(&three).await;
1329
1330    wtm.insert(five.clone(), 5).await.unwrap();
1331
1332    assert_eq!(
1333      wtm.contains_key_comparable(&five).await.unwrap(),
1334      Some(true)
1335    );
1336    assert_eq!(
1337      wtm
1338        .get_comparable(&five)
1339        .await
1340        .unwrap()
1341        .unwrap()
1342        .value()
1343        .unwrap(),
1344      &5
1345    );
1346
1347    assert_eq!(
1348      wtm.contains_key_comparable_blocking(&five).unwrap(),
1349      Some(true)
1350    );
1351    assert_eq!(
1352      wtm
1353        .get_comparable_blocking(&five)
1354        .unwrap()
1355        .unwrap()
1356        .value()
1357        .unwrap(),
1358      &5
1359    );
1360
1361    let six = Arc::new(6);
1362
1363    assert_eq!(wtm.contains_key_comparable(&six).await.unwrap(), None);
1364    assert_eq!(wtm.get_comparable(&six).await.unwrap(), None);
1365    assert_eq!(wtm.contains_key_comparable_blocking(&six).unwrap(), None);
1366    assert_eq!(wtm.get_comparable_blocking(&six).unwrap(), None);
1367  }
1368}