mwcas/
lib.rs

1//! Multi-word CAS.
2//!
3//! Rust standard library provides atomic types in [`atomic`](std::sync::atomic) package.
4//! Atomic types provide lock-free way to atomically update value of one pointer. Many concurrent
5//! data structures usually requires atomic update for more than 1 pointer at time. For example,
6//! [`BzTree`](http://www.vldb.org/pvldb/vol11/p553-arulraj.pdf).
7//!
8//! This crate provides concurrency primitive called [`MwCas`] which can atomically update several
9//! pointers at time. It is based on paper
10//! [`Easy Lock-Free Indexing in Non-Volatile Memory`](http://justinlevandoski.org/papers/ICDE18_mwcas.pdf).
11//! Current implementation doesn't provide features for non-volatile memory(persistent
12//! memory) and only covers DRAM multi-word CAS.
13//!
14//! # Platform support
15//! Currently, [`MwCas`] supports only x86_64 platform because it exploits platform specific hacks:
16//! MwCas use upper 3 bit of pointer's virtual address to representing internal state. Today x86_64
17//! CPUs use lower 48 bits of virtual address, other 16 bits are 0. Usage of upper 3 bits
18//! described in paper.
19//!
20//! # Usage
21//! Multi-word CAS API represented by [`MwCas`] struct which can operate on 2 types of pointers:
22//! - pointer to heap allocated data([`HeapPointer`])
23//! - pointer to u64([`U64Pointer`])
24//!
25//! [`HeapPointer`] can be used to execute multi-word CAS on any data type, but with
26//! cost of heap allocation. [`U64Pointer`] do not allocate anything on heap and has memory
27//! overhead same as `u64`.
28//!
29//! [`MwCas`] is a container for chain of `compare_exchange` operations. When caller adds all
30//! required CASes, it performs multi-word CAS by calling `exec` method. `exec` method
31//! returns `bool` which indicate is MwCAS was successful.
32//!
33//! Example of `MwCAS` usage:
34//! ```
35//! use mwcas::{MwCas, HeapPointer, U64Pointer};
36//!
37//! let ptr = HeapPointer::new(String::new());
38//! let val = U64Pointer::new(0);
39//! let guard = crossbeam_epoch::pin();
40//! let cur_ptr_val: &String = ptr.read(&guard);
41//!
42//! let mut mwcas = MwCas::new();
43//! mwcas.compare_exchange(&ptr, cur_ptr_val, String::from("new_string"));
44//! mwcas.compare_exchange_u64(&val, 0, 1);
45//!
46//! assert!(mwcas.exec(&guard));
47//! assert_eq!(ptr.read(&guard), &String::from("new_string"));
48//! assert_eq!(val.read(&guard), 1);
49//! ```
50//!
51//! # Memory reclamation
52//! Drop of values pointed by `HeapPointer` which were replaced by new one during CAS, will
53//! be performed by [`crossbeam_epoch`] memory reclamation.  
54
55use crossbeam_epoch::Guard;
56use std::borrow::Borrow;
57use std::marker::PhantomData;
58use std::mem::{align_of_val, size_of};
59use std::ops::Deref;
60use std::option::Option::Some;
61use std::ptr;
62use std::rc::Rc;
63use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
64
65const STATUS_PREPARE: u8 = 0;
66const STATUS_COMPLETED: u8 = 1;
67const STATUS_FAILED: u8 = 2;
68
69/// Pointer to data located on heap.
70///
71/// # Drop
72/// Heap memory reference by `HeapPointer` will be released(and structure will be dropped) as part
73/// of `HeapPointer` drop.
74#[derive(Debug)]
75#[repr(transparent)]
76pub struct HeapPointer<T> {
77    ptr: AtomicU64,
78    phantom: PhantomData<T>,
79}
80
81impl<T> HeapPointer<T> {
82    /// Create new `HeapPointer` which allocates memory for `val` on heap.
83    #[inline]
84    pub fn new(val: T) -> Self {
85        let val_address = Box::into_raw(Box::new(val)) as u64;
86        HeapPointer {
87            ptr: AtomicU64::new(val_address),
88            phantom: PhantomData {},
89        }
90    }
91
92    /// Read current value of `HeapPointer` and return reference to it.
93    #[inline]
94    pub fn read<'g>(&'g self, guard: &'g Guard) -> &'g T {
95        unsafe { &*self.read_ptr(guard) }
96    }
97
98    /// Read current value of `HeapPointer` and return mutable reference to it.
99    #[inline]
100    pub fn read_mut<'g>(&'g mut self, guard: &'g Guard) -> &'g mut T {
101        unsafe { &mut *self.read_ptr(guard) }
102    }
103
104    #[inline]
105    fn read_ptr(&self, guard: &Guard) -> *mut T {
106        read_val(&self.ptr, guard) as *mut u8 as *mut T
107    }
108}
109
110#[inline]
111fn read_val(ptr: &AtomicU64, guard: &Guard) -> u64 {
112    loop {
113        let cur_val = ptr.load(Ordering::Acquire);
114        if let Some(mwcas_ptr) = MwCasPointer::from_poisoned(cur_val, guard) {
115            mwcas_ptr.exec_internal(guard);
116        } else {
117            return cur_val;
118        }
119    }
120}
121
122impl<T: Clone> Clone for HeapPointer<T> {
123    fn clone(&self) -> Self {
124        let val = self.read(&crossbeam_epoch::pin()).clone();
125        HeapPointer::new(val)
126    }
127}
128
129impl<T> Drop for HeapPointer<T> {
130    fn drop(&mut self) {
131        unsafe {
132            drop(Box::from_raw(
133                // this heap pointer cannot be part of any running MwCAS,
134                // we can safely use crossbeam_epoch::unprotected()
135                self.read_ptr(crossbeam_epoch::unprotected()),
136            ));
137        }
138    }
139}
140
141unsafe impl<T: Send> Send for HeapPointer<T> {}
142unsafe impl<T: Sync> Sync for HeapPointer<T> {}
143
144/// Pointer to `u64` data.
145///
146/// This structure is more 'holder' of `u64` than 'pointer'.
147/// It exists only to provide interface which is consistent with `HeapPointer` and  
148/// can get safe access to current value of `u64` data.
149#[derive(Debug)]
150#[repr(transparent)]
151pub struct U64Pointer {
152    val: AtomicU64,
153}
154
155impl U64Pointer {
156    /// Create new `U64Pointer` with initial value.  
157    #[inline]
158    pub fn new(val: u64) -> Self {
159        Self {
160            val: AtomicU64::new(val),
161        }
162    }
163
164    /// Read current value of pointer.
165    #[inline]
166    pub fn read(&self, guard: &Guard) -> u64 {
167        read_val(&self.val, guard)
168    }
169}
170
171impl Clone for U64Pointer {
172    fn clone(&self) -> Self {
173        U64Pointer::new(self.read(&crossbeam_epoch::pin()))
174    }
175}
176
177unsafe impl Send for U64Pointer {}
178unsafe impl Sync for U64Pointer {}
179
180/// Multi-word CAS structure.
181///
182/// [`MwCas`] contains multi-word CAS state, including pointers which should be changed,
183/// original and new pointer values.
184/// [`MwCas`] provides `compare and exchange` operations to register CAS operations on pointers.
185/// When all `compare and exchange` operations registered, caller should execute `exec` method to
186/// actually perform multi-word CAS.
187#[cfg(target_arch = "x86_64")]
188pub struct MwCas<'g> {
189    // allocated on heap to be safely accessible by 'assisting' threads.
190    inner: Box<MwCasInner<'g>>,
191    // is MwCAS completed successfully. Used during MwCAS drop.
192    success: AtomicBool,
193    // Rc used to make this type !Send and !Sync,
194    phantom: PhantomData<Rc<u8>>,
195}
196
197impl<'g> Default for MwCas<'g> {
198    fn default() -> Self {
199        Self::new()
200    }
201}
202
203impl<'g> MwCas<'g> {
204    /// Create new `MwCAS`.
205    #[inline]
206    pub fn new() -> Self {
207        MwCas {
208            inner: Box::new(MwCasInner {
209                status: AtomicU8::new(STATUS_PREPARE),
210                cas_ops: Vec::with_capacity(2),
211            }),
212            success: AtomicBool::new(false),
213            phantom: PhantomData {},
214        }
215    }
216
217    /// Add compare-exchange operation to MwCAS for heap allocated data.
218    ///
219    /// - `target` points to heap allocated data which should be replaced by new value.  
220    /// - `orig_val` is value is from `target` pointer at some point in time, using
221    /// `HeapPointer.read()` method.
222    /// - `new_val` will be installed to `target` on `MwCas` success. If `MwCas` will fail, then
223    /// `new_val` will be dropped.
224    #[inline]
225    pub fn compare_exchange<T>(&mut self, target: &'g HeapPointer<T>, orig_val: &'g T, new_val: T) {
226        #[cfg(debug_assertions)]
227        {
228            for cas in &self.inner.cas_ops {
229                if ptr::eq(cas.target_ptr, &target.ptr as *const AtomicU64) {
230                    panic!(
231                        "MwCAS cannot compare-and-swap the same {} several times in one execution. 
232                        Remove duplicate target reference passed to 'add/with' method. 
233                        This can happen if you use unsafe code which skips borrowing rules 
234                        checker of Rust: target parameter declared as mutable reference and 
235                        cannot be added twice to MwCAS by 'safe' code.",
236                        std::any::type_name::<HeapPointer<T>>()
237                    )
238                }
239            }
240        }
241        let orig_val_ptr = orig_val as *const T as *mut T;
242        let orig_val_addr = orig_val_ptr as u64;
243        let new_val_ptr = Box::into_raw(Box::new(new_val));
244        let new_val_addr = new_val_ptr as u64;
245        let drop_fn: Box<dyn Fn(bool) + 'g> = Box::new(move |success| {
246            if success {
247                drop(unsafe { Box::from_raw(orig_val_ptr) })
248            } else {
249                drop(unsafe { Box::from_raw(new_val_ptr) })
250            }
251        });
252        self.inner.cas_ops.push(Cas::new(
253            &target.ptr as *const AtomicU64 as *mut AtomicU64,
254            orig_val_addr,
255            new_val_addr,
256            drop_fn,
257        ));
258    }
259
260    /// Add compare-exchange operation to MwCAS for simple u64.
261    ///
262    /// - `target` struct contains u64 which should be replaced by `MwCas`.  
263    /// - `orig_val` is expected value of `target` during CAS.
264    /// - `new_val` will be installed to `target` on `MwCas` success.
265    #[inline]
266    pub fn compare_exchange_u64(&mut self, target: &'g U64Pointer, orig_val: u64, new_val: u64) {
267        #[cfg(debug_assertions)]
268        {
269            for cas in &self.inner.cas_ops {
270                if ptr::eq(cas.target_ptr, &target.val as *const AtomicU64) {
271                    panic!(
272                        "MwCAS cannot compare-and-swap the same {} several times in one execution. 
273                        Remove duplicate target reference passed to 'add/with' method. 
274                        This can happen if you use unsafe code which skips borrowing rules 
275                        checker of Rust: target parameter declared as mutable reference and 
276                        cannot be added twice to MwCAS by 'safe' code.",
277                        std::any::type_name::<U64Pointer>()
278                    )
279                }
280            }
281        }
282
283        let drop_fn: Box<dyn Fn(bool) + 'g> = Box::new(move |_| {});
284        self.inner.cas_ops.push(Cas::new(
285            &target.val as *const AtomicU64 as *mut AtomicU64,
286            *orig_val.borrow(),
287            *new_val.borrow(),
288            drop_fn,
289        ));
290    }
291
292    /// Execute all registered CAS operations and return result status.
293    ///
294    /// `guard` is used for reclamation of memory used by previous values
295    /// which were replaced during `MwCas` by new one.
296    #[inline]
297    pub fn exec(self, guard: &Guard) -> bool {
298        let successful_cas = self.inner.exec_internal(guard);
299        // delay drop of MwCAS until all thread which can assist to it,
300        // e.g. can access this MwCAS by pointer.
301        self.success.store(successful_cas, Ordering::Release);
302        unsafe {
303            guard.defer_unchecked(move || {
304                drop(self);
305            });
306        }
307        successful_cas
308    }
309}
310
311impl<'g> Drop for MwCas<'g> {
312    fn drop(&mut self) {
313        // if CAS was successful, free memory used by previous value(e.g., value which
314        // was replaced). Otherwise, free memory used by 'candidate' value which not
315        // used anymore and never will be seen by other threads.
316        for cas in &self.inner.cas_ops {
317            (cas.drop_fn)(self.success.load(Ordering::Acquire));
318        }
319    }
320}
321
322struct MwCasInner<'g> {
323    // MwCAS status(described by const values)
324    status: AtomicU8,
325    // list of registered CAS operations
326    cas_ops: Vec<Cas<'g>>,
327}
328
329impl<'g> MwCasInner<'g> {
330    #[inline(always)]
331    fn status(&self) -> u8 {
332        self.status.load(Ordering::Acquire)
333    }
334
335    #[inline]
336    fn exec_internal(&self, guard: &Guard) -> bool {
337        let phase_one_status = self.phase_one(guard);
338        let phase_two_status = self.update_status(phase_one_status);
339        match phase_two_status {
340            Ok(status) => self.phase_two(status),
341            Err(cur_status) => {
342                self.phase_two(cur_status);
343            }
344        }
345        phase_two_status.map_or_else(|status| status, |status| status) == STATUS_COMPLETED
346    }
347
348    /// Phase 1 according to paper
349    fn phase_one(&self, guard: &Guard) -> u8 {
350        for cas in &self.cas_ops {
351            loop {
352                match cas.prepare(self, guard) {
353                    CasPrepareResult::Conflict(mwcas_ptr) => {
354                        if &mwcas_ptr != self.deref() {
355                            // we must to try to complete other MWCAS to assists to other thread
356                            mwcas_ptr.exec_internal(guard);
357                        } else {
358                            // if we found our MwCAS => proceed, this indicate that other thread
359                            // already assist us.
360                            break;
361                        }
362                    }
363                    CasPrepareResult::Success => break,
364                    CasPrepareResult::Failed => return STATUS_FAILED,
365                }
366            }
367        }
368        STATUS_COMPLETED
369    }
370
371    #[inline]
372    fn update_status(&self, new_status: u8) -> Result<u8, u8> {
373        if let Err(prev_status) = self.status.compare_exchange(
374            STATUS_PREPARE,
375            new_status,
376            Ordering::AcqRel,
377            Ordering::Acquire,
378        ) {
379            // if some other thread executed our MwCAS before us,
380            // it already update status and we revert all changes.
381            // otherwise, we can overwrite results of completed MwCAS.
382            // Description from paper:
383            // Installation of a descriptor for a completed PMwCAS (p1) that might
384            // inadvertently overwrite the result of another PMwCAS (p2), where
385            // p2 should occur after p1. This can happen if a thread T executing p1
386            // is about to install a descriptor in a target address A over an existing
387            // value V, but goes to sleep. While T sleeps, another thread may complete p1
388            // (given the cooperative nature of PMwCAS ) and subsequently
389            // p2 executes to set a back to V. If T were to wake up and try to
390            // overwrite V (the value it expects) in address A, it would actually be
391            // overwriting the result of p2, violating the linearizable schedule for
392            // updates to A.
393            Err(prev_status)
394        } else {
395            Ok(new_status)
396        }
397    }
398
399    /// Phase 2 according to paper
400    fn phase_two(&self, mwcas_status: u8) {
401        // in any case(success or failure), we should complete CAS
402        // on each pointer to obtain a consistent state.
403        let mwcas_ptr = MwCasPointer::from(self.deref());
404        for cas in &self.cas_ops {
405            cas.complete(mwcas_status, &mwcas_ptr);
406        }
407    }
408}
409
410#[derive(Copy, Clone)]
411#[repr(transparent)]
412struct MwCasPointer<'g> {
413    mwcas: &'g MwCasInner<'g>,
414}
415
416impl<'g> Deref for MwCasPointer<'g> {
417    type Target = MwCasInner<'g>;
418
419    fn deref(&self) -> &Self::Target {
420        self.mwcas
421    }
422}
423
424impl<'g> MwCasPointer<'g> {
425    const MWCAS_FLAG: u64 = 0x4000_0000_0000_0000;
426
427    /// Try to create pointer to existing `MwCAS` based on address installed on CAS target pointer.
428    #[inline]
429    fn from_poisoned(poisoned_addr: u64, _: &'g Guard) -> Option<MwCasPointer<'g>> {
430        let valid_addr = poisoned_addr & !Self::MWCAS_FLAG;
431        if poisoned_addr != valid_addr {
432            Option::Some(MwCasPointer {
433                // we observe existing MwCas during of guard lifetime
434                // it's safe to access it until guard is alive
435                mwcas: unsafe { &*(valid_addr as *const u64 as *const MwCasInner) },
436            })
437        } else {
438            // passed address is not `poisoned` address,
439            // e.g. not an address of some existing `MwCAS`
440            Option::None
441        }
442    }
443
444    /// Return address of MwCas structure but with modified high bits which
445    /// indicate that this address is not valid address of MwCas structure
446    /// but a special pointer to MwCas.
447    #[inline(always)]
448    fn poisoned(&self) -> u64 {
449        let addr = self.mwcas as *const MwCasInner as *const u64 as u64;
450        addr | Self::MWCAS_FLAG
451    }
452}
453
454impl<'g> From<&'g MwCasInner<'g>> for MwCasPointer<'g> {
455    fn from(mwcas: &'g MwCasInner) -> Self {
456        MwCasPointer { mwcas }
457    }
458}
459
460impl<'g> Eq for MwCasPointer<'g> {}
461
462impl<'g> PartialEq for MwCasPointer<'g> {
463    fn eq(&self, other: &MwCasPointer) -> bool {
464        ptr::eq(self.mwcas, other.mwcas)
465    }
466}
467
468impl<'g> PartialEq<MwCasInner<'g>> for MwCasPointer<'g> {
469    fn eq(&self, other: &MwCasInner) -> bool {
470        ptr::eq(self.mwcas, other)
471    }
472}
473
474/// Struct describe one CAS operation of `MwCAS`.
475struct Cas<'g> {
476    target_ptr: *mut AtomicU64,
477    orig_val: u64,
478    new_val: u64,
479    // function which will drop original/new value after CAS completion
480    drop_fn: Box<dyn Fn(bool) + 'g>,
481}
482
483unsafe impl<'g> Send for Cas<'g> {}
484unsafe impl<'g> Sync for Cas<'g> {}
485
486#[derive(PartialEq, Copy, Clone)]
487enum CasPrepareResult<'g> {
488    Success,
489    Conflict(MwCasPointer<'g>),
490    Failed,
491}
492
493impl<'g> Cas<'g> {
494    fn new(
495        pointer: *mut AtomicU64,
496        orig_val: u64,
497        new_val: u64,
498        drop_fn: Box<dyn Fn(bool) + 'g>,
499    ) -> Self {
500        let max_addr: u64 = 0xDFFF_FFFF_FFFF_FFFF;
501        assert!(!pointer.is_null(), "Pointer must be non null");
502        debug_assert!(
503            (pointer as u64) < max_addr,
504            "Pointer must point to memory in range [0x{:X}, 0x{:X}], because MwCas \
505             use highest 3 bits of address for internal use. Actual address to which pointer \
506             points was 0x{:x}",
507            0,
508            max_addr,
509            pointer as u64
510        );
511        unsafe {
512            let align = align_of_val(&*pointer);
513            debug_assert_eq!(
514                align,
515                size_of::<u64>(),
516                "Pointer must be align on {} bytes, but pointer was aligned on {}",
517                size_of::<u64>(),
518                align
519            )
520        }
521        debug_assert!(
522            orig_val < MwCasPointer::MWCAS_FLAG,
523            "MwCas can be applied only for original values < {}. Actual value was {}",
524            MwCasPointer::MWCAS_FLAG,
525            orig_val
526        );
527        debug_assert!(
528            new_val < MwCasPointer::MWCAS_FLAG,
529            "MwCas can be applied only for new values < {}. Actual value was {}",
530            MwCasPointer::MWCAS_FLAG,
531            new_val
532        );
533
534        Cas {
535            target_ptr: pointer,
536            orig_val,
537            new_val,
538            drop_fn,
539        }
540    }
541
542    /// Try to install pointer to `MwCAS` into value of current CAS target.
543    fn prepare<'a>(&self, mwcas: &MwCasInner, guard: &'a Guard) -> CasPrepareResult<'a> {
544        let new_val = MwCasPointer::from(mwcas.deref()).poisoned();
545        let prev = unsafe {
546            (*self.target_ptr)
547                .compare_exchange(self.orig_val, new_val, Ordering::AcqRel, Ordering::Acquire)
548                .map_or_else(|v| v, |v| v)
549        };
550
551        if prev == self.orig_val {
552            CasPrepareResult::Success
553        } else if let Some(mwcas_ptr) = MwCasPointer::from_poisoned(prev, guard) {
554            // found MWCAS pointer installed by some other
555            CasPrepareResult::Conflict(mwcas_ptr)
556        } else {
557            CasPrepareResult::Failed
558        }
559    }
560
561    /// Complete CAS operation for current pointer: set new value on MwCAS success or rollback to
562    /// original value if MwCAS failed.
563    fn complete(&self, status: u8, mwcas: &MwCasPointer) {
564        let new_val = match status {
565            STATUS_COMPLETED => self.new_val,
566            STATUS_FAILED => self.orig_val,
567            _ => panic!("CAS cannot be completed for not prepared MWCAS"),
568        };
569        let expected_val = mwcas.poisoned();
570        unsafe {
571            let _ = (*self.target_ptr).compare_exchange(
572                expected_val,
573                new_val,
574                Ordering::AcqRel,
575                Ordering::Acquire,
576            );
577        };
578        // if CAS above failed, then some other thread completed our MwCAS,
579        // e.g assist us. This is expected case, no additional actions required.
580        // Or we found MwCas of installed by other thread. This is also expected
581        // case when we fail our MwCAS and some other MwCas install it pointer to same memory cell.
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use crate::Cas;
588    use std::sync::atomic::Ordering;
589
590    mod simple {
591        use crate::{HeapPointer, MwCas, U64Pointer, STATUS_COMPLETED, STATUS_FAILED};
592        use std::ops::Deref;
593        use std::ptr::NonNull;
594        use std::sync::atomic::Ordering;
595
596        #[test]
597        fn test_mwcas_add_ptr() {
598            let guard = crossbeam_epoch::pin();
599            let val1 = HeapPointer::new(5);
600            let val2 = HeapPointer::new(10);
601            let val3 = U64Pointer::new(15);
602            let new_val1 = 15;
603            let new_val2 = 20;
604            let new_val3 = 25;
605            let orig_val1 = val1.read(&guard);
606            let orig_val2 = val2.read(&guard);
607            let orig_val3 = val3.read(&guard);
608
609            let mut mw_cas = MwCas::new();
610            mw_cas.compare_exchange(&val1, orig_val1, new_val1);
611            mw_cas.compare_exchange(&val2, orig_val2, new_val2);
612            mw_cas.compare_exchange_u64(&val3, orig_val3, new_val3);
613            assert!(mw_cas.exec(&guard));
614            assert_eq!(*val1.read(&guard), new_val1);
615            assert_eq!(*val2.read(&guard), new_val2);
616            assert_eq!(val3.read(&guard), new_val3);
617        }
618
619        #[test]
620        #[should_panic]
621        fn test_add_same_ptr() {
622            let guard = crossbeam_epoch::pin();
623            let val1 = HeapPointer::new(5);
624            let new_val1 = 15;
625            let orig_val1 = val1.read(&guard);
626
627            let mut mw_cas = MwCas::new();
628            mw_cas.compare_exchange(&val1, orig_val1, new_val1);
629            mw_cas.compare_exchange(&val1, orig_val1, new_val1);
630        }
631
632        #[test]
633        #[should_panic]
634        fn test_add_same_u64_val() {
635            let guard = crossbeam_epoch::pin();
636            let val1 = U64Pointer::new(5);
637            let new_val1 = 15;
638            let orig_val1 = val1.read(&guard);
639
640            let mut mw_cas = MwCas::new();
641            mw_cas.compare_exchange_u64(&val1, orig_val1, new_val1);
642            mw_cas.compare_exchange_u64(&val1, orig_val1, new_val1);
643        }
644
645        #[test]
646        fn test_prepared_cas_completion_assist() {
647            let val1 = HeapPointer::new(1);
648            let val2 = HeapPointer::new(2);
649            let guard = crossbeam_epoch::pin();
650            let orig_val1 = val1.read(&guard);
651            let orig_val2 = val2.read(&guard);
652            let mut mwcas = MwCas::new();
653            mwcas.compare_exchange(&val1, orig_val1, 2);
654            mwcas.compare_exchange(&val2, orig_val2, 3);
655
656            // emulate that some other thread begins our MwCAS
657            let cas1 = mwcas.inner.cas_ops.first().unwrap();
658            let cas2 = mwcas.inner.cas_ops.get(1).unwrap();
659            cas1.prepare(mwcas.inner.deref(), &guard);
660            cas2.prepare(mwcas.inner.deref(), &guard);
661
662            assert!(mwcas.exec(&guard));
663            assert_eq!(*val1.read(&guard), 2);
664            assert_eq!(*val2.read(&guard), 3);
665
666            let orig_val1 = val1.read(&guard);
667            let orig_val2 = val2.read(&guard);
668            let mut mwcas = MwCas::new();
669            mwcas.compare_exchange(&val1, orig_val1, 3);
670            mwcas.compare_exchange(&val2, orig_val2, 4);
671            // emulate that some other thread begins our MwCAS
672            let cas1 = mwcas.inner.cas_ops.last().unwrap();
673            cas1.prepare(mwcas.inner.deref(), &guard);
674
675            assert!(mwcas.exec(&guard));
676            assert_eq!(*val1.read(&guard), 3);
677            assert_eq!(*val2.read(&guard), 4);
678        }
679
680        #[test]
681        fn test_cas_completion_assist_on_subset_of_references() {
682            let val1 = HeapPointer::new(1);
683            let val2 = HeapPointer::new(2);
684            let val3 = HeapPointer::new(3);
685            let guard = crossbeam_epoch::pin();
686            let mut mwcas1 = MwCas::new();
687            let mut mwcas2 = MwCas::new();
688            let orig_val1 = val1.read(&guard);
689            let orig_val2 = val2.read(&guard);
690            let orig_val3 = val3.read(&guard);
691            mwcas1.compare_exchange(&val1, orig_val1, 2);
692            mwcas1.compare_exchange(&val2, orig_val2, 3);
693            mwcas2.compare_exchange(&val3, orig_val3, 4);
694
695            // assist first MwCAS
696            let cas1 = mwcas1.inner.cas_ops.first().unwrap();
697            cas1.prepare(mwcas1.inner.deref(), &guard);
698
699            // at start, second MwCAS should complete first MwCAS
700            // and then can successfully complete it's own operations.
701            assert!(mwcas2.exec(&guard));
702            assert_eq!(*val3.read(&guard), 4);
703            assert!(mwcas1.exec(&guard));
704            assert_eq!(*val1.read(&guard), 2);
705            assert_eq!(*val2.read(&guard), 3);
706        }
707
708        #[test]
709        fn test_assist_not_change_cas_result() {
710            let mut val1 = HeapPointer::new(1);
711            let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
712            let mut val2 = HeapPointer::new(2);
713            let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
714            let guard = crossbeam_epoch::pin();
715            let mut mwcas1 = MwCas::new();
716            let mut mwcas2 = MwCas::new();
717            let val1_ref = val1.read(&guard);
718            unsafe {
719                mwcas1.compare_exchange(&*value1.as_ptr(), val1_ref, 2);
720                mwcas1.compare_exchange(&*value2.as_ptr(), val1_ref, 2);
721            }
722            assert_eq!(mwcas1.inner.phase_one(&guard), STATUS_FAILED);
723            mwcas1.inner.update_status(STATUS_FAILED).unwrap();
724
725            // this cause assist to mwcas-1 which already on fail path
726            unsafe {
727                mwcas2.compare_exchange(&*value1.as_ptr(), val1_ref, 2);
728            }
729            assert!(mwcas2.exec(&guard));
730            assert_eq!(mwcas1.inner.status(), STATUS_FAILED);
731            assert!(!mwcas1.exec(&guard));
732
733            assert_eq!(*val1.read(&guard), 2);
734            assert_eq!(*val2.read(&guard), 2);
735        }
736
737        #[test]
738        #[ignore]
739        fn test_mwcas_race_in_phase_one_before_status_update() {
740            let mut val1 = HeapPointer::new(1);
741            let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
742            let mut val2 = HeapPointer::new(2);
743            let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
744            let mut val3 = HeapPointer::new(3);
745            let value3 = unsafe { NonNull::new_unchecked(&mut val3) };
746            let guard = crossbeam_epoch::pin();
747            let mut mwcas1 = MwCas::new();
748            let mut mwcas2 = MwCas::new();
749            unsafe {
750                mwcas1.compare_exchange(&*value1.as_ptr(), val1.read(&guard), 2);
751                mwcas1.compare_exchange(&*value2.as_ptr(), val2.read(&guard), 3);
752                mwcas2.compare_exchange(&*value3.as_ptr(), val3.read(&guard), 4);
753            }
754
755            // start phase 1 of 1st mwcas
756            let status = mwcas1.inner.phase_one(&guard);
757            assert_eq!(status, STATUS_COMPLETED);
758            // execute 2nd mwcas which should find conflicting 1st MwCAS in value2,
759            // assist it and complete both MwCASs
760            assert!(mwcas2.exec(&guard));
761            assert_eq!(*val1.read(&guard), 2);
762            assert_eq!(*val2.read(&guard), 4);
763            assert_eq!(*val3.read(&guard), 4);
764            // execute phase 2 for completed MwCas and check that result remains the same
765            mwcas1.inner.phase_two(STATUS_COMPLETED);
766            assert_eq!(*val1.read(&guard), 1);
767            assert_eq!(*val2.read(&guard), 4);
768            assert_eq!(*val3.read(&guard), 4);
769            mwcas1.success.store(true, Ordering::Release);
770        }
771
772        #[test]
773        #[ignore]
774        fn test_mwcas_race_in_phase_one_after_status_update() {
775            let mut mwcas1 = MwCas::new();
776            let mut mwcas2 = MwCas::new();
777
778            let mut val1 = HeapPointer::new(1);
779            let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
780            let mut val2 = HeapPointer::new(2);
781            let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
782            unsafe {
783                mwcas1.compare_exchange(&*value1.as_ptr(), &1, 2);
784                mwcas1.compare_exchange(&*value2.as_ptr(), &2, 3);
785                mwcas2.compare_exchange(&*value2.as_ptr(), &3, 4);
786            }
787
788            let guard = crossbeam_epoch::pin();
789            // start phase 1 of 1st mwcas
790            let status = mwcas1.inner.phase_one(&guard);
791            mwcas1.inner.update_status(status).unwrap();
792            // execute 2nd mwcas which should find conflicting 1st MwCAS in value2,
793            // assist it and complete both MwCASs
794            mwcas2.exec(&guard);
795            assert_eq!(*val1.read(&guard), 2);
796            assert_eq!(*val2.read(&guard), 3);
797            // execute phase 2 for completed MwCas and check that result remains the same
798            mwcas1.inner.phase_two(status);
799            assert_eq!(*val1.read(&guard), 2);
800            assert_eq!(*val2.read(&guard), 3);
801        }
802
803        #[test]
804        #[ignore]
805        fn test_mwcas_fail_when_concurrent_mwcas_won_race() {
806            let mut val1 = HeapPointer::new(1);
807            let mut val2 = HeapPointer::new(2);
808            let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
809            let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
810            let guard = crossbeam_epoch::pin();
811            let mut mwcas1 = MwCas::new();
812            let mut mwcas2 = MwCas::new();
813            unsafe {
814                mwcas1.compare_exchange(&*value1.as_ptr(), val1.read(&guard), 2);
815                mwcas1.compare_exchange(&*value2.as_ptr(), val2.read(&guard), 3);
816                // emulate race with 2nd MwCAS on same value
817                mwcas2.compare_exchange(&*value2.as_ptr(), val2.read(&guard), 4);
818            }
819
820            let cas = mwcas1.inner.cas_ops.first().unwrap();
821            // emulate that only 1 CAS started in 1st MwCAS
822            cas.prepare(mwcas1.inner.deref(), &guard);
823
824            mwcas2.exec(&guard);
825            assert_eq!(*val2.read(&guard), 4);
826
827            // try complete 1st MwCAS which should fail because 2nd already
828            // update expected field value
829            assert!(!mwcas1.exec(&guard));
830        }
831
832        #[test]
833        #[ignore]
834        fn test_mwcas_linearization() {
835            let mut mwcas1 = MwCas::new();
836            let mut mwcas2 = MwCas::new();
837
838            let mut val1 = HeapPointer::new(1);
839            let value1 = unsafe { NonNull::new_unchecked(&mut val1) };
840            let mut val2 = HeapPointer::new(2);
841            let value2 = unsafe { NonNull::new_unchecked(&mut val2) };
842            unsafe {
843                mwcas1.compare_exchange(&*value1.as_ptr(), &1, 2);
844                mwcas1.compare_exchange(&*value2.as_ptr(), &2, 3);
845                mwcas2.compare_exchange(&*value1.as_ptr(), &2, 1);
846                mwcas2.compare_exchange(&*value2.as_ptr(), &3, 2);
847            }
848
849            let guard = crossbeam_epoch::pin();
850            // emulate start of 1st MwCAS without status update
851            mwcas1.inner.phase_one(&guard);
852
853            // 2nd MwCAS will assist to 1st MwCAS, complete itself(rollback
854            // all fields to original values)
855            assert!(mwcas2.exec(&guard));
856            // 1st MwCAS should skip all field updates because someone already done it's work
857            // and revert field values back
858            assert!(mwcas1.exec(&guard));
859
860            assert_eq!(*val1.read(&guard), 1);
861            assert_eq!(*val2.read(&guard), 2);
862        }
863
864        #[test]
865        fn test_mwcas_completion_on_pointer_read() {
866            let mut val = HeapPointer::new(1);
867            let value = unsafe { NonNull::new_unchecked(&mut val) };
868            let guard = crossbeam_epoch::pin();
869            let mut mwcas = MwCas::new();
870            unsafe {
871                mwcas.compare_exchange(&*value.as_ptr(), val.read(&guard), 2);
872            }
873
874            assert_eq!(*val.read(&guard), 1);
875            assert_eq!(mwcas.inner.phase_one(&guard), STATUS_COMPLETED);
876            assert_eq!(*val.read(&guard), 2);
877            mwcas.success.store(true, Ordering::Release);
878        }
879    }
880
881    impl<'g> Cas<'g> {
882        #[inline]
883        fn current_value(&self) -> u64 {
884            unsafe { (*self.target_ptr).load(Ordering::Acquire) }
885        }
886    }
887
888    mod mwcas_pointer_test {
889        use crate::{MwCas, MwCasPointer};
890        use std::ops::Deref;
891        use std::ptr;
892
893        #[test]
894        fn create_pointer_from_structure() {
895            let mw_cas = MwCas::new();
896            let ptr = MwCasPointer::from(mw_cas.inner.deref());
897            assert!(ptr::eq(ptr.deref(), mw_cas.inner.deref()));
898            let guard = crossbeam_epoch::pin();
899            assert!(matches!(
900                MwCasPointer::from_poisoned(ptr.poisoned(), &guard),
901                Some(_)
902            ));
903        }
904
905        #[test]
906        fn create_pointer_from_address() {
907            let guard = crossbeam_epoch::pin();
908            let mw_cas = MwCas::new();
909            let parsed_ptr = MwCasPointer::from_poisoned(
910                MwCasPointer::from(mw_cas.inner.deref()).poisoned(),
911                &guard,
912            );
913            assert!(parsed_ptr.is_some());
914            let ptr = parsed_ptr.unwrap();
915            assert!(ptr::eq(ptr.deref(), mw_cas.inner.deref()));
916
917            assert_eq!(
918                ptr.poisoned(),
919                MwCasPointer::from(mw_cas.inner.deref()).poisoned()
920            );
921        }
922
923        #[test]
924        fn create_pointer_from_invalid_address() {
925            let mw_cas = MwCas::new();
926            let addr = &mw_cas as *const MwCas as u64;
927            let guard = crossbeam_epoch::pin();
928            let parsed_ptr = MwCasPointer::from_poisoned(addr, &guard);
929            assert!(parsed_ptr.is_none());
930        }
931    }
932
933    mod cas_tests {
934        use crate::{
935            CasPrepareResult, HeapPointer, MwCas, MwCasPointer, STATUS_COMPLETED, STATUS_FAILED,
936        };
937        use std::ops::Deref;
938        use std::sync::atomic::Ordering;
939
940        #[test]
941        fn test_cas_success_completion() {
942            let guard = crossbeam_epoch::pin();
943            let cur_val = HeapPointer::new(1);
944            let mut mwcas = MwCas::new();
945            let orig_val = cur_val.read(&guard);
946            mwcas.compare_exchange(&cur_val, orig_val, 2);
947            let cas = mwcas.inner.cas_ops.first().unwrap();
948
949            assert!(matches!(
950                cas.prepare(mwcas.inner.deref(), &guard),
951                CasPrepareResult::Success
952            ));
953
954            let mwcas_ptr = MwCasPointer::from(mwcas.inner.deref());
955            assert!(
956                matches!(MwCasPointer::from_poisoned(cas.current_value(), &guard),
957                    Some(ptr) if mwcas_ptr == ptr)
958            );
959
960            cas.complete(STATUS_COMPLETED, &mwcas_ptr);
961            mwcas.success.store(true, Ordering::Release);
962            assert_eq!(*cur_val.read(&guard), 2);
963        }
964
965        #[test]
966        fn test_complete_cas_with_failure() {
967            let guard = crossbeam_epoch::pin();
968            let value = HeapPointer::new(1);
969            let mut mwcas = MwCas::new();
970            let orig_val = value.read(&guard);
971            mwcas.compare_exchange(&value, orig_val, 2);
972            let cas = mwcas.inner.cas_ops.first().unwrap();
973
974            assert!(matches!(
975                cas.prepare(mwcas.inner.deref(), &guard),
976                CasPrepareResult::Success
977            ));
978            let mwcas_ptr = MwCasPointer::from(mwcas.inner.deref());
979            assert!(
980                matches!(MwCasPointer::from_poisoned(cas.current_value(), &guard),
981                    Some(ptr) if mwcas_ptr == ptr)
982            );
983
984            cas.complete(STATUS_FAILED, &mwcas_ptr);
985            mwcas.success.store(false, Ordering::Release);
986            assert_eq!(*value.read(&guard), 1);
987        }
988
989        #[test]
990        fn test_same_cas_conflict() {
991            let guard = crossbeam_epoch::pin();
992            let val1 = HeapPointer::new(1);
993            let mut mwcas = MwCas::new();
994            let orig_val = val1.read(&guard);
995            mwcas.compare_exchange(&val1, orig_val, 2);
996            let cas = mwcas.inner.cas_ops.first().unwrap();
997            let mwcas_ptr = MwCasPointer::from(mwcas.inner.deref());
998            assert!(matches!(
999                cas.prepare(mwcas.inner.deref(), &guard),
1000                CasPrepareResult::Success
1001            ));
1002            assert!(matches!(
1003                cas.prepare(mwcas.inner.deref(), &guard),
1004                CasPrepareResult::Conflict(ptr) if ptr == mwcas_ptr
1005            ));
1006            cas.complete(STATUS_COMPLETED, &mwcas_ptr);
1007            mwcas.success.store(true, Ordering::Release);
1008        }
1009
1010        #[test]
1011        #[should_panic]
1012        fn test_cas_completion_with_invalid_status() {
1013            let mut value = HeapPointer::new(1);
1014            let mut mwcas = MwCas::new();
1015            mwcas.compare_exchange(&value, &1, 2);
1016            let cas = mwcas.inner.cas_ops.first().unwrap();
1017            cas.complete(u8::MAX, &MwCasPointer::from(mwcas.inner.deref()));
1018        }
1019    }
1020}