1use std::{
2 future::Future,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll, Waker},
6};
7
8use parking_lot::{Condvar, Mutex};
9
10use super::Result;
11
12#[derive(Debug)]
13struct OneShotState<T> {
14 filled: bool,
15 fused: bool,
16 item: Option<T>,
17 waker: Option<Waker>,
18}
19
20impl<T> Default for OneShotState<T> {
21 fn default() -> OneShotState<T> {
22 OneShotState {
23 filled: false,
24 fused: false,
25 item: None,
26 waker: None,
27 }
28 }
29}
30
31#[derive(Debug)]
33pub struct OneShot<T> {
34 mu: Arc<Mutex<OneShotState<T>>>,
35 cv: Arc<Condvar>,
36}
37
38pub struct OneShotFiller<T> {
40 mu: Arc<Mutex<OneShotState<T>>>,
41 cv: Arc<Condvar>,
42}
43
44impl<T> OneShot<T> {
45 pub fn pair() -> (OneShotFiller<T>, Self) {
48 let mu = Arc::new(Mutex::new(OneShotState::default()));
49 let cv = Arc::new(Condvar::new());
50 let future = Self {
51 mu: mu.clone(),
52 cv: cv.clone(),
53 };
54 let filler = OneShotFiller { mu, cv };
55
56 (filler, future)
57 }
58
59 pub fn wait(self) -> Option<T> {
62 let mut inner = self.mu.lock();
63 while !inner.filled {
64 self.cv.wait(&mut inner);
65 }
66 inner.item.take()
67 }
68
69 pub fn unwrap(self) -> T {
76 self.wait().unwrap()
77 }
78}
79
80impl<T> Future for OneShot<Result<T>> {
81 type Output = Result<T>;
82
83 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
84 let mut state = self.mu.lock();
85 if state.fused {
86 return Poll::Pending;
87 }
88 if state.filled {
89 state.fused = true;
90 Poll::Ready(state.item.take().unwrap())
91 } else {
92 state.waker = Some(cx.waker().clone());
93 Poll::Pending
94 }
95 }
96}
97
98impl<T> OneShotFiller<T> {
99 pub fn fill(self, inner: T) {
101 let mut state = self.mu.lock();
102
103 if let Some(waker) = state.waker.take() {
104 waker.wake();
105 }
106
107 state.filled = true;
108 state.item = Some(inner);
109
110 self.cv.notify_all();
111 }
112}
113
114impl<T> Drop for OneShotFiller<T> {
115 fn drop(&mut self) {
116 let mut state = self.mu.lock();
117
118 if state.filled {
119 return;
120 }
121
122 if let Some(waker) = state.waker.take() {
123 waker.wake();
124 }
125
126 state.filled = true;
127
128 self.cv.notify_all();
129 }
130}