lite_sync/oneshot/
common.rs1use std::fmt;
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{Context, Poll};
10
11use crate::atomic_waker::AtomicWaker;
12
13pub mod error {
18 use std::fmt;
21
22 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
26 pub struct RecvError;
27
28 impl fmt::Display for RecvError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 write!(f, "channel closed")
31 }
32 }
33
34 impl std::error::Error for RecvError {}
35
36 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
40 pub enum TryRecvError {
41 Empty,
45 Closed,
49 }
50
51 impl fmt::Display for TryRecvError {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self {
54 TryRecvError::Empty => write!(f, "channel empty"),
55 TryRecvError::Closed => write!(f, "channel closed"),
56 }
57 }
58 }
59
60 impl std::error::Error for TryRecvError {}
61}
62
63pub use self::error::RecvError;
64pub use self::error::TryRecvError;
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
72pub enum TakeResult<T> {
73 Ready(T),
75 Pending,
77 Closed,
79}
80
81impl<T> TakeResult<T> {
82 #[inline]
83 pub fn ok(self) -> Option<T> {
84 match self {
85 TakeResult::Ready(v) => Some(v),
86 _ => None,
87 }
88 }
89
90 #[inline]
91 pub fn is_closed(&self) -> bool {
92 matches!(self, TakeResult::Closed)
93 }
94}
95
96pub trait OneshotStorage: Send + Sync + Sized {
100 type Value: Send;
102
103 fn new() -> Self;
105
106 fn store(&self, value: Self::Value);
108
109 fn try_take(&self) -> TakeResult<Self::Value>;
111
112 fn is_sender_dropped(&self) -> bool;
114
115 fn mark_sender_dropped(&self);
117
118 fn is_receiver_closed(&self) -> bool;
122
123 fn mark_receiver_closed(&self);
127}
128
129pub struct Inner<S: OneshotStorage> {
135 pub(crate) waker: AtomicWaker,
136 pub(crate) storage: S,
137}
138
139impl<S: OneshotStorage> Inner<S> {
140 #[inline]
141 pub fn new() -> Arc<Self> {
142 Arc::new(Self {
143 waker: AtomicWaker::new(),
144 storage: S::new(),
145 })
146 }
147
148 #[inline]
149 pub fn send(&self, value: S::Value) {
150 self.storage.store(value);
151 self.waker.wake();
152 }
153
154 #[inline]
155 pub fn try_recv(&self) -> TakeResult<S::Value> {
156 self.storage.try_take()
157 }
158
159 #[inline]
160 pub fn register_waker(&self, waker: &std::task::Waker) {
161 self.waker.register(waker);
162 }
163
164 #[inline]
165 pub fn is_sender_dropped(&self) -> bool {
166 self.storage.is_sender_dropped()
167 }
168}
169
170impl<S: OneshotStorage> fmt::Debug for Inner<S> {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 f.debug_struct("Inner").finish_non_exhaustive()
173 }
174}
175
176pub struct Sender<S: OneshotStorage> {
184 pub(crate) inner: Arc<Inner<S>>,
185}
186
187impl<S: OneshotStorage> fmt::Debug for Sender<S> {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 f.debug_struct("Sender").finish_non_exhaustive()
190 }
191}
192
193impl<S: OneshotStorage> Sender<S> {
194 #[inline]
196 pub fn new() -> (Self, Receiver<S>) {
197 let inner = Inner::new();
198 let sender = Sender { inner: inner.clone() };
199 let receiver = Receiver { inner };
200 (sender, receiver)
201 }
202
203 #[inline]
207 pub fn send(self, value: S::Value) -> Result<(), S::Value> {
208 if self.is_closed() {
209 return Err(value);
210 }
211 self.send_unchecked(value);
212 Ok(())
213 }
214
215 #[inline]
219 pub fn send_unchecked(self, value: S::Value) {
220 self.inner.send(value);
221 std::mem::forget(self);
222 }
223
224 #[inline]
228 pub fn is_closed(&self) -> bool {
229 self.inner.storage.is_receiver_closed() || Arc::strong_count(&self.inner) == 1
231 }
232}
233
234impl<S: OneshotStorage> Drop for Sender<S> {
235 fn drop(&mut self) {
236 self.inner.storage.mark_sender_dropped();
237 self.inner.waker.wake();
238 }
239}
240
241pub struct Receiver<S: OneshotStorage> {
253 pub(crate) inner: Arc<Inner<S>>,
254}
255
256impl<S: OneshotStorage> fmt::Debug for Receiver<S> {
257 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
258 f.debug_struct("Receiver").finish_non_exhaustive()
259 }
260}
261
262impl<S: OneshotStorage> Unpin for Receiver<S> {}
263
264impl<S: OneshotStorage> Receiver<S> {
265 #[inline]
267 pub async fn wait(self) -> Result<S::Value, RecvError> {
268 self.await
269 }
270
271 #[inline]
282 pub fn close(&mut self) {
283 self.inner.storage.mark_receiver_closed();
284 }
285
286 #[inline]
302 pub fn blocking_recv(self) -> Result<S::Value, RecvError> {
303 use std::sync::atomic::{AtomicBool, Ordering};
304 use std::task::{RawWaker, RawWakerVTable, Waker};
305
306 match self.inner.storage.try_take() {
308 TakeResult::Ready(value) => return Ok(value),
309 TakeResult::Closed => return Err(RecvError),
310 TakeResult::Pending => {}
311 }
312
313 struct ThreadParker {
315 thread: std::thread::Thread,
316 notified: AtomicBool,
317 }
318
319 const VTABLE: RawWakerVTable = RawWakerVTable::new(
320 |ptr| unsafe {
321 Arc::increment_strong_count(ptr as *const ThreadParker);
322 RawWaker::new(ptr, &VTABLE)
323 },
324 |ptr| unsafe {
325 let parker = Arc::from_raw(ptr as *const ThreadParker);
326 parker.notified.store(true, Ordering::Release);
327 parker.thread.unpark();
328 },
329 |ptr| unsafe {
330 let parker = &*(ptr as *const ThreadParker);
331 parker.notified.store(true, Ordering::Release);
332 parker.thread.unpark();
333 },
334 |ptr| unsafe { Arc::decrement_strong_count(ptr as *const ThreadParker); },
335 );
336
337 let parker = Arc::new(ThreadParker {
338 thread: std::thread::current(),
339 notified: AtomicBool::new(false),
340 });
341
342 let raw_waker = RawWaker::new(Arc::into_raw(parker.clone()) as *const (), &VTABLE);
343 let waker = unsafe { Waker::from_raw(raw_waker) };
344
345 self.inner.register_waker(&waker);
347
348 loop {
349 match self.inner.storage.try_take() {
350 TakeResult::Ready(value) => return Ok(value),
351 TakeResult::Closed => return Err(RecvError),
352 TakeResult::Pending => {}
353 }
354
355 if Arc::strong_count(&self.inner) == 1 && self.inner.is_sender_dropped() {
357 return Err(RecvError);
358 }
359
360 if !parker.notified.swap(false, Ordering::Acquire) {
362 std::thread::park();
363 }
364 }
365 }
366}
367
368impl<S: OneshotStorage> Future for Receiver<S> {
369 type Output = Result<S::Value, RecvError>;
370
371 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
372 let this = self.get_mut();
373
374 match this.inner.try_recv() {
376 TakeResult::Ready(value) => return Poll::Ready(Ok(value)),
377 TakeResult::Closed => return Poll::Ready(Err(RecvError)),
378 TakeResult::Pending => {}
379 }
380
381 this.inner.register_waker(cx.waker());
383
384 match this.inner.try_recv() {
386 TakeResult::Ready(value) => return Poll::Ready(Ok(value)),
387 TakeResult::Closed => return Poll::Ready(Err(RecvError)),
388 TakeResult::Pending => {}
389 }
390
391 if Arc::strong_count(&this.inner) == 1 && this.inner.is_sender_dropped() {
393 return Poll::Ready(Err(RecvError));
394 }
395
396 Poll::Pending
397 }
398}
399
400#[inline]
402pub fn channel<S: OneshotStorage>() -> (Sender<S>, Receiver<S>) {
403 Sender::new()
404}