rocketmq_rust/
rocketmq_tokio_lock.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17use std::time::Duration;
18
19pub struct RocketMQTokioRwLock<T: ?Sized> {
20    lock: tokio::sync::RwLock<T>,
21}
22
23impl<T> Default for RocketMQTokioRwLock<T>
24where
25    T: Default,
26{
27    fn default() -> Self {
28        Self::new(T::default())
29    }
30}
31
32impl<T: ?Sized> RocketMQTokioRwLock<T> {
33    /// Creates a new `RocketMQTokioRwLock` instance containing the given data.
34    ///
35    /// # Arguments
36    ///
37    /// * `data` - The data to be protected by the read-write lock.
38    ///
39    /// # Returns
40    ///
41    /// A new `RocketMQTokioRwLock` instance.
42    pub fn new(data: T) -> Self
43    where
44        T: Sized,
45    {
46        Self {
47            lock: tokio::sync::RwLock::new(data),
48        }
49    }
50
51    /// Creates a new `RocketMQTokioRwLock` instance from an existing `tokio::sync::RwLock`.
52    ///
53    /// # Arguments
54    ///
55    /// * `lock` - An existing `tokio::sync::RwLock` to be used.
56    ///
57    /// # Returns
58    ///
59    /// A new `RocketMQTokioRwLock` instance.
60    pub fn new_rw_lock(lock: tokio::sync::RwLock<T>) -> Self
61    where
62        T: Sized,
63    {
64        Self { lock }
65    }
66
67    /// Acquires a read lock asynchronously, blocking the current task until it is able to do so.
68    ///
69    /// # Returns
70    ///
71    /// A `RwLockReadGuard` that releases the read lock when dropped.
72    pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
73        self.lock.read().await
74    }
75
76    /// Acquires a write lock asynchronously, blocking the current task until it is able to do so.
77    ///
78    /// # Returns
79    ///
80    /// A `RwLockWriteGuard` that releases the write lock when dropped.
81    pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, T> {
82        self.lock.write().await
83    }
84
85    /// Attempts to acquire a read lock asynchronously without blocking.
86    ///
87    /// # Returns
88    ///
89    /// An `Option` containing a `RwLockReadGuard` if the read lock was successfully acquired, or
90    /// `None` if the lock is already held.
91    pub async fn try_read(&self) -> Option<tokio::sync::RwLockReadGuard<'_, T>> {
92        self.lock.try_read().ok()
93    }
94
95    /// Attempts to acquire a write lock asynchronously without blocking.
96    ///
97    /// # Returns
98    ///
99    /// An `Option` containing a `RwLockWriteGuard` if the write lock was successfully acquired, or
100    /// `None` if the lock is already held.
101    pub async fn try_write(&self) -> Option<tokio::sync::RwLockWriteGuard<'_, T>> {
102        self.lock.try_write().ok()
103    }
104
105    /// Attempts to acquire a read lock asynchronously, blocking for up to the specified timeout.
106    ///
107    /// # Arguments
108    ///
109    /// * `timeout` - The maximum duration to wait for the read lock.
110    ///
111    /// # Returns
112    ///
113    /// An `Option` containing a `RwLockReadGuard` if the read lock was successfully acquired within
114    /// the timeout, or `None` if the timeout expired.
115    pub async fn try_read_timeout(
116        &self,
117        timeout: Duration,
118    ) -> Option<tokio::sync::RwLockReadGuard<'_, T>> {
119        (tokio::time::timeout(timeout, self.lock.read()).await).ok()
120    }
121
122    /// Attempts to acquire a write lock asynchronously, blocking for up to the specified timeout.
123    ///
124    /// # Arguments
125    ///
126    /// * `timeout` - The maximum duration to wait for the write lock.
127    ///
128    /// # Returns
129    ///
130    /// An `Option` containing a `RwLockWriteGuard` if the write lock was successfully acquired
131    /// within the timeout, or `None` if the timeout expired.
132    pub async fn try_write_timeout(
133        &self,
134        timeout: Duration,
135    ) -> Option<tokio::sync::RwLockWriteGuard<'_, T>> {
136        (tokio::time::timeout(timeout, self.lock.write()).await).ok()
137    }
138}
139
140pub struct RocketMQTokioMutex<T: ?Sized> {
141    lock: tokio::sync::Mutex<T>,
142}
143
144impl<T: ?Sized> RocketMQTokioMutex<T> {
145    /// Creates a new `RocketMQTokioMutex` instance containing the given data.
146    ///
147    /// # Arguments
148    ///
149    /// * `data` - The data to be protected by the mutex.
150    ///
151    /// # Returns
152    ///
153    /// A new `RocketMQTokioMutex` instance.
154    pub fn new(data: T) -> Self
155    where
156        T: Sized,
157    {
158        Self {
159            lock: tokio::sync::Mutex::new(data),
160        }
161    }
162
163    /// Acquires the lock asynchronously, blocking the current task until it is able to do so.
164    ///
165    /// # Returns
166    ///
167    /// A `MutexGuard` that releases the lock when dropped.
168    pub async fn lock(&self) -> tokio::sync::MutexGuard<'_, T> {
169        self.lock.lock().await
170    }
171
172    /// Attempts to acquire the lock asynchronously without blocking.
173    ///
174    /// # Returns
175    ///
176    /// An `Option` containing a `MutexGuard` if the lock was successfully acquired, or `None` if
177    /// the lock is already held.
178    pub async fn try_lock(&self) -> Option<tokio::sync::MutexGuard<'_, T>> {
179        self.lock.try_lock().ok()
180    }
181
182    /// Attempts to acquire the lock asynchronously, blocking for up to the specified timeout.
183    ///
184    /// # Arguments
185    ///
186    /// * `timeout` - The maximum duration to wait for the lock.
187    ///
188    /// # Returns
189    ///
190    /// An `Option` containing a `MutexGuard` if the lock was successfully acquired within the
191    /// timeout, or `None` if the timeout expired.
192    pub async fn try_lock_timeout(
193        &self,
194        timeout: Duration,
195    ) -> Option<tokio::sync::MutexGuard<'_, T>> {
196        (tokio::time::timeout(timeout, self.lock.lock()).await).ok()
197    }
198}
199
200impl<T> Default for RocketMQTokioMutex<T>
201where
202    T: Default,
203{
204    fn default() -> Self {
205        Self::new(T::default())
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use std::sync::Arc;
212    use std::time::Duration;
213
214    use tokio::sync::RwLock;
215
216    use super::*;
217
218    #[tokio::test]
219    async fn new_creates_instance() {
220        let lock = RocketMQTokioRwLock::new(5);
221        assert_eq!(*lock.read().await, 5);
222    }
223
224    #[tokio::test]
225    async fn new_rw_lock_creates_instance() {
226        let rw_lock = RwLock::new(5);
227        let lock = RocketMQTokioRwLock::new_rw_lock(rw_lock);
228        assert_eq!(*lock.read().await, 5);
229    }
230
231    #[tokio::test]
232    async fn read_locks_and_reads() {
233        let lock = RocketMQTokioRwLock::new(5);
234        let read_guard = lock.read().await;
235        assert_eq!(*read_guard, 5);
236    }
237
238    #[tokio::test]
239    async fn write_locks_and_writes() {
240        let lock = RocketMQTokioRwLock::new(5);
241        {
242            let mut write_guard = lock.write().await;
243            *write_guard = 10;
244        }
245        assert_eq!(*lock.read().await, 10);
246    }
247
248    #[tokio::test]
249    async fn try_read_locks_and_reads() {
250        let lock = RocketMQTokioRwLock::new(5);
251        let read_guard = lock.try_read().await;
252        assert!(read_guard.is_some());
253        assert_eq!(*read_guard.unwrap(), 5);
254    }
255
256    #[tokio::test]
257    async fn try_write_locks_and_writes() {
258        let lock = RocketMQTokioRwLock::new(5);
259        {
260            let write_guard = lock.try_write().await;
261            assert!(write_guard.is_some());
262            *write_guard.unwrap() = 10;
263        }
264        assert_eq!(*lock.read().await, 10);
265    }
266
267    #[tokio::test]
268    async fn try_read_timeout_succeeds_within_timeout() {
269        let lock = RocketMQTokioRwLock::new(5);
270        let read_guard = lock.try_read_timeout(Duration::from_millis(100)).await;
271        assert!(read_guard.is_some());
272        assert_eq!(*read_guard.unwrap(), 5);
273    }
274
275    #[tokio::test]
276    async fn try_read_timeout_fails_after_timeout() {
277        let lock = Arc::new(RocketMQTokioRwLock::new(5));
278        let arc = lock.clone();
279        let (tx, rx) = tokio::sync::oneshot::channel();
280        tokio::spawn(async move {
281            let _read_guard = lock.write().await;
282            tx.send(()).unwrap();
283            tokio::time::sleep(Duration::from_millis(10)).await;
284            drop(_read_guard);
285        });
286        rx.await.unwrap();
287        let read_guard = arc.try_read_timeout(Duration::from_millis(2)).await;
288        assert!(read_guard.is_none());
289    }
290
291    #[tokio::test]
292    async fn try_write_timeout_succeeds_within_timeout() {
293        let lock = RocketMQTokioRwLock::new(5);
294        let write_guard = lock.try_write_timeout(Duration::from_millis(100)).await;
295        assert!(write_guard.is_some());
296        *write_guard.unwrap() = 10;
297        assert_eq!(*lock.read().await, 10);
298    }
299
300    #[tokio::test]
301    async fn try_write_timeout_fails_after_timeout() {
302        let lock = Arc::new(RocketMQTokioRwLock::new(5));
303        let arc = lock.clone();
304        let (tx, rx) = tokio::sync::oneshot::channel();
305        tokio::spawn(async move {
306            let write_guard = lock.read().await;
307            tx.send(()).unwrap();
308            tokio::time::sleep(Duration::from_millis(10)).await;
309            drop(write_guard);
310        });
311        rx.await.unwrap();
312        let write_guard = arc.try_write_timeout(Duration::from_millis(2)).await;
313        assert!(write_guard.is_none());
314    }
315
316    #[tokio::test]
317    async fn new_creates_mutex_instance() {
318        let mutex = RocketMQTokioMutex::new(5);
319        let guard = mutex.lock().await;
320        assert_eq!(*guard, 5);
321    }
322
323    #[tokio::test]
324    async fn lock_acquires_lock_and_allows_mutation() {
325        let mutex = RocketMQTokioMutex::new(5);
326        {
327            let mut guard = mutex.lock().await;
328            *guard = 10;
329        }
330        let guard = mutex.lock().await;
331        assert_eq!(*guard, 10);
332    }
333
334    #[tokio::test]
335    async fn try_lock_acquires_lock_if_available() {
336        let mutex = RocketMQTokioMutex::new(5);
337        let guard = mutex.try_lock().await;
338        assert!(guard.is_some());
339        assert_eq!(*guard.unwrap(), 5);
340    }
341
342    #[tokio::test]
343    async fn try_lock_returns_none_if_unavailable() {
344        let mutex = Arc::new(RocketMQTokioMutex::new(5));
345        let arc = mutex.clone();
346        let (tx, rx) = tokio::sync::oneshot::channel();
347        tokio::spawn(async move {
348            let _guard = arc.lock().await;
349            tx.send(()).unwrap();
350            // Hold the lock until the test completes
351            tokio::time::sleep(Duration::from_secs(1)).await;
352        });
353        rx.await.unwrap();
354        let guard = mutex.try_lock().await;
355        assert!(guard.is_none());
356    }
357
358    #[tokio::test]
359    async fn try_lock_timeout_succeeds_within_timeout() {
360        let mutex = RocketMQTokioMutex::new(5);
361        let guard = mutex.try_lock_timeout(Duration::from_millis(100)).await;
362        assert!(guard.is_some());
363        assert_eq!(*guard.unwrap(), 5);
364    }
365
366    #[tokio::test]
367    async fn try_lock_timeout_fails_after_timeout() {
368        let mutex = Arc::new(RocketMQTokioMutex::new(5));
369        let arc = mutex.clone();
370        let (tx, rx) = tokio::sync::oneshot::channel();
371        tokio::spawn(async move {
372            let _guard = arc.lock().await;
373            tx.send(()).unwrap();
374            // Hold the lock for longer than the timeout
375            tokio::time::sleep(Duration::from_secs(1)).await;
376        });
377        rx.await.unwrap();
378        let guard = mutex.try_lock_timeout(Duration::from_millis(2)).await;
379        assert!(guard.is_none());
380    }
381}