d-engine-core 0.2.3

Pure Raft consensus algorithm - for building custom Raft-based systems
Documentation
//! A oneshot channel implementation that can optionally be cloned in test environments.
//!
//! Unlike `StreamResponseSender` which is specialized for gRPC streaming responses
//! (`Result<tonic::Streaming<SnapshotChunk>, Status>`), this provides a generic oneshot
//! channel for any `T: Send`:
//! - Production: Regular oneshot semantics (non-cloneable)
//! - Tests: Uses broadcast channel to allow cloning senders
//!
//! Key differences from `StreamResponseSender`:
//! 1. Generic vs specialized (gRPC streaming)
//! 2. Simpler error handling
//! 3. Same test-friendly cloning pattern

use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

use d_engine_proto::server::storage::SnapshotChunk;
#[cfg(any(test, feature = "__test_support"))]
use tokio::sync::broadcast;
use tokio::sync::oneshot;
use tonic::Status;

pub trait RaftOneshot<T: Send> {
    type Sender: Send + Sync;
    type Receiver: Send + Sync;

    fn new() -> (Self::Sender, Self::Receiver);
}

pub struct MaybeCloneOneshot;

pub struct MaybeCloneOneshotSender<T: Send> {
    #[allow(dead_code)]
    inner: oneshot::Sender<T>,

    #[cfg(any(test, feature = "__test_support"))]
    test_inner: Option<broadcast::Sender<T>>, // None for non-cloneable types
}

impl<T: Send> Debug for MaybeCloneOneshotSender<T> {
    fn fmt(
        &self,
        f: &mut std::fmt::Formatter<'_>,
    ) -> std::fmt::Result {
        f.debug_struct("MaybeCloneOneshotSender").finish()
    }
}

pub struct MaybeCloneOneshotReceiver<T: Send> {
    #[allow(dead_code)]
    inner: oneshot::Receiver<T>,

    #[cfg(any(test, feature = "__test_support"))]
    test_inner: Option<broadcast::Receiver<T>>, // None for non-cloneable types
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send> MaybeCloneOneshotSender<T> {
    pub fn send(
        &self,
        value: T,
    ) -> Result<usize, broadcast::error::SendError<T>> {
        if let Some(tx) = &self.test_inner {
            tx.send(value)
        } else {
            // Fallback for non-cloneable types
            panic!("Cannot broadcast non-cloneable type in tests");
        }
    }
}

#[cfg(not(any(test, feature = "__test_support")))]
impl<T: Send> MaybeCloneOneshotSender<T> {
    pub fn send(
        self,
        value: T,
    ) -> Result<(), T> {
        self.inner.send(value)
    }
}

impl<T: Send + Clone> MaybeCloneOneshotReceiver<T> {
    #[cfg(any(test, feature = "__test_support"))]
    pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
        if let Some(rx) = &mut self.test_inner {
            rx.recv().await
        } else {
            // Fallback for non-cloneable types
            panic!("Cannot broadcast non-cloneable type in tests");
        }
    }

    #[cfg(any(test, feature = "__test_support"))]
    pub fn try_recv(&mut self) -> Result<T, broadcast::error::TryRecvError> {
        if let Some(rx) = &mut self.test_inner {
            rx.try_recv()
        } else {
            panic!("Cannot try_recv non-cloneable type in tests");
        }
    }
}

#[cfg(not(any(test, feature = "__test_support")))]
impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
    type Output = Result<T, oneshot::error::RecvError>;

    fn poll(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Self::Output> {
        unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll(cx)
    }
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
    type Output = Result<T, broadcast::error::RecvError>;

    fn poll(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Self::Output> {
        let this = self.get_mut();

        // Using the recv method of tokio::sync::broadcast::Receiver
        if let Some(rx) = &mut this.test_inner {
            match rx.try_recv() {
                Ok(value) => Poll::Ready(Ok(value)),
                Err(broadcast::error::TryRecvError::Empty) => {
                    // Register a Waker to wake up the task when data arrives
                    cx.waker().wake_by_ref();
                    Poll::Pending
                }
                Err(broadcast::error::TryRecvError::Closed) => {
                    Poll::Ready(Err(broadcast::error::RecvError::Closed))
                }
                Err(broadcast::error::TryRecvError::Lagged(n)) => {
                    Poll::Ready(Err(broadcast::error::RecvError::Lagged(n)))
                }
            }
        } else {
            // Fallback for non-cloneable types
            panic!("Cannot broadcast non-cloneable type in tests");
        }
    }
}

#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> Clone for MaybeCloneOneshotSender<T> {
    fn clone(&self) -> Self {
        let (sender, _) = oneshot::channel();
        Self {
            inner: sender,
            test_inner: self.test_inner.clone(),
        }
    }
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> Clone for MaybeCloneOneshotReceiver<T> {
    fn clone(&self) -> Self {
        let (_, receiver) = oneshot::channel();

        Self {
            inner: receiver,
            test_inner: Some(self.test_inner.as_ref().unwrap().resubscribe()),
        }
    }
}
#[cfg(any(test, feature = "__test_support"))]
impl<T: Send + Clone> RaftOneshot<T> for MaybeCloneOneshot {
    type Sender = MaybeCloneOneshotSender<T>;
    type Receiver = MaybeCloneOneshotReceiver<T>;

    fn new() -> (Self::Sender, Self::Receiver) {
        let (tx, rx) = oneshot::channel();
        let (test_tx, test_rx) = broadcast::channel(1);
        (
            MaybeCloneOneshotSender {
                inner: tx,
                test_inner: Some(test_tx),
            },
            MaybeCloneOneshotReceiver {
                inner: rx,
                test_inner: Some(test_rx),
            },
        )
    }
}
#[cfg(not(any(test, feature = "__test_support")))]
impl<T: Send> RaftOneshot<T> for MaybeCloneOneshot {
    type Sender = MaybeCloneOneshotSender<T>;
    type Receiver = MaybeCloneOneshotReceiver<T>;

    fn new() -> (Self::Sender, Self::Receiver) {
        let (tx, rx) = oneshot::channel();
        (
            MaybeCloneOneshotSender {
                inner: tx,
                #[cfg(any(test, feature = "__test_support"))]
                test_inner: None,
            },
            MaybeCloneOneshotReceiver {
                inner: rx,
                #[cfg(any(test, feature = "__test_support"))]
                test_inner: None,
            },
        )
    }
}

#[derive(Debug)]
pub struct StreamResponseSender {
    inner: oneshot::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,

    #[cfg(any(test, feature = "__test_support"))]
    test_inner:
        Option<broadcast::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>>,
}

impl StreamResponseSender {
    pub fn new() -> (
        Self,
        oneshot::Receiver<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
    ) {
        let (inner_tx, inner_rx) = oneshot::channel();
        (
            Self {
                inner: inner_tx,
                #[cfg(any(test, feature = "__test_support"))]
                test_inner: None,
            },
            inner_rx,
        )
    }

    pub fn send(
        self,
        value: std::result::Result<tonic::Streaming<SnapshotChunk>, Status>,
    ) -> Result<(), Box<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>> {
        #[cfg(not(any(test, feature = "__test_support")))]
        return self.inner.send(value).map_err(Box::new);

        #[cfg(any(test, feature = "__test_support"))]
        if let Some(tx) = self.test_inner {
            tx.send(value).map(|_| ()).map_err(|e| Box::new(e.0))
        } else {
            self.inner.send(value).map_err(Box::new)
        }
    }
}