1use std::sync::Arc;
43use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
44use std::time::Duration;
45
46use crossbeam_utils::Backoff;
47
48use crate::queue::mpsc as queue;
49
50const DEFAULT_PARK_TIMEOUT: Duration = Duration::from_millis(100);
52
53pub fn channel(capacity: usize) -> (Sender, Receiver) {
61 let (producer, consumer) = queue::new(capacity);
62
63 let parker = crossbeam_utils::sync::Parker::new();
64 let unparker = parker.unparker().clone();
65
66 let shared = Arc::new(ChannelShared {
67 receiver_waiting: AtomicBool::new(false),
68 receiver_unparker: unparker,
69 sender_count: AtomicUsize::new(1),
70 receiver_disconnected: AtomicBool::new(false),
71 });
72
73 (
74 Sender {
75 inner: producer,
76 shared: Arc::clone(&shared),
77 },
78 Receiver {
79 inner: consumer,
80 parker,
81 shared,
82 },
83 )
84}
85
86struct ChannelShared {
88 receiver_waiting: AtomicBool,
90 receiver_unparker: crossbeam_utils::sync::Unparker,
92 sender_count: AtomicUsize,
94 receiver_disconnected: AtomicBool,
96}
97
98pub struct Sender {
109 inner: queue::Producer,
110 shared: Arc<ChannelShared>,
111}
112
113impl Clone for Sender {
114 fn clone(&self) -> Self {
115 self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
116 Self {
117 inner: self.inner.clone(),
118 shared: Arc::clone(&self.shared),
119 }
120 }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum SendError {
126 Disconnected,
128 ZeroLength,
130}
131
132impl std::fmt::Display for SendError {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 match self {
135 Self::Disconnected => write!(f, "channel disconnected"),
136 Self::ZeroLength => write!(f, "payload length must be non-zero"),
137 }
138 }
139}
140
141impl std::error::Error for SendError {}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145pub enum TrySendError {
146 Full,
148 Disconnected,
150 ZeroLength,
152}
153
154impl std::fmt::Display for TrySendError {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 match self {
157 Self::Full => write!(f, "channel full"),
158 Self::Disconnected => write!(f, "channel disconnected"),
159 Self::ZeroLength => write!(f, "payload length must be non-zero"),
160 }
161 }
162}
163
164impl std::error::Error for TrySendError {}
165
166impl Sender {
167 #[inline]
181 pub fn send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, SendError> {
182 if len == 0 {
184 return Err(SendError::ZeroLength);
185 }
186 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
187 return Err(SendError::Disconnected);
188 }
189
190 let backoff = Backoff::new();
191
192 loop {
193 unsafe {
198 let inner_ptr: *mut queue::Producer = &raw mut self.inner;
199 match (*inner_ptr).try_claim(len) {
200 Ok(claim) => {
201 return Ok(std::mem::transmute::<
202 queue::WriteClaim<'_>,
203 queue::WriteClaim<'_>,
204 >(claim));
205 }
206 Err(crate::TryClaimError::Full) => {
207 backoff.snooze();
208 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
209 return Err(SendError::Disconnected);
210 }
211 if backoff.is_completed() {
213 backoff.reset();
214 }
215 }
216 Err(crate::TryClaimError::ZeroLength) => return Err(SendError::ZeroLength),
217 }
218 }
219 }
220 }
221
222 #[inline]
230 pub fn try_send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, TrySendError> {
231 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
232 return Err(TrySendError::Disconnected);
233 }
234
235 match self.inner.try_claim(len) {
236 Ok(claim) => Ok(claim),
237 Err(crate::TryClaimError::Full) => Err(TrySendError::Full),
238 Err(crate::TryClaimError::ZeroLength) => Err(TrySendError::ZeroLength),
239 }
240 }
241
242 #[inline]
247 pub fn notify(&self) {
248 if self.shared.receiver_waiting.load(Ordering::Relaxed) {
249 self.shared.receiver_unparker.unpark();
250 }
251 }
252
253 #[inline]
255 pub fn capacity(&self) -> usize {
256 self.inner.capacity()
257 }
258
259 #[inline]
261 pub fn is_disconnected(&self) -> bool {
262 self.shared.receiver_disconnected.load(Ordering::Relaxed)
263 }
264}
265
266impl Drop for Sender {
267 fn drop(&mut self) {
268 let prev = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
269 if prev == 1 {
270 self.shared.receiver_unparker.unpark();
272 }
273 }
274}
275
276impl std::fmt::Debug for Sender {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 f.debug_struct("Sender")
279 .field("capacity", &self.capacity())
280 .finish_non_exhaustive()
281 }
282}
283
284pub struct Receiver {
293 inner: queue::Consumer,
294 parker: crossbeam_utils::sync::Parker,
295 shared: Arc<ChannelShared>,
296}
297
298#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub enum RecvError {
301 Timeout,
305 Disconnected,
307}
308
309impl std::fmt::Display for RecvError {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 match self {
312 Self::Timeout => write!(f, "receive timed out"),
313 Self::Disconnected => write!(f, "channel disconnected"),
314 }
315 }
316}
317
318impl std::error::Error for RecvError {}
319
320impl Receiver {
321 #[inline]
334 pub fn recv(&mut self, timeout: Option<Duration>) -> Result<queue::ReadClaim<'_>, RecvError> {
335 if timeout == Some(Duration::ZERO) {
337 unsafe {
342 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
343 if let Some(claim) = (*inner_ptr).try_claim() {
344 return Ok(std::mem::transmute::<
345 queue::ReadClaim<'_>,
346 queue::ReadClaim<'_>,
347 >(claim));
348 }
349 }
350 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
351 return Err(RecvError::Disconnected);
352 }
353 return Err(RecvError::Timeout);
354 }
355
356 let park_timeout = timeout.unwrap_or(DEFAULT_PARK_TIMEOUT);
357 let backoff = Backoff::new();
358
359 loop {
360 unsafe {
365 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
366 if let Some(claim) = (*inner_ptr).try_claim() {
367 return Ok(std::mem::transmute::<
368 queue::ReadClaim<'_>,
369 queue::ReadClaim<'_>,
370 >(claim));
371 }
372 }
373
374 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
375 return Err(RecvError::Disconnected);
376 }
377
378 if !backoff.is_completed() {
380 backoff.snooze();
381 continue;
382 }
383
384 self.shared.receiver_waiting.store(true, Ordering::Relaxed);
386 self.parker.park_timeout(park_timeout);
387 self.shared.receiver_waiting.store(false, Ordering::Relaxed);
388
389 if timeout.is_some() {
392 unsafe {
395 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
396 if let Some(claim) = (*inner_ptr).try_claim() {
397 return Ok(std::mem::transmute::<
398 queue::ReadClaim<'_>,
399 queue::ReadClaim<'_>,
400 >(claim));
401 }
402 }
403
404 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
405 return Err(RecvError::Disconnected);
406 }
407
408 return Err(RecvError::Timeout);
409 }
410
411 backoff.reset();
413 }
414 }
415
416 #[inline]
420 pub fn try_recv(&mut self) -> Option<queue::ReadClaim<'_>> {
421 self.inner.try_claim()
422 }
423
424 #[inline]
426 pub fn capacity(&self) -> usize {
427 self.inner.capacity()
428 }
429
430 #[inline]
432 pub fn is_disconnected(&self) -> bool {
433 self.shared.sender_count.load(Ordering::Relaxed) == 0
434 }
435}
436
437impl Drop for Receiver {
438 fn drop(&mut self) {
439 self.shared
440 .receiver_disconnected
441 .store(true, Ordering::Relaxed);
442 }
443}
444
445impl std::fmt::Debug for Receiver {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 f.debug_struct("Receiver")
448 .field("capacity", &self.capacity())
449 .finish_non_exhaustive()
450 }
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460 use std::thread;
461
462 #[test]
463 fn basic_send_recv() {
464 let (mut tx, mut rx) = channel(1024);
465
466 let payload = b"hello world";
467 let mut claim = tx.send(payload.len()).unwrap();
468 claim.copy_from_slice(payload);
469 claim.commit();
470 tx.notify();
471
472 let record = rx.recv(None).unwrap();
473 assert_eq!(&*record, payload);
474 }
475
476 #[test]
477 #[allow(clippy::redundant_clone)]
478 fn sender_is_clone() {
479 let (tx, _rx) = channel(1024);
480 let _tx2 = tx.clone();
481 }
482
483 #[test]
484 fn multiple_senders() {
485 const SENDERS: usize = 4;
486 const MESSAGES: usize = 100;
487
488 let (tx, mut rx) = channel(4096);
489
490 let handles: Vec<_> = (0..SENDERS)
491 .map(|id| {
492 let mut tx = tx.clone();
493 thread::spawn(move || {
494 for i in 0..MESSAGES {
495 let payload = format!("{}:{}", id, i);
496 let mut claim = tx.send(payload.len()).unwrap();
497 claim.copy_from_slice(payload.as_bytes());
498 claim.commit();
499 tx.notify();
500 }
501 })
502 })
503 .collect();
504
505 drop(tx);
506
507 let mut count = 0;
508 while let Ok(_record) = rx.recv(None) {
509 count += 1;
510 if count == SENDERS * MESSAGES {
511 break;
512 }
513 }
514
515 for h in handles {
516 h.join().unwrap();
517 }
518
519 assert_eq!(count, SENDERS * MESSAGES);
520 }
521
522 #[test]
523 fn disconnection_all_senders_dropped() {
524 let (tx, mut rx) = channel(1024);
525
526 drop(tx);
527
528 match rx.recv(None) {
529 Err(RecvError::Disconnected) => {}
530 _ => panic!("expected Disconnected"),
531 }
532 }
533
534 #[test]
535 fn disconnection_receiver_dropped() {
536 let (mut tx, rx) = channel(1024);
537
538 drop(rx);
539
540 match tx.send(8) {
541 Err(SendError::Disconnected) => {}
542 _ => panic!("expected Disconnected"),
543 }
544 }
545
546 #[test]
547 fn recv_timeout_works() {
548 let (_tx, mut rx) = channel(1024);
549
550 let start = std::time::Instant::now();
551 let result = rx.recv(Some(Duration::from_millis(50)));
552 let elapsed = start.elapsed();
553
554 assert!(matches!(result, Err(RecvError::Timeout)));
555 assert!(elapsed >= Duration::from_millis(40));
556 assert!(elapsed < Duration::from_millis(200));
557 }
558
559 #[test]
560 fn zero_len_error() {
561 let (mut tx, _rx) = channel(1024);
562 assert!(matches!(tx.send(0), Err(SendError::ZeroLength)));
563 assert!(matches!(tx.try_send(0), Err(TrySendError::ZeroLength)));
564 }
565
566 #[test]
568 fn stress_multiple_senders() {
569 const SENDERS: usize = 4;
570 const MESSAGES_PER_SENDER: u64 = 10_000;
571 const TOTAL: u64 = SENDERS as u64 * MESSAGES_PER_SENDER;
572 const BUFFER_SIZE: usize = 64 * 1024;
573
574 let (tx, mut rx) = channel(BUFFER_SIZE);
575
576 let handles: Vec<_> = (0..SENDERS)
577 .map(|sender_id| {
578 let mut tx = tx.clone();
579 thread::spawn(move || {
580 for i in 0..MESSAGES_PER_SENDER {
581 let mut payload = [0u8; 16];
583 payload[..8].copy_from_slice(&(sender_id as u64).to_le_bytes());
584 payload[8..].copy_from_slice(&i.to_le_bytes());
585
586 {
587 let mut claim = tx.send(16).unwrap();
588 claim.copy_from_slice(&payload);
589 claim.commit();
590 }
591 tx.notify();
592 }
593 })
594 })
595 .collect();
596
597 drop(tx);
598
599 let consumer = thread::spawn(move || {
601 let mut received = 0u64;
602 let mut per_sender = vec![0u64; SENDERS];
603
604 while received < TOTAL {
605 match rx.recv(None) {
606 Ok(record) => {
607 let sender_id =
608 u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
609 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
610
611 assert_eq!(
613 seq, per_sender[sender_id],
614 "sender {} out of order at {}",
615 sender_id, received
616 );
617 per_sender[sender_id] += 1;
618 received += 1;
619 }
620 Err(RecvError::Timeout) => unreachable!(),
621 Err(RecvError::Disconnected) => break,
622 }
623 }
624
625 per_sender
626 });
627
628 for h in handles {
629 h.join().unwrap();
630 }
631
632 let per_sender = consumer.join().unwrap();
633 for (i, &count) in per_sender.iter().enumerate() {
634 assert_eq!(count, MESSAGES_PER_SENDER, "sender {} count", i);
635 }
636 }
637}