Skip to main content

d_engine_core/
maybe_clone_oneshot.rs

1//! A oneshot channel implementation that can optionally be cloned in test environments.
2//!
3//! Unlike `StreamResponseSender` which is specialized for gRPC streaming responses
4//! (`Result<tonic::Streaming<SnapshotChunk>, Status>`), this provides a generic oneshot
5//! channel for any `T: Send`:
6//! - Production: Regular oneshot semantics (non-cloneable)
7//! - Tests: Uses broadcast channel to allow cloning senders
8//!
9//! Key differences from `StreamResponseSender`:
10//! 1. Generic vs specialized (gRPC streaming)
11//! 2. Simpler error handling
12//! 3. Same test-friendly cloning pattern
13
14use std::fmt::Debug;
15use std::future::Future;
16use std::pin::Pin;
17use std::task::Context;
18use std::task::Poll;
19
20use d_engine_proto::server::storage::SnapshotChunk;
21#[cfg(any(test, feature = "__test_support"))]
22use tokio::sync::broadcast;
23use tokio::sync::oneshot;
24use tonic::Status;
25
26pub trait RaftOneshot<T: Send> {
27    type Sender: Send + Sync;
28    type Receiver: Send + Sync;
29
30    fn new() -> (Self::Sender, Self::Receiver);
31}
32
33pub struct MaybeCloneOneshot;
34
35pub struct MaybeCloneOneshotSender<T: Send> {
36    #[allow(dead_code)]
37    inner: oneshot::Sender<T>,
38
39    #[cfg(any(test, feature = "__test_support"))]
40    test_inner: Option<broadcast::Sender<T>>, // None for non-cloneable types
41}
42
43impl<T: Send> Debug for MaybeCloneOneshotSender<T> {
44    fn fmt(
45        &self,
46        f: &mut std::fmt::Formatter<'_>,
47    ) -> std::fmt::Result {
48        f.debug_struct("MaybeCloneOneshotSender").finish()
49    }
50}
51
52pub struct MaybeCloneOneshotReceiver<T: Send> {
53    #[allow(dead_code)]
54    inner: oneshot::Receiver<T>,
55
56    #[cfg(any(test, feature = "__test_support"))]
57    test_inner: Option<broadcast::Receiver<T>>, // None for non-cloneable types
58}
59#[cfg(any(test, feature = "__test_support"))]
60impl<T: Send> MaybeCloneOneshotSender<T> {
61    pub fn send(
62        &self,
63        value: T,
64    ) -> Result<usize, broadcast::error::SendError<T>> {
65        if let Some(tx) = &self.test_inner {
66            tx.send(value)
67        } else {
68            // Fallback for non-cloneable types
69            panic!("Cannot broadcast non-cloneable type in tests");
70        }
71    }
72}
73
74#[cfg(not(any(test, feature = "__test_support")))]
75impl<T: Send> MaybeCloneOneshotSender<T> {
76    pub fn send(
77        self,
78        value: T,
79    ) -> Result<(), T> {
80        self.inner.send(value)
81    }
82}
83
84impl<T: Send + Clone> MaybeCloneOneshotReceiver<T> {
85    #[cfg(any(test, feature = "__test_support"))]
86    pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
87        if let Some(rx) = &mut self.test_inner {
88            rx.recv().await
89        } else {
90            // Fallback for non-cloneable types
91            panic!("Cannot broadcast non-cloneable type in tests");
92        }
93    }
94
95    #[cfg(any(test, feature = "__test_support"))]
96    pub fn try_recv(&mut self) -> Result<T, broadcast::error::TryRecvError> {
97        if let Some(rx) = &mut self.test_inner {
98            rx.try_recv()
99        } else {
100            panic!("Cannot try_recv non-cloneable type in tests");
101        }
102    }
103}
104
105#[cfg(not(any(test, feature = "__test_support")))]
106impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
107    type Output = Result<T, oneshot::error::RecvError>;
108
109    fn poll(
110        self: Pin<&mut Self>,
111        cx: &mut Context<'_>,
112    ) -> Poll<Self::Output> {
113        unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll(cx)
114    }
115}
116#[cfg(any(test, feature = "__test_support"))]
117impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
118    type Output = Result<T, broadcast::error::RecvError>;
119
120    fn poll(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123    ) -> Poll<Self::Output> {
124        let this = self.get_mut();
125
126        // Using the recv method of tokio::sync::broadcast::Receiver
127        if let Some(rx) = &mut this.test_inner {
128            match rx.try_recv() {
129                Ok(value) => Poll::Ready(Ok(value)),
130                Err(broadcast::error::TryRecvError::Empty) => {
131                    // Register a Waker to wake up the task when data arrives
132                    cx.waker().wake_by_ref();
133                    Poll::Pending
134                }
135                Err(broadcast::error::TryRecvError::Closed) => {
136                    Poll::Ready(Err(broadcast::error::RecvError::Closed))
137                }
138                Err(broadcast::error::TryRecvError::Lagged(n)) => {
139                    Poll::Ready(Err(broadcast::error::RecvError::Lagged(n)))
140                }
141            }
142        } else {
143            // Fallback for non-cloneable types
144            panic!("Cannot broadcast non-cloneable type in tests");
145        }
146    }
147}
148
149#[cfg(any(test, feature = "__test_support"))]
150impl<T: Send + Clone> Clone for MaybeCloneOneshotSender<T> {
151    fn clone(&self) -> Self {
152        let (sender, _) = oneshot::channel();
153        Self {
154            inner: sender,
155            test_inner: self.test_inner.clone(),
156        }
157    }
158}
159#[cfg(any(test, feature = "__test_support"))]
160impl<T: Send + Clone> Clone for MaybeCloneOneshotReceiver<T> {
161    fn clone(&self) -> Self {
162        let (_, receiver) = oneshot::channel();
163
164        Self {
165            inner: receiver,
166            test_inner: Some(self.test_inner.as_ref().unwrap().resubscribe()),
167        }
168    }
169}
170#[cfg(any(test, feature = "__test_support"))]
171impl<T: Send + Clone> RaftOneshot<T> for MaybeCloneOneshot {
172    type Sender = MaybeCloneOneshotSender<T>;
173    type Receiver = MaybeCloneOneshotReceiver<T>;
174
175    fn new() -> (Self::Sender, Self::Receiver) {
176        let (tx, rx) = oneshot::channel();
177        let (test_tx, test_rx) = broadcast::channel(1);
178        (
179            MaybeCloneOneshotSender {
180                inner: tx,
181                test_inner: Some(test_tx),
182            },
183            MaybeCloneOneshotReceiver {
184                inner: rx,
185                test_inner: Some(test_rx),
186            },
187        )
188    }
189}
190#[cfg(not(any(test, feature = "__test_support")))]
191impl<T: Send> RaftOneshot<T> for MaybeCloneOneshot {
192    type Sender = MaybeCloneOneshotSender<T>;
193    type Receiver = MaybeCloneOneshotReceiver<T>;
194
195    fn new() -> (Self::Sender, Self::Receiver) {
196        let (tx, rx) = oneshot::channel();
197        (
198            MaybeCloneOneshotSender {
199                inner: tx,
200                #[cfg(any(test, feature = "__test_support"))]
201                test_inner: None,
202            },
203            MaybeCloneOneshotReceiver {
204                inner: rx,
205                #[cfg(any(test, feature = "__test_support"))]
206                test_inner: None,
207            },
208        )
209    }
210}
211
212#[derive(Debug)]
213pub struct StreamResponseSender {
214    inner: oneshot::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
215
216    #[cfg(any(test, feature = "__test_support"))]
217    test_inner:
218        Option<broadcast::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>>,
219}
220
221impl StreamResponseSender {
222    pub fn new() -> (
223        Self,
224        oneshot::Receiver<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
225    ) {
226        let (inner_tx, inner_rx) = oneshot::channel();
227        (
228            Self {
229                inner: inner_tx,
230                #[cfg(any(test, feature = "__test_support"))]
231                test_inner: None,
232            },
233            inner_rx,
234        )
235    }
236
237    pub fn send(
238        self,
239        value: std::result::Result<tonic::Streaming<SnapshotChunk>, Status>,
240    ) -> Result<(), Box<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>> {
241        #[cfg(not(any(test, feature = "__test_support")))]
242        return self.inner.send(value).map_err(Box::new);
243
244        #[cfg(any(test, feature = "__test_support"))]
245        if let Some(tx) = self.test_inner {
246            tx.send(value).map(|_| ()).map_err(|e| Box::new(e.0))
247        } else {
248            self.inner.send(value).map_err(Box::new)
249        }
250    }
251}