d_engine_core/
maybe_clone_oneshot.rs1use std::fmt::Debug;
15use std::future::Future;
16use std::pin::Pin;
17use std::task::Context;
18use std::task::Poll;
19
20use d_engine_proto::server::storage::SnapshotChunk;
21#[cfg(any(test, feature = "__test_support"))]
22use tokio::sync::broadcast;
23use tokio::sync::oneshot;
24use tonic::Status;
25
26pub trait RaftOneshot<T: Send> {
27 type Sender: Send + Sync;
28 type Receiver: Send + Sync;
29
30 fn new() -> (Self::Sender, Self::Receiver);
31}
32
33pub struct MaybeCloneOneshot;
34
35pub struct MaybeCloneOneshotSender<T: Send> {
36 #[allow(dead_code)]
37 inner: oneshot::Sender<T>,
38
39 #[cfg(any(test, feature = "__test_support"))]
40 test_inner: Option<broadcast::Sender<T>>, }
42
43impl<T: Send> Debug for MaybeCloneOneshotSender<T> {
44 fn fmt(
45 &self,
46 f: &mut std::fmt::Formatter<'_>,
47 ) -> std::fmt::Result {
48 f.debug_struct("MaybeCloneOneshotSender").finish()
49 }
50}
51
52pub struct MaybeCloneOneshotReceiver<T: Send> {
53 #[allow(dead_code)]
54 inner: oneshot::Receiver<T>,
55
56 #[cfg(any(test, feature = "__test_support"))]
57 test_inner: Option<broadcast::Receiver<T>>, }
59#[cfg(any(test, feature = "__test_support"))]
60impl<T: Send> MaybeCloneOneshotSender<T> {
61 pub fn send(
62 &self,
63 value: T,
64 ) -> Result<usize, broadcast::error::SendError<T>> {
65 if let Some(tx) = &self.test_inner {
66 tx.send(value)
67 } else {
68 panic!("Cannot broadcast non-cloneable type in tests");
70 }
71 }
72}
73
74#[cfg(not(any(test, feature = "__test_support")))]
75impl<T: Send> MaybeCloneOneshotSender<T> {
76 pub fn send(
77 self,
78 value: T,
79 ) -> Result<(), T> {
80 self.inner.send(value)
81 }
82}
83
84impl<T: Send + Clone> MaybeCloneOneshotReceiver<T> {
85 #[cfg(any(test, feature = "__test_support"))]
86 pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
87 if let Some(rx) = &mut self.test_inner {
88 rx.recv().await
89 } else {
90 panic!("Cannot broadcast non-cloneable type in tests");
92 }
93 }
94
95 #[cfg(any(test, feature = "__test_support"))]
96 pub fn try_recv(&mut self) -> Result<T, broadcast::error::TryRecvError> {
97 if let Some(rx) = &mut self.test_inner {
98 rx.try_recv()
99 } else {
100 panic!("Cannot try_recv non-cloneable type in tests");
101 }
102 }
103}
104
105#[cfg(not(any(test, feature = "__test_support")))]
106impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
107 type Output = Result<T, oneshot::error::RecvError>;
108
109 fn poll(
110 self: Pin<&mut Self>,
111 cx: &mut Context<'_>,
112 ) -> Poll<Self::Output> {
113 unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll(cx)
114 }
115}
116#[cfg(any(test, feature = "__test_support"))]
117impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
118 type Output = Result<T, broadcast::error::RecvError>;
119
120 fn poll(
121 self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 ) -> Poll<Self::Output> {
124 let this = self.get_mut();
125
126 if let Some(rx) = &mut this.test_inner {
128 match rx.try_recv() {
129 Ok(value) => Poll::Ready(Ok(value)),
130 Err(broadcast::error::TryRecvError::Empty) => {
131 cx.waker().wake_by_ref();
133 Poll::Pending
134 }
135 Err(broadcast::error::TryRecvError::Closed) => {
136 Poll::Ready(Err(broadcast::error::RecvError::Closed))
137 }
138 Err(broadcast::error::TryRecvError::Lagged(n)) => {
139 Poll::Ready(Err(broadcast::error::RecvError::Lagged(n)))
140 }
141 }
142 } else {
143 panic!("Cannot broadcast non-cloneable type in tests");
145 }
146 }
147}
148
149#[cfg(any(test, feature = "__test_support"))]
150impl<T: Send + Clone> Clone for MaybeCloneOneshotSender<T> {
151 fn clone(&self) -> Self {
152 let (sender, _) = oneshot::channel();
153 Self {
154 inner: sender,
155 test_inner: self.test_inner.clone(),
156 }
157 }
158}
159#[cfg(any(test, feature = "__test_support"))]
160impl<T: Send + Clone> Clone for MaybeCloneOneshotReceiver<T> {
161 fn clone(&self) -> Self {
162 let (_, receiver) = oneshot::channel();
163
164 Self {
165 inner: receiver,
166 test_inner: Some(self.test_inner.as_ref().unwrap().resubscribe()),
167 }
168 }
169}
170#[cfg(any(test, feature = "__test_support"))]
171impl<T: Send + Clone> RaftOneshot<T> for MaybeCloneOneshot {
172 type Sender = MaybeCloneOneshotSender<T>;
173 type Receiver = MaybeCloneOneshotReceiver<T>;
174
175 fn new() -> (Self::Sender, Self::Receiver) {
176 let (tx, rx) = oneshot::channel();
177 let (test_tx, test_rx) = broadcast::channel(1);
178 (
179 MaybeCloneOneshotSender {
180 inner: tx,
181 test_inner: Some(test_tx),
182 },
183 MaybeCloneOneshotReceiver {
184 inner: rx,
185 test_inner: Some(test_rx),
186 },
187 )
188 }
189}
190#[cfg(not(any(test, feature = "__test_support")))]
191impl<T: Send> RaftOneshot<T> for MaybeCloneOneshot {
192 type Sender = MaybeCloneOneshotSender<T>;
193 type Receiver = MaybeCloneOneshotReceiver<T>;
194
195 fn new() -> (Self::Sender, Self::Receiver) {
196 let (tx, rx) = oneshot::channel();
197 (
198 MaybeCloneOneshotSender {
199 inner: tx,
200 #[cfg(any(test, feature = "__test_support"))]
201 test_inner: None,
202 },
203 MaybeCloneOneshotReceiver {
204 inner: rx,
205 #[cfg(any(test, feature = "__test_support"))]
206 test_inner: None,
207 },
208 )
209 }
210}
211
212#[derive(Debug)]
213pub struct StreamResponseSender {
214 inner: oneshot::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
215
216 #[cfg(any(test, feature = "__test_support"))]
217 test_inner:
218 Option<broadcast::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>>,
219}
220
221impl StreamResponseSender {
222 pub fn new() -> (
223 Self,
224 oneshot::Receiver<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
225 ) {
226 let (inner_tx, inner_rx) = oneshot::channel();
227 (
228 Self {
229 inner: inner_tx,
230 #[cfg(any(test, feature = "__test_support"))]
231 test_inner: None,
232 },
233 inner_rx,
234 )
235 }
236
237 pub fn send(
238 self,
239 value: std::result::Result<tonic::Streaming<SnapshotChunk>, Status>,
240 ) -> Result<(), Box<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>> {
241 #[cfg(not(any(test, feature = "__test_support")))]
242 return self.inner.send(value).map_err(Box::new);
243
244 #[cfg(any(test, feature = "__test_support"))]
245 if let Some(tx) = self.test_inner {
246 tx.send(value).map(|_| ()).map_err(|e| Box::new(e.0))
247 } else {
248 self.inner.send(value).map_err(Box::new)
249 }
250 }
251}