1#[cfg(not(watcher_loom))]
17use std::sync;
18use std::{
19 collections::VecDeque,
20 future::Future,
21 pin::Pin,
22 sync::{Arc, Weak},
23 task::{self, ready, Poll, Waker},
24};
25
26#[cfg(watcher_loom)]
27use loom::sync;
28use snafu::Snafu;
29use sync::{Mutex, RwLock};
30
31#[derive(Debug, Default)]
36pub struct Watchable<T> {
37 shared: Arc<Shared<T>>,
38}
39
40impl<T> Clone for Watchable<T> {
41 fn clone(&self) -> Self {
42 Self {
43 shared: self.shared.clone(),
44 }
45 }
46}
47
48pub trait Nullable<T> {
50 fn into_option(self) -> Option<T>;
52}
53
54impl<T> Nullable<T> for Option<T> {
55 fn into_option(self) -> Option<T> {
56 self
57 }
58}
59
60impl<T> Nullable<T> for Vec<T> {
61 fn into_option(mut self) -> Option<T> {
62 self.pop()
63 }
64}
65
66impl<T: Clone + Eq> Watchable<T> {
67 pub fn new(value: T) -> Self {
69 Self {
70 shared: Arc::new(Shared {
71 state: RwLock::new(State {
72 value,
73 epoch: INITIAL_EPOCH,
74 }),
75 watchers: Default::default(),
76 }),
77 }
78 }
79
80 pub fn set(&self, value: T) -> Result<T, T> {
87 let mut state = self.shared.state.write().expect("poisoned");
91
92 let changed = state.value != value;
94
95 let ret = if changed {
96 let old = std::mem::replace(&mut state.value, value);
97 state.epoch += 1;
98 Ok(old)
99 } else {
100 Err(value)
101 };
102 drop(state); if changed {
106 for watcher in self.shared.watchers.lock().expect("poisoned").drain(..) {
107 watcher.wake();
108 }
109 }
110 ret
111 }
112
113 pub fn watch(&self) -> Direct<T> {
115 Direct {
116 epoch: self.shared.state.read().expect("poisoned").epoch,
117 shared: Arc::downgrade(&self.shared),
118 }
119 }
120
121 pub fn get(&self) -> T {
123 self.shared.get()
124 }
125}
126
127pub trait Watcher: Clone {
146 type Value: Clone + Eq;
155
156 fn get(&self) -> Result<Self::Value, Disconnected>;
159
160 fn poll_updated(
163 &mut self,
164 cx: &mut task::Context<'_>,
165 ) -> Poll<Result<Self::Value, Disconnected>>;
166
167 fn updated(&mut self) -> NextFut<Self> {
174 NextFut { watcher: self }
175 }
176
177 fn initialized<T, W>(&mut self) -> InitializedFut<T, W, Self>
188 where
189 W: Nullable<T>,
190 Self: Watcher<Value = W>,
191 {
192 InitializedFut {
193 initial: match self.get() {
194 Ok(value) => value.into_option().map(Ok),
195 Err(Disconnected) => Some(Err(Disconnected)),
196 },
197 watcher: self,
198 }
199 }
200
201 fn stream(self) -> Stream<Self>
215 where
216 Self: Unpin,
217 {
218 Stream {
219 initial: self.get().ok(),
220 watcher: self,
221 }
222 }
223
224 fn stream_updates_only(self) -> Stream<Self>
239 where
240 Self: Unpin,
241 {
242 Stream {
243 initial: None,
244 watcher: self,
245 }
246 }
247
248 fn map<T: Clone + Eq>(
253 self,
254 map: impl Fn(Self::Value) -> T + Send + Sync + 'static,
255 ) -> Result<Map<Self, T>, Disconnected> {
256 Ok(Map {
257 current: (map)(self.get()?),
258 map: Arc::new(map),
259 watcher: self,
260 })
261 }
262
263 fn or<W: Watcher>(self, other: W) -> (Self, W) {
266 (self, other)
267 }
268}
269
270#[derive(Debug, Clone)]
274pub struct Direct<T> {
275 epoch: u64,
276 shared: Weak<Shared<T>>,
277}
278
279impl<T: Clone + Eq> Watcher for Direct<T> {
280 type Value = T;
281
282 fn get(&self) -> Result<Self::Value, Disconnected> {
283 let shared = self.shared.upgrade().ok_or(Disconnected)?;
284 Ok(shared.get())
285 }
286
287 fn poll_updated(
288 &mut self,
289 cx: &mut task::Context<'_>,
290 ) -> Poll<Result<Self::Value, Disconnected>> {
291 let Some(shared) = self.shared.upgrade() else {
292 return Poll::Ready(Err(Disconnected));
293 };
294 match shared.poll_updated(cx, self.epoch) {
295 Poll::Pending => Poll::Pending,
296 Poll::Ready((current_epoch, value)) => {
297 self.epoch = current_epoch;
298 Poll::Ready(Ok(value))
299 }
300 }
301 }
302}
303
304impl<S: Watcher, T: Watcher> Watcher for (S, T) {
305 type Value = (S::Value, T::Value);
306
307 fn get(&self) -> Result<Self::Value, Disconnected> {
308 Ok((self.0.get()?, self.1.get()?))
309 }
310
311 fn poll_updated(
312 &mut self,
313 cx: &mut task::Context<'_>,
314 ) -> Poll<Result<Self::Value, Disconnected>> {
315 let poll_0 = self.0.poll_updated(cx)?;
316 let poll_1 = self.1.poll_updated(cx)?;
317 match (poll_0, poll_1) {
318 (Poll::Ready(s), Poll::Ready(t)) => Poll::Ready(Ok((s, t))),
319 (Poll::Ready(s), Poll::Pending) => Poll::Ready(self.1.get().map(move |t| (s, t))),
320 (Poll::Pending, Poll::Ready(t)) => Poll::Ready(self.0.get().map(move |s| (s, t))),
321 (Poll::Pending, Poll::Pending) => Poll::Pending,
322 }
323 }
324}
325
326#[derive(Debug, Clone)]
328pub struct Join<T: Clone + Eq, W: Watcher<Value = T>> {
329 watchers: Vec<W>,
330}
331impl<T: Clone + Eq, W: Watcher<Value = T>> Join<T, W> {
332 pub fn new(watchers: impl Iterator<Item = W>) -> Self {
334 let watchers: Vec<W> = watchers.into_iter().collect();
335
336 Self { watchers }
337 }
338}
339
340impl<T: Clone + Eq, W: Watcher<Value = T>> Watcher for Join<T, W> {
341 type Value = Vec<T>;
342
343 fn get(&self) -> Result<Self::Value, Disconnected> {
344 let mut out = Vec::with_capacity(self.watchers.len());
345 for watcher in &self.watchers {
346 out.push(watcher.get()?);
347 }
348
349 Ok(out)
350 }
351
352 fn poll_updated(
353 &mut self,
354 cx: &mut task::Context<'_>,
355 ) -> Poll<Result<Self::Value, Disconnected>> {
356 let mut new_value = None;
357 for (i, watcher) in self.watchers.iter_mut().enumerate() {
358 match watcher.poll_updated(cx) {
359 Poll::Pending => {}
360 Poll::Ready(Ok(value)) => {
361 new_value.replace((i, value));
362 break;
363 }
364 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
365 }
366 }
367
368 if let Some((j, new_value)) = new_value {
369 let mut new = Vec::with_capacity(self.watchers.len());
370 for (i, watcher) in self.watchers.iter().enumerate() {
371 if i != j {
372 new.push(watcher.get()?);
373 } else {
374 new.push(new_value.clone());
375 }
376 }
377 Poll::Ready(Ok(new))
378 } else {
379 Poll::Pending
380 }
381 }
382}
383
384#[derive(derive_more::Debug, Clone)]
388pub struct Map<W: Watcher, T: Clone + Eq> {
389 #[debug("Arc<dyn Fn(W::Value) -> T + 'static>")]
390 map: Arc<dyn Fn(W::Value) -> T + Send + Sync + 'static>,
391 watcher: W,
392 current: T,
393}
394
395impl<W: Watcher, T: Clone + Eq> Watcher for Map<W, T> {
396 type Value = T;
397
398 fn get(&self) -> Result<Self::Value, Disconnected> {
399 Ok((self.map)(self.watcher.get()?))
400 }
401
402 fn poll_updated(
403 &mut self,
404 cx: &mut task::Context<'_>,
405 ) -> Poll<Result<Self::Value, Disconnected>> {
406 loop {
407 let value = ready!(self.watcher.poll_updated(cx)?);
408 let mapped = (self.map)(value);
409 if mapped != self.current {
410 self.current = mapped.clone();
411 return Poll::Ready(Ok(mapped));
412 } else {
413 self.current = mapped;
414 }
415 }
416 }
417}
418
419#[derive(Debug)]
427pub struct NextFut<'a, W: Watcher> {
428 watcher: &'a mut W,
429}
430
431impl<W: Watcher> Future for NextFut<'_, W> {
432 type Output = Result<W::Value, Disconnected>;
433
434 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
435 self.watcher.poll_updated(cx)
436 }
437}
438
439#[derive(Debug)]
448pub struct InitializedFut<'a, T, V: Nullable<T>, W: Watcher<Value = V>> {
449 initial: Option<Result<T, Disconnected>>,
450 watcher: &'a mut W,
451}
452
453impl<T: Clone + Eq + Unpin, V: Nullable<T>, W: Watcher<Value = V> + Unpin> Future
454 for InitializedFut<'_, T, V, W>
455{
456 type Output = Result<T, Disconnected>;
457
458 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
459 if let Some(value) = self.as_mut().initial.take() {
460 return Poll::Ready(value);
461 }
462 loop {
463 let value = ready!(self.as_mut().watcher.poll_updated(cx)?);
464 if let Some(value) = value.into_option() {
465 return Poll::Ready(Ok(value));
466 }
467 }
468 }
469}
470
471#[derive(Debug, Clone)]
479pub struct Stream<W: Watcher + Unpin> {
480 initial: Option<W::Value>,
481 watcher: W,
482}
483
484impl<W: Watcher + Unpin> n0_future::Stream for Stream<W>
485where
486 W::Value: Unpin,
487{
488 type Item = W::Value;
489
490 fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Option<Self::Item>> {
491 if let Some(value) = self.as_mut().initial.take() {
492 return Poll::Ready(Some(value));
493 }
494 match self.as_mut().watcher.poll_updated(cx) {
495 Poll::Ready(Ok(value)) => Poll::Ready(Some(value)),
496 Poll::Ready(Err(Disconnected)) => Poll::Ready(None),
497 Poll::Pending => Poll::Pending,
498 }
499 }
500}
501
502#[derive(Snafu, Debug)]
505#[snafu(display("Watcher lost connection to underlying Watchable, it was dropped"))]
506pub struct Disconnected;
507
508const INITIAL_EPOCH: u64 = 1;
511
512#[derive(Debug, Default)]
514struct Shared<T> {
515 state: RwLock<State<T>>,
517 watchers: Mutex<VecDeque<Waker>>,
518}
519
520#[derive(Debug)]
521struct State<T> {
522 value: T,
523 epoch: u64,
524}
525
526impl<T: Default> Default for State<T> {
527 fn default() -> Self {
528 Self {
529 value: Default::default(),
530 epoch: INITIAL_EPOCH,
531 }
532 }
533}
534
535impl<T: Clone> Shared<T> {
536 fn get(&self) -> T {
538 self.state.read().expect("poisoned").value.clone()
539 }
540
541 fn poll_updated(&self, cx: &mut task::Context<'_>, last_epoch: u64) -> Poll<(u64, T)> {
542 {
543 let state = self.state.read().expect("poisoned");
544 let epoch = state.epoch;
545
546 if last_epoch < epoch {
547 return Poll::Ready((epoch, state.value.clone()));
550 }
551 }
552
553 self.watchers
554 .lock()
555 .expect("poisoned")
556 .push_back(cx.waker().to_owned());
557
558 #[cfg(watcher_loom)]
559 loom::thread::yield_now();
560
561 {
562 let state = self.state.read().expect("poisoned");
563 let epoch = state.epoch;
564
565 if last_epoch < epoch {
566 return Poll::Ready((epoch, state.value.clone()));
569 }
570 }
571
572 Poll::Pending
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use std::time::{Duration, Instant};
579
580 use n0_future::{future::poll_once, StreamExt};
581 use rand::{thread_rng, Rng};
582 use tokio::task::JoinSet;
583 use tokio_util::sync::CancellationToken;
584
585 use super::*;
586
587 #[tokio::test]
588 async fn test_watcher() {
589 let cancel = CancellationToken::new();
590 let watchable = Watchable::new(17);
591
592 assert_eq!(watchable.watch().stream().next().await.unwrap(), 17);
593
594 let start = Instant::now();
595 let mut tasks = JoinSet::new();
597 for i in 0..3 {
598 let mut watch = watchable.watch().stream();
599 let cancel = cancel.clone();
600 tasks.spawn(async move {
601 println!("[{i}] spawn");
602 let mut expected_value = 17;
603 loop {
604 tokio::select! {
605 biased;
606 Some(value) = &mut watch.next() => {
607 println!("{:?} [{i}] update: {value}", start.elapsed());
608 assert_eq!(value, expected_value);
609 if expected_value == 17 {
610 expected_value = 0;
611 } else {
612 expected_value += 1;
613 }
614 },
615 _ = cancel.cancelled() => {
616 println!("{:?} [{i}] cancel", start.elapsed());
617 assert_eq!(expected_value, 10);
618 break;
619 }
620 }
621 }
622 });
623 }
624 for i in 0..3 {
625 let mut watch = watchable.watch().stream_updates_only();
626 let cancel = cancel.clone();
627 tasks.spawn(async move {
628 println!("[{i}] spawn");
629 let mut expected_value = 0;
630 loop {
631 tokio::select! {
632 biased;
633 Some(value) = watch.next() => {
634 println!("{:?} [{i}] stream update: {value}", start.elapsed());
635 assert_eq!(value, expected_value);
636 expected_value += 1;
637 },
638 _ = cancel.cancelled() => {
639 println!("{:?} [{i}] cancel", start.elapsed());
640 assert_eq!(expected_value, 10);
641 break;
642 }
643 else => {
644 panic!("stream died");
645 }
646 }
647 }
648 });
649 }
650
651 for next_value in 0..10 {
653 let sleep = Duration::from_nanos(thread_rng().gen_range(0..100_000_000));
654 println!("{:?} sleep {sleep:?}", start.elapsed());
655 tokio::time::sleep(sleep).await;
656
657 let changed = watchable.set(next_value);
658 println!("{:?} set {next_value} changed={changed:?}", start.elapsed());
659 }
660
661 println!("cancel");
662 cancel.cancel();
663 while let Some(res) = tasks.join_next().await {
664 res.expect("task failed");
665 }
666 }
667
668 #[test]
669 fn test_get() {
670 let watchable = Watchable::new(None);
671 assert!(watchable.get().is_none());
672
673 watchable.set(Some(1u8)).ok();
674 assert_eq!(watchable.get(), Some(1u8));
675 }
676
677 #[tokio::test]
678 async fn test_initialize() {
679 let watchable = Watchable::new(None);
680
681 let mut watcher = watchable.watch();
682 let mut initialized = watcher.initialized();
683
684 let poll = poll_once(&mut initialized).await;
685 assert!(poll.is_none());
686
687 watchable.set(Some(1u8)).ok();
688
689 let poll = poll_once(&mut initialized).await;
690 assert_eq!(poll.unwrap().unwrap(), 1u8);
691 }
692
693 #[tokio::test]
694 async fn test_initialize_already_init() {
695 let watchable = Watchable::new(Some(1u8));
696
697 let mut watcher = watchable.watch();
698 let mut initialized = watcher.initialized();
699
700 let poll = poll_once(&mut initialized).await;
701 assert_eq!(poll.unwrap().unwrap(), 1u8);
702 }
703
704 #[test]
705 fn test_initialized_always_resolves() {
706 #[cfg(not(watcher_loom))]
707 use std::thread;
708
709 #[cfg(watcher_loom)]
710 use loom::thread;
711
712 let test_case = || {
713 let watchable = Watchable::<Option<u8>>::new(None);
714
715 let mut watch = watchable.watch();
716 let thread = thread::spawn(move || n0_future::future::block_on(watch.initialized()));
717
718 watchable.set(Some(42)).ok();
719
720 thread::yield_now();
721
722 let value: u8 = thread.join().unwrap().unwrap();
723
724 assert_eq!(value, 42);
725 };
726
727 #[cfg(watcher_loom)]
728 loom::model(test_case);
729 #[cfg(not(watcher_loom))]
730 test_case();
731 }
732
733 #[tokio::test(flavor = "multi_thread")]
734 async fn test_update_cancel_safety() {
735 let watchable = Watchable::new(0);
736 let mut watch = watchable.watch();
737 const MAX: usize = 100_000;
738
739 let handle = tokio::spawn(async move {
740 let mut last_observed = 0;
741
742 while last_observed != MAX {
743 tokio::select! {
744 val = watch.updated() => {
745 let Ok(val) = val else {
746 return;
747 };
748
749 assert_ne!(val, last_observed, "never observe the same value twice, even with cancellation");
750 last_observed = val;
751 }
752 _ = tokio::time::sleep(Duration::from_micros(thread_rng().gen_range(0..10_000))) => {
753 continue;
755 }
756 }
757 }
758 });
759
760 for i in 1..=MAX {
761 watchable.set(i).ok();
762 if thread_rng().gen_bool(0.2) {
763 tokio::task::yield_now().await;
764 }
765 }
766
767 tokio::time::timeout(Duration::from_secs(10), handle)
768 .await
769 .unwrap()
770 .unwrap()
771 }
772
773 #[tokio::test]
774 async fn test_join_simple() {
775 let a = Watchable::new(1u8);
776 let b = Watchable::new(1u8);
777
778 let ab = Join::new([a.watch(), b.watch()].into_iter());
779
780 let stream = ab.clone().stream();
781 let handle = tokio::task::spawn(async move { stream.take(5).collect::<Vec<_>>().await });
782
783 assert_eq!(ab.get().unwrap(), vec![1, 1]);
785 a.set(2u8).unwrap();
787 tokio::task::yield_now().await;
788 assert_eq!(ab.get().unwrap(), vec![2, 1]);
789 b.set(3u8).unwrap();
791 tokio::task::yield_now().await;
792 assert_eq!(ab.get().unwrap(), vec![2, 3]);
793
794 a.set(3u8).unwrap();
795 tokio::task::yield_now().await;
796 b.set(4u8).unwrap();
797 tokio::task::yield_now().await;
798
799 let values = tokio::time::timeout(Duration::from_secs(5), handle)
800 .await
801 .unwrap()
802 .unwrap();
803 assert_eq!(
804 values,
805 vec![vec![1, 1], vec![2, 1], vec![2, 3], vec![3, 3], vec![3, 4]]
806 );
807 }
808}