n0_watcher/
lib.rs

1//! Watchable values.
2//!
3//! A [`Watchable`] exists to keep track of a value which may change over time.  It allows
4//! observers to be notified of changes to the value.  The aim is to always be aware of the
5//! **last** value, not to observe *every* value change.
6//!
7//! The reason for this is ergonomics and predictable resource usage: Requiring every
8//! intermediate value to be observable would mean that either the side that sets new values
9//! using [`Watchable::set`] would need to wait for all "receivers" of these intermediate
10//! values to catch up and thus be an async operation, or it would require the receivers
11//! to buffer intermediate values until they've been "received" on the [`Watcher`]s with
12//! an unlimited buffer size and thus potentially unlimited memory growth.
13//!
14//! # Example
15//!
16//! ```
17//! use n0_future::StreamExt;
18//! use n0_watcher::{Watchable, Watcher as _};
19//!
20//! #[tokio::main(flavor = "current_thread", start_paused = true)]
21//! async fn main() {
22//!     let watchable = Watchable::new(None);
23//!
24//!     // A task that waits for the watcher to be initialized to Some(value) before printing it
25//!     let mut watcher = watchable.watch();
26//!     tokio::spawn(async move {
27//!         let initialized_value = watcher.initialized().await;
28//!         println!("initialized: {initialized_value}");
29//!     });
30//!
31//!     // A task that prints every update to the watcher since the initial one:
32//!     let mut updates = watchable.watch().stream_updates_only();
33//!     tokio::spawn(async move {
34//!         while let Some(update) = updates.next().await {
35//!             println!("update: {update:?}");
36//!         }
37//!     });
38//!
39//!     // A task that prints the current value and then every update it can catch,
40//!     // but it also does something else which makes it very slow to pick up new
41//!     // values, so it'll skip some:
42//!     let mut current_and_updates = watchable.watch().stream();
43//!     tokio::spawn(async move {
44//!         while let Some(update) = current_and_updates.next().await {
45//!             println!("update2: {update:?}");
46//!             tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
47//!         }
48//!     });
49//!
50//!     for i in 0..20 {
51//!         println!("Setting watchable to {i}");
52//!         watchable.set(Some(i)).ok();
53//!         tokio::time::sleep(tokio::time::Duration::from_millis(250)).await;
54//!     }
55//! }
56//! ```
57//!
58//! # Similar but different
59//!
60//! - `async_channel`: This is a multi-producer, multi-consumer channel implementation.
61//!   Only at most one consumer will receive each "produced" value.
62//!   What we want is to have every "produced" value to be "broadcast" to every receiver.
63//! - `tokio::broadcast`: Also a multi-producer, multi-consumer channel implementation.
64//!   This is very similar to this crate (`tokio::broadcast::Sender` is like [`Watchable`]
65//!   and `tokio::broadcast::Receiver` is like [`Watcher`]), but you can't get the latest
66//!   value without `.await`ing on the receiver, and it'll internally store a queue of
67//!   intermediate values.
68//! - `tokio::watch`: Also a MPSC channel, and unlike `tokio::broadcast` only retains the
69//!   latest value. That module has pretty much the same purpose as this crate, but doesn't
70//!   implement a poll-based method of getting updates and doesn't implement combinators.
71//! - [`std::sync::RwLock`]: (wrapped in an [`std::sync::Arc`]) This allows you access
72//!   to the latest values, but might block while it's being set (but that could be short
73//!   enough not to matter for async rust purposes).
74//!   This doesn't allow you to be notified whenever a new value is written.
75//! - The `watchable` crate: We used to use this crate at n0, but we wanted to experiment
76//!   with different APIs and needed Wasm support.
77#[cfg(not(watcher_loom))]
78use std::sync;
79use std::{
80    collections::VecDeque,
81    future::Future,
82    pin::Pin,
83    sync::{Arc, Weak},
84    task::{self, ready, Poll, Waker},
85};
86
87#[cfg(watcher_loom)]
88use loom::sync;
89use n0_error::StackError;
90use sync::{Mutex, RwLock};
91
92/// A wrapper around a value that notifies [`Watcher`]s when the value is modified.
93///
94/// Only the most recent value is available to any observer, but the observer is guaranteed
95/// to be notified of the most recent value.
96#[derive(Debug, Default)]
97pub struct Watchable<T> {
98    shared: Arc<Shared<T>>,
99}
100
101impl<T> Clone for Watchable<T> {
102    fn clone(&self) -> Self {
103        Self {
104            shared: self.shared.clone(),
105        }
106    }
107}
108
109/// Abstracts over `Option<T>` and `Vec<T>`
110pub trait Nullable<T> {
111    /// Converts this value into an `Option`.
112    fn into_option(self) -> Option<T>;
113}
114
115impl<T> Nullable<T> for Option<T> {
116    fn into_option(self) -> Option<T> {
117        self
118    }
119}
120
121impl<T> Nullable<T> for Vec<T> {
122    fn into_option(mut self) -> Option<T> {
123        self.pop()
124    }
125}
126
127impl<T: Clone + Eq> Watchable<T> {
128    /// Creates a [`Watchable`] initialized to given value.
129    pub fn new(value: T) -> Self {
130        Self {
131            shared: Arc::new(Shared {
132                state: RwLock::new(State {
133                    value,
134                    epoch: INITIAL_EPOCH,
135                }),
136                watchers: Default::default(),
137            }),
138        }
139    }
140
141    /// Sets a new value.
142    ///
143    /// Returns `Ok(previous_value)` if the value was different from the one set, or
144    /// returns the provided value back as `Err(value)` if the value didn't change.
145    ///
146    /// Watchers are only notified if the value changed.
147    pub fn set(&self, value: T) -> Result<T, T> {
148        // We don't actually write when the value didn't change, but there's unfortunately
149        // no way to upgrade a read guard to a write guard, and locking as read first, then
150        // dropping and locking as write introduces a possible race condition.
151        let mut state = self.shared.state.write().expect("poisoned");
152
153        // Find out if the value changed
154        let changed = state.value != value;
155
156        let ret = if changed {
157            let old = std::mem::replace(&mut state.value, value);
158            state.epoch += 1;
159            Ok(old)
160        } else {
161            Err(value)
162        };
163        drop(state); // No need to write anymore
164
165        // Notify watchers
166        if changed {
167            for watcher in self.shared.watchers.lock().expect("poisoned").drain(..) {
168                watcher.wake();
169            }
170        }
171        ret
172    }
173
174    /// Creates a [`Direct`] [`Watcher`], allowing the value to be observed, but not modified.
175    pub fn watch(&self) -> Direct<T> {
176        Direct {
177            state: self.shared.state(),
178            shared: Arc::downgrade(&self.shared),
179        }
180    }
181
182    /// Returns the currently stored value.
183    pub fn get(&self) -> T {
184        self.shared.get()
185    }
186
187    /// Returns true when there are any watchers actively listening on changes,
188    /// or false when all watchers have been dropped or none have been created yet.
189    pub fn has_watchers(&self) -> bool {
190        // `Watchable`s will increase the strong count
191        // `Direct`s watchers (which all watchers descend from) will increase the weak count
192        Arc::weak_count(&self.shared) != 0
193    }
194}
195
196impl<T> Drop for Watchable<T> {
197    fn drop(&mut self) {
198        let Ok(mut watchers) = self.shared.watchers.lock() else {
199            return; // Poisoned waking?
200        };
201        // Wake all watchers every time we drop.
202        // This allows us to notify `NextFut::poll`s that the underlying
203        // watchable might be dropped.
204        for watcher in watchers.drain(..) {
205            watcher.wake();
206        }
207    }
208}
209
210/// A handle to a value that's represented by one or more underlying [`Watchable`]s.
211///
212/// A [`Watcher`] can get the current value, and will be notified when the value changes.
213/// Only the most recent value is accessible, and if the threads with the underlying [`Watchable`]s
214/// change the value faster than the threads with the [`Watcher`] can keep up with, then
215/// it'll miss in-between values.
216/// When the thread changing the [`Watchable`] pauses updating, the [`Watcher`] will always
217/// end up reporting the most recent state eventually.
218///
219/// Watchers can be modified via [`Watcher::map`] to observe a value derived from the original
220/// value via a function.
221///
222/// Watchers can be combined via [`Watcher::or`] to allow observing multiple values at once and
223/// getting an update in case any of the values updates.
224///
225/// One of the underlying [`Watchable`]s might already be dropped. In that case,
226/// the watcher will be "disconnected" and return [`Err(Disconnected)`](Disconnected)
227/// on some function calls or, when turned into a stream, that stream will end.
228/// This property can also be checked with [`Watcher::is_connected`].
229pub trait Watcher: Clone {
230    /// The type of value that can change.
231    ///
232    /// We require `Clone`, because we need to be able to make
233    /// the values have a lifetime that's detached from the original [`Watchable`]'s
234    /// lifetime.
235    ///
236    /// We require `Eq`, to be able to check whether the value actually changed or
237    /// not, so we can notify or not notify accordingly.
238    type Value: Clone + Eq;
239
240    /// Returns the current state of the underlying value.
241    ///
242    /// If any of the underlying [`Watchable`] values has been dropped, then this
243    /// might return an outdated value for that watchable, specifically, the latest
244    /// value that was fetched for that watchable, as opposed to the latest value
245    /// that was set on the watchable before it was dropped.
246    fn get(&mut self) -> Self::Value;
247
248    /// Whether this watcher is still connected to all of its underlying [`Watchable`]s.
249    ///
250    /// Returns false when any of the underlying watchables has been dropped.
251    fn is_connected(&self) -> bool;
252
253    /// Polls for the next value, or returns [`Disconnected`] if one of the underlying
254    /// [`Watchable`]s has been dropped.
255    fn poll_updated(
256        &mut self,
257        cx: &mut task::Context<'_>,
258    ) -> Poll<Result<Self::Value, Disconnected>>;
259
260    /// Returns a future completing with `Ok(value)` once a new value is set, or with
261    /// [`Err(Disconnected)`](Disconnected) if the connected [`Watchable`] was dropped.
262    ///
263    /// # Cancel Safety
264    ///
265    /// The returned future is cancel-safe.
266    fn updated(&mut self) -> NextFut<'_, Self> {
267        NextFut { watcher: self }
268    }
269
270    /// Returns a future completing once the value is set to [`Some`] value.
271    ///
272    /// If the current value is [`Some`] value, this future will resolve immediately.
273    ///
274    /// This is a utility for the common case of storing an [`Option`] inside a
275    /// [`Watchable`].
276    ///
277    /// # Cancel Safety
278    ///
279    /// The returned future is cancel-safe.
280    fn initialized<T, W>(&mut self) -> InitializedFut<'_, T, W, Self>
281    where
282        W: Nullable<T>,
283        Self: Watcher<Value = W>,
284    {
285        InitializedFut {
286            initial: self.get().into_option(),
287            watcher: self,
288        }
289    }
290
291    /// Returns a stream which will yield the most recent values as items.
292    ///
293    /// The first item of the stream is the current value, so that this stream can be easily
294    /// used to operate on the most recent value.
295    ///
296    /// Note however, that only the last item is stored.  If the stream is not polled when an
297    /// item is available it can be replaced with another item by the time it is polled.
298    ///
299    /// This stream ends once the original [`Watchable`] has been dropped.
300    ///
301    /// # Cancel Safety
302    ///
303    /// The returned stream is cancel-safe.
304    fn stream(mut self) -> Stream<Self>
305    where
306        Self: Unpin,
307    {
308        Stream {
309            initial: Some(self.get()),
310            watcher: self,
311        }
312    }
313
314    /// Returns a stream which will yield the most recent values as items, starting from
315    /// the next unobserved future value.
316    ///
317    /// This means this stream will only yield values when the watched value changes,
318    /// the value stored at the time the stream is created is not yielded.
319    ///
320    /// Note however, that only the last item is stored.  If the stream is not polled when an
321    /// item is available it can be replaced with another item by the time it is polled.
322    ///
323    /// This stream ends once the original [`Watchable`] has been dropped.
324    ///
325    /// # Cancel Safety
326    ///
327    /// The returned stream is cancel-safe.
328    fn stream_updates_only(self) -> Stream<Self>
329    where
330        Self: Unpin,
331    {
332        Stream {
333            initial: None,
334            watcher: self,
335        }
336    }
337
338    /// Maps this watcher with a function that transforms the observed values.
339    ///
340    /// The returned watcher will only register updates, when the *mapped* value
341    /// observably changes.
342    fn map<T: Clone + Eq>(
343        mut self,
344        map: impl Fn(Self::Value) -> T + Send + Sync + 'static,
345    ) -> Map<Self, T> {
346        Map {
347            current: (map)(self.get()),
348            map: Arc::new(map),
349            watcher: self,
350        }
351    }
352
353    /// Returns a watcher that updates every time this or the other watcher
354    /// updates, and yields both watcher's items together when that happens.
355    fn or<W: Watcher>(self, other: W) -> (Self, W) {
356        (self, other)
357    }
358}
359
360/// The immediate, direct observer of a [`Watchable`] value.
361///
362/// This type is mainly used via the [`Watcher`] interface.
363#[derive(Debug, Clone)]
364pub struct Direct<T> {
365    state: State<T>,
366    shared: Weak<Shared<T>>,
367}
368
369impl<T: Clone + Eq> Watcher for Direct<T> {
370    type Value = T;
371
372    fn get(&mut self) -> Self::Value {
373        if let Some(shared) = self.shared.upgrade() {
374            self.state = shared.state();
375        }
376        self.state.value.clone()
377    }
378
379    fn is_connected(&self) -> bool {
380        self.shared.upgrade().is_some()
381    }
382
383    fn poll_updated(
384        &mut self,
385        cx: &mut task::Context<'_>,
386    ) -> Poll<Result<Self::Value, Disconnected>> {
387        let Some(shared) = self.shared.upgrade() else {
388            return Poll::Ready(Err(Disconnected));
389        };
390        self.state = ready!(shared.poll_updated(cx, self.state.epoch));
391        Poll::Ready(Ok(self.state.value.clone()))
392    }
393}
394
395impl<S: Watcher, T: Watcher> Watcher for (S, T) {
396    type Value = (S::Value, T::Value);
397
398    fn get(&mut self) -> Self::Value {
399        (self.0.get(), self.1.get())
400    }
401
402    fn is_connected(&self) -> bool {
403        self.0.is_connected() && self.1.is_connected()
404    }
405
406    fn poll_updated(
407        &mut self,
408        cx: &mut task::Context<'_>,
409    ) -> Poll<Result<Self::Value, Disconnected>> {
410        let poll_0 = self.0.poll_updated(cx)?;
411        let poll_1 = self.1.poll_updated(cx)?;
412        match (poll_0, poll_1) {
413            (Poll::Ready(s), Poll::Ready(t)) => Poll::Ready(Ok((s, t))),
414            (Poll::Ready(s), Poll::Pending) => Poll::Ready(Ok((s, self.1.get()))),
415            (Poll::Pending, Poll::Ready(t)) => Poll::Ready(Ok((self.0.get(), t))),
416            (Poll::Pending, Poll::Pending) => Poll::Pending,
417        }
418    }
419}
420
421impl<S: Watcher, T: Watcher, U: Watcher> Watcher for (S, T, U) {
422    type Value = (S::Value, T::Value, U::Value);
423
424    fn get(&mut self) -> Self::Value {
425        (self.0.get(), self.1.get(), self.2.get())
426    }
427
428    fn is_connected(&self) -> bool {
429        self.0.is_connected() && self.1.is_connected() && self.2.is_connected()
430    }
431
432    fn poll_updated(
433        &mut self,
434        cx: &mut task::Context<'_>,
435    ) -> Poll<Result<Self::Value, Disconnected>> {
436        let poll_0 = self.0.poll_updated(cx)?;
437        let poll_1 = self.1.poll_updated(cx)?;
438        let poll_2 = self.2.poll_updated(cx)?;
439
440        if poll_0.is_pending() && poll_1.is_pending() && poll_2.is_pending() {
441            Poll::Pending
442        } else {
443            fn to_option<T>(poll: Poll<T>) -> Option<T> {
444                match poll {
445                    Poll::Ready(t) => Some(t),
446                    Poll::Pending => None,
447                }
448            }
449
450            let s = to_option(poll_0).unwrap_or_else(|| self.0.get());
451            let t = to_option(poll_1).unwrap_or_else(|| self.1.get());
452            let u = to_option(poll_2).unwrap_or_else(|| self.2.get());
453            Poll::Ready(Ok((s, t, u)))
454        }
455    }
456}
457
458/// Combinator to join two watchers
459#[derive(Debug, Clone)]
460pub struct Join<T: Clone + Eq, W: Watcher<Value = T>> {
461    watchers: Vec<W>,
462}
463impl<T: Clone + Eq, W: Watcher<Value = T>> Join<T, W> {
464    /// Joins a set of watchers into a single watcher
465    pub fn new(watchers: impl Iterator<Item = W>) -> Self {
466        let watchers: Vec<W> = watchers.into_iter().collect();
467
468        Self { watchers }
469    }
470}
471
472impl<T: Clone + Eq, W: Watcher<Value = T>> Watcher for Join<T, W> {
473    type Value = Vec<T>;
474
475    fn get(&mut self) -> Self::Value {
476        let mut out = Vec::with_capacity(self.watchers.len());
477        for watcher in &mut self.watchers {
478            out.push(watcher.get());
479        }
480
481        out
482    }
483
484    fn is_connected(&self) -> bool {
485        self.watchers.iter().all(|w| w.is_connected())
486    }
487
488    fn poll_updated(
489        &mut self,
490        cx: &mut task::Context<'_>,
491    ) -> Poll<Result<Self::Value, Disconnected>> {
492        let mut new_value = None;
493        for (i, watcher) in self.watchers.iter_mut().enumerate() {
494            match watcher.poll_updated(cx) {
495                Poll::Pending => {}
496                Poll::Ready(Ok(value)) => {
497                    new_value.replace((i, value));
498                    break;
499                }
500                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
501            }
502        }
503
504        if let Some((j, new_value)) = new_value {
505            let mut new = Vec::with_capacity(self.watchers.len());
506            for (i, watcher) in self.watchers.iter_mut().enumerate() {
507                if i != j {
508                    new.push(watcher.get());
509                } else {
510                    new.push(new_value.clone());
511                }
512            }
513            Poll::Ready(Ok(new))
514        } else {
515            Poll::Pending
516        }
517    }
518}
519
520/// Wraps a [`Watcher`] to allow observing a derived value.
521///
522/// See [`Watcher::map`].
523#[derive(derive_more::Debug, Clone)]
524pub struct Map<W: Watcher, T: Clone + Eq> {
525    #[debug("Arc<dyn Fn(W::Value) -> T + 'static>")]
526    map: Arc<dyn Fn(W::Value) -> T + Send + Sync + 'static>,
527    watcher: W,
528    current: T,
529}
530
531impl<W: Watcher, T: Clone + Eq> Watcher for Map<W, T> {
532    type Value = T;
533
534    fn get(&mut self) -> Self::Value {
535        (self.map)(self.watcher.get())
536    }
537
538    fn is_connected(&self) -> bool {
539        self.watcher.is_connected()
540    }
541
542    fn poll_updated(
543        &mut self,
544        cx: &mut task::Context<'_>,
545    ) -> Poll<Result<Self::Value, Disconnected>> {
546        loop {
547            let value = ready!(self.watcher.poll_updated(cx)?);
548            let mapped = (self.map)(value);
549            if mapped != self.current {
550                self.current = mapped.clone();
551                return Poll::Ready(Ok(mapped));
552            } else {
553                self.current = mapped;
554            }
555        }
556    }
557}
558
559/// Future returning the next item after the current one in a [`Watcher`].
560///
561/// See [`Watcher::updated`].
562///
563/// # Cancel Safety
564///
565/// This future is cancel-safe.
566#[derive(Debug)]
567pub struct NextFut<'a, W: Watcher> {
568    watcher: &'a mut W,
569}
570
571impl<W: Watcher> Future for NextFut<'_, W> {
572    type Output = Result<W::Value, Disconnected>;
573
574    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
575        self.watcher.poll_updated(cx)
576    }
577}
578
579/// Future returning the current or next value that's [`Some`] value.
580/// in a [`Watcher`].
581///
582/// See [`Watcher::initialized`].
583///
584/// # Cancel Safety
585///
586/// This Future is cancel-safe.
587#[derive(Debug)]
588pub struct InitializedFut<'a, T, V: Nullable<T>, W: Watcher<Value = V>> {
589    initial: Option<T>,
590    watcher: &'a mut W,
591}
592
593impl<T: Clone + Eq + Unpin, V: Nullable<T>, W: Watcher<Value = V> + Unpin> Future
594    for InitializedFut<'_, T, V, W>
595{
596    type Output = T;
597
598    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
599        if let Some(value) = self.as_mut().initial.take() {
600            return Poll::Ready(value);
601        }
602        loop {
603            let Ok(value) = ready!(self.as_mut().watcher.poll_updated(cx)) else {
604                // The value will never be initialized
605                return Poll::Pending;
606            };
607            if let Some(value) = value.into_option() {
608                return Poll::Ready(value);
609            }
610        }
611    }
612}
613
614/// A stream for a [`Watcher`]'s next values.
615///
616/// See [`Watcher::stream`] and [`Watcher::stream_updates_only`].
617///
618/// # Cancel Safety
619///
620/// This stream is cancel-safe.
621#[derive(Debug, Clone)]
622pub struct Stream<W: Watcher + Unpin> {
623    initial: Option<W::Value>,
624    watcher: W,
625}
626
627impl<W: Watcher + Unpin> n0_future::Stream for Stream<W>
628where
629    W::Value: Unpin,
630{
631    type Item = W::Value;
632
633    fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
634        if let Some(value) = self.as_mut().initial.take() {
635            return Poll::Ready(Some(value));
636        }
637        match self.as_mut().watcher.poll_updated(cx) {
638            Poll::Ready(Ok(value)) => Poll::Ready(Some(value)),
639            Poll::Ready(Err(Disconnected)) => Poll::Ready(None),
640            Poll::Pending => Poll::Pending,
641        }
642    }
643}
644
645/// The error for when a [`Watcher`] is disconnected from its underlying
646/// [`Watchable`] value, because of that watchable having been dropped.
647#[derive(StackError)]
648#[error("Watcher lost connection to underlying Watchable, it was dropped")]
649pub struct Disconnected;
650
651// Private:
652
653const INITIAL_EPOCH: u64 = 1;
654
655/// The shared state for a [`Watchable`].
656#[derive(Debug, Default)]
657struct Shared<T> {
658    /// The value to be watched and its current epoch.
659    state: RwLock<State<T>>,
660    watchers: Mutex<VecDeque<Waker>>,
661}
662
663#[derive(Debug, Clone)]
664struct State<T> {
665    value: T,
666    epoch: u64,
667}
668
669impl<T: Default> Default for State<T> {
670    fn default() -> Self {
671        Self {
672            value: Default::default(),
673            epoch: INITIAL_EPOCH,
674        }
675    }
676}
677
678impl<T: Clone> Shared<T> {
679    /// Returns the value, initialized or not.
680    fn get(&self) -> T {
681        self.state.read().expect("poisoned").value.clone()
682    }
683
684    fn state(&self) -> State<T> {
685        self.state.read().expect("poisoned").clone()
686    }
687
688    fn poll_updated(&self, cx: &mut task::Context<'_>, last_epoch: u64) -> Poll<State<T>> {
689        {
690            let state = self.state.read().expect("poisoned");
691
692            // We might get spurious wakeups due to e.g. a second-to-last Watchable being dropped.
693            // This makes sure we don't accidentally return an update that's not actually an update.
694            if last_epoch < state.epoch {
695                return Poll::Ready(state.clone());
696            }
697        }
698
699        self.watchers
700            .lock()
701            .expect("poisoned")
702            .push_back(cx.waker().to_owned());
703
704        #[cfg(watcher_loom)]
705        loom::thread::yield_now();
706
707        // We check for an update again to prevent races between putting in wakers and looking for updates.
708        {
709            let state = self.state.read().expect("poisoned");
710
711            if last_epoch < state.epoch {
712                return Poll::Ready(state.clone());
713            }
714        }
715
716        Poll::Pending
717    }
718}
719
720#[cfg(test)]
721mod tests {
722
723    use n0_future::{future::poll_once, StreamExt};
724    use rand::{rng, Rng};
725    use tokio::{
726        task::JoinSet,
727        time::{Duration, Instant},
728    };
729    use tokio_util::sync::CancellationToken;
730
731    use super::*;
732
733    #[tokio::test]
734    async fn test_watcher() {
735        let cancel = CancellationToken::new();
736        let watchable = Watchable::new(17);
737
738        assert_eq!(watchable.watch().stream().next().await.unwrap(), 17);
739
740        let start = Instant::now();
741        // spawn watchers
742        let mut tasks = JoinSet::new();
743        for i in 0..3 {
744            let mut watch = watchable.watch().stream();
745            let cancel = cancel.clone();
746            tasks.spawn(async move {
747                println!("[{i}] spawn");
748                let mut expected_value = 17;
749                loop {
750                    tokio::select! {
751                        biased;
752                        Some(value) = &mut watch.next() => {
753                            println!("{:?} [{i}] update: {value}", start.elapsed());
754                            assert_eq!(value, expected_value);
755                            if expected_value == 17 {
756                                expected_value = 0;
757                            } else {
758                                expected_value += 1;
759                            }
760                        },
761                        _ = cancel.cancelled() => {
762                            println!("{:?} [{i}] cancel", start.elapsed());
763                            assert_eq!(expected_value, 10);
764                            break;
765                        }
766                    }
767                }
768            });
769        }
770        for i in 0..3 {
771            let mut watch = watchable.watch().stream_updates_only();
772            let cancel = cancel.clone();
773            tasks.spawn(async move {
774                println!("[{i}] spawn");
775                let mut expected_value = 0;
776                loop {
777                    tokio::select! {
778                        biased;
779                        Some(value) = watch.next() => {
780                            println!("{:?} [{i}] stream update: {value}", start.elapsed());
781                            assert_eq!(value, expected_value);
782                            expected_value += 1;
783                        },
784                        _ = cancel.cancelled() => {
785                            println!("{:?} [{i}] cancel", start.elapsed());
786                            assert_eq!(expected_value, 10);
787                            break;
788                        }
789                        else => {
790                            panic!("stream died");
791                        }
792                    }
793                }
794            });
795        }
796
797        // set value
798        for next_value in 0..10 {
799            let sleep = Duration::from_nanos(rng().random_range(0..100_000_000));
800            println!("{:?} sleep {sleep:?}", start.elapsed());
801            tokio::time::sleep(sleep).await;
802
803            let changed = watchable.set(next_value);
804            println!("{:?} set {next_value} changed={changed:?}", start.elapsed());
805        }
806
807        println!("cancel");
808        cancel.cancel();
809        while let Some(res) = tasks.join_next().await {
810            res.expect("task failed");
811        }
812    }
813
814    #[test]
815    fn test_get() {
816        let watchable = Watchable::new(None);
817        assert!(watchable.get().is_none());
818
819        watchable.set(Some(1u8)).ok();
820        assert_eq!(watchable.get(), Some(1u8));
821    }
822
823    #[tokio::test]
824    async fn test_initialize() {
825        let watchable = Watchable::new(None);
826
827        let mut watcher = watchable.watch();
828        let mut initialized = watcher.initialized();
829
830        let poll = poll_once(&mut initialized).await;
831        assert!(poll.is_none());
832
833        watchable.set(Some(1u8)).ok();
834
835        let poll = poll_once(&mut initialized).await;
836        assert_eq!(poll.unwrap(), 1u8);
837    }
838
839    #[tokio::test]
840    async fn test_initialize_already_init() {
841        let watchable = Watchable::new(Some(1u8));
842
843        let mut watcher = watchable.watch();
844        let mut initialized = watcher.initialized();
845
846        let poll = poll_once(&mut initialized).await;
847        assert_eq!(poll.unwrap(), 1u8);
848    }
849
850    #[test]
851    fn test_initialized_always_resolves() {
852        #[cfg(not(watcher_loom))]
853        use std::thread;
854
855        #[cfg(watcher_loom)]
856        use loom::thread;
857
858        let test_case = || {
859            let watchable = Watchable::<Option<u8>>::new(None);
860
861            let mut watch = watchable.watch();
862            let thread = thread::spawn(move || n0_future::future::block_on(watch.initialized()));
863
864            watchable.set(Some(42)).ok();
865
866            thread::yield_now();
867
868            let value: u8 = thread.join().unwrap();
869
870            assert_eq!(value, 42);
871        };
872
873        #[cfg(watcher_loom)]
874        loom::model(test_case);
875        #[cfg(not(watcher_loom))]
876        test_case();
877    }
878
879    #[tokio::test(flavor = "multi_thread")]
880    async fn test_update_cancel_safety() {
881        let watchable = Watchable::new(0);
882        let mut watch = watchable.watch();
883        const MAX: usize = 100_000;
884
885        let handle = tokio::spawn(async move {
886            let mut last_observed = 0;
887
888            while last_observed != MAX {
889                tokio::select! {
890                    val = watch.updated() => {
891                        let Ok(val) = val else {
892                            return;
893                        };
894
895                        assert_ne!(val, last_observed, "never observe the same value twice, even with cancellation");
896                        last_observed = val;
897                    }
898                    _ = tokio::time::sleep(Duration::from_micros(rng().random_range(0..10_000))) => {
899                        // We cancel the other future and start over again
900                        continue;
901                    }
902                }
903            }
904        });
905
906        for i in 1..=MAX {
907            watchable.set(i).ok();
908            if rng().random_bool(0.2) {
909                tokio::task::yield_now().await;
910            }
911        }
912
913        tokio::time::timeout(Duration::from_secs(10), handle)
914            .await
915            .unwrap()
916            .unwrap()
917    }
918
919    #[tokio::test]
920    async fn test_join_simple() {
921        let a = Watchable::new(1u8);
922        let b = Watchable::new(1u8);
923
924        let mut ab = Join::new([a.watch(), b.watch()].into_iter());
925
926        let stream = ab.clone().stream();
927        let handle = tokio::task::spawn(async move { stream.take(5).collect::<Vec<_>>().await });
928
929        // get
930        assert_eq!(ab.get(), vec![1, 1]);
931        // set a
932        a.set(2u8).unwrap();
933        tokio::task::yield_now().await;
934        assert_eq!(ab.get(), vec![2, 1]);
935        // set b
936        b.set(3u8).unwrap();
937        tokio::task::yield_now().await;
938        assert_eq!(ab.get(), vec![2, 3]);
939
940        a.set(3u8).unwrap();
941        tokio::task::yield_now().await;
942        b.set(4u8).unwrap();
943        tokio::task::yield_now().await;
944
945        let values = tokio::time::timeout(Duration::from_secs(5), handle)
946            .await
947            .unwrap()
948            .unwrap();
949        assert_eq!(
950            values,
951            vec![vec![1, 1], vec![2, 1], vec![2, 3], vec![3, 3], vec![3, 4]]
952        );
953    }
954
955    #[tokio::test]
956    async fn test_updated_then_disconnect_then_get() {
957        let watchable = Watchable::new(10);
958        let mut watcher = watchable.watch();
959        assert_eq!(watchable.get(), 10);
960        watchable.set(42).ok();
961        assert_eq!(watcher.updated().await.unwrap(), 42);
962        drop(watchable);
963        assert_eq!(watcher.get(), 42);
964    }
965
966    #[tokio::test(start_paused = true)]
967    async fn test_update_wakeup_on_watchable_drop() {
968        let watchable = Watchable::new(10);
969        let mut watcher = watchable.watch();
970
971        let start = Instant::now();
972        let (_, result) = tokio::time::timeout(Duration::from_secs(2), async move {
973            tokio::join!(
974                async move {
975                    tokio::time::sleep(Duration::from_secs(1)).await;
976                    drop(watchable);
977                },
978                async move { watcher.updated().await }
979            )
980        })
981        .await
982        .expect("watcher never updated");
983        // We should've updated 1s after start, since that's when the watchable was dropped.
984        // If this is 2s, then the watchable dropping didn't wake up the `Watcher::updated` future.
985        assert_eq!(start.elapsed(), Duration::from_secs(1));
986        assert!(result.is_err());
987    }
988
989    #[tokio::test(start_paused = true)]
990    async fn test_update_wakeup_always_a_change() {
991        let watchable = Watchable::new(10);
992        let mut watcher = watchable.watch();
993
994        let task = tokio::spawn(async move {
995            let mut last_value = watcher.get();
996            let mut values = Vec::new();
997            while let Ok(value) = watcher.updated().await {
998                values.push(value);
999                if last_value == value {
1000                    return Err("value duplicated");
1001                }
1002                last_value = value;
1003            }
1004            Ok(values)
1005        });
1006
1007        // wait for the task to get set up and polled till pending for once
1008        tokio::time::sleep(Duration::from_millis(100)).await;
1009
1010        watchable.set(11).ok();
1011        tokio::time::sleep(Duration::from_millis(100)).await;
1012        let clone = watchable.clone();
1013        drop(clone); // this shouldn't trigger an update
1014        tokio::time::sleep(Duration::from_millis(100)).await;
1015        for i in 1..=10 {
1016            watchable.set(i + 11).ok();
1017            tokio::time::sleep(Duration::from_millis(100)).await;
1018        }
1019        drop(watchable);
1020
1021        let values = task
1022            .await
1023            .expect("task panicked")
1024            .expect("value duplicated");
1025        assert_eq!(values, vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]);
1026    }
1027
1028    #[test]
1029    fn test_has_watchers() {
1030        let a = Watchable::new(1u8);
1031        assert!(!a.has_watchers());
1032        let b = a.clone();
1033        assert!(!a.has_watchers());
1034        assert!(!b.has_watchers());
1035
1036        let watcher = a.watch();
1037        assert!(a.has_watchers());
1038        assert!(b.has_watchers());
1039
1040        drop(watcher);
1041
1042        assert!(!a.has_watchers());
1043        assert!(!b.has_watchers());
1044    }
1045
1046    #[tokio::test]
1047    async fn test_three_watchers_basic() {
1048        let watchable = Watchable::new(1u8);
1049
1050        let mut w1 = watchable.watch();
1051        let mut w2 = watchable.watch();
1052        let mut w3 = watchable.watch();
1053
1054        // All see the initial value
1055
1056        assert_eq!(w1.get(), 1);
1057        assert_eq!(w2.get(), 1);
1058        assert_eq!(w3.get(), 1);
1059
1060        // Change  value
1061        watchable.set(42).unwrap();
1062
1063        // All watchers get notified
1064        assert_eq!(w1.updated().await.unwrap(), 42);
1065        assert_eq!(w2.updated().await.unwrap(), 42);
1066        assert_eq!(w3.updated().await.unwrap(), 42);
1067    }
1068
1069    #[tokio::test]
1070    async fn test_three_watchers_skip_intermediate() {
1071        let watchable = Watchable::new(0u8);
1072        let mut watcher = watchable.watch();
1073
1074        watchable.set(1).ok();
1075        watchable.set(2).ok();
1076        watchable.set(3).ok();
1077        watchable.set(4).ok();
1078
1079        let value = watcher.updated().await.unwrap();
1080
1081        assert_eq!(value, 4);
1082    }
1083
1084    #[tokio::test]
1085    async fn test_three_watchers_with_streams() {
1086        let watchable = Watchable::new(10u8);
1087
1088        let mut stream1 = watchable.watch().stream();
1089        let mut stream2 = watchable.watch().stream();
1090        let mut stream3 = watchable.watch().stream_updates_only();
1091
1092        assert_eq!(stream1.next().await.unwrap(), 10);
1093        assert_eq!(stream2.next().await.unwrap(), 10);
1094
1095        // Update the value
1096        watchable.set(20).ok();
1097
1098        // All streams see the update
1099        assert_eq!(stream1.next().await.unwrap(), 20);
1100        assert_eq!(stream2.next().await.unwrap(), 20);
1101        assert_eq!(stream3.next().await.unwrap(), 20);
1102    }
1103
1104    #[tokio::test]
1105    async fn test_three_watchers_independent() {
1106        let watchable = Watchable::new(0u8);
1107
1108        let mut fast_watcher = watchable.watch();
1109        let mut slow_watcher = watchable.watch();
1110        let mut lazy_watcher = watchable.watch();
1111
1112        watchable.set(1).ok();
1113        assert_eq!(fast_watcher.updated().await.unwrap(), 1);
1114
1115        // More updates happen
1116        watchable.set(2).ok();
1117        watchable.set(3).ok();
1118
1119        assert_eq!(slow_watcher.updated().await.unwrap(), 3);
1120        assert_eq!(lazy_watcher.get(), 3);
1121    }
1122
1123    #[tokio::test]
1124    async fn test_combine_three_watchers() {
1125        let a = Watchable::new(1u8);
1126        let b = Watchable::new(2u8);
1127        let c = Watchable::new(3u8);
1128
1129        let mut combined = (a.watch(), b.watch(), c.watch());
1130
1131        assert_eq!(combined.get(), (1, 2, 3));
1132
1133        // Update one
1134        b.set(20).ok();
1135
1136        assert_eq!(combined.updated().await.unwrap(), (1, 20, 3));
1137
1138        c.set(30).ok();
1139        assert_eq!(combined.updated().await.unwrap(), (1, 20, 30));
1140    }
1141
1142    #[tokio::test]
1143    async fn test_three_watchers_disconnection() {
1144        let watchable = Watchable::new(5u8);
1145
1146        // All connected
1147        let mut w1 = watchable.watch();
1148        let mut w2 = watchable.watch();
1149        let mut w3 = watchable.watch();
1150
1151        // Drop the watchable
1152        drop(watchable);
1153
1154        // All become disconnected
1155        assert!(!w1.is_connected());
1156        assert!(!w2.is_connected());
1157        assert!(!w3.is_connected());
1158
1159        // Can still get last known value
1160        assert_eq!(w1.get(), 5);
1161        assert_eq!(w2.get(), 5);
1162
1163        // But updates fail
1164        assert!(w3.updated().await.is_err());
1165    }
1166
1167    #[tokio::test]
1168    async fn test_three_watchers_truly_concurrent() {
1169        use tokio::time::sleep;
1170        let watchable = Watchable::new(0u8);
1171
1172        // Spawn three READER tasks
1173        let mut reader_handles = vec![];
1174        for i in 0..3 {
1175            let mut watcher = watchable.watch();
1176            let handle = tokio::spawn(async move {
1177                let mut values = vec![];
1178                // Collect up to 5 updates
1179                for _ in 0..5 {
1180                    if let Ok(value) = watcher.updated().await {
1181                        values.push(value);
1182                    } else {
1183                        break;
1184                    }
1185                }
1186                (i, values)
1187            });
1188            reader_handles.push(handle);
1189        }
1190
1191        // Spawn three WRITER tasks that update concurrently
1192        let mut writer_handles = vec![];
1193        for i in 0..3 {
1194            let watchable_clone = watchable.clone();
1195            let handle = tokio::spawn(async move {
1196                for j in 0..5 {
1197                    let value = (i * 10) + j;
1198                    watchable_clone.set(value).ok();
1199                    sleep(Duration::from_millis(5)).await;
1200                }
1201            });
1202            writer_handles.push(handle);
1203        }
1204
1205        // Wait for writers to finish
1206        for handle in writer_handles {
1207            handle.await.unwrap();
1208        }
1209
1210        // Wait for readers and check results
1211        for handle in reader_handles {
1212            let (task_id, values) = handle.await.unwrap();
1213            println!("Reader {}: saw values {:?}", task_id, values);
1214            assert!(!values.is_empty());
1215        }
1216    }
1217}