ref_swap/
lib.rs

1// Copyright (C) 2022 Nitrokey GmbH
2// SPDX-License-Identifier: Apache-2.0 or MIT
3
4#![cfg_attr(not(test), no_std)]
5#![doc = include_str!("../README.md")]
6
7use core::{
8    marker::PhantomData,
9    sync::atomic::{AtomicPtr, Ordering},
10};
11
12/// Relaxed operations can lead to race conditions:
13///
14/// Thread A can pass a reference with relaxed ordering to thread B which means that no garanties are made that the data seen by thread B after dereferencing the reference will include the mutations that thread B
15fn load(ordering: Ordering) -> Ordering {
16    match ordering {
17        Ordering::Relaxed | Ordering::Acquire => Ordering::Acquire,
18
19        Ordering::AcqRel => Ordering::AcqRel,
20        Ordering::SeqCst => Ordering::SeqCst,
21        Ordering::Release => panic!("Release ordering cannot be used for loads"),
22        _ => unimplemented!("{ordering:?} is not supported"),
23    }
24}
25
26fn store(ordering: Ordering) -> Ordering {
27    match ordering {
28        Ordering::Relaxed | Ordering::Release => Ordering::Release,
29
30        Ordering::AcqRel => Ordering::AcqRel,
31        Ordering::SeqCst => Ordering::SeqCst,
32        Ordering::Acquire => panic!("Acquire ordering cannot be used for stores"),
33        _ => unimplemented!("{ordering:?} is not supported"),
34    }
35}
36
37fn load_store(ordering: Ordering) -> Ordering {
38    match ordering {
39        Ordering::Relaxed | Ordering::Release | Ordering::Acquire | Ordering::AcqRel => {
40            Ordering::AcqRel
41        }
42
43        Ordering::SeqCst => Ordering::SeqCst,
44        _ => unimplemented!("{ordering:?} is not supported"),
45    }
46}
47
48// SAFETY:
49//
50// # Lifetimes
51// After being loaded, a reference is guarateed to be alive for 'a because it was required when it was stored
52//
53// # Thread Safety
54// the `load` and `store` and `load_store` functions are used to coerce orderings to have at least Release-Acquire semantics.
55// This ensures that any data written *before* the call to a `store` is synchronized with the thread observing the reference through a `load` operation
56
57/// A reference that can atomically be changed using another reference with the same lifetime and type
58pub struct RefSwap<'a, T> {
59    ptr: AtomicPtr<T>,
60    phantom: PhantomData<&'a T>,
61}
62
63impl<'a, T> RefSwap<'a, T> {
64    pub const fn new(data: &'a T) -> Self {
65        Self {
66            ptr: AtomicPtr::new(data as *const _ as *mut _),
67            phantom: PhantomData,
68        }
69    }
70
71    /// Stores a reference if the current value is the same as the current value.
72    ///
73    /// Be aware that the comparison is only between the reference, not between the value.
74    /// If current point to another adress in memory than the reference currently holds, it will fail,
75    /// Even if both are equal according to a `PartialEq` implementation.
76    ///
77    /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_and_swap`](core::sync::atomic::AtomicPtr::compare_and_swap)
78    #[deprecated(note = "Use `compare_exchange` or `compare_exchange_weak` instead")]
79    pub fn compare_and_swap(&self, current: &'a T, new: &'a T, order: Ordering) -> &'a T {
80        #[allow(deprecated)]
81        let ptr = self.ptr.compare_and_swap(
82            current as *const _ as *mut _,
83            new as *const _ as *mut _,
84            load_store(order),
85        );
86
87        unsafe { &*ptr }
88    }
89
90    /// Stores a reference if the current value is the same as the current value.
91    ///
92    /// Be aware that the comparison is only between the reference, not between the value.
93    /// If current point to another adress in memory than the reference currently holds, it will fail,
94    /// Even if both are equal according to a `PartialEq` implementation.
95    ///
96    /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange`](core::sync::atomic::AtomicPtr::compare_exchange)
97    pub fn compare_exchange(
98        &self,
99        current: &'a T,
100        new: &'a T,
101        success: Ordering,
102        failure: Ordering,
103    ) -> Result<&'a T, &'a T> {
104        let res = self.ptr.compare_exchange(
105            current as *const _ as *mut _,
106            new as *const _ as *mut _,
107            load_store(success),
108            load(failure),
109        );
110
111        res.map(|ptr| unsafe { &*ptr })
112            .map_err(|ptr| unsafe { &*ptr })
113    }
114
115    /// Stores a reference if the current value is the same as the current value.
116    ///
117    /// Be aware that the comparison is only between the reference, not between the value.
118    /// If current point to another adress in memory than the reference currently holds, it will fail,
119    /// Even if both are equal according to a `PartialEq` implementation.
120    ///
121    /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange_weak`](core::sync::atomic::AtomicPtr::compare_exchange_weak)
122    pub fn compare_exchange_weak(
123        &self,
124        current: &'a T,
125        new: &'a T,
126        success: Ordering,
127        failure: Ordering,
128    ) -> Result<&'a T, &'a T> {
129        let res = self.ptr.compare_exchange_weak(
130            current as *const _ as *mut _,
131            new as *const _ as *mut _,
132            load_store(success),
133            load(failure),
134        );
135
136        res.map(|ptr| unsafe { &*ptr })
137            .map_err(|ptr| unsafe { &*ptr })
138    }
139
140    /// Get a mutable reference to the current stored reference.
141    ///
142    /// This is safe because the mutable reference guarantees that no other threads are concurrently accessing the atomic data.
143    pub fn get_mut<'s>(&'s mut self) -> &'s mut &'a T {
144        let res: &'s mut *mut T = self.ptr.get_mut();
145        unsafe { &mut *(res as *mut *mut T as *mut &'a T) }
146    }
147
148    /// Consumes the atomic and returns the contained value.
149    ///
150    /// This is safe because passing `self` by value guarantees that no other threads are concurrently accessing the atomic data.
151    pub fn into_inner(self) -> &'a T {
152        let res = self.ptr.into_inner();
153        unsafe { &*res }
154    }
155
156    /// Fetches the value, and applies a function to it that returns an optional new value. `Returns` a `Result` of `Ok(previous_value)` if the function returned `Some(_)`, else `Err(previous_value)`.
157    ///
158    /// For more information on the orderings, se the documentation of [`AtomicPtr::fetch_update`](core::sync::atomic::AtomicPtr::fetch_update)
159    pub fn fetch_update<F: FnMut(&'a T) -> Option<&'a T>>(
160        &self,
161        set_order: Ordering,
162        fetch_order: Ordering,
163        mut f: F,
164    ) -> Result<&'a T, &'a T> {
165        self.ptr
166            .fetch_update(load_store(set_order), load(fetch_order), |ptr| {
167                f(unsafe { &*ptr }).map(|r| r as *const _ as *mut _)
168            })
169            .map(|ptr| unsafe { &*ptr })
170            .map_err(|ptr| unsafe { &*ptr })
171    }
172
173    /// Loads a value
174    pub fn load(&self, order: Ordering) -> &'a T {
175        let res = self.ptr.load(load(order));
176
177        unsafe { &*res }
178    }
179
180    /// Store a value
181    pub fn store(&self, ptr: &'a T, order: Ordering) {
182        self.ptr.store(ptr as *const _ as *mut _, store(order));
183    }
184
185    /// Stores a value into the pointer, returning the previous value.
186    pub fn swap(&self, ptr: &'a T, order: Ordering) -> &'a T {
187        let res = self.ptr.swap(ptr as *const _ as *mut _, load_store(order));
188
189        unsafe { &*res }
190    }
191}
192
193/// An optionnal reference that can atomically be changed to another optionnal reference with the same lifetime and type
194pub struct OptionRefSwap<'a, T> {
195    ptr: AtomicPtr<T>,
196    phantom: PhantomData<&'a T>,
197}
198
199/// Returns a null pointer if `ptr` is None, otherwise returns the the pointer corresponding to the reference
200const fn opt_to_ptr<T>(ptr: Option<&T>) -> *mut T {
201    match ptr {
202        Some(r) => r as *const _ as *mut _,
203        None => core::ptr::null_mut(),
204    }
205}
206
207/// # Safety: `ptr` must come from `opt_to_ptr` with the corresponding lifetime
208unsafe fn ptr_to_opt<'a, T>(ptr: *mut T) -> Option<&'a T> {
209    if ptr.is_null() {
210        None
211    } else {
212        // Safety: we know that ptr comes from `opt_to_ptr`, and therefor is a `&'a T` when not null
213        Some(unsafe { &*ptr })
214    }
215}
216
217impl<'a, T> OptionRefSwap<'a, T> {
218    pub const fn new(data: Option<&'a T>) -> Self {
219        Self {
220            ptr: AtomicPtr::new(opt_to_ptr(data)),
221            phantom: PhantomData,
222        }
223    }
224
225    /// Stores a reference if the current value is the same as the current value.
226    ///
227    /// Be aware that the comparison is only between the reference, not between the value.
228    /// If current point to another adress in memory than the reference currently holds, it will fail,
229    /// Even if both are equal according to a `PartialEq` implementation.
230    ///
231    /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_and_swap`](core::sync::atomic::AtomicPtr::compare_and_swap)
232    #[deprecated(note = "Use `compare_exchange` or `compare_exchange_weak` instead")]
233    pub fn compare_and_swap(
234        &self,
235        current: Option<&'a T>,
236        new: Option<&'a T>,
237        order: Ordering,
238    ) -> Option<&'a T> {
239        #[allow(deprecated)]
240        let ptr =
241            self.ptr
242                .compare_and_swap(opt_to_ptr(current), opt_to_ptr(new), load_store(order));
243
244        unsafe { ptr_to_opt(ptr) }
245    }
246
247    /// Stores a reference if the current value is the same as the current value.
248    ///
249    /// Be aware that the comparison is only between the reference, not between the value.
250    /// If current point to another adress in memory than the reference currently holds, it will fail,
251    /// Even if both are equal according to a `PartialEq` implementation.
252    ///
253    /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange`](core::sync::atomic::AtomicPtr::compare_exchange)
254    pub fn compare_exchange(
255        &self,
256        current: Option<&'a T>,
257        new: Option<&'a T>,
258        success: Ordering,
259        failure: Ordering,
260    ) -> Result<Option<&'a T>, Option<&'a T>> {
261        let res = self.ptr.compare_exchange(
262            opt_to_ptr(current),
263            opt_to_ptr(new),
264            load_store(success),
265            load(failure),
266        );
267
268        res.map(|ptr| unsafe { ptr_to_opt(ptr) })
269            .map_err(|ptr| unsafe { ptr_to_opt(ptr) })
270    }
271
272    /// Stores a reference if the current value is the same as the current value.
273    ///
274    /// Be aware that the comparison is only between the reference, not between the value.
275    /// If current point to another adress in memory than the reference currently holds, it will fail,
276    /// Even if both are equal according to a `PartialEq` implementation.
277    ///
278    /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange_weak`](core::sync::atomic::AtomicPtr::compare_exchange_weak)
279    pub fn compare_exchange_weak(
280        &self,
281        current: Option<&'a T>,
282        new: Option<&'a T>,
283        success: Ordering,
284        failure: Ordering,
285    ) -> Result<Option<&'a T>, Option<&'a T>> {
286        let res = self.ptr.compare_exchange_weak(
287            opt_to_ptr(current),
288            opt_to_ptr(new),
289            load_store(success),
290            load(failure),
291        );
292
293        res.map(|ptr| unsafe { ptr_to_opt(ptr) })
294            .map_err(|ptr| unsafe { ptr_to_opt(ptr) })
295    }
296
297    /// Get a mutable reference to the current stored reference.
298    ///
299    /// This is safe because the mutable reference guarantees that no other threads are concurrently accessing the atomic data.
300    #[allow(unused)]
301    fn get_mut<'s>(&'s mut self) -> &'s mut Option<&'a T> {
302        let res: &'s mut *mut T = self.ptr.get_mut();
303
304        // TODO: Is this transmute really safe? Making this function private until I'm sure
305
306        unsafe { core::mem::transmute(res) }
307    }
308
309    /// Consumes the atomic and returns the contained value.
310    ///
311    /// This is safe because passing `self` by value guarantees that no other threads are concurrently accessing the atomic data.
312    pub fn into_inner(self) -> &'a T {
313        let res = self.ptr.into_inner();
314        unsafe { &*res }
315    }
316
317    /// Fetches the value, and applies a function to it that returns an optional new value. `Returns` a `Result` of `Ok(previous_value)` if the function returned `Some(_)`, else `Err(previous_value)`.
318    ///
319    /// For more information on the orderings, se the documentation of [`AtomicPtr::fetch_update`](core::sync::atomic::AtomicPtr::fetch_update)
320    pub fn fetch_update<F: FnMut(Option<&'a T>) -> Option<Option<&'a T>>>(
321        &self,
322        set_order: Ordering,
323        fetch_order: Ordering,
324        mut f: F,
325    ) -> Result<Option<&'a T>, Option<&'a T>> {
326        self.ptr
327            .fetch_update(load_store(set_order), load(fetch_order), |ptr| {
328                f(unsafe { ptr_to_opt(ptr) }).map(opt_to_ptr)
329            })
330            .map(|ptr| unsafe { ptr_to_opt(ptr) })
331            .map_err(|ptr| unsafe { ptr_to_opt(ptr) })
332    }
333
334    /// Loads a value
335    pub fn load(&self, order: Ordering) -> Option<&'a T> {
336        let res = self.ptr.load(load(order));
337
338        unsafe { ptr_to_opt(res) }
339    }
340
341    /// Stores a value
342    pub fn store(&self, ptr: Option<&'a T>, order: Ordering) {
343        self.ptr.store(opt_to_ptr(ptr), store(order));
344    }
345
346    /// Stores a value into the pointer, returning the previous value.
347    pub fn swap(&self, ptr: Option<&'a T>, order: Ordering) -> Option<&'a T> {
348        let res = self.ptr.swap(opt_to_ptr(ptr), load_store(order));
349
350        unsafe { ptr_to_opt(res) }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[allow(unused)]
359    fn variance<'a, 'b>(a: &'a u32, b: Option<&'b u32>) {
360        let r = RefSwap::new(a);
361        let stat: &'static u32 = &123;
362        r.store(stat, Ordering::Relaxed);
363
364        let r = OptionRefSwap::new(b);
365        let stat: Option<&'static u32> = Some(&123);
366        r.store(b, Ordering::Relaxed);
367    }
368}