timed_locks/
rwlock.rs

1//! Smart pointer to [`tokio::sync::RwLock`].
2
3use std::time::Duration;
4
5use tokio::time::timeout;
6
7use crate::{Result, DEFAULT_TIMEOUT_DURATION};
8
9/// Smart pointer to [`tokio::sync::RwLock`].
10///
11/// Wraps acquiring the lock into [`timeout`] with a [`Duration`] of 30 seconds
12/// by default.
13#[derive(Debug)]
14pub struct RwLock<T> {
15	/// The actual [`tokio::sync::Mutex`]
16	inner: tokio::sync::RwLock<T>,
17	/// The timeout duration
18	timeout: Duration,
19}
20
21impl<T> RwLock<T> {
22	/// Create new `RwLock` with default timeout of 30 seconds.
23	pub fn new(value: T) -> Self {
24		Self { inner: tokio::sync::RwLock::new(value), timeout: DEFAULT_TIMEOUT_DURATION }
25	}
26
27	/// Create new `RwLock` with given timeout.
28	pub fn new_with_timeout(value: T, timeout: Duration) -> Self {
29		Self { inner: tokio::sync::RwLock::new(value), timeout }
30	}
31
32	/// Wrapper around [`tokio::sync::RwLock::read()`]. Will time out if the
33	/// lock can’t get acquired until the timeout is reached.
34	///
35	/// # Panics
36	///
37	/// Panics when timeout is reached.
38	pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, T> {
39		let read_guard = match timeout(self.timeout, self.inner.read()).await {
40			Ok(read_guard) => read_guard,
41			Err(_) => panic!(
42				"Timed out while waiting for `read` lock after {} seconds.",
43				self.timeout.as_secs()
44			),
45		};
46
47		read_guard
48	}
49
50	/// Wrapper around [`tokio::sync::RwLock::read()`]. Will time out if the
51	/// lock can't get acquired until the timeout is reached.
52	///
53	/// Returns an error if timeout is reached.
54	pub async fn read_err(&self) -> Result<tokio::sync::RwLockReadGuard<'_, T>> {
55		let read_guard = timeout(self.timeout, self.inner.read())
56			.await
57			.map_err(|_| crate::Error::ReadLockTimeout(self.timeout.as_secs()))?;
58
59		Ok(read_guard)
60	}
61
62	/// Wrapper around [`tokio::sync::RwLock::write()`]. Will time out if
63	/// the lock can't get acquired until the timeout is reached.
64	///
65	///  # Panics
66	///
67	/// Panics when timeout is reached.
68	pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, T> {
69		let write_guard = match timeout(self.timeout, self.inner.write()).await {
70			Ok(write_guard) => write_guard,
71			Err(_) => panic!(
72				"Timed out while waiting for `write` lock after {} seconds.",
73				self.timeout.as_secs()
74			),
75		};
76
77		write_guard
78	}
79
80	/// Wrapper around [`tokio::sync::RwLock::write()`]. Will time out if
81	/// the lock can't get acquired until the timeout is reached.
82	///
83	/// Returns an error if timeout is reached.
84	pub async fn write_err(&self) -> Result<tokio::sync::RwLockWriteGuard<'_, T>> {
85		let write_guard = timeout(self.timeout, self.inner.write())
86			.await
87			.map_err(|_| crate::Error::WriteLockTimeout(self.timeout.as_secs()))?;
88
89		Ok(write_guard)
90	}
91}
92
93impl<T> std::ops::Deref for RwLock<T> {
94	type Target = tokio::sync::RwLock<T>;
95
96	fn deref(&self) -> &Self::Target {
97		&self.inner
98	}
99}
100
101impl<T: Default> Default for RwLock<T> {
102	fn default() -> Self {
103		Self::new(T::default())
104	}
105}
106
107impl<T> From<T> for RwLock<T> {
108	fn from(value: T) -> Self {
109		Self::new(value)
110	}
111}