1use crossbeam_epoch::Guard;
56use std::borrow::Borrow;
57use std::marker::PhantomData;
58use std::mem::{align_of_val, size_of};
59use std::ops::Deref;
60use std::option::Option::Some;
61use std::ptr;
62use std::rc::Rc;
63use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
64
65const STATUS_PREPARE: u8 = 0;
66const STATUS_COMPLETED: u8 = 1;
67const STATUS_FAILED: u8 = 2;
68
69#[derive(Debug)]
75#[repr(transparent)]
76pub struct HeapPointer<T> {
77 ptr: AtomicU64,
78 phantom: PhantomData<T>,
79}
80
81impl<T> HeapPointer<T> {
82 #[inline]
84 pub fn new(val: T) -> Self {
85 let val_address = Box::into_raw(Box::new(val)) as u64;
86 HeapPointer {
87 ptr: AtomicU64::new(val_address),
88 phantom: PhantomData {},
89 }
90 }
91
92 #[inline]
94 pub fn read<'g>(&'g self, guard: &'g Guard) -> &'g T {
95 unsafe { &*self.read_ptr(guard) }
96 }
97
98 #[inline]
100 pub fn read_mut<'g>(&'g mut self, guard: &'g Guard) -> &'g mut T {
101 unsafe { &mut *self.read_ptr(guard) }
102 }
103
104 #[inline]
105 fn read_ptr(&self, guard: &Guard) -> *mut T {
106 read_val(&self.ptr, guard) as *mut u8 as *mut T
107 }
108}
109
110#[inline]
111fn read_val(ptr: &AtomicU64, guard: &Guard) -> u64 {
112 loop {
113 let cur_val = ptr.load(Ordering::Acquire);
114 if let Some(mwcas_ptr) = MwCasPointer::from_poisoned(cur_val, guard) {
115 mwcas_ptr.exec_internal(guard);
116 } else {
117 return cur_val;
118 }
119 }
120}
121
122impl<T: Clone> Clone for HeapPointer<T> {
123 fn clone(&self) -> Self {
124 let val = self.read(&crossbeam_epoch::pin()).clone();
125 HeapPointer::new(val)
126 }
127}
128
129impl<T> Drop for HeapPointer<T> {
130 fn drop(&mut self) {
131 unsafe {
132 drop(Box::from_raw(
133 self.read_ptr(crossbeam_epoch::unprotected()),
136 ));
137 }
138 }
139}
140
141unsafe impl<T: Send> Send for HeapPointer<T> {}
142unsafe impl<T: Sync> Sync for HeapPointer<T> {}
143
144#[derive(Debug)]
150#[repr(transparent)]
151pub struct U64Pointer {
152 val: AtomicU64,
153}
154
155impl U64Pointer {
156 #[inline]
158 pub fn new(val: u64) -> Self {
159 Self {
160 val: AtomicU64::new(val),
161 }
162 }
163
164 #[inline]
166 pub fn read(&self, guard: &Guard) -> u64 {
167 read_val(&self.val, guard)
168 }
169}
170
171impl Clone for U64Pointer {
172 fn clone(&self) -> Self {
173 U64Pointer::new(self.read(&crossbeam_epoch::pin()))
174 }
175}
176
177unsafe impl Send for U64Pointer {}
178unsafe impl Sync for U64Pointer {}
179
180#[cfg(target_arch = "x86_64")]
188pub struct MwCas<'g> {
189 inner: Box<MwCasInner<'g>>,
191 success: AtomicBool,
193 phantom: PhantomData<Rc<u8>>,
195}
196
197impl<'g> Default for MwCas<'g> {
198 fn default() -> Self {
199 Self::new()
200 }
201}
202
203impl<'g> MwCas<'g> {
204 #[inline]
206 pub fn new() -> Self {
207 MwCas {
208 inner: Box::new(MwCasInner {
209 status: AtomicU8::new(STATUS_PREPARE),
210 cas_ops: Vec::with_capacity(2),
211 }),
212 success: AtomicBool::new(false),
213 phantom: PhantomData {},
214 }
215 }
216
217 #[inline]
225 pub fn compare_exchange<T>(&mut self, target: &'g HeapPointer<T>, orig_val: &'g T, new_val: T) {
226 #[cfg(debug_assertions)]
227 {
228 for cas in &self.inner.cas_ops {
229 if ptr::eq(cas.target_ptr, &target.ptr as *const AtomicU64) {
230 panic!(
231 "MwCAS cannot compare-and-swap the same {} several times in one execution.
232 Remove duplicate target reference passed to 'add/with' method.
233 This can happen if you use unsafe code which skips borrowing rules
234 checker of Rust: target parameter declared as mutable reference and
235 cannot be added twice to MwCAS by 'safe' code.",
236 std::any::type_name::<HeapPointer<T>>()
237 )
238 }
239 }
240 }
241 let orig_val_ptr = orig_val as *const T as *mut T;
242 let orig_val_addr = orig_val_ptr as u64;
243 let new_val_ptr = Box::into_raw(Box::new(new_val));
244 let new_val_addr = new_val_ptr as u64;
245 let drop_fn: Box<dyn Fn(bool) + 'g> = Box::new(move |success| {
246 if success {
247 drop(unsafe { Box::from_raw(orig_val_ptr) })
248 } else {
249 drop(unsafe { Box::from_raw(new_val_ptr) })
250 }
251 });
252 self.inner.cas_ops.push(Cas::new(
253 &target.ptr as *const AtomicU64 as *mut AtomicU64,
254 orig_val_addr,
255 new_val_addr,
256 drop_fn,
257 ));
258 }
259
260 #[inline]
266 pub fn compare_exchange_u64(&mut self, target: &'g U64Pointer, orig_val: u64, new_val: u64) {
267 #[cfg(debug_assertions)]
268 {
269 for cas in &self.inner.cas_ops {
270 if ptr::eq(cas.target_ptr, &target.val as *const AtomicU64) {
271 panic!(
272 "MwCAS cannot compare-and-swap the same {} several times in one execution.
273 Remove duplicate target reference passed to 'add/with' method.
274 This can happen if you use unsafe code which skips borrowing rules
275 checker of Rust: target parameter declared as mutable reference and
276 cannot be added twice to MwCAS by 'safe' code.",
277 std::any::type_name::<U64Pointer>()
278 )
279 }
280 }
281 }
282
283 let drop_fn: Box<dyn Fn(bool) + 'g> = Box::new(move |_| {});
284 self.inner.cas_ops.push(Cas::new(
285 &target.val as *const AtomicU64 as *mut AtomicU64,
286 *orig_val.borrow(),
287 *new_val.borrow(),
288 drop_fn,
289 ));
290 }
291
292 #[inline]
297 pub fn exec(self, guard: &Guard) -> bool {
298 let successful_cas = self.inner.exec_internal(guard);
299 self.success.store(successful_cas, Ordering::Release);
302 unsafe {
303 guard.defer_unchecked(move || {
304 drop(self);
305 });
306 }
307 successful_cas
308 }
309}
310
311impl<'g> Drop for MwCas<'g> {
312 fn drop(&mut self) {
313 for cas in &self.inner.cas_ops {
317 (cas.drop_fn)(self.success.load(Ordering::Acquire));
318 }
319 }
320}
321
322struct MwCasInner<'g> {
323 status: AtomicU8,
325 cas_ops: Vec<Cas<'g>>,
327}
328
329impl<'g> MwCasInner<'g> {
330 #[inline(always)]
331 fn status(&self) -> u8 {
332 self.status.load(Ordering::Acquire)
333 }
334
335 #[inline]
336 fn exec_internal(&self, guard: &Guard) -> bool {
337 let phase_one_status = self.phase_one(guard);
338 let phase_two_status = self.update_status(phase_one_status);
339 match phase_two_status {
340 Ok(status) => self.phase_two(status),
341 Err(cur_status) => {
342 self.phase_two(cur_status);
343 }
344 }
345 phase_two_status.map_or_else(|status| status, |status| status) == STATUS_COMPLETED
346 }
347
348 fn phase_one(&self, guard: &Guard) -> u8 {
350 for cas in &self.cas_ops {
351 loop {
352 match cas.prepare(self, guard) {
353 CasPrepareResult::Conflict(mwcas_ptr) => {
354 if &mwcas_ptr != self.deref() {
355 mwcas_ptr.exec_internal(guard);
357 } else {
358 break;
361 }
362 }
363 CasPrepareResult::Success => break,
364 CasPrepareResult::Failed => return STATUS_FAILED,
365 }
366 }
367 }
368 STATUS_COMPLETED
369 }
370
371 #[inline]
372 fn update_status(&self, new_status: u8) -> Result<u8, u8> {
373 if let Err(prev_status) = self.status.compare_exchange(
374 STATUS_PREPARE,
375 new_status,
376 Ordering::AcqRel,
377 Ordering::Acquire,
378 ) {
379 Err(prev_status)
394 } else {
395 Ok(new_status)
396 }
397 }
398
399 fn phase_two(&self, mwcas_status: u8) {
401 let mwcas_ptr = MwCasPointer::from(self.deref());
404 for cas in &self.cas_ops {
405 cas.complete(mwcas_status, &mwcas_ptr);
406 }
407 }
408}
409
410#[derive(Copy, Clone)]
411#[repr(transparent)]
412struct MwCasPointer<'g> {
413 mwcas: &'g MwCasInner<'g>,
414}
415
416impl<'g> Deref for MwCasPointer<'g> {
417 type Target = MwCasInner<'g>;
418
419 fn deref(&self) -> &Self::Target {
420 self.mwcas
421 }
422}
423
424impl<'g> MwCasPointer<'g> {
425 const MWCAS_FLAG: u64 = 0x4000_0000_0000_0000;
426
427 #[inline]
429 fn from_poisoned(poisoned_addr: u64, _: &'g Guard) -> Option<MwCasPointer<'g>> {
430 let valid_addr = poisoned_addr & !Self::MWCAS_FLAG;
431 if poisoned_addr != valid_addr {
432 Option::Some(MwCasPointer {
433 mwcas: unsafe { &*(valid_addr as *const u64 as *const MwCasInner) },
436 })
437 } else {
438 Option::None
441 }
442 }
443
444 #[inline(always)]
448 fn poisoned(&self) -> u64 {
449 let addr = self.mwcas as *const MwCasInner as *const u64 as u64;
450 addr | Self::MWCAS_FLAG
451 }
452}
453
454impl<'g> From<&'g MwCasInner<'g>> for MwCasPointer<'g> {
455 fn from(mwcas: &'g MwCasInner) -> Self {
456 MwCasPointer { mwcas }
457 }
458}
459
460impl<'g> Eq for MwCasPointer<'g> {}
461
462impl<'g> PartialEq for MwCasPointer<'g> {
463 fn eq(&self, other: &MwCasPointer) -> bool {
464 ptr::eq(self.mwcas, other.mwcas)
465 }
466}
467
468impl<'g> PartialEq<MwCasInner<'g>> for MwCasPointer<'g> {
469 fn eq(&self, other: &MwCasInner) -> bool {
470 ptr::eq(self.mwcas, other)
471 }
472}
473
474struct Cas<'g> {
476 target_ptr: *mut AtomicU64,
477 orig_val: u64,
478 new_val: u64,
479 drop_fn: Box<dyn Fn(bool) + 'g>,
481}
482
483unsafe impl<'g> Send for Cas<'g> {}
484unsafe impl<'g> Sync for Cas<'g> {}
485
486#[derive(PartialEq, Copy, Clone)]
487enum CasPrepareResult<'g> {
488 Success,
489 Conflict(MwCasPointer<'g>),
490 Failed,
491}
492
493impl<'g> Cas<'g> {
494 fn new(
495 pointer: *mut AtomicU64,
496 orig_val: u64,
497 new_val: u64,
498 drop_fn: Box<dyn Fn(bool) + 'g>,
499 ) -> Self {
500 let max_addr: u64 = 0xDFFF_FFFF_FFFF_FFFF;
501 assert!(!pointer.is_null(), "Pointer must be non null");
502 debug_assert!(
503 (pointer as u64) < max_addr,
504 "Pointer must point to memory in range [0x{:X}, 0x{:X}], because MwCas \
505 use highest 3 bits of address for internal use. Actual address to which pointer \
506 points was 0x{:x}",
507 0,
508 max_addr,
509 pointer as u64
510 );
511 unsafe {
512 let align = align_of_val(&*pointer);
513 debug_assert_eq!(
514 align,
515 size_of::<u64>(),
516 "Pointer must be align on {} bytes, but pointer was aligned on {}",
517 size_of::<u64>(),
518 align
519 )
520 }
521 debug_assert!(
522 orig_val < MwCasPointer::MWCAS_FLAG,
523 "MwCas can be applied only for original values < {}. Actual value was {}",
524 MwCasPointer::MWCAS_FLAG,
525 orig_val
526 );
527 debug_assert!(
528 new_val < MwCasPointer::MWCAS_FLAG,
529 "MwCas can be applied only for new values < {}. Actual value was {}",
530 MwCasPointer::MWCAS_FLAG,
531 new_val
532 );
533
534 Cas {
535 target_ptr: pointer,
536 orig_val,
537 new_val,
538 drop_fn,
539 }
540 }
541
542 fn prepare<'a>(&self, mwcas: &MwCasInner, guard: &'a Guard) -> CasPrepareResult<'a> {
544 let new_val = MwCasPointer::from(mwcas.deref()).poisoned();
545 let prev = unsafe {
546 (*self.target_ptr)
547 .compare_exchange(self.orig_val, new_val, Ordering::AcqRel, Ordering::Acquire)
548 .map_or_else(|v| v, |v| v)
549 };
550
551 if prev == self.orig_val {
552 CasPrepareResult::Success
553 } else if let Some(mwcas_ptr) = MwCasPointer::from_poisoned(prev, guard) {
554 CasPrepareResult::Conflict(mwcas_ptr)
556 } else {
557 CasPrepareResult::Failed
558 }
559 }
560
561 fn complete(&self, status: u8, mwcas: &MwCasPointer) {
564 let new_val = match status {
565 STATUS_COMPLETED => self.new_val,
566 STATUS_FAILED => self.orig_val,
567 _ => panic!("CAS cannot be completed for not prepared MWCAS"),
568 };
569 let expected_val = mwcas.poisoned();
570 unsafe {
571 let _ = (*self.target_ptr).compare_exchange(
572 expected_val,
573 new_val,
574 Ordering::AcqRel,
575 Ordering::Acquire,
576 );
577 };
578 }
583}
584
585#[cfg(test)]
586mod tests {
587 use crate::Cas;
588 use std::sync::atomic::Ordering;
589
590 mod simple {
591 use crate::{HeapPointer, MwCas, U64Pointer, STATUS_COMPLETED, STATUS_FAILED};
592 use std::ops::Deref;
593 use std::ptr::NonNull;
594 use std::sync::atomic::Ordering;
595
596 #[test]
597 fn test_mwcas_add_ptr() {
598 let guard = crossbeam_epoch::pin();
599 let val1 = HeapPointer::new(5);
600 let val2 = HeapPointer::new(10);
601 let val3 = U64Pointer::new(15);
602 let new_val1 = 15;
603 let new_val2 = 20;
604 let new_val3 = 25;
605 let orig_val1 = val1.read(&guard);
606 let orig_val2 = val2.read(&guard);
607 let orig_val3 = val3.read(&guard);
608
609 let mut mw_cas = MwCas::new();
610 mw_cas.compare_exchange(&val1, orig_val1, new_val1);
611 mw_cas.compare_exchange(&val2, orig_val2, new_val2);
612 mw_cas.compare_exchange_u64(&val3, orig_val3, new_val3);
613 assert!(mw_cas.exec(&guard));
614 assert_eq!(*val1.read(&guard), new_val1);
615 assert_eq!(*val2.read(&guard), new_val2);
616 assert_eq!(val3.read(&guard), new_val3);
617 }
618
619 #[test]
620 #[should_panic]
621 fn test_add_same_ptr() {
622 let guard = crossbeam_epoch::pin();
623 let val1 = HeapPointer::new(5);
624 let new_val1 = 15;
625 let orig_val1 = val1.read(&guard);
626
627 let mut mw_cas = MwCas::new();
628 mw_cas.compare_exchange(&val1, orig_val1, new_val1);
629 mw_cas.compare_exchange(&val1, orig_val1, new_val1);
630 }
631
632 #[test]
633 #[should_panic]
634 fn test_add_same_u64_val() {
635 let guard = crossbeam_epoch::pin();
636 let val1 = U64Pointer::new(5);
637 let new_val1 = 15;
638 let orig_val1 = val1.read(&guard);
639
640 let mut mw_cas = MwCas::new();
641 mw_cas.compare_exchange_u64(&val1, orig_val1, new_val1);
642 mw_cas.compare_exchange_u64(&val1, orig_val1, new_val1);
643 }
644
645 #[test]
646 fn test_prepared_cas_completion_assist() {
647 let val1 = HeapPointer::new(1);
648 let val2 = HeapPointer::new(2);
649 let guard = crossbeam_epoch::pin();
650 let orig_val1 = val1.read(&guard);
651 let orig_val2 = val2.read(&guard);
652 let mut mwcas = MwCas::new();
653 mwcas.compare_exchange(&val1, orig_val1, 2);
654 mwcas.compare_exchange(&val2, orig_val2, 3);
655
656 let cas1 = mwcas.inner.cas_ops.first().unwrap();
658 let cas2 = mwcas.inner.cas_ops.get(1).unwrap();
659 cas1.prepare(mwcas.inner.deref(), &guard);
660 cas2.prepare(mwcas.inner.deref(), &guard);
661
662 assert!(mwcas.exec(&guard));
663 assert_eq!(*val1.read(&guard), 2);
664 assert_eq!(*val2.read(&guard), 3);
665
666 let orig_val1 = val1.read(&guard);
667 let orig_val2 = val2.read(&guard);
668 let mut mwcas = MwCas::new();
669 mwcas.compare_exchange(&val1, orig_val1, 3);
670 mwcas.compare_exchange(&val2, orig_val2, 4);
671 let cas1 = mwcas.inner.cas_ops.last().unwrap();
673 cas1.prepare(mwcas.inner.deref(), &guard);
674
675 assert!(mwcas.exec(&guard));
676 assert_eq!(*val1.read(&guard), 3);
677 assert_eq!(*val2.read(&guard), 4);
678 }
679
680 #[test]
681 fn test_cas_completion_assist_on_subset_of_references() {
682 let val1 = HeapPointer::new(1);
683 let val2 = HeapPointer::new(2);
684 let val3 = HeapPointer::new(3);
685 let guard = crossbeam_epoch::pin();
686 let mut mwcas1 = MwCas::new();
687 let mut mwcas2 = MwCas::new();
688 let orig_val1 = val1.read(&guard);
689 let orig_val2 = val2.read(&guard);
690 let orig_val3 = val3.read(&guard);
691 mwcas1.compare_exchange(&val1, orig_val1, 2);
692 mwcas1.compare_exchange(&val2, orig_val2, 3);
693 mwcas2.compare_exchange(&val3, orig_val3, 4);
694
695 let cas1 = mwcas1.inner.cas_ops.first().unwrap();
697 cas1.prepare(mwcas1.inner.deref(), &guard);
698
699 assert!(mwcas2.exec(&guard));
702 assert_eq!(*val3.read(&guard), 4);
703 assert!(mwcas1.exec(&guard));
704 assert_eq!(*val1.read(&guard), 2);
705 assert_eq!(*val2.read(&guard), 3);
706 }
707
708 #[test]
709 fn test_assist_not_change_cas_result() {
710 let mut val1 = HeapPointer::new(1);
711 let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
712 let mut val2 = HeapPointer::new(2);
713 let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
714 let guard = crossbeam_epoch::pin();
715 let mut mwcas1 = MwCas::new();
716 let mut mwcas2 = MwCas::new();
717 let val1_ref = val1.read(&guard);
718 unsafe {
719 mwcas1.compare_exchange(&*value1.as_ptr(), val1_ref, 2);
720 mwcas1.compare_exchange(&*value2.as_ptr(), val1_ref, 2);
721 }
722 assert_eq!(mwcas1.inner.phase_one(&guard), STATUS_FAILED);
723 mwcas1.inner.update_status(STATUS_FAILED).unwrap();
724
725 unsafe {
727 mwcas2.compare_exchange(&*value1.as_ptr(), val1_ref, 2);
728 }
729 assert!(mwcas2.exec(&guard));
730 assert_eq!(mwcas1.inner.status(), STATUS_FAILED);
731 assert!(!mwcas1.exec(&guard));
732
733 assert_eq!(*val1.read(&guard), 2);
734 assert_eq!(*val2.read(&guard), 2);
735 }
736
737 #[test]
738 #[ignore]
739 fn test_mwcas_race_in_phase_one_before_status_update() {
740 let mut val1 = HeapPointer::new(1);
741 let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
742 let mut val2 = HeapPointer::new(2);
743 let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
744 let mut val3 = HeapPointer::new(3);
745 let value3 = unsafe { NonNull::new_unchecked(&mut val3) };
746 let guard = crossbeam_epoch::pin();
747 let mut mwcas1 = MwCas::new();
748 let mut mwcas2 = MwCas::new();
749 unsafe {
750 mwcas1.compare_exchange(&*value1.as_ptr(), val1.read(&guard), 2);
751 mwcas1.compare_exchange(&*value2.as_ptr(), val2.read(&guard), 3);
752 mwcas2.compare_exchange(&*value3.as_ptr(), val3.read(&guard), 4);
753 }
754
755 let status = mwcas1.inner.phase_one(&guard);
757 assert_eq!(status, STATUS_COMPLETED);
758 assert!(mwcas2.exec(&guard));
761 assert_eq!(*val1.read(&guard), 2);
762 assert_eq!(*val2.read(&guard), 4);
763 assert_eq!(*val3.read(&guard), 4);
764 mwcas1.inner.phase_two(STATUS_COMPLETED);
766 assert_eq!(*val1.read(&guard), 1);
767 assert_eq!(*val2.read(&guard), 4);
768 assert_eq!(*val3.read(&guard), 4);
769 mwcas1.success.store(true, Ordering::Release);
770 }
771
772 #[test]
773 #[ignore]
774 fn test_mwcas_race_in_phase_one_after_status_update() {
775 let mut mwcas1 = MwCas::new();
776 let mut mwcas2 = MwCas::new();
777
778 let mut val1 = HeapPointer::new(1);
779 let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
780 let mut val2 = HeapPointer::new(2);
781 let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
782 unsafe {
783 mwcas1.compare_exchange(&*value1.as_ptr(), &1, 2);
784 mwcas1.compare_exchange(&*value2.as_ptr(), &2, 3);
785 mwcas2.compare_exchange(&*value2.as_ptr(), &3, 4);
786 }
787
788 let guard = crossbeam_epoch::pin();
789 let status = mwcas1.inner.phase_one(&guard);
791 mwcas1.inner.update_status(status).unwrap();
792 mwcas2.exec(&guard);
795 assert_eq!(*val1.read(&guard), 2);
796 assert_eq!(*val2.read(&guard), 3);
797 mwcas1.inner.phase_two(status);
799 assert_eq!(*val1.read(&guard), 2);
800 assert_eq!(*val2.read(&guard), 3);
801 }
802
803 #[test]
804 #[ignore]
805 fn test_mwcas_fail_when_concurrent_mwcas_won_race() {
806 let mut val1 = HeapPointer::new(1);
807 let mut val2 = HeapPointer::new(2);
808 let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
809 let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
810 let guard = crossbeam_epoch::pin();
811 let mut mwcas1 = MwCas::new();
812 let mut mwcas2 = MwCas::new();
813 unsafe {
814 mwcas1.compare_exchange(&*value1.as_ptr(), val1.read(&guard), 2);
815 mwcas1.compare_exchange(&*value2.as_ptr(), val2.read(&guard), 3);
816 mwcas2.compare_exchange(&*value2.as_ptr(), val2.read(&guard), 4);
818 }
819
820 let cas = mwcas1.inner.cas_ops.first().unwrap();
821 cas.prepare(mwcas1.inner.deref(), &guard);
823
824 mwcas2.exec(&guard);
825 assert_eq!(*val2.read(&guard), 4);
826
827 assert!(!mwcas1.exec(&guard));
830 }
831
832 #[test]
833 #[ignore]
834 fn test_mwcas_linearization() {
835 let mut mwcas1 = MwCas::new();
836 let mut mwcas2 = MwCas::new();
837
838 let mut val1 = HeapPointer::new(1);
839 let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
840 let mut val2 = HeapPointer::new(2);
841 let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
842 unsafe {
843 mwcas1.compare_exchange(&*value1.as_ptr(), &1, 2);
844 mwcas1.compare_exchange(&*value2.as_ptr(), &2, 3);
845 mwcas2.compare_exchange(&*value1.as_ptr(), &2, 1);
846 mwcas2.compare_exchange(&*value2.as_ptr(), &3, 2);
847 }
848
849 let guard = crossbeam_epoch::pin();
850 mwcas1.inner.phase_one(&guard);
852
853 assert!(mwcas2.exec(&guard));
856 assert!(mwcas1.exec(&guard));
859
860 assert_eq!(*val1.read(&guard), 1);
861 assert_eq!(*val2.read(&guard), 2);
862 }
863
864 #[test]
865 fn test_mwcas_completion_on_pointer_read() {
866 let mut val = HeapPointer::new(1);
867 let value = unsafe { NonNull::new_unchecked(&mut val) };
868 let guard = crossbeam_epoch::pin();
869 let mut mwcas = MwCas::new();
870 unsafe {
871 mwcas.compare_exchange(&*value.as_ptr(), val.read(&guard), 2);
872 }
873
874 assert_eq!(*val.read(&guard), 1);
875 assert_eq!(mwcas.inner.phase_one(&guard), STATUS_COMPLETED);
876 assert_eq!(*val.read(&guard), 2);
877 mwcas.success.store(true, Ordering::Release);
878 }
879 }
880
881 impl<'g> Cas<'g> {
882 #[inline]
883 fn current_value(&self) -> u64 {
884 unsafe { (*self.target_ptr).load(Ordering::Acquire) }
885 }
886 }
887
888 mod mwcas_pointer_test {
889 use crate::{MwCas, MwCasPointer};
890 use std::ops::Deref;
891 use std::ptr;
892
893 #[test]
894 fn create_pointer_from_structure() {
895 let mw_cas = MwCas::new();
896 let ptr = MwCasPointer::from(mw_cas.inner.deref());
897 assert!(ptr::eq(ptr.deref(), mw_cas.inner.deref()));
898 let guard = crossbeam_epoch::pin();
899 assert!(matches!(
900 MwCasPointer::from_poisoned(ptr.poisoned(), &guard),
901 Some(_)
902 ));
903 }
904
905 #[test]
906 fn create_pointer_from_address() {
907 let guard = crossbeam_epoch::pin();
908 let mw_cas = MwCas::new();
909 let parsed_ptr = MwCasPointer::from_poisoned(
910 MwCasPointer::from(mw_cas.inner.deref()).poisoned(),
911 &guard,
912 );
913 assert!(parsed_ptr.is_some());
914 let ptr = parsed_ptr.unwrap();
915 assert!(ptr::eq(ptr.deref(), mw_cas.inner.deref()));
916
917 assert_eq!(
918 ptr.poisoned(),
919 MwCasPointer::from(mw_cas.inner.deref()).poisoned()
920 );
921 }
922
923 #[test]
924 fn create_pointer_from_invalid_address() {
925 let mw_cas = MwCas::new();
926 let addr = &mw_cas as *const MwCas as u64;
927 let guard = crossbeam_epoch::pin();
928 let parsed_ptr = MwCasPointer::from_poisoned(addr, &guard);
929 assert!(parsed_ptr.is_none());
930 }
931 }
932
933 mod cas_tests {
934 use crate::{
935 CasPrepareResult, HeapPointer, MwCas, MwCasPointer, STATUS_COMPLETED, STATUS_FAILED,
936 };
937 use std::ops::Deref;
938 use std::sync::atomic::Ordering;
939
940 #[test]
941 fn test_cas_success_completion() {
942 let guard = crossbeam_epoch::pin();
943 let cur_val = HeapPointer::new(1);
944 let mut mwcas = MwCas::new();
945 let orig_val = cur_val.read(&guard);
946 mwcas.compare_exchange(&cur_val, orig_val, 2);
947 let cas = mwcas.inner.cas_ops.first().unwrap();
948
949 assert!(matches!(
950 cas.prepare(mwcas.inner.deref(), &guard),
951 CasPrepareResult::Success
952 ));
953
954 let mwcas_ptr = MwCasPointer::from(mwcas.inner.deref());
955 assert!(
956 matches!(MwCasPointer::from_poisoned(cas.current_value(), &guard),
957 Some(ptr) if mwcas_ptr == ptr)
958 );
959
960 cas.complete(STATUS_COMPLETED, &mwcas_ptr);
961 mwcas.success.store(true, Ordering::Release);
962 assert_eq!(*cur_val.read(&guard), 2);
963 }
964
965 #[test]
966 fn test_complete_cas_with_failure() {
967 let guard = crossbeam_epoch::pin();
968 let value = HeapPointer::new(1);
969 let mut mwcas = MwCas::new();
970 let orig_val = value.read(&guard);
971 mwcas.compare_exchange(&value, orig_val, 2);
972 let cas = mwcas.inner.cas_ops.first().unwrap();
973
974 assert!(matches!(
975 cas.prepare(mwcas.inner.deref(), &guard),
976 CasPrepareResult::Success
977 ));
978 let mwcas_ptr = MwCasPointer::from(mwcas.inner.deref());
979 assert!(
980 matches!(MwCasPointer::from_poisoned(cas.current_value(), &guard),
981 Some(ptr) if mwcas_ptr == ptr)
982 );
983
984 cas.complete(STATUS_FAILED, &mwcas_ptr);
985 mwcas.success.store(false, Ordering::Release);
986 assert_eq!(*value.read(&guard), 1);
987 }
988
989 #[test]
990 fn test_same_cas_conflict() {
991 let guard = crossbeam_epoch::pin();
992 let val1 = HeapPointer::new(1);
993 let mut mwcas = MwCas::new();
994 let orig_val = val1.read(&guard);
995 mwcas.compare_exchange(&val1, orig_val, 2);
996 let cas = mwcas.inner.cas_ops.first().unwrap();
997 let mwcas_ptr = MwCasPointer::from(mwcas.inner.deref());
998 assert!(matches!(
999 cas.prepare(mwcas.inner.deref(), &guard),
1000 CasPrepareResult::Success
1001 ));
1002 assert!(matches!(
1003 cas.prepare(mwcas.inner.deref(), &guard),
1004 CasPrepareResult::Conflict(ptr) if ptr == mwcas_ptr
1005 ));
1006 cas.complete(STATUS_COMPLETED, &mwcas_ptr);
1007 mwcas.success.store(true, Ordering::Release);
1008 }
1009
1010 #[test]
1011 #[should_panic]
1012 fn test_cas_completion_with_invalid_status() {
1013 let mut value = HeapPointer::new(1);
1014 let mut mwcas = MwCas::new();
1015 mwcas.compare_exchange(&value, &1, 2);
1016 let cas = mwcas.inner.cas_ops.first().unwrap();
1017 cas.complete(u8::MAX, &MwCasPointer::from(mwcas.inner.deref()));
1018 }
1019 }
1020}