use super::{BroadcastedView, WriteCommand};
use super::{Delta, Durability, WriteError, WriteResult};
use crate::StorageRead;
use crate::coordinator::traits::EpochStamped;
use crate::storage::StorageSnapshot;
use futures::FutureExt;
use futures::future::Shared;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc, oneshot, watch};
pub struct View<D: Delta> {
pub current: D::DeltaView,
pub frozen: Vec<EpochStamped<D::FrozenView>>,
pub snapshot: Arc<dyn StorageSnapshot>,
pub last_written_delta: Option<EpochStamped<D::FrozenView>>,
}
impl<D: Delta> Clone for View<D> {
fn clone(&self) -> Self {
Self {
current: self.current.clone(),
frozen: self.frozen.clone(),
snapshot: self.snapshot.clone(),
last_written_delta: self.last_written_delta.clone(),
}
}
}
#[derive(Clone)]
pub struct EpochWatcher {
pub applied_rx: watch::Receiver<u64>,
pub written_rx: watch::Receiver<u64>,
pub durable_rx: watch::Receiver<u64>,
}
impl EpochWatcher {
pub async fn wait(
&mut self,
epoch: u64,
durability: Durability,
) -> Result<(), watch::error::RecvError> {
let rx = match durability {
Durability::Applied => &mut self.applied_rx,
Durability::Written => &mut self.written_rx,
Durability::Durable => &mut self.durable_rx,
};
rx.wait_for(|curr| *curr >= epoch).await.map(|_| ())
}
}
#[derive(Clone, Debug)]
pub(crate) struct WriteApplied<M> {
pub epoch: u64,
pub result: M,
}
#[derive(Clone, Debug)]
pub(crate) struct WriteFailed {
pub epoch: u64,
pub error: String,
}
pub(crate) type EpochResult<M> = Result<WriteApplied<M>, WriteFailed>;
pub struct WriteHandle<M: Clone + Send + 'static = ()> {
inner: Shared<oneshot::Receiver<EpochResult<M>>>,
watchers: EpochWatcher,
}
impl<M: Clone + Send + 'static> WriteHandle<M> {
pub(crate) fn new(rx: oneshot::Receiver<EpochResult<M>>, watchers: EpochWatcher) -> Self {
Self {
inner: rx.shared(),
watchers,
}
}
async fn recv(&self) -> WriteResult<WriteApplied<M>> {
self.inner
.clone()
.await
.map_err(|_| WriteError::Shutdown)?
.map_err(|e| WriteError::ApplyError(e.epoch, e.error))
}
pub async fn epoch(&self) -> WriteResult<u64> {
Ok(self.recv().await?.epoch)
}
pub async fn wait(&mut self, durability: Durability) -> WriteResult<M> {
let WriteApplied { epoch, result } = self.recv().await?;
self.watchers
.wait(epoch, durability)
.await
.map_err(|_| WriteError::Shutdown)?;
Ok(result)
}
}
pub struct WriteCoordinatorHandle<D: Delta> {
write_tx: mpsc::Sender<WriteCommand<D>>,
watchers: EpochWatcher,
view: Arc<BroadcastedView<D>>,
}
impl<D: Delta> WriteCoordinatorHandle<D> {
pub(crate) fn new(
write_tx: mpsc::Sender<WriteCommand<D>>,
watchers: EpochWatcher,
view: Arc<BroadcastedView<D>>,
) -> Self {
Self {
write_tx,
watchers,
view,
}
}
pub fn flushed_epoch(&self) -> u64 {
*self.watchers.written_rx.borrow()
}
}
impl<D: Delta> WriteCoordinatorHandle<D> {
pub async fn write_timeout(
&self,
write: D::Write,
timeout: Duration,
) -> Result<WriteHandle<D::ApplyResult>, WriteError<D::Write>> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send_timeout(
WriteCommand::Write {
write,
result_tx: tx,
},
timeout,
)
.await
.map_err(|e| match e {
mpsc::error::SendTimeoutError::Timeout(WriteCommand::Write { write, .. }) => {
WriteError::TimeoutError(write)
}
mpsc::error::SendTimeoutError::Closed(WriteCommand::Write { write, .. }) => {
WriteError::Shutdown
}
_ => unreachable!("sent a Write command"),
})?;
Ok(WriteHandle::new(rx, self.watchers.clone()))
}
pub async fn write(
&self,
write: D::Write,
) -> Result<WriteHandle<D::ApplyResult>, WriteError<D::Write>> {
let (tx, rx) = oneshot::channel();
self.write_tx
.send(WriteCommand::Write {
write,
result_tx: tx,
})
.await
.map_err(|e| match e {
mpsc::error::SendError(WriteCommand::Write { write, .. }) => WriteError::Shutdown,
_ => unreachable!("sent a Write command"),
})?;
Ok(WriteHandle::new(rx, self.watchers.clone()))
}
pub async fn try_write(
&self,
write: D::Write,
) -> Result<WriteHandle<D::ApplyResult>, WriteError<D::Write>> {
let (tx, rx) = oneshot::channel();
self.write_tx
.try_send(WriteCommand::Write {
write,
result_tx: tx,
})
.map_err(|e| match e {
mpsc::error::TrySendError::Full(WriteCommand::Write { write, .. }) => {
WriteError::Backpressure(write)
}
mpsc::error::TrySendError::Closed(WriteCommand::Write { write, .. }) => {
WriteError::Shutdown
}
_ => unreachable!("sent a Write command"),
})?;
Ok(WriteHandle::new(rx, self.watchers.clone()))
}
pub async fn flush(&self, flush_storage: bool) -> WriteResult<WriteHandle> {
let (tx, rx) = oneshot::channel();
self.write_tx
.try_send(WriteCommand::Flush {
epoch_tx: tx,
flush_storage,
})
.map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => WriteError::Backpressure(()),
mpsc::error::TrySendError::Closed(_) => WriteError::Shutdown,
})?;
Ok(WriteHandle::new(rx, self.watchers.clone()))
}
pub fn view(&self) -> Arc<View<D>> {
self.view.current()
}
pub fn subscribe(&self) -> (broadcast::Receiver<Arc<View<D>>>, Arc<View<D>>) {
self.view.subscribe()
}
}
impl<D: Delta> Clone for WriteCoordinatorHandle<D> {
fn clone(&self) -> Self {
Self {
write_tx: self.write_tx.clone(),
watchers: self.watchers.clone(),
view: self.view.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::watch;
fn create_watchers(
applied: watch::Receiver<u64>,
flushed: watch::Receiver<u64>,
durable: watch::Receiver<u64>,
) -> EpochWatcher {
EpochWatcher {
applied_rx: applied,
written_rx: flushed,
durable_rx: durable,
}
}
#[tokio::test]
async fn should_return_epoch_when_assigned() {
let (tx, rx) = oneshot::channel();
let (_applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let handle: WriteHandle<()> =
WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
tx.send(Ok(WriteApplied {
epoch: 42,
result: (),
}))
.unwrap();
let result = handle.epoch().await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn should_allow_multiple_epoch_calls() {
let (tx, rx) = oneshot::channel();
let (_applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let handle: WriteHandle<()> =
WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
tx.send(Ok(WriteApplied {
epoch: 42,
result: (),
}))
.unwrap();
let result1 = handle.epoch().await;
let result2 = handle.epoch().await;
let result3 = handle.epoch().await;
assert_eq!(result1.unwrap(), 42);
assert_eq!(result2.unwrap(), 42);
assert_eq!(result3.unwrap(), 42);
}
#[tokio::test]
async fn should_return_apply_result_from_wait() {
let (tx, rx) = oneshot::channel();
let (_applied_tx, applied_rx) = watch::channel(100u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut handle: WriteHandle<String> =
WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
tx.send(Ok(WriteApplied {
epoch: 1,
result: "hello".to_string(),
}))
.unwrap();
assert_eq!(handle.wait(Durability::Applied).await.unwrap(), "hello");
}
#[tokio::test]
async fn should_return_immediately_when_watermark_already_reached() {
let (tx, rx) = oneshot::channel();
let (_applied_tx, applied_rx) = watch::channel(100u64); let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut handle: WriteHandle<()> =
WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
tx.send(Ok(WriteApplied {
epoch: 50,
result: (),
}))
.unwrap();
let result = handle.wait(Durability::Applied).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn should_wait_until_watermark_reaches_epoch() {
let (tx, rx) = oneshot::channel();
let (applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut handle: WriteHandle<()> =
WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
tx.send(Ok(WriteApplied {
epoch: 10,
result: (),
}))
.unwrap();
let wait_task = tokio::spawn(async move { handle.wait(Durability::Applied).await });
tokio::task::yield_now().await;
applied_tx.send(5).unwrap(); tokio::task::yield_now().await;
applied_tx.send(10).unwrap();
let result = wait_task.await.unwrap();
assert!(result.is_ok());
}
#[tokio::test]
async fn should_wait_for_correct_durability_level() {
let (tx, rx) = oneshot::channel();
let (_applied_tx, applied_rx) = watch::channel(100u64);
let (_flushed_tx, flushed_rx) = watch::channel(50u64);
let (durable_tx, durable_rx) = watch::channel(10u64);
let mut handle: WriteHandle<()> =
WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
tx.send(Ok(WriteApplied {
epoch: 25,
result: (),
}))
.unwrap();
let wait_task = tokio::spawn(async move { handle.wait(Durability::Durable).await });
tokio::task::yield_now().await;
durable_tx.send(25).unwrap();
let result = wait_task.await.unwrap();
assert!(result.is_ok());
}
#[tokio::test]
async fn should_propagate_epoch_error_in_wait() {
let (tx, rx) = oneshot::channel::<EpochResult<()>>();
let (_applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut handle = WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
drop(tx);
let result = handle.wait(Durability::Applied).await;
assert!(matches!(result, Err(WriteError::Shutdown)));
}
#[tokio::test]
async fn should_propagate_apply_error_in_wait() {
let (tx, rx) = oneshot::channel();
let (_applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut handle: WriteHandle<()> =
WriteHandle::new(rx, create_watchers(applied_rx, flushed_rx, durable_rx));
tx.send(Err(WriteFailed {
epoch: 1,
error: "apply error".into(),
}))
.unwrap();
let result = handle.wait(Durability::Applied).await;
assert!(
matches!(result, Err(WriteError::ApplyError(epoch, msg)) if epoch == 1 && msg == "apply error")
);
}
#[tokio::test]
async fn epoch_watcher_should_resolve_when_watermark_reached() {
let (applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut watcher = create_watchers(applied_rx, flushed_rx, durable_rx);
let wait_task = tokio::spawn(async move { watcher.wait(5, Durability::Applied).await });
tokio::task::yield_now().await;
applied_tx.send(5).unwrap();
assert!(wait_task.await.unwrap().is_ok());
}
#[tokio::test]
async fn epoch_watcher_should_resolve_immediately_when_already_reached() {
let (_applied_tx, applied_rx) = watch::channel(10u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut watcher = create_watchers(applied_rx, flushed_rx, durable_rx);
assert!(watcher.wait(5, Durability::Applied).await.is_ok());
}
#[tokio::test]
async fn epoch_watcher_should_select_correct_durability_receiver() {
let (_applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (durable_tx, durable_rx) = watch::channel(0u64);
let mut watcher = create_watchers(applied_rx, flushed_rx, durable_rx);
let wait_task = tokio::spawn(async move { watcher.wait(3, Durability::Durable).await });
tokio::task::yield_now().await;
durable_tx.send(3).unwrap();
assert!(wait_task.await.unwrap().is_ok());
}
#[tokio::test]
async fn epoch_watcher_should_return_error_on_sender_drop() {
let (applied_tx, applied_rx) = watch::channel(0u64);
let (_flushed_tx, flushed_rx) = watch::channel(0u64);
let (_durable_tx, durable_rx) = watch::channel(0u64);
let mut watcher = create_watchers(applied_rx, flushed_rx, durable_rx);
drop(applied_tx);
let result = watcher.wait(1, Durability::Applied).await;
assert!(result.is_err());
}
}