use std::time::Duration;
use tokio::{sync::mpsc, time::sleep};
use crate::discovery::ServiceInstance;
use crate::discovery_etcd::client::SharedEtcdClient;
use crate::discovery_etcd::{
EtcdDiscoveryConfig, EtcdDiscoveryError, EtcdDiscoveryResult, decode_instance, service_prefix,
split_instance_key,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WatchEventKind {
Put,
Delete,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WatchEvent {
pub kind: WatchEventKind,
pub key: String,
pub service: String,
pub id: String,
pub instance: Option<ServiceInstance>,
}
#[derive(Debug)]
pub struct EtcdWatchStream {
receiver: mpsc::Receiver<EtcdDiscoveryResult<WatchEvent>>,
}
impl EtcdWatchStream {
pub async fn recv(&mut self) -> Option<EtcdDiscoveryResult<WatchEvent>> {
self.receiver.recv().await
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BackoffConfig {
pub initial: Duration,
pub max: Duration,
}
impl Default for BackoffConfig {
fn default() -> Self {
Self {
initial: Duration::from_millis(200),
max: Duration::from_secs(5),
}
}
}
impl BackoffConfig {
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let multiplier = 2_u32.saturating_pow(attempt.min(16));
self.initial
.checked_mul(multiplier)
.unwrap_or(self.max)
.min(self.max)
}
}
pub(crate) fn spawn_service_watch(
config: EtcdDiscoveryConfig,
client: SharedEtcdClient,
service: String,
) -> EtcdWatchStream {
let (sender, receiver) = mpsc::channel(128);
tokio::spawn(async move {
watch_loop(config, client, service, sender).await;
});
EtcdWatchStream { receiver }
}
async fn watch_loop(
config: EtcdDiscoveryConfig,
client: SharedEtcdClient,
service: String,
sender: mpsc::Sender<EtcdDiscoveryResult<WatchEvent>>,
) {
let mut attempt = 0;
while !sender.is_closed() {
match watch_once(&config, &client, &service, &sender).await {
Ok(()) => {}
Err(error) => {
if sender.send(Err(error)).await.is_err() {
break;
}
}
}
let delay = config.watch_backoff.delay_for_attempt(attempt);
attempt = attempt.saturating_add(1);
sleep(delay).await;
}
}
async fn watch_once(
config: &EtcdDiscoveryConfig,
client: &SharedEtcdClient,
service: &str,
sender: &mpsc::Sender<EtcdDiscoveryResult<WatchEvent>>,
) -> EtcdDiscoveryResult<()> {
let prefix = service_prefix(config, service);
let options = etcd_client::WatchOptions::new()
.with_prefix()
.with_prev_key();
let (_watcher, mut stream) = {
let mut client = client.lock().await;
client
.watch(prefix, Some(options))
.await
.map_err(|error| EtcdDiscoveryError::Watch(error.to_string()))?
};
while let Some(response) = stream
.message()
.await
.map_err(|error| EtcdDiscoveryError::Watch(error.to_string()))?
{
if response.canceled() {
return Err(EtcdDiscoveryError::Watch(format!(
"watch canceled at compact revision {}: {}",
response.compact_revision(),
response.cancel_reason()
)));
}
for event in response.events() {
let normalized = normalize_event(config, event)?;
if sender.send(Ok(normalized)).await.is_err() {
return Ok(());
}
}
}
Err(EtcdDiscoveryError::Watch(
"watch stream closed by etcd".to_string(),
))
}
fn normalize_event(
config: &EtcdDiscoveryConfig,
event: &etcd_client::Event,
) -> EtcdDiscoveryResult<WatchEvent> {
let kv = event.kv().or_else(|| event.prev_kv()).ok_or_else(|| {
EtcdDiscoveryError::Watch("watch event did not include a key".to_string())
})?;
let key = String::from_utf8(kv.key().to_vec())
.map_err(|_| EtcdDiscoveryError::InvalidKey("<non-utf8>".to_string()))?;
let parsed = split_instance_key(config, &key)?
.ok_or_else(|| EtcdDiscoveryError::InvalidKey(key.clone()))?;
let kind = match event.event_type() {
etcd_client::EventType::Put => WatchEventKind::Put,
etcd_client::EventType::Delete => WatchEventKind::Delete,
};
let instance = event_instance(&kind, event)?;
Ok(WatchEvent {
kind,
key,
service: parsed.service,
id: parsed.id,
instance,
})
}
fn event_instance(
kind: &WatchEventKind,
event: &etcd_client::Event,
) -> EtcdDiscoveryResult<Option<ServiceInstance>> {
match kind {
WatchEventKind::Put => event.kv().map(|kv| decode_instance(kv.value())).transpose(),
WatchEventKind::Delete => event
.prev_kv()
.map(|kv| decode_instance(kv.value()))
.transpose(),
}
}
#[cfg(test)]
mod tests {
use super::BackoffConfig;
#[test]
fn backoff_is_capped() {
let config = BackoffConfig::default();
assert!(config.delay_for_attempt(10) <= config.max);
}
}