1use crate::shim::atomic::{AtomicU8, Ordering};
6
7use super::common::{OneshotStorage, TakeResult};
8
9pub use super::common::RecvError;
11pub use super::common::TryRecvError;
12pub use super::common::error;
13
14pub trait State: Sized + Send + Sync + 'static {
87 fn to_u8(&self) -> u8;
91
92 fn from_u8(value: u8) -> Option<Self>;
100
101 fn pending_value() -> u8;
105
106 fn closed_value() -> u8;
110
111 fn receiver_closed_value() -> u8;
115}
116
117impl State for () {
121 #[inline]
122 fn to_u8(&self) -> u8 {
123 1 }
125
126 #[inline]
127 fn from_u8(value: u8) -> Option<Self> {
128 match value {
129 1 => Some(()),
130 _ => None,
131 }
132 }
133
134 #[inline]
135 fn pending_value() -> u8 {
136 0 }
138
139 #[inline]
140 fn closed_value() -> u8 {
141 255 }
143
144 #[inline]
145 fn receiver_closed_value() -> u8 {
146 254 }
148}
149
150pub struct LiteStorage<S: State> {
158 state: AtomicU8,
159 _marker: core::marker::PhantomData<S>,
160}
161
162unsafe impl<S: State> Send for LiteStorage<S> {}
163unsafe impl<S: State> Sync for LiteStorage<S> {}
164
165impl<S: State> OneshotStorage for LiteStorage<S> {
166 type Value = S;
167
168 #[inline]
169 fn new() -> Self {
170 Self {
171 state: AtomicU8::new(S::pending_value()),
172 _marker: core::marker::PhantomData,
173 }
174 }
175
176 #[inline]
177 fn store(&self, value: S) {
178 self.state.store(value.to_u8(), Ordering::Release);
179 }
180
181 #[inline]
182 fn try_take(&self) -> TakeResult<S> {
183 let current = self.state.load(Ordering::Acquire);
184
185 if current == S::closed_value() || current == S::receiver_closed_value() {
186 return TakeResult::Closed;
187 }
188
189 if current == S::pending_value() {
190 return TakeResult::Pending;
191 }
192
193 if let Some(state) = S::from_u8(current) {
195 TakeResult::Ready(state)
196 } else {
197 TakeResult::Pending
198 }
199 }
200
201 #[inline]
202 fn is_sender_dropped(&self) -> bool {
203 self.state.load(Ordering::Acquire) == S::closed_value()
204 }
205
206 #[inline]
207 fn mark_sender_dropped(&self) {
208 self.state.store(S::closed_value(), Ordering::Release);
209 }
210
211 #[inline]
212 fn is_receiver_closed(&self) -> bool {
213 self.state.load(Ordering::Acquire) == S::receiver_closed_value()
214 }
215
216 #[inline]
217 fn mark_receiver_closed(&self) {
218 self.state
219 .store(S::receiver_closed_value(), Ordering::Release);
220 }
221}
222
223pub type Sender<S> = super::common::Sender<LiteStorage<S>>;
231
232pub type Receiver<S> = super::common::Receiver<LiteStorage<S>>;
236
237#[inline]
241pub fn channel<S: State>() -> (Sender<S>, Receiver<S>) {
242 Sender::new()
243}
244
245impl<S: State> Receiver<S> {
250 #[inline]
258 pub async fn recv(self) -> Result<S, RecvError> {
259 self.await
260 }
261
262 #[inline]
272 pub fn try_recv(&mut self) -> Result<S, TryRecvError> {
273 match self.inner.try_recv() {
274 TakeResult::Ready(v) => Ok(v),
275 TakeResult::Pending => Err(TryRecvError::Empty),
276 TakeResult::Closed => Err(TryRecvError::Closed),
277 }
278 }
279}
280
281#[cfg(all(test, not(feature = "loom")))]
282mod tests {
283 use super::*;
284
285 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
289 enum TestCompletion {
290 Called,
294
295 Cancelled,
299 }
300
301 impl State for TestCompletion {
302 fn to_u8(&self) -> u8 {
303 match self {
304 TestCompletion::Called => 1,
305 TestCompletion::Cancelled => 2,
306 }
307 }
308
309 fn from_u8(value: u8) -> Option<Self> {
310 match value {
311 1 => Some(TestCompletion::Called),
312 2 => Some(TestCompletion::Cancelled),
313 _ => None,
314 }
315 }
316
317 fn pending_value() -> u8 {
318 0
319 }
320
321 fn closed_value() -> u8 {
322 255
323 }
324
325 fn receiver_closed_value() -> u8 {
326 254
327 }
328 }
329
330 #[tokio::test]
331 async fn test_oneshot_called() {
332 let (notifier, receiver) = Sender::<TestCompletion>::new();
333
334 tokio::spawn(async move {
335 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
336 notifier.send(TestCompletion::Called).unwrap();
337 });
338
339 let result = receiver.recv().await;
340 assert_eq!(result, Ok(TestCompletion::Called));
341 }
342
343 #[tokio::test]
344 async fn test_oneshot_cancelled() {
345 let (notifier, receiver) = Sender::<TestCompletion>::new();
346
347 tokio::spawn(async move {
348 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
349 notifier.send(TestCompletion::Cancelled).unwrap();
350 });
351
352 let result = receiver.recv().await;
353 assert_eq!(result, Ok(TestCompletion::Cancelled));
354 }
355
356 #[tokio::test]
357 async fn test_oneshot_immediate_called() {
358 let (notifier, receiver) = Sender::<TestCompletion>::new();
359
360 notifier.send(TestCompletion::Called).unwrap();
362
363 let result = receiver.recv().await;
364 assert_eq!(result, Ok(TestCompletion::Called));
365 }
366
367 #[tokio::test]
368 async fn test_oneshot_immediate_cancelled() {
369 let (notifier, receiver) = Sender::<TestCompletion>::new();
370
371 notifier.send(TestCompletion::Cancelled).unwrap();
373
374 let result = receiver.recv().await;
375 assert_eq!(result, Ok(TestCompletion::Cancelled));
376 }
377
378 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
380 enum CustomState {
381 Success,
382 Failure,
383 Timeout,
384 }
385
386 impl State for CustomState {
387 fn to_u8(&self) -> u8 {
388 match self {
389 CustomState::Success => 1,
390 CustomState::Failure => 2,
391 CustomState::Timeout => 3,
392 }
393 }
394
395 fn from_u8(value: u8) -> Option<Self> {
396 match value {
397 1 => Some(CustomState::Success),
398 2 => Some(CustomState::Failure),
399 3 => Some(CustomState::Timeout),
400 _ => None,
401 }
402 }
403
404 fn pending_value() -> u8 {
405 0
406 }
407
408 fn closed_value() -> u8 {
409 255
410 }
411
412 fn receiver_closed_value() -> u8 {
413 254
414 }
415 }
416
417 #[tokio::test]
418 async fn test_oneshot_custom_state() {
419 let (notifier, receiver) = Sender::<CustomState>::new();
420
421 tokio::spawn(async move {
422 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
423 notifier.send(CustomState::Success).unwrap();
424 });
425
426 let result = receiver.recv().await;
427 assert_eq!(result, Ok(CustomState::Success));
428 }
429
430 #[tokio::test]
431 async fn test_oneshot_custom_state_timeout() {
432 let (notifier, receiver) = Sender::<CustomState>::new();
433
434 notifier.send(CustomState::Timeout).unwrap();
436
437 let result = receiver.recv().await;
438 assert_eq!(result, Ok(CustomState::Timeout));
439 }
440
441 #[tokio::test]
442 async fn test_oneshot_unit_type() {
443 let (notifier, receiver) = Sender::<()>::new();
444
445 tokio::spawn(async move {
446 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
447 notifier.send(()).unwrap();
448 });
449
450 let result = receiver.recv().await;
451 assert_eq!(result, Ok(()));
452 }
453
454 #[tokio::test]
455 async fn test_oneshot_unit_type_immediate() {
456 let (notifier, receiver) = Sender::<()>::new();
457
458 notifier.send(()).unwrap();
460
461 let result = receiver.recv().await;
462 assert_eq!(result, Ok(()));
463 }
464
465 #[tokio::test]
467 async fn test_oneshot_into_future_called() {
468 let (notifier, receiver) = Sender::<TestCompletion>::new();
469
470 tokio::spawn(async move {
471 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
472 notifier.send(TestCompletion::Called).unwrap();
473 });
474
475 let result = receiver.await;
477 assert_eq!(result, Ok(TestCompletion::Called));
478 }
479
480 #[tokio::test]
481 async fn test_oneshot_into_future_immediate() {
482 let (notifier, receiver) = Sender::<TestCompletion>::new();
483
484 notifier.send(TestCompletion::Cancelled).unwrap();
486
487 let result = receiver.await;
489 assert_eq!(result, Ok(TestCompletion::Cancelled));
490 }
491
492 #[tokio::test]
493 async fn test_oneshot_into_future_unit_type() {
494 let (notifier, receiver) = Sender::<()>::new();
495
496 tokio::spawn(async move {
497 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
498 notifier.send(()).unwrap();
499 });
500
501 let result = receiver.await;
503 assert_eq!(result, Ok(()));
504 }
505
506 #[tokio::test]
507 async fn test_oneshot_into_future_custom_state() {
508 let (notifier, receiver) = Sender::<CustomState>::new();
509
510 tokio::spawn(async move {
511 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
512 notifier.send(CustomState::Failure).unwrap();
513 });
514
515 let result = receiver.await;
517 assert_eq!(result, Ok(CustomState::Failure));
518 }
519
520 #[tokio::test]
522 async fn test_oneshot_await_mut_reference() {
523 let (notifier, mut receiver) = Sender::<TestCompletion>::new();
524
525 tokio::spawn(async move {
526 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
527 notifier.send(TestCompletion::Called).unwrap();
528 });
529
530 let result = (&mut receiver).await;
532 assert_eq!(result, Ok(TestCompletion::Called));
533 }
534
535 #[tokio::test]
536 async fn test_oneshot_await_mut_reference_unit_type() {
537 let (notifier, mut receiver) = Sender::<()>::new();
538
539 notifier.send(()).unwrap();
541
542 let result = (&mut receiver).await;
544 assert_eq!(result, Ok(()));
545 }
546
547 #[tokio::test]
549 async fn test_oneshot_try_recv_pending() {
550 let (_notifier, mut receiver) = Sender::<TestCompletion>::new();
551
552 let result = receiver.try_recv();
554 assert_eq!(result, Err(TryRecvError::Empty));
555 }
556
557 #[tokio::test]
558 async fn test_oneshot_try_recv_ready() {
559 let (notifier, mut receiver) = Sender::<TestCompletion>::new();
560
561 notifier.send(TestCompletion::Called).unwrap();
563
564 let result = receiver.try_recv();
566 assert_eq!(result, Ok(TestCompletion::Called));
567 }
568
569 #[tokio::test]
570 async fn test_oneshot_try_recv_sender_dropped() {
571 let (notifier, mut receiver) = Sender::<TestCompletion>::new();
572
573 drop(notifier);
575
576 let result = receiver.try_recv();
578 assert_eq!(result, Err(TryRecvError::Closed));
579 }
580
581 #[tokio::test]
583 async fn test_oneshot_sender_dropped_before_recv() {
584 let (notifier, receiver) = Sender::<TestCompletion>::new();
585
586 drop(notifier);
588
589 let result = receiver.recv().await;
591 assert_eq!(result, Err(RecvError));
592 }
593
594 #[tokio::test]
595 async fn test_oneshot_sender_dropped_unit_type() {
596 let (notifier, receiver) = Sender::<()>::new();
597
598 drop(notifier);
600
601 let result = receiver.recv().await;
603 assert_eq!(result, Err(RecvError));
604 }
605
606 #[tokio::test]
607 async fn test_oneshot_sender_dropped_custom_state() {
608 let (notifier, receiver) = Sender::<CustomState>::new();
609
610 drop(notifier);
612
613 let result = receiver.recv().await;
615 assert_eq!(result, Err(RecvError));
616 }
617
618 #[test]
620 fn test_sender_is_closed_initially_false() {
621 let (sender, _receiver) = Sender::<()>::new();
622 assert!(!sender.is_closed());
623 }
624
625 #[test]
626 fn test_sender_is_closed_after_receiver_drop() {
627 let (sender, receiver) = Sender::<()>::new();
628 drop(receiver);
629 assert!(sender.is_closed());
630 }
631
632 #[test]
633 fn test_sender_is_closed_after_receiver_close() {
634 let (sender, mut receiver) = Sender::<()>::new();
635 receiver.close();
636 assert!(sender.is_closed());
637 }
638
639 #[test]
641 fn test_receiver_close_prevents_send() {
642 let (sender, mut receiver) = Sender::<TestCompletion>::new();
643 receiver.close();
644
645 assert!(sender.send(TestCompletion::Called).is_err());
647 }
648
649 #[test]
651 fn test_blocking_recv_immediate() {
652 let (sender, receiver) = Sender::<TestCompletion>::new();
653
654 sender.send(TestCompletion::Called).unwrap();
656
657 let result = receiver.blocking_recv();
658 assert_eq!(result, Ok(TestCompletion::Called));
659 }
660
661 #[test]
662 fn test_blocking_recv_with_thread() {
663 let (sender, receiver) = Sender::<()>::new();
664
665 std::thread::spawn(move || {
666 std::thread::sleep(std::time::Duration::from_millis(10));
667 sender.send(()).unwrap();
668 });
669
670 let result = receiver.blocking_recv();
671 assert_eq!(result, Ok(()));
672 }
673
674 #[test]
675 fn test_blocking_recv_sender_dropped() {
676 let (sender, receiver) = Sender::<()>::new();
677
678 std::thread::spawn(move || {
679 std::thread::sleep(std::time::Duration::from_millis(10));
680 drop(sender);
681 });
682
683 let result = receiver.blocking_recv();
684 assert_eq!(result, Err(RecvError));
685 }
686}