1use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, Ordering};
36use std::time::Duration;
37
38use crossbeam_utils::Backoff;
39
40use crate::queue::spsc as queue;
41
42const DEFAULT_PARK_TIMEOUT: Duration = Duration::from_millis(100);
46
47pub fn channel(capacity: usize) -> (Sender, Receiver) {
55 let (producer, consumer) = queue::new(capacity);
56
57 let shared = Arc::new(ChannelShared {
58 receiver_waiting: AtomicBool::new(false),
59 sender_disconnected: AtomicBool::new(false),
60 receiver_disconnected: AtomicBool::new(false),
61 });
62
63 let parker = crossbeam_utils::sync::Parker::new();
64 let unparker = parker.unparker().clone();
65
66 (
67 Sender {
68 inner: producer,
69 receiver_unparker: unparker,
70 shared: Arc::clone(&shared),
71 },
72 Receiver {
73 inner: consumer,
74 parker,
75 shared,
76 },
77 )
78}
79
80struct ChannelShared {
82 receiver_waiting: AtomicBool,
84 sender_disconnected: AtomicBool,
86 receiver_disconnected: AtomicBool,
88}
89
90pub struct Sender {
99 inner: queue::Producer,
100 receiver_unparker: crossbeam_utils::sync::Unparker,
101 shared: Arc<ChannelShared>,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub struct ChannelClosed;
111
112impl std::fmt::Display for ChannelClosed {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.write_str("channel disconnected")
115 }
116}
117
118impl std::error::Error for ChannelClosed {}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum TrySendError {
123 Full,
125 Disconnected,
127}
128
129impl std::fmt::Display for TrySendError {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 match self {
132 Self::Full => write!(f, "channel full"),
133 Self::Disconnected => write!(f, "channel disconnected"),
134 }
135 }
136}
137
138impl std::error::Error for TrySendError {}
139
140impl Sender {
141 #[inline]
158 pub fn send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, ChannelClosed> {
159 assert!(len > 0, "payload length must be non-zero");
163 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
164 return Err(ChannelClosed);
165 }
166
167 let backoff = Backoff::new();
168
169 loop {
170 unsafe {
175 let inner_ptr: *mut queue::Producer = &raw mut self.inner;
176 if let Ok(claim) = (*inner_ptr).try_claim(len) {
177 return Ok(std::mem::transmute::<
178 queue::WriteClaim<'_>,
179 queue::WriteClaim<'_>,
180 >(claim));
181 }
182 backoff.snooze();
184 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
185 return Err(ChannelClosed);
186 }
187 if backoff.is_completed() {
189 backoff.reset();
190 }
191 }
192 }
193 }
194
195 #[inline]
206 pub fn try_send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, TrySendError> {
207 assert!(len > 0, "payload length must be non-zero");
209 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
210 return Err(TrySendError::Disconnected);
211 }
212
213 match self.inner.try_claim(len) {
214 Ok(claim) => Ok(claim),
215 Err(crate::BufferFull) => Err(TrySendError::Full),
216 }
217 }
218
219 #[inline]
224 pub fn notify(&self) {
225 if self.shared.receiver_waiting.load(Ordering::Relaxed) {
226 self.receiver_unparker.unpark();
227 }
228 }
229
230 #[inline]
232 pub fn capacity(&self) -> usize {
233 self.inner.capacity()
234 }
235
236 #[inline]
238 pub fn is_disconnected(&self) -> bool {
239 self.shared.receiver_disconnected.load(Ordering::Relaxed)
240 }
241}
242
243impl Drop for Sender {
244 fn drop(&mut self) {
245 self.shared
246 .sender_disconnected
247 .store(true, Ordering::Relaxed);
248 self.receiver_unparker.unpark();
250 }
251}
252
253impl std::fmt::Debug for Sender {
254 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
255 f.debug_struct("Sender")
256 .field("capacity", &self.capacity())
257 .finish_non_exhaustive()
258 }
259}
260
261pub struct Receiver {
270 inner: queue::Consumer,
271 parker: crossbeam_utils::sync::Parker,
272 shared: Arc<ChannelShared>,
273}
274
275#[derive(Debug, Clone, Copy, PartialEq, Eq)]
277pub enum RecvError {
278 Timeout,
282 Disconnected,
284}
285
286impl std::fmt::Display for RecvError {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 match self {
289 Self::Timeout => write!(f, "receive timed out"),
290 Self::Disconnected => write!(f, "channel disconnected"),
291 }
292 }
293}
294
295impl std::error::Error for RecvError {}
296
297impl Receiver {
298 #[inline]
311 pub fn recv(&mut self, timeout: Option<Duration>) -> Result<queue::ReadClaim<'_>, RecvError> {
312 if timeout == Some(Duration::ZERO) {
314 unsafe {
319 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
320 if let Some(claim) = (*inner_ptr).try_claim() {
321 return Ok(std::mem::transmute::<
322 queue::ReadClaim<'_>,
323 queue::ReadClaim<'_>,
324 >(claim));
325 }
326 }
327 if self.shared.sender_disconnected.load(Ordering::Relaxed) {
328 return Err(RecvError::Disconnected);
329 }
330 return Err(RecvError::Timeout);
331 }
332
333 let park_timeout = timeout.unwrap_or(DEFAULT_PARK_TIMEOUT);
334 let backoff = Backoff::new();
335
336 loop {
337 unsafe {
342 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
343 if let Some(claim) = (*inner_ptr).try_claim() {
344 return Ok(std::mem::transmute::<
345 queue::ReadClaim<'_>,
346 queue::ReadClaim<'_>,
347 >(claim));
348 }
349 }
350
351 if self.shared.sender_disconnected.load(Ordering::Relaxed) {
352 return Err(RecvError::Disconnected);
353 }
354
355 if !backoff.is_completed() {
357 backoff.snooze();
358 continue;
359 }
360
361 self.shared.receiver_waiting.store(true, Ordering::Relaxed);
363 self.parker.park_timeout(park_timeout);
364 self.shared.receiver_waiting.store(false, Ordering::Relaxed);
365
366 if timeout.is_some() {
369 unsafe {
372 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
373 if let Some(claim) = (*inner_ptr).try_claim() {
374 return Ok(std::mem::transmute::<
375 queue::ReadClaim<'_>,
376 queue::ReadClaim<'_>,
377 >(claim));
378 }
379 }
380
381 if self.shared.sender_disconnected.load(Ordering::Relaxed) {
382 return Err(RecvError::Disconnected);
383 }
384
385 return Err(RecvError::Timeout);
386 }
387
388 backoff.reset();
390 }
391 }
392
393 #[inline]
397 pub fn try_recv(&mut self) -> Option<queue::ReadClaim<'_>> {
398 self.inner.try_claim()
399 }
400
401 #[inline]
403 pub fn capacity(&self) -> usize {
404 self.inner.capacity()
405 }
406
407 #[inline]
409 pub fn is_disconnected(&self) -> bool {
410 self.shared.sender_disconnected.load(Ordering::Relaxed)
411 }
412}
413
414impl Drop for Receiver {
415 fn drop(&mut self) {
416 self.shared
417 .receiver_disconnected
418 .store(true, Ordering::Relaxed);
419 }
420}
421
422impl std::fmt::Debug for Receiver {
423 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
424 f.debug_struct("Receiver")
425 .field("capacity", &self.capacity())
426 .finish_non_exhaustive()
427 }
428}
429
430#[cfg(test)]
435mod tests {
436 use super::*;
437 use std::thread;
438
439 #[test]
440 fn basic_send_recv() {
441 let (mut tx, mut rx) = channel(1024);
442
443 let payload = b"hello world";
444 let mut claim = tx.send(payload.len()).unwrap();
445 claim.copy_from_slice(payload);
446 claim.commit();
447 tx.notify();
448
449 let record = rx.recv(None).unwrap();
450 assert_eq!(&*record, payload);
451 }
452
453 #[test]
454 fn try_send_try_recv() {
455 let (mut tx, mut rx) = channel(1024);
456
457 assert!(rx.try_recv().is_none());
458
459 let payload = b"test";
460 let mut claim = tx.try_send(payload.len()).unwrap();
461 claim.copy_from_slice(payload);
462 claim.commit();
463
464 {
465 let record = rx.try_recv().unwrap();
466 assert_eq!(&*record, payload);
467 } assert!(rx.try_recv().is_none());
470 }
471
472 #[test]
473 fn cross_thread() {
474 let (mut tx, mut rx) = channel(4096);
475
476 let producer = thread::spawn(move || {
477 for i in 0..1000u64 {
478 let payload = i.to_le_bytes();
479 {
480 let mut claim = tx.send(payload.len()).unwrap();
481 claim.copy_from_slice(&payload);
482 claim.commit();
483 } tx.notify();
485 }
486 });
487
488 let consumer = thread::spawn(move || {
489 for i in 0..1000u64 {
490 let record = rx.recv(None).unwrap();
491 let value = u64::from_le_bytes((*record).try_into().unwrap());
492 assert_eq!(value, i);
493 }
494 });
495
496 producer.join().unwrap();
497 consumer.join().unwrap();
498 }
499
500 #[test]
501 fn disconnection_sender_dropped() {
502 let (tx, mut rx) = channel(1024);
503
504 drop(tx);
505
506 match rx.recv(None) {
507 Err(RecvError::Disconnected) => {}
508 _ => panic!("expected Disconnected"),
509 }
510 }
511
512 #[test]
513 fn disconnection_receiver_dropped() {
514 let (mut tx, rx) = channel(1024);
515
516 drop(rx);
517
518 match tx.send(8) {
519 Err(ChannelClosed) => {}
520 _ => panic!("expected ChannelClosed"),
521 }
522 }
523
524 #[test]
525 fn recv_timeout_works() {
526 let (_tx, mut rx) = channel(1024);
527
528 let start = std::time::Instant::now();
529 let result = rx.recv(Some(Duration::from_millis(50)));
530 let elapsed = start.elapsed();
531
532 assert!(matches!(result, Err(RecvError::Timeout)));
533 assert!(elapsed >= Duration::from_millis(40)); assert!(elapsed < Duration::from_millis(200));
535 }
536
537 #[test]
538 fn recv_timeout_with_data() {
539 let (mut tx, mut rx) = channel(1024);
540
541 let payload = b"data";
542 let mut claim = tx.send(payload.len()).unwrap();
543 claim.copy_from_slice(payload);
544 claim.commit();
545 tx.notify();
546
547 let result = rx.recv(Some(Duration::from_secs(1)));
548 assert!(result.is_ok());
549 assert_eq!(&*result.unwrap(), payload);
550 }
551
552 #[test]
553 fn try_send_returns_full() {
554 let (mut tx, _rx) = channel(64);
555
556 let mut count = 0;
558 loop {
559 match tx.try_send(8) {
560 Ok(mut claim) => {
561 claim.copy_from_slice(b"12345678");
562 claim.commit();
563 count += 1;
564 }
565 Err(TrySendError::Full) => break,
566 Err(e) => panic!("unexpected error: {:?}", e),
567 }
568 }
569
570 assert!(count > 0);
571 }
572
573 #[test]
574 #[should_panic(expected = "payload length must be non-zero")]
575 fn send_zero_panics() {
576 let (mut tx, _rx) = channel(1024);
577 let _ = tx.send(0);
578 }
579
580 #[test]
581 #[should_panic(expected = "payload length must be non-zero")]
582 fn try_send_zero_panics() {
583 let (mut tx, _rx) = channel(1024);
584 let _ = tx.try_send(0);
585 }
586}