use tokio::sync::watch;
use super::versioned::{ConfigVersion, VersionedConfig};
pub struct ConfigWatchSender<T> {
inner: watch::Sender<VersionedConfig<T>>,
}
impl<T> ConfigWatchSender<T> {
pub fn publish(&self, value: VersionedConfig<T>) -> Result<(), ()> {
self.inner.send(value).map_err(|_| ())
}
pub fn receiver_count(&self) -> usize {
self.inner.receiver_count()
}
}
pub struct ConfigWatch<T> {
inner: watch::Receiver<VersionedConfig<T>>,
}
impl<T> Clone for ConfigWatch<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<T> ConfigWatch<T> {
pub fn current(&self) -> VersionedConfig<T>
where
T: Clone,
{
self.inner.borrow().clone()
}
pub fn current_version(&self) -> ConfigVersion {
self.inner.borrow().version()
}
pub async fn changed(&mut self) -> Result<VersionedConfig<T>, ()>
where
T: Clone,
{
self.inner.changed().await.map_err(|_| ())?;
Ok(self.inner.borrow().clone())
}
pub fn mark_seen(&mut self) {
self.inner.mark_unchanged();
}
}
pub fn channel<T>(initial: VersionedConfig<T>) -> (ConfigWatchSender<T>, ConfigWatch<T>) {
let (tx, rx) = watch::channel(initial);
(ConfigWatchSender { inner: tx }, ConfigWatch { inner: rx })
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn vcfg(v: u64, x: u32) -> VersionedConfig<u32> {
VersionedConfig::new(x, ConfigVersion(v))
}
#[test]
fn current_returns_initial() {
let (_tx, rx) = channel(vcfg(0, 42));
assert_eq!(rx.current_version(), ConfigVersion(0));
assert_eq!(*rx.current().value(), 42);
}
#[tokio::test]
async fn changed_returns_after_publish() {
let (tx, mut rx) = channel(vcfg(0, 42));
rx.mark_seen();
let publisher = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
tx.publish(vcfg(1, 99)).unwrap();
});
let next = rx.changed().await.unwrap();
assert_eq!(next.version(), ConfigVersion(1));
assert_eq!(*next.value(), 99);
publisher.await.unwrap();
}
#[tokio::test]
async fn changed_returns_err_when_sender_dropped() {
let (tx, mut rx) = channel(vcfg(0, 42));
rx.mark_seen();
drop(tx);
let err = rx.changed().await;
assert!(err.is_err());
}
#[test]
fn receiver_count_reflects_clones() {
let (tx, rx) = channel(vcfg(0, 42));
assert_eq!(tx.receiver_count(), 1);
let _rx2 = rx.clone();
assert_eq!(tx.receiver_count(), 2);
drop(rx);
assert_eq!(tx.receiver_count(), 1);
}
#[test]
fn publish_with_no_receivers_returns_err() {
let (tx, rx) = channel(vcfg(0, 42));
drop(rx);
let result = tx.publish(vcfg(1, 99));
assert!(result.is_err());
}
#[tokio::test]
async fn multiple_receivers_all_see_same_version() {
let (tx, rx) = channel(vcfg(0, 42));
let mut rx1 = rx.clone();
let mut rx2 = rx.clone();
rx1.mark_seen();
rx2.mark_seen();
tx.publish(vcfg(1, 99)).unwrap();
let v1 = rx1.changed().await.unwrap();
let v2 = rx2.changed().await.unwrap();
assert_eq!(v1.version(), v2.version());
}
#[tokio::test]
async fn slow_consumer_sees_only_latest_value() {
let (tx, mut rx) = channel(vcfg(0, 42));
rx.mark_seen();
tx.publish(vcfg(1, 1)).unwrap();
tx.publish(vcfg(2, 2)).unwrap();
tx.publish(vcfg(3, 3)).unwrap();
let v = rx.changed().await.unwrap();
assert_eq!(v.version(), ConfigVersion(3));
assert_eq!(*v.value(), 3);
}
#[test]
fn current_clone_shares_arc() {
let (_tx, rx) = channel(vcfg(0, 42));
let a = rx.current();
let b = rx.current();
assert!(std::sync::Arc::ptr_eq(&a.arc(), &b.arc()));
}
}