1use atomic::Atomic;
27use num::traits::*;
28use std::{marker::PhantomData, sync::atomic::Ordering};
29
30pub struct BitField<S, T, const POSITION: usize, const SIZE: usize, const SIGN_EXTEND: bool>(
53 PhantomData<(S, T)>,
54);
55
56impl<S, T, const POSITION: usize, const SIZE: usize, const SIGN_EXTEND: bool> BitFieldTrait<S>
57 for BitField<S, T, POSITION, SIZE, SIGN_EXTEND>
58where
59 S: ToPrimitive + FromPrimitive + One + PrimInt,
60 T: FromBitfield<S> + ToBitfield<S> + PartialEq + Eq + Copy,
61{
62 type Type = T;
63
64 const NEXT_BIT: usize = POSITION + SIZE;
65
66 fn mask() -> S {
67 S::from_usize((1 << SIZE) - 1).unwrap()
68 }
69
70 fn mask_in_place() -> S {
71 S::from_usize(((1 << SIZE) - 1) << POSITION).unwrap()
72 }
73
74 fn shift() -> usize {
75 POSITION
76 }
77
78 fn bitsize() -> usize {
79 SIZE
80 }
81
82 fn encode(value: T) -> S {
83 assert!(Self::is_valid(value));
84 Self::encode_unchecked(value)
85 }
86
87 fn decode(value: S) -> T {
88 if SIGN_EXTEND {
89 let u = value.to_u64().unwrap();
90
91 let res = ((u << (64 - Self::NEXT_BIT)) as i64) >> (64 - SIZE);
92
93 T::from_i64(res as _)
94 } else {
95 let u = value;
96
97 T::from_bitfield(u.unsigned_shr(POSITION as _) & Self::mask())
98 }
99 }
100
101 fn update(value: T, original: S) -> S {
102 Self::encode(value) | (!Self::mask_in_place() & original)
103 }
104
105 fn is_valid(value: T) -> bool {
106 Self::decode(Self::encode_unchecked(value)) == value
107 }
108
109 fn encode_unchecked(value: T) -> S {
110 let u = value.to_bitfield();
111
112 (u.bitand(Self::mask())).unsigned_shl(POSITION as _)
113 }
114}
115
116pub trait ToBitfield<S>: Sized {
117 fn to_bitfield(self) -> S;
118 fn one() -> Self;
119 fn zero() -> Self;
120}
121
122pub trait FromBitfield<S>: Sized {
123 fn from_bitfield(value: S) -> Self;
124 fn from_i64(value: i64) -> Self;
125}
126
127impl<S: One + Zero> ToBitfield<S> for bool {
128 fn to_bitfield(self) -> S {
129 if self {
130 S::one()
131 } else {
132 S::zero()
133 }
134 }
135
136 fn one() -> Self {
137 true
138 }
139
140 fn zero() -> Self {
141 false
142 }
143}
144
145impl<S: One + Zero + PartialEq + Eq> FromBitfield<S> for bool {
146 fn from_bitfield(value: S) -> Self {
147 value != S::zero()
148 }
149
150 fn from_i64(value: i64) -> Self {
151 value != 0
152 }
153}
154
155macro_rules! impl_tofrom_bitfield {
156 ($($t:ty)*) => {
157 $(
158 impl<S: NumCast + One + Zero + ToPrimitive + FromPrimitive> ToBitfield<S> for $t {
159 fn to_bitfield(self) -> S {
160 <S as NumCast>::from(self).unwrap()
161 }
162
163 fn one() -> Self {
164 1
165 }
166
167 fn zero() -> Self {
168 0
169 }
170 }
171
172 impl<S: One + Zero + ToPrimitive + FromPrimitive> FromBitfield<S> for $t {
173 fn from_bitfield(value: S) -> Self {
174 <$t as NumCast>::from(value).unwrap()
175 }
176
177 fn from_i64(value: i64) -> Self {
178 value as _
179 }
180 }
181 )*
182 };
183}
184
185impl_tofrom_bitfield!(u8 u16 u32 u64 usize);
186macro_rules! impl_tofrom_bitfield_signed {
187 ($($t:ty => $unsigned:ty)*) => {
188 $(
189 impl<S: NumCast + One + Zero + ToPrimitive + FromPrimitive> ToBitfield<S> for $t {
190 fn to_bitfield(self) -> S {
191 <S as NumCast>::from(self as $unsigned).unwrap()
192 }
193
194 fn one() -> Self {
195 1
196 }
197
198 fn zero() -> Self {
199 0
200 }
201 }
202
203 impl<S: One + Zero + ToPrimitive + FromPrimitive> FromBitfield<S> for $t {
204 fn from_bitfield(value: S) -> Self {
205 <$t as NumCast>::from(value).unwrap()
206 }
207
208 fn from_i64(value: i64) -> Self {
209 value as _
210 }
211 }
212 )*
213 };
214}
215
216impl_tofrom_bitfield_signed!(i8 => u8 i16 => u16 i32 => u32 i64 => u64 isize => usize);
217
218pub trait BitFieldTrait<S>
222where
223 S: PrimInt + ToPrimitive + FromPrimitive + One,
224{
225 type Type: Copy + FromBitfield<S> + ToBitfield<S> + PartialEq + Eq;
227
228 const NEXT_BIT: usize;
231 fn mask() -> S;
233 fn mask_in_place() -> S;
235 fn shift() -> usize;
237 fn bitsize() -> usize;
239 fn encode(value: Self::Type) -> S;
245 fn decode(storage: S) -> Self::Type;
247 fn update(value: Self::Type, original: S) -> S;
249 fn is_valid(value: Self::Type) -> bool;
251 fn encode_unchecked(value: Self::Type) -> S;
253}
254
255pub struct AtomicBitfieldContainer<T: bytemuck::NoUninit>(Atomic<T>);
257
258impl<T: bytemuck::NoUninit + FromPrimitive + ToPrimitive + PrimInt> AtomicBitfieldContainer<T> {
259 pub fn new(value: T) -> Self {
261 Self(Atomic::new(value))
262 }
263
264 pub fn load_ignore_race(&self) -> T {
266 unsafe {
267 let ptr = &self.0 as *const Atomic<T> as *const T;
268 ptr.read()
269 }
270 }
271
272 pub fn load(&self, order: Ordering) -> T {
274 self.0.load(order)
275 }
276
277 pub fn store(&self, value: T, order: Ordering) {
279 self.0.store(value, order)
280 }
281
282 pub fn compare_exchange_weak(
285 &self,
286 current: T,
287 new: T,
288 success: Ordering,
289 failure: Ordering,
290 ) -> Result<T, T> {
291 self.0.compare_exchange_weak(current, new, success, failure)
292 }
293
294 pub fn read<B: BitFieldTrait<T>>(&self) -> B::Type {
296 B::decode(self.load(Ordering::Relaxed))
297 }
298
299 pub fn update<B: BitFieldTrait<T>>(&self, value: B::Type) {
301 let mut old_field = self.0.load(Ordering::Relaxed);
302 let mut new_field;
303 loop {
304 new_field = B::update(value, old_field);
305 match self.0.compare_exchange_weak(
306 old_field,
307 new_field,
308 Ordering::Relaxed,
309 Ordering::Relaxed,
310 ) {
311 Ok(_) => {
312 break;
313 }
314 Err(x) => {
315 old_field = x;
316 }
317 }
318 }
319 }
320
321 pub fn update_synchronized<B: BitFieldTrait<T>>(&self, value: B::Type) {
323 self.0.store(
324 B::update(value, self.0.load(Ordering::Relaxed)),
325 Ordering::Relaxed,
326 );
327 }
328
329 pub fn update_conditional<B: BitFieldTrait<T>>(
333 &self,
334 value_to_be_set: B::Type,
335 conditional_old_value: B::Type,
336 ) -> B::Type {
337 let mut old_field = self.0.load(Ordering::Relaxed);
338
339 loop {
340 let old_value = B::decode(old_field);
341 if old_value != conditional_old_value {
342 return old_value;
343 }
344
345 let new_tags = B::update(value_to_be_set, old_field);
346
347 match self.0.compare_exchange_weak(
348 old_field,
349 new_tags,
350 Ordering::Relaxed,
351 Ordering::Relaxed,
352 ) {
353 Ok(_) => {
354 return conditional_old_value;
355 }
356 Err(x) => {
357 old_field = x;
358 }
359 }
360
361 }
363 }
364}
365
366macro_rules! atomic_ops_common {
367 ($($t: ty)*) => {
368 $(impl AtomicBitfieldContainer<$t> {
369 pub fn fetch_or<B: BitFieldTrait<$t>>(&self, value: B::Type) {
371 self.0.fetch_or(B::encode(value), Ordering::Relaxed);
372 }
373 pub fn fetch_and<B: BitFieldTrait<$t>>(&self, value: B::Type) {
375 self.0.fetch_and(B::encode(value), Ordering::Relaxed);
376 }
377 pub fn fetch_xor<B: BitFieldTrait<$t>>(&self, value: B::Type) {
379 self.0.fetch_xor(B::encode(value), Ordering::Relaxed);
380 }
381
382 pub fn try_acquire<B: BitFieldTrait<$t>>(&self) -> bool {
384 let mask = B::encode(B::Type::one());
385 let old_field = self.0.fetch_or(mask, Ordering::Relaxed);
386 B::decode(old_field) == B::Type::zero()
387 }
388
389 pub fn try_release<B: BitFieldTrait<$t>>(&self) -> bool {
391 let mask = !B::encode(B::Type::one());
392 let old_field = self.0.fetch_and(mask, Ordering::Relaxed);
393 B::decode(old_field) != B::Type::zero()
394 }
395
396 pub fn update_bool<B: BitFieldTrait<$t>>(&self, value: bool, order: Ordering)
399 where
400 B::Type: From<bool>,
401 {
402 if value {
403 self.0.fetch_or(B::encode(B::Type::from(true)), order);
404 } else {
405 self.0.fetch_and(!B::encode(B::Type::from(false)), order);
406 }
407 }
408 })*
409 };
410}
411
412atomic_ops_common!(u8 u16 u32 u64 i8 i16 i32 i64 usize isize);
413
414#[cfg(test)]
415mod tests {
416 use std::sync::Arc;
417
418 use super::*;
419 type LockBit = BitField<usize, bool, 0, 1, false>;
420 type DataBits = BitField<usize, u32, { LockBit::NEXT_BIT }, 32, false>;
421
422 #[test]
423 fn test_acquire_release() {
424 let container = Arc::new(AtomicBitfieldContainer::new(0usize));
425 let thread = {
426 let container = container.clone();
427
428 std::thread::spawn(move || {
429 while !container.try_acquire::<LockBit>() {}
430 container.update::<DataBits>(42);
431 assert!(container.try_release::<LockBit>());
432 })
433 };
434
435 loop {
436 while !container.try_acquire::<LockBit>() {}
437 if container.read::<DataBits>() != 0 {
438 assert_eq!(container.read::<DataBits>(), 42);
439 assert!(container.try_release::<LockBit>());
440 break;
441 }
442
443 assert!(container.try_release::<LockBit>());
444 }
445
446 thread.join().unwrap();
447 }
448
449 type A = BitField<u32, bool, 0, 1, false>;
450 type B = BitField<u32, i8, { A::NEXT_BIT }, 4, true>;
451 type C = BitField<u32, i16, { B::NEXT_BIT }, 16, true>;
452
453 #[test]
454 fn test_encode_decoe() {
455 let mut storage;
456
457 storage = A::encode(true);
458 assert!(A::decode(storage));
459 storage = B::update(-1, storage);
460 assert_eq!(B::decode(storage), -1);
461 storage = B::update(2, storage);
462 assert_eq!(B::decode(storage), 2);
463 assert!(A::decode(storage));
464 assert_eq!(C::decode(storage), 0);
465 assert!(B::is_valid(7));
466 assert!(!B::is_valid(8));
467 assert!(B::is_valid(-8));
468 assert!(!B::is_valid(-9));
469 }
470}