use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use d_engine_proto::server::storage::SnapshotChunk;
#[cfg(any(test, feature = "__test_support"))]
use tokio::sync::broadcast;
use tokio::sync::oneshot;
use tonic::Status;
pub trait RaftOneshot<T: Send> {
type Sender: Send + Sync;
type Receiver: Send + Sync;
fn new() -> (Self::Sender, Self::Receiver);
}
pub struct MaybeCloneOneshot;
pub struct MaybeCloneOneshotSender<T: Send> {
#[allow(dead_code)]
inner: oneshot::Sender<T>,
#[cfg(any(test, feature = "__test_support"))]
test_inner: Option<broadcast::Sender<T>>, }
impl<T: Send> Debug for MaybeCloneOneshotSender<T> {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
f.debug_struct("MaybeCloneOneshotSender").finish()
}
}
pub struct MaybeCloneOneshotReceiver<T: Send> {
#[allow(dead_code)]
inner: oneshot::Receiver<T>,
#[cfg(any(test, feature = "__test_support"))]
test_inner: Option<broadcast::Receiver<T>>, }
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send> MaybeCloneOneshotSender<T> {
pub fn send(
&self,
value: T,
) -> Result<usize, broadcast::error::SendError<T>> {
if let Some(tx) = &self.test_inner {
tx.send(value)
} else {
panic!("Cannot broadcast non-cloneable type in tests");
}
}
}
#[cfg(not(any(test, feature = "__test_support")))]
impl<T: Send> MaybeCloneOneshotSender<T> {
pub fn send(
self,
value: T,
) -> Result<(), T> {
self.inner.send(value)
}
}
impl<T: Send + Clone> MaybeCloneOneshotReceiver<T> {
#[cfg(any(test, feature = "__test_support"))]
pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
if let Some(rx) = &mut self.test_inner {
rx.recv().await
} else {
panic!("Cannot broadcast non-cloneable type in tests");
}
}
#[cfg(any(test, feature = "__test_support"))]
pub fn try_recv(&mut self) -> Result<T, broadcast::error::TryRecvError> {
if let Some(rx) = &mut self.test_inner {
rx.try_recv()
} else {
panic!("Cannot try_recv non-cloneable type in tests");
}
}
}
#[cfg(not(any(test, feature = "__test_support")))]
impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
type Output = Result<T, oneshot::error::RecvError>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll(cx)
}
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
type Output = Result<T, broadcast::error::RecvError>;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
let this = self.get_mut();
if let Some(rx) = &mut this.test_inner {
match rx.try_recv() {
Ok(value) => Poll::Ready(Ok(value)),
Err(broadcast::error::TryRecvError::Empty) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Err(broadcast::error::TryRecvError::Closed) => {
Poll::Ready(Err(broadcast::error::RecvError::Closed))
}
Err(broadcast::error::TryRecvError::Lagged(n)) => {
Poll::Ready(Err(broadcast::error::RecvError::Lagged(n)))
}
}
} else {
panic!("Cannot broadcast non-cloneable type in tests");
}
}
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> Clone for MaybeCloneOneshotSender<T> {
fn clone(&self) -> Self {
let (sender, _) = oneshot::channel();
Self {
inner: sender,
test_inner: self.test_inner.clone(),
}
}
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> Clone for MaybeCloneOneshotReceiver<T> {
fn clone(&self) -> Self {
let (_, receiver) = oneshot::channel();
Self {
inner: receiver,
test_inner: Some(self.test_inner.as_ref().unwrap().resubscribe()),
}
}
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> RaftOneshot<T> for MaybeCloneOneshot {
type Sender = MaybeCloneOneshotSender<T>;
type Receiver = MaybeCloneOneshotReceiver<T>;
fn new() -> (Self::Sender, Self::Receiver) {
let (tx, rx) = oneshot::channel();
let (test_tx, test_rx) = broadcast::channel(1);
(
MaybeCloneOneshotSender {
inner: tx,
test_inner: Some(test_tx),
},
MaybeCloneOneshotReceiver {
inner: rx,
test_inner: Some(test_rx),
},
)
}
}
#[cfg(not(any(test, feature = "__test_support")))]
impl<T: Send> RaftOneshot<T> for MaybeCloneOneshot {
type Sender = MaybeCloneOneshotSender<T>;
type Receiver = MaybeCloneOneshotReceiver<T>;
fn new() -> (Self::Sender, Self::Receiver) {
let (tx, rx) = oneshot::channel();
(
MaybeCloneOneshotSender {
inner: tx,
#[cfg(any(test, feature = "__test_support"))]
test_inner: None,
},
MaybeCloneOneshotReceiver {
inner: rx,
#[cfg(any(test, feature = "__test_support"))]
test_inner: None,
},
)
}
}
#[derive(Debug)]
pub struct StreamResponseSender {
inner: oneshot::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
#[cfg(any(test, feature = "__test_support"))]
test_inner:
Option<broadcast::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>>,
}
impl StreamResponseSender {
pub fn new() -> (
Self,
oneshot::Receiver<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
) {
let (inner_tx, inner_rx) = oneshot::channel();
(
Self {
inner: inner_tx,
#[cfg(any(test, feature = "__test_support"))]
test_inner: None,
},
inner_rx,
)
}
pub fn send(
self,
value: std::result::Result<tonic::Streaming<SnapshotChunk>, Status>,
) -> Result<(), Box<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>> {
#[cfg(not(any(test, feature = "__test_support")))]
return self.inner.send(value).map_err(Box::new);
#[cfg(any(test, feature = "__test_support"))]
if let Some(tx) = self.test_inner {
tx.send(value).map(|_| ()).map_err(|e| Box::new(e.0))
} else {
self.inner.send(value).map_err(Box::new)
}
}
}