Skip to main content

bitflagset/
atomic_slice.rs

1use core::marker::PhantomData;
2use core::ops::BitAndAssign;
3use core::sync::atomic::Ordering;
4use num_traits::{One, PrimInt, Zero};
5use radium::Radium;
6
7use super::bitset::PrimBitSetIter;
8
9/// Unsized shared base for multi-word atomic bitset types.
10///
11/// `AtomicBitSlice<A, V>` is to atomic bitsets what [`BitSlice<T, V>`](super::BitSlice)
12/// is to non-atomic ones: a `#[repr(transparent)]` wrapper around `[A]` that provides
13/// common query and mutation methods. Owned types ([`AtomicBitSet<[A; N], V>`](super::AtomicBitSet)
14/// and [`AtomicBoxedBitSet<A, V>`](super::AtomicBoxedBitSet)) implement
15/// `Deref<Target = AtomicBitSlice<A, V>>`.
16///
17/// # Atomicity guarantees
18///
19/// Each individual method (`insert`, `remove`, `contains`, ...) performs atomic
20/// operations **per word**. However, the bitset as a whole is **not** a single
21/// atomic unit when it spans multiple words. Concurrent modifications to bits
22/// within the **same word** are correctly synchronized via `fetch_or` / `fetch_and`
23/// with `AcqRel` ordering. Modifications to bits in **different words** are
24/// independent atomic operations — there is no cross-word transactional guarantee.
25///
26/// Read-only methods (`len`, `iter`, `contains`, `is_subset`, …) load each word
27/// with `Relaxed` ordering and do not take a consistent snapshot of the entire
28/// bitset. If another thread modifies the set concurrently, these methods may
29/// observe a mix of old and new state across different words.
30#[repr(transparent)]
31pub struct AtomicBitSlice<A, V>(PhantomData<V>, [A]);
32
33impl<A, V> AtomicBitSlice<A, V> {
34    pub(crate) fn from_slice_ref(s: &[A]) -> &Self {
35        // SAFETY: AtomicBitSlice<A, V> is repr(transparent) over [A]
36        // (PhantomData<V> is ZST)
37        unsafe { &*(s as *const [A] as *const Self) }
38    }
39
40    #[inline]
41    pub fn as_raw_slice(&self) -> &[A] {
42        &self.1
43    }
44}
45
46impl<A, V> AtomicBitSlice<A, V>
47where
48    A: Radium,
49    A::Item: PrimInt,
50{
51    const BITS_PER: usize = core::mem::size_of::<A>() * 8;
52
53    #[inline]
54    fn index_of(idx: usize) -> (usize, A::Item) {
55        (
56            idx / Self::BITS_PER,
57            <A::Item as num_traits::One>::one().unsigned_shl((idx % Self::BITS_PER) as u32),
58        )
59    }
60
61    #[inline]
62    pub fn capacity(&self) -> usize {
63        self.1.len() * Self::BITS_PER
64    }
65
66    #[inline]
67    pub fn len(&self) -> usize {
68        self.1
69            .iter()
70            .map(|a| a.load(Ordering::Relaxed).count_ones() as usize)
71            .sum()
72    }
73
74    #[inline]
75    pub fn is_empty(&self) -> bool {
76        self.1.iter().all(|a| a.load(Ordering::Relaxed).is_zero())
77    }
78
79    #[inline]
80    pub fn first(&self) -> Option<V>
81    where
82        V: TryFrom<usize>,
83    {
84        for (i, a) in self.1.iter().enumerate() {
85            let word = a.load(Ordering::Relaxed);
86            if !word.is_zero() {
87                let bit = word.trailing_zeros() as usize;
88                return V::try_from(i * Self::BITS_PER + bit).ok();
89            }
90        }
91        None
92    }
93
94    #[inline]
95    pub fn last(&self) -> Option<V>
96    where
97        V: TryFrom<usize>,
98    {
99        for (i, a) in self.1.iter().enumerate().rev() {
100            let word = a.load(Ordering::Relaxed);
101            if !word.is_zero() {
102                let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
103                return V::try_from(i * Self::BITS_PER + bit).ok();
104            }
105        }
106        None
107    }
108
109    #[inline]
110    pub fn pop_first(&self) -> Option<V>
111    where
112        V: TryFrom<usize>,
113        A::Item: radium::marker::BitOps,
114    {
115        for (i, a) in self.1.iter().enumerate() {
116            loop {
117                let word = a.load(Ordering::Acquire);
118                if word.is_zero() {
119                    break;
120                }
121                let bit = word.trailing_zeros() as usize;
122                let mask = A::Item::one().unsigned_shl(bit as u32);
123                let old = a.fetch_and(!mask, Ordering::AcqRel);
124                if old & mask != A::Item::zero() {
125                    return V::try_from(i * Self::BITS_PER + bit).ok();
126                }
127            }
128        }
129        None
130    }
131
132    #[inline]
133    pub fn pop_last(&self) -> Option<V>
134    where
135        V: TryFrom<usize>,
136        A::Item: radium::marker::BitOps,
137    {
138        for (i, a) in self.1.iter().enumerate().rev() {
139            loop {
140                let word = a.load(Ordering::Acquire);
141                if word.is_zero() {
142                    break;
143                }
144                let bit = Self::BITS_PER - 1 - word.leading_zeros() as usize;
145                let mask = A::Item::one().unsigned_shl(bit as u32);
146                let old = a.fetch_and(!mask, Ordering::AcqRel);
147                if old & mask != A::Item::zero() {
148                    return V::try_from(i * Self::BITS_PER + bit).ok();
149                }
150            }
151        }
152        None
153    }
154
155    #[inline]
156    pub fn contains(&self, id: &V) -> bool
157    where
158        V: Copy + num_traits::AsPrimitive<usize>,
159    {
160        let idx = (*id).as_();
161        let (seg, mask) = Self::index_of(idx);
162        if seg >= self.1.len() {
163            return false;
164        }
165        // SAFETY: seg < self.1.len() checked above.
166        let a = unsafe { self.1.get_unchecked(seg) };
167        a.load(Ordering::Relaxed) & mask != A::Item::zero()
168    }
169
170    #[inline]
171    pub fn insert(&self, id: V) -> bool
172    where
173        V: num_traits::AsPrimitive<usize>,
174        A::Item: radium::marker::BitOps,
175    {
176        let idx = id.as_();
177        let (seg, mask) = Self::index_of(idx);
178        if seg >= self.1.len() {
179            return false;
180        }
181        // SAFETY: seg < self.1.len() checked above.
182        let a = unsafe { self.1.get_unchecked(seg) };
183        let old = a.fetch_or(mask, Ordering::AcqRel);
184        old & mask == A::Item::zero()
185    }
186
187    #[inline]
188    pub fn remove(&self, id: V) -> bool
189    where
190        V: num_traits::AsPrimitive<usize>,
191        A::Item: radium::marker::BitOps,
192    {
193        let idx = id.as_();
194        let (seg, mask) = Self::index_of(idx);
195        if seg >= self.1.len() {
196            return false;
197        }
198        // SAFETY: seg < self.1.len() checked above.
199        let a = unsafe { self.1.get_unchecked(seg) };
200        let old = a.fetch_and(!mask, Ordering::AcqRel);
201        old & mask != A::Item::zero()
202    }
203
204    #[inline]
205    pub fn set(&self, id: V, value: bool)
206    where
207        V: num_traits::AsPrimitive<usize>,
208        A::Item: radium::marker::BitOps,
209    {
210        if value {
211            self.insert(id);
212        } else {
213            self.remove(id);
214        }
215    }
216
217    #[inline]
218    pub fn toggle(&self, id: V)
219    where
220        V: num_traits::AsPrimitive<usize>,
221        A::Item: radium::marker::BitOps,
222    {
223        let idx = id.as_();
224        let (seg, mask) = Self::index_of(idx);
225        if seg >= self.1.len() {
226            return;
227        }
228        // SAFETY: seg < self.1.len() checked above.
229        let a = unsafe { self.1.get_unchecked(seg) };
230        a.fetch_xor(mask, Ordering::AcqRel);
231    }
232
233    #[inline]
234    pub fn clear(&self) {
235        for atomic in self.1.iter() {
236            atomic.store(A::Item::zero(), Ordering::Release);
237        }
238    }
239
240    pub fn retain(&self, mut f: impl FnMut(V) -> bool)
241    where
242        V: TryFrom<usize>,
243        A::Item: radium::marker::BitOps,
244    {
245        for (i, a) in self.1.iter().enumerate() {
246            let word = a.load(Ordering::Relaxed);
247            let mut w = word;
248            while !w.is_zero() {
249                let bit = w.trailing_zeros() as usize;
250                let mask = A::Item::one().unsigned_shl(bit as u32);
251                w = w & !mask;
252                let idx = i * Self::BITS_PER + bit;
253                debug_assert!(V::try_from(idx).is_ok());
254                let value = match V::try_from(idx) {
255                    Ok(v) => v,
256                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
257                };
258                if !f(value) {
259                    a.fetch_and(!mask, Ordering::AcqRel);
260                }
261            }
262        }
263    }
264
265    #[inline]
266    pub fn iter(&self) -> impl Iterator<Item = V> + '_
267    where
268        A::Item: BitAndAssign,
269        V: TryFrom<usize>,
270    {
271        self.1.iter().enumerate().flat_map(move |(i, a)| {
272            let bits = a.load(Ordering::Relaxed);
273            let offset = i * Self::BITS_PER;
274            PrimBitSetIter::<A::Item, usize>(bits, PhantomData).map(move |pos| {
275                let idx = offset + pos;
276                debug_assert!(V::try_from(idx).is_ok());
277                match V::try_from(idx) {
278                    Ok(v) => v,
279                    Err(_) => unsafe { core::hint::unreachable_unchecked() },
280                }
281            })
282        })
283    }
284
285    #[inline]
286    pub fn is_subset(&self, other: &Self) -> bool {
287        let min = self.1.len().min(other.1.len());
288        self.1[..min]
289            .iter()
290            .zip(other.1[..min].iter())
291            .all(|(a, b)| {
292                let va = a.load(Ordering::Relaxed);
293                let vb = b.load(Ordering::Relaxed);
294                va & vb == va
295            })
296            && self.1[min..]
297                .iter()
298                .all(|a| a.load(Ordering::Relaxed).is_zero())
299    }
300
301    #[inline]
302    pub fn is_superset(&self, other: &Self) -> bool {
303        other.is_subset(self)
304    }
305
306    #[inline]
307    pub fn is_disjoint(&self, other: &Self) -> bool {
308        self.1.iter().zip(other.1.iter()).all(|(a, b)| {
309            let va = a.load(Ordering::Relaxed);
310            let vb = b.load(Ordering::Relaxed);
311            (va & vb).is_zero()
312        })
313    }
314
315    #[inline]
316    pub fn union_from(&self, other: &[A::Item])
317    where
318        A::Item: radium::marker::BitOps + Copy,
319    {
320        for (atomic, &value) in self.1.iter().zip(other.iter()) {
321            atomic.fetch_or(value, Ordering::AcqRel);
322        }
323    }
324}