async_observe/
lib.rs

1#![cfg_attr(all(doc, not(doctest)), doc = include_str!("../README.md"))]
2#![cfg_attr(
3    any(not(doc), doctest),
4    doc = "Async single-producer, multi-consumer channel that only retains the last sent value"
5)]
6
7use {
8    event_listener::Event,
9    futures_lite::{Stream, stream},
10    std::{
11        error, fmt,
12        sync::{
13            Arc, RwLock, RwLockReadGuard, RwLockWriteGuard,
14            atomic::{AtomicUsize, Ordering},
15        },
16    },
17};
18
19/// Creates a new observer channel, returning the sender and receiver halves.
20///
21/// All values sent by [`Sender`] will become visible to the [`Receiver`]
22/// handles. Only the last value sent is made available to the [`Receiver`]
23/// half. All intermediate values are dropped.
24///
25/// # Examples
26///
27/// This example prints numbers from 0 to 9:
28///
29/// ```
30/// # fn f() {
31/// use {
32///     futures_lite::future,
33///     std::{thread, time::Duration},
34/// };
35///
36/// let (tx, mut rx) = async_observe::channel(0);
37///
38/// // Perform computations in another thread
39/// thread::spawn(move || {
40///     for n in 1..10 {
41///         thread::sleep(Duration::from_secs(1));
42///
43///         // Send a new value without blocking the thread.
44///         // If sending fails, it means the sender was dropped.
45///         // In that case stop the computation.
46///         if tx.send(n).is_err() {
47///             break;
48///         }
49///     }
50/// });
51///
52/// // Print the initial value (0)
53/// let n = rx.observe(|n| *n);
54/// println!("{n}");
55///
56/// future::block_on(async {
57///     // Print the value whenever it changes
58///     while let Ok(n) = rx.recv().await {
59///         println!("{n}");
60///     }
61/// });
62/// # }
63/// ```
64///
65/// In this example a new thread is spawned, but you can also use the channel
66/// to update values from a future/async task.
67///
68/// Note that this channel does not have a message queue - it only stores the
69/// latest update. Therefore, it does *not* guarantee that you will observe
70/// every intermediate value. If you need to observe each change, use a
71/// message-queue channel such as [`async-channel`].
72///
73/// [`async-channel`]: https://docs.rs/async-channel/latest/async_channel
74pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
75    let shared = Arc::new(Shared {
76        value: RwLock::new(init),
77        state: State::new(),
78        rx_count: AtomicUsize::new(1),
79        changed: Event::new(),
80        all_receivers_dropped: Event::new(),
81    });
82
83    let tx = Sender {
84        shared: shared.clone(),
85    };
86
87    let rx = Receiver {
88        shared,
89        last_version: 0,
90    };
91
92    (tx, rx)
93}
94
95/// Sends values to the associated [`Receiver`]s.
96///
97/// Created by the [`channel`] function.
98#[derive(Debug)]
99pub struct Sender<T> {
100    /// The inner shared state.
101    shared: Arc<Shared<T>>,
102}
103
104impl<T> Sender<T> {
105    /// Sends a new value to the channel and notifies all receivers.
106    ///
107    /// # Examples
108    ///
109    /// ```
110    /// let (tx, rx) = async_observe::channel(0);
111    /// assert_eq!(rx.observe(|n| *n), 0);
112    ///
113    /// // Send a new value
114    /// tx.send(1);
115    ///
116    /// // Now the receiver can see it
117    /// assert_eq!(rx.observe(|n| *n), 1);
118    /// ```
119    ///
120    /// To wait until the value is updated, use the receiver's async methods
121    /// [`changed`](Receiver::changed) or [`recv`](Receiver::recv).
122    pub fn send(&self, value: T) -> Result<(), SendError<T>> {
123        if self.shared.rx_count.load(Ordering::Relaxed) == 0 {
124            // All receivers have been dropped
125            return Err(SendError(value));
126        }
127
128        // Replace the value
129        *self.shared.write_value() = value;
130        self.shared.state.increment_version();
131
132        // Notify all receivers
133        self.shared.changed.notify(usize::MAX);
134
135        Ok(())
136    }
137
138    /// Waits until all receivers are dropped.
139    ///
140    /// # Examples
141    ///
142    /// A producer can wait until no consumers is interested in its updates
143    /// anymore and then stop working.
144    ///
145    /// ```
146    /// # futures_lite::future::block_on(async {
147    /// # use futures_lite::future::yield_now as wait_long_time;
148    /// use futures_lite::future;
149    ///
150    /// let (tx, mut rx) = async_observe::channel(0);
151    ///
152    /// // The producer runs concurrently and waits until the channel
153    /// // is closed, then cancels the main future
154    /// let producer = future::or(
155    ///     async {
156    ///         let mut n = 0;
157    ///         loop {
158    ///             wait_long_time().await;
159    ///
160    ///             // The producer could check if the channel is closed on send,
161    ///             // but computing a new value might take a long time.
162    ///             // Instead, we cancel the whole working future.
163    ///             _ = tx.send(n);
164    ///             n += 1;
165    ///         }
166    ///     },
167    ///     tx.closed(),
168    /// );
169    ///
170    /// let consumer = async move {
171    /// //                   ^^^^
172    /// // Note: rx is moved into this future,
173    /// // so it will be dropped when the loop ends.
174    /// // Alternatively, you could call `drop(rx);` explicitly.
175    ///
176    ///     while let Ok(n) = rx.recv().await {
177    ///         // After receiving number 5,
178    ///         // the consumer is no longer interested.
179    ///         if n == 5 {
180    ///             break;
181    ///         }
182    ///     }
183    /// };
184    ///
185    /// future::zip(producer, consumer).await;
186    /// # });
187    /// ```
188    ///
189    /// The receiver is dropped after getting the number 5. Since there are no
190    /// other receivers, this makes `tx.closed()` finish and the sender stops
191    /// sending new values.
192    pub async fn closed(&self) {
193        if self.shared.rx_count.load(Ordering::Relaxed) == 0 {
194            return;
195        }
196
197        // In order to avoid a notification loss, we first request a notification,
198        // **then** check the current `rx_count`.
199        // If `rx_count` is 0, the notification request is dropped.
200        event_listener::listener!(self.shared.all_receivers_dropped => listener);
201
202        if self.shared.rx_count.load(Ordering::Relaxed) == 0 {
203            return;
204        }
205
206        listener.await;
207
208        debug_assert_eq!(self.shared.rx_count.load(Ordering::Relaxed), 0);
209    }
210}
211
212impl<T> Drop for Sender<T> {
213    fn drop(&mut self) {
214        self.shared.state.close();
215        self.shared.changed.notify(usize::MAX);
216    }
217}
218
219/// Error produced when sending a value fails.
220#[derive(PartialEq, Eq)]
221pub struct SendError<T>(pub T);
222
223impl<T> fmt::Display for SendError<T> {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        f.write_str("sending on a closed channel")
226    }
227}
228
229impl<T> fmt::Debug for SendError<T> {
230    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
231        f.write_str("sending on a closed channel")
232    }
233}
234
235impl<T> error::Error for SendError<T> {}
236
237/// Receives values from the associated [`Sender`].
238///
239/// Created by the [`channel`] function.
240#[derive(Debug)]
241pub struct Receiver<T> {
242    /// The inner shared state.
243    shared: Arc<Shared<T>>,
244
245    /// Last observed version.
246    last_version: usize,
247}
248
249impl<T> Receiver<T> {
250    /// Observes the latest value sent to the channel.
251    ///
252    /// This method takes a closure and calls it with a reference to the value.
253    /// While the closure is running [`send`](Sender::send) calls are blocked.
254    /// Because of this, the closure should run only as long as needed.
255    /// A common pattern is to copy or clone the value inside the closure, then
256    /// return and work with the copy outside.
257    ///
258    /// You can observe the value at any time, but usually you want to wait
259    /// until it changes. For that, use the [`changed`](Receiver::changed)
260    /// async method.
261    ///
262    /// # Examples
263    ///
264    /// ```
265    /// # futures_lite::future::block_on(async {
266    /// let (tx, mut rx) = async_observe::channel(0);
267    ///
268    /// // Send a new value
269    /// tx.send(1);
270    ///
271    /// // Wait until the value changes
272    /// rx.changed().await?;
273    ///
274    /// // Now we can read the new value
275    /// let n = rx.observe(|n| *n);
276    /// assert_eq!(n, 1);
277    /// # Ok::<_, async_observe::RecvError>(())
278    /// # });
279    /// ```
280    ///
281    /// If the value type implements `Clone`, you can use
282    /// [`recv`](Receiver::recv) instead, which waits for a change and returns
283    /// the new value.
284    ///
285    /// ```
286    /// # futures_lite::future::block_on(async {
287    /// let (tx, mut rx) = async_observe::channel(0);
288    ///
289    /// // Send a new value
290    /// tx.send(1);
291    ///
292    /// // Wait until the value changes and read it
293    /// let n = rx.recv().await?;
294    /// assert_eq!(n, 1);
295    /// # Ok::<_, async_observe::RecvError>(())
296    /// # });
297    /// ```
298    ///
299    /// # Possible deadlock
300    ///
301    /// Calling [`send`](Sender::send) inside the closure will deadlock:
302    ///
303    /// ```no_run
304    /// let (tx, rx) = async_observe::channel(0);
305    /// rx.observe(|n| {
306    ///     _ = tx.send(n + 1);
307    /// });
308    /// ```
309    pub fn observe<F, R>(&self, f: F) -> R
310    where
311        F: FnOnce(&T) -> R,
312    {
313        f(&self.shared.read_value())
314    }
315
316    /// Waits for the value to change.
317    ///
318    /// Call [`observe`](Receiver::observe) to read the new value.
319    pub async fn changed(&mut self) -> Result<(), RecvError> {
320        if self
321            .shared
322            .state
323            .version_changed(&mut self.last_version)
324            .ok_or(RecvError)?
325        {
326            return Ok(());
327        }
328
329        // In order to avoid a notification loss, we first request a notification,
330        // **then** check the current value's version.
331        // If a new version exists, the notification request is dropped.
332        event_listener::listener!(self.shared.changed => listener);
333
334        if self
335            .shared
336            .state
337            .version_changed(&mut self.last_version)
338            .ok_or(RecvError)?
339        {
340            return Ok(());
341        }
342
343        listener.await;
344
345        let changed = self
346            .shared
347            .state
348            .version_changed(&mut self.last_version)
349            .ok_or(RecvError)?;
350
351        debug_assert!(changed);
352        Ok(())
353    }
354
355    /// Waits for the value to change and then returns a clone of it.
356    ///
357    /// # Examples
358    ///
359    /// ```
360    /// # futures_lite::future::block_on(async {
361    /// let (tx, mut rx) = async_observe::channel(0);
362    ///
363    /// // Send a new value
364    /// tx.send(1);
365    ///
366    /// // Wait until the value changes and read it
367    /// let n = rx.recv().await?;
368    /// assert_eq!(n, 1);
369    /// # Ok::<_, async_observe::RecvError>(())
370    /// # });
371    /// ```
372    pub async fn recv(&mut self) -> Result<T, RecvError>
373    where
374        T: Clone,
375    {
376        self.changed().await?;
377        Ok(self.observe(T::clone))
378    }
379
380    /// Creates a [stream](Stream) from the receiver.
381    ///
382    /// The stream ends when the [`Sender`] is dropped.
383    ///
384    /// # Examples
385    ///
386    /// ```
387    /// # futures_lite::future::block_on(async {
388    /// # use futures_lite::future::yield_now as wait_long_time;
389    /// use futures_lite::{StreamExt, future};
390    ///
391    /// let (tx, rx) = async_observe::channel(0);
392    ///
393    /// let producer = async move {
394    /// //                   ^^^^
395    /// // Move tx into the future so it is dropped after the loop ends.
396    /// // Dropping the sender is important so the receiver can
397    /// // see it and stop the stream.
398    ///
399    ///     for n in 1..10 {
400    ///         wait_long_time().await;
401    ///         _ = tx.send(n);
402    ///     }
403    /// };
404    ///
405    /// // Create a stream from the receiver
406    /// let consumer = rx
407    ///     .into_stream()
408    ///     .for_each(|n| println!("{n}"));
409    ///
410    /// future::zip(producer, consumer).await;
411    /// # });
412    /// ```
413    ///
414    /// # Return type
415    ///
416    /// Due to implementation details, the stream currently does not have
417    /// a named type. For example, you cannot store this stream as a struct
418    /// field (without a generic) or use it as an associated type in a trait
419    /// implementation. As long as you don't need that, just use the returned
420    /// anonymous type directly.
421    ///
422    /// However, if you need a named type, one solution is to box the stream:
423    ///
424    /// ```
425    /// use {
426    ///     async_observe::Receiver,
427    ///     futures_lite::{StreamExt, stream::Boxed},
428    /// };
429    ///
430    /// fn named_stream(rx: Receiver<u32>) -> Boxed<u32> {
431    ///     rx.into_stream().boxed()
432    /// }
433    /// ```
434    pub fn into_stream(self) -> impl Stream<Item = T>
435    where
436        T: Clone,
437    {
438        stream::unfold(self, async |mut me| {
439            let value = me.recv().await.ok()?;
440            Some((value, me))
441        })
442    }
443}
444
445impl<T> Clone for Receiver<T> {
446    fn clone(&self) -> Self {
447        self.shared.rx_count.fetch_add(1, Ordering::Relaxed);
448        Self {
449            shared: self.shared.clone(),
450            last_version: self.last_version,
451        }
452    }
453}
454
455impl<T> Drop for Receiver<T> {
456    fn drop(&mut self) {
457        if self.shared.rx_count.fetch_sub(1, Ordering::Relaxed) == 1 {
458            // Notify all senders.
459            // Even though `Sender` is not `Clone`, it can still wait for
460            // the channel to close from multiple places, since the `closed`
461            // method takes `&self`.
462            self.shared.all_receivers_dropped.notify(usize::MAX);
463        }
464    }
465}
466
467/// Error produced when receiving a change notification.
468#[derive(PartialEq, Eq)]
469pub struct RecvError;
470
471impl fmt::Display for RecvError {
472    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
473        f.write_str("receiving on a closed channel")
474    }
475}
476
477impl fmt::Debug for RecvError {
478    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
479        f.write_str("receiving on a closed channel")
480    }
481}
482
483impl error::Error for RecvError {}
484
485#[derive(Debug)]
486struct Shared<T> {
487    /// The most recent value.
488    value: RwLock<T>,
489
490    /// The current state.
491    state: State,
492
493    /// Tracks the number of `Receiver` instances.
494    rx_count: AtomicUsize,
495
496    /// Event when the value has changed or the `Sender` has been dropped.
497    changed: Event,
498
499    /// Event when all `Receiver`s have been dropped.
500    all_receivers_dropped: Event,
501}
502
503impl<T> Shared<T> {
504    fn read_value(&self) -> RwLockReadGuard<'_, T> {
505        // The `RwLock` has no poisoned state
506        match self.value.read() {
507            Ok(guard) => guard,
508            Err(e) => e.into_inner(),
509        }
510    }
511
512    fn write_value(&self) -> RwLockWriteGuard<'_, T> {
513        // The `RwLock` has no poisoned state
514        match self.value.write() {
515            Ok(guard) => guard,
516            Err(e) => e.into_inner(),
517        }
518    }
519}
520
521#[derive(Debug)]
522struct State(AtomicUsize);
523
524impl State {
525    /// Using 2 as the version step preserves the `CLOSED_BIT`.
526    const VERSION_STEP: usize = 2;
527
528    /// The least significant bit signifies a closed channel.
529    const CLOSED_BIT: usize = 1;
530
531    fn new() -> Self {
532        Self(AtomicUsize::new(0))
533    }
534
535    fn increment_version(&self) {
536        self.0.fetch_add(Self::VERSION_STEP, Ordering::Release);
537    }
538
539    fn version_changed(&self, last_version: &mut usize) -> Option<bool> {
540        let state = self.0.load(Ordering::Acquire);
541        let new_version = state & !Self::CLOSED_BIT;
542
543        if *last_version != new_version {
544            *last_version = new_version;
545            return Some(true);
546        }
547
548        if Self::CLOSED_BIT == state & Self::CLOSED_BIT {
549            return None;
550        }
551
552        Some(false)
553    }
554
555    fn close(&self) {
556        self.0.fetch_or(Self::CLOSED_BIT, Ordering::Release);
557    }
558}