use {
crate::{PeerId, UniqueId, groups::*, primitives::UnboundedChannel},
chrono::{DateTime, Utc},
core::{
any::type_name,
marker::PhantomData,
ops::Range,
task::{Context, Poll},
time::Duration,
},
derive_more::From,
parking_lot::RwLock,
protocol::{SnapshotRequest, SnapshotSyncMessage},
provider::SnapshotSyncProvider,
serde::{Deserialize, Serialize, de::DeserializeOwned},
session::SnapshotSyncSession,
std::sync::Arc,
tokio::sync::{
broadcast,
mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel},
},
};
pub(super) mod protocol;
mod provider;
mod session;
pub struct SnapshotSync<M: SnapshotStateMachine>(
Arc<RwLock<SnapshotSyncInner<M>>>,
);
pub trait SnapshotStateMachine: StateMachine {
type Snapshot: Snapshot;
fn create_snapshot(&self) -> Self::Snapshot;
fn install_snapshot(&mut self, snapshot: Self::Snapshot);
}
pub trait Snapshot: Default + Clone + Send + 'static {
type Item: SnapshotItem;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn len(&self) -> u64;
fn iter_range(
&self,
range: Range<u64>,
) -> Option<impl Iterator<Item = Self::Item>>;
fn append(&mut self, items: impl IntoIterator<Item = Self::Item>);
}
pub trait SnapshotItem:
Clone + Send + Sync + Serialize + DeserializeOwned + 'static
{
}
impl<T> SnapshotItem for T where
T: Clone + Send + Sync + Serialize + DeserializeOwned + 'static
{
}
#[derive(Debug, Clone)]
pub struct Config {
fetch_batch_size: u64,
snapshot_ttl: Duration,
snapshot_request_timeout: Duration,
fetch_timeout: Duration,
}
impl Config {
#[must_use]
pub const fn with_fetch_batch_size(mut self, fetch_batch_size: u64) -> Self {
self.fetch_batch_size = fetch_batch_size;
self
}
#[must_use]
pub const fn with_snapshot_ttl(mut self, snapshot_ttl: Duration) -> Self {
self.snapshot_ttl = snapshot_ttl;
self
}
#[must_use]
pub const fn with_snapshot_request_timeout(
mut self,
snapshot_request_timeout: Duration,
) -> Self {
self.snapshot_request_timeout = snapshot_request_timeout;
self
}
#[must_use]
pub const fn with_fetch_timeout(mut self, fetch_timeout: Duration) -> Self {
self.fetch_timeout = fetch_timeout;
self
}
}
impl Config {
pub(super) fn is_expired(&self, request: &SnapshotRequest) -> bool {
let Ok(elapsed) = Utc::now()
.signed_duration_since(request.requested_at)
.abs()
.to_std()
else {
return true;
};
elapsed > self.snapshot_ttl
}
}
impl Default for Config {
fn default() -> Self {
Self {
fetch_batch_size: 2000,
snapshot_ttl: Duration::from_secs(10),
snapshot_request_timeout: Duration::from_secs(15),
fetch_timeout: Duration::from_secs(5),
}
}
}
pub type SyncInitCommand<M: SnapshotStateMachine> =
Arc<dyn Fn(SnapshotRequest) -> M::Command + Send + Sync>;
impl<M: SnapshotStateMachine> SnapshotSync<M> {
pub(super) fn new(
config: Config,
to_command: impl Fn(SnapshotRequest) -> M::Command + Send + Sync + 'static,
) -> Self {
let inner = SnapshotSyncInner::new(config, to_command);
Self(Arc::new(RwLock::new(inner)))
}
}
impl<M: SnapshotStateMachine> Clone for SnapshotSync<M> {
fn clone(&self) -> Self {
Self(Arc::clone(&self.0))
}
}
impl<M: SnapshotStateMachine> StateSync for SnapshotSync<M> {
type Machine = M;
type Message = SnapshotSyncMessage<<M::Snapshot as Snapshot>::Item>;
type Provider = provider::SnapshotSyncProvider<M>;
type Session = session::SnapshotSyncSession<M>;
fn signature(&self) -> crate::UniqueId {
self.0.read().signature()
}
fn create_provider(&self, cx: &dyn SyncContext<Self>) -> Self::Provider {
self.0.write().create_provider(cx)
}
fn create_session(
&self,
cx: &mut dyn SyncSessionContext<Self>,
position: Cursor,
leader_commit: Index,
entries: Vec<(M::Command, Term)>,
) -> Self::Session {
self
.0
.read()
.create_session(cx, position, leader_commit, entries)
}
}
impl<M: SnapshotStateMachine> SnapshotSync<M> {
pub fn is_expired(&self, request: &SnapshotRequest) -> bool {
self.0.read().is_expired(request)
}
pub fn serve_snapshot(
&self,
request: SnapshotRequest,
position: Cursor,
snapshot: M::Snapshot,
) {
self.0.read().serve_snapshot(request, position, snapshot);
}
}
struct PendingRequest<M: SnapshotStateMachine> {
request: SnapshotRequest,
position: Cursor,
snapshot: M::Snapshot,
}
struct SnapshotSyncInner<M: SnapshotStateMachine> {
config: Config,
to_command: SyncInitCommand<M>,
requests_tx: UnboundedSender<PendingRequest<M>>,
requests_rx: Option<UnboundedReceiver<PendingRequest<M>>>,
}
impl<M: SnapshotStateMachine> SnapshotSyncInner<M> {
pub fn new(
config: Config,
to_command: impl Fn(SnapshotRequest) -> M::Command + Send + Sync + 'static,
) -> Self {
let (requests_tx, requests_rx) = unbounded_channel();
Self {
config,
requests_tx,
requests_rx: Some(requests_rx),
to_command: Arc::new(to_command),
}
}
pub fn is_expired(&self, request: &SnapshotRequest) -> bool {
self.config.is_expired(request)
}
pub fn serve_snapshot(
&self,
request: SnapshotRequest,
position: Cursor,
snapshot: M::Snapshot,
) {
let _ = self.requests_tx.send(PendingRequest {
request,
position,
snapshot,
});
}
}
impl<M: SnapshotStateMachine> SnapshotSyncInner<M> {
pub fn signature(&self) -> crate::UniqueId {
UniqueId::from("mosaik_collections_snapshot_sync")
.derive(type_name::<M>())
.derive(self.config.fetch_batch_size.to_le_bytes())
.derive(self.config.snapshot_ttl.as_millis().to_le_bytes())
.derive(
self
.config
.snapshot_request_timeout
.as_millis()
.to_le_bytes(),
)
.derive(self.config.fetch_timeout.as_millis().to_le_bytes())
}
pub fn create_provider(
&mut self,
_cx: &dyn SyncContext<SnapshotSync<M>>,
) -> SnapshotSyncProvider<M> {
let Some(requests_rx) = self.requests_rx.take() else {
unreachable!("create_provider called more than once. this is a bug.")
};
SnapshotSyncProvider::<M>::new(
self.config.clone(),
self.to_command.clone(),
requests_rx,
)
}
pub fn create_session(
&self,
cx: &mut dyn SyncSessionContext<SnapshotSync<M>>,
position: Cursor,
leader_commit: Index,
entries: Vec<(M::Command, Term)>,
) -> SnapshotSyncSession<M> {
SnapshotSyncSession::new(&self.config, cx, position, leader_commit, entries)
}
}