gstd/sync/
rwlock.rs

1// This file is part of Gear.
2
3// Copyright (C) 2021-2025 Gear Technologies Inc.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
5
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10
11// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19use super::access::AccessQueue;
20use crate::{MessageId, exec, msg};
21use core::{
22    cell::{Cell, UnsafeCell},
23    future::Future,
24    ops::{Deref, DerefMut},
25    pin::Pin,
26    task::{Context, Poll},
27};
28
29type ReadersCount = u8;
30const READERS_LIMIT: ReadersCount = 32;
31
32/// A reader-writer lock.
33///
34/// This type of lock allows a number of readers or at most one writer at any
35/// point in time. The write portion of this lock typically allows modification
36/// of the underlying data (exclusive access) and the read portion of this lock
37/// typically allows for read-only access (shared access).
38///
39/// In comparison, a [`Mutex`](super::Mutex) does not distinguish between
40/// readers or writers that acquire the lock, therefore blocking any actors
41/// waiting for the lock to become available. An `RwLock` will allow any number
42/// of readers to acquire the lock as long as a writer is not holding the lock.
43///
44/// The type parameter `T` represents the data that this lock protects. The RAII
45/// guards returned from the locking methods implement [`Deref`] (and
46/// [`DerefMut`] for the `write` methods) to allow access to the content of the
47/// lock.
48///
49/// # Examples
50///
51/// The following program processes several messages. It locks the `RwLock` for
52/// reading when processing one of the `get` commands and for writing in the
53/// case of the `inc` command.
54///
55/// ```ignored
56/// use gstd::{msg, sync::RwLock, ActorId};
57///
58/// static mut DEST: ActorId = ActorId::zero();
59/// static RWLOCK: RwLock<u32> = RwLock::new(0);
60///
61/// #[unsafe(no_mangle)]
62/// extern "C" fn init() {
63///     // `some_address` can be obtained from the init payload
64///     # let some_address = ActorId::zero();
65///     unsafe { DEST = some_address };
66/// }
67///
68/// #[gstd::async_main]
69/// async fn main() {
70///     let payload = msg::load_bytes().expect("Unable to load payload bytes");
71///
72///     match payload.as_slice() {
73///         b"get" => {
74///             msg::reply(*RWLOCK.read().await, 0).unwrap();
75///         }
76///         b"inc" => {
77///             let mut val = RWLOCK.write().await;
78///             *val += 1;
79///         }
80///         b"ping&get" => {
81///             let _ = msg::send_bytes_for_reply(unsafe { DEST }, b"PING", 0, 0)
82///                 .expect("Unable to send bytes")
83///                 .await
84///                 .expect("Error in async message processing");
85///             msg::reply(*RWLOCK.read().await, 0).unwrap();
86///         }
87///         b"inc&ping" => {
88///             let mut val = RWLOCK.write().await;
89///             *val += 1;
90///             let _ = msg::send_bytes_for_reply(unsafe { DEST }, b"PING", 0, 0)
91///                 .expect("Unable to send bytes")
92///                 .await
93///                 .expect("Error in async message processing");
94///         }
95///         b"get&ping" => {
96///             let val = RWLOCK.read().await;
97///             let _ = msg::send_bytes_for_reply(unsafe { DEST }, b"PING", 0, 0)
98///                 .expect("Unable to send bytes")
99///                 .await
100///                 .expect("Error in async message processing");
101///             msg::reply(*val, 0).unwrap();
102///         }
103///         _ => {
104///             let _write = RWLOCK.write().await;
105///             RWLOCK.read().await;
106///         }
107///     }
108/// }
109/// # fn main() {}
110/// ```
111pub struct RwLock<T> {
112    locked: UnsafeCell<Option<MessageId>>,
113    value: UnsafeCell<T>,
114    readers: Cell<ReadersCount>,
115    queue: AccessQueue,
116}
117
118impl<T> From<T> for RwLock<T> {
119    fn from(t: T) -> Self {
120        RwLock::new(t)
121    }
122}
123
124impl<T: Default> Default for RwLock<T> {
125    fn default() -> Self {
126        <T as Default>::default().into()
127    }
128}
129
130impl<T> RwLock<T> {
131    /// Limit of readers for `RwLock`
132    pub const READERS_LIMIT: ReadersCount = READERS_LIMIT;
133
134    /// Create a new instance of an `RwLock<T>` which is unlocked.
135    pub const fn new(t: T) -> RwLock<T> {
136        RwLock {
137            value: UnsafeCell::new(t),
138            locked: UnsafeCell::new(None),
139            readers: Cell::new(0),
140            queue: AccessQueue::new(),
141        }
142    }
143
144    /// Locks this rwlock with shared read access, protecting the subsequent
145    /// code from executing by other actors until it can be acquired.
146    ///
147    /// The underlying code section will be blocked until there are no more
148    /// writers who hold the lock. There may be other readers currently inside
149    /// the lock when this method returns. This method does not provide any
150    /// guarantees with respect to the ordering of whether contentious readers
151    /// or writers will acquire the lock first.
152    ///
153    /// Returns an RAII guard, which will release this thread's shared access
154    /// once it is dropped.
155    pub fn read(&self) -> RwLockReadFuture<'_, T> {
156        RwLockReadFuture { lock: self }
157    }
158
159    /// Locks this rwlock with exclusive write access, blocking the underlying
160    /// code section until it can be acquired.
161    ///
162    /// This function will not return while other writers or other readers
163    /// currently have access to the lock.
164    ///
165    /// Returns an RAII guard which will drop the write access of this rwlock
166    /// when dropped.
167    pub fn write(&self) -> RwLockWriteFuture<'_, T> {
168        RwLockWriteFuture { lock: self }
169    }
170}
171
172// we are always single-threaded
173unsafe impl<T> Sync for RwLock<T> {}
174
175/// RAII structure used to release the shared read access of a lock when
176/// dropped.
177///
178/// This structure wrapped in the future is returned by the
179/// [`read`](RwLock::read) method on [`RwLock`].
180pub struct RwLockReadGuard<'a, T> {
181    lock: &'a RwLock<T>,
182    holder_msg_id: MessageId,
183}
184
185impl<T> RwLockReadGuard<'_, T> {
186    fn ensure_access_by_holder(&self) {
187        let current_msg_id = msg::id();
188        if self.holder_msg_id != current_msg_id {
189            panic!(
190                "Read lock guard held by message 0x{} is being accessed by message 0x{}",
191                hex::encode(self.holder_msg_id),
192                hex::encode(current_msg_id)
193            );
194        }
195    }
196}
197
198impl<T> Drop for RwLockReadGuard<'_, T> {
199    fn drop(&mut self) {
200        self.ensure_access_by_holder();
201        unsafe {
202            let readers = &self.lock.readers;
203            let readers_count = readers.get().saturating_sub(1);
204
205            readers.replace(readers_count);
206
207            if readers_count == 0 {
208                *self.lock.locked.get() = None;
209
210                if let Some(message_id) = self.lock.queue.dequeue() {
211                    exec::wake(message_id).expect("Failed to wake the message");
212                }
213            }
214        }
215    }
216}
217
218impl<'a, T> AsRef<T> for RwLockReadGuard<'a, T> {
219    fn as_ref(&self) -> &'a T {
220        self.ensure_access_by_holder();
221        unsafe { &*self.lock.value.get() }
222    }
223}
224
225impl<T> Deref for RwLockReadGuard<'_, T> {
226    type Target = T;
227
228    fn deref(&self) -> &T {
229        self.ensure_access_by_holder();
230        unsafe { &*self.lock.value.get() }
231    }
232}
233
234/// RAII structure used to release the exclusive write access of a lock when
235/// dropped.
236///
237/// This structure wrapped in the future is returned by the
238/// [`write`](RwLock::write) method on [`RwLock`].
239pub struct RwLockWriteGuard<'a, T> {
240    lock: &'a RwLock<T>,
241    holder_msg_id: MessageId,
242}
243
244impl<T> RwLockWriteGuard<'_, T> {
245    fn ensure_access_by_holder(&self) {
246        let current_msg_id = msg::id();
247        if self.holder_msg_id != current_msg_id {
248            panic!(
249                "Write lock guard held by message 0x{} is being accessed by message 0x{}",
250                hex::encode(self.holder_msg_id),
251                hex::encode(current_msg_id)
252            );
253        }
254    }
255}
256
257impl<T> Drop for RwLockWriteGuard<'_, T> {
258    fn drop(&mut self) {
259        self.ensure_access_by_holder();
260        unsafe {
261            let locked_by = &mut *self.lock.locked.get();
262            let owner_msg_id = locked_by.unwrap_or_else(|| {
263                panic!(
264                    "Write lock guard held by message 0x{} is being dropped for non-existing lock",
265                    hex::encode(self.holder_msg_id),
266                );
267            });
268            if owner_msg_id != self.holder_msg_id {
269                panic!(
270                    "Write lock guard held by message 0x{} does not match lock owner message 0x{}",
271                    hex::encode(self.holder_msg_id),
272                    hex::encode(owner_msg_id),
273                );
274            }
275            *locked_by = None;
276            if let Some(message_id) = self.lock.queue.dequeue() {
277                exec::wake(message_id).expect("Failed to wake the message");
278            }
279        }
280    }
281}
282
283impl<'a, T> AsRef<T> for RwLockWriteGuard<'a, T> {
284    fn as_ref(&self) -> &'a T {
285        self.ensure_access_by_holder();
286        unsafe { &*self.lock.value.get() }
287    }
288}
289
290impl<'a, T> AsMut<T> for RwLockWriteGuard<'a, T> {
291    fn as_mut(&mut self) -> &'a mut T {
292        self.ensure_access_by_holder();
293        unsafe { &mut *self.lock.value.get() }
294    }
295}
296
297impl<T> Deref for RwLockWriteGuard<'_, T> {
298    type Target = T;
299
300    fn deref(&self) -> &T {
301        self.ensure_access_by_holder();
302        unsafe { &*self.lock.value.get() }
303    }
304}
305
306impl<T> DerefMut for RwLockWriteGuard<'_, T> {
307    fn deref_mut(&mut self) -> &mut T {
308        self.ensure_access_by_holder();
309        unsafe { &mut *self.lock.value.get() }
310    }
311}
312
313/// The future returned by the [`read`](RwLock::read) method.
314///
315/// The output of the future is the [`RwLockReadGuard`] that can be obtained by
316/// using `await` syntax.
317///
318/// # Examples
319///
320/// The following example explicitly annotates variable types for demonstration
321/// purposes only. Usually, annotating types is unnecessary since
322/// they can be inferred automatically.
323///
324/// ```
325/// use gstd::sync::{RwLock, RwLockReadFuture, RwLockReadGuard};
326///
327/// #[gstd::async_main]
328/// async fn main() {
329///     let rwlock: RwLock<i32> = RwLock::new(42);
330///     let future: RwLockReadFuture<i32> = rwlock.read();
331///     let guard: RwLockReadGuard<i32> = future.await;
332///     let value: i32 = *guard;
333///     assert_eq!(value, 42);
334/// }
335/// # fn main() {}
336/// ```
337pub struct RwLockReadFuture<'a, T> {
338    lock: &'a RwLock<T>,
339}
340
341/// The future returned by the [`write`](RwLock::write) method.
342///
343/// The output of the future is the [`RwLockWriteGuard`] that can be obtained by
344/// using `await` syntax.
345///
346/// # Examples
347///
348/// ```
349/// use gstd::sync::{RwLock, RwLockWriteFuture, RwLockWriteGuard};
350///
351/// #[gstd::async_main]
352/// async fn main() {
353///     let rwlock: RwLock<i32> = RwLock::new(42);
354///     let future: RwLockWriteFuture<i32> = rwlock.write();
355///     let mut guard: RwLockWriteGuard<i32> = future.await;
356///     let value: i32 = *guard;
357///     assert_eq!(value, 42);
358///     *guard = 84;
359///     assert_eq!(*guard, 42);
360/// }
361/// # fn main() {}
362/// ```
363pub struct RwLockWriteFuture<'a, T> {
364    lock: &'a RwLock<T>,
365}
366
367impl<'a, T> Future for RwLockReadFuture<'a, T> {
368    type Output = RwLockReadGuard<'a, T>;
369
370    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
371        let readers = &self.lock.readers;
372        let readers_count = readers.get().saturating_add(1);
373
374        let current_msg_id = msg::id();
375        let lock = unsafe { &mut *self.lock.locked.get() };
376        if lock.is_none() && readers_count <= READERS_LIMIT {
377            readers.replace(readers_count);
378            Poll::Ready(RwLockReadGuard {
379                lock: self.lock,
380                holder_msg_id: current_msg_id,
381            })
382        } else {
383            // If the message is already in the access queue, and we come here,
384            // it means the message has just been woken up from the waitlist.
385            // In that case we do not want to register yet another access attempt
386            // and just go back to the waitlist.
387            if !self.lock.queue.contains(&current_msg_id) {
388                self.lock.queue.enqueue(current_msg_id);
389            }
390            Poll::Pending
391        }
392    }
393}
394
395impl<'a, T> Future for RwLockWriteFuture<'a, T> {
396    type Output = RwLockWriteGuard<'a, T>;
397
398    fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
399        let current_msg_id = msg::id();
400        let lock = unsafe { &mut *self.lock.locked.get() };
401        if lock.is_none() && self.lock.readers.get() == 0 {
402            *lock = Some(current_msg_id);
403            Poll::Ready(RwLockWriteGuard {
404                lock: self.lock,
405                holder_msg_id: current_msg_id,
406            })
407        } else {
408            // If the message is already in the access queue, and we come here,
409            // it means the message has just been woken up from the waitlist.
410            // In that case we do not want to register yet another access attempt
411            // and just go back to the waitlist.
412            if !self.lock.queue.contains(&current_msg_id) {
413                self.lock.queue.enqueue(current_msg_id);
414            }
415            Poll::Pending
416        }
417    }
418}