ref_swap/lib.rs
1// Copyright (C) 2022 Nitrokey GmbH
2// SPDX-License-Identifier: Apache-2.0 or MIT
3
4#![cfg_attr(not(test), no_std)]
5#![doc = include_str!("../README.md")]
6
7use core::{
8 marker::PhantomData,
9 sync::atomic::{AtomicPtr, Ordering},
10};
11
12/// Relaxed operations can lead to race conditions:
13///
14/// Thread A can pass a reference with relaxed ordering to thread B which means that no garanties are made that the data seen by thread B after dereferencing the reference will include the mutations that thread B
15fn load(ordering: Ordering) -> Ordering {
16 match ordering {
17 Ordering::Relaxed | Ordering::Acquire => Ordering::Acquire,
18
19 Ordering::AcqRel => Ordering::AcqRel,
20 Ordering::SeqCst => Ordering::SeqCst,
21 Ordering::Release => panic!("Release ordering cannot be used for loads"),
22 _ => unimplemented!("{ordering:?} is not supported"),
23 }
24}
25
26fn store(ordering: Ordering) -> Ordering {
27 match ordering {
28 Ordering::Relaxed | Ordering::Release => Ordering::Release,
29
30 Ordering::AcqRel => Ordering::AcqRel,
31 Ordering::SeqCst => Ordering::SeqCst,
32 Ordering::Acquire => panic!("Acquire ordering cannot be used for stores"),
33 _ => unimplemented!("{ordering:?} is not supported"),
34 }
35}
36
37fn load_store(ordering: Ordering) -> Ordering {
38 match ordering {
39 Ordering::Relaxed | Ordering::Release | Ordering::Acquire | Ordering::AcqRel => {
40 Ordering::AcqRel
41 }
42
43 Ordering::SeqCst => Ordering::SeqCst,
44 _ => unimplemented!("{ordering:?} is not supported"),
45 }
46}
47
48// SAFETY:
49//
50// # Lifetimes
51// After being loaded, a reference is guarateed to be alive for 'a because it was required when it was stored
52//
53// # Thread Safety
54// the `load` and `store` and `load_store` functions are used to coerce orderings to have at least Release-Acquire semantics.
55// This ensures that any data written *before* the call to a `store` is synchronized with the thread observing the reference through a `load` operation
56
57/// A reference that can atomically be changed using another reference with the same lifetime and type
58pub struct RefSwap<'a, T> {
59 ptr: AtomicPtr<T>,
60 phantom: PhantomData<&'a T>,
61}
62
63impl<'a, T> RefSwap<'a, T> {
64 pub const fn new(data: &'a T) -> Self {
65 Self {
66 ptr: AtomicPtr::new(data as *const _ as *mut _),
67 phantom: PhantomData,
68 }
69 }
70
71 /// Stores a reference if the current value is the same as the current value.
72 ///
73 /// Be aware that the comparison is only between the reference, not between the value.
74 /// If current point to another adress in memory than the reference currently holds, it will fail,
75 /// Even if both are equal according to a `PartialEq` implementation.
76 ///
77 /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_and_swap`](core::sync::atomic::AtomicPtr::compare_and_swap)
78 #[deprecated(note = "Use `compare_exchange` or `compare_exchange_weak` instead")]
79 pub fn compare_and_swap(&self, current: &'a T, new: &'a T, order: Ordering) -> &'a T {
80 #[allow(deprecated)]
81 let ptr = self.ptr.compare_and_swap(
82 current as *const _ as *mut _,
83 new as *const _ as *mut _,
84 load_store(order),
85 );
86
87 unsafe { &*ptr }
88 }
89
90 /// Stores a reference if the current value is the same as the current value.
91 ///
92 /// Be aware that the comparison is only between the reference, not between the value.
93 /// If current point to another adress in memory than the reference currently holds, it will fail,
94 /// Even if both are equal according to a `PartialEq` implementation.
95 ///
96 /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange`](core::sync::atomic::AtomicPtr::compare_exchange)
97 pub fn compare_exchange(
98 &self,
99 current: &'a T,
100 new: &'a T,
101 success: Ordering,
102 failure: Ordering,
103 ) -> Result<&'a T, &'a T> {
104 let res = self.ptr.compare_exchange(
105 current as *const _ as *mut _,
106 new as *const _ as *mut _,
107 load_store(success),
108 load(failure),
109 );
110
111 res.map(|ptr| unsafe { &*ptr })
112 .map_err(|ptr| unsafe { &*ptr })
113 }
114
115 /// Stores a reference if the current value is the same as the current value.
116 ///
117 /// Be aware that the comparison is only between the reference, not between the value.
118 /// If current point to another adress in memory than the reference currently holds, it will fail,
119 /// Even if both are equal according to a `PartialEq` implementation.
120 ///
121 /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange_weak`](core::sync::atomic::AtomicPtr::compare_exchange_weak)
122 pub fn compare_exchange_weak(
123 &self,
124 current: &'a T,
125 new: &'a T,
126 success: Ordering,
127 failure: Ordering,
128 ) -> Result<&'a T, &'a T> {
129 let res = self.ptr.compare_exchange_weak(
130 current as *const _ as *mut _,
131 new as *const _ as *mut _,
132 load_store(success),
133 load(failure),
134 );
135
136 res.map(|ptr| unsafe { &*ptr })
137 .map_err(|ptr| unsafe { &*ptr })
138 }
139
140 /// Get a mutable reference to the current stored reference.
141 ///
142 /// This is safe because the mutable reference guarantees that no other threads are concurrently accessing the atomic data.
143 pub fn get_mut<'s>(&'s mut self) -> &'s mut &'a T {
144 let res: &'s mut *mut T = self.ptr.get_mut();
145 unsafe { &mut *(res as *mut *mut T as *mut &'a T) }
146 }
147
148 /// Consumes the atomic and returns the contained value.
149 ///
150 /// This is safe because passing `self` by value guarantees that no other threads are concurrently accessing the atomic data.
151 pub fn into_inner(self) -> &'a T {
152 let res = self.ptr.into_inner();
153 unsafe { &*res }
154 }
155
156 /// Fetches the value, and applies a function to it that returns an optional new value. `Returns` a `Result` of `Ok(previous_value)` if the function returned `Some(_)`, else `Err(previous_value)`.
157 ///
158 /// For more information on the orderings, se the documentation of [`AtomicPtr::fetch_update`](core::sync::atomic::AtomicPtr::fetch_update)
159 pub fn fetch_update<F: FnMut(&'a T) -> Option<&'a T>>(
160 &self,
161 set_order: Ordering,
162 fetch_order: Ordering,
163 mut f: F,
164 ) -> Result<&'a T, &'a T> {
165 self.ptr
166 .fetch_update(load_store(set_order), load(fetch_order), |ptr| {
167 f(unsafe { &*ptr }).map(|r| r as *const _ as *mut _)
168 })
169 .map(|ptr| unsafe { &*ptr })
170 .map_err(|ptr| unsafe { &*ptr })
171 }
172
173 /// Loads a value
174 pub fn load(&self, order: Ordering) -> &'a T {
175 let res = self.ptr.load(load(order));
176
177 unsafe { &*res }
178 }
179
180 /// Store a value
181 pub fn store(&self, ptr: &'a T, order: Ordering) {
182 self.ptr.store(ptr as *const _ as *mut _, store(order));
183 }
184
185 /// Stores a value into the pointer, returning the previous value.
186 pub fn swap(&self, ptr: &'a T, order: Ordering) -> &'a T {
187 let res = self.ptr.swap(ptr as *const _ as *mut _, load_store(order));
188
189 unsafe { &*res }
190 }
191}
192
193/// An optionnal reference that can atomically be changed to another optionnal reference with the same lifetime and type
194pub struct OptionRefSwap<'a, T> {
195 ptr: AtomicPtr<T>,
196 phantom: PhantomData<&'a T>,
197}
198
199/// Returns a null pointer if `ptr` is None, otherwise returns the the pointer corresponding to the reference
200const fn opt_to_ptr<T>(ptr: Option<&T>) -> *mut T {
201 match ptr {
202 Some(r) => r as *const _ as *mut _,
203 None => core::ptr::null_mut(),
204 }
205}
206
207/// # Safety: `ptr` must come from `opt_to_ptr` with the corresponding lifetime
208unsafe fn ptr_to_opt<'a, T>(ptr: *mut T) -> Option<&'a T> {
209 if ptr.is_null() {
210 None
211 } else {
212 // Safety: we know that ptr comes from `opt_to_ptr`, and therefor is a `&'a T` when not null
213 Some(unsafe { &*ptr })
214 }
215}
216
217impl<'a, T> OptionRefSwap<'a, T> {
218 pub const fn new(data: Option<&'a T>) -> Self {
219 Self {
220 ptr: AtomicPtr::new(opt_to_ptr(data)),
221 phantom: PhantomData,
222 }
223 }
224
225 /// Stores a reference if the current value is the same as the current value.
226 ///
227 /// Be aware that the comparison is only between the reference, not between the value.
228 /// If current point to another adress in memory than the reference currently holds, it will fail,
229 /// Even if both are equal according to a `PartialEq` implementation.
230 ///
231 /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_and_swap`](core::sync::atomic::AtomicPtr::compare_and_swap)
232 #[deprecated(note = "Use `compare_exchange` or `compare_exchange_weak` instead")]
233 pub fn compare_and_swap(
234 &self,
235 current: Option<&'a T>,
236 new: Option<&'a T>,
237 order: Ordering,
238 ) -> Option<&'a T> {
239 #[allow(deprecated)]
240 let ptr =
241 self.ptr
242 .compare_and_swap(opt_to_ptr(current), opt_to_ptr(new), load_store(order));
243
244 unsafe { ptr_to_opt(ptr) }
245 }
246
247 /// Stores a reference if the current value is the same as the current value.
248 ///
249 /// Be aware that the comparison is only between the reference, not between the value.
250 /// If current point to another adress in memory than the reference currently holds, it will fail,
251 /// Even if both are equal according to a `PartialEq` implementation.
252 ///
253 /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange`](core::sync::atomic::AtomicPtr::compare_exchange)
254 pub fn compare_exchange(
255 &self,
256 current: Option<&'a T>,
257 new: Option<&'a T>,
258 success: Ordering,
259 failure: Ordering,
260 ) -> Result<Option<&'a T>, Option<&'a T>> {
261 let res = self.ptr.compare_exchange(
262 opt_to_ptr(current),
263 opt_to_ptr(new),
264 load_store(success),
265 load(failure),
266 );
267
268 res.map(|ptr| unsafe { ptr_to_opt(ptr) })
269 .map_err(|ptr| unsafe { ptr_to_opt(ptr) })
270 }
271
272 /// Stores a reference if the current value is the same as the current value.
273 ///
274 /// Be aware that the comparison is only between the reference, not between the value.
275 /// If current point to another adress in memory than the reference currently holds, it will fail,
276 /// Even if both are equal according to a `PartialEq` implementation.
277 ///
278 /// For more information on the orderings, se the documentation of [`AtomicPtr::compare_exchange_weak`](core::sync::atomic::AtomicPtr::compare_exchange_weak)
279 pub fn compare_exchange_weak(
280 &self,
281 current: Option<&'a T>,
282 new: Option<&'a T>,
283 success: Ordering,
284 failure: Ordering,
285 ) -> Result<Option<&'a T>, Option<&'a T>> {
286 let res = self.ptr.compare_exchange_weak(
287 opt_to_ptr(current),
288 opt_to_ptr(new),
289 load_store(success),
290 load(failure),
291 );
292
293 res.map(|ptr| unsafe { ptr_to_opt(ptr) })
294 .map_err(|ptr| unsafe { ptr_to_opt(ptr) })
295 }
296
297 /// Get a mutable reference to the current stored reference.
298 ///
299 /// This is safe because the mutable reference guarantees that no other threads are concurrently accessing the atomic data.
300 #[allow(unused)]
301 fn get_mut<'s>(&'s mut self) -> &'s mut Option<&'a T> {
302 let res: &'s mut *mut T = self.ptr.get_mut();
303
304 // TODO: Is this transmute really safe? Making this function private until I'm sure
305
306 unsafe { core::mem::transmute(res) }
307 }
308
309 /// Consumes the atomic and returns the contained value.
310 ///
311 /// This is safe because passing `self` by value guarantees that no other threads are concurrently accessing the atomic data.
312 pub fn into_inner(self) -> &'a T {
313 let res = self.ptr.into_inner();
314 unsafe { &*res }
315 }
316
317 /// Fetches the value, and applies a function to it that returns an optional new value. `Returns` a `Result` of `Ok(previous_value)` if the function returned `Some(_)`, else `Err(previous_value)`.
318 ///
319 /// For more information on the orderings, se the documentation of [`AtomicPtr::fetch_update`](core::sync::atomic::AtomicPtr::fetch_update)
320 pub fn fetch_update<F: FnMut(Option<&'a T>) -> Option<Option<&'a T>>>(
321 &self,
322 set_order: Ordering,
323 fetch_order: Ordering,
324 mut f: F,
325 ) -> Result<Option<&'a T>, Option<&'a T>> {
326 self.ptr
327 .fetch_update(load_store(set_order), load(fetch_order), |ptr| {
328 f(unsafe { ptr_to_opt(ptr) }).map(opt_to_ptr)
329 })
330 .map(|ptr| unsafe { ptr_to_opt(ptr) })
331 .map_err(|ptr| unsafe { ptr_to_opt(ptr) })
332 }
333
334 /// Loads a value
335 pub fn load(&self, order: Ordering) -> Option<&'a T> {
336 let res = self.ptr.load(load(order));
337
338 unsafe { ptr_to_opt(res) }
339 }
340
341 /// Stores a value
342 pub fn store(&self, ptr: Option<&'a T>, order: Ordering) {
343 self.ptr.store(opt_to_ptr(ptr), store(order));
344 }
345
346 /// Stores a value into the pointer, returning the previous value.
347 pub fn swap(&self, ptr: Option<&'a T>, order: Ordering) -> Option<&'a T> {
348 let res = self.ptr.swap(opt_to_ptr(ptr), load_store(order));
349
350 unsafe { ptr_to_opt(res) }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 #[allow(unused)]
359 fn variance<'a, 'b>(a: &'a u32, b: Option<&'b u32>) {
360 let r = RefSwap::new(a);
361 let stat: &'static u32 = &123;
362 r.store(stat, Ordering::Relaxed);
363
364 let r = OptionRefSwap::new(b);
365 let stat: Option<&'static u32> = Some(&123);
366 r.store(b, Ordering::Relaxed);
367 }
368}