d_engine_core/
maybe_clone_oneshot.rs

1//! A oneshot channel implementation that can optionally be cloned in test environments.
2//!
3//! Unlike `StreamResponseSender` which is specialized for gRPC streaming responses
4//! (`Result<tonic::Streaming<SnapshotChunk>, Status>`), this provides a generic oneshot
5//! channel for any `T: Send`:
6//! - Production: Regular oneshot semantics (non-cloneable)
7//! - Tests: Uses broadcast channel to allow cloning senders
8//!
9//! Key differences from `StreamResponseSender`:
10//! 1. Generic vs specialized (gRPC streaming)
11//! 2. Simpler error handling
12//! 3. Same test-friendly cloning pattern
13
14use std::fmt::Debug;
15use std::future::Future;
16use std::pin::Pin;
17use std::task::Context;
18use std::task::Poll;
19
20use d_engine_proto::server::storage::SnapshotChunk;
21#[cfg(any(test, feature = "__test_support"))]
22use tokio::sync::broadcast;
23use tokio::sync::oneshot;
24use tonic::Status;
25
26pub trait RaftOneshot<T: Send> {
27    type Sender: Send + Sync;
28    type Receiver: Send + Sync;
29
30    fn new() -> (Self::Sender, Self::Receiver);
31}
32
33pub struct MaybeCloneOneshot;
34
35pub struct MaybeCloneOneshotSender<T: Send> {
36    #[allow(dead_code)]
37    inner: oneshot::Sender<T>,
38
39    #[cfg(any(test, feature = "__test_support"))]
40    test_inner: Option<broadcast::Sender<T>>, // None for non-cloneable types
41}
42
43impl<T: Send> Debug for MaybeCloneOneshotSender<T> {
44    fn fmt(
45        &self,
46        f: &mut std::fmt::Formatter<'_>,
47    ) -> std::fmt::Result {
48        f.debug_struct("MaybeCloneOneshotSender").finish()
49    }
50}
51
52pub struct MaybeCloneOneshotReceiver<T: Send> {
53    #[allow(dead_code)]
54    inner: oneshot::Receiver<T>,
55
56    #[cfg(any(test, feature = "__test_support"))]
57    test_inner: Option<broadcast::Receiver<T>>, // None for non-cloneable types
58}
59#[cfg(any(test, feature = "__test_support"))]
60impl<T: Send> MaybeCloneOneshotSender<T> {
61    pub fn send(
62        &self,
63        value: T,
64    ) -> Result<usize, broadcast::error::SendError<T>> {
65        if let Some(tx) = &self.test_inner {
66            tx.send(value)
67        } else {
68            // Fallback for non-cloneable types
69            panic!("Cannot broadcast non-cloneable type in tests");
70        }
71    }
72}
73
74#[cfg(not(any(test, feature = "__test_support")))]
75impl<T: Send> MaybeCloneOneshotSender<T> {
76    pub fn send(
77        self,
78        value: T,
79    ) -> Result<(), T> {
80        self.inner.send(value)
81    }
82}
83
84impl<T: Send + Clone> MaybeCloneOneshotReceiver<T> {
85    #[cfg(any(test, feature = "__test_support"))]
86    pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
87        if let Some(rx) = &mut self.test_inner {
88            rx.recv().await
89        } else {
90            // Fallback for non-cloneable types
91            panic!("Cannot broadcast non-cloneable type in tests");
92        }
93    }
94}
95
96#[cfg(not(any(test, feature = "__test_support")))]
97impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
98    type Output = Result<T, oneshot::error::RecvError>;
99
100    fn poll(
101        self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103    ) -> Poll<Self::Output> {
104        unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll(cx)
105    }
106}
107#[cfg(any(test, feature = "__test_support"))]
108impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
109    type Output = Result<T, broadcast::error::RecvError>;
110
111    fn poll(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114    ) -> Poll<Self::Output> {
115        let this = self.get_mut();
116
117        // Using the recv method of tokio::sync::broadcast::Receiver
118        if let Some(rx) = &mut this.test_inner {
119            match rx.try_recv() {
120                Ok(value) => Poll::Ready(Ok(value)),
121                Err(broadcast::error::TryRecvError::Empty) => {
122                    // Register a Waker to wake up the task when data arrives
123                    cx.waker().wake_by_ref();
124                    Poll::Pending
125                }
126                Err(broadcast::error::TryRecvError::Closed) => {
127                    Poll::Ready(Err(broadcast::error::RecvError::Closed))
128                }
129                Err(broadcast::error::TryRecvError::Lagged(n)) => {
130                    Poll::Ready(Err(broadcast::error::RecvError::Lagged(n)))
131                }
132            }
133        } else {
134            // Fallback for non-cloneable types
135            panic!("Cannot broadcast non-cloneable type in tests");
136        }
137    }
138}
139
140#[cfg(any(test, feature = "__test_support"))]
141impl<T: Send + Clone> Clone for MaybeCloneOneshotSender<T> {
142    fn clone(&self) -> Self {
143        let (sender, _) = oneshot::channel();
144        Self {
145            inner: sender,
146            test_inner: self.test_inner.clone(),
147        }
148    }
149}
150#[cfg(any(test, feature = "__test_support"))]
151impl<T: Send + Clone> Clone for MaybeCloneOneshotReceiver<T> {
152    fn clone(&self) -> Self {
153        let (_, receiver) = oneshot::channel();
154
155        Self {
156            inner: receiver,
157            test_inner: Some(self.test_inner.as_ref().unwrap().resubscribe()),
158        }
159    }
160}
161#[cfg(any(test, feature = "__test_support"))]
162impl<T: Send + Clone> RaftOneshot<T> for MaybeCloneOneshot {
163    type Sender = MaybeCloneOneshotSender<T>;
164    type Receiver = MaybeCloneOneshotReceiver<T>;
165
166    fn new() -> (Self::Sender, Self::Receiver) {
167        let (tx, rx) = oneshot::channel();
168        let (test_tx, test_rx) = broadcast::channel(1);
169        (
170            MaybeCloneOneshotSender {
171                inner: tx,
172                test_inner: Some(test_tx),
173            },
174            MaybeCloneOneshotReceiver {
175                inner: rx,
176                test_inner: Some(test_rx),
177            },
178        )
179    }
180}
181#[cfg(not(any(test, feature = "__test_support")))]
182impl<T: Send> RaftOneshot<T> for MaybeCloneOneshot {
183    type Sender = MaybeCloneOneshotSender<T>;
184    type Receiver = MaybeCloneOneshotReceiver<T>;
185
186    fn new() -> (Self::Sender, Self::Receiver) {
187        let (tx, rx) = oneshot::channel();
188        (
189            MaybeCloneOneshotSender {
190                inner: tx,
191                #[cfg(any(test, feature = "__test_support"))]
192                test_inner: None,
193            },
194            MaybeCloneOneshotReceiver {
195                inner: rx,
196                #[cfg(any(test, feature = "__test_support"))]
197                test_inner: None,
198            },
199        )
200    }
201}
202
203#[derive(Debug)]
204pub struct StreamResponseSender {
205    inner: oneshot::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
206
207    #[cfg(any(test, feature = "__test_support"))]
208    test_inner:
209        Option<broadcast::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>>,
210}
211
212impl StreamResponseSender {
213    pub fn new() -> (
214        Self,
215        oneshot::Receiver<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
216    ) {
217        let (inner_tx, inner_rx) = oneshot::channel();
218        (
219            Self {
220                inner: inner_tx,
221                #[cfg(any(test, feature = "__test_support"))]
222                test_inner: None,
223            },
224            inner_rx,
225        )
226    }
227
228    pub fn send(
229        self,
230        value: std::result::Result<tonic::Streaming<SnapshotChunk>, Status>,
231    ) -> Result<(), Box<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>> {
232        #[cfg(not(any(test, feature = "__test_support")))]
233        return self.inner.send(value).map_err(Box::new);
234
235        #[cfg(any(test, feature = "__test_support"))]
236        if let Some(tx) = self.test_inner {
237            tx.send(value).map(|_| ()).map_err(|e| Box::new(e.0))
238        } else {
239            self.inner.send(value).map_err(Box::new)
240        }
241    }
242}