d_engine_core/
maybe_clone_oneshot.rs1use std::fmt::Debug;
15use std::future::Future;
16use std::pin::Pin;
17use std::task::Context;
18use std::task::Poll;
19
20use d_engine_proto::server::storage::SnapshotChunk;
21#[cfg(any(test, feature = "test-utils"))]
22use tokio::sync::broadcast;
23use tokio::sync::oneshot;
24use tonic::Status;
25
26pub trait RaftOneshot<T: Send> {
27 type Sender: Send + Sync;
28 type Receiver: Send + Sync;
29
30 fn new() -> (Self::Sender, Self::Receiver);
31}
32
33pub struct MaybeCloneOneshot;
34
35pub struct MaybeCloneOneshotSender<T: Send> {
36 #[allow(dead_code)]
37 inner: oneshot::Sender<T>,
38
39 #[cfg(any(test, feature = "test-utils"))]
40 test_inner: Option<broadcast::Sender<T>>, }
42
43impl<T: Send> Debug for MaybeCloneOneshotSender<T> {
44 fn fmt(
45 &self,
46 f: &mut std::fmt::Formatter<'_>,
47 ) -> std::fmt::Result {
48 f.debug_struct("MaybeCloneOneshotSender").finish()
49 }
50}
51
52pub struct MaybeCloneOneshotReceiver<T: Send> {
53 #[allow(dead_code)]
54 inner: oneshot::Receiver<T>,
55
56 #[cfg(any(test, feature = "test-utils"))]
57 test_inner: Option<broadcast::Receiver<T>>, }
59#[cfg(any(test, feature = "test-utils"))]
60impl<T: Send> MaybeCloneOneshotSender<T> {
61 pub fn send(
62 &self,
63 value: T,
64 ) -> Result<usize, broadcast::error::SendError<T>> {
65 if let Some(tx) = &self.test_inner {
66 tx.send(value)
67 } else {
68 panic!("Cannot broadcast non-cloneable type in tests");
70 }
71 }
72}
73
74#[cfg(not(any(test, feature = "test-utils")))]
75impl<T: Send> MaybeCloneOneshotSender<T> {
76 pub fn send(
77 self,
78 value: T,
79 ) -> Result<(), T> {
80 self.inner.send(value)
81 }
82}
83
84impl<T: Send + Clone> MaybeCloneOneshotReceiver<T> {
85 #[cfg(any(test, feature = "test-utils"))]
86 pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
87 if let Some(rx) = &mut self.test_inner {
88 rx.recv().await
89 } else {
90 panic!("Cannot broadcast non-cloneable type in tests");
92 }
93 }
94}
95
96#[cfg(not(any(test, feature = "test-utils")))]
97impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
98 type Output = Result<T, oneshot::error::RecvError>;
99
100 fn poll(
101 self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 ) -> Poll<Self::Output> {
104 unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll(cx)
105 }
106}
107#[cfg(any(test, feature = "test-utils"))]
108impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
109 type Output = Result<T, broadcast::error::RecvError>;
110
111 fn poll(
112 self: Pin<&mut Self>,
113 cx: &mut Context<'_>,
114 ) -> Poll<Self::Output> {
115 let this = self.get_mut();
116
117 if let Some(rx) = &mut this.test_inner {
119 match rx.try_recv() {
120 Ok(value) => Poll::Ready(Ok(value)),
121 Err(broadcast::error::TryRecvError::Empty) => {
122 cx.waker().wake_by_ref();
124 Poll::Pending
125 }
126 Err(broadcast::error::TryRecvError::Closed) => {
127 Poll::Ready(Err(broadcast::error::RecvError::Closed))
128 }
129 Err(broadcast::error::TryRecvError::Lagged(n)) => {
130 Poll::Ready(Err(broadcast::error::RecvError::Lagged(n)))
131 }
132 }
133 } else {
134 panic!("Cannot broadcast non-cloneable type in tests");
136 }
137 }
138}
139
140#[cfg(any(test, feature = "test-utils"))]
141impl<T: Send + Clone> Clone for MaybeCloneOneshotSender<T> {
142 fn clone(&self) -> Self {
143 let (sender, _) = oneshot::channel();
144 Self {
145 inner: sender,
146 test_inner: self.test_inner.clone(),
147 }
148 }
149}
150#[cfg(any(test, feature = "test-utils"))]
151impl<T: Send + Clone> Clone for MaybeCloneOneshotReceiver<T> {
152 fn clone(&self) -> Self {
153 let (_, receiver) = oneshot::channel();
154
155 Self {
156 inner: receiver,
157 test_inner: Some(self.test_inner.as_ref().unwrap().resubscribe()),
158 }
159 }
160}
161#[cfg(any(test, feature = "test-utils"))]
162impl<T: Send + Clone> RaftOneshot<T> for MaybeCloneOneshot {
163 type Sender = MaybeCloneOneshotSender<T>;
164 type Receiver = MaybeCloneOneshotReceiver<T>;
165
166 fn new() -> (Self::Sender, Self::Receiver) {
167 let (tx, rx) = oneshot::channel();
168 let (test_tx, test_rx) = broadcast::channel(1);
169 (
170 MaybeCloneOneshotSender {
171 inner: tx,
172 test_inner: Some(test_tx),
173 },
174 MaybeCloneOneshotReceiver {
175 inner: rx,
176 test_inner: Some(test_rx),
177 },
178 )
179 }
180}
181
182#[cfg(not(any(test, feature = "test-utils")))]
183impl<T: Send> RaftOneshot<T> for MaybeCloneOneshot {
184 type Sender = MaybeCloneOneshotSender<T>;
185 type Receiver = MaybeCloneOneshotReceiver<T>;
186
187 fn new() -> (Self::Sender, Self::Receiver) {
188 let (tx, rx) = oneshot::channel();
189 (
190 MaybeCloneOneshotSender {
191 inner: tx,
192 #[cfg(any(test, feature = "test-utils"))]
193 test_inner: None,
194 },
195 MaybeCloneOneshotReceiver {
196 inner: rx,
197 #[cfg(any(test, feature = "test-utils"))]
198 test_inner: None,
199 },
200 )
201 }
202}
203
204#[derive(Debug)]
205pub struct StreamResponseSender {
206 inner: oneshot::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
207
208 #[cfg(any(test, feature = "test-utils"))]
209 test_inner:
210 Option<broadcast::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>>,
211}
212
213impl StreamResponseSender {
214 pub fn new() -> (
215 Self,
216 oneshot::Receiver<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
217 ) {
218 let (inner_tx, inner_rx) = oneshot::channel();
219 (
220 Self {
221 inner: inner_tx,
222 #[cfg(any(test, feature = "test-utils"))]
223 test_inner: None,
224 },
225 inner_rx,
226 )
227 }
228
229 pub fn send(
230 self,
231 value: std::result::Result<tonic::Streaming<SnapshotChunk>, Status>,
232 ) -> Result<(), Box<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>> {
233 #[cfg(not(any(test, feature = "test-utils")))]
234 return self.inner.send(value).map_err(Box::new);
235
236 #[cfg(any(test, feature = "test-utils"))]
237 if let Some(tx) = self.test_inner {
238 tx.send(value).map(|_| ()).map_err(|e| Box::new(e.0))
239 } else {
240 self.inner.send(value).map_err(Box::new)
241 }
242 }
243}