rs-zero 0.2.8

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::time::Duration;

use tokio::{sync::mpsc, time::sleep};

use crate::discovery::ServiceInstance;
use crate::discovery_etcd::client::SharedEtcdClient;
use crate::discovery_etcd::{
    EtcdDiscoveryConfig, EtcdDiscoveryError, EtcdDiscoveryResult, decode_instance, service_prefix,
    split_instance_key,
};

/// Watch event kind.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum WatchEventKind {
    /// Instance created or updated.
    Put,
    /// Instance deleted.
    Delete,
}

/// Normalized etcd watch event.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WatchEvent {
    /// Event kind.
    pub kind: WatchEventKind,
    /// Raw etcd key.
    pub key: String,
    /// Service name parsed from the key.
    pub service: String,
    /// Service instance id parsed from the key.
    pub id: String,
    /// Service instance when the event carries an instance payload.
    pub instance: Option<ServiceInstance>,
}

/// Async stream of normalized etcd watch events.
#[derive(Debug)]
pub struct EtcdWatchStream {
    receiver: mpsc::Receiver<EtcdDiscoveryResult<WatchEvent>>,
}

impl EtcdWatchStream {
    /// Receives the next watch event or error.
    pub async fn recv(&mut self) -> Option<EtcdDiscoveryResult<WatchEvent>> {
        self.receiver.recv().await
    }
}

/// Reconnect backoff configuration.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BackoffConfig {
    /// Initial delay.
    pub initial: Duration,
    /// Maximum delay.
    pub max: Duration,
}

impl Default for BackoffConfig {
    fn default() -> Self {
        Self {
            initial: Duration::from_millis(200),
            max: Duration::from_secs(5),
        }
    }
}

impl BackoffConfig {
    /// Returns the delay for an attempt number.
    pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
        let multiplier = 2_u32.saturating_pow(attempt.min(16));
        self.initial
            .checked_mul(multiplier)
            .unwrap_or(self.max)
            .min(self.max)
    }
}

pub(crate) fn spawn_service_watch(
    config: EtcdDiscoveryConfig,
    client: SharedEtcdClient,
    service: String,
) -> EtcdWatchStream {
    let (sender, receiver) = mpsc::channel(128);
    tokio::spawn(async move {
        watch_loop(config, client, service, sender).await;
    });
    EtcdWatchStream { receiver }
}

async fn watch_loop(
    config: EtcdDiscoveryConfig,
    client: SharedEtcdClient,
    service: String,
    sender: mpsc::Sender<EtcdDiscoveryResult<WatchEvent>>,
) {
    let mut attempt = 0;
    while !sender.is_closed() {
        match watch_once(&config, &client, &service, &sender).await {
            Ok(()) => {}
            Err(error) => {
                if sender.send(Err(error)).await.is_err() {
                    break;
                }
            }
        }
        let delay = config.watch_backoff.delay_for_attempt(attempt);
        attempt = attempt.saturating_add(1);
        sleep(delay).await;
    }
}

async fn watch_once(
    config: &EtcdDiscoveryConfig,
    client: &SharedEtcdClient,
    service: &str,
    sender: &mpsc::Sender<EtcdDiscoveryResult<WatchEvent>>,
) -> EtcdDiscoveryResult<()> {
    let prefix = service_prefix(config, service);
    let options = etcd_client::WatchOptions::new()
        .with_prefix()
        .with_prev_key();
    let (_watcher, mut stream) = {
        let mut client = client.lock().await;
        client
            .watch(prefix, Some(options))
            .await
            .map_err(|error| EtcdDiscoveryError::Watch(error.to_string()))?
    };
    while let Some(response) = stream
        .message()
        .await
        .map_err(|error| EtcdDiscoveryError::Watch(error.to_string()))?
    {
        if response.canceled() {
            return Err(EtcdDiscoveryError::Watch(format!(
                "watch canceled at compact revision {}: {}",
                response.compact_revision(),
                response.cancel_reason()
            )));
        }
        for event in response.events() {
            let normalized = normalize_event(config, event)?;
            if sender.send(Ok(normalized)).await.is_err() {
                return Ok(());
            }
        }
    }
    Err(EtcdDiscoveryError::Watch(
        "watch stream closed by etcd".to_string(),
    ))
}

fn normalize_event(
    config: &EtcdDiscoveryConfig,
    event: &etcd_client::Event,
) -> EtcdDiscoveryResult<WatchEvent> {
    let kv = event.kv().or_else(|| event.prev_kv()).ok_or_else(|| {
        EtcdDiscoveryError::Watch("watch event did not include a key".to_string())
    })?;
    let key = String::from_utf8(kv.key().to_vec())
        .map_err(|_| EtcdDiscoveryError::InvalidKey("<non-utf8>".to_string()))?;
    let parsed = split_instance_key(config, &key)?
        .ok_or_else(|| EtcdDiscoveryError::InvalidKey(key.clone()))?;
    let kind = match event.event_type() {
        etcd_client::EventType::Put => WatchEventKind::Put,
        etcd_client::EventType::Delete => WatchEventKind::Delete,
    };
    let instance = event_instance(&kind, event)?;
    Ok(WatchEvent {
        kind,
        key,
        service: parsed.service,
        id: parsed.id,
        instance,
    })
}

fn event_instance(
    kind: &WatchEventKind,
    event: &etcd_client::Event,
) -> EtcdDiscoveryResult<Option<ServiceInstance>> {
    match kind {
        WatchEventKind::Put => event.kv().map(|kv| decode_instance(kv.value())).transpose(),
        WatchEventKind::Delete => event
            .prev_kv()
            .map(|kv| decode_instance(kv.value()))
            .transpose(),
    }
}

#[cfg(test)]
mod tests {
    use super::BackoffConfig;

    #[test]
    fn backoff_is_capped() {
        let config = BackoffConfig::default();
        assert!(config.delay_for_attempt(10) <= config.max);
    }
}