async_inspect/channel/
oneshot.rs1use crate::channel::{ChannelMetrics, ChannelMetricsTracker};
6use std::fmt;
7use std::sync::Arc;
8use tokio::sync::oneshot as tokio_oneshot;
9
10pub fn channel<T>(name: impl Into<String>) -> (Sender<T>, Receiver<T>) {
34 let (tx, rx) = tokio_oneshot::channel();
35 let metrics = Arc::new(ChannelMetricsTracker::new());
36 let name = Arc::new(name.into());
37
38 (
39 Sender {
40 inner: Some(tx),
41 metrics: metrics.clone(),
42 name: name.clone(),
43 },
44 Receiver {
45 inner: Some(rx),
46 metrics,
47 name,
48 },
49 )
50}
51
52pub struct Sender<T> {
54 inner: Option<tokio_oneshot::Sender<T>>,
55 metrics: Arc<ChannelMetricsTracker>,
56 name: Arc<String>,
57}
58
59impl<T> Sender<T> {
60 pub fn send(mut self, value: T) -> Result<(), T> {
66 if let Some(tx) = self.inner.take() {
67 match tx.send(value) {
68 Ok(()) => {
69 self.metrics.record_send(None);
70 Ok(())
71 }
72 Err(value) => {
73 self.metrics.mark_closed();
74 Err(value)
75 }
76 }
77 } else {
78 Err(value)
79 }
80 }
81
82 #[must_use]
84 pub fn is_closed(&self) -> bool {
85 self.inner
86 .as_ref()
87 .map_or(true, tokio::sync::oneshot::Sender::is_closed)
88 }
89
90 #[must_use]
92 pub fn name(&self) -> &str {
93 &self.name
94 }
95}
96
97impl<T> fmt::Debug for Sender<T> {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 f.debug_struct("oneshot::Sender")
100 .field("name", &self.name)
101 .finish()
102 }
103}
104
105impl<T> Drop for Sender<T> {
106 fn drop(&mut self) {
107 if self.inner.is_some() {
108 self.metrics.mark_closed();
110 }
111 }
112}
113
114pub struct Receiver<T> {
116 inner: Option<tokio_oneshot::Receiver<T>>,
117 metrics: Arc<ChannelMetricsTracker>,
118 name: Arc<String>,
119}
120
121impl<T> Receiver<T> {
122 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
124 if let Some(rx) = self.inner.as_mut() {
125 match rx.try_recv() {
126 Ok(value) => {
127 self.metrics.record_recv(None);
128 self.inner = None;
129 Ok(value)
130 }
131 Err(tokio_oneshot::error::TryRecvError::Empty) => Err(TryRecvError::Empty),
132 Err(tokio_oneshot::error::TryRecvError::Closed) => {
133 self.metrics.mark_closed();
134 self.inner = None;
135 Err(TryRecvError::Closed)
136 }
137 }
138 } else {
139 Err(TryRecvError::Closed)
140 }
141 }
142
143 pub fn close(&mut self) {
145 if let Some(rx) = self.inner.as_mut() {
146 rx.close();
147 self.metrics.mark_closed();
148 }
149 }
150
151 #[must_use]
153 pub fn name(&self) -> &str {
154 &self.name
155 }
156
157 #[must_use]
159 pub fn metrics(&self) -> ChannelMetrics {
160 self.metrics.get_metrics(0)
161 }
162}
163
164impl<T> std::future::Future for Receiver<T> {
165 type Output = Result<T, RecvError>;
166
167 fn poll(
168 mut self: std::pin::Pin<&mut Self>,
169 cx: &mut std::task::Context<'_>,
170 ) -> std::task::Poll<Self::Output> {
171 if let Some(ref mut rx) = self.inner {
172 let rx = unsafe { std::pin::Pin::new_unchecked(rx) };
174 match rx.poll(cx) {
175 std::task::Poll::Ready(Ok(value)) => {
176 self.metrics.record_recv(None);
177 self.inner = None;
178 std::task::Poll::Ready(Ok(value))
179 }
180 std::task::Poll::Ready(Err(_)) => {
181 self.metrics.mark_closed();
182 self.inner = None;
183 std::task::Poll::Ready(Err(RecvError(())))
184 }
185 std::task::Poll::Pending => std::task::Poll::Pending,
186 }
187 } else {
188 std::task::Poll::Ready(Err(RecvError(())))
189 }
190 }
191}
192
193impl<T> fmt::Debug for Receiver<T> {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("oneshot::Receiver")
196 .field("name", &self.name)
197 .finish()
198 }
199}
200
201#[derive(Debug, Clone, Copy, PartialEq, Eq)]
203pub struct RecvError(());
204
205impl fmt::Display for RecvError {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 write!(f, "channel closed")
208 }
209}
210
211impl std::error::Error for RecvError {}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq)]
215pub enum TryRecvError {
216 Empty,
218 Closed,
220}
221
222impl fmt::Display for TryRecvError {
223 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
224 match self {
225 TryRecvError::Empty => write!(f, "channel empty"),
226 TryRecvError::Closed => write!(f, "channel closed"),
227 }
228 }
229}
230
231impl std::error::Error for TryRecvError {}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[tokio::test]
238 async fn test_oneshot_success() {
239 let (tx, rx) = channel::<i32>("test");
240
241 tx.send(42).unwrap();
242 let value = rx.await.unwrap();
243 assert_eq!(value, 42);
244 }
245
246 #[tokio::test]
247 async fn test_oneshot_sender_dropped() {
248 let (tx, rx) = channel::<i32>("test");
249 drop(tx);
250
251 assert!(rx.await.is_err());
252 }
253
254 #[tokio::test]
255 async fn test_oneshot_receiver_dropped() {
256 let (tx, rx) = channel::<i32>("test");
257 drop(rx);
258
259 assert!(tx.is_closed());
260 assert!(tx.send(42).is_err());
261 }
262
263 #[tokio::test]
264 async fn test_try_recv() {
265 let (tx, mut rx) = channel::<i32>("test");
266
267 assert!(matches!(rx.try_recv(), Err(TryRecvError::Empty)));
268
269 tx.send(42).unwrap();
270 assert_eq!(rx.try_recv().unwrap(), 42);
271 }
272}