1pub(crate) mod sync {
52 pub use std::sync::{
53 atomic::{AtomicUsize, Ordering},
54 Arc, Weak,
55 };
56}
57
58use std::{
59 future::Future,
60 pin::Pin,
61 task::{Context, Poll},
62};
63use sync::{Arc, AtomicUsize, Ordering, Weak};
64
65use atomic_waker::AtomicWaker;
66use take_once::TakeOnce;
67
68pub fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
93 let chan = Arc::new(Chan::new());
94 (Sender { chan: chan.clone() }, Receiver { chan })
95}
96
97#[derive(Debug)]
98struct Chan<T> {
99 sender_rc: AtomicUsize,
100 data: TakeOnce<T>,
101 waker: AtomicWaker,
102}
103
104impl<T> Chan<T> {
105 const fn new() -> Self {
106 Self {
107 sender_rc: AtomicUsize::new(1),
108 data: TakeOnce::new(),
109 waker: AtomicWaker::new(),
110 }
111 }
112
113 fn set(&self, data: T) -> Result<(), T> {
119 self.data.store(data)?;
120 self.waker.wake();
121
122 Ok(())
123 }
124
125 fn is_dropped(&self) -> bool {
127 self.sender_rc.load(Ordering::Acquire) == 0
128 }
129}
130
131#[derive(Debug)]
138pub struct Sender<T> {
139 chan: Arc<Chan<T>>,
140}
141
142#[derive(Debug)]
149pub struct Receiver<T> {
150 chan: Arc<Chan<T>>,
151}
152
153impl<T> Sender<T> {
154 pub fn send(&self, data: T) -> Result<(), T> {
174 self.chan.set(data)
175 }
176
177 pub fn downgrade(&self) -> WeakSender<T> {
180 WeakSender {
181 chan: Arc::downgrade(&self.chan),
182 }
183 }
184}
185
186impl<T> Clone for Sender<T> {
187 fn clone(&self) -> Self {
188 self.chan.sender_rc.fetch_add(1, Ordering::Release);
189 Self {
190 chan: self.chan.clone(),
191 }
192 }
193}
194
195impl<T> Drop for Sender<T> {
196 fn drop(&mut self) {
197 if self.chan.sender_rc.fetch_sub(1, Ordering::AcqRel) == 1 {
198 self.chan.waker.wake();
199 }
200 }
201}
202
203#[derive(Debug)]
210pub struct Recv<T> {
211 chan: Arc<Chan<T>>,
212}
213
214impl<T> Future for Recv<T> {
215 type Output = Option<T>;
216
217 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
218 if self.chan.data.is_completed() || self.chan.is_dropped() {
220 return Poll::Ready(self.chan.data.take());
222 }
223
224 self.chan.waker.register(cx.waker());
225
226 if self.chan.data.is_completed() || self.chan.is_dropped() {
227 Poll::Ready(self.chan.data.take())
228 } else {
229 Poll::Pending
230 }
231 }
232}
233
234impl<T> Drop for Receiver<T> {
235 fn drop(&mut self) {
236 self.chan.waker.take();
237 }
238}
239
240impl<T> Receiver<T> {
241 pub fn recv(&self) -> Recv<T> {
265 Recv {
266 chan: self.chan.clone(),
267 }
268 }
269}
270
271#[derive(Debug)]
272pub struct WeakSender<T> {
273 chan: Weak<Chan<T>>,
274}
275
276impl<T> WeakSender<T> {
277 pub fn send(&self, data: T) -> Result<(), T> {
297 match self.chan.upgrade() {
298 Some(chan) => chan.set(data),
299 None => Err(data),
300 }
301 }
302
303 pub fn upgrade(&self) -> Option<Sender<T>> {
304 let chan = self.chan.upgrade()?;
305 if chan.sender_rc.fetch_add(1, Ordering::Acquire) == 0 {
306 chan.sender_rc.fetch_sub(1, Ordering::AcqRel);
308 None
309 } else {
310 Some(Sender { chan })
311 }
312 }
313}
314
315impl<T> Clone for WeakSender<T> {
316 fn clone(&self) -> Self {
317 Self {
318 chan: self.chan.clone(),
319 }
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use std::sync::Arc;
327 use tokio::{
328 task::JoinSet,
329 time::{sleep, Duration},
330 };
331
332 #[tokio::test]
333 async fn test_basic_send_recv() {
334 let (tx, rx) = oneshot();
335 tx.send(42).unwrap();
336 assert_eq!(rx.recv().await, Some(42));
337 }
338
339 #[tokio::test]
340 async fn test_multiple_sends_fail() {
341 let (tx, rx) = oneshot();
342 assert!(tx.send(1).is_ok());
343 assert!(tx.send(2).is_err());
344 assert_eq!(rx.recv().await, Some(1));
345 }
346
347 #[tokio::test]
348 async fn test_multiple_receives_fail() {
349 let (tx, rx) = oneshot();
350 tx.send(1).unwrap();
351 assert_eq!(rx.recv().await, Some(1));
352 assert_eq!(rx.recv().await, None);
353 }
354
355 #[tokio::test]
356 async fn test_sender_drop_before_send() {
357 let (tx, rx) = oneshot::<i32>();
358 drop(tx);
359 assert_eq!(rx.recv().await, None);
360 }
361
362 #[tokio::test]
363 async fn test_receiver_drop_before_receive() {
364 let (tx, _rx) = oneshot();
365 assert!(tx.send(1).is_ok());
366 }
367
368 #[tokio::test]
369 async fn test_concurrent_send_receive() {
370 for _ in 0..1000 {
371 let (tx, rx) = oneshot();
372
373 let tx_handle = tokio::spawn(async move {
374 sleep(Duration::from_micros(1)).await;
375 tx.send(42)
376 });
377
378 let rx_handle = tokio::spawn(async move { rx.recv().await });
379
380 let (send_result, receive_result) = tokio::join!(tx_handle, rx_handle);
381 assert!(send_result.unwrap().is_ok());
382 assert_eq!(receive_result.unwrap(), Some(42));
383 }
384 }
385
386 #[tokio::test]
387 async fn test_clone_sender() {
388 let (tx1, rx) = oneshot();
389 let tx2 = tx1.clone();
390
391 let handle1 = tokio::spawn(async move {
393 sleep(Duration::from_micros(1)).await;
394 tx1.send(1)
395 });
396
397 let handle2 = tokio::spawn(async move {
398 sleep(Duration::from_micros(1)).await;
399 tx2.send(2)
400 });
401
402 let (result1, result2) = tokio::join!(handle1, handle2);
403 let results = [result1.unwrap(), result2.unwrap()];
404 assert!(results.iter().filter(|r| r.is_ok()).count() == 1);
405 assert!(results.iter().filter(|r| r.is_err()).count() == 1);
406
407 let received = rx.recv().await;
408 assert!(received.is_some());
409 assert!([1, 2].contains(&received.unwrap()));
410 }
411
412 #[tokio::test]
413 async fn test_sender_ref_counting() {
414 let (tx1, rx) = oneshot::<i32>();
415 let tx2 = tx1.clone();
416 let tx3 = tx2.clone();
417
418 assert!(!rx.chan.is_dropped());
419 drop(tx1);
420 assert!(!rx.chan.is_dropped());
421 drop(tx2);
422 assert!(!rx.chan.is_dropped());
423 drop(tx3);
424 assert!(rx.chan.is_dropped());
425 }
426
427 #[tokio::test]
428 async fn test_concurrent_clone_and_send() {
429 for _ in 0..1000 {
430 let (tx, rx) = oneshot();
431 let tx = Arc::new(tx);
432
433 let mut jset = JoinSet::new();
434
435 for i in 0..10 {
437 let tx = tx.clone();
438 jset.spawn(async move {
439 let tx = tx.clone();
440 sleep(Duration::from_micros(1)).await;
441 tx.send(i)
442 });
443 }
444
445 let results = jset.join_all().await;
446 let ok_count = results.iter().filter(|r| r.is_ok()).count();
447 assert_eq!(ok_count, 1);
448
449 let received = rx.recv().await;
450 assert!(received.is_some());
451 }
452 }
453
454 #[test]
455 fn test_sync_send() {
456 fn assert_sync<T: Sync>() {}
457 fn assert_send<T: Send>() {}
458
459 assert_sync::<Chan<i32>>();
460 assert_send::<Chan<i32>>();
461 assert_sync::<Sender<i32>>();
462 assert_send::<Sender<i32>>();
463 assert_sync::<Receiver<i32>>();
464 assert_send::<Receiver<i32>>();
465 }
466
467 #[tokio::test]
468 async fn test_concurrent_take_operations() {
469 for _ in 0..1000 {
470 let (tx, rx) = oneshot();
471 let rx = Arc::new(rx);
472
473 tx.send(42).unwrap();
474
475 let mut jset = JoinSet::new();
476 for _ in 0..10 {
477 let rx = rx.clone();
478 jset.spawn(tokio::spawn(async move { rx.recv().await }));
479 }
480
481 let results = jset.join_all().await;
482 let some_count = results
483 .iter()
484 .filter(|r| r.as_ref().unwrap().is_some())
485 .count();
486 assert_eq!(some_count, 1);
487 }
488 }
489
490 #[tokio::test]
491 async fn test_receive_after_sender_dropped() {
492 let (tx, rx) = oneshot();
493 tx.send(42).unwrap();
494 drop(tx);
495 assert_eq!(rx.recv().await, Some(42));
496 }
497
498 #[tokio::test]
499 async fn test_receive_timeout() {
500 let (tx, rx) = oneshot();
501
502 let timeout_result = tokio::time::timeout(Duration::from_millis(10), rx.recv()).await;
503
504 assert!(timeout_result.is_err()); tx.send(42).unwrap();
507
508 let timeout_result = tokio::time::timeout(Duration::from_millis(10), rx.recv()).await;
509
510 assert!(timeout_result.is_ok());
511 assert_eq!(timeout_result.unwrap(), Some(42));
512 }
513
514 #[tokio::test]
515 async fn test_detailed_sender_drops() {
516 let (tx, rx) = oneshot::<i32>();
518 drop(tx);
519 assert_eq!(rx.recv().await, None);
520
521 let (tx, rx) = oneshot::<i32>();
523 let tx2 = tx.clone();
524 drop(tx);
525 assert_eq!(tx2.send(42), Ok(()));
526 assert_eq!(rx.recv().await, Some(42));
527
528 let (tx, rx) = oneshot::<i32>();
530 let tx2 = tx.clone();
531 drop(tx);
532 drop(tx2);
533 assert_eq!(rx.recv().await, None);
534 }
535
536 #[tokio::test]
537 async fn test_detached_spawn_send() {
538 let (tx, rx) = oneshot::<i32>();
539
540 tokio::spawn(async move {
541 let send = tx.send(42);
542 assert_eq!(send, Ok(()));
543 let send = tx.send(43);
544 assert_eq!(send, Err(43));
545 });
546
547 let data = rx.recv().await;
548 assert_eq!(data, Some(42));
549 let data = rx.recv().await;
550 assert_eq!(data, None);
551 }
552
553 #[tokio::test]
554 async fn test_detached_spawn_with_clone() {
555 let (tx, rx) = oneshot::<i32>();
556
557 tokio::spawn(async move {
558 let tx2 = tx.clone();
559 let send = tx.send(42);
560 assert_eq!(send, Ok(()));
561 let send = tx2.send(43);
562 assert_eq!(send, Err(43));
563 });
564
565 let data = rx.recv().await;
566 assert_eq!(data, Some(42));
567 let data = rx.recv().await;
568 assert_eq!(data, None);
569 }
570
571 #[test]
572 fn trait_compiles() {
573 fn test_send<T: Send>() {}
574 fn test_sync<T: Sync>() {}
575
576 test_send::<Sender<i32>>();
577 test_send::<Receiver<i32>>();
578 test_sync::<Sender<i32>>();
579 test_sync::<Receiver<i32>>();
580 test_send::<Chan<i32>>();
581 test_sync::<Chan<i32>>();
582 }
583
584 #[tokio::test]
585 async fn test_weak_sender() {
586 let (tx, rx) = oneshot();
587 let weak_tx = tx.downgrade();
588 assert!(weak_tx.send(42).is_ok());
589 assert_eq!(rx.recv().await, Some(42));
590
591 assert!(weak_tx.upgrade().is_some());
592 drop(tx);
593 assert!(weak_tx.upgrade().is_none());
594 }
595}