1use self::error::WtmError;
2
3use core::{borrow::Borrow, future::Future, hash::Hash};
4
5use super::*;
6
7mod blocking;
8
9pub struct AsyncWtm<K, V, C, P, S>
12where
13 S: AsyncSpawner,
14{
15 pub(super) read_ts: u64,
16 pub(super) size: u64,
17 pub(super) count: u64,
18 pub(super) orc: Arc<Oracle<C, S>>,
19 pub(super) conflict_manager: Option<C>,
20
21 pub(super) pending_writes: Option<P>,
23 pub(super) duplicate_writes: OneOrMore<Entry<K, V>>,
25
26 pub(super) discarded: bool,
27 pub(super) done_read: bool,
28}
29
30impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
31where
32 S: AsyncSpawner,
33{
34 #[inline]
36 pub const fn version(&self) -> u64 {
37 self.read_ts
38 }
39
40 #[doc(hidden)]
43 #[inline]
44 pub fn __set_read_version(&mut self, version: u64) {
45 self.read_ts = version;
46 }
47
48 #[inline]
52 pub fn pwm(&self) -> Option<&P> {
53 self.pending_writes.as_ref()
54 }
55
56 #[inline]
60 pub fn cm(&self) -> Option<&C> {
61 self.conflict_manager.as_ref()
62 }
63}
64
65impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
66where
67 C: AsyncCm<Key = K>,
68 S: AsyncSpawner,
69{
70 pub fn marker(&mut self) -> Option<AsyncMarker<'_, C>> {
84 self.conflict_manager.as_mut().map(AsyncMarker::new)
85 }
86
87 pub fn marker_with_pm(&mut self) -> Option<(AsyncMarker<'_, C>, &P)> {
94 self.conflict_manager.as_mut().map(|marker| {
95 (
96 AsyncMarker::new(marker),
97 self.pending_writes.as_ref().unwrap(),
98 )
99 })
100 }
101
102 pub async fn mark_read(&mut self, k: &K) {
104 if let Some(ref mut conflict_manager) = self.conflict_manager {
105 conflict_manager.mark_read(k).await;
106 }
107 }
108
109 pub async fn mark_conflict(&mut self, k: &K) {
111 if let Some(ref mut conflict_manager) = self.conflict_manager {
112 conflict_manager.mark_conflict(k).await;
113 }
114 }
115}
116
117impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
118where
119 C: AsyncCm<Key = K>,
120 P: AsyncPwm<Key = K, Value = V>,
121 S: AsyncSpawner,
122{
123 #[inline]
125 pub async fn rollback(&mut self) -> Result<(), TransactionError<C::Error, P::Error>> {
126 if self.discarded {
127 return Err(TransactionError::Discard);
128 }
129
130 self
131 .pending_writes
132 .as_mut()
133 .unwrap()
134 .rollback()
135 .await
136 .map_err(TransactionError::Pwm)?;
137 self
138 .conflict_manager
139 .as_mut()
140 .unwrap()
141 .rollback()
142 .await
143 .map_err(TransactionError::Cm)?;
144 Ok(())
145 }
146
147 pub async fn insert(
149 &mut self,
150 key: K,
151 value: V,
152 ) -> Result<(), TransactionError<C::Error, P::Error>> {
153 self.insert_with_in(key, value).await
154 }
155
156 pub async fn remove(&mut self, key: K) -> Result<(), TransactionError<C::Error, P::Error>> {
162 self
163 .modify(Entry {
164 data: EntryData::Remove(key),
165 version: 0,
166 })
167 .await
168 }
169
170 pub async fn contains_key(
172 &mut self,
173 key: &K,
174 ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>> {
175 if self.discarded {
176 return Err(TransactionError::Discard);
177 }
178
179 match self
180 .pending_writes
181 .as_ref()
182 .unwrap()
183 .get(key)
184 .await
185 .map_err(TransactionError::pending)?
186 {
187 Some(ent) => {
188 if ent.value.is_none() {
190 return Ok(Some(false));
191 }
192
193 Ok(Some(true))
195 }
196 None => {
197 if let Some(ref mut conflict_manager) = self.conflict_manager {
200 conflict_manager.mark_read(key).await;
201 }
202
203 Ok(None)
204 }
205 }
206 }
207
208 pub async fn get<'a, 'b: 'a>(
211 &'a mut self,
212 key: &'b K,
213 ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>> {
214 if self.discarded {
215 return Err(TransactionError::Discard);
216 }
217
218 if let Some(e) = self
219 .pending_writes
220 .as_ref()
221 .unwrap()
222 .get(key)
223 .await
224 .map_err(TransactionError::Pwm)?
225 {
226 if e.value.is_none() {
228 return Ok(None);
229 }
230
231 Ok(Some(EntryRef {
233 data: match &e.value {
234 Some(value) => EntryDataRef::Insert { key, value },
235 None => EntryDataRef::Remove(key),
236 },
237 version: e.version,
238 }))
239 } else {
240 if let Some(ref mut conflict_manager) = self.conflict_manager {
243 conflict_manager.mark_read(key).await;
244 }
245
246 Ok(None)
247 }
248 }
249
250 pub async fn commit<F, Fut, O, E>(
266 &mut self,
267 apply: F,
268 ) -> Result<O, WtmError<C::Error, P::Error, E>>
269 where
270 Fut: Future<Output = Result<O, E>>,
271 F: FnOnce(OneOrMore<Entry<K, V>>) -> Fut,
272 E: std::error::Error,
273 {
274 if self.pending_writes.as_ref().unwrap().is_empty().await {
275 self.discard();
277 return apply(Default::default()).await.map_err(WtmError::commit);
278 }
279
280 match self.commit_entries().await {
281 Ok((commit_ts, entries)) => match apply(entries).await {
282 Ok(output) => {
283 self.orc.done_commit(commit_ts);
284 self.discard();
285 Ok(output)
286 }
287 Err(e) => {
288 self.orc.done_commit(commit_ts);
289 self.discard();
290 Err(WtmError::commit(e))
291 }
292 },
293 Err(e) => {
294 self.discard();
295 Err(WtmError::transaction(e))
296 }
297 }
298 }
299}
300
301impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
302where
303 C: AsyncCmEquivalent<Key = K>,
304 P: AsyncPwm<Key = K, Value = V>,
305 S: AsyncSpawner,
306{
307 pub async fn mark_read_equivalent<Q>(&mut self, k: &Q)
309 where
310 K: Borrow<Q>,
311 Q: ?Sized + Eq + Hash,
312 {
313 if let Some(ref mut conflict_manager) = self.conflict_manager {
314 conflict_manager.mark_read_equivalent(k).await;
315 }
316 }
317
318 pub async fn mark_conflict_equivalent<Q>(&mut self, k: &Q)
320 where
321 K: Borrow<Q>,
322 Q: ?Sized + Eq + Hash,
323 {
324 if let Some(ref mut conflict_manager) = self.conflict_manager {
325 conflict_manager.mark_conflict_equivalent(k).await;
326 }
327 }
328}
329
330impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
331where
332 C: AsyncCmEquivalent<Key = K>,
333 P: AsyncPwmEquivalent<Key = K, Value = V>,
334 S: AsyncSpawner,
335{
336 pub async fn contains_key_equivalent<'a, 'b: 'a, Q>(
342 &'a mut self,
343 key: &'b Q,
344 ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
345 where
346 K: Borrow<Q>,
347 Q: ?Sized + Eq + Hash,
348 {
349 if self.discarded {
350 return Err(TransactionError::Discard);
351 }
352
353 match self
354 .pending_writes
355 .as_ref()
356 .unwrap()
357 .get_equivalent(key)
358 .await
359 .map_err(TransactionError::pending)?
360 {
361 Some(ent) => {
362 if ent.value.is_none() {
364 return Ok(Some(false));
365 }
366
367 Ok(Some(true))
369 }
370 None => {
371 if let Some(ref mut conflict_manager) = self.conflict_manager {
374 conflict_manager.mark_read_equivalent(key).await;
375 }
376
377 Ok(None)
378 }
379 }
380 }
381
382 pub async fn get_equivalent<'a, 'b: 'a, Q>(
385 &'a mut self,
386 key: &'b Q,
387 ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
388 where
389 K: Borrow<Q>,
390 Q: ?Sized + Eq + Hash,
391 {
392 if self.discarded {
393 return Err(TransactionError::Discard);
394 }
395
396 if let Some((k, e)) = self
397 .pending_writes
398 .as_ref()
399 .unwrap()
400 .get_entry_equivalent(key)
401 .await
402 .map_err(TransactionError::Pwm)?
403 {
404 if e.value.is_none() {
406 return Ok(None);
407 }
408
409 Ok(Some(EntryRef {
411 data: match &e.value {
412 Some(value) => EntryDataRef::Insert { key: k, value },
413 None => EntryDataRef::Remove(k),
414 },
415 version: e.version,
416 }))
417 } else {
418 if let Some(ref mut conflict_manager) = self.conflict_manager {
421 conflict_manager.mark_read_equivalent(key).await;
422 }
423
424 Ok(None)
425 }
426 }
427}
428
429impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
430where
431 C: AsyncCmComparable<Key = K>,
432 P: AsyncPwmEquivalent<Key = K, Value = V>,
433 S: AsyncSpawner,
434{
435 pub async fn contains_key_comparable_cm_equivalent_pm<'a, 'b: 'a, Q>(
441 &'a mut self,
442 key: &'b Q,
443 ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
444 where
445 K: Borrow<Q>,
446 Q: ?Sized + Eq + Ord + Hash,
447 {
448 match self
449 .pending_writes
450 .as_ref()
451 .unwrap()
452 .get_equivalent(key)
453 .await
454 .map_err(TransactionError::pending)?
455 {
456 Some(ent) => {
457 if ent.value.is_none() {
459 return Ok(Some(false));
460 }
461
462 Ok(Some(true))
464 }
465 None => {
466 if let Some(ref mut conflict_manager) = self.conflict_manager {
469 conflict_manager.mark_read_comparable(key).await;
470 }
471
472 Ok(None)
473 }
474 }
475 }
476
477 pub async fn get_comparable_cm_equivalent_pm<'a, 'b: 'a, Q>(
480 &'a mut self,
481 key: &'b Q,
482 ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
483 where
484 K: Borrow<Q>,
485 Q: ?Sized + Eq + Ord + Hash,
486 {
487 if let Some((k, e)) = self
488 .pending_writes
489 .as_ref()
490 .unwrap()
491 .get_entry_equivalent(key)
492 .await
493 .map_err(TransactionError::Pwm)?
494 {
495 if e.value.is_none() {
497 return Ok(None);
498 }
499
500 Ok(Some(EntryRef {
502 data: match &e.value {
503 Some(value) => EntryDataRef::Insert { key: k, value },
504 None => EntryDataRef::Remove(k),
505 },
506 version: e.version,
507 }))
508 } else {
509 if let Some(ref mut conflict_manager) = self.conflict_manager {
512 conflict_manager.mark_read_comparable(key).await;
513 }
514
515 Ok(None)
516 }
517 }
518}
519
520impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
521where
522 C: AsyncCmComparable<Key = K>,
523 S: AsyncSpawner,
524{
525 pub async fn mark_read_comparable<Q>(&mut self, k: &Q)
527 where
528 K: Borrow<Q>,
529 Q: ?Sized + Ord,
530 {
531 if let Some(ref mut conflict_manager) = self.conflict_manager {
532 conflict_manager.mark_read_comparable(k).await;
533 }
534 }
535
536 pub async fn mark_conflict_comparable<Q>(&mut self, k: &Q)
538 where
539 K: Borrow<Q>,
540 Q: ?Sized + Ord,
541 {
542 if let Some(ref mut conflict_manager) = self.conflict_manager {
543 conflict_manager.mark_conflict_comparable(k).await;
544 }
545 }
546}
547
548impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
549where
550 C: AsyncCmComparable<Key = K>,
551 P: AsyncPwmComparable<Key = K, Value = V>,
552 S: AsyncSpawner,
553{
554 pub async fn contains_key_comparable<'a, 'b: 'a, Q>(
560 &'a mut self,
561 key: &'b Q,
562 ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
563 where
564 K: Borrow<Q>,
565 Q: ?Sized + Ord,
566 {
567 match self
568 .pending_writes
569 .as_ref()
570 .unwrap()
571 .get_comparable(key)
572 .await
573 .map_err(TransactionError::pending)?
574 {
575 Some(ent) => {
576 if ent.value.is_none() {
578 return Ok(Some(false));
579 }
580
581 Ok(Some(true))
583 }
584 None => {
585 if let Some(ref mut conflict_manager) = self.conflict_manager {
588 conflict_manager.mark_read_comparable(key).await;
589 }
590
591 Ok(None)
592 }
593 }
594 }
595
596 pub async fn get_comparable<'a, 'b: 'a, Q>(
599 &'a mut self,
600 key: &'b Q,
601 ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
602 where
603 K: Borrow<Q>,
604 Q: ?Sized + Ord,
605 {
606 if let Some((k, e)) = self
607 .pending_writes
608 .as_ref()
609 .unwrap()
610 .get_entry_comparable(key)
611 .await
612 .map_err(TransactionError::Pwm)?
613 {
614 if e.value.is_none() {
616 return Ok(None);
617 }
618
619 Ok(Some(EntryRef {
621 data: match &e.value {
622 Some(value) => EntryDataRef::Insert { key: k, value },
623 None => EntryDataRef::Remove(k),
624 },
625 version: e.version,
626 }))
627 } else {
628 if let Some(ref mut conflict_manager) = self.conflict_manager {
631 conflict_manager.mark_read_comparable(key).await;
632 }
633
634 Ok(None)
635 }
636 }
637}
638
639impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
640where
641 C: AsyncCmEquivalent<Key = K>,
642 P: AsyncPwmComparable<Key = K, Value = V>,
643 S: AsyncSpawner,
644{
645 pub async fn contains_key_equivalent_cm_comparable_pm<'a, 'b: 'a, Q>(
651 &'a mut self,
652 key: &'b Q,
653 ) -> Result<Option<bool>, TransactionError<C::Error, P::Error>>
654 where
655 K: Borrow<Q>,
656 Q: ?Sized + Eq + Ord + Hash,
657 {
658 match self
659 .pending_writes
660 .as_ref()
661 .unwrap()
662 .get_comparable(key)
663 .await
664 .map_err(TransactionError::pending)?
665 {
666 Some(ent) => {
667 if ent.value.is_none() {
669 return Ok(Some(false));
670 }
671
672 Ok(Some(true))
674 }
675 None => {
676 if let Some(ref mut conflict_manager) = self.conflict_manager {
679 conflict_manager.mark_read_equivalent(key).await;
680 }
681
682 Ok(None)
683 }
684 }
685 }
686
687 pub async fn get_equivalent_cm_comparable_pm<'a, 'b: 'a, Q>(
690 &'a mut self,
691 key: &'b Q,
692 ) -> Result<Option<EntryRef<'a, K, V>>, TransactionError<C::Error, P::Error>>
693 where
694 K: Borrow<Q>,
695 Q: ?Sized + Eq + Ord + Hash,
696 {
697 if let Some((k, e)) = self
698 .pending_writes
699 .as_ref()
700 .unwrap()
701 .get_entry_comparable(key)
702 .await
703 .map_err(TransactionError::Pwm)?
704 {
705 if e.value.is_none() {
707 return Ok(None);
708 }
709
710 Ok(Some(EntryRef {
712 data: match &e.value {
713 Some(value) => EntryDataRef::Insert { key: k, value },
714 None => EntryDataRef::Remove(k),
715 },
716 version: e.version,
717 }))
718 } else {
719 if let Some(ref mut conflict_manager) = self.conflict_manager {
722 conflict_manager.mark_read_equivalent(key).await;
723 }
724
725 Ok(None)
726 }
727 }
728}
729
730impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
731where
732 C: AsyncCm<Key = K> + Send,
733 P: AsyncPwm<Key = K, Value = V> + Send,
734 S: AsyncSpawner,
735{
736 pub async fn commit_with_task<F, Fut, CFut, E, R>(
752 &mut self,
753 apply: F,
754 fut: impl FnOnce(Result<(), E>) -> CFut + Send + 'static,
755 ) -> Result<<S as AsyncSpawner>::JoinHandle<R>, WtmError<C::Error, P::Error, E>>
756 where
757 K: Send + 'static,
758 V: Send + 'static,
759 Fut: Future<Output = Result<(), E>> + Send,
760 F: FnOnce(OneOrMore<Entry<K, V>>) -> Fut + Send + 'static,
761 CFut: Future<Output = R> + Send + 'static,
762 E: std::error::Error + Send,
763 C: 'static,
764 R: Send + 'static,
765 {
766 if self.discarded {
767 return Err(WtmError::transaction(TransactionError::Discard));
768 }
769
770 if self.pending_writes.as_ref().unwrap().is_empty().await {
771 self.discard();
773 return Ok(S::spawn(async move { fut(Ok(())).await }));
774 }
775
776 match self.commit_entries().await {
777 Ok((commit_ts, entries)) => {
778 let orc = self.orc.clone();
779 Ok(S::spawn(async move {
780 match apply(entries).await {
781 Ok(_) => {
782 orc.done_commit(commit_ts);
783 fut(Ok(())).await
784 }
785 Err(e) => {
786 orc.done_commit(commit_ts);
787 fut(Err(e)).await
788 }
789 }
790 }))
791 }
792 Err(e) => match e {
793 TransactionError::Conflict => Err(WtmError::transaction(e)),
794 _ => {
795 self.discard();
796 Err(WtmError::transaction(e))
797 }
798 },
799 }
800 }
801}
802
803impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
804where
805 C: AsyncCm<Key = K>,
806 P: AsyncPwm<Key = K, Value = V>,
807 S: AsyncSpawner,
808{
809 async fn insert_with_in(
810 &mut self,
811 key: K,
812 value: V,
813 ) -> Result<(), TransactionError<C::Error, P::Error>> {
814 let ent = Entry {
815 data: EntryData::Insert { key, value },
816 version: self.read_ts,
817 };
818
819 self.modify(ent).await
820 }
821
822 async fn modify(&mut self, ent: Entry<K, V>) -> Result<(), TransactionError<C::Error, P::Error>> {
823 if self.discarded {
824 return Err(TransactionError::Discard);
825 }
826
827 let pending_writes = self.pending_writes.as_mut().unwrap();
828 pending_writes
829 .validate_entry(&ent)
830 .await
831 .map_err(TransactionError::Pwm)?;
832
833 let cnt = self.count + 1;
834 let size = self.size + pending_writes.estimate_size(&ent);
836 if cnt >= pending_writes.max_batch_entries() || size >= pending_writes.max_batch_size() {
837 return Err(TransactionError::LargeTxn);
838 }
839
840 self.count = cnt;
841 self.size = size;
842
843 if let Some(ref mut conflict_manager) = self.conflict_manager {
846 conflict_manager.mark_conflict(ent.key()).await;
847 }
848
849 let eversion = ent.version;
853 let (ek, ev) = ent.split();
854
855 if let Some((old_key, old_value)) = pending_writes
856 .remove_entry(&ek)
857 .await
858 .map_err(TransactionError::Pwm)?
859 {
860 if old_value.version != eversion {
861 self
862 .duplicate_writes
863 .push(Entry::unsplit(old_key, old_value));
864 }
865 }
866 pending_writes
867 .insert(ek, ev)
868 .await
869 .map_err(TransactionError::Pwm)?;
870
871 Ok(())
872 }
873
874 async fn commit_entries(
875 &mut self,
876 ) -> Result<(u64, OneOrMore<Entry<K, V>>), TransactionError<C::Error, P::Error>> {
877 let _write_lock = self.orc.write_serialize_lock.lock().await;
882
883 let conflict_manager = if self.conflict_manager.is_none() {
884 None
885 } else {
886 mem::take(&mut self.conflict_manager)
887 };
888
889 match self
890 .orc
891 .new_commit_ts(&mut self.done_read, self.read_ts, conflict_manager)
892 .await
893 {
894 CreateCommitTimestampResult::Conflict(conflict_manager) => {
895 self.conflict_manager = conflict_manager;
898 Err(TransactionError::Conflict)
899 }
900 CreateCommitTimestampResult::Timestamp(commit_ts) => {
901 let pending_writes = mem::take(&mut self.pending_writes).unwrap();
902 let duplicate_writes = mem::take(&mut self.duplicate_writes);
903 let mut entries =
904 OneOrMore::with_capacity(pending_writes.len().await + self.duplicate_writes.len());
905
906 let process_entry = |entries: &mut OneOrMore<Entry<K, V>>, mut ent: Entry<K, V>| {
907 ent.version = commit_ts;
908 entries.push(ent);
909 };
910 pending_writes
911 .into_iter()
912 .await
913 .for_each(|(k, v)| process_entry(&mut entries, Entry::unsplit(k, v)));
914 duplicate_writes
915 .into_iter()
916 .for_each(|ent| process_entry(&mut entries, ent));
917
918 assert_ne!(commit_ts, 0);
920
921 Ok((commit_ts, entries))
922 }
923 }
924 }
925}
926
927impl<K, V, C, P, S> AsyncWtm<K, V, C, P, S>
928where
929 S: AsyncSpawner,
930{
931 fn done_read(&mut self) {
932 if !self.done_read {
933 self.done_read = true;
934 self.orc().read_mark.done(self.read_ts).unwrap();
935 }
936 }
937
938 #[inline]
939 fn orc(&self) -> &Oracle<C, S> {
940 &self.orc
941 }
942
943 pub fn discard(&mut self) {
948 if self.discarded {
949 return;
950 }
951 self.discarded = true;
952 self.done_read();
953 }
954
955 #[inline]
957 pub const fn is_discard(&self) -> bool {
958 self.discarded
959 }
960}
961
962impl<K, V, C, P, S> Drop for AsyncWtm<K, V, C, P, S>
963where
964 S: AsyncSpawner,
965{
966 fn drop(&mut self) {
967 if !self.discarded {
968 self.discard();
969 }
970 }
971}
972
973#[cfg(test)]
974mod tests {
975 use std::{collections::BTreeSet, convert::Infallible, marker::PhantomData};
976
977 use super::*;
978
979 #[async_std::test]
980 async fn wtm() {
981 let tm = AsyncTm::<String, u64, HashCm<String>, IndexMapPwm<String, u64>, wmark::AsyncStdSpawner>::new("test", 0).await;
982 let mut wtm = tm
983 .write(Default::default(), Default::default())
984 .await
985 .unwrap();
986 assert!(!wtm.is_discard());
987 assert!(wtm.pwm().is_some());
988 assert!(wtm.cm().is_some());
989
990 let mut marker = wtm.marker().unwrap();
991
992 marker.mark(&"1".to_owned()).await;
993 marker.mark_equivalent("3").await;
994 marker.mark_conflict(&"2".to_owned()).await;
995 marker.mark_conflict_equivalent("4").await;
996 wtm.mark_read(&"2".to_owned()).await;
997 wtm.mark_conflict(&"1".to_owned()).await;
998 wtm.mark_conflict_equivalent("2").await;
999 wtm.mark_read_equivalent("3").await;
1000
1001 wtm.insert("5".into(), 5).await.unwrap();
1002
1003 assert_eq!(wtm.contains_key_equivalent("5").await.unwrap(), Some(true));
1004 assert_eq!(
1005 wtm
1006 .get_equivalent("5")
1007 .await
1008 .unwrap()
1009 .unwrap()
1010 .value()
1011 .unwrap(),
1012 &5
1013 );
1014
1015 assert_eq!(wtm.contains_key(&"5".to_owned()).await.unwrap(), Some(true));
1016 assert_eq!(
1017 wtm
1018 .get(&"5".to_owned())
1019 .await
1020 .unwrap()
1021 .unwrap()
1022 .value()
1023 .unwrap(),
1024 &5
1025 );
1026
1027 assert_eq!(wtm.contains_key_equivalent("6").await.unwrap(), None);
1028 assert_eq!(wtm.get_equivalent("6").await.unwrap(), None);
1029 assert_eq!(wtm.contains_key_blocking(&"6".to_owned()).unwrap(), None);
1030
1031 wtm.remove("5".into()).await.unwrap();
1032 wtm.rollback().await.unwrap();
1033
1034 wtm
1035 .commit::<_, _, _, Infallible>(|_| async { Ok(()) })
1036 .await
1037 .unwrap();
1038
1039 assert!(wtm.is_discard());
1040 }
1041
1042 #[async_std::test]
1043 async fn wtm2() {
1044 let tm =
1045 AsyncTm::<String, u64, HashCm<String>, BTreePwm<String, u64>, wmark::AsyncStdSpawner>::new(
1046 "test", 0,
1047 )
1048 .await;
1049 let mut wtm = tm.write((), Default::default()).await.unwrap();
1050 assert!(!wtm.is_discard());
1051 assert!(wtm.pwm().is_some());
1052 assert!(wtm.cm().is_some());
1053 assert!(wtm.blocking_marker().is_some());
1054 assert!(wtm.marker_with_pm().is_some());
1055
1056 let mut marker = wtm.marker().unwrap();
1057
1058 marker.mark(&"1".to_owned()).await;
1059 marker.mark_blocking(&"3".to_owned());
1060 marker.mark_equivalent("3").await;
1061 marker.mark_equivalent_blocking("3");
1062 marker.mark_conflict(&"2".to_owned()).await;
1063 marker.mark_conflict_equivalent_blocking("4");
1064 marker.mark_conflict_equivalent("4").await;
1065 wtm.mark_read(&"2".to_owned()).await;
1066 wtm.mark_read_blocking(&"3".to_owned());
1067 wtm.mark_read_equivalent_blocking("3");
1068 wtm.mark_conflict(&"1".to_owned()).await;
1069 wtm.mark_conflict_equivalent("2").await;
1070 wtm.mark_conflict_equivalent_blocking("2");
1071 wtm.mark_read_equivalent("3").await;
1072
1073 wtm.insert("5".into(), 5).await.unwrap();
1074
1075 assert_eq!(
1076 wtm
1077 .contains_key_equivalent_cm_comparable_pm("5")
1078 .await
1079 .unwrap(),
1080 Some(true)
1081 );
1082 assert_eq!(
1083 wtm
1084 .get_equivalent_cm_comparable_pm("5")
1085 .await
1086 .unwrap()
1087 .unwrap()
1088 .value()
1089 .unwrap(),
1090 &5
1091 );
1092
1093 assert_eq!(wtm.contains_key(&"5".to_owned()).await.unwrap(), Some(true));
1094 assert_eq!(
1095 wtm
1096 .get(&"5".to_owned())
1097 .await
1098 .unwrap()
1099 .unwrap()
1100 .value()
1101 .unwrap(),
1102 &5
1103 );
1104
1105 assert_eq!(
1106 wtm
1107 .contains_key_equivalent_cm_comparable_pm("6")
1108 .await
1109 .unwrap(),
1110 None
1111 );
1112 assert_eq!(
1113 wtm.get_equivalent_cm_comparable_pm("6").await.unwrap(),
1114 None
1115 );
1116 assert_eq!(wtm.contains_key(&"6".to_owned()).await.unwrap(), None);
1117 assert_eq!(wtm.get(&"6".to_owned()).await.unwrap(), None);
1118
1119 wtm.remove("5".into()).await.unwrap();
1120 wtm.rollback().await.unwrap();
1121
1122 wtm
1123 .commit::<_, _, _, Infallible>(|_| async { Ok(()) })
1124 .await
1125 .unwrap();
1126
1127 assert!(wtm.is_discard());
1128 }
1129
1130 struct TestCm<K> {
1131 conflict_keys: BTreeSet<usize>,
1132 reads: BTreeSet<usize>,
1133 _m: PhantomData<K>,
1134 }
1135
1136 impl<K> Cm for TestCm<K> {
1137 type Error = Infallible;
1138
1139 type Key = K;
1140
1141 type Options = ();
1142
1143 fn new(_options: Self::Options) -> Result<Self, Self::Error> {
1144 Ok(Self {
1145 conflict_keys: BTreeSet::new(),
1146 reads: BTreeSet::new(),
1147 _m: PhantomData,
1148 })
1149 }
1150
1151 fn mark_read(&mut self, key: &Self::Key) {
1152 self.reads.insert(key as *const K as usize);
1153 }
1154
1155 fn mark_conflict(&mut self, key: &Self::Key) {
1156 self.conflict_keys.insert(key as *const K as usize);
1157 }
1158
1159 fn has_conflict(&self, other: &Self) -> bool {
1160 if self.reads.is_empty() {
1161 return false;
1162 }
1163
1164 for ro in self.reads.iter() {
1165 if other.conflict_keys.contains(ro) {
1166 return true;
1167 }
1168 }
1169 false
1170 }
1171
1172 fn rollback(&mut self) -> Result<(), Self::Error> {
1173 self.conflict_keys.clear();
1174 self.reads.clear();
1175 Ok(())
1176 }
1177 }
1178
1179 impl<K> CmComparable for TestCm<K> {
1180 fn mark_read_comparable<Q>(&mut self, key: &Q)
1181 where
1182 Self::Key: Borrow<Q>,
1183 Q: Ord + ?Sized,
1184 {
1185 self.reads.insert(key as *const Q as *const () as usize);
1186 }
1187
1188 fn mark_conflict_comparable<Q>(&mut self, key: &Q)
1189 where
1190 Self::Key: Borrow<Q>,
1191 Q: Ord + ?Sized,
1192 {
1193 self
1194 .conflict_keys
1195 .insert(key as *const Q as *const () as usize);
1196 }
1197 }
1198
1199 #[async_std::test]
1200 async fn wtm3() {
1201 let tm = AsyncTm::<
1202 Arc<u64>,
1203 u64,
1204 TestCm<Arc<u64>>,
1205 IndexMapPwm<Arc<u64>, u64>,
1206 wmark::AsyncStdSpawner,
1207 >::new("test", 0)
1208 .await;
1209 let mut wtm = tm.write(Default::default(), ()).await.unwrap();
1210 assert!(!wtm.is_discard());
1211 assert!(wtm.pwm().is_some());
1212 assert!(wtm.cm().is_some());
1213
1214 let mut marker = wtm.marker().unwrap();
1215
1216 let one = Arc::new(1);
1217 let two = Arc::new(2);
1218 let three = Arc::new(3);
1219 let four = Arc::new(4);
1220 let five = Arc::new(5);
1221 marker.mark(&one).await;
1222 marker.mark_comparable(&three).await;
1223 marker.mark_conflict(&two).await;
1224 marker.mark_conflict_comparable(&four).await;
1225 wtm.mark_read(&two).await;
1226 wtm.mark_conflict(&one).await;
1227 wtm.mark_conflict_comparable(&two).await;
1228 wtm.mark_read_comparable(&three).await;
1229
1230 wtm.insert(five.clone(), 5).await.unwrap();
1231
1232 assert_eq!(
1233 wtm
1234 .contains_key_comparable_cm_equivalent_pm(&five)
1235 .await
1236 .unwrap(),
1237 Some(true)
1238 );
1239 assert_eq!(
1240 wtm
1241 .get_comparable_cm_equivalent_pm(&five)
1242 .await
1243 .unwrap()
1244 .unwrap()
1245 .value()
1246 .unwrap(),
1247 &5
1248 );
1249
1250 assert_eq!(
1251 wtm
1252 .contains_key_comparable_cm_equivalent_pm_blocking(&five)
1253 .unwrap(),
1254 Some(true)
1255 );
1256 assert_eq!(
1257 wtm
1258 .get_comparable_cm_equivalent_pm_blocking(&five)
1259 .unwrap()
1260 .unwrap()
1261 .value()
1262 .unwrap(),
1263 &5
1264 );
1265
1266 let six = Arc::new(6);
1267
1268 assert_eq!(
1269 wtm
1270 .contains_key_comparable_cm_equivalent_pm(&six)
1271 .await
1272 .unwrap(),
1273 None
1274 );
1275 assert_eq!(
1276 wtm.get_comparable_cm_equivalent_pm(&six).await.unwrap(),
1277 None
1278 );
1279 assert_eq!(
1280 wtm
1281 .contains_key_comparable_cm_equivalent_pm_blocking(&six)
1282 .unwrap(),
1283 None
1284 );
1285 assert_eq!(
1286 wtm.get_comparable_cm_equivalent_pm_blocking(&six).unwrap(),
1287 None
1288 );
1289 }
1290
1291 #[async_std::test]
1292 async fn wtm4() {
1293 let tm = AsyncTm::<
1294 Arc<u64>,
1295 u64,
1296 TestCm<Arc<u64>>,
1297 BTreePwm<Arc<u64>, u64>,
1298 wmark::AsyncStdSpawner,
1299 >::new("test", 0)
1300 .await;
1301 let mut wtm = tm.write((), ()).await.unwrap();
1302 assert!(!wtm.is_discard());
1303 assert!(wtm.pwm().is_some());
1304 assert!(wtm.cm().is_some());
1305
1306 let mut marker = wtm.marker().unwrap();
1307
1308 let one = Arc::new(1);
1309 let two = Arc::new(2);
1310 let three = Arc::new(3);
1311 let four = Arc::new(4);
1312 let five = Arc::new(5);
1313 marker.mark(&one).await;
1314 marker.mark_blocking(&one);
1315 marker.mark_comparable(&three).await;
1316 marker.mark_comparable_blocking(&three);
1317 marker.mark_conflict(&two).await;
1318 marker.mark_conflict_blocking(&two);
1319 marker.mark_conflict_comparable(&four).await;
1320 marker.mark_conflict_comparable_blocking(&four);
1321 wtm.mark_read(&two).await;
1322 wtm.mark_read_blocking(&two);
1323 wtm.mark_read_comparable_blocking(&two);
1324 wtm.mark_conflict(&one).await;
1325 wtm.mark_conflict_blocking(&one);
1326 wtm.mark_conflict_comparable(&two).await;
1327 wtm.mark_conflict_comparable_blocking(&two);
1328 wtm.mark_read_comparable(&three).await;
1329
1330 wtm.insert(five.clone(), 5).await.unwrap();
1331
1332 assert_eq!(
1333 wtm.contains_key_comparable(&five).await.unwrap(),
1334 Some(true)
1335 );
1336 assert_eq!(
1337 wtm
1338 .get_comparable(&five)
1339 .await
1340 .unwrap()
1341 .unwrap()
1342 .value()
1343 .unwrap(),
1344 &5
1345 );
1346
1347 assert_eq!(
1348 wtm.contains_key_comparable_blocking(&five).unwrap(),
1349 Some(true)
1350 );
1351 assert_eq!(
1352 wtm
1353 .get_comparable_blocking(&five)
1354 .unwrap()
1355 .unwrap()
1356 .value()
1357 .unwrap(),
1358 &5
1359 );
1360
1361 let six = Arc::new(6);
1362
1363 assert_eq!(wtm.contains_key_comparable(&six).await.unwrap(), None);
1364 assert_eq!(wtm.get_comparable(&six).await.unwrap(), None);
1365 assert_eq!(wtm.contains_key_comparable_blocking(&six).unwrap(), None);
1366 assert_eq!(wtm.get_comparable_blocking(&six).unwrap(), None);
1367 }
1368}