use std::{
collections::HashMap,
fmt,
ops::DerefMut,
sync::{Arc, Mutex, MutexGuard},
task::Poll,
};
use futures::stream::Stream;
use futures::task::AtomicWaker;
use uuid::Uuid;
const INITIAL_VERSION: u128 = 1;
pub struct Observable<T>
where
T: Clone,
{
inner: Arc<Mutex<Inner<T>>>,
waker: u128,
version: u128,
}
impl<T> Clone for Observable<T>
where
T: Clone,
{
fn clone(&self) -> Self {
Self {
waker: Uuid::new_v4().as_u128(),
inner: self.inner.clone(),
version: self.version,
}
}
}
impl<T> Observable<T>
where
T: Clone,
{
pub fn new(value: T) -> Self {
Observable {
waker: Uuid::new_v4().as_u128(),
inner: Arc::new(Mutex::new(Inner::new(value))),
version: INITIAL_VERSION,
}
}
pub fn publish(&mut self, value: T) {
self.modify(|v| *v = value);
}
pub fn modify<M>(&mut self, modify: M)
where
M: FnOnce(&mut T),
{
self.modify_conditional(|_| true, modify);
}
pub fn try_modify<M, O, E>(&mut self, modify: M) -> Result<O, E>
where
M: FnOnce(&mut T) -> Result<O, E>,
{
self.try_apply(modify)
}
pub fn modify_conditional<C, M>(&mut self, condition: C, modify: M) -> bool
where
C: FnOnce(&T) -> bool,
M: FnOnce(&mut T),
{
self.apply(|value| {
if condition(value) {
modify(value);
true
} else {
false
}
})
}
#[doc(hidden)]
pub(crate) fn try_apply<F, O, E>(&mut self, change: F) -> Result<O, E>
where
F: FnOnce(&mut T) -> Result<O, E>,
{
let mut inner = self.lock();
let mut value = inner.value.clone();
let output = change(&mut value)?;
inner.value = value;
inner.version += 1;
for (_, waker) in inner.waker.iter() {
waker.wake();
}
inner.waker.clear();
Ok(output)
}
#[doc(hidden)]
pub(crate) fn apply<F>(&mut self, change: F) -> bool
where
F: FnOnce(&mut T) -> bool,
{
self.try_apply(|m| {
if change(m) {
return Ok(());
}
Err(())
})
.is_ok()
}
pub fn clone_and_reset(&self) -> Observable<T> {
Self {
waker: Uuid::new_v4().as_u128(),
inner: self.inner.clone(),
version: 0,
}
}
pub fn reset(&mut self) {
self.version = 0;
}
pub fn latest(&self) -> T {
let inner = self.lock();
inner.value.clone()
}
#[inline]
pub async fn next(&mut self) -> T {
futures::StreamExt::next(self)
.await
.expect("internal implementation error: observable update streams cannot end")
}
pub fn synchronize(&mut self) -> T {
let (value, version) = {
let inner = self.lock();
(inner.value.clone(), inner.version)
};
self.version = version;
value
}
pub fn split(self) -> (Self, Self) {
(self.clone(), self)
}
pub(crate) fn lock<'a>(&'a self) -> MutexGuard<'a, Inner<T>> {
match self.inner.lock() {
Ok(guard) => guard,
Err(e) => e.into_inner(),
}
}
#[cfg(test)]
pub(crate) fn waker_count(&self) -> usize {
self.inner.lock().unwrap().waker.len()
}
}
impl<T> Observable<T>
where
T: Clone + PartialEq,
{
pub fn publish_if_changed(&mut self, value: T) -> bool {
self.apply(|v| {
if *v != value {
*v = value;
true
} else {
false
}
})
}
}
impl<T> PartialEq for Observable<T>
where
T: Clone + PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.latest() == other.latest()
}
}
impl<T> Eq for Observable<T> where T: Clone + PartialEq + Eq {}
impl<T> From<T> for Observable<T>
where
T: Clone,
{
fn from(value: T) -> Self {
Observable::new(value)
}
}
impl<T> fmt::Debug for Observable<T>
where
T: Clone + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let inner = self.lock();
f.debug_struct("Observable")
.field("inner", &inner)
.field("version", &self.version)
.finish()
}
}
impl<T> Stream for Observable<T>
where
T: Clone,
{
type Item = T;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let mut guard = self.lock();
let inner = guard.deref_mut();
if self.version == inner.version {
inner
.waker
.entry(self.waker)
.and_modify(|w| {
w.register(cx.waker());
})
.or_insert_with(|| {
let waker = AtomicWaker::new();
waker.register(cx.waker());
waker
});
drop(guard);
Poll::Pending
} else {
inner.waker.remove(&self.waker);
let (version, value) = (inner.version, inner.value.clone());
drop(guard);
self.version = version;
Poll::Ready(Some(value))
}
}
}
#[cfg(feature = "serde")]
impl<T> serde::Serialize for Observable<T>
where
T: serde::Serialize + Clone,
{
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.latest().serialize(serializer)
}
}
#[cfg(feature = "serde")]
impl<'de, T> serde::Deserialize<'de> for Observable<T>
where
T: Clone + serde::Deserialize<'de>,
{
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
T::deserialize(deserializer).map(Into::into)
}
}
struct Inner<T>
where
T: Clone,
{
version: u128,
value: T,
waker: HashMap<u128, AtomicWaker>,
}
impl<T> Inner<T>
where
T: Clone,
{
fn new(value: T) -> Self {
Self {
version: INITIAL_VERSION,
value,
waker: Default::default(),
}
}
}
impl<T> fmt::Debug for Inner<T>
where
T: Clone + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Inner")
.field("value", &self.value)
.field("version", &self.version)
.finish()
}
}
#[cfg(test)]
mod test {
use super::Observable;
use async_std::future::timeout;
use async_std::task::{sleep, spawn};
use std::time::Duration;
const SLEEP_DURATION: Duration = Duration::from_millis(25);
const TIMEOUT_DURATION: Duration = Duration::from_millis(500);
mod publishing {
use super::*;
use async_std::test;
#[test]
async fn should_get_notified_sync() {
let mut int = Observable::new(1);
let mut other = int.clone();
int.publish(2);
assert_eq!(other.next().await, 2);
int.publish(3);
assert_eq!(other.next().await, 3);
int.publish(0);
assert_eq!(other.next().await, 0);
}
#[test]
async fn should_get_notified_sync_multiple() {
let mut int = Observable::new(1);
let mut fork_one = int.clone();
let mut fork_two = int.clone();
int.publish(2);
assert_eq!(fork_one.next().await, 2);
assert_eq!(fork_two.next().await, 2);
int.publish(3);
assert_eq!(fork_one.next().await, 3);
assert_eq!(fork_two.next().await, 3);
int.publish(0);
assert_eq!(fork_one.next().await, 0);
assert_eq!(fork_two.next().await, 0);
}
#[test]
async fn should_publish_after_modify() {
let mut int = Observable::new(1);
let mut fork = int.clone();
int.modify(|i| *i += 1);
assert_eq!(fork.next().await, 2);
int.modify(|i| *i += 1);
assert_eq!(fork.next().await, 3);
int.modify(|i| *i -= 2);
assert_eq!(fork.next().await, 1);
int.modify(|i| *i -= 2);
assert_eq!(fork.next().await, -1);
}
#[test]
async fn should_conditionally_modify() {
let mut int = Observable::new(1);
let modified = int.modify_conditional(|i| i % 2 == 0, |i| *i *= 2);
assert!(!modified);
assert_eq!(int.latest(), 1);
let modified = int.modify_conditional(|i| i % 2 == 1, |i| *i *= 2);
assert!(modified);
assert_eq!(int.latest(), 2);
let modified = int.modify_conditional(|i| i % 2 == 0, |i| *i = 1000);
assert!(modified);
assert_eq!(int.latest(), 1000);
}
#[test]
async fn shouldnt_publish_same_change() {
let mut int = Observable::new(1);
let published = int.publish_if_changed(1);
assert!(!published);
assert!(timeout(TIMEOUT_DURATION, int.next()).await.is_err());
}
#[test]
async fn should_publish_changed() {
let mut int = Observable::new(1);
let published = int.publish_if_changed(2);
assert!(published);
assert_eq!(int.synchronize(), 2);
let published = int.publish_if_changed(2);
assert!(!published);
assert!(timeout(TIMEOUT_DURATION, int.next()).await.is_err());
}
}
mod versions {
use super::*;
use async_std::test;
#[test]
async fn should_skip_versions() {
let mut int = Observable::new(1);
let mut fork = int.clone();
int.publish(2);
int.publish(3);
int.publish(0);
assert_eq!(fork.next().await, 0);
}
#[test]
async fn should_wait_after_skiped_versions() {
let mut int = Observable::new(1);
let mut fork = int.clone();
int.publish(2);
int.publish(3);
int.publish(0);
assert_eq!(fork.next().await, 0);
assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
}
#[test]
async fn should_skip_unchecked_updates() {
let mut int = Observable::new(1);
let mut fork = int.clone();
int.publish(2);
assert_eq!(fork.next().await, 2);
int.publish(3);
int.publish(0);
assert_eq!(fork.next().await, 0);
}
#[test]
async fn should_clone_and_reset() {
let int = Observable::new(1);
let mut fork = int.clone_and_reset();
assert_eq!(fork.next().await, 1);
}
#[test]
async fn should_reset() {
let (_int, mut fork) = Observable::new(1).split();
fork.reset();
assert_eq!(fork.next().await, 1);
}
}
mod asynchronous {
use super::*;
use async_std::test;
#[test]
async fn should_wait_for_publisher_task() {
let mut int = Observable::new(1);
let mut fork = int.clone();
spawn(async move {
sleep(SLEEP_DURATION).await;
int.publish(2);
sleep(SLEEP_DURATION).await;
int.publish(3);
sleep(SLEEP_DURATION).await;
int.publish(0);
});
assert_eq!(fork.next().await, 2);
assert_eq!(fork.next().await, 3);
assert_eq!(fork.next().await, 0);
}
}
mod synchronization {
use super::*;
use async_std::test;
#[test]
async fn should_get_latest_without_loosing_updates() {
let mut int = Observable::new(1);
let mut fork = int.clone();
int.publish(2);
assert_eq!(fork.latest(), 2);
assert_eq!(fork.latest(), 2);
assert_eq!(fork.next().await, 2);
}
#[test]
async fn should_skip_updates_while_synchronizing() {
let mut int = Observable::new(1);
let mut fork = int.clone();
int.publish(2);
int.publish(3);
assert_eq!(fork.synchronize(), 3);
assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
}
#[test]
async fn should_synchronize_multiple_times() {
let mut int = Observable::new(1);
let mut fork = int.clone();
int.publish(2);
int.publish(3);
assert_eq!(fork.synchronize(), 3);
assert_eq!(fork.synchronize(), 3);
int.publish(4);
assert_eq!(fork.synchronize(), 4);
assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
}
}
mod future {
use super::*;
use futures::task::{noop_waker, Context};
use futures::Stream;
use std::pin::Pin;
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::Arc;
use std::task::Poll;
use std::thread;
use std::time::Duration;
struct TestWaker {
called: Arc<AtomicU16>,
}
impl futures::task::ArcWake for TestWaker {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.called.fetch_add(1, Ordering::SeqCst);
}
}
#[async_std::test]
async fn should_remove_waker_after_resolving() {
let mut int = Observable::new(1);
let mut fork = int.clone();
for _ in 0..100 {
int.publish(1);
timeout(Duration::from_millis(10), fork.next()).await.ok();
assert_eq!(int.waker_count(), 0);
}
}
#[async_std::test]
async fn should_wait_forever() {
let int = Observable::new(1);
let mut fork = int.clone();
assert!(timeout(TIMEOUT_DURATION, fork.next()).await.is_err());
}
#[test]
fn supports_multiple_polls_before_data() {
let mut observable = Observable::new(0);
let mut fork = observable.clone();
let called = Arc::new(AtomicU16::new(0));
let waker = futures::task::waker(Arc::new(TestWaker {
called: called.clone(),
}));
let mut cx = Context::from_waker(&waker);
let poll1 = Pin::new(&mut fork).poll_next(&mut cx);
assert_eq!(poll1, Poll::Pending);
assert_eq!(fork.waker_count(), 1);
let poll2 = Pin::new(&mut fork).poll_next(&mut cx);
assert_eq!(poll2, Poll::Pending);
assert_eq!(fork.waker_count(), 1);
let poll3 = Pin::new(&mut fork).poll_next(&mut cx);
assert_eq!(poll3, Poll::Pending);
assert_eq!(fork.waker_count(), 1);
observable.publish(42);
assert_eq!(
called.load(Ordering::SeqCst),
1,
"Waker was not called after publishing data!"
);
called.store(0, Ordering::SeqCst);
let poll4 = Pin::new(&mut fork).poll_next(&mut cx);
assert_eq!(poll4, Poll::Ready(Some(42)));
assert_eq!(fork.waker_count(), 0);
}
#[test]
fn supports_waker_survival_across_multiple_polls() {
let mut observable = Observable::new(0);
let mut fork = observable.clone();
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
for i in 0..10 {
let poll = Pin::new(&mut fork).poll_next(&mut cx);
assert_eq!(poll, Poll::Pending, "Poll {} should return Pending", i);
assert_eq!(
fork.waker_count(),
1,
"Should have exactly 1 waker after poll {}",
i
);
}
observable.publish(99);
let last = Pin::new(&mut fork).poll_next(&mut cx);
assert_eq!(last, Poll::Ready(Some(99)));
}
#[async_std::test]
async fn supports_concurrent_poll_and_publish() {
let mut observable = Observable::new(0);
let mut fork = observable.clone();
let called = Arc::new(AtomicU16::new(0));
let waker = futures::task::waker(Arc::new(TestWaker {
called: called.clone(),
}));
let handle = async_std::task::spawn(async move {
for _ in 0..100 {
{
let mut cx = Context::from_waker(&waker);
let _ = Pin::new(&mut fork).poll_next(&mut cx);
}
async_std::task::sleep(Duration::from_millis(1)).await;
}
fork
});
thread::spawn(move || {
thread::sleep(Duration::from_millis(25));
observable.publish(123);
});
handle.await;
assert_eq!(called.load(Ordering::SeqCst), 1);
}
}
#[cfg(feature = "serde")]
mod serde {
use super::*;
use async_std::test;
use serde_derive::*;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
struct Foo {
uint: Observable<u8>,
string: Observable<String>,
}
#[test]
async fn should_serialize_and_deserialize() {
let data = Foo {
uint: 1.into(),
string: "bar".to_owned().into(),
};
let serialized: String = serde_json::to_string(&data).unwrap();
assert_eq!(serialized, r#"{"uint":1,"string":"bar"}"#);
let deserialized: Foo = serde_json::from_str(&serialized).unwrap();
assert_eq!(
deserialized,
Foo {
uint: 1.into(),
string: "bar".to_owned().into()
}
);
}
#[test]
async fn should_serialize_latest() {
let (uint, mut other) = Observable::new(1).split();
let data = Foo {
uint,
string: "bar".to_owned().into(),
};
other.publish(2);
let serialized: String = serde_json::to_string(&data).unwrap();
assert_eq!(serialized, r#"{"uint":2,"string":"bar"}"#);
}
}
}