n0_watcher/
lib.rs

1//! Watchable values.
2//!
3//! A [`Watchable`] exists to keep track of a value which may change over time.  It allows
4//! observers to be notified of changes to the value.  The aim is to always be aware of the
5//! **last** value, not to observe *every* value change.
6//!
7//! The reason for this is ergonomics and predictable resource usage: Requiring every
8//! intermediate value to be observable would mean that either the side that sets new values
9//! using [`Watchable::set`] would need to wait for all "receivers" of these intermediate
10//! values to catch up and thus be an async operation, or it would require the receivers
11//! to buffer intermediate values until they've been "received" on the [`Watcher`]s with
12//! an unlimited buffer size and thus potentially unlimited memory growth.
13//!
14//! # Example
15//!
16//! ```
17//! use n0_future::StreamExt;
18//! use n0_watcher::{Watchable, Watcher as _};
19//!
20//! #[tokio::main(flavor = "current_thread", start_paused = true)]
21//! async fn main() {
22//!     let watchable = Watchable::new(None);
23//!
24//!     // A task that waits for the watcher to be initialized to Some(value) before printing it
25//!     let mut watcher = watchable.watch();
26//!     tokio::spawn(async move {
27//!         let initialized_value = watcher.initialized().await;
28//!         println!("initialized: {initialized_value}");
29//!     });
30//!
31//!     // A task that prints every update to the watcher since the initial one:
32//!     let mut updates = watchable.watch().stream_updates_only();
33//!     tokio::spawn(async move {
34//!         while let Some(update) = updates.next().await {
35//!             println!("update: {update:?}");
36//!         }
37//!     });
38//!
39//!     // A task that prints the current value and then every update it can catch,
40//!     // but it also does something else which makes it very slow to pick up new
41//!     // values, so it'll skip some:
42//!     let mut current_and_updates = watchable.watch().stream();
43//!     tokio::spawn(async move {
44//!         while let Some(update) = current_and_updates.next().await {
45//!             println!("update2: {update:?}");
46//!             tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
47//!         }
48//!     });
49//!
50//!     for i in 0..20 {
51//!         println!("Setting watchable to {i}");
52//!         watchable.set(Some(i)).ok();
53//!         tokio::time::sleep(tokio::time::Duration::from_millis(250)).await;
54//!     }
55//! }
56//! ```
57//!
58//! # Similar but different
59//!
60//! - `async_channel`: This is a multi-producer, multi-consumer channel implementation.
61//!   Only at most one consumer will receive each "produced" value.
62//!   What we want is to have every "produced" value to be "broadcast" to every receiver.
63//! - `tokio::broadcast`: Also a multi-producer, multi-consumer channel implementation.
64//!   This is very similar to this crate (`tokio::broadcast::Sender` is like [`Watchable`]
65//!   and `tokio::broadcast::Receiver` is like [`Watcher`]), but you can't get the latest
66//!   value without `.await`ing on the receiver, and it'll internally store a queue of
67//!   intermediate values.
68//! - `tokio::watch`: Also a MPSC channel, and unlike `tokio::broadcast` only retains the
69//!   latest value. That module has pretty much the same purpose as this crate, but doesn't
70//!   implement a poll-based method of getting updates and doesn't implement combinators.
71//! - [`std::sync::RwLock`]: (wrapped in an [`std::sync::Arc`]) This allows you access
72//!   to the latest values, but might block while it's being set (but that could be short
73//!   enough not to matter for async rust purposes).
74//!   This doesn't allow you to be notified whenever a new value is written.
75//! - The `watchable` crate: We used to use this crate at n0, but we wanted to experiment
76//!   with different APIs and needed Wasm support.
77#[cfg(not(watcher_loom))]
78use std::sync;
79use std::{
80    collections::VecDeque,
81    future::Future,
82    pin::Pin,
83    sync::{Arc, RwLockReadGuard, Weak},
84    task::{self, ready, Poll, Waker},
85};
86
87#[cfg(watcher_loom)]
88use loom::sync;
89use n0_error::StackError;
90use sync::{Mutex, RwLock};
91
92/// A wrapper around a value that notifies [`Watcher`]s when the value is modified.
93///
94/// Only the most recent value is available to any observer, but the observer is guaranteed
95/// to be notified of the most recent value.
96#[derive(Debug, Default)]
97pub struct Watchable<T> {
98    shared: Arc<Shared<T>>,
99}
100
101impl<T> Clone for Watchable<T> {
102    fn clone(&self) -> Self {
103        Self {
104            shared: self.shared.clone(),
105        }
106    }
107}
108
109/// Abstracts over `Option<T>` and `Vec<T>`
110pub trait Nullable<T> {
111    /// Converts this value into an `Option`.
112    fn into_option(self) -> Option<T>;
113}
114
115impl<T> Nullable<T> for Option<T> {
116    fn into_option(self) -> Option<T> {
117        self
118    }
119}
120
121impl<T> Nullable<T> for Vec<T> {
122    fn into_option(mut self) -> Option<T> {
123        self.pop()
124    }
125}
126
127impl<T: Clone + Eq> Watchable<T> {
128    /// Creates a [`Watchable`] initialized to given value.
129    pub fn new(value: T) -> Self {
130        Self {
131            shared: Arc::new(Shared {
132                state: RwLock::new(State {
133                    value,
134                    epoch: INITIAL_EPOCH,
135                }),
136                watchers: Default::default(),
137            }),
138        }
139    }
140
141    /// Sets a new value.
142    ///
143    /// Returns `Ok(previous_value)` if the value was different from the one set, or
144    /// returns the provided value back as `Err(value)` if the value didn't change.
145    ///
146    /// Watchers are only notified if the value changed.
147    pub fn set(&self, value: T) -> Result<T, T> {
148        // We don't actually write when the value didn't change, but there's unfortunately
149        // no way to upgrade a read guard to a write guard, and locking as read first, then
150        // dropping and locking as write introduces a possible race condition.
151        let mut state = self.shared.state.write().expect("poisoned");
152
153        // Find out if the value changed
154        let changed = state.value != value;
155
156        let ret = if changed {
157            let old = std::mem::replace(&mut state.value, value);
158            state.epoch += 1;
159            Ok(old)
160        } else {
161            Err(value)
162        };
163        drop(state); // No need to write anymore
164
165        // Notify watchers
166        if changed {
167            for watcher in self.shared.watchers.lock().expect("poisoned").drain(..) {
168                watcher.wake();
169            }
170        }
171        ret
172    }
173
174    /// Creates a [`Direct`] [`Watcher`], allowing the value to be observed, but not modified.
175    pub fn watch(&self) -> Direct<T> {
176        Direct {
177            state: self.shared.state().clone(),
178            shared: Arc::downgrade(&self.shared),
179        }
180    }
181
182    /// Returns the currently stored value.
183    pub fn get(&self) -> T {
184        self.shared.get()
185    }
186
187    /// Returns true when there are any watchers actively listening on changes,
188    /// or false when all watchers have been dropped or none have been created yet.
189    pub fn has_watchers(&self) -> bool {
190        // `Watchable`s will increase the strong count
191        // `Direct`s watchers (which all watchers descend from) will increase the weak count
192        Arc::weak_count(&self.shared) != 0
193    }
194}
195
196impl<T> Drop for Watchable<T> {
197    fn drop(&mut self) {
198        let Ok(mut watchers) = self.shared.watchers.lock() else {
199            return; // Poisoned waking?
200        };
201        // Wake all watchers every time we drop.
202        // This allows us to notify `NextFut::poll`s that the underlying
203        // watchable might be dropped.
204        for watcher in watchers.drain(..) {
205            watcher.wake();
206        }
207    }
208}
209
210/// A handle to a value that's represented by one or more underlying [`Watchable`]s.
211///
212/// A [`Watcher`] can get the current value, and will be notified when the value changes.
213/// Only the most recent value is accessible, and if the threads with the underlying [`Watchable`]s
214/// change the value faster than the threads with the [`Watcher`] can keep up with, then
215/// it'll miss in-between values.
216/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always
217/// end up reporting the most recent state eventually.
218///
219/// Watchers can be modified via [`Watcher::map`] to observe a value derived from the original
220/// value via a function.
221///
222/// Watchers can be combined via [`Watcher::or`] to allow observing multiple values at once and
223/// getting an update in case any of the values updates.
224///
225/// One of the underlying [`Watchable`]s might already be dropped. In that case,
226/// the watcher will be "disconnected" and return [`Err(Disconnected)`](Disconnected)
227/// on some function calls or, when turned into a stream, that stream will end.
228/// This property can also be checked with [`Watcher::is_connected`].
229pub trait Watcher: Clone {
230    /// The type of value that can change.
231    ///
232    /// We require `Clone`, because we need to be able to make
233    /// the values have a lifetime that's detached from the original [`Watchable`]'s
234    /// lifetime.
235    ///
236    /// We require `Eq`, to be able to check whether the value actually changed or
237    /// not, so we can notify or not notify accordingly.
238    type Value: Clone + Eq;
239
240    /// Updates the watcher to the latest value and returns that value.
241    ///
242    /// If any of the underlying [`Watchable`] values have been dropped, then this
243    /// might return an outdated value for that watchable, specifically, the latest
244    /// value that was fetched for that watchable, as opposed to the latest value
245    /// that was set on the watchable before it was dropped.
246    ///
247    /// The default implementation for this is simply
248    /// ```ignore
249    /// fn get(&mut self) -> Self::Value {
250    ///     self.update();
251    ///     self.peek().clone()
252    /// }
253    /// ```
254    fn get(&mut self) -> Self::Value {
255        self.update();
256        self.peek().clone()
257    }
258
259    /// Updates the watcher to the latest value and returns whether it changed.
260    ///
261    /// Watchers keep track of the "latest known" value they fetched.
262    /// This function updates that internal value by looking up the latest value
263    /// at the [`Watchable`]\(s\) that this watcher is linked to.
264    fn update(&mut self) -> bool;
265
266    /// Returns a reference to the value currently stored in the watcher.
267    ///
268    /// Watchers keep track of the "latest known" value they fetched.
269    /// Calling this won't update the latest value, unlike [`Watcher::get`] or
270    /// [`Watcher::update`].
271    ///
272    /// This can be useful if you want to avoid copying out the internal value
273    /// frequently like what [`Watcher::get`] will end up doing.
274    fn peek(&self) -> &Self::Value;
275
276    /// Whether this watcher is still connected to all of its underlying [`Watchable`]s.
277    ///
278    /// Returns false when any of the underlying watchables has been dropped.
279    fn is_connected(&self) -> bool;
280
281    /// Polls for the next value, or returns [`Disconnected`] if one of the underlying
282    /// [`Watchable`]s has been dropped.
283    fn poll_updated(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Disconnected>>;
284
285    /// Returns a future completing with `Ok(value)` once a new value is set, or with
286    /// [`Err(Disconnected)`](Disconnected) if the connected [`Watchable`] was dropped.
287    ///
288    /// # Cancel Safety
289    ///
290    /// The returned future is cancel-safe.
291    fn updated(&mut self) -> NextFut<'_, Self> {
292        NextFut { watcher: self }
293    }
294
295    /// Returns a future completing once the value is set to [`Some`] value.
296    ///
297    /// If the current value is [`Some`] value, this future will resolve immediately.
298    ///
299    /// This is a utility for the common case of storing an [`Option`] inside a
300    /// [`Watchable`].
301    ///
302    /// # Cancel Safety
303    ///
304    /// The returned future is cancel-safe.
305    fn initialized<T, W>(&mut self) -> InitializedFut<'_, T, W, Self>
306    where
307        W: Nullable<T> + Clone,
308        Self: Watcher<Value = W>,
309    {
310        InitializedFut {
311            initial: self.get().into_option(),
312            watcher: self,
313        }
314    }
315
316    /// Returns a stream which will yield the most recent values as items.
317    ///
318    /// The first item of the stream is the current value, so that this stream can be easily
319    /// used to operate on the most recent value.
320    ///
321    /// Note however, that only the last item is stored.  If the stream is not polled when an
322    /// item is available it can be replaced with another item by the time it is polled.
323    ///
324    /// This stream ends once the original [`Watchable`] has been dropped.
325    ///
326    /// # Cancel Safety
327    ///
328    /// The returned stream is cancel-safe.
329    fn stream(mut self) -> Stream<Self>
330    where
331        Self: Unpin,
332    {
333        Stream {
334            initial: Some(self.get()),
335            watcher: self,
336        }
337    }
338
339    /// Returns a stream which will yield the most recent values as items, starting from
340    /// the next unobserved future value.
341    ///
342    /// This means this stream will only yield values when the watched value changes,
343    /// the value stored at the time the stream is created is not yielded.
344    ///
345    /// Note however, that only the last item is stored.  If the stream is not polled when an
346    /// item is available it can be replaced with another item by the time it is polled.
347    ///
348    /// This stream ends once the original [`Watchable`] has been dropped.
349    ///
350    /// # Cancel Safety
351    ///
352    /// The returned stream is cancel-safe.
353    fn stream_updates_only(self) -> Stream<Self>
354    where
355        Self: Unpin,
356    {
357        Stream {
358            initial: None,
359            watcher: self,
360        }
361    }
362
363    /// Maps this watcher with a function that transforms the observed values.
364    ///
365    /// The returned watcher will only register updates, when the *mapped* value
366    /// observably changes.
367    fn map<T: Clone + Eq>(
368        mut self,
369        map: impl Fn(Self::Value) -> T + Send + Sync + 'static,
370    ) -> Map<Self, T> {
371        Map {
372            current: (map)(self.get()),
373            map: Arc::new(map),
374            watcher: self,
375        }
376    }
377
378    /// Returns a watcher that updates every time this or the other watcher
379    /// updates, and yields both watcher's items together when that happens.
380    fn or<W: Watcher>(self, other: W) -> Tuple<Self, W> {
381        Tuple::new(self, other)
382    }
383}
384
385/// The immediate, direct observer of a [`Watchable`] value.
386///
387/// This type is mainly used via the [`Watcher`] interface.
388#[derive(Debug, Clone)]
389pub struct Direct<T> {
390    state: State<T>,
391    shared: Weak<Shared<T>>,
392}
393
394impl<T: Clone + Eq> Watcher for Direct<T> {
395    type Value = T;
396
397    fn update(&mut self) -> bool {
398        let Some(shared) = self.shared.upgrade() else {
399            return false;
400        };
401        let state = shared.state();
402        if state.epoch > self.state.epoch {
403            self.state = state.clone();
404            true
405        } else {
406            false
407        }
408    }
409
410    fn peek(&self) -> &Self::Value {
411        &self.state.value
412    }
413
414    fn is_connected(&self) -> bool {
415        self.shared.upgrade().is_some()
416    }
417
418    fn poll_updated(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Disconnected>> {
419        let Some(shared) = self.shared.upgrade() else {
420            return Poll::Ready(Err(Disconnected));
421        };
422        self.state = ready!(shared.poll_updated(cx, self.state.epoch));
423        Poll::Ready(Ok(()))
424    }
425}
426
427#[derive(Debug, Clone)]
428pub struct Tuple<S: Watcher, T: Watcher> {
429    inner: (S, T),
430    current: (S::Value, T::Value),
431}
432
433impl<S: Watcher, T: Watcher> Tuple<S, T> {
434    pub fn new(mut s: S, mut t: T) -> Self {
435        let current = (s.get(), t.get());
436        Self {
437            inner: (s, t),
438            current,
439        }
440    }
441}
442
443impl<S: Watcher, T: Watcher> Watcher for Tuple<S, T> {
444    type Value = (S::Value, T::Value);
445
446    fn update(&mut self) -> bool {
447        // We need to update all watchers! So don't early-return
448        let s_updated = self.inner.0.update();
449        let t_updated = self.inner.1.update();
450        let updated = s_updated || t_updated;
451        if updated {
452            self.current = (self.inner.0.peek().clone(), self.inner.1.peek().clone());
453        }
454        updated
455    }
456
457    fn peek(&self) -> &Self::Value {
458        &self.current
459    }
460
461    fn is_connected(&self) -> bool {
462        self.inner.0.is_connected() && self.inner.1.is_connected()
463    }
464
465    fn poll_updated(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Disconnected>> {
466        let poll_0 = self.inner.0.poll_updated(cx)?;
467        let poll_1 = self.inner.1.poll_updated(cx)?;
468        if poll_0.is_pending() && poll_1.is_pending() {
469            return Poll::Pending;
470        }
471        if poll_0.is_ready() {
472            self.current.0 = self.inner.0.peek().clone();
473        }
474        if poll_1.is_ready() {
475            self.current.1 = self.inner.1.peek().clone();
476        }
477        Poll::Ready(Ok(()))
478    }
479}
480
481#[derive(Debug, Clone)]
482pub struct Triple<S: Watcher, T: Watcher, U: Watcher> {
483    inner: (S, T, U),
484    current: (S::Value, T::Value, U::Value),
485}
486
487impl<S: Watcher, T: Watcher, U: Watcher> Triple<S, T, U> {
488    pub fn new(mut s: S, mut t: T, mut u: U) -> Self {
489        let current = (s.get(), t.get(), u.get());
490        Self {
491            inner: (s, t, u),
492            current,
493        }
494    }
495}
496
497impl<S: Watcher, T: Watcher, U: Watcher> Watcher for Triple<S, T, U> {
498    type Value = (S::Value, T::Value, U::Value);
499
500    fn update(&mut self) -> bool {
501        // We need to update all watchers! So don't early-return
502        let s_updated = self.inner.0.update();
503        let t_updated = self.inner.1.update();
504        let u_updated = self.inner.2.update();
505        let updated = s_updated || t_updated || u_updated;
506        if updated {
507            self.current = (
508                self.inner.0.peek().clone(),
509                self.inner.1.peek().clone(),
510                self.inner.2.peek().clone(),
511            );
512        }
513        updated
514    }
515
516    fn peek(&self) -> &Self::Value {
517        &self.current
518    }
519
520    fn is_connected(&self) -> bool {
521        self.inner.0.is_connected() && self.inner.1.is_connected() && self.inner.2.is_connected()
522    }
523
524    fn poll_updated(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Disconnected>> {
525        let poll_0 = self.inner.0.poll_updated(cx)?;
526        let poll_1 = self.inner.1.poll_updated(cx)?;
527        let poll_2 = self.inner.2.poll_updated(cx)?;
528
529        if poll_0.is_pending() && poll_1.is_pending() && poll_2.is_pending() {
530            return Poll::Pending;
531        }
532        if poll_0.is_ready() {
533            self.current.0 = self.inner.0.peek().clone();
534        }
535        if poll_1.is_ready() {
536            self.current.1 = self.inner.1.peek().clone();
537        }
538        if poll_2.is_ready() {
539            self.current.2 = self.inner.2.peek().clone();
540        }
541        Poll::Ready(Ok(()))
542    }
543}
544
545/// Combinator to join two watchers
546#[derive(Debug, Clone)]
547pub struct Join<T: Clone + Eq, W: Watcher<Value = T>> {
548    // invariant: watchers.len() == current.len()
549    watchers: Vec<W>,
550    current: Vec<T>,
551}
552
553impl<T: Clone + Eq, W: Watcher<Value = T>> Join<T, W> {
554    /// Joins a set of watchers into a single watcher
555    pub fn new(watchers: impl Iterator<Item = W>) -> Self {
556        let mut watchers: Vec<W> = watchers.into_iter().collect();
557
558        let mut current = Vec::with_capacity(watchers.len());
559        for watcher in &mut watchers {
560            current.push(watcher.get());
561        }
562        Self { watchers, current }
563    }
564}
565
566impl<T: Clone + Eq, W: Watcher<Value = T>> Watcher for Join<T, W> {
567    type Value = Vec<T>;
568
569    fn update(&mut self) -> bool {
570        let mut any_updated = false;
571        for (value, watcher) in self.current.iter_mut().zip(self.watchers.iter_mut()) {
572            if watcher.update() {
573                any_updated = true;
574                *value = watcher.peek().clone();
575            }
576        }
577        any_updated
578    }
579
580    fn peek(&self) -> &Self::Value {
581        &self.current
582    }
583
584    fn is_connected(&self) -> bool {
585        self.watchers.iter().all(|w| w.is_connected())
586    }
587
588    fn poll_updated(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Disconnected>> {
589        let mut any_updated = false;
590        for (value, watcher) in self.current.iter_mut().zip(self.watchers.iter_mut()) {
591            if watcher.poll_updated(cx)?.is_ready() {
592                any_updated = true;
593                *value = watcher.peek().clone();
594            }
595        }
596
597        if any_updated {
598            Poll::Ready(Ok(()))
599        } else {
600            Poll::Pending
601        }
602    }
603}
604
605/// Wraps a [`Watcher`] to allow observing a derived value.
606///
607/// See [`Watcher::map`].
608#[derive(derive_more::Debug, Clone)]
609pub struct Map<W: Watcher, T: Clone + Eq> {
610    #[debug("Arc<dyn Fn(W::Value) -> T>")]
611    map: Arc<dyn Fn(W::Value) -> T + Send + Sync + 'static>,
612    watcher: W,
613    current: T,
614}
615
616impl<W: Watcher, T: Clone + Eq> Watcher for Map<W, T> {
617    type Value = T;
618
619    fn update(&mut self) -> bool {
620        if self.watcher.update() {
621            let new = (self.map)(self.watcher.peek().clone());
622            if new != self.current {
623                self.current = new;
624                true
625            } else {
626                false
627            }
628        } else {
629            false
630        }
631    }
632
633    fn peek(&self) -> &Self::Value {
634        &self.current
635    }
636
637    fn is_connected(&self) -> bool {
638        self.watcher.is_connected()
639    }
640
641    fn poll_updated(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Disconnected>> {
642        loop {
643            ready!(self.watcher.poll_updated(cx)?);
644            let new = (self.map)(self.watcher.peek().clone());
645            if new != self.current {
646                self.current = new;
647                return Poll::Ready(Ok(()));
648            }
649        }
650    }
651}
652
653/// Future returning the next item after the current one in a [`Watcher`].
654///
655/// See [`Watcher::updated`].
656///
657/// # Cancel Safety
658///
659/// This future is cancel-safe.
660#[derive(Debug)]
661pub struct NextFut<'a, W: Watcher> {
662    watcher: &'a mut W,
663}
664
665impl<W: Watcher> Future for NextFut<'_, W> {
666    type Output = Result<W::Value, Disconnected>;
667
668    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
669        ready!(self.watcher.poll_updated(cx))?;
670        Poll::Ready(Ok(self.watcher.peek().clone()))
671    }
672}
673
674/// Future returning the current or next value that's [`Some`] value.
675/// in a [`Watcher`].
676///
677/// See [`Watcher::initialized`].
678///
679/// # Cancel Safety
680///
681/// This Future is cancel-safe.
682#[derive(Debug)]
683pub struct InitializedFut<'a, T, V: Nullable<T> + Clone, W: Watcher<Value = V>> {
684    initial: Option<T>,
685    watcher: &'a mut W,
686}
687
688impl<T: Clone + Eq + Unpin, V: Nullable<T> + Clone, W: Watcher<Value = V> + Unpin> Future
689    for InitializedFut<'_, T, V, W>
690{
691    type Output = T;
692
693    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
694        let mut this = self.as_mut();
695        if let Some(value) = this.initial.take() {
696            return Poll::Ready(value);
697        }
698        loop {
699            if ready!(this.watcher.poll_updated(cx)).is_err() {
700                // The value will never be initialized
701                return Poll::Pending;
702            };
703            let value = this.watcher.peek();
704            if let Some(value) = value.clone().into_option() {
705                return Poll::Ready(value);
706            }
707        }
708    }
709}
710
711/// A stream for a [`Watcher`]'s next values.
712///
713/// See [`Watcher::stream`] and [`Watcher::stream_updates_only`].
714///
715/// # Cancel Safety
716///
717/// This stream is cancel-safe.
718#[derive(Debug, Clone)]
719pub struct Stream<W: Watcher + Unpin> {
720    initial: Option<W::Value>,
721    watcher: W,
722}
723
724impl<W: Watcher + Unpin> n0_future::Stream for Stream<W>
725where
726    W::Value: Unpin,
727{
728    type Item = W::Value;
729
730    fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
731        if let Some(value) = self.as_mut().initial.take() {
732            return Poll::Ready(Some(value));
733        }
734        match self.as_mut().watcher.poll_updated(cx) {
735            Poll::Ready(Ok(())) => Poll::Ready(Some(self.as_ref().watcher.peek().clone())),
736            Poll::Ready(Err(Disconnected)) => Poll::Ready(None),
737            Poll::Pending => Poll::Pending,
738        }
739    }
740}
741
742/// The error for when a [`Watcher`] is disconnected from its underlying
743/// [`Watchable`] value, because of that watchable having been dropped.
744#[derive(StackError)]
745#[error("Watcher lost connection to underlying Watchable, it was dropped")]
746pub struct Disconnected;
747
748// Private:
749
750const INITIAL_EPOCH: u64 = 1;
751
752/// The shared state for a [`Watchable`].
753#[derive(Debug, Default)]
754struct Shared<T> {
755    /// The value to be watched and its current epoch.
756    state: RwLock<State<T>>,
757    watchers: Mutex<VecDeque<Waker>>,
758}
759
760#[derive(Debug, Clone)]
761struct State<T> {
762    value: T,
763    epoch: u64,
764}
765
766impl<T: Default> Default for State<T> {
767    fn default() -> Self {
768        Self {
769            value: Default::default(),
770            epoch: INITIAL_EPOCH,
771        }
772    }
773}
774
775impl<T: Clone> Shared<T> {
776    fn get(&self) -> T {
777        self.state.read().expect("poisoned").value.clone()
778    }
779
780    fn state(&self) -> RwLockReadGuard<'_, State<T>> {
781        self.state.read().expect("poisoned")
782    }
783
784    fn poll_updated(&self, cx: &mut task::Context<'_>, last_epoch: u64) -> Poll<State<T>> {
785        {
786            let state = self.state();
787
788            // We might get spurious wakeups due to e.g. a second-to-last Watchable being dropped.
789            // This makes sure we don't accidentally return an update that's not actually an update.
790            if last_epoch < state.epoch {
791                return Poll::Ready(state.clone());
792            }
793        }
794
795        self.watchers
796            .lock()
797            .expect("poisoned")
798            .push_back(cx.waker().to_owned());
799
800        #[cfg(watcher_loom)]
801        loom::thread::yield_now();
802
803        // We check for an update again to prevent races between putting in wakers and looking for updates.
804        {
805            let state = self.state();
806
807            if last_epoch < state.epoch {
808                return Poll::Ready(state.clone());
809            }
810        }
811
812        Poll::Pending
813    }
814}
815
816#[cfg(test)]
817mod tests {
818
819    use n0_future::{future::poll_once, StreamExt};
820    use rand::{rng, Rng};
821    use tokio::{
822        task::JoinSet,
823        time::{Duration, Instant},
824    };
825    use tokio_util::sync::CancellationToken;
826
827    use super::*;
828
829    #[tokio::test]
830    async fn test_watcher() {
831        let cancel = CancellationToken::new();
832        let watchable = Watchable::new(17);
833
834        assert_eq!(watchable.watch().stream().next().await.unwrap(), 17);
835
836        let start = Instant::now();
837        // spawn watchers
838        let mut tasks = JoinSet::new();
839        for i in 0..3 {
840            let mut watch = watchable.watch().stream();
841            let cancel = cancel.clone();
842            tasks.spawn(async move {
843                println!("[{i}] spawn");
844                let mut expected_value = 17;
845                loop {
846                    tokio::select! {
847                        biased;
848                        Some(value) = &mut watch.next() => {
849                            println!("{:?} [{i}] update: {value}", start.elapsed());
850                            assert_eq!(value, expected_value);
851                            if expected_value == 17 {
852                                expected_value = 0;
853                            } else {
854                                expected_value += 1;
855                            }
856                        },
857                        _ = cancel.cancelled() => {
858                            println!("{:?} [{i}] cancel", start.elapsed());
859                            assert_eq!(expected_value, 10);
860                            break;
861                        }
862                    }
863                }
864            });
865        }
866        for i in 0..3 {
867            let mut watch = watchable.watch().stream_updates_only();
868            let cancel = cancel.clone();
869            tasks.spawn(async move {
870                println!("[{i}] spawn");
871                let mut expected_value = 0;
872                loop {
873                    tokio::select! {
874                        biased;
875                        Some(value) = watch.next() => {
876                            println!("{:?} [{i}] stream update: {value}", start.elapsed());
877                            assert_eq!(value, expected_value);
878                            expected_value += 1;
879                        },
880                        _ = cancel.cancelled() => {
881                            println!("{:?} [{i}] cancel", start.elapsed());
882                            assert_eq!(expected_value, 10);
883                            break;
884                        }
885                        else => {
886                            panic!("stream died");
887                        }
888                    }
889                }
890            });
891        }
892
893        // set value
894        for next_value in 0..10 {
895            let sleep = Duration::from_nanos(rng().random_range(0..100_000_000));
896            println!("{:?} sleep {sleep:?}", start.elapsed());
897            tokio::time::sleep(sleep).await;
898
899            let changed = watchable.set(next_value);
900            println!("{:?} set {next_value} changed={changed:?}", start.elapsed());
901        }
902
903        println!("cancel");
904        cancel.cancel();
905        while let Some(res) = tasks.join_next().await {
906            res.expect("task failed");
907        }
908    }
909
910    #[test]
911    fn test_get() {
912        let watchable = Watchable::new(None);
913        assert!(watchable.get().is_none());
914
915        watchable.set(Some(1u8)).ok();
916        assert_eq!(watchable.get(), Some(1u8));
917    }
918
919    #[tokio::test]
920    async fn test_initialize() {
921        let watchable = Watchable::new(None);
922
923        let mut watcher = watchable.watch();
924        let mut initialized = watcher.initialized();
925
926        let poll = poll_once(&mut initialized).await;
927        assert!(poll.is_none());
928
929        watchable.set(Some(1u8)).ok();
930
931        let poll = poll_once(&mut initialized).await;
932        assert_eq!(poll.unwrap(), 1u8);
933    }
934
935    #[tokio::test]
936    async fn test_initialize_already_init() {
937        let watchable = Watchable::new(Some(1u8));
938
939        let mut watcher = watchable.watch();
940        let mut initialized = watcher.initialized();
941
942        let poll = poll_once(&mut initialized).await;
943        assert_eq!(poll.unwrap(), 1u8);
944    }
945
946    #[test]
947    fn test_initialized_always_resolves() {
948        #[cfg(not(watcher_loom))]
949        use std::thread;
950
951        #[cfg(watcher_loom)]
952        use loom::thread;
953
954        let test_case = || {
955            let watchable = Watchable::<Option<u8>>::new(None);
956
957            let mut watch = watchable.watch();
958            let thread = thread::spawn(move || n0_future::future::block_on(watch.initialized()));
959
960            watchable.set(Some(42)).ok();
961
962            thread::yield_now();
963
964            let value: u8 = thread.join().unwrap();
965
966            assert_eq!(value, 42);
967        };
968
969        #[cfg(watcher_loom)]
970        loom::model(test_case);
971        #[cfg(not(watcher_loom))]
972        test_case();
973    }
974
975    #[tokio::test(flavor = "multi_thread")]
976    async fn test_update_cancel_safety() {
977        let watchable = Watchable::new(0);
978        let mut watch = watchable.watch();
979        const MAX: usize = 100_000;
980
981        let handle = tokio::spawn(async move {
982            let mut last_observed = 0;
983
984            while last_observed != MAX {
985                tokio::select! {
986                    val = watch.updated() => {
987                        let Ok(val) = val else {
988                            return;
989                        };
990
991                        assert_ne!(val, last_observed, "never observe the same value twice, even with cancellation");
992                        last_observed = val;
993                    }
994                    _ = tokio::time::sleep(Duration::from_micros(rng().random_range(0..10_000))) => {
995                        // We cancel the other future and start over again
996                        continue;
997                    }
998                }
999            }
1000        });
1001
1002        for i in 1..=MAX {
1003            watchable.set(i).ok();
1004            if rng().random_bool(0.2) {
1005                tokio::task::yield_now().await;
1006            }
1007        }
1008
1009        tokio::time::timeout(Duration::from_secs(10), handle)
1010            .await
1011            .unwrap()
1012            .unwrap()
1013    }
1014
1015    #[tokio::test]
1016    async fn test_join_simple() {
1017        let a = Watchable::new(1u8);
1018        let b = Watchable::new(1u8);
1019
1020        let mut ab = Join::new([a.watch(), b.watch()].into_iter());
1021
1022        let stream = ab.clone().stream();
1023        let handle = tokio::task::spawn(async move { stream.take(5).collect::<Vec<_>>().await });
1024
1025        // get
1026        assert_eq!(ab.get(), vec![1, 1]);
1027        // set a
1028        a.set(2u8).unwrap();
1029        tokio::task::yield_now().await;
1030        assert_eq!(ab.get(), vec![2, 1]);
1031        // set b
1032        b.set(3u8).unwrap();
1033        tokio::task::yield_now().await;
1034        assert_eq!(ab.get(), vec![2, 3]);
1035
1036        a.set(3u8).unwrap();
1037        tokio::task::yield_now().await;
1038        b.set(4u8).unwrap();
1039        tokio::task::yield_now().await;
1040
1041        let values = tokio::time::timeout(Duration::from_secs(5), handle)
1042            .await
1043            .unwrap()
1044            .unwrap();
1045        assert_eq!(
1046            values,
1047            vec![vec![1, 1], vec![2, 1], vec![2, 3], vec![3, 3], vec![3, 4]]
1048        );
1049    }
1050
1051    #[tokio::test]
1052    async fn test_updated_then_disconnect_then_get() {
1053        let watchable = Watchable::new(10);
1054        let mut watcher = watchable.watch();
1055        assert_eq!(watchable.get(), 10);
1056        watchable.set(42).ok();
1057        assert_eq!(watcher.updated().await.unwrap(), 42);
1058        drop(watchable);
1059        assert_eq!(watcher.get(), 42);
1060    }
1061
1062    #[tokio::test(start_paused = true)]
1063    async fn test_update_wakeup_on_watchable_drop() {
1064        let watchable = Watchable::new(10);
1065        let mut watcher = watchable.watch();
1066
1067        let start = Instant::now();
1068        let (_, result) = tokio::time::timeout(Duration::from_secs(2), async move {
1069            tokio::join!(
1070                async move {
1071                    tokio::time::sleep(Duration::from_secs(1)).await;
1072                    drop(watchable);
1073                },
1074                async move { watcher.updated().await }
1075            )
1076        })
1077        .await
1078        .expect("watcher never updated");
1079        // We should've updated 1s after start, since that's when the watchable was dropped.
1080        // If this is 2s, then the watchable dropping didn't wake up the `Watcher::updated` future.
1081        assert_eq!(start.elapsed(), Duration::from_secs(1));
1082        assert!(result.is_err());
1083    }
1084
1085    #[tokio::test(start_paused = true)]
1086    async fn test_update_wakeup_always_a_change() {
1087        let watchable = Watchable::new(10);
1088        let mut watcher = watchable.watch();
1089
1090        let task = tokio::spawn(async move {
1091            let mut last_value = watcher.get();
1092            let mut values = Vec::new();
1093            while let Ok(value) = watcher.updated().await {
1094                values.push(value);
1095                if last_value == value {
1096                    return Err("value duplicated");
1097                }
1098                last_value = value;
1099            }
1100            Ok(values)
1101        });
1102
1103        // wait for the task to get set up and polled till pending for once
1104        tokio::time::sleep(Duration::from_millis(100)).await;
1105
1106        watchable.set(11).ok();
1107        tokio::time::sleep(Duration::from_millis(100)).await;
1108        let clone = watchable.clone();
1109        drop(clone); // this shouldn't trigger an update
1110        tokio::time::sleep(Duration::from_millis(100)).await;
1111        for i in 1..=10 {
1112            watchable.set(i + 11).ok();
1113            tokio::time::sleep(Duration::from_millis(100)).await;
1114        }
1115        drop(watchable);
1116
1117        let values = task
1118            .await
1119            .expect("task panicked")
1120            .expect("value duplicated");
1121        assert_eq!(values, vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]);
1122    }
1123
1124    #[test]
1125    fn test_has_watchers() {
1126        let a = Watchable::new(1u8);
1127        assert!(!a.has_watchers());
1128        let b = a.clone();
1129        assert!(!a.has_watchers());
1130        assert!(!b.has_watchers());
1131
1132        let watcher = a.watch();
1133        assert!(a.has_watchers());
1134        assert!(b.has_watchers());
1135
1136        drop(watcher);
1137
1138        assert!(!a.has_watchers());
1139        assert!(!b.has_watchers());
1140    }
1141
1142    #[tokio::test]
1143    async fn test_three_watchers_basic() {
1144        let watchable = Watchable::new(1u8);
1145
1146        let mut w1 = watchable.watch();
1147        let mut w2 = watchable.watch();
1148        let mut w3 = watchable.watch();
1149
1150        // All see the initial value
1151
1152        assert_eq!(w1.get(), 1);
1153        assert_eq!(w2.get(), 1);
1154        assert_eq!(w3.get(), 1);
1155
1156        // Change  value
1157        watchable.set(42).unwrap();
1158
1159        // All watchers get notified
1160        assert_eq!(w1.updated().await.unwrap(), 42);
1161        assert_eq!(w2.updated().await.unwrap(), 42);
1162        assert_eq!(w3.updated().await.unwrap(), 42);
1163    }
1164
1165    #[tokio::test]
1166    async fn test_three_watchers_skip_intermediate() {
1167        let watchable = Watchable::new(0u8);
1168        let mut watcher = watchable.watch();
1169
1170        watchable.set(1).ok();
1171        watchable.set(2).ok();
1172        watchable.set(3).ok();
1173        watchable.set(4).ok();
1174
1175        let value = watcher.updated().await.unwrap();
1176
1177        assert_eq!(value, 4);
1178    }
1179
1180    #[tokio::test]
1181    async fn test_three_watchers_with_streams() {
1182        let watchable = Watchable::new(10u8);
1183
1184        let mut stream1 = watchable.watch().stream();
1185        let mut stream2 = watchable.watch().stream();
1186        let mut stream3 = watchable.watch().stream_updates_only();
1187
1188        assert_eq!(stream1.next().await.unwrap(), 10);
1189        assert_eq!(stream2.next().await.unwrap(), 10);
1190
1191        // Update the value
1192        watchable.set(20).ok();
1193
1194        // All streams see the update
1195        assert_eq!(stream1.next().await.unwrap(), 20);
1196        assert_eq!(stream2.next().await.unwrap(), 20);
1197        assert_eq!(stream3.next().await.unwrap(), 20);
1198    }
1199
1200    #[tokio::test]
1201    async fn test_three_watchers_independent() {
1202        let watchable = Watchable::new(0u8);
1203
1204        let mut fast_watcher = watchable.watch();
1205        let mut slow_watcher = watchable.watch();
1206        let mut lazy_watcher = watchable.watch();
1207
1208        watchable.set(1).ok();
1209        assert_eq!(fast_watcher.updated().await.unwrap(), 1);
1210
1211        // More updates happen
1212        watchable.set(2).ok();
1213        watchable.set(3).ok();
1214
1215        assert_eq!(slow_watcher.updated().await.unwrap(), 3);
1216        assert_eq!(lazy_watcher.get(), 3);
1217    }
1218
1219    #[tokio::test]
1220    async fn test_combine_three_watchers() {
1221        let a = Watchable::new(1u8);
1222        let b = Watchable::new(2u8);
1223        let c = Watchable::new(3u8);
1224
1225        let mut combined = Triple::new(a.watch(), b.watch(), c.watch());
1226
1227        assert_eq!(combined.get(), (1, 2, 3));
1228
1229        // Update one
1230        b.set(20).ok();
1231
1232        assert_eq!(combined.updated().await.unwrap(), (1, 20, 3));
1233
1234        c.set(30).ok();
1235        assert_eq!(combined.updated().await.unwrap(), (1, 20, 30));
1236    }
1237
1238    #[tokio::test]
1239    async fn test_three_watchers_disconnection() {
1240        let watchable = Watchable::new(5u8);
1241
1242        // All connected
1243        let mut w1 = watchable.watch();
1244        let mut w2 = watchable.watch();
1245        let mut w3 = watchable.watch();
1246
1247        // Drop the watchable
1248        drop(watchable);
1249
1250        // All become disconnected
1251        assert!(!w1.is_connected());
1252        assert!(!w2.is_connected());
1253        assert!(!w3.is_connected());
1254
1255        // Can still get last known value
1256        assert_eq!(w1.get(), 5);
1257        assert_eq!(w2.get(), 5);
1258
1259        // But updates fail
1260        assert!(w3.updated().await.is_err());
1261    }
1262
1263    #[tokio::test]
1264    async fn test_three_watchers_truly_concurrent() {
1265        use tokio::time::sleep;
1266        let watchable = Watchable::new(0u8);
1267
1268        // Spawn three READER tasks
1269        let mut reader_handles = vec![];
1270        for i in 0..3 {
1271            let mut watcher = watchable.watch();
1272            let handle = tokio::spawn(async move {
1273                let mut values = vec![];
1274                // Collect up to 5 updates
1275                for _ in 0..5 {
1276                    if let Ok(value) = watcher.updated().await {
1277                        values.push(value);
1278                    } else {
1279                        break;
1280                    }
1281                }
1282                (i, values)
1283            });
1284            reader_handles.push(handle);
1285        }
1286
1287        // Spawn three WRITER tasks that update concurrently
1288        let mut writer_handles = vec![];
1289        for i in 0..3 {
1290            let watchable_clone = watchable.clone();
1291            let handle = tokio::spawn(async move {
1292                for j in 0..5 {
1293                    let value = (i * 10) + j;
1294                    watchable_clone.set(value).ok();
1295                    sleep(Duration::from_millis(5)).await;
1296                }
1297            });
1298            writer_handles.push(handle);
1299        }
1300
1301        // Wait for writers to finish
1302        for handle in writer_handles {
1303            handle.await.unwrap();
1304        }
1305
1306        // Wait for readers and check results
1307        for handle in reader_handles {
1308            let (task_id, values) = handle.await.unwrap();
1309            println!("Reader {}: saw values {:?}", task_id, values);
1310            assert!(!values.is_empty());
1311        }
1312    }
1313
1314    #[tokio::test]
1315    async fn test_peek() {
1316        let a = Watchable::new(vec![1, 2, 3]);
1317        let mut wa = a.watch();
1318
1319        assert_eq!(wa.get(), vec![1, 2, 3]);
1320        assert_eq!(wa.peek(), &vec![1, 2, 3]);
1321
1322        let mut wa_map = wa.map(|a| a.into_iter().map(|a| a * 2).collect::<Vec<_>>());
1323
1324        assert_eq!(wa_map.get(), vec![2, 4, 6]);
1325        assert_eq!(wa_map.peek(), &vec![2, 4, 6]);
1326
1327        let mut wb = a.watch();
1328
1329        assert_eq!(wb.get(), vec![1, 2, 3]);
1330        assert_eq!(wb.peek(), &vec![1, 2, 3]);
1331
1332        let mut wb_map = wb.map(|a| a.into_iter().map(|a| a * 2).collect::<Vec<_>>());
1333
1334        assert_eq!(wb_map.get(), vec![2, 4, 6]);
1335        assert_eq!(wb_map.peek(), &vec![2, 4, 6]);
1336
1337        let mut w_join = Join::new([wa_map, wb_map].into_iter());
1338
1339        assert_eq!(w_join.get(), vec![vec![2, 4, 6], vec![2, 4, 6]]);
1340        assert_eq!(w_join.peek(), &vec![vec![2, 4, 6], vec![2, 4, 6]]);
1341    }
1342
1343    #[tokio::test]
1344    async fn test_update_updates_peek() {
1345        let value = Watchable::new(42);
1346        let mut watcher = value.watch();
1347
1348        assert_eq!(watcher.peek(), &42);
1349        assert!(!watcher.update());
1350
1351        value.set(50).ok();
1352
1353        assert_eq!(watcher.peek(), &42); // watcher wasn't updated yet
1354        assert!(watcher.update()); // Update returns true, because there was an update
1355        assert_eq!(watcher.peek(), &50);
1356        assert!(!watcher.update());
1357
1358        let mut watcher_map = watcher.clone().map(|v| v * 2);
1359
1360        assert_eq!(watcher_map.peek(), &100);
1361        assert!(!watcher_map.update());
1362
1363        value.set(10).ok();
1364
1365        assert_eq!(watcher_map.peek(), &100);
1366        assert!(watcher_map.update());
1367        assert_eq!(watcher_map.peek(), &20);
1368        assert!(!watcher_map.update());
1369
1370        let value2 = Watchable::new(0);
1371        let mut watcher_join = Join::new([watcher, value2.watch()].into_iter());
1372
1373        assert_eq!(watcher_join.peek(), &vec![10, 0]);
1374        assert!(!watcher_join.update());
1375
1376        value.set(0).ok();
1377        value2.set(1).ok();
1378
1379        assert_eq!(watcher_join.peek(), &vec![10, 0]);
1380        assert!(watcher_join.update());
1381        assert_eq!(watcher_join.peek(), &vec![0, 1]);
1382        assert!(!watcher_join.update());
1383    }
1384
1385    #[tokio::test]
1386    async fn test_get_updates_peek() {
1387        let value = Watchable::new(42);
1388        let mut watcher = value.watch();
1389
1390        assert_eq!(watcher.peek(), &42);
1391        assert!(!watcher.update());
1392
1393        value.set(50).ok();
1394
1395        assert_eq!(watcher.peek(), &42); // watcher wasn't updated yet
1396        assert_eq!(watcher.get(), 50); // Update returns true, because there was an update
1397        assert_eq!(watcher.peek(), &50);
1398        assert!(!watcher.update());
1399
1400        let mut watcher_map = watcher.clone().map(|v| v * 2);
1401
1402        assert_eq!(watcher_map.peek(), &100);
1403        assert!(!watcher_map.update());
1404
1405        value.set(10).ok();
1406
1407        assert_eq!(watcher_map.peek(), &100);
1408        assert_eq!(watcher_map.get(), 20);
1409        assert_eq!(watcher_map.peek(), &20);
1410        assert!(!watcher_map.update());
1411
1412        let value2 = Watchable::new(0);
1413        let mut watcher_join = Join::new([watcher, value2.watch()].into_iter());
1414
1415        assert_eq!(watcher_join.peek(), &vec![10, 0]);
1416        assert!(!watcher_join.update());
1417
1418        value.set(0).ok();
1419        value2.set(1).ok();
1420
1421        assert_eq!(watcher_join.peek(), &vec![10, 0]);
1422        assert_eq!(watcher_join.get(), vec![0, 1]);
1423        assert_eq!(watcher_join.peek(), &vec![0, 1]);
1424        assert!(!watcher_join.update());
1425    }
1426}