1use std::sync::Arc;
35use std::sync::atomic::{AtomicBool, Ordering};
36use std::time::Duration;
37
38use crossbeam_utils::Backoff;
39
40use crate::queue::spsc as queue;
41
42const DEFAULT_PARK_TIMEOUT: Duration = Duration::from_millis(100);
46
47pub fn channel(capacity: usize) -> (Sender, Receiver) {
55 let (producer, consumer) = queue::new(capacity);
56
57 let shared = Arc::new(ChannelShared {
58 receiver_waiting: AtomicBool::new(false),
59 sender_disconnected: AtomicBool::new(false),
60 receiver_disconnected: AtomicBool::new(false),
61 });
62
63 let parker = crossbeam_utils::sync::Parker::new();
64 let unparker = parker.unparker().clone();
65
66 (
67 Sender {
68 inner: producer,
69 receiver_unparker: unparker,
70 shared: Arc::clone(&shared),
71 },
72 Receiver {
73 inner: consumer,
74 parker,
75 shared,
76 },
77 )
78}
79
80struct ChannelShared {
82 receiver_waiting: AtomicBool,
84 sender_disconnected: AtomicBool,
86 receiver_disconnected: AtomicBool,
88}
89
90pub struct Sender {
99 inner: queue::Producer,
100 receiver_unparker: crossbeam_utils::sync::Unparker,
101 shared: Arc<ChannelShared>,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum SendError {
107 Disconnected,
109 ZeroLength,
111}
112
113impl std::fmt::Display for SendError {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 Self::Disconnected => write!(f, "channel disconnected"),
117 Self::ZeroLength => write!(f, "payload length must be non-zero"),
118 }
119 }
120}
121
122impl std::error::Error for SendError {}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum TrySendError {
127 Full,
129 Disconnected,
131 ZeroLength,
133}
134
135impl std::fmt::Display for TrySendError {
136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
137 match self {
138 Self::Full => write!(f, "channel full"),
139 Self::Disconnected => write!(f, "channel disconnected"),
140 Self::ZeroLength => write!(f, "payload length must be non-zero"),
141 }
142 }
143}
144
145impl std::error::Error for TrySendError {}
146
147impl Sender {
148 #[inline]
162 pub fn send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, SendError> {
163 if len == 0 {
165 return Err(SendError::ZeroLength);
166 }
167 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
168 return Err(SendError::Disconnected);
169 }
170
171 let backoff = Backoff::new();
172
173 loop {
174 unsafe {
179 let inner_ptr: *mut queue::Producer = &raw mut self.inner;
180 match (*inner_ptr).try_claim(len) {
181 Ok(claim) => {
182 return Ok(std::mem::transmute::<
183 queue::WriteClaim<'_>,
184 queue::WriteClaim<'_>,
185 >(claim));
186 }
187 Err(crate::TryClaimError::Full) => {
188 backoff.snooze();
189 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
190 return Err(SendError::Disconnected);
191 }
192 if backoff.is_completed() {
194 backoff.reset();
195 }
196 }
197 Err(crate::TryClaimError::ZeroLength) => return Err(SendError::ZeroLength),
198 }
199 }
200 }
201 }
202
203 #[inline]
211 pub fn try_send(&mut self, len: usize) -> Result<queue::WriteClaim<'_>, TrySendError> {
212 if self.shared.receiver_disconnected.load(Ordering::Relaxed) {
213 return Err(TrySendError::Disconnected);
214 }
215
216 match self.inner.try_claim(len) {
217 Ok(claim) => Ok(claim),
218 Err(crate::TryClaimError::Full) => Err(TrySendError::Full),
219 Err(crate::TryClaimError::ZeroLength) => Err(TrySendError::ZeroLength),
220 }
221 }
222
223 #[inline]
228 pub fn notify(&self) {
229 if self.shared.receiver_waiting.load(Ordering::Relaxed) {
230 self.receiver_unparker.unpark();
231 }
232 }
233
234 #[inline]
236 pub fn capacity(&self) -> usize {
237 self.inner.capacity()
238 }
239
240 #[inline]
242 pub fn is_disconnected(&self) -> bool {
243 self.shared.receiver_disconnected.load(Ordering::Relaxed)
244 }
245}
246
247impl Drop for Sender {
248 fn drop(&mut self) {
249 self.shared
250 .sender_disconnected
251 .store(true, Ordering::Relaxed);
252 self.receiver_unparker.unpark();
254 }
255}
256
257impl std::fmt::Debug for Sender {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("Sender")
260 .field("capacity", &self.capacity())
261 .finish_non_exhaustive()
262 }
263}
264
265pub struct Receiver {
274 inner: queue::Consumer,
275 parker: crossbeam_utils::sync::Parker,
276 shared: Arc<ChannelShared>,
277}
278
279#[derive(Debug, Clone, Copy, PartialEq, Eq)]
281pub enum RecvError {
282 Timeout,
286 Disconnected,
288}
289
290impl std::fmt::Display for RecvError {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 match self {
293 Self::Timeout => write!(f, "receive timed out"),
294 Self::Disconnected => write!(f, "channel disconnected"),
295 }
296 }
297}
298
299impl std::error::Error for RecvError {}
300
301impl Receiver {
302 #[inline]
315 pub fn recv(&mut self, timeout: Option<Duration>) -> Result<queue::ReadClaim<'_>, RecvError> {
316 if timeout == Some(Duration::ZERO) {
318 unsafe {
323 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
324 if let Some(claim) = (*inner_ptr).try_claim() {
325 return Ok(std::mem::transmute::<
326 queue::ReadClaim<'_>,
327 queue::ReadClaim<'_>,
328 >(claim));
329 }
330 }
331 if self.shared.sender_disconnected.load(Ordering::Relaxed) {
332 return Err(RecvError::Disconnected);
333 }
334 return Err(RecvError::Timeout);
335 }
336
337 let park_timeout = timeout.unwrap_or(DEFAULT_PARK_TIMEOUT);
338 let backoff = Backoff::new();
339
340 loop {
341 unsafe {
346 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
347 if let Some(claim) = (*inner_ptr).try_claim() {
348 return Ok(std::mem::transmute::<
349 queue::ReadClaim<'_>,
350 queue::ReadClaim<'_>,
351 >(claim));
352 }
353 }
354
355 if self.shared.sender_disconnected.load(Ordering::Relaxed) {
356 return Err(RecvError::Disconnected);
357 }
358
359 if !backoff.is_completed() {
361 backoff.snooze();
362 continue;
363 }
364
365 self.shared.receiver_waiting.store(true, Ordering::Relaxed);
367 self.parker.park_timeout(park_timeout);
368 self.shared.receiver_waiting.store(false, Ordering::Relaxed);
369
370 if timeout.is_some() {
373 unsafe {
376 let inner_ptr: *mut queue::Consumer = &raw mut self.inner;
377 if let Some(claim) = (*inner_ptr).try_claim() {
378 return Ok(std::mem::transmute::<
379 queue::ReadClaim<'_>,
380 queue::ReadClaim<'_>,
381 >(claim));
382 }
383 }
384
385 if self.shared.sender_disconnected.load(Ordering::Relaxed) {
386 return Err(RecvError::Disconnected);
387 }
388
389 return Err(RecvError::Timeout);
390 }
391
392 backoff.reset();
394 }
395 }
396
397 #[inline]
401 pub fn try_recv(&mut self) -> Option<queue::ReadClaim<'_>> {
402 self.inner.try_claim()
403 }
404
405 #[inline]
407 pub fn capacity(&self) -> usize {
408 self.inner.capacity()
409 }
410
411 #[inline]
413 pub fn is_disconnected(&self) -> bool {
414 self.shared.sender_disconnected.load(Ordering::Relaxed)
415 }
416}
417
418impl Drop for Receiver {
419 fn drop(&mut self) {
420 self.shared
421 .receiver_disconnected
422 .store(true, Ordering::Relaxed);
423 }
424}
425
426impl std::fmt::Debug for Receiver {
427 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428 f.debug_struct("Receiver")
429 .field("capacity", &self.capacity())
430 .finish_non_exhaustive()
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441 use std::thread;
442
443 #[test]
444 fn basic_send_recv() {
445 let (mut tx, mut rx) = channel(1024);
446
447 let payload = b"hello world";
448 let mut claim = tx.send(payload.len()).unwrap();
449 claim.copy_from_slice(payload);
450 claim.commit();
451 tx.notify();
452
453 let record = rx.recv(None).unwrap();
454 assert_eq!(&*record, payload);
455 }
456
457 #[test]
458 fn try_send_try_recv() {
459 let (mut tx, mut rx) = channel(1024);
460
461 assert!(rx.try_recv().is_none());
462
463 let payload = b"test";
464 let mut claim = tx.try_send(payload.len()).unwrap();
465 claim.copy_from_slice(payload);
466 claim.commit();
467
468 {
469 let record = rx.try_recv().unwrap();
470 assert_eq!(&*record, payload);
471 } assert!(rx.try_recv().is_none());
474 }
475
476 #[test]
477 fn cross_thread() {
478 let (mut tx, mut rx) = channel(4096);
479
480 let producer = thread::spawn(move || {
481 for i in 0..1000u64 {
482 let payload = i.to_le_bytes();
483 {
484 let mut claim = tx.send(payload.len()).unwrap();
485 claim.copy_from_slice(&payload);
486 claim.commit();
487 } tx.notify();
489 }
490 });
491
492 let consumer = thread::spawn(move || {
493 for i in 0..1000u64 {
494 let record = rx.recv(None).unwrap();
495 let value = u64::from_le_bytes((*record).try_into().unwrap());
496 assert_eq!(value, i);
497 }
498 });
499
500 producer.join().unwrap();
501 consumer.join().unwrap();
502 }
503
504 #[test]
505 fn disconnection_sender_dropped() {
506 let (tx, mut rx) = channel(1024);
507
508 drop(tx);
509
510 match rx.recv(None) {
511 Err(RecvError::Disconnected) => {}
512 _ => panic!("expected Disconnected"),
513 }
514 }
515
516 #[test]
517 fn disconnection_receiver_dropped() {
518 let (mut tx, rx) = channel(1024);
519
520 drop(rx);
521
522 match tx.send(8) {
523 Err(SendError::Disconnected) => {}
524 _ => panic!("expected Disconnected"),
525 }
526 }
527
528 #[test]
529 fn recv_timeout_works() {
530 let (_tx, mut rx) = channel(1024);
531
532 let start = std::time::Instant::now();
533 let result = rx.recv(Some(Duration::from_millis(50)));
534 let elapsed = start.elapsed();
535
536 assert!(matches!(result, Err(RecvError::Timeout)));
537 assert!(elapsed >= Duration::from_millis(40)); assert!(elapsed < Duration::from_millis(200));
539 }
540
541 #[test]
542 fn recv_timeout_with_data() {
543 let (mut tx, mut rx) = channel(1024);
544
545 let payload = b"data";
546 let mut claim = tx.send(payload.len()).unwrap();
547 claim.copy_from_slice(payload);
548 claim.commit();
549 tx.notify();
550
551 let result = rx.recv(Some(Duration::from_secs(1)));
552 assert!(result.is_ok());
553 assert_eq!(&*result.unwrap(), payload);
554 }
555
556 #[test]
557 fn try_send_returns_full() {
558 let (mut tx, _rx) = channel(64);
559
560 let mut count = 0;
562 loop {
563 match tx.try_send(8) {
564 Ok(mut claim) => {
565 claim.copy_from_slice(b"12345678");
566 claim.commit();
567 count += 1;
568 }
569 Err(TrySendError::Full) => break,
570 Err(e) => panic!("unexpected error: {:?}", e),
571 }
572 }
573
574 assert!(count > 0);
575 }
576
577 #[test]
578 fn zero_len_error() {
579 let (mut tx, _rx) = channel(1024);
580 assert!(matches!(tx.send(0), Err(SendError::ZeroLength)));
581 assert!(matches!(tx.try_send(0), Err(TrySendError::ZeroLength)));
582 }
583}