1use std::cell::UnsafeCell;
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::atomic::{AtomicU8, Ordering};
11use std::sync::{Arc, Mutex};
12use std::task::{Context, Poll, Waker};
13
14const EMPTY: u8 = 0;
18const SENT: u8 = 1;
20const CLOSED: u8 = 2;
22
23struct Inner<T> {
26 state: AtomicU8,
28 value: UnsafeCell<Option<T>>,
35 waker: Mutex<Option<Waker>>,
37}
38
39unsafe impl<T: Send> Send for Inner<T> {}
44unsafe impl<T: Send> Sync for Inner<T> {}
45
46impl<T> Inner<T> {
47 fn new() -> Self {
48 Self {
49 state: AtomicU8::new(EMPTY),
50 value: UnsafeCell::new(None),
51 waker: Mutex::new(None),
52 }
53 }
54}
55
56#[derive(Debug, PartialEq, Eq)]
60pub enum RecvError {
61 Closed,
63}
64
65impl std::fmt::Display for RecvError {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 match self {
68 RecvError::Closed => f.write_str("oneshot channel closed without a value"),
69 }
70 }
71}
72
73impl std::error::Error for RecvError {}
74
75pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
77 let inner = Arc::new(Inner::new());
78 (
79 Sender {
80 inner: inner.clone(),
81 sent: false,
82 },
83 Receiver { inner },
84 )
85}
86
87pub struct Sender<T> {
91 inner: Arc<Inner<T>>,
92 sent: bool,
94}
95
96impl<T> Sender<T> {
97 pub fn send(mut self, value: T) -> Result<(), T> {
101 unsafe { *self.inner.value.get() = Some(value) };
107
108 match self.inner.state.compare_exchange(
109 EMPTY,
110 SENT,
111 Ordering::Release, Ordering::Relaxed,
113 ) {
114 Ok(_) => {
115 self.sent = true;
116 if let Some(w) = self.inner.waker.lock().unwrap().take() {
118 w.wake();
119 }
120 Ok(())
121 }
122 Err(_) => {
123 let val = unsafe { (*self.inner.value.get()).take() }.unwrap();
127 Err(val)
128 }
129 }
130 }
131}
132
133impl<T> Drop for Sender<T> {
134 fn drop(&mut self) {
135 if self.sent {
136 return; }
138 let prev = self.inner.state.swap(CLOSED, Ordering::Release);
140 if prev == EMPTY {
141 if let Some(w) = self.inner.waker.lock().unwrap().take() {
142 w.wake();
143 }
144 }
145 }
146}
147
148pub struct Receiver<T> {
152 inner: Arc<Inner<T>>,
153}
154
155impl<T> Future for Receiver<T> {
156 type Output = Result<T, RecvError>;
157
158 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159 let state = self.inner.state.load(Ordering::Acquire);
160 match state {
161 SENT => {
162 let val = unsafe { (*self.inner.value.get()).take() }
166 .expect("oneshot: SENT state but value is None (logic error)");
167 Poll::Ready(Ok(val))
168 }
169 CLOSED => Poll::Ready(Err(RecvError::Closed)),
170 _ => {
171 *self.inner.waker.lock().unwrap() = Some(cx.waker().clone());
173 let state2 = self.inner.state.load(Ordering::Acquire);
175 if state2 == SENT {
176 let val = unsafe { (*self.inner.value.get()).take() }
178 .expect("oneshot: SENT but value None after re-check");
179 Poll::Ready(Ok(val))
180 } else if state2 == CLOSED {
181 Poll::Ready(Err(RecvError::Closed))
182 } else {
183 Poll::Pending
184 }
185 }
186 }
187 }
188}
189
190impl<T> Drop for Receiver<T> {
191 fn drop(&mut self) {
192 let _ =
196 self.inner
197 .state
198 .compare_exchange(EMPTY, CLOSED, Ordering::Relaxed, Ordering::Relaxed);
199 }
200}
201
202#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::executor::{block_on, block_on_with_spawn, spawn};
208
209 #[test]
210 fn send_then_recv() {
211 let result = block_on(async {
212 let (tx, rx) = channel::<u32>();
213 tx.send(42).unwrap();
214 rx.await
215 });
216 assert_eq!(result, Ok(42));
217 }
218
219 #[test]
220 fn recv_then_send_via_spawn() {
221 let result = block_on_with_spawn(async {
222 let (tx, rx) = channel::<String>();
223 let jh = spawn(async move {
224 tx.send("hello".to_string()).unwrap();
225 });
226 let val = rx.await.unwrap();
227 jh.await.unwrap();
228 val
229 });
230 assert_eq!(result, "hello");
231 }
232
233 #[test]
234 fn sender_drop_closes_channel() {
235 let result = block_on(async {
236 let (tx, rx) = channel::<u32>();
237 drop(tx);
238 rx.await
239 });
240 assert_eq!(result, Err(RecvError::Closed));
241 }
242
243 #[test]
244 fn send_after_receiver_drop_returns_err() {
245 let (tx, rx) = channel::<u32>();
246 drop(rx);
247 assert!(tx.send(1).is_err());
248 }
249
250 #[test]
251 fn value_types_roundtrip() {
252 block_on(async {
253 let (tx, rx) = channel::<Vec<u8>>();
254 tx.send(vec![1, 2, 3]).unwrap();
255 assert_eq!(rx.await.unwrap(), vec![1, 2, 3]);
256 });
257 }
258
259 #[test]
262 fn oneshot_send_string() {
263 let result = block_on(async {
264 let (tx, rx) = channel::<String>();
265 tx.send("world".to_string()).unwrap();
266 rx.await
267 });
268 assert_eq!(result.unwrap(), "world");
269 }
270
271 #[test]
272 fn oneshot_send_struct() {
273 #[derive(Debug, PartialEq)]
274 struct Point {
275 x: i32,
276 y: i32,
277 }
278 let result = block_on(async {
279 let (tx, rx) = channel::<Point>();
280 tx.send(Point { x: 1, y: 2 }).unwrap();
281 rx.await
282 });
283 assert_eq!(result.unwrap(), Point { x: 1, y: 2 });
284 }
285
286 #[test]
287 fn oneshot_send_vec() {
288 let result = block_on(async {
289 let (tx, rx) = channel::<Vec<u8>>();
290 tx.send(vec![1, 2, 3, 4, 5]).unwrap();
291 rx.await
292 });
293 assert_eq!(result.unwrap(), vec![1, 2, 3, 4, 5]);
294 }
295
296 #[test]
297 fn oneshot_multiple_pairs_concurrent() {
298 block_on_with_spawn(async {
299 let mut rxs = Vec::new();
300 for i in 0u32..10 {
301 let (tx, rx) = channel::<u32>();
302 spawn(async move {
303 tx.send(i).unwrap();
304 });
305 rxs.push(rx);
306 }
307 let mut results: Vec<u32> = Vec::new();
308 for rx in rxs {
309 results.push(rx.await.unwrap());
310 }
311 results.sort();
312 assert_eq!(results, (0..10).collect::<Vec<_>>());
313 });
314 }
315
316 #[test]
317 fn oneshot_recv_error_display() {
318 let err = RecvError::Closed;
319 let s = format!("{err}");
320 assert!(s.contains("closed") || s.contains("Closed"));
321 }
322
323 #[test]
324 fn oneshot_send_returns_err_when_rx_dropped() {
325 let (tx, rx) = channel::<i32>();
326 drop(rx);
327 let result = tx.send(42);
328 assert_eq!(result, Err(42));
329 }
330
331 #[test]
332 fn oneshot_send_value_then_recv_in_separate_block_on() {
333 let (tx, rx) = channel::<u64>();
335 tx.send(12345).unwrap();
336 let val = block_on(async { rx.await.unwrap() });
337 assert_eq!(val, 12345);
338 }
339
340 #[test]
341 fn oneshot_sender_drop_closes_from_spawn() {
342 let result = block_on_with_spawn(async {
343 let (tx, rx) = channel::<u32>();
344 let jh = spawn(async move {
346 drop(tx);
347 });
348 jh.await.unwrap();
349 rx.await
350 });
351 assert_eq!(result, Err(RecvError::Closed));
352 }
353
354 #[test]
355 fn oneshot_recv_error_is_error_trait() {
356 let err = RecvError::Closed;
357 let _e: &dyn std::error::Error = &err;
359 }
360
361 #[test]
362 fn oneshot_u8_roundtrip() {
363 let result = block_on(async {
364 let (tx, rx) = channel::<u8>();
365 tx.send(255).unwrap();
366 rx.await.unwrap()
367 });
368 assert_eq!(result, 255);
369 }
370
371 #[test]
372 fn oneshot_bool_roundtrip() {
373 let result = block_on(async {
374 let (tx, rx) = channel::<bool>();
375 tx.send(true).unwrap();
376 rx.await.unwrap()
377 });
378 assert!(result);
379 }
380
381 #[test]
382 fn oneshot_unit_roundtrip() {
383 let result = block_on(async {
384 let (tx, rx) = channel::<()>();
385 tx.send(()).unwrap();
386 rx.await.unwrap()
387 });
388 assert_eq!(result, ());
389 }
390
391 #[test]
392 fn oneshot_10_pairs_in_parallel() {
393 block_on_with_spawn(async {
394 let mut rxs = Vec::new();
395 for i in 0..10u32 {
396 let (tx, rx) = channel::<u32>();
397 let val = i * 3;
398 spawn(async move { tx.send(val).unwrap() });
399 rxs.push((i, rx));
400 }
401 for (i, rx) in rxs {
402 let v = rx.await.unwrap();
403 assert_eq!(v, i * 3);
404 }
405 });
406 }
407
408 #[test]
409 fn oneshot_send_before_poll_synchronous() {
410 let (tx, rx) = channel::<u32>();
412 tx.send(777).unwrap();
413 let v = block_on(async { rx.await.unwrap() });
414 assert_eq!(v, 777);
415 }
416
417 #[test]
418 fn oneshot_send_i64_max() {
419 let result = block_on(async {
420 let (tx, rx) = channel::<i64>();
421 tx.send(i64::MAX).unwrap();
422 rx.await.unwrap()
423 });
424 assert_eq!(result, i64::MAX);
425 }
426
427 #[test]
428 fn oneshot_send_i64_min() {
429 let result = block_on(async {
430 let (tx, rx) = channel::<i64>();
431 tx.send(i64::MIN).unwrap();
432 rx.await.unwrap()
433 });
434 assert_eq!(result, i64::MIN);
435 }
436
437 #[test]
438 fn oneshot_send_empty_vec() {
439 let result = block_on(async {
440 let (tx, rx) = channel::<Vec<u8>>();
441 tx.send(Vec::new()).unwrap();
442 rx.await.unwrap()
443 });
444 assert!(result.is_empty());
445 }
446
447 #[test]
448 fn oneshot_two_separate_channels_independent() {
449 block_on(async {
450 let (tx1, rx1) = channel::<u32>();
451 let (tx2, rx2) = channel::<u32>();
452 tx1.send(1).unwrap();
453 tx2.send(2).unwrap();
454 assert_eq!(rx1.await.unwrap(), 1);
455 assert_eq!(rx2.await.unwrap(), 2);
456 });
457 }
458}