rtsc/
base_channel_async.rs

1use std::{
2    collections::{BTreeSet, VecDeque},
3    future::Future,
4    marker::PhantomData,
5    mem,
6    pin::Pin,
7    sync::{
8        atomic::{AtomicUsize, Ordering},
9        Arc,
10    },
11    task::{Context, Poll, Waker},
12    time::Duration,
13};
14
15use crate::{base_channel::ChannelStorage, data_policy::StorageTryPushOutput, Error, Result};
16use object_id::UniqueId;
17use parking_lot_rt::{Condvar, Mutex};
18use pin_project::{pin_project, pinned_drop};
19
20type ClientId = usize;
21
22/// Base async channel
23pub struct BaseChannelAsync<T: Sized, S: ChannelStorage<T>>(pub(crate) Arc<ChannelInner<T, S>>);
24
25impl<T: Sized, S: ChannelStorage<T>> BaseChannelAsync<T, S> {
26    fn id(&self) -> usize {
27        self.0.id.as_usize()
28    }
29}
30
31impl<T: Sized, S: ChannelStorage<T>> Eq for BaseChannelAsync<T, S> {}
32
33impl<T: Sized, S: ChannelStorage<T>> PartialEq for BaseChannelAsync<T, S> {
34    fn eq(&self, other: &Self) -> bool {
35        self.id() == other.id()
36    }
37}
38
39impl<T, S> Clone for BaseChannelAsync<T, S>
40where
41    T: Sized,
42    S: ChannelStorage<T>,
43{
44    fn clone(&self) -> Self {
45        Self(self.0.clone())
46    }
47}
48
49pub(crate) struct ChannelInner<T: Sized, S: ChannelStorage<T>> {
50    id: UniqueId,
51    pub(crate) data: Mutex<InnerData<T, S>>,
52    next_op_id: AtomicUsize,
53    space_available: Arc<Condvar>,
54    data_available: Arc<Condvar>,
55}
56
57impl<T: Sized, S: ChannelStorage<T>> BaseChannelAsync<T, S> {
58    pub(crate) fn new(capacity: usize, ordering: bool) -> Self {
59        let pc = InnerData::new(capacity, ordering);
60        let space_available = pc.space_available.clone();
61        let data_available = pc.data_available.clone();
62        Self(
63            ChannelInner {
64                id: <_>::default(),
65                data: Mutex::new(pc),
66                next_op_id: <_>::default(),
67                space_available,
68                data_available,
69            }
70            .into(),
71        )
72    }
73    fn op_id(&self) -> usize {
74        self.0.next_op_id.fetch_add(1, Ordering::SeqCst)
75    }
76}
77
78pub(crate) struct InnerData<T: Sized, S: ChannelStorage<T>> {
79    queue: S,
80    senders: usize,
81    receivers: usize,
82    pub(crate) send_fut_wakers: VecDeque<Option<(Waker, ClientId)>>,
83    pub(crate) send_fut_waker_ids: BTreeSet<ClientId>,
84    pub(crate) send_fut_pending: BTreeSet<ClientId>,
85    pub(crate) recv_fut_wakers: VecDeque<Option<(Waker, ClientId)>>,
86    pub(crate) recv_fut_waker_ids: BTreeSet<ClientId>,
87    pub(crate) recv_fut_pending: BTreeSet<ClientId>,
88    data_available: Arc<Condvar>,
89    space_available: Arc<Condvar>,
90    _phatom: PhantomData<T>,
91}
92
93impl<T, S> InnerData<T, S>
94where
95    T: Sized,
96    S: ChannelStorage<T>,
97{
98    fn new(capacity: usize, ordering: bool) -> Self {
99        assert!(capacity > 0, "channel capacity MUST be > 0");
100        Self {
101            queue: S::with_capacity_and_ordering(capacity, ordering),
102            senders: 1,
103            receivers: 1,
104            send_fut_wakers: <_>::default(),
105            send_fut_waker_ids: <_>::default(),
106            send_fut_pending: <_>::default(),
107            recv_fut_wakers: <_>::default(),
108            recv_fut_waker_ids: <_>::default(),
109            recv_fut_pending: <_>::default(),
110            data_available: <_>::default(),
111            space_available: <_>::default(),
112            _phatom: PhantomData,
113        }
114    }
115
116    // senders
117
118    #[inline]
119    fn notify_data_sent(&mut self) {
120        self.wake_next_recv();
121    }
122
123    #[inline]
124    fn wake_next_send(&mut self) {
125        if let Some(w) = self.send_fut_wakers.pop_front() {
126            if let Some((waker, id)) = w {
127                self.send_fut_waker_ids.remove(&id);
128                self.send_fut_pending.insert(id);
129                waker.wake();
130            } else {
131                self.space_available.notify_one();
132            }
133        }
134    }
135    #[inline]
136    fn wake_all_sends(&mut self) {
137        self.send_fut_waker_ids.clear();
138        for (waker, _) in mem::take(&mut self.send_fut_wakers).into_iter().flatten() {
139            waker.wake();
140        }
141        self.space_available.notify_all();
142    }
143
144    #[inline]
145    fn notify_send_fut_drop(&mut self, id: ClientId) {
146        if let Some(pos) = self
147            .send_fut_wakers
148            .iter()
149            .position(|w| w.as_ref().map_or(false, |(_, i)| *i == id))
150        {
151            self.send_fut_wakers.remove(pos);
152            self.send_fut_waker_ids.remove(&id);
153        }
154        if self.send_fut_pending.remove(&id) {
155            self.wake_next_send();
156        }
157    }
158
159    #[inline]
160    fn confirm_send_fut_waked(&mut self, id: ClientId) {
161        self.send_fut_pending.remove(&id);
162    }
163
164    #[inline]
165    fn append_send_fut_waker(&mut self, waker: Waker, id: ClientId) {
166        if !self.send_fut_waker_ids.insert(id) {
167            return;
168        }
169        self.send_fut_wakers.push_back(Some((waker, id)));
170    }
171
172    #[inline]
173    fn append_send_sync_waker(&mut self) {
174        // use condvar
175        self.send_fut_wakers.push_back(None);
176    }
177
178    // receivers
179
180    #[inline]
181    fn notify_data_received(&mut self) {
182        self.wake_next_send();
183    }
184
185    #[inline]
186    fn wake_next_recv(&mut self) {
187        if let Some(w) = self.recv_fut_wakers.pop_front() {
188            if let Some((waker, id)) = w {
189                self.recv_fut_pending.insert(id);
190                self.recv_fut_waker_ids.remove(&id);
191                waker.wake();
192            } else {
193                self.data_available.notify_one();
194            }
195        }
196    }
197    #[inline]
198    fn wake_all_recvs(&mut self) {
199        for (waker, _) in mem::take(&mut self.recv_fut_wakers).into_iter().flatten() {
200            waker.wake();
201        }
202        self.recv_fut_waker_ids.clear();
203        self.data_available.notify_all();
204    }
205
206    #[inline]
207    fn notify_recv_fut_drop(&mut self, id: ClientId) {
208        if let Some(pos) = self
209            .recv_fut_wakers
210            .iter()
211            .position(|w| w.as_ref().map_or(false, |(_, i)| *i == id))
212        {
213            self.recv_fut_wakers.remove(pos);
214            self.recv_fut_waker_ids.remove(&id);
215        }
216        if self.recv_fut_pending.remove(&id) {
217            self.wake_next_recv();
218        }
219    }
220
221    #[inline]
222    fn confirm_recv_fut_waked(&mut self, id: ClientId) {
223        // the resource is taken, remove from pending
224        self.recv_fut_pending.remove(&id);
225    }
226
227    #[inline]
228    fn append_recv_fut_waker(&mut self, waker: Waker, id: ClientId) {
229        if !self.recv_fut_waker_ids.insert(id) {
230            return;
231        }
232        self.recv_fut_wakers.push_back(Some((waker, id)));
233    }
234
235    #[inline]
236    fn append_recv_sync_waker(&mut self) {
237        // use condvar
238        self.recv_fut_wakers.push_back(None);
239    }
240}
241
242#[pin_project(PinnedDrop)]
243struct Send<'a, T: Sized, S: ChannelStorage<T>> {
244    id: usize,
245    channel: &'a BaseChannelAsync<T, S>,
246    queued: bool,
247    value: Option<T>,
248}
249
250#[pinned_drop]
251#[allow(clippy::needless_lifetimes)]
252impl<'a, T: Sized, S: ChannelStorage<T>> PinnedDrop for Send<'a, T, S> {
253    fn drop(self: Pin<&mut Self>) {
254        if self.queued {
255            self.channel.0.data.lock().notify_send_fut_drop(self.id);
256        }
257    }
258}
259
260impl<'a, T, S> Future for Send<'a, T, S>
261where
262    T: Sized,
263    S: ChannelStorage<T>,
264{
265    type Output = Result<()>;
266    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267        let mut pc = self.channel.0.data.lock();
268        if self.queued {
269            pc.confirm_send_fut_waked(self.id);
270        }
271        if pc.receivers == 0 {
272            self.queued = false;
273            return Poll::Ready(Err(Error::ChannelClosed));
274        }
275        if pc.send_fut_wakers.is_empty() || self.queued {
276            let push_result = pc.queue.try_push(self.value.take().unwrap());
277            if let StorageTryPushOutput::Full(val) = push_result {
278                self.value = Some(val);
279            } else {
280                self.queued = false;
281                return Poll::Ready(match push_result {
282                    StorageTryPushOutput::Pushed => {
283                        pc.notify_data_sent();
284                        Ok(())
285                    }
286                    StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
287                    StorageTryPushOutput::Full(_) => unreachable!(),
288                });
289            }
290        }
291        self.queued = true;
292        pc.append_send_fut_waker(cx.waker().clone(), self.id);
293        Poll::Pending
294    }
295}
296
297/// Base async sender
298#[derive(Eq, PartialEq)]
299pub struct BaseSenderAsync<T, S>
300where
301    T: Sized,
302    S: ChannelStorage<T>,
303{
304    channel: BaseChannelAsync<T, S>,
305}
306
307impl<T, S> BaseSenderAsync<T, S>
308where
309    T: Sized,
310    S: ChannelStorage<T>,
311{
312    /// Sends a value to the channel
313    #[inline]
314    pub fn send(&self, value: T) -> impl Future<Output = Result<()>> + '_ {
315        Send {
316            id: self.channel.op_id(),
317            channel: &self.channel,
318            queued: false,
319            value: Some(value),
320        }
321    }
322    /// Tries to send a value to the channel
323    pub fn try_send(&self, value: T) -> Result<()> {
324        let mut pc = self.channel.0.data.lock();
325        if pc.receivers == 0 {
326            return Err(Error::ChannelClosed);
327        }
328        match pc.queue.try_push(value) {
329            StorageTryPushOutput::Pushed => {
330                pc.notify_data_sent();
331                Ok(())
332            }
333            StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
334            StorageTryPushOutput::Full(_) => Err(Error::ChannelFull),
335        }
336    }
337    /// Sends a value to the channel in a blocking (synchronous) way
338    pub fn send_blocking(&self, mut value: T) -> Result<()> {
339        let mut pc = self.channel.0.data.lock();
340        let pushed = loop {
341            if pc.receivers == 0 {
342                return Err(Error::ChannelClosed);
343            }
344            let push_result = pc.queue.try_push(value);
345            let StorageTryPushOutput::Full(val) = push_result else {
346                break push_result;
347            };
348            value = val;
349            pc.append_send_sync_waker();
350            self.channel.0.space_available.wait(&mut pc);
351        };
352        match pushed {
353            StorageTryPushOutput::Pushed => {
354                pc.notify_data_sent();
355                Ok(())
356            }
357            StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
358            StorageTryPushOutput::Full(_) => unreachable!(),
359        }
360    }
361    /// Sends a value to the channel in a blocking (synchronous) way with a given tiemout
362    pub fn send_blocking_timeout(&self, mut value: T, timeout: Duration) -> Result<()> {
363        let mut pc = self.channel.0.data.lock();
364        let pushed = loop {
365            if pc.receivers == 0 {
366                return Err(Error::ChannelClosed);
367            }
368            let push_result = pc.queue.try_push(value);
369            let StorageTryPushOutput::Full(val) = push_result else {
370                break push_result;
371            };
372            value = val;
373            pc.append_send_sync_waker();
374            if self
375                .channel
376                .0
377                .space_available
378                .wait_for(&mut pc, timeout)
379                .timed_out()
380            {
381                return Err(Error::Timeout);
382            }
383        };
384        pc.notify_data_sent();
385        match pushed {
386            StorageTryPushOutput::Pushed => Ok(()),
387            StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
388            StorageTryPushOutput::Full(_) => unreachable!(),
389        }
390    }
391    /// Returns the number of items in the channel
392    #[inline]
393    pub fn len(&self) -> usize {
394        self.channel.0.data.lock().queue.len()
395    }
396    /// Returns true if the channel is full
397    #[inline]
398    pub fn is_full(&self) -> bool {
399        self.channel.0.data.lock().queue.is_full()
400    }
401    /// Returns true if the channel is empty
402    #[inline]
403    pub fn is_empty(&self) -> bool {
404        self.channel.0.data.lock().queue.is_empty()
405    }
406    /// Returns true if the channel is still alive
407    #[inline]
408    pub fn is_alive(&self) -> bool {
409        self.channel.0.data.lock().receivers > 0
410    }
411}
412
413impl<T, S> Clone for BaseSenderAsync<T, S>
414where
415    T: Sized,
416    S: ChannelStorage<T>,
417{
418    fn clone(&self) -> Self {
419        self.channel.0.data.lock().senders += 1;
420        Self {
421            channel: self.channel.clone(),
422        }
423    }
424}
425
426impl<T, S> Drop for BaseSenderAsync<T, S>
427where
428    T: Sized,
429    S: ChannelStorage<T>,
430{
431    fn drop(&mut self) {
432        let mut pc = self.channel.0.data.lock();
433        pc.senders -= 1;
434        if pc.senders == 0 {
435            pc.wake_all_recvs();
436        }
437    }
438}
439
440struct Recv<'a, T: Sized, S: ChannelStorage<T>> {
441    id: usize,
442    channel: &'a BaseChannelAsync<T, S>,
443    queued: bool,
444}
445
446impl<'a, T: Sized, S: ChannelStorage<T>> Drop for Recv<'a, T, S> {
447    fn drop(&mut self) {
448        if self.queued {
449            self.channel.0.data.lock().notify_recv_fut_drop(self.id);
450        }
451    }
452}
453
454impl<'a, T, S> Future for Recv<'a, T, S>
455where
456    T: Sized,
457    S: ChannelStorage<T>,
458{
459    type Output = Result<T>;
460    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
461        let mut pc = self.channel.0.data.lock();
462        if self.queued {
463            pc.confirm_recv_fut_waked(self.id);
464        }
465        if pc.recv_fut_wakers.is_empty() || self.queued {
466            if let Some(val) = pc.queue.get() {
467                pc.notify_data_received();
468                self.queued = false;
469                return Poll::Ready(Ok(val));
470            } else if pc.senders == 0 {
471                self.queued = false;
472                return Poll::Ready(Err(Error::ChannelClosed));
473            }
474        }
475        self.queued = true;
476        pc.append_recv_fut_waker(cx.waker().clone(), self.id);
477        Poll::Pending
478    }
479}
480
481/// Base async receiver
482#[derive(Eq, PartialEq)]
483pub struct BaseReceiverAsync<T, S>
484where
485    T: Sized,
486    S: ChannelStorage<T>,
487{
488    pub(crate) channel: BaseChannelAsync<T, S>,
489}
490
491impl<T, S> BaseReceiverAsync<T, S>
492where
493    T: Sized,
494    S: ChannelStorage<T>,
495{
496    /// Receives a value from the channel
497    #[inline]
498    pub fn recv(&self) -> impl Future<Output = Result<T>> + '_ {
499        Recv {
500            id: self.channel.op_id(),
501            channel: &self.channel,
502            queued: false,
503        }
504    }
505    /// Tries to receive a value from the channel
506    pub fn try_recv(&self) -> Result<T> {
507        let mut pc = self.channel.0.data.lock();
508        if let Some(val) = pc.queue.get() {
509            pc.notify_data_received();
510            Ok(val)
511        } else if pc.senders == 0 {
512            Err(Error::ChannelClosed)
513        } else {
514            Err(Error::ChannelEmpty)
515        }
516    }
517    /// Receives a value from the channel in a blocking (synchronous) way
518    pub fn recv_blocking(&self) -> Result<T> {
519        let mut pc = self.channel.0.data.lock();
520        loop {
521            if let Some(val) = pc.queue.get() {
522                pc.notify_data_received();
523                return Ok(val);
524            } else if pc.senders == 0 {
525                return Err(Error::ChannelClosed);
526            }
527            pc.append_recv_sync_waker();
528            self.channel.0.data_available.wait(&mut pc);
529        }
530    }
531    /// Receives a value from the channel in a blocking (synchronous) way with a given timeout
532    pub fn recv_blocking_timeout(&self, timeout: Duration) -> Result<T> {
533        let mut pc = self.channel.0.data.lock();
534        loop {
535            if let Some(val) = pc.queue.get() {
536                pc.notify_data_received();
537                return Ok(val);
538            } else if pc.senders == 0 {
539                return Err(Error::ChannelClosed);
540            }
541            pc.append_recv_sync_waker();
542            if self
543                .channel
544                .0
545                .data_available
546                .wait_for(&mut pc, timeout)
547                .timed_out()
548            {
549                return Err(Error::Timeout);
550            }
551        }
552    }
553    /// Returns the number of items in the channel
554    #[inline]
555    pub fn len(&self) -> usize {
556        self.channel.0.data.lock().queue.len()
557    }
558    /// Returns true if the channel is full
559    #[inline]
560    pub fn is_full(&self) -> bool {
561        self.channel.0.data.lock().queue.is_full()
562    }
563    /// Returns true if the channel is empty
564    #[inline]
565    pub fn is_empty(&self) -> bool {
566        self.channel.0.data.lock().queue.is_empty()
567    }
568    /// Returns true if the channel is still alive
569    #[inline]
570    pub fn is_alive(&self) -> bool {
571        self.channel.0.data.lock().senders > 0
572    }
573}
574
575impl<T, S> Clone for BaseReceiverAsync<T, S>
576where
577    T: Sized,
578    S: ChannelStorage<T>,
579{
580    fn clone(&self) -> Self {
581        self.channel.0.data.lock().receivers += 1;
582        Self {
583            channel: self.channel.clone(),
584        }
585    }
586}
587
588impl<T, S> Drop for BaseReceiverAsync<T, S>
589where
590    T: Sized,
591    S: ChannelStorage<T>,
592{
593    fn drop(&mut self) {
594        let mut pc = self.channel.0.data.lock();
595        pc.receivers -= 1;
596        if pc.receivers == 0 {
597            pc.wake_all_sends();
598        }
599    }
600}
601
602pub(crate) fn make_channel<T: Sized, S: ChannelStorage<T>>(
603    ch: BaseChannelAsync<T, S>,
604) -> (BaseSenderAsync<T, S>, BaseReceiverAsync<T, S>) {
605    let tx = BaseSenderAsync {
606        channel: ch.clone(),
607    };
608    let rx = BaseReceiverAsync { channel: ch };
609    (tx, rx)
610}