yrs/
atomic.rs

1//! `atomic` module is a home for [AtomicRef] cell-like struct, used to perform thread-safe
2//! operations using underlying hardware intristics.
3
4use std::fmt::Formatter;
5use std::ptr::null_mut;
6use std::sync::atomic::{AtomicPtr, Ordering};
7use std::sync::Arc;
8
9/// Atomic reference holding a value, that's supposed to be shared - potentially between multiple
10/// threads. Internally this value is hidden behind [Arc] reference, which is returned during
11/// [AtomicRef::get] method. This cell doesn't allow to return &mut references to stored object.
12/// Instead updates can be performed as lock-free operation mutating function
13/// passed over during [AtomicRef::update] call.
14///
15/// Example:
16/// ```rust
17/// use yrs::atomic::AtomicRef;
18///
19/// let atom = AtomicRef::new(vec!["John"]);
20/// atom.update(|users| {
21///     let mut users_copy = users.cloned().unwrap_or_else(Vec::default);
22///     users_copy.push("Susan");
23///     users_copy
24/// });
25/// let users = atom.get(); // John, Susan
26/// ```
27/// **Important note**: since [AtomicRef::update] may call provided function multiple times (in
28/// scenarios, when another thread intercepted update with its own update call), provided function
29/// should be idempotent and preferably quick to execute.
30#[repr(transparent)]
31pub struct AtomicRef<T>(AtomicPtr<T>);
32
33unsafe impl<T> Send for AtomicRef<T> {}
34unsafe impl<T> Sync for AtomicRef<T> {}
35
36impl<T> AtomicRef<T> {
37    /// Creates a new instance of [AtomicRef]. This call boxes provided `value` and allocates it
38    /// on a heap.
39    pub fn new(value: T) -> Self {
40        let arc = Arc::new(value);
41        let ptr = Arc::into_raw(arc) as *mut _;
42        AtomicRef(AtomicPtr::new(ptr))
43    }
44
45    /// Returns a reference to current state hold by the [AtomicRef]. Keep in mind that after
46    /// acquiring it, it may not present the current view of the state, but instead be changed by
47    /// the concurrent [AtomicRef::update] call.
48    pub fn get(&self) -> Option<Arc<T>> {
49        let ptr = self.0.load(Ordering::SeqCst);
50        if ptr.is_null() {
51            None
52        } else {
53            let arc = unsafe { Arc::from_raw(ptr) };
54            let result = arc.clone();
55            std::mem::forget(arc);
56            Some(result)
57        }
58    }
59
60    /// Atomically replaces currently stored value with a new one, returning the last stored value.
61    pub fn swap(&self, value: T) -> Option<Arc<T>> {
62        let new_ptr = Arc::into_raw(Arc::new(value)) as *mut _;
63        let prev = self.0.swap(new_ptr, Ordering::Release);
64        if prev.is_null() {
65            None
66        } else {
67            let arc = unsafe { Arc::from_raw(prev) };
68            Some(arc)
69        }
70    }
71
72    /// Atomically replaces currently stored value with a null, returning the last stored value.
73    pub fn take(&self) -> Option<Arc<T>> {
74        let prev = self.0.swap(null_mut(), Ordering::Release);
75        if prev.is_null() {
76            None
77        } else {
78            let arc = unsafe { Arc::from_raw(prev) };
79            Some(arc)
80        }
81    }
82
83    /// Updates stored value in place using provided function `f`, which takes read-only refrence
84    /// to the most recently known state and producing new state in the result.
85    ///
86    /// **Important note**: since [AtomicRef::update] may call provided function multiple times (in
87    /// scenarios, when another thread intercepted update with its own update call), provided
88    /// function should be idempotent and preferably quick to execute.
89    pub fn update<F>(&self, f: F)
90    where
91        F: Fn(Option<&T>) -> T,
92    {
93        loop {
94            let old_ptr = self.0.load(Ordering::SeqCst);
95            let old_value = unsafe { old_ptr.as_ref() };
96
97            // modify copied value
98            let new_value = f(old_value);
99
100            let new_ptr = Arc::into_raw(Arc::new(new_value)) as *mut _;
101
102            let swapped =
103                self.0
104                    .compare_exchange(old_ptr, new_ptr, Ordering::AcqRel, Ordering::Relaxed);
105
106            match swapped {
107                Ok(old) => {
108                    if !old.is_null() {
109                        unsafe { Arc::decrement_strong_count(old) }; // drop reference to old
110                    }
111                    break; // we succeeded
112                }
113                Err(new) => {
114                    if !new.is_null() {
115                        unsafe { Arc::decrement_strong_count(new) }; // drop reference to new and retry
116                    }
117                }
118            }
119        }
120    }
121}
122
123impl<T: Copy> AtomicRef<T> {
124    /// Returns a current state copy hold by the [AtomicRef]. Keep in mind that after
125    /// acquiring it, it may not present the current view of the state, but instead be changed by
126    /// the concurrent [AtomicRef::update] call.
127    pub fn get_owned(&self) -> Option<T> {
128        let ptr = self.0.load(Ordering::SeqCst);
129        if ptr.is_null() {
130            None
131        } else {
132            let arc = unsafe { Arc::from_raw(ptr) };
133            let result = *arc;
134            std::mem::forget(arc);
135            Some(result)
136        }
137    }
138}
139
140impl<T> Drop for AtomicRef<T> {
141    fn drop(&mut self) {
142        unsafe {
143            let ptr = self.0.load(Ordering::Acquire);
144            if !ptr.is_null() {
145                Arc::decrement_strong_count(ptr);
146            }
147        }
148    }
149}
150
151impl<T> PartialEq for AtomicRef<T>
152where
153    T: PartialEq,
154{
155    fn eq(&self, other: &Self) -> bool {
156        let a = self.0.load(Ordering::Acquire);
157        let b = other.0.load(Ordering::Acquire);
158        if std::ptr::eq(a, b) {
159            true
160        } else {
161            unsafe { a.as_ref() == b.as_ref() }
162        }
163    }
164}
165
166impl<T> Eq for AtomicRef<T> where T: Eq {}
167
168impl<T: std::fmt::Debug> std::fmt::Debug for AtomicRef<T> {
169    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170        let value = self.get();
171        write!(f, "AtomicRef({:?})", value.as_deref())
172    }
173}
174
175impl<T> Default for AtomicRef<T> {
176    fn default() -> Self {
177        AtomicRef(AtomicPtr::new(null_mut()))
178    }
179}
180
181#[cfg(test)]
182mod test {
183    use crate::atomic::AtomicRef;
184
185    #[test]
186    fn init_get() {
187        let atom = AtomicRef::new(1);
188        let value = atom.get();
189        assert_eq!(value.as_deref().cloned(), Some(1));
190    }
191
192    #[test]
193    fn update() {
194        let atom = AtomicRef::new(vec!["John"]);
195        let old_users = atom.get().unwrap();
196        let actual: &[&str] = &old_users;
197        assert_eq!(actual, &["John"]);
198
199        atom.update(|users| {
200            let mut users_copy = users.cloned().unwrap_or_else(Vec::default);
201            users_copy.push("Susan");
202            users_copy
203        });
204
205        // after update new Arc ptr data returns updated content
206        let new_users = atom.get().unwrap();
207        let actual: &[&str] = &new_users;
208        assert_eq!(actual, &["John", "Susan"]);
209
210        // old Arc ptr data is unchanged
211        let actual: &[&str] = &old_users;
212        assert_eq!(actual, &["John"]);
213    }
214}