rs-zero 0.2.4

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{
    collections::BTreeMap,
    fmt,
    sync::{Arc, Mutex},
};

use async_trait::async_trait;
use etcd_client::{Compare, CompareOp, DeleteOptions, GetOptions, PutOptions, Txn, TxnOp};
use tokio::time::timeout;

use crate::discovery::{Discovery, DiscoveryError, DiscoveryResult, Registry, ServiceInstance};
use crate::discovery_etcd::client::{SharedEtcdClient, validate_config};
use crate::discovery_etcd::codec::decode_healthy_instances;
use crate::discovery_etcd::lease::{
    EtcdLeaseState, EtcdLeaseStatus, LeaseRegistration, LeaseStatuses, set_lease_status,
    spawn_keep_alive_task,
};
use crate::discovery_etcd::validation::{validate_instance, validate_name};
use crate::discovery_etcd::watch::spawn_service_watch;
use crate::discovery_etcd::{
    EtcdClientFactory, EtcdDiscoveryConfig, EtcdDiscoveryError, EtcdDiscoveryResult,
    EtcdWatchStream, decode_instance, encode_instance, instance_key, service_prefix,
};

/// Real etcd-backed service registry and discovery adapter.
#[derive(Clone)]
pub struct EtcdRegistry {
    config: EtcdDiscoveryConfig,
    client: Option<SharedEtcdClient>,
    registrations: Arc<Mutex<BTreeMap<String, LeaseRegistration>>>,
    lease_statuses: LeaseStatuses,
}

impl fmt::Debug for EtcdRegistry {
    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
        formatter
            .debug_struct("EtcdRegistry")
            .field("config", &self.config)
            .field("connected", &self.client.is_some())
            .finish_non_exhaustive()
    }
}

impl EtcdRegistry {
    /// Creates a disconnected registry with no in-memory fallback.
    pub fn new(config: EtcdDiscoveryConfig) -> Self {
        Self {
            config,
            client: None,
            registrations: Arc::new(Mutex::new(BTreeMap::new())),
            lease_statuses: Arc::new(Mutex::new(BTreeMap::new())),
        }
    }

    /// Connects to etcd and returns a real registry adapter.
    pub async fn connect(config: EtcdDiscoveryConfig) -> EtcdDiscoveryResult<Self> {
        let client = EtcdClientFactory::new(config.clone())?.connect().await?;
        Self::from_client(config, client)
    }

    /// Creates a registry from an already connected etcd client.
    pub fn from_client(
        config: EtcdDiscoveryConfig,
        client: etcd_client::Client,
    ) -> EtcdDiscoveryResult<Self> {
        validate_config(&config)?;
        Ok(Self {
            config,
            client: Some(Arc::new(tokio::sync::Mutex::new(client))),
            registrations: Arc::new(Mutex::new(BTreeMap::new())),
            lease_statuses: Arc::new(Mutex::new(BTreeMap::new())),
        })
    }

    /// Returns the etcd key for an instance.
    pub fn key_for(&self, instance: &ServiceInstance) -> String {
        instance_key(&self.config, &instance.service, &instance.id)
    }

    /// Watches all changes for one service and reconnects with backoff on stream errors.
    pub async fn watch_service(&self, service: &str) -> EtcdDiscoveryResult<EtcdWatchStream> {
        validate_name("service", service)?;
        let client = self.connected_client()?;
        Ok(spawn_service_watch(
            self.config.clone(),
            client,
            service.to_string(),
        ))
    }

    /// Returns the local lease keep-alive status for a registered instance.
    pub async fn lease_status(&self, service: &str, id: &str) -> Option<EtcdLeaseStatus> {
        let key = instance_key(&self.config, service, id);
        self.lease_statuses
            .lock()
            .expect("lease status mutex poisoned")
            .get(&key)
            .cloned()
    }

    /// Updates an instance health flag while preserving the etcd lease.
    pub async fn update_health(
        &self,
        service: &str,
        id: &str,
        healthy: bool,
    ) -> EtcdDiscoveryResult<ServiceInstance> {
        validate_name("service", service)?;
        validate_name("id", id)?;
        let key = instance_key(&self.config, service, id);
        let kv = self
            .get_key(&key)
            .await?
            .ok_or_else(|| EtcdDiscoveryError::MissingInstance {
                service: service.to_string(),
                id: id.to_string(),
            })?;
        let mut instance = decode_instance(kv.value())?;
        instance.healthy = healthy;
        let lease_id = self.effective_lease_id(&key, kv.lease());
        self.put_instance(&key, &instance, lease_id).await?;
        Ok(instance)
    }

    fn connected_client(&self) -> EtcdDiscoveryResult<SharedEtcdClient> {
        self.client
            .as_ref()
            .cloned()
            .ok_or(EtcdDiscoveryError::NotConnected)
    }

    async fn register_instance(&self, instance: ServiceInstance) -> DiscoveryResult<()> {
        validate_instance(&instance).map_err(EtcdDiscoveryError::into_discovery_error)?;
        let key = self.key_for(&instance);
        let lease_id = self
            .grant_lease()
            .await
            .map_err(EtcdDiscoveryError::into_discovery_error)?;
        match self.put_new_instance(&key, &instance, lease_id).await {
            Ok(true) => {
                self.spawn_keep_alive(key, lease_id);
                Ok(())
            }
            Ok(false) => {
                let _ = self.revoke_lease(lease_id).await;
                Err(DiscoveryError::DuplicateInstance {
                    service: instance.service,
                    id: instance.id,
                })
            }
            Err(error) => {
                let _ = self.revoke_lease(lease_id).await;
                Err(error.into_discovery_error())
            }
        }
    }

    async fn discover_instances(&self, service: &str) -> DiscoveryResult<Vec<ServiceInstance>> {
        validate_name("service", service).map_err(EtcdDiscoveryError::into_discovery_error)?;
        let prefix = service_prefix(&self.config, service);
        let kvs = self
            .get_prefix(&prefix)
            .await
            .map_err(EtcdDiscoveryError::into_discovery_error)?;
        let mut instances =
            decode_healthy_instances(kvs).map_err(EtcdDiscoveryError::into_discovery_error)?;
        instances.sort_by(|left, right| left.id.cmp(&right.id));
        if instances.is_empty() {
            Err(DiscoveryError::NoInstances {
                service: service.to_string(),
            })
        } else {
            Ok(instances)
        }
    }

    async fn deregister_instance(
        &self,
        service: &str,
        id: &str,
    ) -> DiscoveryResult<ServiceInstance> {
        validate_name("service", service).map_err(EtcdDiscoveryError::into_discovery_error)?;
        validate_name("id", id).map_err(EtcdDiscoveryError::into_discovery_error)?;
        let key = instance_key(&self.config, service, id);
        let kv = self
            .get_key(&key)
            .await
            .map_err(EtcdDiscoveryError::into_discovery_error)?;
        let kv = kv.ok_or_else(|| DiscoveryError::MissingInstance {
            service: service.to_string(),
            id: id.to_string(),
        })?;
        let instance =
            decode_instance(kv.value()).map_err(EtcdDiscoveryError::into_discovery_error)?;
        let lease_id = self.effective_lease_id(&key, kv.lease());
        self.stop_keep_alive(&key, lease_id);
        if lease_id != 0 {
            self.revoke_lease(lease_id)
                .await
                .map_err(EtcdDiscoveryError::into_discovery_error)?;
        } else {
            self.delete_key(&key)
                .await
                .map_err(EtcdDiscoveryError::into_discovery_error)?;
        }
        Ok(instance)
    }

    async fn grant_lease(&self) -> EtcdDiscoveryResult<i64> {
        let client = self.connected_client()?;
        let mut client = client.lock().await;
        let response = timeout(
            self.config.operation_timeout,
            client.lease_grant(self.config.lease_ttl, None),
        )
        .await
        .map_err(|_| EtcdDiscoveryError::Timeout {
            operation: "lease_grant",
        })?
        .map_err(|error| EtcdDiscoveryError::Lease(error.to_string()))?;
        Ok(response.id())
    }

    async fn revoke_lease(&self, lease_id: i64) -> EtcdDiscoveryResult<()> {
        let client = self.connected_client()?;
        let mut client = client.lock().await;
        timeout(self.config.operation_timeout, client.lease_revoke(lease_id))
            .await
            .map_err(|_| EtcdDiscoveryError::Timeout {
                operation: "lease_revoke",
            })?
            .map_err(|error| EtcdDiscoveryError::Lease(error.to_string()))?;
        Ok(())
    }

    async fn put_new_instance(
        &self,
        key: &str,
        instance: &ServiceInstance,
        lease_id: i64,
    ) -> EtcdDiscoveryResult<bool> {
        let value = encode_instance(instance)?;
        let options = PutOptions::new().with_lease(lease_id);
        let txn = Txn::new()
            .when([Compare::version(key, CompareOp::Equal, 0)])
            .and_then([TxnOp::put(key, value, Some(options))]);
        let client = self.connected_client()?;
        let mut client = client.lock().await;
        let response = timeout(self.config.operation_timeout, client.txn(txn))
            .await
            .map_err(|_| EtcdDiscoveryError::Timeout { operation: "txn" })?
            .map_err(|error| EtcdDiscoveryError::Backend(error.to_string()))?;
        Ok(response.succeeded())
    }

    async fn put_instance(
        &self,
        key: &str,
        instance: &ServiceInstance,
        lease_id: i64,
    ) -> EtcdDiscoveryResult<()> {
        let value = encode_instance(instance)?;
        let options = (lease_id != 0).then(|| PutOptions::new().with_lease(lease_id));
        let client = self.connected_client()?;
        let mut client = client.lock().await;
        timeout(
            self.config.operation_timeout,
            client.put(key, value, options),
        )
        .await
        .map_err(|_| EtcdDiscoveryError::Timeout { operation: "put" })?
        .map_err(|error| EtcdDiscoveryError::Backend(error.to_string()))?;
        Ok(())
    }

    async fn get_key(&self, key: &str) -> EtcdDiscoveryResult<Option<etcd_client::KeyValue>> {
        let client = self.connected_client()?;
        let mut client = client.lock().await;
        let response = timeout(self.config.operation_timeout, client.get(key, None))
            .await
            .map_err(|_| EtcdDiscoveryError::Timeout { operation: "get" })?
            .map_err(|error| EtcdDiscoveryError::Backend(error.to_string()))?;
        Ok(response.kvs().first().cloned())
    }

    async fn get_prefix(&self, prefix: &str) -> EtcdDiscoveryResult<Vec<etcd_client::KeyValue>> {
        let options = GetOptions::new().with_prefix();
        let client = self.connected_client()?;
        let mut client = client.lock().await;
        let response = timeout(
            self.config.operation_timeout,
            client.get(prefix, Some(options)),
        )
        .await
        .map_err(|_| EtcdDiscoveryError::Timeout { operation: "get" })?
        .map_err(|error| EtcdDiscoveryError::Backend(error.to_string()))?;
        Ok(response.kvs().to_vec())
    }

    async fn delete_key(&self, key: &str) -> EtcdDiscoveryResult<()> {
        let options = DeleteOptions::new().with_prev_key();
        let client = self.connected_client()?;
        let mut client = client.lock().await;
        timeout(
            self.config.operation_timeout,
            client.delete(key, Some(options)),
        )
        .await
        .map_err(|_| EtcdDiscoveryError::Timeout {
            operation: "delete",
        })?
        .map_err(|error| EtcdDiscoveryError::Backend(error.to_string()))?;
        Ok(())
    }

    fn effective_lease_id(&self, key: &str, remote_lease_id: i64) -> i64 {
        if remote_lease_id != 0 {
            return remote_lease_id;
        }
        self.registrations
            .lock()
            .expect("registration mutex poisoned")
            .get(key)
            .map(|registration| registration.lease_id)
            .unwrap_or_default()
    }

    fn spawn_keep_alive(&self, key: String, lease_id: i64) {
        let registration = spawn_keep_alive_task(
            self.config.clone(),
            self.connected_client().expect("connected client"),
            key.clone(),
            lease_id,
            Arc::clone(&self.lease_statuses),
        );
        self.registrations
            .lock()
            .expect("registration mutex poisoned")
            .insert(key, registration);
    }

    fn stop_keep_alive(&self, key: &str, lease_id: i64) {
        let registration = self
            .registrations
            .lock()
            .expect("registration mutex poisoned")
            .remove(key);
        if let Some(registration) = registration {
            registration.stop();
        }
        if lease_id != 0 {
            self.set_lease_status(key, lease_id, EtcdLeaseState::Stopped);
        }
    }

    fn set_lease_status(&self, key: &str, lease_id: i64, state: EtcdLeaseState) {
        set_lease_status(&self.lease_statuses, key, lease_id, state);
    }
}

impl Drop for EtcdRegistry {
    fn drop(&mut self) {
        if Arc::strong_count(&self.registrations) != 1 {
            return;
        }
        let registrations = std::mem::take(
            &mut *self
                .registrations
                .lock()
                .expect("registration mutex poisoned"),
        );
        for (key, registration) in registrations {
            let lease_id = registration.lease_id;
            registration.stop();
            self.set_lease_status(&key, lease_id, EtcdLeaseState::Stopped);
        }
    }
}

#[async_trait]
impl Discovery for EtcdRegistry {
    async fn discover(&self, service: &str) -> DiscoveryResult<Vec<ServiceInstance>> {
        self.discover_instances(service).await
    }
}

#[async_trait]
impl Registry for EtcdRegistry {
    async fn register(&self, instance: ServiceInstance) -> DiscoveryResult<()> {
        self.register_instance(instance).await
    }

    async fn deregister(&self, service: &str, id: &str) -> DiscoveryResult<ServiceInstance> {
        self.deregister_instance(service, id).await
    }
}