gstd/sync/rwlock.rs
1// This file is part of Gear.
2
3// Copyright (C) 2021-2025 Gear Technologies Inc.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
5
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10
11// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19use super::access::AccessQueue;
20use crate::{MessageId, exec, msg};
21use core::{
22 cell::{Cell, UnsafeCell},
23 future::Future,
24 ops::{Deref, DerefMut},
25 pin::Pin,
26 task::{Context, Poll},
27};
28
29type ReadersCount = u8;
30const READERS_LIMIT: ReadersCount = 32;
31
32/// A reader-writer lock.
33///
34/// This type of lock allows a number of readers or at most one writer at any
35/// point in time. The write portion of this lock typically allows modification
36/// of the underlying data (exclusive access) and the read portion of this lock
37/// typically allows for read-only access (shared access).
38///
39/// In comparison, a [`Mutex`](super::Mutex) does not distinguish between
40/// readers or writers that acquire the lock, therefore blocking any actors
41/// waiting for the lock to become available. An `RwLock` will allow any number
42/// of readers to acquire the lock as long as a writer is not holding the lock.
43///
44/// The type parameter `T` represents the data that this lock protects. The RAII
45/// guards returned from the locking methods implement [`Deref`] (and
46/// [`DerefMut`] for the `write` methods) to allow access to the content of the
47/// lock.
48///
49/// # Examples
50///
51/// The following program processes several messages. It locks the `RwLock` for
52/// reading when processing one of the `get` commands and for writing in the
53/// case of the `inc` command.
54///
55/// ```ignored
56/// use gstd::{msg, sync::RwLock, ActorId};
57///
58/// static mut DEST: ActorId = ActorId::zero();
59/// static RWLOCK: RwLock<u32> = RwLock::new(0);
60///
61/// #[unsafe(no_mangle)]
62/// extern "C" fn init() {
63/// // `some_address` can be obtained from the init payload
64/// # let some_address = ActorId::zero();
65/// unsafe { DEST = some_address };
66/// }
67///
68/// #[gstd::async_main]
69/// async fn main() {
70/// let payload = msg::load_bytes().expect("Unable to load payload bytes");
71///
72/// match payload.as_slice() {
73/// b"get" => {
74/// msg::reply(*RWLOCK.read().await, 0).unwrap();
75/// }
76/// b"inc" => {
77/// let mut val = RWLOCK.write().await;
78/// *val += 1;
79/// }
80/// b"ping&get" => {
81/// let _ = msg::send_bytes_for_reply(unsafe { DEST }, b"PING", 0, 0)
82/// .expect("Unable to send bytes")
83/// .await
84/// .expect("Error in async message processing");
85/// msg::reply(*RWLOCK.read().await, 0).unwrap();
86/// }
87/// b"inc&ping" => {
88/// let mut val = RWLOCK.write().await;
89/// *val += 1;
90/// let _ = msg::send_bytes_for_reply(unsafe { DEST }, b"PING", 0, 0)
91/// .expect("Unable to send bytes")
92/// .await
93/// .expect("Error in async message processing");
94/// }
95/// b"get&ping" => {
96/// let val = RWLOCK.read().await;
97/// let _ = msg::send_bytes_for_reply(unsafe { DEST }, b"PING", 0, 0)
98/// .expect("Unable to send bytes")
99/// .await
100/// .expect("Error in async message processing");
101/// msg::reply(*val, 0).unwrap();
102/// }
103/// _ => {
104/// let _write = RWLOCK.write().await;
105/// RWLOCK.read().await;
106/// }
107/// }
108/// }
109/// # fn main() {}
110/// ```
111pub struct RwLock<T> {
112 locked: UnsafeCell<Option<MessageId>>,
113 value: UnsafeCell<T>,
114 readers: Cell<ReadersCount>,
115 queue: AccessQueue,
116}
117
118impl<T> From<T> for RwLock<T> {
119 fn from(t: T) -> Self {
120 RwLock::new(t)
121 }
122}
123
124impl<T: Default> Default for RwLock<T> {
125 fn default() -> Self {
126 <T as Default>::default().into()
127 }
128}
129
130impl<T> RwLock<T> {
131 /// Limit of readers for `RwLock`
132 pub const READERS_LIMIT: ReadersCount = READERS_LIMIT;
133
134 /// Create a new instance of an `RwLock<T>` which is unlocked.
135 pub const fn new(t: T) -> RwLock<T> {
136 RwLock {
137 value: UnsafeCell::new(t),
138 locked: UnsafeCell::new(None),
139 readers: Cell::new(0),
140 queue: AccessQueue::new(),
141 }
142 }
143
144 /// Locks this rwlock with shared read access, protecting the subsequent
145 /// code from executing by other actors until it can be acquired.
146 ///
147 /// The underlying code section will be blocked until there are no more
148 /// writers who hold the lock. There may be other readers currently inside
149 /// the lock when this method returns. This method does not provide any
150 /// guarantees with respect to the ordering of whether contentious readers
151 /// or writers will acquire the lock first.
152 ///
153 /// Returns an RAII guard, which will release this thread's shared access
154 /// once it is dropped.
155 pub fn read(&self) -> RwLockReadFuture<'_, T> {
156 RwLockReadFuture { lock: self }
157 }
158
159 /// Locks this rwlock with exclusive write access, blocking the underlying
160 /// code section until it can be acquired.
161 ///
162 /// This function will not return while other writers or other readers
163 /// currently have access to the lock.
164 ///
165 /// Returns an RAII guard which will drop the write access of this rwlock
166 /// when dropped.
167 pub fn write(&self) -> RwLockWriteFuture<'_, T> {
168 RwLockWriteFuture { lock: self }
169 }
170}
171
172// we are always single-threaded
173unsafe impl<T> Sync for RwLock<T> {}
174
175/// RAII structure used to release the shared read access of a lock when
176/// dropped.
177///
178/// This structure wrapped in the future is returned by the
179/// [`read`](RwLock::read) method on [`RwLock`].
180pub struct RwLockReadGuard<'a, T> {
181 lock: &'a RwLock<T>,
182 holder_msg_id: MessageId,
183}
184
185impl<T> RwLockReadGuard<'_, T> {
186 fn ensure_access_by_holder(&self) {
187 let current_msg_id = msg::id();
188 if self.holder_msg_id != current_msg_id {
189 panic!(
190 "Read lock guard held by message 0x{} is being accessed by message 0x{}",
191 hex::encode(self.holder_msg_id),
192 hex::encode(current_msg_id)
193 );
194 }
195 }
196}
197
198impl<T> Drop for RwLockReadGuard<'_, T> {
199 fn drop(&mut self) {
200 self.ensure_access_by_holder();
201 unsafe {
202 let readers = &self.lock.readers;
203 let readers_count = readers.get().saturating_sub(1);
204
205 readers.replace(readers_count);
206
207 if readers_count == 0 {
208 *self.lock.locked.get() = None;
209
210 if let Some(message_id) = self.lock.queue.dequeue() {
211 exec::wake(message_id).expect("Failed to wake the message");
212 }
213 }
214 }
215 }
216}
217
218impl<'a, T> AsRef<T> for RwLockReadGuard<'a, T> {
219 fn as_ref(&self) -> &'a T {
220 self.ensure_access_by_holder();
221 unsafe { &*self.lock.value.get() }
222 }
223}
224
225impl<T> Deref for RwLockReadGuard<'_, T> {
226 type Target = T;
227
228 fn deref(&self) -> &T {
229 self.ensure_access_by_holder();
230 unsafe { &*self.lock.value.get() }
231 }
232}
233
234/// RAII structure used to release the exclusive write access of a lock when
235/// dropped.
236///
237/// This structure wrapped in the future is returned by the
238/// [`write`](RwLock::write) method on [`RwLock`].
239pub struct RwLockWriteGuard<'a, T> {
240 lock: &'a RwLock<T>,
241 holder_msg_id: MessageId,
242}
243
244impl<T> RwLockWriteGuard<'_, T> {
245 fn ensure_access_by_holder(&self) {
246 let current_msg_id = msg::id();
247 if self.holder_msg_id != current_msg_id {
248 panic!(
249 "Write lock guard held by message 0x{} is being accessed by message 0x{}",
250 hex::encode(self.holder_msg_id),
251 hex::encode(current_msg_id)
252 );
253 }
254 }
255}
256
257impl<T> Drop for RwLockWriteGuard<'_, T> {
258 fn drop(&mut self) {
259 self.ensure_access_by_holder();
260 unsafe {
261 let locked_by = &mut *self.lock.locked.get();
262 let owner_msg_id = locked_by.unwrap_or_else(|| {
263 panic!(
264 "Write lock guard held by message 0x{} is being dropped for non-existing lock",
265 hex::encode(self.holder_msg_id),
266 );
267 });
268 if owner_msg_id != self.holder_msg_id {
269 panic!(
270 "Write lock guard held by message 0x{} does not match lock owner message 0x{}",
271 hex::encode(self.holder_msg_id),
272 hex::encode(owner_msg_id),
273 );
274 }
275 *locked_by = None;
276 if let Some(message_id) = self.lock.queue.dequeue() {
277 exec::wake(message_id).expect("Failed to wake the message");
278 }
279 }
280 }
281}
282
283impl<'a, T> AsRef<T> for RwLockWriteGuard<'a, T> {
284 fn as_ref(&self) -> &'a T {
285 self.ensure_access_by_holder();
286 unsafe { &*self.lock.value.get() }
287 }
288}
289
290impl<'a, T> AsMut<T> for RwLockWriteGuard<'a, T> {
291 fn as_mut(&mut self) -> &'a mut T {
292 self.ensure_access_by_holder();
293 unsafe { &mut *self.lock.value.get() }
294 }
295}
296
297impl<T> Deref for RwLockWriteGuard<'_, T> {
298 type Target = T;
299
300 fn deref(&self) -> &T {
301 self.ensure_access_by_holder();
302 unsafe { &*self.lock.value.get() }
303 }
304}
305
306impl<T> DerefMut for RwLockWriteGuard<'_, T> {
307 fn deref_mut(&mut self) -> &mut T {
308 self.ensure_access_by_holder();
309 unsafe { &mut *self.lock.value.get() }
310 }
311}
312
313/// The future returned by the [`read`](RwLock::read) method.
314///
315/// The output of the future is the [`RwLockReadGuard`] that can be obtained by
316/// using `await` syntax.
317///
318/// # Examples
319///
320/// The following example explicitly annotates variable types for demonstration
321/// purposes only. Usually, annotating types is unnecessary since
322/// they can be inferred automatically.
323///
324/// ```
325/// use gstd::sync::{RwLock, RwLockReadFuture, RwLockReadGuard};
326///
327/// #[gstd::async_main]
328/// async fn main() {
329/// let rwlock: RwLock<i32> = RwLock::new(42);
330/// let future: RwLockReadFuture<i32> = rwlock.read();
331/// let guard: RwLockReadGuard<i32> = future.await;
332/// let value: i32 = *guard;
333/// assert_eq!(value, 42);
334/// }
335/// # fn main() {}
336/// ```
337pub struct RwLockReadFuture<'a, T> {
338 lock: &'a RwLock<T>,
339}
340
341/// The future returned by the [`write`](RwLock::write) method.
342///
343/// The output of the future is the [`RwLockWriteGuard`] that can be obtained by
344/// using `await` syntax.
345///
346/// # Examples
347///
348/// ```
349/// use gstd::sync::{RwLock, RwLockWriteFuture, RwLockWriteGuard};
350///
351/// #[gstd::async_main]
352/// async fn main() {
353/// let rwlock: RwLock<i32> = RwLock::new(42);
354/// let future: RwLockWriteFuture<i32> = rwlock.write();
355/// let mut guard: RwLockWriteGuard<i32> = future.await;
356/// let value: i32 = *guard;
357/// assert_eq!(value, 42);
358/// *guard = 84;
359/// assert_eq!(*guard, 42);
360/// }
361/// # fn main() {}
362/// ```
363pub struct RwLockWriteFuture<'a, T> {
364 lock: &'a RwLock<T>,
365}
366
367impl<'a, T> Future for RwLockReadFuture<'a, T> {
368 type Output = RwLockReadGuard<'a, T>;
369
370 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
371 let readers = &self.lock.readers;
372 let readers_count = readers.get().saturating_add(1);
373
374 let current_msg_id = msg::id();
375 let lock = unsafe { &mut *self.lock.locked.get() };
376 if lock.is_none() && readers_count <= READERS_LIMIT {
377 readers.replace(readers_count);
378 Poll::Ready(RwLockReadGuard {
379 lock: self.lock,
380 holder_msg_id: current_msg_id,
381 })
382 } else {
383 // If the message is already in the access queue, and we come here,
384 // it means the message has just been woken up from the waitlist.
385 // In that case we do not want to register yet another access attempt
386 // and just go back to the waitlist.
387 if !self.lock.queue.contains(¤t_msg_id) {
388 self.lock.queue.enqueue(current_msg_id);
389 }
390 Poll::Pending
391 }
392 }
393}
394
395impl<'a, T> Future for RwLockWriteFuture<'a, T> {
396 type Output = RwLockWriteGuard<'a, T>;
397
398 fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
399 let current_msg_id = msg::id();
400 let lock = unsafe { &mut *self.lock.locked.get() };
401 if lock.is_none() && self.lock.readers.get() == 0 {
402 *lock = Some(current_msg_id);
403 Poll::Ready(RwLockWriteGuard {
404 lock: self.lock,
405 holder_msg_id: current_msg_id,
406 })
407 } else {
408 // If the message is already in the access queue, and we come here,
409 // it means the message has just been woken up from the waitlist.
410 // In that case we do not want to register yet another access attempt
411 // and just go back to the waitlist.
412 if !self.lock.queue.contains(¤t_msg_id) {
413 self.lock.queue.enqueue(current_msg_id);
414 }
415 Poll::Pending
416 }
417 }
418}