eventdbx 1.12.4

An event-sourced, key-value, write-side database system.
Documentation
pub mod proto {
    tonic::include_proto!("eventdbx.replication");
}

use std::{collections::BTreeMap, sync::Arc, time::Duration};

use proto::{
    AggregatePosition, ApplySchemasRequest, ApplySchemasResponse, EventBatch,
    EventRecord as ProtoEventRecord, HeartbeatRequest, HeartbeatResponse, ListPositionsRequest,
    ListPositionsResponse, PullEventsRequest, PullEventsResponse, PullSchemasRequest,
    PullSchemasResponse, ReplicationAck, SnapshotChunk, SnapshotRequest,
    replication_server::Replication,
};
use tokio::{sync::mpsc, time::sleep};
use tokio_stream::{StreamExt, wrappers::ReceiverStream};
use tonic::{Request, Response, Status};
use tracing::{error, info, warn};

use crate::{
    config::{Config, RemoteConfig},
    error::{EventError, Result},
    replication_capnp_client::{
        CapnpReplicationClient, decode_public_key_bytes, normalize_capnp_endpoint,
    },
    schema::SchemaManager,
    store::{EventMetadata, EventRecord, EventStore},
};

#[derive(Clone, Default)]
pub struct ReplicationState {
    pub last_sequence: Arc<tokio::sync::Mutex<u64>>,
}

#[derive(Clone)]
pub struct ReplicationService {
    state: ReplicationState,
    store: Arc<EventStore>,
    schemas: Arc<SchemaManager>,
}

impl ReplicationService {
    pub fn new(store: Arc<EventStore>, schemas: Arc<SchemaManager>) -> Self {
        Self {
            state: ReplicationState::default(),
            store,
            schemas,
        }
    }

    pub fn state(&self) -> ReplicationState {
        self.state.clone()
    }
}

#[tonic::async_trait]
impl Replication for ReplicationService {
    type BootstrapSnapshotStream = ReceiverStream<std::result::Result<SnapshotChunk, Status>>;

    async fn bootstrap_snapshot(
        &self,
        _request: Request<SnapshotRequest>,
    ) -> std::result::Result<Response<Self::BootstrapSnapshotStream>, Status> {
        let (tx, rx) = mpsc::channel(1);
        drop(tx);
        Ok(Response::new(ReceiverStream::new(rx)))
    }

    async fn apply_events(
        &self,
        request: Request<tonic::Streaming<EventBatch>>,
    ) -> std::result::Result<Response<ReplicationAck>, Status> {
        let mut stream = request.into_inner();
        let mut last_sequence = 0u64;
        while let Some(batch) = stream.next().await {
            let batch = batch?;
            for event in batch.events.iter() {
                let record =
                    decode_event(event).map_err(|err| Status::internal(err.to_string()))?;
                self.store
                    .append_replica(record)
                    .map_err(|err| Status::internal(err.to_string()))?;
            }
            last_sequence = batch.sequence;
        }

        let mut guard = self.state.last_sequence.lock().await;
        if last_sequence > *guard {
            *guard = last_sequence;
        }

        Ok(Response::new(ReplicationAck {
            applied_sequence: *guard,
        }))
    }

    async fn heartbeat(
        &self,
        _request: Request<HeartbeatRequest>,
    ) -> std::result::Result<Response<HeartbeatResponse>, Status> {
        let applied = *self.state.last_sequence.lock().await;
        Ok(Response::new(HeartbeatResponse {
            applied_sequence: applied,
            pending_events: 0,
        }))
    }

    async fn list_positions(
        &self,
        _request: Request<ListPositionsRequest>,
    ) -> std::result::Result<Response<ListPositionsResponse>, Status> {
        let positions = self
            .store
            .aggregate_positions()
            .map_err(|err| Status::internal(err.to_string()))?;

        let proto_positions = positions
            .into_iter()
            .map(|entry| AggregatePosition {
                aggregate_type: entry.aggregate_type,
                aggregate_id: entry.aggregate_id,
                version: entry.version,
            })
            .collect();

        Ok(Response::new(ListPositionsResponse {
            positions: proto_positions,
        }))
    }

    async fn pull_events(
        &self,
        request: Request<PullEventsRequest>,
    ) -> std::result::Result<Response<PullEventsResponse>, Status> {
        let req = request.into_inner();
        let limit = if req.limit == 0 {
            None
        } else {
            Some(req.limit as usize)
        };

        let events = self
            .store
            .events_after(
                &req.aggregate_type,
                &req.aggregate_id,
                req.from_version,
                limit,
            )
            .map_err(|err| Status::internal(err.to_string()))?;

        let proto_events = events
            .into_iter()
            .map(|event| convert_event(&event).map_err(|err| Status::internal(err.to_string())))
            .collect::<std::result::Result<Vec<_>, Status>>()?;

        Ok(Response::new(PullEventsResponse {
            events: proto_events,
        }))
    }

    async fn pull_schemas(
        &self,
        _request: Request<PullSchemasRequest>,
    ) -> std::result::Result<Response<PullSchemasResponse>, Status> {
        let snapshot = self.schemas.snapshot();
        let payload =
            serde_json::to_vec(&snapshot).map_err(|err| Status::internal(err.to_string()))?;
        Ok(Response::new(PullSchemasResponse {
            schemas_json: payload,
        }))
    }

    async fn apply_schemas(
        &self,
        request: Request<ApplySchemasRequest>,
    ) -> std::result::Result<Response<ApplySchemasResponse>, Status> {
        let payload = request.into_inner();
        let map: BTreeMap<String, crate::schema::AggregateSchema> =
            if payload.schemas_json.is_empty() {
                BTreeMap::new()
            } else {
                serde_json::from_slice(&payload.schemas_json)
                    .map_err(|err| Status::internal(err.to_string()))?
            };

        let aggregate_count = map.len() as u32;
        self.schemas
            .replace_all(map)
            .map_err(|err| Status::internal(err.to_string()))?;

        Ok(Response::new(ApplySchemasResponse { aggregate_count }))
    }
}

#[derive(Clone)]
pub struct ReplicationManager {
    remotes: Arc<Vec<RemoteHandle>>,
}

struct RemoteHandle {
    name: String,
    sender: mpsc::Sender<Vec<EventRecord>>,
}

impl ReplicationManager {
    pub fn from_config(config: &Config) -> Self {
        let mut handles = Vec::new();
        for (name, remote_config) in &config.remotes {
            let (tx, rx) = mpsc::channel(64);
            let worker = RemoteWorker::new(name.clone(), remote_config.clone(), rx);
            tokio::spawn(async move { worker.run().await });
            handles.push(RemoteHandle {
                name: name.clone(),
                sender: tx,
            });
        }

        Self {
            remotes: Arc::new(handles),
        }
    }

    pub fn is_empty(&self) -> bool {
        self.remotes.is_empty()
    }

    pub fn enqueue(&self, events: &[EventRecord]) {
        if events.is_empty() {
            return;
        }

        for handle in self.remotes.iter() {
            let mut batch = Vec::with_capacity(events.len());
            for event in events {
                batch.push(event.clone());
            }

            let sender = handle.sender.clone();
            let remote_name = handle.name.clone();
            tokio::spawn(async move {
                if let Err(err) = sender.send(batch).await {
                    warn!(
                        "failed to queue replication batch for {}: {}",
                        remote_name, err
                    );
                }
            });
        }
    }
}

struct RemoteWorker {
    name: String,
    config: RemoteConfig,
    receiver: mpsc::Receiver<Vec<EventRecord>>,
    sequence: u64,
}

impl RemoteWorker {
    fn new(name: String, config: RemoteConfig, receiver: mpsc::Receiver<Vec<EventRecord>>) -> Self {
        Self {
            name,
            config,
            receiver,
            sequence: 0,
        }
    }

    async fn run(mut self) {
        while let Some(events) = self.receiver.recv().await {
            if events.is_empty() {
                continue;
            }

            let mut attempt: u32 = 0;
            loop {
                match self.send_batch(&events).await {
                    Ok(_) => break,
                    Err(err) => {
                        attempt += 1;
                        warn!(
                            "replication to remote '{}' failed on attempt {}: {}",
                            self.name, attempt, err
                        );
                        let delay = Self::backoff_delay(attempt);
                        sleep(delay).await;
                    }
                }
            }
        }
    }

    async fn send_batch(&mut self, events: &[EventRecord]) -> Result<()> {
        let mut client = self.connect().await?;
        self.sequence = self.sequence.wrapping_add(1);
        let filtered: Vec<EventRecord> = events
            .iter()
            .filter_map(|record| match convert_event(record) {
                Ok(_) => Some(record.clone()),
                Err(err) => {
                    error!(
                        "failed to encode event {}::{} version {} for replication: {}",
                        record.aggregate_type, record.aggregate_id, record.version, err
                    );
                    None
                }
            })
            .collect();

        if filtered.is_empty() {
            return Ok(());
        }

        let applied = client
            .apply_events(self.sequence, &filtered)
            .await
            .map_err(|err| EventError::Storage(err.to_string()))?;
        info!(
            "remote '{}' acknowledged sequence {} (applied {})",
            self.name, self.sequence, applied
        );
        Ok(())
    }

    async fn connect(&self) -> Result<CapnpReplicationClient> {
        let endpoint = normalize_capnp_endpoint(&self.config.endpoint)
            .map_err(|err| EventError::Config(err.to_string()))?;
        let expected_key = decode_public_key_bytes(&self.config.public_key)
            .map_err(|err| EventError::Config(err.to_string()))?;
        CapnpReplicationClient::connect(&endpoint, &expected_key)
            .await
            .map_err(|err| EventError::Storage(err.to_string()))
    }

    fn backoff_delay(attempt: u32) -> Duration {
        match attempt {
            0 | 1 => Duration::from_secs(1),
            2 => Duration::from_secs(2),
            3 => Duration::from_secs(4),
            _ => Duration::from_secs(10),
        }
    }
}

pub fn convert_event(record: &EventRecord) -> Result<ProtoEventRecord> {
    let payload = serde_json::to_vec(&record.payload)?;
    let metadata = serde_json::to_vec(&record.metadata)?;

    Ok(ProtoEventRecord {
        aggregate_type: record.aggregate_type.clone(),
        aggregate_id: record.aggregate_id.clone(),
        event_type: record.event_type.clone(),
        version: record.version,
        merkle_root: record.merkle_root.clone(),
        payload,
        metadata,
        hash: record.hash.clone(),
    })
}

pub fn decode_event(proto: &ProtoEventRecord) -> Result<EventRecord> {
    let payload: serde_json::Value = serde_json::from_slice(&proto.payload)?;
    let metadata: EventMetadata = serde_json::from_slice(&proto.metadata)?;

    Ok(EventRecord {
        aggregate_type: proto.aggregate_type.clone(),
        aggregate_id: proto.aggregate_id.clone(),
        event_type: proto.event_type.clone(),
        payload,
        metadata,
        version: proto.version,
        hash: proto.hash.clone(),
        merkle_root: proto.merkle_root.clone(),
    })
}

pub fn normalize_endpoint(endpoint: &str) -> Result<String> {
    if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
        return Ok(endpoint.to_string());
    }
    if let Some(rest) = endpoint.strip_prefix("grpc://") {
        return Ok(format!("http://{}", rest));
    }
    if let Some(rest) = endpoint.strip_prefix("grpcs://") {
        return Ok(format!("https://{}", rest));
    }
    Err(EventError::Config(format!(
        "unsupported replication endpoint scheme: {}",
        endpoint
    )))
}