1use remoc::{
2 codec,
3 rch::{mpsc, oneshot},
4 robj::rw_lock::LockError,
5 RemoteSend,
6};
7use serde::{Deserialize, Serialize};
8use std::{
9 error::Error,
10 fmt,
11 ops::{Deref, DerefMut},
12 sync::{Arc, Weak},
13};
14use tokio::sync::{
15 RwLock as TokioRwLock, RwLockReadGuard as TokioRwLockReadGuard,
16 RwLockWriteGuard as TokioRwLockWriteGuard,
17};
18
19#[derive(Clone, Debug, Serialize, Deserialize)]
21pub enum CommitError {
22 Dropped,
24 Failed,
26}
27
28impl fmt::Display for CommitError {
29 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
30 match self {
31 Self::Dropped => write!(f, "host dropped"),
32 Self::Failed => write!(f, "commit failed"),
33 }
34 }
35}
36
37impl<T> From<oneshot::SendError<T>> for CommitError {
38 fn from(err: oneshot::SendError<T>) -> Self {
39 match err {
40 oneshot::SendError::Closed(_) => Self::Dropped,
41 oneshot::SendError::Failed => Self::Failed,
42 }
43 }
44}
45
46impl From<oneshot::RecvError> for CommitError {
47 fn from(_: oneshot::RecvError) -> Self {
48 Self::Failed
49 }
50}
51
52impl Error for CommitError {}
53
54pub struct RwLock<T, Codec = codec::Default> {
58 host: Option<RwLockOwner<T, Codec>>,
59 remote: Arc<RwLockRemote<T, Codec>>,
60}
61
62impl<T, Codec> RwLock<T, Codec> {
63 pub fn host(&self) -> Option<&Arc<TokioRwLock<T>>> {
65 self.host.as_ref().map(|host| &host.value)
66 }
67
68 pub async fn host_read(&self) -> Option<TokioRwLockReadGuard<'_, T>> {
71 let host = self.host.as_ref()?.value.read().await;
72
73 Some(host)
74 }
75
76 pub async fn host_write(&self) -> Option<TokioRwLockWriteGuard<T>> {
79 let host = self.host.as_ref()?.value.write().await;
80
81 Some(host)
82 }
83
84 fn new_remote(remote: RwLockRemote<T, Codec>) -> Self {
85 Self {
86 host: None,
87 remote: Arc::new(remote),
88 }
89 }
90}
91
92impl<T: RemoteSend + Clone + Sync, Codec: codec::Codec> RwLock<T, Codec> {
93 pub fn new(value: T) -> Self {
95 let (read_req_tx, read_req_rx) = mpsc::channel(1);
96 let read_req_tx = read_req_tx.set_buffer();
97 let read_req_rx = read_req_rx.set_buffer();
98 let (write_req_tx, write_req_rx) = mpsc::channel(1);
99 let write_req_tx = write_req_tx.set_buffer();
100 let write_req_rx = write_req_rx.set_buffer();
101
102 let host = RwLockOwner {
103 value: Arc::new(TokioRwLock::new(value)),
104 drop_chanel: Arc::new(TokioRwLock::new(None)),
105 };
106 let weak_host = host.as_weak();
107
108 let rw_lock = Self {
109 host: Some(host),
110 remote: Arc::new(RwLockRemote {
111 read_req_tx,
112 write_req_tx,
113 is_frivolous: false,
114 }),
115 };
116
117 tokio::spawn(Self::handle_host_requests(
118 weak_host,
119 read_req_rx,
120 write_req_rx,
121 ));
122
123 rw_lock
124 }
125
126 #[doc(hidden)]
127 pub fn new_frivolous(value: T) -> Self {
128 let mut this = Self::new(value);
129
130 let remote = Arc::get_mut(&mut this.remote).unwrap();
131 remote.is_frivolous = true;
132
133 this
134 }
135
136 pub async fn read(&self) -> Result<RwLockReadGuard<T, Codec>, LockError> {
138 if let Some(host) = self.host.as_ref() {
139 let (value_guard, _remote_gruard) =
140 tokio::join!(host.value.read(), host.drop_chanel.read());
141
142 return Ok(RwLockReadGuard::new_host(value_guard));
143 }
144
145 let (value_tx, value_rx) = oneshot::channel();
146 let _ = self.remote.read_req_tx.send(ReadRequest { value_tx }).await;
147 let value = value_rx.await?;
148
149 Ok(RwLockReadGuard::new_remote(value))
150 }
151
152 pub async fn write(&self) -> Result<RwLockWriteGuard<T, Codec>, LockError> {
157 if let Some(host) = self.host.as_ref() {
158 let (value_guard, _remote_gruard) = tokio::join!(
159 host.value.write(),
160 drop_remote_read_guard(&host.drop_chanel)
161 );
162
163 if self.remote.is_frivolous {
164 return Ok(RwLockWriteGuard::new_host_frivolous(value_guard));
165 }
166
167 return Ok(RwLockWriteGuard::new_host(value_guard));
168 }
169
170 let (value_tx, value_rx) = oneshot::channel();
171 let (new_value_tx, new_value_rx) = oneshot::channel();
172 let (confirm_tx, confirm_rx) = oneshot::channel();
173
174 let _ = self
175 .remote
176 .write_req_tx
177 .send(WriteRequest {
178 value_tx,
179 new_value_rx,
180 confirm_tx,
181 })
182 .await;
183 let value = value_rx.await?;
184
185 Ok(RwLockWriteGuard::new_remote(
186 value,
187 new_value_tx,
188 confirm_rx,
189 ))
190 }
191
192 async fn handle_host_requests(
193 weak_host: WeakRwLockOwner<T, Codec>,
194 mut read_req_rx: mpsc::Receiver<ReadRequest<T, Codec>, Codec, 1>,
195 mut write_req_rx: mpsc::Receiver<WriteRequest<T, Codec>, Codec, 1>,
196 ) {
197 loop {
198 tokio::select! {
199 biased;
200
201 res = write_req_rx.recv() => {
203 let WriteRequest {value_tx, new_value_rx, confirm_tx} = match res {
204 Ok(Some(req)) => req,
205 Ok(None) => break,
206 Err(err) if err.is_final() => break,
207 Err(_) => continue,
208 };
209
210 let Some(host) = weak_host.upgrade() else {
211 break
212 };
213
214 {
215 let _remote_write_guard = drop_remote_read_guard(&host.drop_chanel).await;
216 let remote_value = host.value.write().await.clone();
217 if value_tx.send(remote_value).is_err() {
218 continue
219 }
220
221 if let Ok(new_value) = new_value_rx.await {
222 *host.value.write().await = new_value;
223 let _ = confirm_tx.send(());
224 }
225 }
226 }
227
228 res = read_req_rx.recv() => {
230 let ReadRequest {value_tx} = match res {
231 Ok(Some(req)) => req,
232 Ok(None) => break,
233 Err(err) if err.is_final() => break,
234 Err(_) => continue,
235 };
236
237 let Some(remote_value) = weak_host.make_remote_value().await else {
238 break
239 };
240 let _ = value_tx.send(remote_value);
241 }
242 }
243 }
244 }
245}
246
247impl<T: RemoteSend + Clone + Sync, Codec: codec::Codec> serde::Serialize for RwLock<T, Codec> {
248 #[inline]
249 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
250 where
251 S: serde::Serializer,
252 {
253 self.remote.serialize(serializer)
254 }
255}
256
257impl<'de, T: RemoteSend + Clone + Sync, Codec: codec::Codec> serde::Deserialize<'de>
258 for RwLock<T, Codec>
259{
260 #[inline]
261 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
262 where
263 D: serde::Deserializer<'de>,
264 {
265 let remote = RwLockRemote::deserialize(deserializer)?;
266
267 Ok(Self::new_remote(remote))
268 }
269}
270
271impl<T, Codec> Clone for RwLock<T, Codec> {
272 fn clone(&self) -> Self {
273 Self {
274 host: self.host.clone(),
275 remote: self.remote.clone(),
276 }
277 }
278}
279
280pub struct RwLockReadGuard<'a, T, Codec = codec::Default> {
284 inner: RwLockReadGuardInner<'a, T, Codec>,
285}
286
287impl<'a, T, Codec> RwLockReadGuard<'a, T, Codec> {
288 fn new_host(value_guard: TokioRwLockReadGuard<'a, T>) -> Self {
289 Self {
290 inner: RwLockReadGuardInner::Owner(value_guard),
291 }
292 }
293
294 fn new_remote(value: RemoteValue<T, Codec>) -> Self {
295 Self {
296 inner: RwLockReadGuardInner::Remote(value),
297 }
298 }
299}
300
301impl<'a, T, Codec> Deref for RwLockReadGuard<'a, T, Codec> {
302 type Target = T;
303
304 fn deref(&self) -> &Self::Target {
305 match &self.inner {
306 RwLockReadGuardInner::Owner(value_guard) => value_guard,
307 RwLockReadGuardInner::Remote(value) => &value.value,
308 }
309 }
310}
311
312pub struct RwLockWriteGuard<'a, T, Codec = codec::Default> {
317 inner: RwLockWriteGuardInner<'a, T, Codec>,
318}
319
320impl<'a, T, Codec> RwLockWriteGuard<'a, T, Codec> {
321 fn new_host_frivolous(value_guard: TokioRwLockWriteGuard<'a, T>) -> Self {
322 Self {
323 inner: RwLockWriteGuardInner::Owner {
324 new_value: None,
325 value_guard,
326 },
327 }
328 }
329}
330
331impl<'a, T: Clone, Codec> RwLockWriteGuard<'a, T, Codec> {
332 fn new_host(value_guard: TokioRwLockWriteGuard<'a, T>) -> Self {
333 Self {
334 inner: RwLockWriteGuardInner::Owner {
335 new_value: Some(value_guard.clone()),
336 value_guard,
337 },
338 }
339 }
340
341 fn new_remote(
342 value: T,
343 new_value_tx: oneshot::Sender<T, Codec>,
344 confirm_rx: oneshot::Receiver<(), Codec>,
345 ) -> Self {
346 Self {
347 inner: RwLockWriteGuardInner::Remote {
348 value,
349 new_value_tx,
350 confirm_rx,
351 },
352 }
353 }
354}
355
356impl<'a, T: RemoteSend, Codec: codec::Codec> RwLockWriteGuard<'a, T, Codec> {
357 pub async fn commit(self) -> Result<(), CommitError> {
359 match self.inner {
360 RwLockWriteGuardInner::Owner {
361 new_value,
362 mut value_guard,
363 } => {
364 if let Some(new_value) = new_value {
365 *value_guard = new_value;
366 }
367
368 Ok(())
369 }
370 RwLockWriteGuardInner::Remote {
371 value,
372 new_value_tx,
373 confirm_rx,
374 } => {
375 new_value_tx.send(value)?;
376 confirm_rx.await?;
377
378 Ok(())
379 }
380 }
381 }
382}
383
384impl<'a, T, Codec> Deref for RwLockWriteGuard<'a, T, Codec> {
385 type Target = T;
386
387 fn deref(&self) -> &Self::Target {
388 match &self.inner {
389 RwLockWriteGuardInner::Owner {
390 new_value,
391 value_guard,
392 } => new_value.as_ref().unwrap_or(value_guard),
393 RwLockWriteGuardInner::Remote { value, .. } => &value,
394 }
395 }
396}
397
398impl<'a, T, Codec> DerefMut for RwLockWriteGuard<'a, T, Codec> {
399 fn deref_mut(&mut self) -> &mut Self::Target {
400 match &mut self.inner {
401 RwLockWriteGuardInner::Owner {
402 new_value,
403 value_guard,
404 } => new_value.as_mut().unwrap_or(value_guard),
405 RwLockWriteGuardInner::Remote { value, .. } => value,
406 }
407 }
408}
409
410struct RwLockOwner<T, Codec> {
411 value: Arc<TokioRwLock<T>>,
412 drop_chanel:
413 Arc<TokioRwLock<Option<(mpsc::Sender<(), Codec, 1>, mpsc::Receiver<(), Codec, 1>)>>>,
414}
415
416impl<T, Codec> RwLockOwner<T, Codec> {
417 fn as_weak(&self) -> WeakRwLockOwner<T, Codec> {
418 WeakRwLockOwner {
419 value: Arc::downgrade(&self.value),
420 drop_chanel: Arc::downgrade(&self.drop_chanel),
421 }
422 }
423}
424
425impl<T, Codec> Clone for RwLockOwner<T, Codec> {
426 fn clone(&self) -> Self {
427 Self {
428 value: self.value.clone(),
429 drop_chanel: self.drop_chanel.clone(),
430 }
431 }
432}
433
434type DropChanel<Codec> = (mpsc::Sender<(), Codec, 1>, mpsc::Receiver<(), Codec, 1>);
435struct WeakRwLockOwner<T, Codec> {
436 value: Weak<TokioRwLock<T>>,
437 drop_chanel: Weak<TokioRwLock<Option<DropChanel<Codec>>>>,
438}
439
440impl<T, Codec> WeakRwLockOwner<T, Codec> {
441 fn upgrade(&self) -> Option<RwLockOwner<T, Codec>> {
442 Some(RwLockOwner {
443 value: self.value.upgrade()?,
444 drop_chanel: self.drop_chanel.upgrade()?,
445 })
446 }
447}
448impl<T: Clone, Codec> WeakRwLockOwner<T, Codec> {
449 async fn make_remote_value(&self) -> Option<RemoteValue<T, Codec>> {
450 let value = self.value.upgrade()?;
451 let drop_chanel = self.drop_chanel.upgrade()?;
452
453 let dropped_tx = drop_chanel
454 .read()
455 .await
456 .as_ref()
457 .map(|drop_chanel| drop_chanel.0.clone());
458
459 let dropped_tx = if let Some(dropped_tx) = dropped_tx {
460 dropped_tx
461 } else {
462 let (dropped_tx, dropped_rx) = mpsc::channel(1);
463 let dropped_tx = dropped_tx.set_buffer();
464 let dropped_rx = dropped_rx.set_buffer();
465
466 {
467 let mut drop_chanel = drop_chanel.write().await;
468 *drop_chanel = Some((dropped_tx.clone(), dropped_rx));
469 }
470
471 dropped_tx
472 };
473
474 let value = value.read().await;
475 Some(RemoteValue {
476 value: value.clone(),
477 dropped_tx: dropped_tx.clone(),
478 })
479 }
480}
481
482async fn drop_remote_read_guard<Codec>(
483 drop_chanel: &Arc<TokioRwLock<Option<DropChanel<Codec>>>>,
484) -> TokioRwLockWriteGuard<Option<DropChanel<Codec>>> {
485 let mut drop_chanel_write_guard = drop_chanel.write().await;
486 if let Some(drop_chanel) = drop_chanel_write_guard.take() {
487 let (dropped_tx, mut dropped_rx) = drop_chanel;
488 drop(dropped_tx);
489 loop {
490 if let Ok(None) = dropped_rx.recv().await {
491 break;
492 }
493 }
494 }
495
496 drop_chanel_write_guard
497}
498
499#[derive(Serialize, Deserialize)]
500#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
501#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
502struct RwLockRemote<T, Codec> {
503 read_req_tx: mpsc::Sender<ReadRequest<T, Codec>, Codec, 1>,
504 write_req_tx: mpsc::Sender<WriteRequest<T, Codec>, Codec, 1>,
505 is_frivolous: bool,
506}
507
508enum RwLockReadGuardInner<'a, T, Codec> {
509 Owner(TokioRwLockReadGuard<'a, T>),
510 Remote(RemoteValue<T, Codec>),
511}
512
513enum RwLockWriteGuardInner<'a, T, Codec> {
514 Owner {
515 new_value: Option<T>,
516 value_guard: TokioRwLockWriteGuard<'a, T>,
517 },
518 Remote {
519 value: T,
520 new_value_tx: oneshot::Sender<T, Codec>,
521 confirm_rx: oneshot::Receiver<(), Codec>,
522 },
523}
524
525#[derive(Serialize, Deserialize)]
526#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
527#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
528struct RemoteValue<T, Codec = codec::Default> {
529 value: T,
530 dropped_tx: mpsc::Sender<(), Codec, 1>,
531}
532
533#[derive(Serialize, Deserialize)]
534#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
535#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
536struct ReadRequest<T, Codec = codec::Default> {
537 value_tx: oneshot::Sender<RemoteValue<T, Codec>, Codec>,
538}
539
540#[derive(Serialize, Deserialize)]
541#[serde(bound(serialize = "T: RemoteSend, Codec: codec::Codec"))]
542#[serde(bound(deserialize = "T: RemoteSend, Codec: codec::Codec"))]
543struct WriteRequest<T, Codec = codec::Default> {
544 value_tx: oneshot::Sender<T, Codec>,
545 new_value_rx: oneshot::Receiver<T, Codec>,
546 confirm_tx: oneshot::Sender<(), Codec>,
547}