Skip to main content

atomic_tagged_ptr/
impl.rs

1//! A platform-adaptive atomic tagged pointer implementation with robust ABA protection.
2//!
3//! # Background & Hardware Realities
4//!
5//! In lock-free concurrent programming, particularly when constructing intrusive data structures
6//! such as a Treiber Stack, the **ABA problem** frequently arises. The traditional mitigation involves
7//! pairing the physical pointer with a generation tag, updating both atomically.
8//!
9//! However, this incurs CPU architecture constraints:
10//! 1. **64-bit Systems with High Virtual Addresses (52/57-bit)**:
11//!    Modern operating systems on x86_64 (using Intel 5-level paging for 57-bit address space) or AArch64
12//!    (using 52-bit virtual addresses) utilize pointer spaces beyond the typical 48-bit region. Assumed
13//!    48-bit address layout limits lead to pointer truncation and severe wild-pointer crashes.
14//!    This module splits the 64-bit word dynamically: reserving the **lower 56 bits** for the physical pointer
15//!    (covering the entire `007f_ffff_ffff_ffff` user-space boundary) and the **upper 8 bits** for the tag.
16//!    This provides absolute pointer integrity across all current server environments.
17//! 2. **32-bit Systems**:
18//!    Pointer width is 32 bits. We pair it with a 32-bit generation tag to form a double-word 64-bit composite.
19//!    This leverages hardware-level 64-bit atomic operations (such as `cmpxchg8b` on x86 or `ldrd/strd` on ARMv7)
20//!    to complete CAS transitions natively, without making raw address size assumptions.
21//! 3. **Non-AtomicFallback Systems**:
22//!    Under highly customized secure hypervisors (using full MTE tagging) or extremely constrained microcontrollers
23//!    without native 64-bit atomics, the implementation seamlessly falls back to standard Mutex synchronization.
24//!    This guarantees 100% compilation safety without sacrificing API consistency or memory efficiency.
25use core::fmt;
26use core::ptr::NonNull;
27use core::sync::atomic::Ordering;
28
29use crate::ptr::{Ptr, TaggedPtr};
30use crate::traits::IntoOptionNonNull;
31
32// --- Platform Routing Conditional Compile Sections ---
33
34#[cfg(all(target_pointer_width = "64", not(atomic_fallback)))]
35mod ptr64;
36
37#[cfg(all(target_pointer_width = "64", not(atomic_fallback)))]
38use ptr64::AtomicTaggedPtrImpl;
39
40#[cfg(all(target_pointer_width = "64", not(atomic_fallback)))]
41pub use ptr64::TAG_MASK;
42
43#[cfg(all(target_pointer_width = "32", not(atomic_fallback)))]
44mod ptr32;
45
46#[cfg(all(target_pointer_width = "32", not(atomic_fallback)))]
47use ptr32::AtomicTaggedPtrImpl;
48
49#[cfg(all(target_pointer_width = "32", not(atomic_fallback)))]
50pub use ptr32::TAG_MASK;
51
52#[cfg(atomic_fallback)]
53mod fallback;
54
55#[cfg(atomic_fallback)]
56pub use fallback::TAG_MASK;
57
58#[cfg(atomic_fallback)]
59use fallback::AtomicTaggedPtrImpl;
60
61/// Represents a generation tag used for ABA protection in `AtomicTaggedPtr`.
62///
63/// `Tag` wraps a platform-specific generation count and ensures that any operations
64/// (like wrapping addition or creation) respect the hardware platform's limits and bit-width.
65#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
66pub struct Tag(pub(crate) usize);
67
68impl Tag {
69    /// Creates a new `Tag` from a raw value, applying the platform-specific mask.
70    #[inline]
71    pub const fn new(value: usize) -> Self {
72        Self(value & TAG_MASK)
73    }
74
75    /// Gets the raw tag value.
76    #[inline]
77    pub const fn value(self) -> usize {
78        self.0
79    }
80
81    /// Performs wrapping addition on the tag value.
82    #[inline]
83    pub const fn wrapping_add(self, rhs: usize) -> Self {
84        Self::new(self.0.wrapping_add(rhs))
85    }
86
87    /// Returns the maximum tag value allowed on this platform.
88    #[inline]
89    pub const fn max_value() -> Self {
90        Self(TAG_MASK)
91    }
92}
93
94impl fmt::Debug for Tag {
95    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96        write!(f, "Tag({:#X})", self.0)
97    }
98}
99
100impl fmt::Display for Tag {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        write!(f, "{}", self.0)
103    }
104}
105
106impl From<usize> for Tag {
107    #[inline]
108    fn from(value: usize) -> Self {
109        Self::new(value)
110    }
111}
112
113impl From<Tag> for usize {
114    #[inline]
115    fn from(tag: Tag) -> usize {
116        tag.0
117    }
118}
119
120/// Type alias representing the result of atomic compare and exchange operations.
121pub type TaggedPtrResult<T> = Result<TaggedPtr<T>, TaggedPtr<T>>;
122
123/// Type alias for raw results returned by internal platform implementations.
124pub(crate) type RawTaggedPtrResult<T> =
125    Result<(Option<NonNull<T>>, Tag), (Option<NonNull<T>>, Tag)>;
126
127/// A platform-adaptive atomic tagged pointer supporting thread-safe ABA protection.
128pub struct AtomicTaggedPtr<T> {
129    inner: AtomicTaggedPtrImpl<T>,
130}
131
132// Safety: AtomicTaggedPtr is an atomic synchronizer wrapping pointer locations, safe to send/share across threads.
133unsafe impl<T> Send for AtomicTaggedPtr<T> {}
134unsafe impl<T> Sync for AtomicTaggedPtr<T> {}
135
136impl<T> AtomicTaggedPtr<T> {
137    /// Creates a new `AtomicTaggedPtr` initialized with the given tagged pointer.
138    ///
139    /// # Examples
140    ///
141    /// ```
142    /// use std::ptr::NonNull;
143    /// use atomic_tagged_ptr::{AtomicTaggedPtr, TaggedPtr, Tag};
144    ///
145    /// let value = 42;
146    /// let ptr = NonNull::new(&value as *const i32 as *mut i32);
147    /// let atom = AtomicTaggedPtr::new(TaggedPtr::new(ptr, Tag::new(0)));
148    /// ```
149    #[inline]
150    pub fn new(val: TaggedPtr<T>) -> Self {
151        Self {
152            inner: AtomicTaggedPtrImpl::new(val.ptr.into_option_non_null(), val.tag),
153        }
154    }
155
156    /// Loads the current values of the pointer and tag atomically.
157    ///
158    /// # Panics
159    ///
160    /// Panics if `order` is `Release` or `AcqRel`.
161    #[inline]
162    pub fn load(&self, order: Ordering) -> TaggedPtr<T> {
163        let (raw_ptr, tag) = self.inner.load(order);
164        TaggedPtr {
165            ptr: Ptr::new(raw_ptr),
166            tag,
167        }
168    }
169
170    /// Stores a new pointer and tag atomically.
171    ///
172    /// # Panics
173    ///
174    /// Panics if `order` is `Acquire` or `AcqRel`.
175    #[inline]
176    pub fn store(&self, val: TaggedPtr<T>, order: Ordering) {
177        self.inner
178            .store(val.ptr.into_option_non_null(), val.tag, order);
179    }
180
181    /// Exchanges the current values with new ones if the current values match expectations.
182    ///
183    /// On success, returns `Ok` containing the previous pointer and tag.
184    /// On failure, returns `Err` containing the actual loaded pointer and tag.
185    #[inline]
186    pub fn compare_exchange(
187        &self,
188        current: TaggedPtr<T>,
189        new: TaggedPtr<T>,
190        success: Ordering,
191        failure: Ordering,
192    ) -> TaggedPtrResult<T> {
193        match self.inner.compare_exchange(
194            (current.ptr.into_option_non_null(), current.tag),
195            (new.ptr.into_option_non_null(), new.tag),
196            success,
197            failure,
198        ) {
199            Ok((raw_ptr, tag)) => Ok(TaggedPtr {
200                ptr: Ptr::new(raw_ptr),
201                tag,
202            }),
203            Err((raw_ptr, tag)) => Err(TaggedPtr {
204                ptr: Ptr::new(raw_ptr),
205                tag,
206            }),
207        }
208    }
209
210    /// Exchanges the current values with new ones using weak semantics.
211    ///
212    /// This is a weaker variant of `compare_exchange` which is allowed to fail spuriously,
213    /// but can be significantly more efficient on certain LL/SC-based architectures (such as ARM).
214    #[inline]
215    pub fn compare_exchange_weak(
216        &self,
217        current: TaggedPtr<T>,
218        new: TaggedPtr<T>,
219        success: Ordering,
220        failure: Ordering,
221    ) -> TaggedPtrResult<T> {
222        match self.inner.compare_exchange_weak(
223            (current.ptr.into_option_non_null(), current.tag),
224            (new.ptr.into_option_non_null(), new.tag),
225            success,
226            failure,
227        ) {
228            Ok((raw_ptr, tag)) => Ok(TaggedPtr {
229                ptr: Ptr::new(raw_ptr),
230                tag,
231            }),
232            Err((raw_ptr, tag)) => Err(TaggedPtr {
233                ptr: Ptr::new(raw_ptr),
234                tag,
235            }),
236        }
237    }
238}
239
240// --- Common Trait Implementations ---
241
242impl<T> Default for AtomicTaggedPtr<T> {
243    #[inline]
244    fn default() -> Self {
245        Self::new(TaggedPtr::default())
246    }
247}
248
249impl<T> fmt::Debug for AtomicTaggedPtr<T> {
250    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
251        // Safe load under Relaxed ordering to capture debug state snapshot
252        let val = self.load(Ordering::Relaxed);
253        f.debug_struct("AtomicTaggedPtr")
254            .field("pointer", &val.ptr)
255            .field("tag", &val.tag)
256            .finish()
257    }
258}
259
260// --- Built-in Local Integration Tests ---
261
262#[cfg(all(test, feature = "std"))]
263mod tests {
264    use super::*;
265    use std::format;
266
267    #[test]
268    fn test_default_initializer() {
269        let atom: AtomicTaggedPtr<i32> = Default::default();
270        let loaded = atom.load(Ordering::Relaxed);
271        assert!(loaded.ptr.is_none());
272        assert_eq!(loaded.tag, Tag::new(0));
273    }
274
275    #[test]
276    fn test_debug_formatter() {
277        let val = 12345;
278        let ptr = NonNull::new(&val as *const i32 as *mut i32);
279        let atom = AtomicTaggedPtr::new(TaggedPtr::new(ptr, Tag::new(0)));
280        atom.store(TaggedPtr::new(ptr, Tag::new(88)), Ordering::Relaxed);
281
282        let debug_str = format!("{:?}", atom);
283        assert!(debug_str.contains("AtomicTaggedPtr"));
284        assert!(debug_str.contains("tag: Tag(0x58)"));
285    }
286
287    #[test]
288    fn test_multithreaded_atomic_exchanges() {
289        use std::sync::Arc;
290        use std::thread;
291
292        let val = 777;
293        let ptr = NonNull::new(&val as *const i32 as *mut i32);
294        let ptr_usize = ptr.unwrap().as_ptr() as usize;
295        let atom = Arc::new(AtomicTaggedPtr::new(TaggedPtr::new(ptr, Tag::new(0))));
296
297        let atom_clone = Arc::clone(&atom);
298        let handle = thread::spawn(move || {
299            let loaded = atom_clone.load(Ordering::Acquire);
300            let local_ptr = NonNull::new(ptr_usize as *mut i32);
301            if loaded.ptr == local_ptr && loaded.tag == Tag::new(0) {
302                let _ = atom_clone.compare_exchange(
303                    TaggedPtr::new(local_ptr, Tag::new(0)),
304                    TaggedPtr::new(None, Tag::new(55)),
305                    Ordering::SeqCst,
306                    Ordering::SeqCst,
307                );
308            }
309        });
310
311        handle.join().unwrap();
312        let final_state = atom.load(Ordering::Acquire);
313
314        // Assert state was safely transitioned or remained valid
315        assert!(final_state.tag == Tag::new(55) || final_state.tag == Tag::new(0));
316    }
317
318    #[test]
319    fn test_into_option_non_null_api() {
320        let val1 = 111;
321        let raw_ptr1 = &val1 as *const i32;
322        let mut_ptr1 = &val1 as *const i32 as *mut i32;
323        let non_null1 = NonNull::new(mut_ptr1).unwrap();
324
325        // 1. 测试 new
326        // 传入 NonNull<T>
327        let atom = AtomicTaggedPtr::new(TaggedPtr::new(non_null1, Tag::new(0)));
328        assert_eq!(atom.load(Ordering::Relaxed).ptr.option(), Some(non_null1));
329
330        // 传入 Option<NonNull<T>>
331        let atom = AtomicTaggedPtr::new(TaggedPtr::new(Some(non_null1), Tag::new(0)));
332        assert_eq!(atom.load(Ordering::Relaxed).ptr.option(), Some(non_null1));
333
334        // 传入 *const T
335        let atom = AtomicTaggedPtr::new(TaggedPtr::new(raw_ptr1, Tag::new(0)));
336        assert_eq!(atom.load(Ordering::Relaxed).ptr.option(), Some(non_null1));
337
338        // 传入 *mut T
339        let atom = AtomicTaggedPtr::new(TaggedPtr::new(mut_ptr1, Tag::new(0)));
340        assert_eq!(atom.load(Ordering::Relaxed).ptr.option(), Some(non_null1));
341
342        // 传入裸空指针 *const T
343        let atom = AtomicTaggedPtr::new(TaggedPtr::new(core::ptr::null::<i32>(), Tag::new(0)));
344        assert_eq!(atom.load(Ordering::Relaxed).ptr.option(), None);
345
346        // 传入裸空指针 *mut T
347        let atom = AtomicTaggedPtr::new(TaggedPtr::new(core::ptr::null_mut::<i32>(), Tag::new(0)));
348        assert_eq!(atom.load(Ordering::Relaxed).ptr.option(), None);
349
350        // 传入 None
351        let atom: AtomicTaggedPtr<i32> = AtomicTaggedPtr::new(TaggedPtr::new(None, Tag::new(0)));
352        assert_eq!(atom.load(Ordering::Relaxed).ptr.option(), None);
353
354        // 2. 测试 store
355        let atom = AtomicTaggedPtr::new(TaggedPtr::default());
356        atom.store(TaggedPtr::new(raw_ptr1, Tag::new(10)), Ordering::Relaxed);
357        let loaded = atom.load(Ordering::Relaxed);
358        assert_eq!(loaded.ptr.option(), Some(non_null1));
359        assert_eq!(loaded.tag, Tag::new(10));
360
361        atom.store(TaggedPtr::new(None, Tag::new(20)), Ordering::Relaxed);
362        let loaded = atom.load(Ordering::Relaxed);
363        assert_eq!(loaded.ptr.option(), None);
364        assert_eq!(loaded.tag, Tag::new(20));
365
366        // 3. 测试 compare_exchange / compare_exchange_weak (混合不同类型的指针参数)
367        let atom = AtomicTaggedPtr::new(TaggedPtr::new(raw_ptr1, Tag::new(0)));
368        let res = atom.compare_exchange(
369            TaggedPtr::new(raw_ptr1, Tag::new(0)),
370            TaggedPtr::new(mut_ptr1, Tag::new(1)),
371            Ordering::SeqCst,
372            Ordering::SeqCst,
373        );
374        assert!(res.is_ok());
375        let loaded = atom.load(Ordering::Relaxed);
376        assert_eq!(loaded.ptr.option(), Some(non_null1));
377        assert_eq!(loaded.tag, Tag::new(1));
378
379        let res = atom.compare_exchange_weak(
380            TaggedPtr::new(mut_ptr1, Tag::new(1)),
381            TaggedPtr::new(None, Tag::new(2)),
382            Ordering::SeqCst,
383            Ordering::SeqCst,
384        );
385        let mut res = res;
386        while res.is_err() {
387            res = atom.compare_exchange_weak(
388                TaggedPtr::new(mut_ptr1, Tag::new(1)),
389                TaggedPtr::new(None, Tag::new(2)),
390                Ordering::SeqCst,
391                Ordering::SeqCst,
392            );
393        }
394        assert!(res.is_ok());
395        let loaded = atom.load(Ordering::Relaxed);
396        assert_eq!(loaded.ptr.option(), None);
397        assert_eq!(loaded.tag, Tag::new(2));
398    }
399
400    #[test]
401    fn test_ptr_conversions() {
402        let val = 42;
403        let raw = &val as *const i32;
404        let mut_ptr = &val as *const i32 as *mut i32;
405        let non_null = NonNull::new(mut_ptr).unwrap();
406
407        let ptr_some = Ptr::new(Some(non_null));
408        let ptr_none: Ptr<i32> = Ptr::new(None);
409
410        // 测试 option() / as_option()
411        assert_eq!(ptr_some.option(), Some(non_null));
412        assert_eq!(ptr_none.option(), None);
413        assert_eq!(ptr_some.as_option(), Some(non_null));
414
415        // 测试 as_ptr()
416        assert_eq!(ptr_some.as_ptr(), raw);
417        assert_eq!(ptr_none.as_ptr(), core::ptr::null());
418
419        // 测试 as_mut_ptr()
420        assert_eq!(ptr_some.as_mut_ptr(), mut_ptr);
421        assert_eq!(ptr_none.as_mut_ptr(), core::ptr::null_mut());
422
423        // 测试 is_null() / is_some() / is_none()
424        assert!(ptr_some.is_some());
425        assert!(!ptr_some.is_null());
426        assert!(!ptr_some.is_none());
427
428        assert!(ptr_none.is_null());
429        assert!(ptr_none.is_none());
430        assert!(!ptr_none.is_some());
431
432        // 测试 PartialEq
433        assert!(ptr_some == Some(non_null));
434        assert!(ptr_some == non_null);
435        assert!(ptr_some == raw);
436        assert!(ptr_some == mut_ptr);
437
438        assert!(ptr_none == None);
439        assert!(ptr_none == core::ptr::null::<i32>());
440        assert!(ptr_none == core::ptr::null_mut::<i32>());
441    }
442}