dremoc_sync/
rwlock.rs

1use remoc::{
2    codec,
3    rch::{mpsc, oneshot},
4    robj::rw_lock::LockError,
5    RemoteSend,
6};
7use serde::{Deserialize, Serialize};
8use std::{
9    error::Error,
10    fmt,
11    ops::{Deref, DerefMut},
12    sync::{Arc, Weak},
13};
14use tokio::sync::{
15    RwLock as TokioRwLock, RwLockReadGuard as TokioRwLockReadGuard,
16    RwLockWriteGuard as TokioRwLockWriteGuard,
17};
18
19/// An error occurred during committing an RwLock value
20#[derive(Clone, Debug, Serialize, Deserialize)]
21pub enum CommitError {
22    /// The host has been dropped
23    Dropped,
24    /// Commit failed
25    Failed,
26}
27
28impl fmt::Display for CommitError {
29    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30        match self {
31            Self::Dropped => write!(f, "host dropped"),
32            Self::Failed => write!(f, "commit failed"),
33        }
34    }
35}
36
37impl<T> From<oneshot::SendError<T>> for CommitError {
38    fn from(err: oneshot::SendError<T>) -> Self {
39        match err {
40            oneshot::SendError::Closed(_) => Self::Dropped,
41            oneshot::SendError::Failed => Self::Failed,
42        }
43    }
44}
45
46impl From<oneshot::RecvError> for CommitError {
47    fn from(_: oneshot::RecvError) -> Self {
48        Self::Failed
49    }
50}
51
52impl Error for CommitError {}
53
54/// A lock that allows reading and writing of a shared value, possibly stored on a remote endpoint.
55///
56/// This can be cloned and sent to remote endpoints
57pub struct RwLock<T, Codec = codec::Default> {
58    host: Option<RwLockOwner<T, Codec>>,
59    remote: Arc<RwLockRemote<T, Codec>>,
60}
61
62impl<T, Codec> RwLock<T, Codec> {
63    /// Returns host if it exist
64    pub fn host(&self) -> Option<&Arc<TokioRwLock<T>>> {
65        self.host.as_ref().map(|host| &host.value)
66    }
67
68    /// Locks the host shared value for reading and returns a reference to it,
69    /// without synchronizing with remote endpoints
70    pub async fn host_read(&self) -> Option<TokioRwLockReadGuard<'_, T>> {
71        let host = self.host.as_ref()?.value.read().await;
72
73        Some(host)
74    }
75
76    /// Locks the host shared value for reading and writing and returns a mutable reference to it,
77    /// without synchronizing with remote endpoints
78    pub async fn host_write(&self) -> Option<TokioRwLockWriteGuard<T>> {
79        let host = self.host.as_ref()?.value.write().await;
80
81        Some(host)
82    }
83
84    fn new_remote(remote: RwLockRemote<T, Codec>) -> Self {
85        Self {
86            host: None,
87            remote: Arc::new(remote),
88        }
89    }
90}
91
92impl<T: RemoteSend + Clone + Sync, Codec: codec::Codec> RwLock<T, Codec> {
93    /// Creates a new `RwLock<T>` with an host with the specified shared value.
94    pub fn new(value: T) -> Self {
95        let (read_req_tx, read_req_rx) = mpsc::channel(1);
96        let read_req_tx = read_req_tx.set_buffer();
97        let read_req_rx = read_req_rx.set_buffer();
98        let (write_req_tx, write_req_rx) = mpsc::channel(1);
99        let write_req_tx = write_req_tx.set_buffer();
100        let write_req_rx = write_req_rx.set_buffer();
101
102        let host = RwLockOwner {
103            value: Arc::new(TokioRwLock::new(value)),
104            drop_chanel: Arc::new(TokioRwLock::new(None)),
105        };
106        let weak_host = host.as_weak();
107
108        let rw_lock = Self {
109            host: Some(host),
110            remote: Arc::new(RwLockRemote {
111                read_req_tx,
112                write_req_tx,
113                is_frivolous: false,
114            }),
115        };
116
117        tokio::spawn(Self::handle_host_requests(
118            weak_host,
119            read_req_rx,
120            write_req_rx,
121        ));
122
123        rw_lock
124    }
125
126    #[doc(hidden)]
127    pub fn new_frivolous(value: T) -> Self {
128        let mut this = Self::new(value);
129
130        let remote = Arc::get_mut(&mut this.remote).unwrap();
131        remote.is_frivolous = true;
132
133        this
134    }
135
136    /// Locks the current shared value for reading and returns a reference to it.
137    pub async fn read(&self) -> Result<RwLockReadGuard<T, Codec>, LockError> {
138        if let Some(host) = self.host.as_ref() {
139            let (value_guard, _remote_gruard) =
140                tokio::join!(host.value.read(), host.drop_chanel.read());
141
142            return Ok(RwLockReadGuard::new_host(value_guard));
143        }
144
145        let (value_tx, value_rx) = oneshot::channel();
146        let _ = self.remote.read_req_tx.send(ReadRequest { value_tx }).await;
147        let value = value_rx.await?;
148
149        Ok(RwLockReadGuard::new_remote(value))
150    }
151
152    /// Locks the current shared value for reading and writing and returns a mutable reference to it.
153    ///
154    /// To commit the new value [RwLockWriteGuard::commit] must be called, otherwise the
155    /// changes will be lost.
156    pub async fn write(&self) -> Result<RwLockWriteGuard<T, Codec>, LockError> {
157        if let Some(host) = self.host.as_ref() {
158            let (value_guard, _remote_gruard) = tokio::join!(
159                host.value.write(),
160                drop_remote_read_guard(&host.drop_chanel)
161            );
162
163            if self.remote.is_frivolous {
164                return Ok(RwLockWriteGuard::new_host_frivolous(value_guard));
165            }
166
167            return Ok(RwLockWriteGuard::new_host(value_guard));
168        }
169
170        let (value_tx, value_rx) = oneshot::channel();
171        let (new_value_tx, new_value_rx) = oneshot::channel();
172        let (confirm_tx, confirm_rx) = oneshot::channel();
173
174        let _ = self
175            .remote
176            .write_req_tx
177            .send(WriteRequest {
178                value_tx,
179                new_value_rx,
180                confirm_tx,
181            })
182            .await;
183        let value = value_rx.await?;
184
185        Ok(RwLockWriteGuard::new_remote(
186            value,
187            new_value_tx,
188            confirm_rx,
189        ))
190    }
191
192    async fn handle_host_requests(
193        weak_host: WeakRwLockOwner<T, Codec>,
194        mut read_req_rx: mpsc::Receiver<ReadRequest<T, Codec>, Codec, 1>,
195        mut write_req_rx: mpsc::Receiver<WriteRequest<T, Codec>, Codec, 1>,
196    ) {
197        loop {
198            tokio::select! {
199                biased;
200
201                // Write value request
202                res = write_req_rx.recv() => {
203                    let WriteRequest {value_tx, new_value_rx, confirm_tx} = match res {
204                        Ok(Some(req)) => req,
205                        Ok(None) => break,
206                        Err(err) if err.is_final() => break,
207                        Err(_) => continue,
208                    };
209
210                    let Some(host) = weak_host.upgrade() else {
211                        break
212                    };
213
214                    {
215                        let _remote_write_guard = drop_remote_read_guard(&host.drop_chanel).await;
216                        let remote_value = host.value.write().await.clone();
217                        if value_tx.send(remote_value).is_err() {
218                            continue
219                        }
220
221                        if let Ok(new_value) = new_value_rx.await {
222                            *host.value.write().await = new_value;
223                            let _ = confirm_tx.send(());
224                        }
225                    }
226                }
227
228                // Read value request
229                res = read_req_rx.recv() => {
230                    let ReadRequest {value_tx} = match res {
231                        Ok(Some(req)) => req,
232                        Ok(None) => break,
233                        Err(err) if err.is_final() => break,
234                        Err(_) => continue,
235                    };
236
237                    let Some(remote_value) = weak_host.make_remote_value().await else {
238                        break
239                    };
240                    let _ = value_tx.send(remote_value);
241                }
242            }
243        }
244    }
245}
246
247impl<T: RemoteSend + Clone + Sync, Codec: codec::Codec> serde::Serialize for RwLock<T, Codec> {
248    #[inline]
249    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
250    where
251        S: serde::Serializer,
252    {
253        self.remote.serialize(serializer)
254    }
255}
256
257impl<'de, T: RemoteSend + Clone + Sync, Codec: codec::Codec> serde::Deserialize<'de>
258    for RwLock<T, Codec>
259{
260    #[inline]
261    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
262    where
263        D: serde::Deserializer<'de>,
264    {
265        let remote = RwLockRemote::deserialize(deserializer)?;
266
267        Ok(Self::new_remote(remote))
268    }
269}
270
271impl<T, Codec> Clone for RwLock<T, Codec> {
272    fn clone(&self) -> Self {
273        Self {
274            host: self.host.clone(),
275            remote: self.remote.clone(),
276        }
277    }
278}
279
280/// RAII structure used to release the shared read access of a lock when dropped.
281///
282/// As long as this is held, no write access to the lock can occur.
283pub struct RwLockReadGuard<'a, T, Codec = codec::Default> {
284    inner: RwLockReadGuardInner<'a, T, Codec>,
285}
286
287impl<'a, T, Codec> RwLockReadGuard<'a, T, Codec> {
288    fn new_host(value_guard: TokioRwLockReadGuard<'a, T>) -> Self {
289        Self {
290            inner: RwLockReadGuardInner::Owner(value_guard),
291        }
292    }
293
294    fn new_remote(value: RemoteValue<T, Codec>) -> Self {
295        Self {
296            inner: RwLockReadGuardInner::Remote(value),
297        }
298    }
299}
300
301impl<'a, T, Codec> Deref for RwLockReadGuard<'a, T, Codec> {
302    type Target = T;
303
304    fn deref(&self) -> &Self::Target {
305        match &self.inner {
306            RwLockReadGuardInner::Owner(value_guard) => value_guard,
307            RwLockReadGuardInner::Remote(value) => &value.value,
308        }
309    }
310}
311
312/// RAII structure used to release the exclusive write access of a lock when dropped.
313///
314/// To commit changes [commit](Self::commit) must be called.
315/// Dropping the guard will result in the changes to be not applied to the shared value.
316pub struct RwLockWriteGuard<'a, T, Codec = codec::Default> {
317    inner: RwLockWriteGuardInner<'a, T, Codec>,
318}
319
320impl<'a, T, Codec> RwLockWriteGuard<'a, T, Codec> {
321    fn new_host_frivolous(value_guard: TokioRwLockWriteGuard<'a, T>) -> Self {
322        Self {
323            inner: RwLockWriteGuardInner::Owner {
324                new_value: None,
325                value_guard,
326            },
327        }
328    }
329}
330
331impl<'a, T: Clone, Codec> RwLockWriteGuard<'a, T, Codec> {
332    fn new_host(value_guard: TokioRwLockWriteGuard<'a, T>) -> Self {
333        Self {
334            inner: RwLockWriteGuardInner::Owner {
335                new_value: Some(value_guard.clone()),
336                value_guard,
337            },
338        }
339    }
340
341    fn new_remote(
342        value: T,
343        new_value_tx: oneshot::Sender<T, Codec>,
344        confirm_rx: oneshot::Receiver<(), Codec>,
345    ) -> Self {
346        Self {
347            inner: RwLockWriteGuardInner::Remote {
348                value,
349                new_value_tx,
350                confirm_rx,
351            },
352        }
353    }
354}
355
356impl<'a, T: RemoteSend, Codec: codec::Codec> RwLockWriteGuard<'a, T, Codec> {
357    /// Consumes the guard and commits the changes to the shared value.
358    pub async fn commit(self) -> Result<(), CommitError> {
359        match self.inner {
360            RwLockWriteGuardInner::Owner {
361                new_value,
362                mut value_guard,
363            } => {
364                if let Some(new_value) = new_value {
365                    *value_guard = new_value;
366                }
367
368                Ok(())
369            }
370            RwLockWriteGuardInner::Remote {
371                value,
372                new_value_tx,
373                confirm_rx,
374            } => {
375                new_value_tx.send(value)?;
376                confirm_rx.await?;
377
378                Ok(())
379            }
380        }
381    }
382}
383
384impl<'a, T, Codec> Deref for RwLockWriteGuard<'a, T, Codec> {
385    type Target = T;
386
387    fn deref(&self) -> &Self::Target {
388        match &self.inner {
389            RwLockWriteGuardInner::Owner {
390                new_value,
391                value_guard,
392            } => new_value.as_ref().unwrap_or(value_guard),
393            RwLockWriteGuardInner::Remote { value, .. } => &value,
394        }
395    }
396}
397
398impl<'a, T, Codec> DerefMut for RwLockWriteGuard<'a, T, Codec> {
399    fn deref_mut(&mut self) -> &mut Self::Target {
400        match &mut self.inner {
401            RwLockWriteGuardInner::Owner {
402                new_value,
403                value_guard,
404            } => new_value.as_mut().unwrap_or(value_guard),
405            RwLockWriteGuardInner::Remote { value, .. } => value,
406        }
407    }
408}
409
410struct RwLockOwner<T, Codec> {
411    value: Arc<TokioRwLock<T>>,
412    drop_chanel:
413        Arc<TokioRwLock<Option<(mpsc::Sender<(), Codec, 1>, mpsc::Receiver<(), Codec, 1>)>>>,
414}
415
416impl<T, Codec> RwLockOwner<T, Codec> {
417    fn as_weak(&self) -> WeakRwLockOwner<T, Codec> {
418        WeakRwLockOwner {
419            value: Arc::downgrade(&self.value),
420            drop_chanel: Arc::downgrade(&self.drop_chanel),
421        }
422    }
423}
424
425impl<T, Codec> Clone for RwLockOwner<T, Codec> {
426    fn clone(&self) -> Self {
427        Self {
428            value: self.value.clone(),
429            drop_chanel: self.drop_chanel.clone(),
430        }
431    }
432}
433
434type DropChanel<Codec> = (mpsc::Sender<(), Codec, 1>, mpsc::Receiver<(), Codec, 1>);
435struct WeakRwLockOwner<T, Codec> {
436    value: Weak<TokioRwLock<T>>,
437    drop_chanel: Weak<TokioRwLock<Option<DropChanel<Codec>>>>,
438}
439
440impl<T, Codec> WeakRwLockOwner<T, Codec> {
441    fn upgrade(&self) -> Option<RwLockOwner<T, Codec>> {
442        Some(RwLockOwner {
443            value: self.value.upgrade()?,
444            drop_chanel: self.drop_chanel.upgrade()?,
445        })
446    }
447}
448impl<T: Clone, Codec> WeakRwLockOwner<T, Codec> {
449    async fn make_remote_value(&self) -> Option<RemoteValue<T, Codec>> {
450        let value = self.value.upgrade()?;
451        let drop_chanel = self.drop_chanel.upgrade()?;
452
453        let dropped_tx = drop_chanel
454            .read()
455            .await
456            .as_ref()
457            .map(|drop_chanel| drop_chanel.0.clone());
458
459        let dropped_tx = if let Some(dropped_tx) = dropped_tx {
460            dropped_tx
461        } else {
462            let (dropped_tx, dropped_rx) = mpsc::channel(1);
463            let dropped_tx = dropped_tx.set_buffer();
464            let dropped_rx = dropped_rx.set_buffer();
465
466            {
467                let mut drop_chanel = drop_chanel.write().await;
468                *drop_chanel = Some((dropped_tx.clone(), dropped_rx));
469            }
470
471            dropped_tx
472        };
473
474        let value = value.read().await;
475        Some(RemoteValue {
476            value: value.clone(),
477            dropped_tx: dropped_tx.clone(),
478        })
479    }
480}
481
482async fn drop_remote_read_guard<Codec>(
483    drop_chanel: &Arc<TokioRwLock<Option<DropChanel<Codec>>>>,
484) -> TokioRwLockWriteGuard<Option<DropChanel<Codec>>> {
485    let mut drop_chanel_write_guard = drop_chanel.write().await;
486    if let Some(drop_chanel) = drop_chanel_write_guard.take() {
487        let (dropped_tx, mut dropped_rx) = drop_chanel;
488        drop(dropped_tx);
489        loop {
490            if let Ok(None) = dropped_rx.recv().await {
491                break;
492            }
493        }
494    }
495
496    drop_chanel_write_guard
497}
498
499#[derive(Serialize, Deserialize)]
500#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
501#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
502struct RwLockRemote<T, Codec> {
503    read_req_tx: mpsc::Sender<ReadRequest<T, Codec>, Codec, 1>,
504    write_req_tx: mpsc::Sender<WriteRequest<T, Codec>, Codec, 1>,
505    is_frivolous: bool,
506}
507
508enum RwLockReadGuardInner<'a, T, Codec> {
509    Owner(TokioRwLockReadGuard<'a, T>),
510    Remote(RemoteValue<T, Codec>),
511}
512
513enum RwLockWriteGuardInner<'a, T, Codec> {
514    Owner {
515        new_value: Option<T>,
516        value_guard: TokioRwLockWriteGuard<'a, T>,
517    },
518    Remote {
519        value: T,
520        new_value_tx: oneshot::Sender<T, Codec>,
521        confirm_rx: oneshot::Receiver<(), Codec>,
522    },
523}
524
525#[derive(Serialize, Deserialize)]
526#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
527#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
528struct RemoteValue<T, Codec = codec::Default> {
529    value: T,
530    dropped_tx: mpsc::Sender<(), Codec, 1>,
531}
532
533#[derive(Serialize, Deserialize)]
534#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
535#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
536struct ReadRequest<T, Codec = codec::Default> {
537    value_tx: oneshot::Sender<RemoteValue<T, Codec>, Codec>,
538}
539
540#[derive(Serialize, Deserialize)]
541#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
542#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
543struct WriteRequest<T, Codec = codec::Default> {
544    value_tx: oneshot::Sender<T, Codec>,
545    new_value_rx: oneshot::Receiver<T, Codec>,
546    confirm_tx: oneshot::Sender<(), Codec>,
547}