use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use bytes::Bytes;
pub use d_engine_proto::client::{WatchEventType, WatchResponse as WatchEvent};
use dashmap::DashMap;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tracing::debug;
use tracing::trace;
use tracing::warn;
pub struct WatcherHandle {
id: u64,
key: Bytes,
receiver: mpsc::Receiver<WatchEvent>,
unregister_tx: Option<mpsc::UnboundedSender<(u64, Bytes)>>,
}
impl WatcherHandle {
pub fn id(&self) -> u64 {
self.id
}
pub fn key(&self) -> &Bytes {
&self.key
}
pub fn receiver_mut(&mut self) -> &mut mpsc::Receiver<WatchEvent> {
&mut self.receiver
}
pub fn into_receiver(mut self) -> (u64, Bytes, mpsc::Receiver<WatchEvent>) {
let id = self.id;
let key = self.key.clone();
self.unregister_tx = None;
let (dummy_tx, dummy_rx) = mpsc::channel(1);
drop(dummy_tx); let receiver = std::mem::replace(&mut self.receiver, dummy_rx);
(id, key, receiver)
}
}
impl Drop for WatcherHandle {
fn drop(&mut self) {
if let Some(ref tx) = self.unregister_tx {
let _ = tx.send((self.id, self.key.clone()));
trace!(watcher_id = self.id, key = ?self.key, "Watcher unregistered");
}
}
}
#[derive(Debug)]
struct Watcher {
id: u64,
sender: mpsc::Sender<WatchEvent>,
}
pub struct WatchRegistry {
watchers: DashMap<Bytes, Vec<Watcher>>,
next_id: AtomicU64,
watcher_buffer_size: usize,
unregister_tx: mpsc::UnboundedSender<(u64, Bytes)>,
}
impl WatchRegistry {
pub fn new(
watcher_buffer_size: usize,
unregister_tx: mpsc::UnboundedSender<(u64, Bytes)>,
) -> Self {
Self {
watchers: DashMap::new(),
next_id: AtomicU64::new(1),
watcher_buffer_size,
unregister_tx,
}
}
pub fn register(
&self,
key: Bytes,
) -> WatcherHandle {
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(self.watcher_buffer_size);
let watcher = Watcher { id, sender };
self.watchers.entry(key.clone()).or_default().push(watcher);
trace!(watcher_id = id, key = ?key, "Watcher registered");
WatcherHandle {
id,
key,
receiver,
unregister_tx: Some(self.unregister_tx.clone()),
}
}
fn unregister(
&self,
id: u64,
key: &Bytes,
) {
self.watchers.remove_if_mut(key, |_key, watchers| {
watchers.retain(|w| w.id != id);
watchers.is_empty()
});
}
#[cfg(test)]
pub(crate) fn watcher_count(
&self,
key: &Bytes,
) -> usize {
self.watchers.get(key).map(|w| w.len()).unwrap_or(0)
}
#[cfg(test)]
pub(crate) fn watched_key_count(&self) -> usize {
self.watchers.len()
}
}
pub struct WatchDispatcher {
registry: Arc<WatchRegistry>,
broadcast_rx: broadcast::Receiver<WatchEvent>,
unregister_rx: mpsc::UnboundedReceiver<(u64, Bytes)>,
}
impl WatchDispatcher {
pub fn new(
registry: Arc<WatchRegistry>,
broadcast_rx: broadcast::Receiver<WatchEvent>,
unregister_rx: mpsc::UnboundedReceiver<(u64, Bytes)>,
) -> Self {
Self {
registry,
broadcast_rx,
unregister_rx,
}
}
pub async fn run(mut self) {
debug!("WatchDispatcher started");
loop {
tokio::select! {
biased;
Some((id, key)) = self.unregister_rx.recv() => {
self.registry.unregister(id, &key);
}
result = self.broadcast_rx.recv() => {
match result {
Ok(event) => {
self.dispatch_event(event).await;
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("WatchDispatcher lagged {} events (slow watchers)", n);
}
Err(broadcast::error::RecvError::Closed) => {
debug!("Broadcast channel closed, WatchDispatcher stopping");
break;
}
}
}
}
}
debug!("WatchDispatcher stopped");
}
async fn dispatch_event(
&self,
event: WatchEvent,
) {
if let Some(watchers) = self.registry.watchers.get(&event.key) {
let mut dead_watchers = Vec::new();
for watcher in watchers.iter() {
if watcher.sender.try_send(event.clone()).is_err() {
dead_watchers.push(watcher.id);
}
}
drop(watchers);
if !dead_watchers.is_empty() {
for id in dead_watchers {
self.registry.unregister(id, &event.key);
}
}
trace!(key = ?event.key, event_type = ?event.event_type, "Event dispatched");
}
}
}