async_inspect/channel/
oneshot.rs

1//! Tracked oneshot channel
2//!
3//! A drop-in replacement for `tokio::sync::oneshot` that tracks message flow.
4
5use crate::channel::{ChannelMetrics, ChannelMetricsTracker};
6use std::fmt;
7use std::sync::Arc;
8use tokio::sync::oneshot as tokio_oneshot;
9
10/// Create a tracked oneshot channel.
11///
12/// # Arguments
13///
14/// * `name` - A descriptive name for debugging and metrics
15///
16/// # Example
17///
18/// ```rust,no_run
19/// use async_inspect::channel::oneshot;
20///
21/// #[tokio::main]
22/// async fn main() {
23///     let (tx, rx) = oneshot::channel::<String>("result");
24///
25///     tokio::spawn(async move {
26///         tx.send("done".into()).unwrap();
27///     });
28///
29///     let result = rx.await.unwrap();
30///     println!("Result: {}", result);
31/// }
32/// ```
33pub fn channel<T>(name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
34    let (tx, rx) = tokio_oneshot::channel();
35    let metrics = Arc::new(ChannelMetricsTracker::new());
36    let name = Arc::new(name.into());
37
38    (
39        Sender {
40            inner: Some(tx),
41            metrics: metrics.clone(),
42            name: name.clone(),
43        },
44        Receiver {
45            inner: Some(rx),
46            metrics,
47            name,
48        },
49    )
50}
51
52/// Tracked sender half of a oneshot channel.
53pub struct Sender<T> {
54    inner: Option<tokio_oneshot::Sender<T>>,
55    metrics: Arc<ChannelMetricsTracker>,
56    name: Arc<String>,
57}
58
59impl<T> Sender<T> {
60    /// Send a value.
61    ///
62    /// # Errors
63    ///
64    /// Returns the value if the receiver was dropped.
65    pub fn send(mut self, value: T) -> Result<(), T> {
66        if let Some(tx) = self.inner.take() {
67            match tx.send(value) {
68                Ok(()) => {
69                    self.metrics.record_send(None);
70                    Ok(())
71                }
72                Err(value) => {
73                    self.metrics.mark_closed();
74                    Err(value)
75                }
76            }
77        } else {
78            Err(value)
79        }
80    }
81
82    /// Check if the receiver has been dropped.
83    #[must_use]
84    pub fn is_closed(&self) -> bool {
85        self.inner
86            .as_ref()
87            .map_or(true, tokio::sync::oneshot::Sender::is_closed)
88    }
89
90    /// Get the channel name.
91    #[must_use]
92    pub fn name(&self) -> &str {
93        &self.name
94    }
95}
96
97impl<T> fmt::Debug for Sender<T> {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        f.debug_struct("oneshot::Sender")
100            .field("name", &self.name)
101            .finish()
102    }
103}
104
105impl<T> Drop for Sender<T> {
106    fn drop(&mut self) {
107        if self.inner.is_some() {
108            // Sender dropped without sending
109            self.metrics.mark_closed();
110        }
111    }
112}
113
114/// Tracked receiver half of a oneshot channel.
115pub struct Receiver<T> {
116    inner: Option<tokio_oneshot::Receiver<T>>,
117    metrics: Arc<ChannelMetricsTracker>,
118    name: Arc<String>,
119}
120
121impl<T> Receiver<T> {
122    /// Try to receive the value without waiting.
123    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
124        if let Some(rx) = self.inner.as_mut() {
125            match rx.try_recv() {
126                Ok(value) => {
127                    self.metrics.record_recv(None);
128                    self.inner = None;
129                    Ok(value)
130                }
131                Err(tokio_oneshot::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
132                Err(tokio_oneshot::error::TryRecvError::Closed) => {
133                    self.metrics.mark_closed();
134                    self.inner = None;
135                    Err(TryRecvError::Closed)
136                }
137            }
138        } else {
139            Err(TryRecvError::Closed)
140        }
141    }
142
143    /// Close the receiver, notifying the sender.
144    pub fn close(&mut self) {
145        if let Some(rx) = self.inner.as_mut() {
146            rx.close();
147            self.metrics.mark_closed();
148        }
149    }
150
151    /// Get the channel name.
152    #[must_use]
153    pub fn name(&self) -> &str {
154        &self.name
155    }
156
157    /// Get current metrics.
158    #[must_use]
159    pub fn metrics(&self) -> ChannelMetrics {
160        self.metrics.get_metrics(0)
161    }
162}
163
164impl<T> std::future::Future for Receiver<T> {
165    type Output = Result<T, RecvError>;
166
167    fn poll(
168        mut self: std::pin::Pin<&mut Self>,
169        cx: &mut std::task::Context<'_>,
170    ) -> std::task::Poll<Self::Output> {
171        if let Some(ref mut rx) = self.inner {
172            // SAFETY: We're not moving the inner receiver
173            let rx = unsafe { std::pin::Pin::new_unchecked(rx) };
174            match rx.poll(cx) {
175                std::task::Poll::Ready(Ok(value)) => {
176                    self.metrics.record_recv(None);
177                    self.inner = None;
178                    std::task::Poll::Ready(Ok(value))
179                }
180                std::task::Poll::Ready(Err(_)) => {
181                    self.metrics.mark_closed();
182                    self.inner = None;
183                    std::task::Poll::Ready(Err(RecvError(())))
184                }
185                std::task::Poll::Pending => std::task::Poll::Pending,
186            }
187        } else {
188            std::task::Poll::Ready(Err(RecvError(())))
189        }
190    }
191}
192
193impl<T> fmt::Debug for Receiver<T> {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        f.debug_struct("oneshot::Receiver")
196            .field("name", &self.name)
197            .finish()
198    }
199}
200
201/// Error returned when receiving fails because the sender was dropped.
202#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub struct RecvError(());
204
205impl fmt::Display for RecvError {
206    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207        write!(f, "channel closed")
208    }
209}
210
211impl std::error::Error for RecvError {}
212
213/// Error returned when `try_recv` fails.
214#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum TryRecvError {
216    /// The channel is empty (sender hasn't sent yet).
217    Empty,
218    /// The channel is closed.
219    Closed,
220}
221
222impl fmt::Display for TryRecvError {
223    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224        match self {
225            TryRecvError::Empty => write!(f, "channel empty"),
226            TryRecvError::Closed => write!(f, "channel closed"),
227        }
228    }
229}
230
231impl std::error::Error for TryRecvError {}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[tokio::test]
238    async fn test_oneshot_success() {
239        let (tx, rx) = channel::<i32>("test");
240
241        tx.send(42).unwrap();
242        let value = rx.await.unwrap();
243        assert_eq!(value, 42);
244    }
245
246    #[tokio::test]
247    async fn test_oneshot_sender_dropped() {
248        let (tx, rx) = channel::<i32>("test");
249        drop(tx);
250
251        assert!(rx.await.is_err());
252    }
253
254    #[tokio::test]
255    async fn test_oneshot_receiver_dropped() {
256        let (tx, rx) = channel::<i32>("test");
257        drop(rx);
258
259        assert!(tx.is_closed());
260        assert!(tx.send(42).is_err());
261    }
262
263    #[tokio::test]
264    async fn test_try_recv() {
265        let (tx, mut rx) = channel::<i32>("test");
266
267        assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
268
269        tx.send(42).unwrap();
270        assert_eq!(rx.try_recv().unwrap(), 42);
271    }
272}