1use std::sync::Arc;
43use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
44use std::time::Duration;
45
46use crossbeam_utils::Backoff;
47
48use crate::queue::mpsc as queue;
49
50const DEFAULT_PARK_TIMEOUT: Duration = Duration::from_millis(100);
52
53pub fn channel(capacity: usize) -> (Sender, Receiver) {
61 let (producer, consumer) = queue::new(capacity);
62
63 let parker = crossbeam_utils::sync::Parker::new();
64 let unparker = parker.unparker().clone();
65
66 let shared = Arc::new(ChannelShared {
67 receiver_waiting: AtomicBool::new(false),
68 receiver_unparker: unparker,
69 sender_count: AtomicUsize::new(1),
70 receiver_disconnected: AtomicBool::new(false),
71 });
72
73 (
74 Sender {
75 inner: producer,
76 shared: Arc::clone(&shared),
77 },
78 Receiver {
79 inner: consumer,
80 parker,
81 shared,
82 },
83 )
84}
85
86struct ChannelShared {
88 receiver_waiting: AtomicBool,
90 receiver_unparker: crossbeam_utils::sync::Unparker,
92 sender_count: AtomicUsize,
94 receiver_disconnected: AtomicBool,
96}
97
98pub struct Sender {
109 inner: queue::Producer,
110 shared: Arc<ChannelShared>,
111}
112
113impl Clone for Sender {
114 fn clone(&self) -> Self {
115 self.shared.sender_count.fetch_add(1, Ordering::Relaxed);
116 Self {
117 inner: self.inner.clone(),
118 shared: Arc::clone(&self.shared),
119 }
120 }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum SendError {
126 Disconnected,
128 ZeroLength,
130}
131
132impl std::fmt::Display for SendError {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 match self {
135 Self::Disconnected => write!(f, "channel disconnected"),
136 Self::ZeroLength => write!(f, "payload length must be non-zero"),
137 }
138 }
139}
140
141impl std::error::Error for SendError {}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145pub enum TrySendError {
146 Full,
148 Disconnected,
150 ZeroLength,
152}
153
154impl std::fmt::Display for TrySendError {
155 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156 match self {
157 Self::Full => write!(f, "channel full"),
158 Self::Disconnected => write!(f, "channel disconnected"),
159 Self::ZeroLength => write!(f, "payload length must be non-zero"),
160 }
161 }
162}
163
164impl std::error::Error for TrySendError {}
165
166impl Sender {
167 #[inline]
181 pub fn send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, SendError> {
182 if len == 0 {
184 return Err(SendError::ZeroLength);
185 }
186 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
187 return Err(SendError::Disconnected);
188 }
189
190 let backoff = Backoff::new();
191
192 loop {
193 unsafe {
198 let inner_ptr: *mut queue::Producer = &raw mut self.inner;
199 match (*inner_ptr).try_claim(len) {
200 Ok(claim) => {
201 return Ok(std::mem::transmute::<
202 queue::WriteClaim<'_>,
203 queue::WriteClaim<'_>,
204 >(claim));
205 }
206 Err(crate::TryClaimError::Full) => {
207 backoff.snooze();
208 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
209 return Err(SendError::Disconnected);
210 }
211 if backoff.is_completed() {
213 backoff.reset();
214 }
215 }
216 Err(crate::TryClaimError::ZeroLength) => return Err(SendError::ZeroLength),
217 }
218 }
219 }
220 }
221
222 #[inline]
230 pub fn try_send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, TrySendError> {
231 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
232 return Err(TrySendError::Disconnected);
233 }
234
235 match self.inner.try_claim(len) {
236 Ok(claim) => Ok(claim),
237 Err(crate::TryClaimError::Full) => Err(TrySendError::Full),
238 Err(crate::TryClaimError::ZeroLength) => Err(TrySendError::ZeroLength),
239 }
240 }
241
242 #[inline]
247 pub fn notify(&self) {
248 if self.shared.receiver_waiting.load(Ordering::Relaxed) {
249 self.shared.receiver_unparker.unpark();
250 }
251 }
252
253 #[inline]
255 pub fn capacity(&self) -> usize {
256 self.inner.capacity()
257 }
258
259 #[inline]
261 pub fn is_disconnected(&self) -> bool {
262 self.shared.receiver_disconnected.load(Ordering::Relaxed)
263 }
264}
265
266impl Drop for Sender {
267 fn drop(&mut self) {
268 let prev = self.shared.sender_count.fetch_sub(1, Ordering::Relaxed);
269 if prev == 1 {
270 self.shared.receiver_unparker.unpark();
272 }
273 }
274}
275
276impl std::fmt::Debug for Sender {
277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
278 f.debug_struct("Sender")
279 .field("capacity", &self.capacity())
280 .finish_non_exhaustive()
281 }
282}
283
284pub struct Receiver {
293 inner: queue::Consumer,
294 parker: crossbeam_utils::sync::Parker,
295 shared: Arc<ChannelShared>,
296}
297
298#[derive(Debug, Clone, Copy, PartialEq, Eq)]
300pub enum RecvError {
301 Timeout,
305 Disconnected,
307}
308
309impl std::fmt::Display for RecvError {
310 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
311 match self {
312 Self::Timeout => write!(f, "receive timed out"),
313 Self::Disconnected => write!(f, "channel disconnected"),
314 }
315 }
316}
317
318impl std::error::Error for RecvError {}
319
320impl Receiver {
321 #[inline]
334 pub fn recv(&mut self, timeout: Option<Duration>) -> Result<queue::ReadClaim<'_>, RecvError> {
335 if timeout == Some(Duration::ZERO) {
337 unsafe {
342 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
343 if let Some(claim) = (*inner_ptr).try_claim() {
344 return Ok(std::mem::transmute::<
345 queue::ReadClaim<'_>,
346 queue::ReadClaim<'_>,
347 >(claim));
348 }
349 }
350 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
351 return Err(RecvError::Disconnected);
352 }
353 return Err(RecvError::Timeout);
354 }
355
356 let park_timeout = timeout.unwrap_or(DEFAULT_PARK_TIMEOUT);
357 let backoff = Backoff::new();
358
359 loop {
360 unsafe {
365 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
366 if let Some(claim) = (*inner_ptr).try_claim() {
367 return Ok(std::mem::transmute::<
368 queue::ReadClaim<'_>,
369 queue::ReadClaim<'_>,
370 >(claim));
371 }
372 }
373
374 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
375 return Err(RecvError::Disconnected);
376 }
377
378 if !backoff.is_completed() {
380 backoff.snooze();
381 continue;
382 }
383
384 self.shared.receiver_waiting.store(true, Ordering::Relaxed);
386 self.parker.park_timeout(park_timeout);
387 self.shared.receiver_waiting.store(false, Ordering::Relaxed);
388
389 if timeout.is_some() {
392 unsafe {
395 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
396 if let Some(claim) = (*inner_ptr).try_claim() {
397 return Ok(std::mem::transmute::<
398 queue::ReadClaim<'_>,
399 queue::ReadClaim<'_>,
400 >(claim));
401 }
402 }
403
404 if self.shared.sender_count.load(Ordering::Relaxed) == 0 {
405 return Err(RecvError::Disconnected);
406 }
407
408 return Err(RecvError::Timeout);
409 }
410
411 backoff.reset();
413 }
414 }
415
416 #[inline]
420 pub fn try_recv(&mut self) -> Option<queue::ReadClaim<'_>> {
421 self.inner.try_claim()
422 }
423
424 #[inline]
426 pub fn capacity(&self) -> usize {
427 self.inner.capacity()
428 }
429
430 #[inline]
432 pub fn is_disconnected(&self) -> bool {
433 self.shared.sender_count.load(Ordering::Relaxed) == 0
434 }
435}
436
437impl Drop for Receiver {
438 fn drop(&mut self) {
439 self.shared
440 .receiver_disconnected
441 .store(true, Ordering::Relaxed);
442 }
443}
444
445impl std::fmt::Debug for Receiver {
446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
447 f.debug_struct("Receiver")
448 .field("capacity", &self.capacity())
449 .finish_non_exhaustive()
450 }
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460 use std::thread;
461
462 #[test]
463 fn basic_send_recv() {
464 let (mut tx, mut rx) = channel(1024);
465
466 let payload = b"hello world";
467 let mut claim = tx.send(payload.len()).unwrap();
468 claim.copy_from_slice(payload);
469 claim.commit();
470 tx.notify();
471
472 let record = rx.recv(None).unwrap();
473 assert_eq!(&*record, payload);
474 }
475
476 #[test]
477 fn sender_is_clone() {
478 let (tx, _rx) = channel(1024);
479 let _tx2 = tx.clone();
480 }
481
482 #[test]
483 fn multiple_senders() {
484 let (tx, mut rx) = channel(4096);
485
486 const SENDERS: usize = 4;
487 const MESSAGES: usize = 100;
488
489 let handles: Vec<_> = (0..SENDERS)
490 .map(|id| {
491 let mut tx = tx.clone();
492 thread::spawn(move || {
493 for i in 0..MESSAGES {
494 let payload = format!("{}:{}", id, i);
495 let mut claim = tx.send(payload.len()).unwrap();
496 claim.copy_from_slice(payload.as_bytes());
497 claim.commit();
498 tx.notify();
499 }
500 })
501 })
502 .collect();
503
504 drop(tx);
505
506 let mut count = 0;
507 while let Ok(_record) = rx.recv(None) {
508 count += 1;
509 if count == SENDERS * MESSAGES {
510 break;
511 }
512 }
513
514 for h in handles {
515 h.join().unwrap();
516 }
517
518 assert_eq!(count, SENDERS * MESSAGES);
519 }
520
521 #[test]
522 fn disconnection_all_senders_dropped() {
523 let (tx, mut rx) = channel(1024);
524
525 drop(tx);
526
527 match rx.recv(None) {
528 Err(RecvError::Disconnected) => {}
529 _ => panic!("expected Disconnected"),
530 }
531 }
532
533 #[test]
534 fn disconnection_receiver_dropped() {
535 let (mut tx, rx) = channel(1024);
536
537 drop(rx);
538
539 match tx.send(8) {
540 Err(SendError::Disconnected) => {}
541 _ => panic!("expected Disconnected"),
542 }
543 }
544
545 #[test]
546 fn recv_timeout_works() {
547 let (_tx, mut rx) = channel(1024);
548
549 let start = std::time::Instant::now();
550 let result = rx.recv(Some(Duration::from_millis(50)));
551 let elapsed = start.elapsed();
552
553 assert!(matches!(result, Err(RecvError::Timeout)));
554 assert!(elapsed >= Duration::from_millis(40));
555 assert!(elapsed < Duration::from_millis(200));
556 }
557
558 #[test]
559 fn zero_len_error() {
560 let (mut tx, _rx) = channel(1024);
561 assert!(matches!(tx.send(0), Err(SendError::ZeroLength)));
562 assert!(matches!(tx.try_send(0), Err(TrySendError::ZeroLength)));
563 }
564
565 #[test]
567 fn stress_multiple_senders() {
568 const SENDERS: usize = 4;
569 const MESSAGES_PER_SENDER: u64 = 10_000;
570 const TOTAL: u64 = SENDERS as u64 * MESSAGES_PER_SENDER;
571 const BUFFER_SIZE: usize = 64 * 1024;
572
573 let (tx, mut rx) = channel(BUFFER_SIZE);
574
575 let handles: Vec<_> = (0..SENDERS)
576 .map(|sender_id| {
577 let mut tx = tx.clone();
578 thread::spawn(move || {
579 for i in 0..MESSAGES_PER_SENDER {
580 let mut payload = [0u8; 16];
582 payload[..8].copy_from_slice(&(sender_id as u64).to_le_bytes());
583 payload[8..].copy_from_slice(&i.to_le_bytes());
584
585 {
586 let mut claim = tx.send(16).unwrap();
587 claim.copy_from_slice(&payload);
588 claim.commit();
589 }
590 tx.notify();
591 }
592 })
593 })
594 .collect();
595
596 drop(tx);
597
598 let consumer = thread::spawn(move || {
600 let mut received = 0u64;
601 let mut per_sender = vec![0u64; SENDERS];
602
603 while received < TOTAL {
604 match rx.recv(None) {
605 Ok(record) => {
606 let sender_id =
607 u64::from_le_bytes(record[..8].try_into().unwrap()) as usize;
608 let seq = u64::from_le_bytes(record[8..].try_into().unwrap());
609
610 assert_eq!(
612 seq, per_sender[sender_id],
613 "sender {} out of order at {}",
614 sender_id, received
615 );
616 per_sender[sender_id] += 1;
617 received += 1;
618 }
619 Err(RecvError::Timeout) => unreachable!(),
620 Err(RecvError::Disconnected) => break,
621 }
622 }
623
624 per_sender
625 });
626
627 for h in handles {
628 h.join().unwrap();
629 }
630
631 let per_sender = consumer.join().unwrap();
632 for (i, &count) in per_sender.iter().enumerate() {
633 assert_eq!(count, MESSAGES_PER_SENDER, "sender {} count", i);
634 }
635 }
636}