#![doc = include_str!(".crate-docs.md")]
#![forbid(unsafe_code)]
#![warn(
clippy::cargo,
missing_docs,
// clippy::missing_docs_in_private_items,
clippy::pedantic,
future_incompatible,
rust_2018_idioms,
)]
#![allow(clippy::option_if_let_else, clippy::module_name_repetitions)]
use std::{
ops::{Deref, DerefMut},
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::Poll,
time::{Duration, Instant},
};
use event_listener::{Event, EventListener};
use futures_util::{FutureExt, Stream};
use parking_lot::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard};
#[derive(Default, Debug)]
pub struct Watchable<T> {
data: Arc<Data<T>>,
}
impl<T> Clone for Watchable<T> {
fn clone(&self) -> Self {
self.data.watchables.fetch_add(1, Ordering::AcqRel);
Self {
data: self.data.clone(),
}
}
}
impl<T> Drop for Watchable<T> {
fn drop(&mut self) {
if self.data.watchables.fetch_sub(1, Ordering::AcqRel) == 1 {
self.shutdown();
}
}
}
impl<T> Watchable<T> {
pub fn new(initial_value: T) -> Self {
Self {
data: Arc::new(Data {
value: RwLock::new(initial_value),
changed: RwLock::new(Some(Event::new())),
version: AtomicUsize::new(0),
watchers: AtomicUsize::new(0),
watchables: AtomicUsize::new(1),
}),
}
}
pub fn watch(&self) -> Watcher<T> {
self.data.watchers.fetch_add(1, Ordering::AcqRel);
Watcher {
version: AtomicUsize::new(self.data.current_version()),
watched: self.data.clone(),
}
}
pub fn replace(&self, new_value: T) -> T {
let mut stored = self.data.value.write();
let mut old_value = new_value;
std::mem::swap(&mut *stored, &mut old_value);
self.data.increment_version();
old_value
}
pub fn update(&self, new_value: T) -> Result<T, T>
where
T: PartialEq,
{
let stored = self.data.value.upgradable_read();
if *stored == new_value {
Err(new_value)
} else {
let mut stored = RwLockUpgradableReadGuard::upgrade(stored);
let mut old_value = new_value;
std::mem::swap(&mut *stored, &mut old_value);
self.data.increment_version();
Ok(old_value)
}
}
pub fn write(&self) -> WatchableWriteGuard<'_, T> {
WatchableWriteGuard {
watchable: self,
guard: self.data.value.write(),
accessed_mut: false,
}
}
pub fn read(&self) -> WatchableReadGuard<'_, T> {
WatchableReadGuard(self.data.value.read())
}
#[must_use]
pub fn get(&self) -> T
where
T: Clone,
{
self.data.value.read().clone()
}
#[must_use]
pub fn watchers(&self) -> usize {
self.data.watchers.load(Ordering::Acquire)
}
#[must_use]
pub fn has_watchers(&self) -> bool {
self.watchers() > 0
}
pub fn shutdown(&self) {
let mut changed = self.data.changed.write();
if let Some(changed) = changed.take() {
changed.notify(usize::MAX);
}
}
}
impl<T> Data<T> {
fn current_version(&self) -> usize {
self.version.load(Ordering::Acquire)
}
fn increment_version(&self) {
self.version.fetch_add(1, Ordering::AcqRel);
let changed = self.changed.read();
if let Some(changed) = changed.as_ref() {
changed.notify(usize::MAX);
}
}
}
#[must_use]
pub struct WatchableReadGuard<'a, T>(RwLockReadGuard<'a, T>);
impl<'a, T> Deref for WatchableReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[must_use]
pub struct WatchableWriteGuard<'a, T> {
watchable: &'a Watchable<T>,
accessed_mut: bool,
guard: RwLockWriteGuard<'a, T>,
}
impl<'a, T> Deref for WatchableWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.guard
}
}
impl<'a, T> DerefMut for WatchableWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.accessed_mut = true;
&mut self.guard
}
}
impl<'a, T> Drop for WatchableWriteGuard<'a, T> {
fn drop(&mut self) {
if self.accessed_mut {
self.watchable.data.increment_version();
}
}
}
#[derive(Debug)]
struct Data<T> {
changed: RwLock<Option<Event>>,
version: AtomicUsize,
watchers: AtomicUsize,
watchables: AtomicUsize,
value: RwLock<T>,
}
impl<T> Default for Data<T>
where
T: Default,
{
fn default() -> Self {
Self {
changed: RwLock::new(Some(Event::new())),
version: AtomicUsize::new(0),
watchers: AtomicUsize::new(0),
watchables: AtomicUsize::new(1),
value: RwLock::default(),
}
}
}
#[derive(Debug)]
#[must_use]
pub struct Watcher<T> {
version: AtomicUsize,
watched: Arc<Data<T>>,
}
impl<T> Drop for Watcher<T> {
fn drop(&mut self) {
self.watched.watchers.fetch_sub(1, Ordering::AcqRel);
}
}
impl<T> Clone for Watcher<T> {
fn clone(&self) -> Self {
Self {
version: AtomicUsize::new(self.version.load(Ordering::Relaxed)),
watched: self.watched.clone(),
}
}
}
#[derive(Debug)]
enum CreateListenerError {
NewValueAvailable,
Disconnected,
}
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
#[error("all watchable instances have been dropped")]
pub struct Disconnected;
#[derive(Debug, thiserror::Error, Eq, PartialEq)]
pub enum TimeoutError {
#[error("all watchable instances have been dropped")]
Disconnected,
#[error("no new values were written before the timeout elapsed")]
Timeout,
}
impl<T> Watcher<T> {
fn create_listener_if_needed(&self) -> Result<Pin<Box<EventListener>>, CreateListenerError> {
let changed = self.watched.changed.read();
match (changed.as_ref(), self.is_current()) {
(_, false) => Err(CreateListenerError::NewValueAvailable),
(None, _) => Err(CreateListenerError::Disconnected),
(Some(changed), true) => {
let listener = changed.listen();
if self.is_current() {
Ok(listener)
} else {
Err(CreateListenerError::NewValueAvailable)
}
}
}
}
#[must_use]
pub fn is_current(&self) -> bool {
self.version.load(Ordering::Relaxed) == self.watched.current_version()
}
pub fn mark_read(&self) -> bool {
let current_version = self.watched.current_version();
let mut stored_version = self.version.load(Ordering::Acquire);
while stored_version < current_version {
match self.version.compare_exchange(
stored_version,
current_version,
Ordering::Release,
Ordering::Acquire,
) {
Ok(_) => return true,
Err(new_stored) => stored_version = new_stored,
}
}
false
}
pub fn watch(&self) -> Result<(), Disconnected> {
loop {
match self.create_listener_if_needed() {
Ok(mut listener) => {
listener.as_mut().wait();
if !self.is_current() {
break;
}
}
Err(CreateListenerError::Disconnected) => return Err(Disconnected),
Err(CreateListenerError::NewValueAvailable) => break,
}
}
Ok(())
}
pub fn watch_timeout(&self, duration: Duration) -> Result<(), TimeoutError> {
self.watch_until(Instant::now() + duration)
}
pub fn watch_until(&self, deadline: Instant) -> Result<(), TimeoutError> {
loop {
match self.create_listener_if_needed() {
Ok(mut listener) => {
if listener.as_mut().wait_deadline(deadline).is_some() {
if !self.is_current() {
break;
} else if Instant::now() < deadline {
}
} else {
return Err(TimeoutError::Timeout);
}
}
Err(CreateListenerError::Disconnected) => return Err(TimeoutError::Disconnected),
Err(CreateListenerError::NewValueAvailable) => break,
}
}
Ok(())
}
pub async fn watch_async(&self) -> Result<(), Disconnected> {
loop {
match self.create_listener_if_needed() {
Ok(listener) => {
listener.await;
if !self.is_current() {
break;
}
}
Err(CreateListenerError::Disconnected) => return Err(Disconnected),
Err(CreateListenerError::NewValueAvailable) => break,
}
}
Ok(())
}
pub fn peek(&self) -> WatchableReadGuard<'_, T> {
let guard = self.watched.value.read();
WatchableReadGuard(guard)
}
pub fn read(&self) -> WatchableReadGuard<'_, T> {
let guard = self.watched.value.read();
self.version
.store(self.watched.current_version(), Ordering::Relaxed);
WatchableReadGuard(guard)
}
#[must_use]
pub fn get(&self) -> T
where
T: Clone,
{
self.read().clone()
}
pub fn next_value(&self) -> Result<T, Disconnected>
where
T: Clone,
{
self.watch().map(|()| self.read().clone())
}
pub async fn next_value_async(&self) -> Result<T, Disconnected>
where
T: Clone,
{
self.watch_async().await.map(|()| self.read().clone())
}
pub fn into_stream(self) -> WatcherStream<T> {
WatcherStream {
watcher: self,
listener: None,
}
}
}
impl<T> Iterator for Watcher<T>
where
T: Clone,
{
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.next_value().ok()
}
}
#[derive(Debug)]
#[must_use]
pub struct WatcherStream<T> {
watcher: Watcher<T>,
listener: Option<Pin<Box<EventListener>>>,
}
impl<T> WatcherStream<T> {
pub fn into_inner(self) -> Watcher<T> {
self.watcher
}
}
impl<T> Stream for WatcherStream<T>
where
T: Clone,
{
type Item = T;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
loop {
match self
.listener
.take()
.ok_or(CreateListenerError::Disconnected)
.or_else(|_| self.watcher.create_listener_if_needed())
{
Ok(mut listener) => {
match listener.poll_unpin(cx) {
Poll::Ready(()) => {
if !self.watcher.is_current() {
break;
}
}
Poll::Pending => {
self.listener = Some(listener);
return Poll::Pending;
}
}
}
Err(CreateListenerError::NewValueAvailable) => break,
Err(CreateListenerError::Disconnected) => return Poll::Ready(None),
}
}
Poll::Ready(Some(self.watcher.read().clone()))
}
}
#[test]
fn basics() {
let watchable = Watchable::new(1_u32);
assert!(!watchable.has_watchers());
let watcher1 = watchable.watch();
let watcher2 = watchable.watch();
assert!(!watcher1.mark_read());
assert_eq!(watchable.watchers(), 2);
assert_eq!(watchable.replace(2), 1);
watcher1.watch().unwrap();
assert_eq!(*watcher1.peek(), 2);
watcher1.watch().unwrap();
assert_eq!(*watcher1.read(), 2);
assert!(!watcher1.mark_read());
drop(watcher1);
assert_eq!(watchable.watchers(), 1);
assert!(watcher2.mark_read());
assert_eq!(*watcher2.read(), 2);
drop(watcher2);
assert_eq!(watchable.watchers(), 0);
}
#[test]
fn accessing_values() {
let watchable = Watchable::new(String::from("hello"));
assert_eq!(watchable.get(), "hello");
assert_eq!(&*watchable.read(), "hello");
assert_eq!(&*watchable.write(), "hello");
let watcher = watchable.watch();
assert_eq!(watcher.get(), "hello");
assert_eq!(&*watcher.read(), "hello");
}
#[test]
#[allow(clippy::redundant_clone)]
fn clones() {
let watchable = Watchable::default();
let cloned_watchable = watchable.clone();
let watcher1 = watchable.watch();
let watcher2 = watcher1.clone();
watchable.replace(1);
assert_eq!(watcher1.next_value().unwrap(), 1);
assert_eq!(watcher2.next_value().unwrap(), 1);
cloned_watchable.replace(2);
assert_eq!(watcher1.next_value().unwrap(), 2);
assert_eq!(watcher2.next_value().unwrap(), 2);
}
#[test]
fn drop_watchable() {
let watchable = Watchable::default();
assert!(!watchable.has_watchers());
let watcher = watchable.watch();
watchable.replace(1_u32);
assert_eq!(watcher.next_value().unwrap(), 1);
drop(watchable);
assert!(matches!(watcher.next_value().unwrap_err(), Disconnected));
}
#[test]
fn drop_watchable_timeouts() {
let watchable = Watchable::new(0_u8);
assert!(!watchable.has_watchers());
let watcher = watchable.watch();
let start = Instant::now();
let wait_timeout_thread = std::thread::spawn(move || {
assert!(matches!(
watcher.watch_timeout(Duration::from_secs(15)).unwrap_err(),
TimeoutError::Disconnected
));
});
let watcher = watchable.watch();
let wait_until_thread = std::thread::spawn(move || {
assert!(matches!(
watcher
.watch_until(Instant::now().checked_add(Duration::from_secs(15)).unwrap())
.unwrap_err(),
TimeoutError::Disconnected
));
});
std::thread::sleep(Duration::from_millis(100));
drop(watchable);
wait_timeout_thread.join().unwrap();
wait_until_thread.join().unwrap();
let elapsed = Instant::now().checked_duration_since(start).unwrap();
assert!(elapsed.as_secs() < 1);
}
#[test]
fn timeouts() {
let watchable = Watchable::new(1_u32);
let watcher = watchable.watch();
let start = Instant::now();
assert!(matches!(
watcher.watch_timeout(Duration::from_millis(100)),
Err(TimeoutError::Timeout)
));
assert!(matches!(
watcher.watch_until(Instant::now() + Duration::from_millis(100)),
Err(TimeoutError::Timeout)
));
let elapsed = Instant::now().checked_duration_since(start).unwrap();
assert!(elapsed.as_millis() >= 180);
watchable.replace(2);
watcher.watch_timeout(Duration::from_secs(1)).unwrap();
watchable.replace(3);
watcher
.watch_until(Instant::now() + Duration::from_secs(1))
.unwrap();
}
#[test]
fn deref_publish() {
let watchable = Watchable::new(1_u32);
let watcher = watchable.watch();
{
let write_guard = watchable.write();
assert_eq!(*write_guard, 1);
}
assert!(!watcher.mark_read());
{
let mut write_guard = watchable.write();
*write_guard = 2;
}
assert!(watcher.mark_read());
}
#[test]
fn blocking_tests() {
let watchable = Watchable::new(1_u32);
let watcher = watchable.watch();
let (sender, receiver) = std::sync::mpsc::sync_channel(1);
let worker_thread = std::thread::spawn(move || {
watcher.watch().unwrap();
assert_eq!(*watcher.read(), 2);
sender.send(()).unwrap();
watcher.watch().unwrap();
*watcher.read()
});
watchable.replace(2);
receiver.recv().unwrap();
assert!(watchable.update(42).is_ok());
assert!(watchable.update(42).is_err());
assert_eq!(worker_thread.join().unwrap(), 42);
}
#[test]
fn iterator_test() {
let watchable = Watchable::new(1_u32);
let watcher = watchable.watch();
let worker_thread = std::thread::spawn(move || {
let mut last_value = watcher.next_value().unwrap();
for value in watcher {
assert_ne!(last_value, value);
println!("Received {value}");
last_value = value;
}
assert_eq!(last_value, 1000);
});
for i in 1..=1000 {
watchable.replace(i);
}
drop(watchable);
worker_thread.join().unwrap();
}
#[cfg(test)]
#[tokio::test(flavor = "multi_thread")]
async fn stream_test() {
use futures_util::StreamExt;
let watchable = Watchable::default();
let watcher = watchable.watch();
let worker_thread = tokio::task::spawn(async move {
let mut last_value = watcher.next_value_async().await.unwrap();
let mut stream = watcher.into_stream();
while let Some(value) = stream.next().await {
assert_ne!(last_value, value);
println!("Received {value}");
last_value = value;
}
assert_eq!(last_value, 1000);
assert!(stream.next().await.is_none());
let watcher = stream.into_inner();
assert!(!watcher.mark_read());
});
for i in 1..=1000 {
watchable.replace(i);
if i % 100 == 0 {
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
drop(watchable);
worker_thread.await.unwrap();
}
#[test]
fn stress_test() {
let watchable = Watchable::new(1_u32);
let mut workers = Vec::new();
for _ in 1..=10 {
let watcher = watchable.watch();
workers.push(std::thread::spawn(move || {
let mut last_value = *watcher.read();
while watcher.watch().is_ok() {
let current_value = *watcher.read();
assert_ne!(last_value, current_value);
last_value = current_value;
}
assert_eq!(last_value, 10000);
}));
}
for i in 1..=10000 {
let _ = watchable.update(i);
}
drop(watchable);
for worker in workers {
worker.join().unwrap();
}
}
#[cfg(test)]
#[tokio::test(flavor = "multi_thread")]
async fn stress_test_async() {
let watchable = Watchable::new(1_u32);
let mut workers = Vec::new();
for _ in 1..=64 {
let watcher = watchable.watch();
workers.push(tokio::task::spawn(async move {
let mut last_value = *watcher.read();
loop {
watcher.watch_async().await.unwrap();
let current_value = *watcher.read();
assert_ne!(last_value, current_value);
if current_value == 10000 {
break;
}
last_value = current_value;
}
}));
}
tokio::task::spawn_blocking(move || {
for i in 1..=10000 {
let _ = watchable.update(i);
}
})
.await
.unwrap();
for worker in workers {
worker.await.unwrap();
}
}
#[test]
fn shutdown() {
let watchable = Watchable::new(0);
let watcher = watchable.watch();
watchable.replace(1);
watchable.shutdown();
assert_eq!(watcher.next_value().expect("initial value missing"), 1);
watcher
.next_value()
.expect_err("watcher should be disconnected");
}