use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::time::Duration;
use bytes::Bytes;
use d_engine_proto::client::WatchResponse;
use dashmap::DashMap;
use tokio::sync::broadcast;
use tokio::sync::mpsc;
use tracing::debug;
use tracing::trace;
use tracing::warn;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WatchEventType {
Put,
Delete,
Canceled,
Progress,
}
#[derive(Debug, Clone)]
pub struct WatchEvent {
pub event_type: WatchEventType,
pub key: Bytes,
pub value: Bytes,
pub prev_value: Option<Bytes>,
pub revision: u64,
}
fn proto_to_event(
proto: &WatchResponse,
prev_kv: bool,
) -> WatchEvent {
use d_engine_proto::client::WatchEventType as ProtoType;
let event_type = match ProtoType::try_from(proto.event_type) {
Ok(ProtoType::Put) => WatchEventType::Put,
Ok(ProtoType::Delete) => WatchEventType::Delete,
Ok(ProtoType::Canceled) => WatchEventType::Canceled,
Ok(ProtoType::Progress) => WatchEventType::Progress,
Err(_) => WatchEventType::Canceled,
};
WatchEvent {
event_type,
key: proto.key.clone(),
value: proto.value.clone(),
prev_value: if prev_kv {
Some(proto.prev_value.clone())
} else {
None
},
revision: proto.revision,
}
}
impl From<&WatchEvent> for WatchResponse {
fn from(e: &WatchEvent) -> Self {
use d_engine_proto::client::WatchEventType as ProtoType;
WatchResponse {
key: e.key.clone(),
value: e.value.clone(),
prev_value: e.prev_value.clone().unwrap_or_default(),
event_type: match e.event_type {
WatchEventType::Put => ProtoType::Put as i32,
WatchEventType::Delete => ProtoType::Delete as i32,
WatchEventType::Canceled => ProtoType::Canceled as i32,
WatchEventType::Progress => ProtoType::Progress as i32,
},
error: 0,
revision: e.revision,
}
}
}
#[derive(Debug)]
pub enum WatchError {
LimitExceeded(usize),
InvalidPrefix,
}
impl std::fmt::Display for WatchError {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match self {
WatchError::LimitExceeded(n) => write!(f, "watcher limit ({n}) exceeded"),
WatchError::InvalidPrefix => {
write!(f, "prefix must start with '/' and end with '/'")
}
}
}
}
impl std::error::Error for WatchError {}
pub(crate) fn prefix_segments(key: &Bytes) -> Vec<Bytes> {
key.iter()
.enumerate()
.filter(|&(_, &b)| b == b'/')
.map(|(i, _)| key.slice(0..i + 1))
.collect()
}
pub struct WatcherHandle {
id: u64,
key: Bytes,
is_prefix: bool,
receiver: mpsc::Receiver<WatchEvent>,
unregister_tx: Option<mpsc::UnboundedSender<(u64, Bytes)>>,
}
impl WatcherHandle {
pub fn id(&self) -> u64 {
self.id
}
pub fn key(&self) -> &Bytes {
&self.key
}
pub fn is_prefix(&self) -> bool {
self.is_prefix
}
pub fn receiver_mut(&mut self) -> &mut mpsc::Receiver<WatchEvent> {
&mut self.receiver
}
pub fn into_receiver(mut self) -> (u64, Bytes, mpsc::Receiver<WatchEvent>) {
let id = self.id;
let key = self.key.clone();
self.unregister_tx = None;
let (dummy_tx, dummy_rx) = mpsc::channel(1);
drop(dummy_tx);
let receiver = std::mem::replace(&mut self.receiver, dummy_rx);
(id, key, receiver)
}
}
impl Drop for WatcherHandle {
fn drop(&mut self) {
if let Some(ref tx) = self.unregister_tx {
let _ = tx.send((self.id, self.key.clone()));
trace!(
watcher_id = self.id,
key = ?self.key,
is_prefix = self.is_prefix,
"Watcher unregistered"
);
}
}
}
#[derive(Debug)]
struct Watcher {
id: u64,
sender: mpsc::Sender<WatchEvent>,
prev_kv: bool,
}
pub struct WatchRegistry {
exact: DashMap<Bytes, Vec<Watcher>>,
prefix: DashMap<Bytes, Vec<Watcher>>,
next_id: AtomicU64,
total_count: AtomicUsize,
prev_kv_watcher_count: Arc<AtomicUsize>,
watcher_buffer_size: usize,
max_watcher_count: usize,
unregister_tx: mpsc::UnboundedSender<(u64, Bytes)>,
}
impl WatchRegistry {
pub fn new(
watcher_buffer_size: usize,
unregister_tx: mpsc::UnboundedSender<(u64, Bytes)>,
) -> Self {
Self::new_with_limits(watcher_buffer_size, usize::MAX, unregister_tx)
}
pub fn new_with_limits(
watcher_buffer_size: usize,
max_watcher_count: usize,
unregister_tx: mpsc::UnboundedSender<(u64, Bytes)>,
) -> Self {
Self {
exact: DashMap::new(),
prefix: DashMap::new(),
next_id: AtomicU64::new(1),
total_count: AtomicUsize::new(0),
prev_kv_watcher_count: Arc::new(AtomicUsize::new(0)),
watcher_buffer_size,
max_watcher_count,
unregister_tx,
}
}
pub fn register(
&self,
key: Bytes,
prev_kv: bool,
) -> Result<WatcherHandle, WatchError> {
self.do_register(key, false, prev_kv)
}
pub fn register_prefix(
&self,
prefix: Bytes,
prev_kv: bool,
) -> Result<WatcherHandle, WatchError> {
if !prefix.starts_with(b"/") || !prefix.ends_with(b"/") {
return Err(WatchError::InvalidPrefix);
}
self.do_register(prefix, true, prev_kv)
}
fn do_register(
&self,
key: Bytes,
is_prefix: bool,
prev_kv: bool,
) -> Result<WatcherHandle, WatchError> {
let prev = self.total_count.fetch_add(1, Ordering::Relaxed);
if prev >= self.max_watcher_count {
self.total_count.fetch_sub(1, Ordering::Relaxed);
return Err(WatchError::LimitExceeded(self.max_watcher_count));
}
if prev_kv {
self.prev_kv_watcher_count.fetch_add(1, Ordering::Relaxed);
}
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let (sender, receiver) = mpsc::channel(self.watcher_buffer_size + 1);
let watcher = Watcher {
id,
sender,
prev_kv,
};
if is_prefix {
self.prefix.entry(key.clone()).or_default().push(watcher);
} else {
self.exact.entry(key.clone()).or_default().push(watcher);
}
trace!(watcher_id = id, key = ?key, is_prefix, prev_kv, "Watcher registered");
Ok(WatcherHandle {
id,
key,
is_prefix,
receiver,
unregister_tx: Some(self.unregister_tx.clone()),
})
}
fn unregister(
&self,
id: u64,
key: &Bytes,
) {
let mut found = false;
let mut had_prev_kv = false;
self.exact.remove_if_mut(key, |_, watchers| {
if let Some(pos) = watchers.iter().position(|w| w.id == id) {
had_prev_kv = watchers[pos].prev_kv;
watchers.remove(pos);
found = true;
}
watchers.is_empty()
});
if !found {
self.prefix.remove_if_mut(key, |_, watchers| {
if let Some(pos) = watchers.iter().position(|w| w.id == id) {
had_prev_kv = watchers[pos].prev_kv;
watchers.remove(pos);
found = true;
}
watchers.is_empty()
});
}
if found {
self.total_count.fetch_sub(1, Ordering::Relaxed);
if had_prev_kv {
self.prev_kv_watcher_count.fetch_sub(1, Ordering::Relaxed);
}
}
}
pub fn prev_kv_watcher_count(&self) -> usize {
self.prev_kv_watcher_count.load(Ordering::Relaxed)
}
pub fn prev_kv_watcher_count_arc(&self) -> Arc<AtomicUsize> {
Arc::clone(&self.prev_kv_watcher_count)
}
#[cfg(test)]
pub(crate) fn watcher_count(
&self,
key: &Bytes,
) -> usize {
self.exact.get(key).map(|w| w.len()).unwrap_or(0)
}
#[cfg(test)]
pub(crate) fn watched_key_count(&self) -> usize {
self.exact.len()
}
#[cfg(test)]
pub(crate) fn prefix_watcher_count(
&self,
prefix: &Bytes,
) -> usize {
self.prefix.get(prefix).map(|w| w.len()).unwrap_or(0)
}
}
pub struct WatchDispatcher {
registry: Arc<WatchRegistry>,
broadcast_rx: broadcast::Receiver<WatchResponse>,
unregister_rx: mpsc::UnboundedReceiver<(u64, Bytes)>,
last_applied: Arc<std::sync::atomic::AtomicU64>,
heartbeat_interval_ms: u64,
}
impl WatchDispatcher {
pub fn new(
registry: Arc<WatchRegistry>,
broadcast_rx: broadcast::Receiver<WatchResponse>,
unregister_rx: mpsc::UnboundedReceiver<(u64, Bytes)>,
last_applied: Arc<std::sync::atomic::AtomicU64>,
heartbeat_interval_ms: u64,
) -> Self {
Self {
registry,
broadcast_rx,
unregister_rx,
last_applied,
heartbeat_interval_ms,
}
}
pub async fn run(mut self) {
debug!("WatchDispatcher started");
let mut heartbeat: Option<tokio::time::Interval> = if self.heartbeat_interval_ms > 0 {
let base_ms = self.heartbeat_interval_ms;
let jitter = (base_ms / 10).max(1);
let mut h = DefaultHasher::new();
std::thread::current().id().hash(&mut h);
std::time::SystemTime::now().hash(&mut h);
let seed = h.finish();
let offset = seed % (jitter * 2);
let first_tick_ms = base_ms.saturating_sub(jitter) + offset;
let mut interval = tokio::time::interval_at(
tokio::time::Instant::now() + Duration::from_millis(first_tick_ms),
Duration::from_millis(base_ms),
);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
Some(interval)
} else {
None
};
loop {
tokio::select! {
biased;
Some((id, key)) = self.unregister_rx.recv() => {
self.registry.unregister(id, &key);
}
result = self.broadcast_rx.recv() => {
match result {
Ok(event) => self.dispatch_event(event).await,
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!("WatchDispatcher lagged {} events (slow watchers)", n);
}
Err(broadcast::error::RecvError::Closed) => {
debug!("Broadcast channel closed, WatchDispatcher stopping");
break;
}
}
}
Some(t) = async { if let Some(ref mut hb) = heartbeat { Some(hb.tick().await) } else { std::future::pending().await } } => {
let _ = t;
self.broadcast_progress().await;
}
}
}
debug!("WatchDispatcher stopped");
}
async fn broadcast_progress(&self) {
let revision = self.last_applied.load(Ordering::Relaxed);
let progress = WatchResponse {
key: Bytes::new(),
value: Bytes::new(),
prev_value: Bytes::new(),
event_type: d_engine_proto::client::WatchEventType::Progress as i32,
error: 0,
revision,
};
let exact_keys: Vec<Bytes> = self.registry.exact.iter().map(|e| e.key().clone()).collect();
let prefix_keys: Vec<Bytes> =
self.registry.prefix.iter().map(|e| e.key().clone()).collect();
for key in exact_keys {
self.dispatch_to_map(&self.registry.exact, &key, &progress).await;
}
for key in prefix_keys {
self.dispatch_to_map(&self.registry.prefix, &key, &progress).await;
}
}
async fn dispatch_event(
&self,
event: WatchResponse,
) {
self.dispatch_to_map(&self.registry.exact, &event.key, &event).await;
for prefix in prefix_segments(&event.key) {
self.dispatch_to_map(&self.registry.prefix, &prefix, &event).await;
}
}
async fn dispatch_to_map(
&self,
map: &DashMap<Bytes, Vec<Watcher>>,
lookup_key: &Bytes,
event: &WatchResponse,
) {
if let Some(watchers) = map.get(lookup_key) {
let mut dead_watchers = Vec::new();
for watcher in watchers.iter() {
let available = watcher.sender.capacity();
if available <= 1 {
if available == 1 {
warn!(
watcher_id = watcher.id,
key = ?event.key,
buffer_capacity = watcher.sender.max_capacity(),
buffer_len = watcher.sender.max_capacity() - available,
"watcher buffer overflow, sending cancel"
);
let _ = watcher
.sender
.try_send(crate::watch::make_cancel_event(event.key.clone()));
}
dead_watchers.push(watcher.id);
continue;
}
let watch_event = proto_to_event(event, watcher.prev_kv);
if let Err(mpsc::error::TrySendError::Closed(_)) =
watcher.sender.try_send(watch_event)
{
dead_watchers.push(watcher.id);
}
}
drop(watchers);
for id in dead_watchers {
self.registry.unregister(id, lookup_key);
}
trace!(
event_key = ?event.key,
lookup_key = ?lookup_key,
event_type = ?event.event_type,
"Event dispatched"
);
}
}
}