1use std::{
2 collections::{BTreeSet, VecDeque},
3 future::Future,
4 marker::PhantomData,
5 mem,
6 pin::Pin,
7 sync::{
8 atomic::{AtomicUsize, Ordering},
9 Arc,
10 },
11 task::{Context, Poll, Waker},
12 time::Duration,
13};
14
15use crate::{base_channel::ChannelStorage, data_policy::StorageTryPushOutput, Error, Result};
16use object_id::UniqueId;
17use parking_lot_rt::{Condvar, Mutex};
18use pin_project::{pin_project, pinned_drop};
19
20type ClientId = usize;
21
22pub struct BaseChannelAsync<T: Sized, S: ChannelStorage<T>>(pub(crate) Arc<ChannelInner<T, S>>);
24
25impl<T: Sized, S: ChannelStorage<T>> BaseChannelAsync<T, S> {
26 fn id(&self) -> usize {
27 self.0.id.as_usize()
28 }
29}
30
31impl<T: Sized, S: ChannelStorage<T>> Eq for BaseChannelAsync<T, S> {}
32
33impl<T: Sized, S: ChannelStorage<T>> PartialEq for BaseChannelAsync<T, S> {
34 fn eq(&self, other: &Self) -> bool {
35 self.id() == other.id()
36 }
37}
38
39impl<T, S> Clone for BaseChannelAsync<T, S>
40where
41 T: Sized,
42 S: ChannelStorage<T>,
43{
44 fn clone(&self) -> Self {
45 Self(self.0.clone())
46 }
47}
48
49pub(crate) struct ChannelInner<T: Sized, S: ChannelStorage<T>> {
50 id: UniqueId,
51 pub(crate) data: Mutex<InnerData<T, S>>,
52 next_op_id: AtomicUsize,
53 space_available: Arc<Condvar>,
54 data_available: Arc<Condvar>,
55}
56
57impl<T: Sized, S: ChannelStorage<T>> BaseChannelAsync<T, S> {
58 pub(crate) fn new(capacity: usize, ordering: bool) -> Self {
59 let pc = InnerData::new(capacity, ordering);
60 let space_available = pc.space_available.clone();
61 let data_available = pc.data_available.clone();
62 Self(
63 ChannelInner {
64 id: <_>::default(),
65 data: Mutex::new(pc),
66 next_op_id: <_>::default(),
67 space_available,
68 data_available,
69 }
70 .into(),
71 )
72 }
73 fn op_id(&self) -> usize {
74 self.0.next_op_id.fetch_add(1, Ordering::SeqCst)
75 }
76}
77
78pub(crate) struct InnerData<T: Sized, S: ChannelStorage<T>> {
79 queue: S,
80 senders: usize,
81 receivers: usize,
82 pub(crate) send_fut_wakers: VecDeque<Option<(Waker, ClientId)>>,
83 pub(crate) send_fut_waker_ids: BTreeSet<ClientId>,
84 pub(crate) send_fut_pending: BTreeSet<ClientId>,
85 pub(crate) recv_fut_wakers: VecDeque<Option<(Waker, ClientId)>>,
86 pub(crate) recv_fut_waker_ids: BTreeSet<ClientId>,
87 pub(crate) recv_fut_pending: BTreeSet<ClientId>,
88 data_available: Arc<Condvar>,
89 space_available: Arc<Condvar>,
90 _phatom: PhantomData<T>,
91}
92
93impl<T, S> InnerData<T, S>
94where
95 T: Sized,
96 S: ChannelStorage<T>,
97{
98 fn new(capacity: usize, ordering: bool) -> Self {
99 assert!(capacity > 0, "channel capacity MUST be > 0");
100 Self {
101 queue: S::with_capacity_and_ordering(capacity, ordering),
102 senders: 1,
103 receivers: 1,
104 send_fut_wakers: <_>::default(),
105 send_fut_waker_ids: <_>::default(),
106 send_fut_pending: <_>::default(),
107 recv_fut_wakers: <_>::default(),
108 recv_fut_waker_ids: <_>::default(),
109 recv_fut_pending: <_>::default(),
110 data_available: <_>::default(),
111 space_available: <_>::default(),
112 _phatom: PhantomData,
113 }
114 }
115
116 #[inline]
119 fn notify_data_sent(&mut self) {
120 self.wake_next_recv();
121 }
122
123 #[inline]
124 fn wake_next_send(&mut self) {
125 if let Some(w) = self.send_fut_wakers.pop_front() {
126 if let Some((waker, id)) = w {
127 self.send_fut_waker_ids.remove(&id);
128 self.send_fut_pending.insert(id);
129 waker.wake();
130 } else {
131 self.space_available.notify_one();
132 }
133 }
134 }
135 #[inline]
136 fn wake_all_sends(&mut self) {
137 self.send_fut_waker_ids.clear();
138 for (waker, _) in mem::take(&mut self.send_fut_wakers).into_iter().flatten() {
139 waker.wake();
140 }
141 self.space_available.notify_all();
142 }
143
144 #[inline]
145 fn notify_send_fut_drop(&mut self, id: ClientId) {
146 if let Some(pos) = self
147 .send_fut_wakers
148 .iter()
149 .position(|w| w.as_ref().map_or(false, |(_, i)| *i == id))
150 {
151 self.send_fut_wakers.remove(pos);
152 self.send_fut_waker_ids.remove(&id);
153 }
154 if self.send_fut_pending.remove(&id) {
155 self.wake_next_send();
156 }
157 }
158
159 #[inline]
160 fn confirm_send_fut_waked(&mut self, id: ClientId) {
161 self.send_fut_pending.remove(&id);
162 }
163
164 #[inline]
165 fn append_send_fut_waker(&mut self, waker: Waker, id: ClientId) {
166 if !self.send_fut_waker_ids.insert(id) {
167 return;
168 }
169 self.send_fut_wakers.push_back(Some((waker, id)));
170 }
171
172 #[inline]
173 fn append_send_sync_waker(&mut self) {
174 self.send_fut_wakers.push_back(None);
176 }
177
178 #[inline]
181 fn notify_data_received(&mut self) {
182 self.wake_next_send();
183 }
184
185 #[inline]
186 fn wake_next_recv(&mut self) {
187 if let Some(w) = self.recv_fut_wakers.pop_front() {
188 if let Some((waker, id)) = w {
189 self.recv_fut_pending.insert(id);
190 self.recv_fut_waker_ids.remove(&id);
191 waker.wake();
192 } else {
193 self.data_available.notify_one();
194 }
195 }
196 }
197 #[inline]
198 fn wake_all_recvs(&mut self) {
199 for (waker, _) in mem::take(&mut self.recv_fut_wakers).into_iter().flatten() {
200 waker.wake();
201 }
202 self.recv_fut_waker_ids.clear();
203 self.data_available.notify_all();
204 }
205
206 #[inline]
207 fn notify_recv_fut_drop(&mut self, id: ClientId) {
208 if let Some(pos) = self
209 .recv_fut_wakers
210 .iter()
211 .position(|w| w.as_ref().map_or(false, |(_, i)| *i == id))
212 {
213 self.recv_fut_wakers.remove(pos);
214 self.recv_fut_waker_ids.remove(&id);
215 }
216 if self.recv_fut_pending.remove(&id) {
217 self.wake_next_recv();
218 }
219 }
220
221 #[inline]
222 fn confirm_recv_fut_waked(&mut self, id: ClientId) {
223 self.recv_fut_pending.remove(&id);
225 }
226
227 #[inline]
228 fn append_recv_fut_waker(&mut self, waker: Waker, id: ClientId) {
229 if !self.recv_fut_waker_ids.insert(id) {
230 return;
231 }
232 self.recv_fut_wakers.push_back(Some((waker, id)));
233 }
234
235 #[inline]
236 fn append_recv_sync_waker(&mut self) {
237 self.recv_fut_wakers.push_back(None);
239 }
240}
241
242#[pin_project(PinnedDrop)]
243struct Send<'a, T: Sized, S: ChannelStorage<T>> {
244 id: usize,
245 channel: &'a BaseChannelAsync<T, S>,
246 queued: bool,
247 value: Option<T>,
248}
249
250#[pinned_drop]
251#[allow(clippy::needless_lifetimes)]
252impl<'a, T: Sized, S: ChannelStorage<T>> PinnedDrop for Send<'a, T, S> {
253 fn drop(self: Pin<&mut Self>) {
254 if self.queued {
255 self.channel.0.data.lock().notify_send_fut_drop(self.id);
256 }
257 }
258}
259
260impl<'a, T, S> Future for Send<'a, T, S>
261where
262 T: Sized,
263 S: ChannelStorage<T>,
264{
265 type Output = Result<()>;
266 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267 let mut pc = self.channel.0.data.lock();
268 if self.queued {
269 pc.confirm_send_fut_waked(self.id);
270 }
271 if pc.receivers == 0 {
272 self.queued = false;
273 return Poll::Ready(Err(Error::ChannelClosed));
274 }
275 if pc.send_fut_wakers.is_empty() || self.queued {
276 let push_result = pc.queue.try_push(self.value.take().unwrap());
277 if let StorageTryPushOutput::Full(val) = push_result {
278 self.value = Some(val);
279 } else {
280 self.queued = false;
281 return Poll::Ready(match push_result {
282 StorageTryPushOutput::Pushed => {
283 pc.notify_data_sent();
284 Ok(())
285 }
286 StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
287 StorageTryPushOutput::Full(_) => unreachable!(),
288 });
289 }
290 }
291 self.queued = true;
292 pc.append_send_fut_waker(cx.waker().clone(), self.id);
293 Poll::Pending
294 }
295}
296
297#[derive(Eq, PartialEq)]
299pub struct BaseSenderAsync<T, S>
300where
301 T: Sized,
302 S: ChannelStorage<T>,
303{
304 channel: BaseChannelAsync<T, S>,
305}
306
307impl<T, S> BaseSenderAsync<T, S>
308where
309 T: Sized,
310 S: ChannelStorage<T>,
311{
312 #[inline]
314 pub fn send(&self, value: T) -> impl Future<Output = Result<()>> + '_ {
315 Send {
316 id: self.channel.op_id(),
317 channel: &self.channel,
318 queued: false,
319 value: Some(value),
320 }
321 }
322 pub fn try_send(&self, value: T) -> Result<()> {
324 let mut pc = self.channel.0.data.lock();
325 if pc.receivers == 0 {
326 return Err(Error::ChannelClosed);
327 }
328 match pc.queue.try_push(value) {
329 StorageTryPushOutput::Pushed => {
330 pc.notify_data_sent();
331 Ok(())
332 }
333 StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
334 StorageTryPushOutput::Full(_) => Err(Error::ChannelFull),
335 }
336 }
337 pub fn send_blocking(&self, mut value: T) -> Result<()> {
339 let mut pc = self.channel.0.data.lock();
340 let pushed = loop {
341 if pc.receivers == 0 {
342 return Err(Error::ChannelClosed);
343 }
344 let push_result = pc.queue.try_push(value);
345 let StorageTryPushOutput::Full(val) = push_result else {
346 break push_result;
347 };
348 value = val;
349 pc.append_send_sync_waker();
350 self.channel.0.space_available.wait(&mut pc);
351 };
352 match pushed {
353 StorageTryPushOutput::Pushed => {
354 pc.notify_data_sent();
355 Ok(())
356 }
357 StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
358 StorageTryPushOutput::Full(_) => unreachable!(),
359 }
360 }
361 pub fn send_blocking_timeout(&self, mut value: T, timeout: Duration) -> Result<()> {
363 let mut pc = self.channel.0.data.lock();
364 let pushed = loop {
365 if pc.receivers == 0 {
366 return Err(Error::ChannelClosed);
367 }
368 let push_result = pc.queue.try_push(value);
369 let StorageTryPushOutput::Full(val) = push_result else {
370 break push_result;
371 };
372 value = val;
373 pc.append_send_sync_waker();
374 if self
375 .channel
376 .0
377 .space_available
378 .wait_for(&mut pc, timeout)
379 .timed_out()
380 {
381 return Err(Error::Timeout);
382 }
383 };
384 pc.notify_data_sent();
385 match pushed {
386 StorageTryPushOutput::Pushed => Ok(()),
387 StorageTryPushOutput::Skipped => Err(Error::ChannelSkipped),
388 StorageTryPushOutput::Full(_) => unreachable!(),
389 }
390 }
391 #[inline]
393 pub fn len(&self) -> usize {
394 self.channel.0.data.lock().queue.len()
395 }
396 #[inline]
398 pub fn is_full(&self) -> bool {
399 self.channel.0.data.lock().queue.is_full()
400 }
401 #[inline]
403 pub fn is_empty(&self) -> bool {
404 self.channel.0.data.lock().queue.is_empty()
405 }
406 #[inline]
408 pub fn is_alive(&self) -> bool {
409 self.channel.0.data.lock().receivers > 0
410 }
411}
412
413impl<T, S> Clone for BaseSenderAsync<T, S>
414where
415 T: Sized,
416 S: ChannelStorage<T>,
417{
418 fn clone(&self) -> Self {
419 self.channel.0.data.lock().senders += 1;
420 Self {
421 channel: self.channel.clone(),
422 }
423 }
424}
425
426impl<T, S> Drop for BaseSenderAsync<T, S>
427where
428 T: Sized,
429 S: ChannelStorage<T>,
430{
431 fn drop(&mut self) {
432 let mut pc = self.channel.0.data.lock();
433 pc.senders -= 1;
434 if pc.senders == 0 {
435 pc.wake_all_recvs();
436 }
437 }
438}
439
440struct Recv<'a, T: Sized, S: ChannelStorage<T>> {
441 id: usize,
442 channel: &'a BaseChannelAsync<T, S>,
443 queued: bool,
444}
445
446impl<'a, T: Sized, S: ChannelStorage<T>> Drop for Recv<'a, T, S> {
447 fn drop(&mut self) {
448 if self.queued {
449 self.channel.0.data.lock().notify_recv_fut_drop(self.id);
450 }
451 }
452}
453
454impl<'a, T, S> Future for Recv<'a, T, S>
455where
456 T: Sized,
457 S: ChannelStorage<T>,
458{
459 type Output = Result<T>;
460 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
461 let mut pc = self.channel.0.data.lock();
462 if self.queued {
463 pc.confirm_recv_fut_waked(self.id);
464 }
465 if pc.recv_fut_wakers.is_empty() || self.queued {
466 if let Some(val) = pc.queue.get() {
467 pc.notify_data_received();
468 self.queued = false;
469 return Poll::Ready(Ok(val));
470 } else if pc.senders == 0 {
471 self.queued = false;
472 return Poll::Ready(Err(Error::ChannelClosed));
473 }
474 }
475 self.queued = true;
476 pc.append_recv_fut_waker(cx.waker().clone(), self.id);
477 Poll::Pending
478 }
479}
480
481#[derive(Eq, PartialEq)]
483pub struct BaseReceiverAsync<T, S>
484where
485 T: Sized,
486 S: ChannelStorage<T>,
487{
488 pub(crate) channel: BaseChannelAsync<T, S>,
489}
490
491impl<T, S> BaseReceiverAsync<T, S>
492where
493 T: Sized,
494 S: ChannelStorage<T>,
495{
496 #[inline]
498 pub fn recv(&self) -> impl Future<Output = Result<T>> + '_ {
499 Recv {
500 id: self.channel.op_id(),
501 channel: &self.channel,
502 queued: false,
503 }
504 }
505 pub fn try_recv(&self) -> Result<T> {
507 let mut pc = self.channel.0.data.lock();
508 if let Some(val) = pc.queue.get() {
509 pc.notify_data_received();
510 Ok(val)
511 } else if pc.senders == 0 {
512 Err(Error::ChannelClosed)
513 } else {
514 Err(Error::ChannelEmpty)
515 }
516 }
517 pub fn recv_blocking(&self) -> Result<T> {
519 let mut pc = self.channel.0.data.lock();
520 loop {
521 if let Some(val) = pc.queue.get() {
522 pc.notify_data_received();
523 return Ok(val);
524 } else if pc.senders == 0 {
525 return Err(Error::ChannelClosed);
526 }
527 pc.append_recv_sync_waker();
528 self.channel.0.data_available.wait(&mut pc);
529 }
530 }
531 pub fn recv_blocking_timeout(&self, timeout: Duration) -> Result<T> {
533 let mut pc = self.channel.0.data.lock();
534 loop {
535 if let Some(val) = pc.queue.get() {
536 pc.notify_data_received();
537 return Ok(val);
538 } else if pc.senders == 0 {
539 return Err(Error::ChannelClosed);
540 }
541 pc.append_recv_sync_waker();
542 if self
543 .channel
544 .0
545 .data_available
546 .wait_for(&mut pc, timeout)
547 .timed_out()
548 {
549 return Err(Error::Timeout);
550 }
551 }
552 }
553 #[inline]
555 pub fn len(&self) -> usize {
556 self.channel.0.data.lock().queue.len()
557 }
558 #[inline]
560 pub fn is_full(&self) -> bool {
561 self.channel.0.data.lock().queue.is_full()
562 }
563 #[inline]
565 pub fn is_empty(&self) -> bool {
566 self.channel.0.data.lock().queue.is_empty()
567 }
568 #[inline]
570 pub fn is_alive(&self) -> bool {
571 self.channel.0.data.lock().senders > 0
572 }
573}
574
575impl<T, S> Clone for BaseReceiverAsync<T, S>
576where
577 T: Sized,
578 S: ChannelStorage<T>,
579{
580 fn clone(&self) -> Self {
581 self.channel.0.data.lock().receivers += 1;
582 Self {
583 channel: self.channel.clone(),
584 }
585 }
586}
587
588impl<T, S> Drop for BaseReceiverAsync<T, S>
589where
590 T: Sized,
591 S: ChannelStorage<T>,
592{
593 fn drop(&mut self) {
594 let mut pc = self.channel.0.data.lock();
595 pc.receivers -= 1;
596 if pc.receivers == 0 {
597 pc.wake_all_sends();
598 }
599 }
600}
601
602pub(crate) fn make_channel<T: Sized, S: ChannelStorage<T>>(
603 ch: BaseChannelAsync<T, S>,
604) -> (BaseSenderAsync<T, S>, BaseReceiverAsync<T, S>) {
605 let tx = BaseSenderAsync {
606 channel: ch.clone(),
607 };
608 let rx = BaseReceiverAsync { channel: ch };
609 (tx, rx)
610}