d_engine_core/
maybe_clone_oneshot.rs

1//! A oneshot channel implementation that can optionally be cloned in test environments.
2//!
3//! Unlike `StreamResponseSender` which is specialized for gRPC streaming responses
4//! (`Result<tonic::Streaming<SnapshotChunk>, Status>`), this provides a generic oneshot
5//! channel for any `T: Send`:
6//! - Production: Regular oneshot semantics (non-cloneable)
7//! - Tests: Uses broadcast channel to allow cloning senders
8//!
9//! Key differences from `StreamResponseSender`:
10//! 1. Generic vs specialized (gRPC streaming)
11//! 2. Simpler error handling
12//! 3. Same test-friendly cloning pattern
13
14use std::fmt::Debug;
15use std::future::Future;
16use std::pin::Pin;
17use std::task::Context;
18use std::task::Poll;
19
20use d_engine_proto::server::storage::SnapshotChunk;
21#[cfg(any(test, feature = "test-utils"))]
22use tokio::sync::broadcast;
23use tokio::sync::oneshot;
24use tonic::Status;
25
26pub trait RaftOneshot<T: Send> {
27    type Sender: Send + Sync;
28    type Receiver: Send + Sync;
29
30    fn new() -> (Self::Sender, Self::Receiver);
31}
32
33pub struct MaybeCloneOneshot;
34
35pub struct MaybeCloneOneshotSender<T: Send> {
36    #[allow(dead_code)]
37    inner: oneshot::Sender<T>,
38
39    #[cfg(any(test, feature = "test-utils"))]
40    test_inner: Option<broadcast::Sender<T>>, // None for non-cloneable types
41}
42
43impl<T: Send> Debug for MaybeCloneOneshotSender<T> {
44    fn fmt(
45        &self,
46        f: &mut std::fmt::Formatter<'_>,
47    ) -> std::fmt::Result {
48        f.debug_struct("MaybeCloneOneshotSender").finish()
49    }
50}
51
52pub struct MaybeCloneOneshotReceiver<T: Send> {
53    #[allow(dead_code)]
54    inner: oneshot::Receiver<T>,
55
56    #[cfg(any(test, feature = "test-utils"))]
57    test_inner: Option<broadcast::Receiver<T>>, // None for non-cloneable types
58}
59#[cfg(any(test, feature = "test-utils"))]
60impl<T: Send> MaybeCloneOneshotSender<T> {
61    pub fn send(
62        &self,
63        value: T,
64    ) -> Result<usize, broadcast::error::SendError<T>> {
65        if let Some(tx) = &self.test_inner {
66            tx.send(value)
67        } else {
68            // Fallback for non-cloneable types
69            panic!("Cannot broadcast non-cloneable type in tests");
70        }
71    }
72}
73
74#[cfg(not(any(test, feature = "test-utils")))]
75impl<T: Send> MaybeCloneOneshotSender<T> {
76    pub fn send(
77        self,
78        value: T,
79    ) -> Result<(), T> {
80        self.inner.send(value)
81    }
82}
83
84impl<T: Send + Clone> MaybeCloneOneshotReceiver<T> {
85    #[cfg(any(test, feature = "test-utils"))]
86    pub async fn recv(&mut self) -> Result<T, broadcast::error::RecvError> {
87        if let Some(rx) = &mut self.test_inner {
88            rx.recv().await
89        } else {
90            // Fallback for non-cloneable types
91            panic!("Cannot broadcast non-cloneable type in tests");
92        }
93    }
94}
95
96#[cfg(not(any(test, feature = "test-utils")))]
97impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
98    type Output = Result<T, oneshot::error::RecvError>;
99
100    fn poll(
101        self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103    ) -> Poll<Self::Output> {
104        unsafe { self.map_unchecked_mut(|s| &mut s.inner) }.poll(cx)
105    }
106}
107#[cfg(any(test, feature = "test-utils"))]
108impl<T: Send + Clone> Future for MaybeCloneOneshotReceiver<T> {
109    type Output = Result<T, broadcast::error::RecvError>;
110
111    fn poll(
112        self: Pin<&mut Self>,
113        cx: &mut Context<'_>,
114    ) -> Poll<Self::Output> {
115        let this = self.get_mut();
116
117        // Using the recv method of tokio::sync::broadcast::Receiver
118        if let Some(rx) = &mut this.test_inner {
119            match rx.try_recv() {
120                Ok(value) => Poll::Ready(Ok(value)),
121                Err(broadcast::error::TryRecvError::Empty) => {
122                    // Register a Waker to wake up the task when data arrives
123                    cx.waker().wake_by_ref();
124                    Poll::Pending
125                }
126                Err(broadcast::error::TryRecvError::Closed) => {
127                    Poll::Ready(Err(broadcast::error::RecvError::Closed))
128                }
129                Err(broadcast::error::TryRecvError::Lagged(n)) => {
130                    Poll::Ready(Err(broadcast::error::RecvError::Lagged(n)))
131                }
132            }
133        } else {
134            // Fallback for non-cloneable types
135            panic!("Cannot broadcast non-cloneable type in tests");
136        }
137    }
138}
139
140#[cfg(any(test, feature = "test-utils"))]
141impl<T: Send + Clone> Clone for MaybeCloneOneshotSender<T> {
142    fn clone(&self) -> Self {
143        let (sender, _) = oneshot::channel();
144        Self {
145            inner: sender,
146            test_inner: self.test_inner.clone(),
147        }
148    }
149}
150#[cfg(any(test, feature = "test-utils"))]
151impl<T: Send + Clone> Clone for MaybeCloneOneshotReceiver<T> {
152    fn clone(&self) -> Self {
153        let (_, receiver) = oneshot::channel();
154
155        Self {
156            inner: receiver,
157            test_inner: Some(self.test_inner.as_ref().unwrap().resubscribe()),
158        }
159    }
160}
161#[cfg(any(test, feature = "test-utils"))]
162impl<T: Send + Clone> RaftOneshot<T> for MaybeCloneOneshot {
163    type Sender = MaybeCloneOneshotSender<T>;
164    type Receiver = MaybeCloneOneshotReceiver<T>;
165
166    fn new() -> (Self::Sender, Self::Receiver) {
167        let (tx, rx) = oneshot::channel();
168        let (test_tx, test_rx) = broadcast::channel(1);
169        (
170            MaybeCloneOneshotSender {
171                inner: tx,
172                test_inner: Some(test_tx),
173            },
174            MaybeCloneOneshotReceiver {
175                inner: rx,
176                test_inner: Some(test_rx),
177            },
178        )
179    }
180}
181
182#[cfg(not(any(test, feature = "test-utils")))]
183impl<T: Send> RaftOneshot<T> for MaybeCloneOneshot {
184    type Sender = MaybeCloneOneshotSender<T>;
185    type Receiver = MaybeCloneOneshotReceiver<T>;
186
187    fn new() -> (Self::Sender, Self::Receiver) {
188        let (tx, rx) = oneshot::channel();
189        (
190            MaybeCloneOneshotSender {
191                inner: tx,
192                #[cfg(any(test, feature = "test-utils"))]
193                test_inner: None,
194            },
195            MaybeCloneOneshotReceiver {
196                inner: rx,
197                #[cfg(any(test, feature = "test-utils"))]
198                test_inner: None,
199            },
200        )
201    }
202}
203
204#[derive(Debug)]
205pub struct StreamResponseSender {
206    inner: oneshot::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
207
208    #[cfg(any(test, feature = "test-utils"))]
209    test_inner:
210        Option<broadcast::Sender<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>>,
211}
212
213impl StreamResponseSender {
214    pub fn new() -> (
215        Self,
216        oneshot::Receiver<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>,
217    ) {
218        let (inner_tx, inner_rx) = oneshot::channel();
219        (
220            Self {
221                inner: inner_tx,
222                #[cfg(any(test, feature = "test-utils"))]
223                test_inner: None,
224            },
225            inner_rx,
226        )
227    }
228
229    pub fn send(
230        self,
231        value: std::result::Result<tonic::Streaming<SnapshotChunk>, Status>,
232    ) -> Result<(), Box<std::result::Result<tonic::Streaming<SnapshotChunk>, Status>>> {
233        #[cfg(not(any(test, feature = "test-utils")))]
234        return self.inner.send(value).map_err(Box::new);
235
236        #[cfg(any(test, feature = "test-utils"))]
237        if let Some(tx) = self.test_inner {
238            tx.send(value).map(|_| ()).map_err(|e| Box::new(e.0))
239        } else {
240            self.inner.send(value).map_err(Box::new)
241        }
242    }
243}