1use std::time;
3use core::{ptr, task, pin};
4use core::cell::{Cell, UnsafeCell};
5use core::mem::MaybeUninit;
6use core::sync::atomic::{Ordering, AtomicU8};
7use core::future::Future;
8
9const UNINIT: u8 = 0;
10const READY: u8 = 0b00001;
11const WAKER_SET: u8 = 0b00010;
12const SEND_CLOSED: u8 = 0b00100;
13const CONSUMED: u8 = 0b01000;
14const RECV_CLOSED: u8 = 0b10000;
15
16use super::JoinError;
17
18enum Notifier {
19 Thread(std::thread::Thread),
20 Waker(core::task::Waker),
21}
22
23struct Payload<T> {
24 state: AtomicU8,
25 value: UnsafeCell<MaybeUninit<T>>,
26 notifier: Cell<MaybeUninit<Notifier>>
27}
28
29impl<T> Payload<T> {
30 const fn new() -> Self {
31 Self {
32 state: AtomicU8::new(UNINIT),
33 value: UnsafeCell::new(MaybeUninit::uninit()),
34 notifier: Cell::new(MaybeUninit::uninit()),
35 }
36 }
37
38 #[inline(never)]
39 fn set_notifier(&self, notifier: Notifier) -> u8 {
41 self.notifier.set(MaybeUninit::new(notifier));
42 self.state.fetch_or(WAKER_SET, Ordering::AcqRel)
43 }
44
45 #[inline(always)]
46 fn take_notifier(&self) -> Notifier {
47 let storage = self.notifier.replace(MaybeUninit::uninit());
48
49 unsafe {
50 storage.assume_init()
51 }
52 }
53}
54
55impl<T> Drop for Payload<T> {
56 fn drop(&mut self) {
57 let state = self.state.load(Ordering::Relaxed);
58 match (state & READY == READY) && (state & CONSUMED != CONSUMED) {
59 true => unsafe {
60 ptr::drop_in_place((*self.value.get()).as_mut_ptr());
61 },
62 _ => (),
63 }
64
65 if state & WAKER_SET == WAKER_SET {
67 self.take_notifier();
68 }
69 }
70}
71
72#[repr(transparent)]
73pub struct Sender<T> {
77 payload: ptr::NonNull<Payload<T>>,
78}
79
80impl<T> Sender<T> {
81 #[inline(always)]
82 fn payload(&self) -> &Payload<T> {
83 unsafe {
84 &*self.payload.as_ptr()
85 }
86 }
87
88 pub fn send(self, value: T) {
90 unsafe {
92 ptr::write((*self.payload().value.get()).as_mut_ptr(), value);
93 }
94
95 let state = self.payload().state.fetch_or(READY, Ordering::AcqRel);
96 if state & WAKER_SET == WAKER_SET {
97 let notifier = self.payload().take_notifier();
98 self.payload().state.fetch_and(!WAKER_SET, Ordering::Release);
99
100 match notifier {
101 Notifier::Thread(thread) => thread.unpark(),
102 Notifier::Waker(waker) => waker.wake(),
103 }
104 }
105 }
106}
107
108impl<T> Drop for Sender<T> {
109 fn drop(&mut self) {
110 let mut state = self.payload().state.load(Ordering::Acquire);
112 if state & WAKER_SET == WAKER_SET {
113 let notifier = self.payload().take_notifier();
114 state = self.payload().state.fetch_xor(WAKER_SET | SEND_CLOSED, Ordering::AcqRel);
116
117 match notifier {
118 Notifier::Thread(thread) => thread.unpark(),
119 Notifier::Waker(waker) => waker.wake(),
120 }
121 } else {
122 state = self.payload().state.fetch_or(SEND_CLOSED, Ordering::AcqRel);
123 }
124
125 if state & RECV_CLOSED == RECV_CLOSED {
126 unsafe {
127 let _ = Box::from_raw(self.payload.as_ptr());
128 }
129 }
130 }
131}
132
133unsafe impl<T: Send> Send for Sender<T> {}
134unsafe impl<T: Sync> Sync for Sender<T> {}
135
136#[repr(transparent)]
137pub struct Receiver<T> {
141 payload: ptr::NonNull<Payload<T>>,
142}
143
144impl<T> Receiver<T> {
145 #[inline(always)]
146 fn payload(&self) -> &Payload<T> {
147 unsafe {
148 &*self.payload.as_ptr()
149 }
150 }
151
152 fn consume(&self) -> T {
153 self.payload().state.fetch_or(CONSUMED, Ordering::Release);
154 let mut result = MaybeUninit::uninit();
155
156 unsafe {
157 ptr::swap(result.as_mut_ptr(), (*self.payload().value.get()).as_mut_ptr());
158
159 result.assume_init()
160 }
161 }
162
163 #[inline(always)]
164 pub fn is_ready(&self) -> bool {
166 self.payload().state.load(Ordering::Acquire) & READY == READY
167 }
168
169 #[inline(always)]
170 pub fn is_consumed(&self) -> bool {
172 self.payload().state.load(Ordering::Acquire) & CONSUMED == CONSUMED
173 }
174
175 pub fn try_recv(&self) -> Result<Option<T>, JoinError> {
178 let state = self.payload().state.load(Ordering::Acquire);
179
180 if state & CONSUMED == CONSUMED {
181 Err(JoinError::AlreadyConsumed)
182 } else if state & READY == READY {
183 Ok(Some(self.consume()))
184 } else if state & SEND_CLOSED == SEND_CLOSED {
185 Err(JoinError::Disconnect)
186 } else {
187 Ok(None)
188 }
189 }
190
191 pub fn recv(self) -> Result<T, JoinError> {
194 let mut state = self.payload().state.load(Ordering::Acquire);
195
196 if state & CONSUMED == CONSUMED {
197 return Err(JoinError::AlreadyConsumed);
198 } else if state & READY == READY {
199 return Ok(self.consume());
200 } else if state & SEND_CLOSED == SEND_CLOSED {
201 return Err(JoinError::Disconnect);
202 }
203
204 state = self.payload().set_notifier(Notifier::Thread(std::thread::current()));
205
206 while state & READY != READY {
207 if state & SEND_CLOSED == SEND_CLOSED {
209 return Err(JoinError::Disconnect);
210 }
211
212 std::thread::park();
213
214 state = self.payload().state.load(Ordering::Acquire);
215 }
216
217 Ok(self.consume())
218 }
219
220 pub fn recv_timeout(&self, mut time: time::Duration) -> Result<Option<T>, JoinError> {
225 let mut state = self.payload().state.load(Ordering::Acquire);
226
227 if state & CONSUMED == CONSUMED {
228 return Err(JoinError::AlreadyConsumed);
229 } else if state & READY == READY {
230 return Ok(Some(self.consume()));
231 } else if state & SEND_CLOSED == SEND_CLOSED {
232 return Err(JoinError::Disconnect);
233 }
234
235 state = self.payload().set_notifier(Notifier::Thread(std::thread::current()));
236
237 let start_time = time::Instant::now();
238 while state & READY != READY {
239 std::thread::park_timeout(time);
240
241 if let Some(left_over) = time.checked_sub(start_time.elapsed()) {
242 time = left_over;
244 state = self.payload().state.load(Ordering::Acquire);
245 } else {
246 break;
247 }
248 }
249 state = self.payload().state.fetch_and(!WAKER_SET, Ordering::AcqRel);
250
251 if state & WAKER_SET == WAKER_SET {
252 self.payload().take_notifier();
253 }
254
255 if state & READY == READY {
256 Ok(Some(self.consume()))
257 } else {
258 Ok(None)
259 }
260 }
261}
262
263impl<T> Drop for Receiver<T> {
264 #[inline(always)]
265 fn drop(&mut self) {
266 let state = self.payload().state.fetch_or(RECV_CLOSED, Ordering::AcqRel);
268 if state & SEND_CLOSED == SEND_CLOSED {
269 unsafe {
270 let _ = Box::from_raw(self.payload.as_ptr());
271 }
272 }
273 }
274}
275
276impl<T> Future for Receiver<T> {
277 type Output = Result<T, JoinError>;
278
279 fn poll(self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
280 let mut state = self.payload().state.load(Ordering::Acquire);
281
282 if state & CONSUMED == CONSUMED {
283 return task::Poll::Ready(Err(JoinError::AlreadyConsumed));
284 } else if state & READY == READY {
285 return task::Poll::Ready(Ok(self.consume()));
286 } else if state & SEND_CLOSED == SEND_CLOSED {
287 return task::Poll::Ready(Err(JoinError::Disconnect));
288 }
289
290 if state & WAKER_SET == WAKER_SET {
292 state = self.payload().state.load(Ordering::Acquire);
293 } else {
294 state = self.payload().set_notifier(Notifier::Waker(cx.waker().clone()));
295 }
296
297 if state & CONSUMED == CONSUMED {
299 return task::Poll::Ready(Err(JoinError::AlreadyConsumed));
300 } else if state & READY == READY {
301 return task::Poll::Ready(Ok(self.consume()));
302 } else if state & SEND_CLOSED == SEND_CLOSED {
303 return task::Poll::Ready(Err(JoinError::Disconnect));
304 } else {
305 task::Poll::Pending
306 }
307 }
308}
309
310unsafe impl<T: Send> Send for Receiver<T> {}
311impl<T> Unpin for Receiver<T> {}
312
313pub fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
318 let payload = ptr::NonNull::from(Box::leak(Box::new(Payload::new())));
319
320 let sender = Sender {
321 payload,
322 };
323
324 let receiver = Receiver {
325 payload,
326 };
327
328 (sender, receiver)
329}