asyncs_sync/
watch.rs

1//! Channel to publish and subscribe values.
2
3use std::cell::{Cell, UnsafeCell};
4use std::cmp::Ordering::*;
5use std::fmt::{self, Formatter};
6use std::mem::{ManuallyDrop, MaybeUninit};
7use std::ops::Deref;
8use std::ptr;
9use std::sync::atomic::Ordering::{self, *};
10use std::sync::atomic::{AtomicPtr, AtomicUsize};
11use std::sync::Arc;
12
13use crate::Notify;
14
15/// Error for [Sender::send] to express no receivers alive.
16#[derive(Debug)]
17pub struct SendError<T>(T);
18
19impl<T> SendError<T> {
20    pub fn into_value(self) -> T {
21        self.0
22    }
23}
24
25/// Error for [Receiver::changed()] to express that sender has been dropped.
26#[derive(Debug)]
27pub struct RecvError(());
28
29#[repr(transparent)]
30#[derive(Clone, Copy, Default, Debug, PartialEq, PartialOrd, Eq, Ord)]
31struct Version(u64);
32
33impl Version {
34    pub fn next(self) -> Self {
35        Self(self.0 + 1)
36    }
37}
38
39struct Slot<T> {
40    refs: AtomicUsize,
41
42    frees: AtomicPtr<Slot<T>>,
43
44    /// Safety: never changed after published and before reclaimed
45    value: UnsafeCell<ManuallyDrop<T>>,
46    version: Cell<Version>,
47}
48
49impl<T> Default for Slot<T> {
50    fn default() -> Self {
51        Self {
52            refs: AtomicUsize::new(0),
53            frees: AtomicPtr::new(ptr::null_mut()),
54            value: UnsafeCell::new(ManuallyDrop::new(unsafe { MaybeUninit::zeroed().assume_init() })),
55            version: Cell::new(Version::default()),
56        }
57    }
58}
59
60impl<T> Slot<T> {
61    fn store(&self, value: T) {
62        debug_assert_eq!(self.refs.load(Relaxed), 0);
63        unsafe {
64            std::ptr::write(self.value.get(), ManuallyDrop::new(value));
65        }
66        self.refs.store(1, Relaxed);
67    }
68
69    /// Retains current or newer version.
70    unsafe fn retain(&self, version: Version) -> Option<&Slot<T>> {
71        let mut refs = self.refs.load(Relaxed);
72        loop {
73            if refs == 0 {
74                return None;
75            }
76            match self.refs.compare_exchange(refs, refs + 1, Relaxed, Relaxed) {
77                Ok(_) => {},
78                Err(updated) => {
79                    refs = updated;
80                    continue;
81                },
82            }
83            match self.version.get().cmp(&version) {
84                Equal | Greater => return Some(self),
85                Less => panic!(
86                    "BUG: version is monotonic, expect version {:?}, got old version {:?}",
87                    version,
88                    self.version.get()
89                ),
90            }
91        }
92    }
93}
94
95#[repr(transparent)]
96struct UnsafeSlot<T>(Slot<T>);
97
98impl<T> UnsafeSlot<T> {
99    pub fn retain(&self, version: Version) -> Option<&Slot<T>> {
100        unsafe { self.0.retain(version) }
101    }
102
103    pub unsafe fn slot(&self) -> &Slot<T> {
104        &self.0
105    }
106}
107
108struct Row<T> {
109    prev: Option<Box<Row<T>>>,
110    slots: [Slot<T>; 16],
111}
112
113impl<T> Default for Row<T> {
114    fn default() -> Self {
115        Self { prev: None, slots: Default::default() }
116    }
117}
118
119// We could also stamp version into the atomic to filter eagerly, but it will require `AtomicU128`.
120struct Latest(AtomicUsize);
121
122impl Latest {
123    const CLOSED: usize = 0x01;
124    const MASK: usize = !Self::CLOSED;
125
126    pub fn new<T>(slot: &Slot<T>) -> Self {
127        let raw = slot as *const _ as usize;
128        Self(AtomicUsize::new(raw))
129    }
130
131    pub fn load<'a, T>(&self, ordering: Ordering) -> (&'a UnsafeSlot<T>, bool) {
132        let raw = self.0.load(ordering);
133        (Self::slot(raw & Self::MASK), raw & Self::CLOSED == Self::CLOSED)
134    }
135
136    fn slot<'a, T>(raw: usize) -> &'a UnsafeSlot<T> {
137        unsafe { &*(raw as *const UnsafeSlot<T>) }
138    }
139
140    fn ptr<T>(slot: &Slot<T>) -> usize {
141        slot as *const Slot<T> as usize
142    }
143
144    pub fn compare_exchange<'a, T>(
145        &self,
146        current: &'a Slot<T>,
147        new: &Slot<T>,
148        success: Ordering,
149        failure: Ordering,
150    ) -> Result<&'a Slot<T>, &'a UnsafeSlot<T>> {
151        match self.0.compare_exchange(Self::ptr(current), Self::ptr(new), success, failure) {
152            Ok(_) => Ok(current),
153            Err(updated) => Err(Self::slot(updated)),
154        }
155    }
156
157    pub fn close(&self) {
158        let u = self.0.load(Relaxed);
159        self.0.store(u | Self::CLOSED, Relaxed);
160    }
161}
162
163struct Shared<T> {
164    rows: UnsafeCell<Box<Row<T>>>,
165    frees: AtomicPtr<Slot<T>>,
166
167    latest: Latest,
168
169    closed: Notify,
170    changes: Notify,
171
172    senders: AtomicUsize,
173    receivers: AtomicUsize,
174}
175
176impl<T> Drop for Shared<T> {
177    fn drop(&mut self) {
178        let slot = self.latest.load(Relaxed).0;
179        self.release(unsafe { slot.slot() });
180    }
181}
182
183impl<T> Shared<T> {
184    fn new(version: Version, value: T) -> Self {
185        let row = Box::<Row<_>>::default();
186        let slot = &row.slots[0];
187        slot.store(value);
188        slot.version.set(version);
189        let latest = Latest::new(slot);
190        let shared = Self {
191            rows: UnsafeCell::new(row),
192            frees: AtomicPtr::new(ptr::null_mut()),
193            latest,
194            closed: Notify::new(),
195            changes: Notify::new(),
196            senders: AtomicUsize::new(1),
197            receivers: AtomicUsize::new(1),
198        };
199        let row = unsafe { &*shared.rows.get() };
200        shared.add_slots(&row.slots[1..]);
201        shared
202    }
203
204    fn new_sender(self: &Arc<Self>) -> Sender<T> {
205        self.senders.fetch_add(1, Relaxed);
206        Sender { shared: self.clone() }
207    }
208
209    fn drop_sender(&self) {
210        if self.senders.fetch_sub(1, Relaxed) != 1 {
211            return;
212        }
213        self.latest.close();
214        self.changes.notify_all();
215    }
216
217    fn new_receiver(self: &Arc<Self>, seen: Version) -> Receiver<T> {
218        self.receivers.fetch_add(1, Relaxed);
219        Receiver { seen, shared: self.clone() }
220    }
221
222    fn drop_receiver(&self) {
223        if self.receivers.fetch_sub(1, Relaxed) == 1 {
224            self.closed.notify_all();
225        }
226    }
227
228    fn add_slots(&self, slots: &[Slot<T>]) {
229        for i in 0..slots.len() - 1 {
230            let curr = unsafe { slots.get_unchecked(i) };
231            let next = unsafe { slots.get_unchecked(i + 1) };
232            curr.frees.store(next as *const _ as *mut _, Relaxed);
233        }
234        let head = unsafe { slots.get_unchecked(0) };
235        let tail = unsafe { slots.get_unchecked(slots.len() - 1) };
236        self.free_slots(head, tail);
237    }
238
239    fn alloc_slot(&self) -> &Slot<T> {
240        // Acquire load to see `slot.frees`.
241        let mut head = self.frees.load(Acquire);
242        loop {
243            if head.is_null() {
244                break;
245            }
246            let slot = unsafe { &*head };
247            let next = slot.frees.load(Relaxed);
248            match self.frees.compare_exchange(head, next, Relaxed, Acquire) {
249                Ok(_) => {
250                    slot.frees.store(ptr::null_mut(), Relaxed);
251                    return slot;
252                },
253                Err(updated) => head = updated,
254            }
255        }
256        let mut row = ManuallyDrop::new(Box::<Row<_>>::default());
257        row.prev = Some(unsafe { ptr::read(self.rows.get() as *const _) });
258        unsafe {
259            ptr::write(self.rows.get(), ManuallyDrop::take(&mut row));
260        }
261        self.add_slots(&row.slots[1..]);
262        unsafe { std::mem::transmute(row.slots.get_unchecked(0)) }
263    }
264
265    fn free_slots(&self, head: &Slot<T>, tail: &Slot<T>) {
266        let mut frees = self.frees.load(Relaxed);
267        loop {
268            tail.frees.store(frees, Relaxed);
269            // Release store to publish `slot.frees`.
270            match self.frees.compare_exchange(frees, head as *const _ as *mut _, Release, Relaxed) {
271                Ok(_) => break,
272                Err(updated) => frees = updated,
273            }
274        }
275    }
276
277    fn free_slot(&self, slot: &Slot<T>) {
278        self.free_slots(slot, slot);
279    }
280
281    fn release(&self, slot: &Slot<T>) {
282        if slot.refs.fetch_sub(1, Relaxed) != 1 {
283            return;
284        }
285        let value = unsafe { &mut *slot.value.get() };
286        let value = unsafe { ManuallyDrop::take(value) };
287        drop(value);
288        self.free_slot(slot);
289    }
290
291    fn publish(&self, value: T) {
292        let slot = self.alloc_slot();
293        slot.store(value);
294        let mut latest = self.latest(Version(0));
295        loop {
296            let version = latest.version().next();
297            slot.version.set(version);
298            // Release store to publish value and version
299            // Acquire load to observe version
300            match self.latest.compare_exchange(latest.slot, slot, Release, Acquire) {
301                Ok(slot) => {
302                    self.release(slot);
303                    self.changes.notify_all();
304                    break;
305                },
306                Err(updated) => match updated.retain(version) {
307                    None => latest = self.latest(version),
308                    Some(slot) => latest = Ref { slot, shared: self, closed: false, changed: true },
309                },
310            }
311        }
312    }
313
314    fn latest(&self, seen: Version) -> Ref<'_, T> {
315        loop {
316            // Acquire load to observe version and value.
317            let (slot, closed) = self.latest.load(Acquire);
318            if let Some(slot) = slot.retain(seen) {
319                return Ref { slot, shared: self, closed, changed: seen != slot.version.get() };
320            }
321        }
322    }
323
324    fn try_changed(&self, seen: Version) -> Result<Option<Ref<'_, T>>, RecvError> {
325        let latest = self.latest(seen);
326        if latest.has_changed() {
327            Ok(Some(latest))
328        } else if latest.closed {
329            Err(RecvError(()))
330        } else {
331            Ok(None)
332        }
333    }
334}
335
336/// Constructs a lock free channel to publish and subscribe values.
337///
338/// ## Differences with [tokio]
339/// * [tokio] holds only a single value, so no allocation.
340/// * [Ref] holds no lock but reference to underlying value, which prevent it from reclamation, so
341///   it is also crucial to drop it as soon as possible.
342///
343/// [tokio]: https://docs.rs/tokio
344pub fn channel<T>(value: T) -> (Sender<T>, Receiver<T>) {
345    let version = Version(1);
346    let shared = Arc::new(Shared::new(version, value));
347    let sender = Sender { shared: shared.clone() };
348    let receiver = Receiver { seen: version, shared };
349    (sender, receiver)
350}
351
352/// Send part of [Receiver].
353pub struct Sender<T> {
354    shared: Arc<Shared<T>>,
355}
356
357unsafe impl<T> Send for Sender<T> {}
358unsafe impl<T> Sync for Sender<T> {}
359
360impl<T: fmt::Debug> fmt::Debug for Sender<T> {
361    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
362        let latest = self.shared.latest(Version(0));
363        f.debug_struct("Sender")
364            .field("version", &latest.version())
365            .field("value", latest.as_ref())
366            .field("closed", &latest.closed)
367            .finish()
368    }
369}
370
371impl<T> Sender<T> {
372    /// Sends value to [Receiver]s.
373    pub fn send(&self, value: T) -> Result<(), SendError<T>> {
374        if self.shared.receivers.load(Relaxed) == 0 {
375            return Err(SendError(value));
376        }
377        self.publish(value);
378        Ok(())
379    }
380
381    /// Publishes value for existing [Receiver]s and possible future [Receiver]s from
382    /// [Sender::subscribe].
383    pub fn publish(&self, value: T) {
384        self.shared.publish(value);
385    }
386
387    /// Subscribes to future changes.
388    pub fn subscribe(&self) -> Receiver<T> {
389        let latest = self.shared.latest(Version::default());
390        self.shared.receivers.fetch_add(1, Relaxed);
391        Receiver { seen: latest.version(), shared: self.shared.clone() }
392    }
393
394    /// Receiver count.
395    pub fn receiver_count(&self) -> usize {
396        self.shared.receivers.load(Relaxed)
397    }
398
399    /// Blocks until all receivers dropped.
400    pub async fn closed(&self) {
401        // Loop as `subscribe` takes `&self` but not `&mut self`.
402        while !self.is_closed() {
403            let notified = self.shared.closed.notified();
404            if self.is_closed() {
405                return;
406            }
407            notified.await
408        }
409    }
410
411    pub fn is_closed(&self) -> bool {
412        self.receiver_count() == 0
413    }
414}
415
416impl<T> Clone for Sender<T> {
417    fn clone(&self) -> Self {
418        self.shared.new_sender()
419    }
420}
421
422impl<T> Drop for Sender<T> {
423    fn drop(&mut self) {
424        self.shared.drop_sender();
425    }
426}
427
428/// Receive part of [Sender].
429pub struct Receiver<T> {
430    seen: Version,
431    shared: Arc<Shared<T>>,
432}
433
434unsafe impl<T> Send for Receiver<T> {}
435unsafe impl<T> Sync for Receiver<T> {}
436
437impl<T: fmt::Debug> fmt::Debug for Receiver<T> {
438    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
439        let latest = self.borrow();
440        f.debug_struct("Receiver")
441            .field("seen", &self.seen)
442            .field("version", &latest.version())
443            .field("value", latest.as_ref())
444            .field("closed", &latest.closed)
445            .field("changed", &latest.changed)
446            .finish()
447    }
448}
449
450/// Reference to borrowed value.
451///
452/// Holds reference will prevent it from reclamation so drop it as soon as possible.
453pub struct Ref<'a, T> {
454    slot: &'a Slot<T>,
455    shared: &'a Shared<T>,
456    closed: bool,
457    changed: bool,
458}
459
460unsafe impl<T> Send for Ref<'_, T> {}
461unsafe impl<T> Sync for Ref<'_, T> {}
462
463impl<T: fmt::Debug> fmt::Debug for Ref<'_, T> {
464    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
465        f.debug_struct("Ref")
466            .field("version", &self.version())
467            .field("value", self.as_ref())
468            .field("closed", &self.closed)
469            .field("changed", &self.changed)
470            .finish()
471    }
472}
473
474impl<'a, T> Ref<'a, T> {
475    fn version(&self) -> Version {
476        self.slot.version.get()
477    }
478
479    /// Do we ever seen this before from last [Receiver::borrow_and_update] and [Receiver::changed()] ?
480    pub fn has_changed(&self) -> bool {
481        self.changed
482    }
483}
484
485impl<T> Deref for Ref<'_, T> {
486    type Target = T;
487
488    fn deref(&self) -> &Self::Target {
489        unsafe { &*self.slot.value.get() }
490    }
491}
492
493impl<T> AsRef<T> for Ref<'_, T> {
494    fn as_ref(&self) -> &T {
495        self
496    }
497}
498
499impl<T> Drop for Ref<'_, T> {
500    fn drop(&mut self) {
501        self.shared.release(self.slot);
502    }
503}
504
505impl<T> Receiver<T> {
506    /// Borrows the latest value but does not mark it as seen.
507    pub fn borrow(&self) -> Ref<'_, T> {
508        self.shared.latest(self.seen)
509    }
510
511    /// Borrows the latest value and marks it as seen.
512    pub fn borrow_and_update(&mut self) -> Ref<'_, T> {
513        let latest = self.shared.latest(self.seen);
514        self.seen = latest.version();
515        latest
516    }
517
518    /// Blocks and consumes new change since last [Receiver::borrow_and_update] or [Receiver::changed()].
519    ///
520    /// If multiple values are published in the meantime, it is likely that only later one got
521    /// observed. It is guaranteed that the final value after all [Sender]s dropped is always
522    /// observed.
523    ///
524    /// ## Errors
525    /// * [RecvError] after all [Sender]s dropped and final value consumed
526    pub async fn changed(&mut self) -> Result<Ref<'_, T>, RecvError> {
527        loop {
528            // This serves both luck path and recheck after `notified.await`.
529            if let Some(changed) = self.shared.try_changed(self.seen)? {
530                self.seen = changed.version();
531                return Ok(changed);
532            }
533            let notified = self.shared.changes.notified();
534            if let Some(changed) = self.shared.try_changed(self.seen)? {
535                self.seen = changed.version();
536                return Ok(changed);
537            }
538            notified.await;
539        }
540    }
541}
542
543impl<T> Clone for Receiver<T> {
544    fn clone(&self) -> Self {
545        self.shared.new_receiver(self.seen)
546    }
547}
548
549impl<T> Drop for Receiver<T> {
550    fn drop(&mut self) {
551        self.shared.drop_receiver();
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use std::sync::Arc;
558
559    use asyncs::{select, task};
560
561    use crate::{watch, Notify};
562
563    #[asyncs::test]
564    async fn channel_sequential() {
565        // given: a watch channel
566        let (sender, receiver) = watch::channel(5);
567
568        // when: borrow without a send
569        let latest = receiver.borrow();
570        // then: have seen initial value
571        assert_eq!(*latest, 5);
572        assert!(!latest.has_changed());
573        drop(latest);
574
575        // when: send
576        sender.send(6).unwrap();
577        // then: receiver will observe that send
578        let latest = receiver.borrow();
579        assert_eq!(*latest, 6);
580        assert!(latest.has_changed());
581        drop(latest);
582
583        // when: send after all receivers dropped.
584        drop(receiver);
585        assert_eq!(sender.send(7).unwrap_err().into_value(), 7);
586
587        // then: send fails with no side effect.
588        let receiver = sender.subscribe();
589        let latest = receiver.borrow();
590        assert_eq!(*latest, 6);
591        assert!(!latest.has_changed());
592        drop(latest);
593        drop(receiver);
594
595        // when: publish after all receivers dropped.
596        sender.publish(7);
597        // then: new receiver will observe that
598        let receiver = sender.subscribe();
599        let latest = receiver.borrow();
600        assert_eq!(*latest, 7);
601        assert!(!latest.has_changed());
602    }
603
604    #[asyncs::test]
605    async fn receivers_dropped() {
606        let (sender, receiver) = watch::channel(5);
607        task::spawn(async move {
608            drop(receiver);
609        });
610        select! {
611            _ = sender.closed() => {},
612        }
613
614        let _receiver = sender.subscribe();
615        select! {
616            default => {},
617            _ = sender.closed() => unreachable!(),
618        }
619    }
620
621    #[asyncs::test]
622    async fn senders_dropped() {
623        let (sender, mut receiver) = watch::channel(());
624        drop(sender.clone());
625        select! {
626            default => {},
627            _ = receiver.changed() => unreachable!(),
628        }
629
630        drop(sender);
631        select! {
632            default => unreachable!(),
633            Err(_) = receiver.changed() => {},
634        }
635    }
636
637    #[asyncs::test]
638    async fn changed() {
639        let notify = Arc::new(Notify::new());
640        let (sender, mut receiver) = watch::channel(0);
641        let handle = task::spawn({
642            let notify = notify.clone();
643            async move {
644                let mut values = vec![];
645                while let Ok(latest) = receiver.changed().await {
646                    values.push(*latest);
647                    notify.notify_one();
648                }
649                values
650            }
651        });
652
653        sender.send(1).unwrap();
654        notify.notified().await;
655        sender.send(2).unwrap();
656        notify.notified().await;
657        sender.send(3).unwrap();
658        notify.notified().await;
659
660        // Final value is guaranteed to be seen before error.
661        sender.send(4).unwrap();
662        drop(sender);
663        let values = handle.await.unwrap();
664        assert_eq!(values, vec![1, 2, 3, 4]);
665    }
666
667    #[asyncs::test]
668    async fn ref_drop_release_value() {
669        let shared = Arc::new(());
670
671        let (sender, receiver) = watch::channel(shared.clone());
672        assert_eq!(Arc::strong_count(&shared), 2);
673
674        let borrowed1 = receiver.borrow();
675        let borrowed2 = receiver.borrow();
676        assert_eq!(Arc::strong_count(&shared), 2);
677        sender.send(Arc::new(())).unwrap();
678        assert_eq!(Arc::strong_count(&shared), 2);
679
680        drop(borrowed1);
681        assert_eq!(Arc::strong_count(&shared), 2);
682        drop(borrowed2);
683        assert_eq!(Arc::strong_count(&shared), 1);
684    }
685}