1use futures::stream::Stream;
47use slab::Slab;
48use std::{
49 fmt,
50 ops::DerefMut,
51 sync::{Arc, Mutex, MutexGuard},
52 task::{Poll, Waker},
53};
54
55const INITIAL_VERSION: u128 = 1;
60
61#[derive(Clone)]
107pub struct Observable<T>
108where
109 T: Clone,
110{
111 inner: Arc<Mutex<Inner<T>>>,
112 version: u128,
113 waker_id: Option<usize>,
114}
115
116impl<T> Observable<T>
117where
118 T: Clone,
119{
120 pub fn new(value: T) -> Self {
122 Observable {
123 inner: Arc::new(Mutex::new(Inner::new(value))),
124 version: INITIAL_VERSION,
125 waker_id: None,
126 }
127 }
128
129 pub fn publish(&mut self, value: T) {
131 self.modify(|v| *v = value);
132 }
133
134 pub fn modify<M>(&mut self, modify: M)
136 where
137 M: FnOnce(&mut T),
138 {
139 self.modify_conditional(|_| true, modify);
140 }
141
142 pub fn modify_conditional<C, M>(&mut self, condition: C, modify: M) -> bool
160 where
161 C: FnOnce(&T) -> bool,
162 M: FnOnce(&mut T),
163 {
164 self.apply(|value| {
165 if condition(value) {
166 modify(value);
167 true
168 } else {
169 false
170 }
171 })
172 }
173
174 #[doc(hidden)]
203 pub(crate) fn apply<F>(&mut self, change: F) -> bool
204 where
205 F: FnOnce(&mut T) -> bool,
206 {
207 let mut inner = self.lock();
208
209 if !change(&mut inner.value) {
210 return false;
211 }
212
213 inner.version += 1;
214
215 for ref waker in inner.waker.iter() {
216 waker.1.wake_by_ref();
217 }
218
219 inner.waker.clear();
220
221 true
222 }
223
224 pub fn clone_and_reset(&self) -> Observable<T> {
237 Self {
238 inner: self.inner.clone(),
239 version: 0,
240 waker_id: None,
241 }
242 }
243
244 pub fn reset(&mut self) {
257 self.version = 0;
258 }
259
260 pub fn latest(&self) -> T {
275 let inner = self.lock();
276 inner.value.clone()
277 }
278
279 #[inline]
297 pub async fn next(&mut self) -> T {
298 futures::StreamExt::next(self)
299 .await
300 .expect("internal implementation error: observable update streams cannot end")
301 }
302
303 pub fn synchronize(&mut self) -> T {
321 let (value, version) = {
322 let inner = self.lock();
323 (inner.value.clone(), inner.version)
324 };
325
326 self.version = version;
327 value
328 }
329
330 pub fn split(self) -> (Self, Self) {
348 (self.clone(), self)
349 }
350
351 pub(crate) fn lock(&self) -> MutexGuard<Inner<T>> {
352 match self.inner.lock() {
353 Ok(guard) => guard,
354 Err(e) => e.into_inner(),
355 }
356 }
357
358 #[cfg(test)]
359 pub(crate) fn waker_count(&self) -> usize {
360 self.inner.lock().unwrap().waker.len()
361 }
362}
363
364impl<T> Observable<T>
365where
366 T: Clone + PartialEq,
367{
368 pub fn publish_if_changed(&mut self, value: T) -> bool {
372 self.apply(|v| {
373 if *v != value {
374 *v = value;
375 true
376 } else {
377 false
378 }
379 })
380 }
381}
382
383impl<T> PartialEq for Observable<T>
384where
385 T: Clone + PartialEq,
386{
387 fn eq(&self, other: &Self) -> bool {
388 self.latest() == other.latest()
389 }
390}
391
392impl<T> Eq for Observable<T> where T: Clone + PartialEq + Eq {}
393
394impl<T> From<T> for Observable<T>
395where
396 T: Clone,
397{
398 fn from(value: T) -> Self {
400 Observable::new(value)
401 }
402}
403
404impl<T> fmt::Debug for Observable<T>
405where
406 T: Clone + fmt::Debug,
407{
408 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
409 let inner = self.lock();
410
411 f.debug_struct("Observable")
412 .field("inner", &inner)
413 .field("version", &self.version)
414 .finish()
415 }
416}
417
418impl<T> Stream for Observable<T>
419where
420 T: Clone,
421{
422 type Item = T;
423
424 fn poll_next(
425 mut self: std::pin::Pin<&mut Self>,
426 cx: &mut std::task::Context<'_>,
427 ) -> Poll<Option<Self::Item>> {
428 let mut guard = self.lock();
429 let inner = guard.deref_mut();
430
431 if self.version == inner.version {
432 if let Some(waker) = self.waker_id {
433 inner.waker.try_remove(waker);
434 }
435
436 let waker_id = inner.waker.insert(cx.waker().clone());
437
438 drop(guard);
439
440 self.waker_id = Some(waker_id);
441
442 Poll::Pending
443 } else {
444 if let Some(waker) = self.waker_id {
445 inner.waker.try_remove(waker);
446 }
447
448 let (version, value) = (inner.version, inner.value.clone());
449
450 drop(guard);
451
452 self.waker_id = None;
453 self.version = version;
454
455 Poll::Ready(Some(value))
456 }
457 }
458}
459
460impl<T> Drop for Observable<T>
461where
462 T: Clone,
463{
464 fn drop(&mut self) {
465 if let Some(waker) = self.waker_id {
466 let mut guard = self.lock();
467 let inner = guard.deref_mut();
468 inner.waker.try_remove(waker);
469 }
470 }
471}
472
473#[cfg(feature = "serde")]
474impl<T> serde::Serialize for Observable<T>
476where
477 T: serde::Serialize + Clone,
478{
479 #[inline]
480 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
481 where
482 S: serde::Serializer,
483 {
484 self.latest().serialize(serializer)
485 }
486}
487
488#[cfg(feature = "serde")]
489impl<'de, T> serde::Deserialize<'de> for Observable<T>
491where
492 T: Clone + serde::Deserialize<'de>,
493{
494 #[inline]
495 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
496 where
497 D: serde::Deserializer<'de>,
498 {
499 T::deserialize(deserializer).map(Into::into)
500 }
501}
502
503struct Inner<T>
504where
505 T: Clone,
506{
507 version: u128,
508 value: T,
509 waker: Slab<Waker>,
510}
511
512impl<T> Inner<T>
513where
514 T: Clone,
515{
516 fn new(value: T) -> Self {
517 Self {
518 version: INITIAL_VERSION,
519 value,
520 waker: Slab::new(),
521 }
522 }
523}
524
525impl<T> fmt::Debug for Inner<T>
526where
527 T: Clone + fmt::Debug,
528{
529 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530 f.debug_struct("Inner")
531 .field("value", &self.value)
532 .field("version", &self.version)
533 .finish()
534 }
535}
536
537#[cfg(test)]
538mod test {
539 use super::Observable;
540 use async_std::future::timeout;
541 use async_std::task::{sleep, spawn};
542 use std::time::Duration;
543
544 const SLEEP_DURATION: Duration = Duration::from_millis(25);
545 const TIMEOUT_DURATION: Duration = Duration::from_millis(500);
546
547 mod publishing {
548 use super::*;
549 use async_std::test;
550
551 #[test]
552 async fn should_get_notified_sync() {
553 let mut int = Observable::new(1);
554 let mut other = int.clone();
555
556 int.publish(2);
557 assert_eq!(other.next().await, 2);
558 int.publish(3);
559 assert_eq!(other.next().await, 3);
560 int.publish(0);
561 assert_eq!(other.next().await, 0);
562 }
563
564 #[test]
565 async fn should_get_notified_sync_multiple() {
566 let mut int = Observable::new(1);
567 let mut fork_one = int.clone();
568 let mut fork_two = int.clone();
569
570 int.publish(2);
571 assert_eq!(fork_one.next().await, 2);
572 assert_eq!(fork_two.next().await, 2);
573
574 int.publish(3);
575 assert_eq!(fork_one.next().await, 3);
576 assert_eq!(fork_two.next().await, 3);
577
578 int.publish(0);
579 assert_eq!(fork_one.next().await, 0);
580 assert_eq!(fork_two.next().await, 0);
581 }
582
583 #[test]
584 async fn should_publish_after_modify() {
585 let mut int = Observable::new(1);
586 let mut fork = int.clone();
587
588 int.modify(|i| *i += 1);
589 assert_eq!(fork.next().await, 2);
590
591 int.modify(|i| *i += 1);
592 assert_eq!(fork.next().await, 3);
593
594 int.modify(|i| *i -= 2);
595 assert_eq!(fork.next().await, 1);
596
597 int.modify(|i| *i -= 2);
598 assert_eq!(fork.next().await, -1);
599 }
600
601 #[test]
602 async fn should_conditionally_modify() {
603 let mut int = Observable::new(1);
604
605 let modified = int.modify_conditional(|i| i % 2 == 0, |i| *i *= 2);
606 assert!(!modified);
607 assert_eq!(int.latest(), 1);
608
609 let modified = int.modify_conditional(|i| i % 2 == 1, |i| *i *= 2);
610 assert!(modified);
611 assert_eq!(int.latest(), 2);
612
613 let modified = int.modify_conditional(|i| i % 2 == 0, |i| *i = 1000);
614 assert!(modified);
615 assert_eq!(int.latest(), 1000);
616 }
617
618 #[test]
619 async fn shouldnt_publish_same_change() {
620 let mut int = Observable::new(1);
621 let published = int.publish_if_changed(1);
622 assert!(!published);
623 assert!(timeout(TIMEOUT_DURATION, int.next()).await.is_err());
624 }
625
626 #[test]
627 async fn should_publish_changed() {
628 let mut int = Observable::new(1);
629
630 let published = int.publish_if_changed(2);
631 assert!(published);
632 assert_eq!(int.synchronize(), 2);
633
634 let published = int.publish_if_changed(2);
635 assert!(!published);
636 assert!(timeout(TIMEOUT_DURATION, int.next()).await.is_err());
637 }
638 }
639
640 mod versions {
641 use super::*;
642 use async_std::test;
643
644 #[test]
645 async fn should_skip_versions() {
646 let mut int = Observable::new(1);
647 let mut fork = int.clone();
648
649 int.publish(2);
650 int.publish(3);
651 int.publish(0);
652
653 assert_eq!(fork.next().await, 0);
654 }
655
656 #[test]
657 async fn should_wait_after_skiped_versions() {
658 let mut int = Observable::new(1);
659 let mut fork = int.clone();
660
661 int.publish(2);
662 int.publish(3);
663 int.publish(0);
664
665 assert_eq!(fork.next().await, 0);
666 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
667 }
668
669 #[test]
670 async fn should_skip_unchecked_updates() {
671 let mut int = Observable::new(1);
672 let mut fork = int.clone();
673
674 int.publish(2);
675 assert_eq!(fork.next().await, 2);
676 int.publish(3);
677 int.publish(0);
678 assert_eq!(fork.next().await, 0);
679 }
680
681 #[test]
682 async fn should_clone_and_reset() {
683 let int = Observable::new(1);
684 let mut fork = int.clone_and_reset();
685 assert_eq!(fork.next().await, 1);
686 }
687
688 #[test]
689 async fn should_reset() {
690 let (_int, mut fork) = Observable::new(1).split();
691 fork.reset();
692 assert_eq!(fork.next().await, 1);
693 }
694 }
695
696 mod asynchronous {
697 use super::*;
698 use async_std::test;
699
700 #[test]
701 async fn should_wait_for_publisher_task() {
702 let mut int = Observable::new(1);
703 let mut fork = int.clone();
704
705 spawn(async move {
706 sleep(SLEEP_DURATION).await;
707 int.publish(2);
708 sleep(SLEEP_DURATION).await;
709 int.publish(3);
710 sleep(SLEEP_DURATION).await;
711 int.publish(0);
712 });
713
714 assert_eq!(fork.next().await, 2);
715 assert_eq!(fork.next().await, 3);
716 assert_eq!(fork.next().await, 0);
717 }
718 }
719
720 mod synchronization {
721 use super::*;
722 use async_std::test;
723
724 #[test]
725 async fn should_get_latest_without_loosing_updates() {
726 let mut int = Observable::new(1);
727 let mut fork = int.clone();
728
729 int.publish(2);
730
731 assert_eq!(fork.latest(), 2);
732 assert_eq!(fork.latest(), 2);
733
734 assert_eq!(fork.next().await, 2);
735 }
736
737 #[test]
738 async fn should_skip_updates_while_synchronizing() {
739 let mut int = Observable::new(1);
740 let mut fork = int.clone();
741
742 int.publish(2);
743 int.publish(3);
744
745 assert_eq!(fork.synchronize(), 3);
746
747 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
748 }
749
750 #[test]
751 async fn should_synchronize_multiple_times() {
752 let mut int = Observable::new(1);
753 let mut fork = int.clone();
754
755 int.publish(2);
756 int.publish(3);
757
758 assert_eq!(fork.synchronize(), 3);
759 assert_eq!(fork.synchronize(), 3);
760
761 int.publish(4);
762
763 assert_eq!(fork.synchronize(), 4);
764
765 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
766 }
767 }
768
769 mod future {
770 use super::*;
771 use async_std::test;
772
773 #[test]
774 async fn should_remove_waker_after_resolving() {
775 let mut int = Observable::new(1);
776 let mut fork = int.clone();
777
778 for _ in 0..100 {
779 int.publish(1);
780 timeout(Duration::from_millis(10), fork.next()).await.ok();
781
782 assert_eq!(fork.waker_id, None);
783 assert_eq!(int.waker_count(), 0);
784 }
785 }
786
787 #[test]
788 async fn should_wait_forever() {
789 let int = Observable::new(1);
790 let mut fork = int.clone();
791
792 assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
793 }
794 }
795
796 #[cfg(feature = "serde")]
797 mod serde {
798 use super::*;
799 use async_std::test;
800 use serde_derive::*;
801
802 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
803 struct Foo {
804 uint: Observable<u8>,
805 string: Observable<String>,
806 }
807
808 #[test]
809 async fn should_serialize_and_deserialize() {
810 let data = Foo {
811 uint: 1.into(),
812 string: "bar".to_owned().into(),
813 };
814
815 let serialized: String = serde_json::to_string(&data).unwrap();
816 assert_eq!(serialized, r#"{"uint":1,"string":"bar"}"#);
817
818 let deserialized: Foo = serde_json::from_str(&serialized).unwrap();
819 assert_eq!(
820 deserialized,
821 Foo {
822 uint: 1.into(),
823 string: "bar".to_owned().into()
824 }
825 );
826 }
827
828 #[test]
829 async fn should_serialize_latest() {
830 let (uint, mut other) = Observable::new(1).split();
831
832 let data = Foo {
833 uint,
834 string: "bar".to_owned().into(),
835 };
836
837 other.publish(2);
838
839 let serialized: String = serde_json::to_string(&data).unwrap();
840 assert_eq!(serialized, r#"{"uint":2,"string":"bar"}"#);
841 }
842 }
843}