1use std::sync::Arc;
43use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
44use std::time::Duration;
45
46use crossbeam_utils::Backoff;
47
48use crate::queue::mpsc as queue;
49
50const DEFAULT_PARK_TIMEOUT: Duration = Duration::from_millis(100);
52
53pub fn channel(capacity: usize) -> (Sender, Receiver) {
61 let (producer, consumer) = queue::new(capacity);
62
63 let parker = crossbeam_utils::sync::Parker::new();
64 let unparker = parker.unparker().clone();
65
66 let shared = Arc::new(ChannelShared {
67 receiver_waiting: AtomicBool::new(false),
68 receiver_unparker: unparker,
69 sender_count: AtomicUsize::new(1),
70 receiver_disconnected: AtomicBool::new(false),
71 });
72
73 (
74 Sender {
75 inner: producer,
76 shared: Arc::clone(&shared),
77 },
78 Receiver {
79 inner: consumer,
80 parker,
81 shared,
82 },
83 )
84}
85
86struct ChannelShared {
88 receiver_waiting: AtomicBool,
90 receiver_unparker: crossbeam_utils::sync::Unparker,
92 sender_count: AtomicUsize,
94 receiver_disconnected: AtomicBool,
96}
97
98pub struct Sender {
109 inner: queue::Producer,
110 shared: Arc<ChannelShared>,
111}
112
113impl Clone for Sender {
114 fn clone(&self) -> Self {
115 self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
116 Self {
117 inner: self.inner.clone(),
118 shared: Arc::clone(&self.shared),
119 }
120 }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
129pub struct ChannelClosed;
130
131impl std::fmt::Display for ChannelClosed {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 f.write_str("channel disconnected")
134 }
135}
136
137impl std::error::Error for ChannelClosed {}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
141pub enum TrySendError {
142 Full,
144 Disconnected,
146}
147
148impl std::fmt::Display for TrySendError {
149 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150 match self {
151 Self::Full => write!(f, "channel full"),
152 Self::Disconnected => write!(f, "channel disconnected"),
153 }
154 }
155}
156
157impl std::error::Error for TrySendError {}
158
159impl Sender {
160 #[inline]
177 pub fn send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, ChannelClosed> {
178 assert!(len > 0, "payload length must be non-zero");
182 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
183 return Err(ChannelClosed);
184 }
185
186 let backoff = Backoff::new();
187
188 loop {
189 unsafe {
194 let inner_ptr: *mut queue::Producer = &raw mut self.inner;
195 if let Ok(claim) = (*inner_ptr).try_claim(len) {
196 return Ok(std::mem::transmute::<
197 queue::WriteClaim<'_>,
198 queue::WriteClaim<'_>,
199 >(claim));
200 }
201 backoff.snooze();
203 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
204 return Err(ChannelClosed);
205 }
206 if backoff.is_completed() {
208 backoff.reset();
209 }
210 }
211 }
212 }
213
214 #[inline]
225 pub fn try_send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, TrySendError> {
226 assert!(len > 0, "payload length must be non-zero");
228 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
229 return Err(TrySendError::Disconnected);
230 }
231
232 match self.inner.try_claim(len) {
233 Ok(claim) => Ok(claim),
234 Err(crate::BufferFull) => Err(TrySendError::Full),
235 }
236 }
237
238 #[inline]
243 pub fn notify(&self) {
244 if self.shared.receiver_waiting.load(Ordering::Relaxed) {
245 self.shared.receiver_unparker.unpark();
246 }
247 }
248
249 #[inline]
251 pub fn capacity(&self) -> usize {
252 self.inner.capacity()
253 }
254
255 #[inline]
257 pub fn is_disconnected(&self) -> bool {
258 self.shared.receiver_disconnected.load(Ordering::Relaxed)
259 }
260}
261
262impl Drop for Sender {
263 fn drop(&mut self) {
264 let prev = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
265 if prev == 1 {
266 self.shared.receiver_unparker.unpark();
268 }
269 }
270}
271
272impl std::fmt::Debug for Sender {
273 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 f.debug_struct("Sender")
275 .field("capacity", &self.capacity())
276 .finish_non_exhaustive()
277 }
278}
279
280pub struct Receiver {
289 inner: queue::Consumer,
290 parker: crossbeam_utils::sync::Parker,
291 shared: Arc<ChannelShared>,
292}
293
294#[derive(Debug, Clone, Copy, PartialEq, Eq)]
296pub enum RecvError {
297 Timeout,
301 Disconnected,
303}
304
305impl std::fmt::Display for RecvError {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 match self {
308 Self::Timeout => write!(f, "receive timed out"),
309 Self::Disconnected => write!(f, "channel disconnected"),
310 }
311 }
312}
313
314impl std::error::Error for RecvError {}
315
316impl Receiver {
317 #[inline]
330 pub fn recv(&mut self, timeout: Option<Duration>) -> Result<queue::ReadClaim<'_>, RecvError> {
331 if timeout == Some(Duration::ZERO) {
333 unsafe {
338 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
339 if let Some(claim) = (*inner_ptr).try_claim() {
340 return Ok(std::mem::transmute::<
341 queue::ReadClaim<'_>,
342 queue::ReadClaim<'_>,
343 >(claim));
344 }
345 }
346 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
347 return Err(RecvError::Disconnected);
348 }
349 return Err(RecvError::Timeout);
350 }
351
352 let park_timeout = timeout.unwrap_or(DEFAULT_PARK_TIMEOUT);
353 let backoff = Backoff::new();
354
355 loop {
356 unsafe {
361 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
362 if let Some(claim) = (*inner_ptr).try_claim() {
363 return Ok(std::mem::transmute::<
364 queue::ReadClaim<'_>,
365 queue::ReadClaim<'_>,
366 >(claim));
367 }
368 }
369
370 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
371 return Err(RecvError::Disconnected);
372 }
373
374 if !backoff.is_completed() {
376 backoff.snooze();
377 continue;
378 }
379
380 self.shared.receiver_waiting.store(true, Ordering::Relaxed);
382 self.parker.park_timeout(park_timeout);
383 self.shared.receiver_waiting.store(false, Ordering::Relaxed);
384
385 if timeout.is_some() {
388 unsafe {
391 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
392 if let Some(claim) = (*inner_ptr).try_claim() {
393 return Ok(std::mem::transmute::<
394 queue::ReadClaim<'_>,
395 queue::ReadClaim<'_>,
396 >(claim));
397 }
398 }
399
400 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
401 return Err(RecvError::Disconnected);
402 }
403
404 return Err(RecvError::Timeout);
405 }
406
407 backoff.reset();
409 }
410 }
411
412 #[inline]
416 pub fn try_recv(&mut self) -> Option<queue::ReadClaim<'_>> {
417 self.inner.try_claim()
418 }
419
420 #[inline]
422 pub fn capacity(&self) -> usize {
423 self.inner.capacity()
424 }
425
426 #[inline]
428 pub fn is_disconnected(&self) -> bool {
429 self.shared.sender_count.load(Ordering::Relaxed) == 0
430 }
431}
432
433impl Drop for Receiver {
434 fn drop(&mut self) {
435 self.shared
436 .receiver_disconnected
437 .store(true, Ordering::Relaxed);
438 }
439}
440
441impl std::fmt::Debug for Receiver {
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.debug_struct("Receiver")
444 .field("capacity", &self.capacity())
445 .finish_non_exhaustive()
446 }
447}
448
449#[cfg(test)]
454mod tests {
455 use super::*;
456 use std::thread;
457
458 #[test]
459 fn basic_send_recv() {
460 let (mut tx, mut rx) = channel(1024);
461
462 let payload = b"hello world";
463 let mut claim = tx.send(payload.len()).unwrap();
464 claim.copy_from_slice(payload);
465 claim.commit();
466 tx.notify();
467
468 let record = rx.recv(None).unwrap();
469 assert_eq!(&*record, payload);
470 }
471
472 #[test]
473 #[allow(clippy::redundant_clone)]
474 fn sender_is_clone() {
475 let (tx, _rx) = channel(1024);
476 let _tx2 = tx.clone();
477 }
478
479 #[test]
480 fn multiple_senders() {
481 const SENDERS: usize = 4;
482 const MESSAGES: usize = 100;
483
484 let (tx, mut rx) = channel(4096);
485
486 let handles: Vec<_> = (0..SENDERS)
487 .map(|id| {
488 let mut tx = tx.clone();
489 thread::spawn(move || {
490 for i in 0..MESSAGES {
491 let payload = format!("{}:{}", id, i);
492 let mut claim = tx.send(payload.len()).unwrap();
493 claim.copy_from_slice(payload.as_bytes());
494 claim.commit();
495 tx.notify();
496 }
497 })
498 })
499 .collect();
500
501 drop(tx);
502
503 let mut count = 0;
504 while let Ok(_record) = rx.recv(None) {
505 count += 1;
506 if count == SENDERS * MESSAGES {
507 break;
508 }
509 }
510
511 for h in handles {
512 h.join().unwrap();
513 }
514
515 assert_eq!(count, SENDERS * MESSAGES);
516 }
517
518 #[test]
519 fn disconnection_all_senders_dropped() {
520 let (tx, mut rx) = channel(1024);
521
522 drop(tx);
523
524 match rx.recv(None) {
525 Err(RecvError::Disconnected) => {}
526 _ => panic!("expected Disconnected"),
527 }
528 }
529
530 #[test]
531 fn disconnection_receiver_dropped() {
532 let (mut tx, rx) = channel(1024);
533
534 drop(rx);
535
536 match tx.send(8) {
537 Err(ChannelClosed) => {}
538 _ => panic!("expected ChannelClosed"),
539 }
540 }
541
542 #[test]
543 fn recv_timeout_works() {
544 let (_tx, mut rx) = channel(1024);
545
546 let start = std::time::Instant::now();
547 let result = rx.recv(Some(Duration::from_millis(50)));
548 let elapsed = start.elapsed();
549
550 assert!(matches!(result, Err(RecvError::Timeout)));
551 assert!(elapsed >= Duration::from_millis(40));
552 assert!(elapsed < Duration::from_millis(200));
553 }
554
555 #[test]
556 #[should_panic(expected = "payload length must be non-zero")]
557 fn send_zero_panics() {
558 let (mut tx, _rx) = channel(1024);
559 let _ = tx.send(0);
560 }
561
562 #[test]
563 #[should_panic(expected = "payload length must be non-zero")]
564 fn try_send_zero_panics() {
565 let (mut tx, _rx) = channel(1024);
566 let _ = tx.try_send(0);
567 }
568
569 #[test]
571 fn stress_multiple_senders() {
572 const SENDERS: usize = 4;
573 const MESSAGES_PER_SENDER: u64 = 10_000;
574 const TOTAL: u64 = SENDERS as u64 * MESSAGES_PER_SENDER;
575 const BUFFER_SIZE: usize = 64 * 1024;
576
577 let (tx, mut rx) = channel(BUFFER_SIZE);
578
579 let handles: Vec<_> = (0..SENDERS)
580 .map(|sender_id| {
581 let mut tx = tx.clone();
582 thread::spawn(move || {
583 for i in 0..MESSAGES_PER_SENDER {
584 let mut payload = [0u8; 16];
586 payload[..8].copy_from_slice(&(sender_id as u64).to_le_bytes());
587 payload[8..].copy_from_slice(&i.to_le_bytes());
588
589 {
590 let mut claim = tx.send(16).unwrap();
591 claim.copy_from_slice(&payload);
592 claim.commit();
593 }
594 tx.notify();
595 }
596 })
597 })
598 .collect();
599
600 drop(tx);
601
602 let consumer = thread::spawn(move || {
604 let mut received = 0u64;
605 let mut per_sender = vec![0u64; SENDERS];
606
607 while received < TOTAL {
608 match rx.recv(None) {
609 Ok(record) => {
610 let sender_id =
611 u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
612 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
613
614 assert_eq!(
616 seq, per_sender[sender_id],
617 "sender {} out of order at {}",
618 sender_id, received
619 );
620 per_sender[sender_id] += 1;
621 received += 1;
622 }
623 Err(RecvError::Timeout) => unreachable!(),
624 Err(RecvError::Disconnected) => break,
625 }
626 }
627
628 per_sender
629 });
630
631 for h in handles {
632 h.join().unwrap();
633 }
634
635 let per_sender = consumer.join().unwrap();
636 for (i, &count) in per_sender.iter().enumerate() {
637 assert_eq!(count, MESSAGES_PER_SENDER, "sender {} count", i);
638 }
639 }
640}