1use std::{
47 collections::HashMap,
48 fmt,
49 ops::DerefMut,
50 sync::{Arc, Mutex, MutexGuard},
51 task::Poll,
52};
53
54use futures::stream::Stream;
55use futures::task::AtomicWaker;
56use uuid::Uuid;
57
58const INITIAL_VERSION: u128 = 1;
63
64pub struct Observable<T>
110where
111 T: Clone,
112{
113 inner: Arc<Mutex<Inner<T>>>,
114 waker: u128,
115 version: u128,
116}
117
118impl<T> Clone for Observable<T>
119where
120 T: Clone,
121{
122 fn clone(&self) -> Self {
123 Self {
124 waker: Uuid::new_v4().as_u128(),
125 inner: self.inner.clone(),
126 version: self.version,
127 }
128 }
129}
130
131impl<T> Observable<T>
132where
133 T: Clone,
134{
135 pub fn new(value: T) -> Self {
137 Observable {
138 waker: Uuid::new_v4().as_u128(),
139 inner: Arc::new(Mutex::new(Inner::new(value))),
140 version: INITIAL_VERSION,
141 }
142 }
143
144 pub fn publish(&mut self, value: T) {
146 self.modify(|v| *v = value);
147 }
148
149 pub fn modify<M>(&mut self, modify: M)
151 where
152 M: FnOnce(&mut T),
153 {
154 self.modify_conditional(|_| true, modify);
155 }
156
157 pub fn try_modify<M, O, E>(&mut self, modify: M) -> Result<O, E>
176 where
177 M: FnOnce(&mut T) -> Result<O, E>,
178 {
179 self.try_apply(modify)
180 }
181
182 pub fn modify_conditional<C, M>(&mut self, condition: C, modify: M) -> bool
200 where
201 C: FnOnce(&T) -> bool,
202 M: FnOnce(&mut T),
203 {
204 self.apply(|value| {
205 if condition(value) {
206 modify(value);
207 true
208 } else {
209 false
210 }
211 })
212 }
213
214 #[doc(hidden)]
235 pub(crate) fn try_apply<F, O, E>(&mut self, change: F) -> Result<O, E>
236 where
237 F: FnOnce(&mut T) -> Result<O, E>,
238 {
239 let mut inner = self.lock();
240
241 let mut value = inner.value.clone();
242
243 let output = change(&mut value)?;
244
245 inner.value = value;
246 inner.version += 1;
247
248 for (_, waker) in inner.waker.iter() {
249 waker.wake();
250 }
251
252 inner.waker.clear();
253
254 Ok(output)
255 }
256
257 #[doc(hidden)]
286 pub(crate) fn apply<F>(&mut self, change: F) -> bool
287 where
288 F: FnOnce(&mut T) -> bool,
289 {
290 self.try_apply(|m| {
291 if change(m) {
292 return Ok(());
293 }
294
295 Err(())
296 })
297 .is_ok()
298 }
299
300 pub fn clone_and_reset(&self) -> Observable<T> {
313 Self {
314 waker: Uuid::new_v4().as_u128(),
315 inner: self.inner.clone(),
316 version: 0,
317 }
318 }
319
320 pub fn reset(&mut self) {
333 self.version = 0;
334 }
335
336 pub fn latest(&self) -> T {
351 let inner = self.lock();
352 inner.value.clone()
353 }
354
355 #[inline]
373 pub async fn next(&mut self) -> T {
374 futures::StreamExt::next(self)
375 .await
376 .expect("internal implementation error: observable update streams cannot end")
377 }
378
379 pub fn synchronize(&mut self) -> T {
397 let (value, version) = {
398 let inner = self.lock();
399
400 (inner.value.clone(), inner.version)
401 };
402
403 self.version = version;
404
405 value
406 }
407
408 pub fn split(self) -> (Self, Self) {
426 (self.clone(), self)
427 }
428
429 pub(crate) fn lock<'a>(&'a self) -> MutexGuard<'a, Inner<T>> {
430 match self.inner.lock() {
431 Ok(guard) => guard,
432 Err(e) => e.into_inner(),
433 }
434 }
435
436 #[cfg(test)]
437 pub(crate) fn waker_count(&self) -> usize {
438 self.inner.lock().unwrap().waker.len()
439 }
440}
441
442impl<T> Observable<T>
443where
444 T: Clone + PartialEq,
445{
446 pub fn publish_if_changed(&mut self, value: T) -> bool {
450 self.apply(|v| {
451 if *v != value {
452 *v = value;
453 true
454 } else {
455 false
456 }
457 })
458 }
459}
460
461impl<T> PartialEq for Observable<T>
462where
463 T: Clone + PartialEq,
464{
465 fn eq(&self, other: &Self) -> bool {
466 self.latest() == other.latest()
467 }
468}
469
470impl<T> Eq for Observable<T> where T: Clone + PartialEq + Eq {}
471
472impl<T> From<T> for Observable<T>
473where
474 T: Clone,
475{
476 fn from(value: T) -> Self {
478 Observable::new(value)
479 }
480}
481
482impl<T> fmt::Debug for Observable<T>
483where
484 T: Clone + fmt::Debug,
485{
486 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
487 let inner = self.lock();
488
489 f.debug_struct("Observable")
490 .field("inner", &inner)
491 .field("version", &self.version)
492 .finish()
493 }
494}
495
496impl<T> Stream for Observable<T>
497where
498 T: Clone,
499{
500 type Item = T;
501
502 fn poll_next(
503 mut self: std::pin::Pin<&mut Self>,
504 cx: &mut std::task::Context<'_>,
505 ) -> Poll<Option<Self::Item>> {
506 let mut guard = self.lock();
507
508 let inner = guard.deref_mut();
509
510 if self.version == inner.version {
511 inner
512 .waker
513 .entry(self.waker)
514 .and_modify(|w| {
515 w.register(cx.waker());
516 })
517 .or_insert_with(|| {
518 let waker = AtomicWaker::new();
519 waker.register(cx.waker());
520 waker
521 });
522
523 drop(guard);
524
525 Poll::Pending
526 } else {
527 inner.waker.remove(&self.waker);
528
529 let (version, value) = (inner.version, inner.value.clone());
530
531 drop(guard);
532
533 self.version = version;
534
535 Poll::Ready(Some(value))
536 }
537 }
538}
539
540#[cfg(feature = "serde")]
541impl<T> serde::Serialize for Observable<T>
543where
544 T: serde::Serialize + Clone,
545{
546 #[inline]
547 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
548 where
549 S: serde::Serializer,
550 {
551 self.latest().serialize(serializer)
552 }
553}
554
555#[cfg(feature = "serde")]
556impl<'de, T> serde::Deserialize<'de> for Observable<T>
558where
559 T: Clone + serde::Deserialize<'de>,
560{
561 #[inline]
562 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
563 where
564 D: serde::Deserializer<'de>,
565 {
566 T::deserialize(deserializer).map(Into::into)
567 }
568}
569
570struct Inner<T>
571where
572 T: Clone,
573{
574 version: u128,
575 value: T,
576 waker: HashMap<u128, AtomicWaker>,
577}
578
579impl<T> Inner<T>
580where
581 T: Clone,
582{
583 fn new(value: T) -> Self {
584 Self {
585 version: INITIAL_VERSION,
586 value,
587 waker: Default::default(),
588 }
589 }
590}
591
592impl<T> fmt::Debug for Inner<T>
593where
594 T: Clone + fmt::Debug,
595{
596 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
597 f.debug_struct("Inner")
598 .field("value", &self.value)
599 .field("version", &self.version)
600 .finish()
601 }
602}
603
604#[cfg(test)]
605mod test {
606 use super::Observable;
607 use async_std::future::timeout;
608 use async_std::task::{sleep, spawn};
609 use std::time::Duration;
610
611 const SLEEP_DURATION: Duration = Duration::from_millis(25);
612 const TIMEOUT_DURATION: Duration = Duration::from_millis(500);
613
614 mod publishing {
615 use super::*;
616 use async_std::test;
617
618 #[test]
619 async fn should_get_notified_sync() {
620 let mut int = Observable::new(1);
621 let mut other = int.clone();
622
623 int.publish(2);
624 assert_eq!(other.next().await, 2);
625 int.publish(3);
626 assert_eq!(other.next().await, 3);
627 int.publish(0);
628 assert_eq!(other.next().await, 0);
629 }
630
631 #[test]
632 async fn should_get_notified_sync_multiple() {
633 let mut int = Observable::new(1);
634 let mut fork_one = int.clone();
635 let mut fork_two = int.clone();
636
637 int.publish(2);
638 assert_eq!(fork_one.next().await, 2);
639 assert_eq!(fork_two.next().await, 2);
640
641 int.publish(3);
642 assert_eq!(fork_one.next().await, 3);
643 assert_eq!(fork_two.next().await, 3);
644
645 int.publish(0);
646 assert_eq!(fork_one.next().await, 0);
647 assert_eq!(fork_two.next().await, 0);
648 }
649
650 #[test]
651 async fn should_publish_after_modify() {
652 let mut int = Observable::new(1);
653 let mut fork = int.clone();
654
655 int.modify(|i| *i += 1);
656 assert_eq!(fork.next().await, 2);
657
658 int.modify(|i| *i += 1);
659 assert_eq!(fork.next().await, 3);
660
661 int.modify(|i| *i -= 2);
662 assert_eq!(fork.next().await, 1);
663
664 int.modify(|i| *i -= 2);
665 assert_eq!(fork.next().await, -1);
666 }
667
668 #[test]
669 async fn should_conditionally_modify() {
670 let mut int = Observable::new(1);
671
672 let modified = int.modify_conditional(|i| i % 2 == 0, |i| *i *= 2);
673 assert!(!modified);
674 assert_eq!(int.latest(), 1);
675
676 let modified = int.modify_conditional(|i| i % 2 == 1, |i| *i *= 2);
677 assert!(modified);
678 assert_eq!(int.latest(), 2);
679
680 let modified = int.modify_conditional(|i| i % 2 == 0, |i| *i = 1000);
681 assert!(modified);
682 assert_eq!(int.latest(), 1000);
683 }
684
685 #[test]
686 async fn shouldnt_publish_same_change() {
687 let mut int = Observable::new(1);
688 let published = int.publish_if_changed(1);
689 assert!(!published);
690 assert!(timeout(TIMEOUT_DURATION, int.next()).await.is_err());
691 }
692
693 #[test]
694 async fn should_publish_changed() {
695 let mut int = Observable::new(1);
696
697 let published = int.publish_if_changed(2);
698 assert!(published);
699 assert_eq!(int.synchronize(), 2);
700
701 let published = int.publish_if_changed(2);
702 assert!(!published);
703 assert!(timeout(TIMEOUT_DURATION, int.next()).await.is_err());
704 }
705 }
706
707 mod versions {
708 use super::*;
709 use async_std::test;
710
711 #[test]
712 async fn should_skip_versions() {
713 let mut int = Observable::new(1);
714 let mut fork = int.clone();
715
716 int.publish(2);
717 int.publish(3);
718 int.publish(0);
719
720 assert_eq!(fork.next().await, 0);
721 }
722
723 #[test]
724 async fn should_wait_after_skiped_versions() {
725 let mut int = Observable::new(1);
726 let mut fork = int.clone();
727
728 int.publish(2);
729 int.publish(3);
730 int.publish(0);
731
732 assert_eq!(fork.next().await, 0);
733 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
734 }
735
736 #[test]
737 async fn should_skip_unchecked_updates() {
738 let mut int = Observable::new(1);
739 let mut fork = int.clone();
740
741 int.publish(2);
742 assert_eq!(fork.next().await, 2);
743 int.publish(3);
744 int.publish(0);
745 assert_eq!(fork.next().await, 0);
746 }
747
748 #[test]
749 async fn should_clone_and_reset() {
750 let int = Observable::new(1);
751 let mut fork = int.clone_and_reset();
752 assert_eq!(fork.next().await, 1);
753 }
754
755 #[test]
756 async fn should_reset() {
757 let (_int, mut fork) = Observable::new(1).split();
758 fork.reset();
759 assert_eq!(fork.next().await, 1);
760 }
761 }
762
763 mod asynchronous {
764 use super::*;
765 use async_std::test;
766
767 #[test]
768 async fn should_wait_for_publisher_task() {
769 let mut int = Observable::new(1);
770 let mut fork = int.clone();
771
772 spawn(async move {
773 sleep(SLEEP_DURATION).await;
774 int.publish(2);
775 sleep(SLEEP_DURATION).await;
776 int.publish(3);
777 sleep(SLEEP_DURATION).await;
778 int.publish(0);
779 });
780
781 assert_eq!(fork.next().await, 2);
782 assert_eq!(fork.next().await, 3);
783 assert_eq!(fork.next().await, 0);
784 }
785 }
786
787 mod synchronization {
788 use super::*;
789 use async_std::test;
790
791 #[test]
792 async fn should_get_latest_without_loosing_updates() {
793 let mut int = Observable::new(1);
794 let mut fork = int.clone();
795
796 int.publish(2);
797
798 assert_eq!(fork.latest(), 2);
799 assert_eq!(fork.latest(), 2);
800
801 assert_eq!(fork.next().await, 2);
802 }
803
804 #[test]
805 async fn should_skip_updates_while_synchronizing() {
806 let mut int = Observable::new(1);
807 let mut fork = int.clone();
808
809 int.publish(2);
810 int.publish(3);
811
812 assert_eq!(fork.synchronize(), 3);
813
814 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
815 }
816
817 #[test]
818 async fn should_synchronize_multiple_times() {
819 let mut int = Observable::new(1);
820 let mut fork = int.clone();
821
822 int.publish(2);
823 int.publish(3);
824
825 assert_eq!(fork.synchronize(), 3);
826 assert_eq!(fork.synchronize(), 3);
827
828 int.publish(4);
829
830 assert_eq!(fork.synchronize(), 4);
831
832 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
833 }
834 }
835
836 mod future {
837 use super::*;
838 use futures::task::{noop_waker, Context};
839 use futures::Stream;
840 use std::pin::Pin;
841 use std::sync::atomic::{AtomicU16, Ordering};
842 use std::sync::Arc;
843 use std::task::Poll;
844 use std::thread;
845 use std::time::Duration;
846
847 struct TestWaker {
848 called: Arc<AtomicU16>,
849 }
850
851 impl futures::task::ArcWake for TestWaker {
852 fn wake_by_ref(arc_self: &Arc<Self>) {
853 arc_self.called.fetch_add(1, Ordering::SeqCst);
854 }
855 }
856
857 #[async_std::test]
858 async fn should_remove_waker_after_resolving() {
859 let mut int = Observable::new(1);
860 let mut fork = int.clone();
861
862 for _ in 0..100 {
863 int.publish(1);
864 timeout(Duration::from_millis(10), fork.next()).await.ok();
865
866 assert_eq!(int.waker_count(), 0);
868 }
869 }
870
871 #[async_std::test]
872 async fn should_wait_forever() {
873 let int = Observable::new(1);
874 let mut fork = int.clone();
875
876 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
877 }
878
879 #[test]
880 fn supports_multiple_polls_before_data() {
881 let mut observable = Observable::new(0);
882 let mut fork = observable.clone();
883
884 let called = Arc::new(AtomicU16::new(0));
885
886 let waker = futures::task::waker(Arc::new(TestWaker {
887 called: called.clone(),
888 }));
889 let mut cx = Context::from_waker(&waker);
890
891 let poll1 = Pin::new(&mut fork).poll_next(&mut cx);
892 assert_eq!(poll1, Poll::Pending);
893 assert_eq!(fork.waker_count(), 1);
894
895 let poll2 = Pin::new(&mut fork).poll_next(&mut cx);
896 assert_eq!(poll2, Poll::Pending);
897 assert_eq!(fork.waker_count(), 1);
898
899 let poll3 = Pin::new(&mut fork).poll_next(&mut cx);
900 assert_eq!(poll3, Poll::Pending);
901 assert_eq!(fork.waker_count(), 1);
902
903 observable.publish(42);
904
905 assert_eq!(
906 called.load(Ordering::SeqCst),
907 1,
908 "Waker was not called after publishing data!"
909 );
910
911 called.store(0, Ordering::SeqCst);
912
913 let poll4 = Pin::new(&mut fork).poll_next(&mut cx);
914 assert_eq!(poll4, Poll::Ready(Some(42)));
915 assert_eq!(fork.waker_count(), 0);
916 }
917
918 #[test]
919 fn supports_waker_survival_across_multiple_polls() {
920 let mut observable = Observable::new(0);
921 let mut fork = observable.clone();
922
923 let waker = noop_waker();
924 let mut cx = Context::from_waker(&waker);
925
926 for i in 0..10 {
927 let poll = Pin::new(&mut fork).poll_next(&mut cx);
928 assert_eq!(poll, Poll::Pending, "Poll {} should return Pending", i);
929
930 assert_eq!(
931 fork.waker_count(),
932 1,
933 "Should have exactly 1 waker after poll {}",
934 i
935 );
936 }
937
938 observable.publish(99);
939
940 let last = Pin::new(&mut fork).poll_next(&mut cx);
941 assert_eq!(last, Poll::Ready(Some(99)));
942 }
943
944 #[async_std::test]
945 async fn supports_concurrent_poll_and_publish() {
946 let mut observable = Observable::new(0);
947 let mut fork = observable.clone();
948
949 let called = Arc::new(AtomicU16::new(0));
950
951 let waker = futures::task::waker(Arc::new(TestWaker {
952 called: called.clone(),
953 }));
954
955 let handle = async_std::task::spawn(async move {
956 for _ in 0..100 {
957 {
958 let mut cx = Context::from_waker(&waker);
959 let _ = Pin::new(&mut fork).poll_next(&mut cx);
960 }
961 async_std::task::sleep(Duration::from_millis(1)).await;
962 }
963 fork
964 });
965
966 thread::spawn(move || {
967 thread::sleep(Duration::from_millis(25));
968 observable.publish(123);
969 });
970
971 handle.await;
972
973 assert_eq!(called.load(Ordering::SeqCst), 1);
974 }
975 }
976
977 #[cfg(feature = "serde")]
978 mod serde {
979 use super::*;
980 use async_std::test;
981 use serde_derive::*;
982
983 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
984 struct Foo {
985 uint: Observable<u8>,
986 string: Observable<String>,
987 }
988
989 #[test]
990 async fn should_serialize_and_deserialize() {
991 let data = Foo {
992 uint: 1.into(),
993 string: "bar".to_owned().into(),
994 };
995
996 let serialized: String = serde_json::to_string(&data).unwrap();
997 assert_eq!(serialized, r#"{"uint":1,"string":"bar"}"#);
998
999 let deserialized: Foo = serde_json::from_str(&serialized).unwrap();
1000 assert_eq!(
1001 deserialized,
1002 Foo {
1003 uint: 1.into(),
1004 string: "bar".to_owned().into()
1005 }
1006 );
1007 }
1008
1009 #[test]
1010 async fn should_serialize_latest() {
1011 let (uint, mut other) = Observable::new(1).split();
1012
1013 let data = Foo {
1014 uint,
1015 string: "bar".to_owned().into(),
1016 };
1017
1018 other.publish(2);
1019
1020 let serialized: String = serde_json::to_string(&data).unwrap();
1021 assert_eq!(serialized, r#"{"uint":2,"string":"bar"}"#);
1022 }
1023 }
1024}