use std::cell::RefCell;
use std::collections::HashMap;
use std::hash::Hash as StdHash;
use tokio::sync::mpsc;
pub type UpdatesOnly = bool;
pub trait Watched {
type Value: Clone;
fn current(&self) -> Self::Value;
fn update_if_changed(&self, cmp: &Self::Value) -> UpdateResult<Self::Value>;
}
pub enum UpdateResult<V>
where
V: Clone,
{
Unchanged,
Changed(WatchedValue<V>),
}
#[derive(Clone, Debug, PartialEq)]
pub struct WatchedValue<V>
where
V: Clone,
{
pub difference: Option<V>,
pub value: V,
}
pub type WatcherSender<V> = mpsc::UnboundedSender<WatchedValue<V>>;
pub type WatcherReceiver<V> = mpsc::UnboundedReceiver<WatchedValue<V>>;
pub struct Watcher<T>
where
T: Watched,
{
watched: T,
subscribers: RefCell<Vec<WatcherSender<T::Value>>>,
}
impl<T> Watcher<T>
where
T: Watched,
{
pub fn new(initial: T) -> Self {
Self {
watched: initial,
subscribers: RefCell::new(Vec::new()),
}
}
pub fn update(&self, value: T::Value) {
if let UpdateResult::Changed(result) = self.watched.update_if_changed(&value) {
self.notify(result);
}
}
pub fn subscribe(&self, updates_only: UpdatesOnly) -> WatcherReceiver<T::Value> {
let (tx, rx) = mpsc::unbounded_channel();
if !updates_only {
let _ = tx.send(WatchedValue {
difference: Some(self.watched.current()),
value: self.watched.current(),
});
}
let mut subscribers = self.subscribers.borrow_mut();
subscribers.push(tx);
rx
}
pub fn is_empty(&self) -> bool {
self.subscribers.borrow().is_empty()
}
fn notify(&self, value: WatchedValue<T::Value>) {
let mut subscribers = self.subscribers.borrow_mut();
subscribers.retain(|tx| tx.send(value.clone()).is_ok());
}
}
pub struct WatcherSet<K, T>
where
T: Watched,
{
watchers: RefCell<HashMap<K, Watcher<T>>>,
}
impl<K, T> Default for WatcherSet<K, T>
where
K: Eq + StdHash,
T: Watched,
{
fn default() -> Self {
Self::new()
}
}
impl<K, T> WatcherSet<K, T>
where
K: Eq + StdHash,
T: Watched,
{
pub fn new() -> Self {
Self {
watchers: RefCell::new(HashMap::new()),
}
}
#[allow(clippy::map_entry, reason = "it's easier to read")]
pub fn subscribe(
&self,
key: K,
updates_only: UpdatesOnly,
initial: T,
) -> WatcherReceiver<T::Value> {
let mut watchers = self.watchers.borrow_mut();
if watchers.contains_key(&key) {
let watcher = watchers.get_mut(&key).expect("we've checked it exists");
watcher.subscribe(updates_only)
} else {
let watcher = Watcher::new(initial);
let rx = watcher.subscribe(updates_only);
watchers.insert(key, watcher);
rx
}
}
pub fn update(&self, key: &K, value: T::Value) {
let mut watchers = self.watchers.borrow_mut();
if let Some(watcher) = watchers.get_mut(key) {
watcher.update(value);
if watcher.is_empty() {
watchers.remove(key);
}
}
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::collections::HashSet;
use tokio::sync::mpsc::error::TryRecvError;
use super::{UpdateResult, Watched, WatchedValue, Watcher};
#[test]
fn subscribe_to_changes() {
struct WatchedSet(RefCell<HashSet<u64>>);
impl WatchedSet {
pub fn new(set: HashSet<u64>) -> Self {
Self(RefCell::new(set))
}
}
impl Watched for WatchedSet {
type Value = HashSet<u64>;
fn current(&self) -> Self::Value {
self.0.borrow().clone()
}
fn update_if_changed(&self, cmp: &Self::Value) -> UpdateResult<Self::Value> {
let difference: HashSet<u64> =
self.0.borrow().symmetric_difference(cmp).cloned().collect();
if difference.is_empty() {
UpdateResult::Unchanged
} else {
self.0.replace(cmp.to_owned());
UpdateResult::Changed(WatchedValue {
difference: Some(difference),
value: cmp.to_owned(),
})
}
}
}
let set = WatchedSet::new(HashSet::from_iter([1, 2, 3]));
let watcher = Watcher::new(set);
let mut updates_only_rx = watcher.subscribe(true);
let mut rx = watcher.subscribe(false);
assert!(matches!(
updates_only_rx.try_recv(),
Err(TryRecvError::Empty)
));
let result = rx.try_recv().expect("should return Ok");
assert_eq!(result.value, HashSet::from_iter([1, 2, 3]),);
assert_eq!(result.difference, Some(result.value));
watcher.update(HashSet::from_iter([1, 2, 3]));
assert!(matches!(
updates_only_rx.try_recv(),
Err(TryRecvError::Empty)
));
assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
watcher.update(HashSet::from_iter([1, 2, 3, 4]));
let result_1 = rx.try_recv().expect("should return Ok");
let result_2 = updates_only_rx.try_recv().expect("should return Ok");
assert_eq!(result_1, result_2);
assert_eq!(result_1.value, HashSet::from_iter([1, 2, 3, 4]),);
assert_eq!(result_1.difference, Some(HashSet::from_iter([4])));
watcher.update(HashSet::from_iter([1, 2, 3]));
let result_1 = rx.try_recv().expect("should return Ok");
let result_2 = updates_only_rx.try_recv().expect("should return Ok");
assert_eq!(result_1, result_2);
assert_eq!(result_1.value, HashSet::from_iter([1, 2, 3]),);
assert_eq!(result_1.difference, Some(HashSet::from_iter([4])));
}
}