1use crate::error::{IpcError, Result};
24use crate::graceful::{GracefulChannel, ShutdownState};
25use crossbeam_channel::{self, Receiver, RecvTimeoutError, Sender, TryRecvError, TrySendError};
26use std::sync::Arc;
27use std::time::Duration;
28
29#[derive(Debug)]
34pub struct ThreadSender<T> {
35 inner: Sender<T>,
36 shutdown: Arc<ShutdownState>,
37}
38
39#[derive(Debug)]
44pub struct ThreadReceiver<T> {
45 inner: Receiver<T>,
46 shutdown: Arc<ShutdownState>,
47}
48
49impl<T> Clone for ThreadSender<T> {
50 fn clone(&self) -> Self {
51 Self {
52 inner: self.inner.clone(),
53 shutdown: Arc::clone(&self.shutdown),
54 }
55 }
56}
57
58impl<T> Clone for ThreadReceiver<T> {
59 fn clone(&self) -> Self {
60 Self {
61 inner: self.inner.clone(),
62 shutdown: Arc::clone(&self.shutdown),
63 }
64 }
65}
66
67impl<T> ThreadSender<T> {
68 pub fn send(&self, msg: T) -> Result<()> {
76 if self.shutdown.is_shutdown() {
77 return Err(IpcError::Closed);
78 }
79
80 self.inner.send(msg).map_err(|_| IpcError::Closed)
81 }
82
83 pub fn try_send(&self, msg: T) -> Result<()> {
90 if self.shutdown.is_shutdown() {
91 return Err(IpcError::Closed);
92 }
93
94 self.inner.try_send(msg).map_err(|e| match e {
95 TrySendError::Full(_) => IpcError::WouldBlock,
96 TrySendError::Disconnected(_) => IpcError::Closed,
97 })
98 }
99
100 pub fn send_timeout(&self, msg: T, timeout: Duration) -> Result<()> {
107 if self.shutdown.is_shutdown() {
108 return Err(IpcError::Closed);
109 }
110
111 self.inner.send_timeout(msg, timeout).map_err(|e| {
112 if e.is_timeout() {
113 IpcError::Timeout
114 } else {
115 IpcError::Closed
116 }
117 })
118 }
119
120 pub fn is_empty(&self) -> bool {
122 self.inner.is_empty()
123 }
124
125 pub fn is_full(&self) -> bool {
127 self.inner.is_full()
128 }
129
130 pub fn len(&self) -> usize {
132 self.inner.len()
133 }
134
135 pub fn capacity(&self) -> Option<usize> {
137 self.inner.capacity()
138 }
139
140 pub fn is_shutdown(&self) -> bool {
142 self.shutdown.is_shutdown()
143 }
144
145 pub fn shutdown(&self) {
147 self.shutdown.shutdown();
148 }
149}
150
151impl<T> ThreadReceiver<T> {
152 pub fn recv(&self) -> Result<T> {
160 if self.shutdown.is_shutdown() {
161 return self.inner.try_recv().map_err(|_| IpcError::Closed);
163 }
164
165 self.inner.recv().map_err(|_| IpcError::Closed)
166 }
167
168 pub fn try_recv(&self) -> Result<T> {
175 self.inner.try_recv().map_err(|e| match e {
176 TryRecvError::Empty => IpcError::WouldBlock,
177 TryRecvError::Disconnected => IpcError::Closed,
178 })
179 }
180
181 pub fn recv_timeout(&self, timeout: Duration) -> Result<T> {
188 if self.shutdown.is_shutdown() {
189 return self.try_recv();
190 }
191
192 self.inner.recv_timeout(timeout).map_err(|e| match e {
193 RecvTimeoutError::Timeout => IpcError::Timeout,
194 RecvTimeoutError::Disconnected => IpcError::Closed,
195 })
196 }
197
198 pub fn is_empty(&self) -> bool {
200 self.inner.is_empty()
201 }
202
203 pub fn len(&self) -> usize {
205 self.inner.len()
206 }
207
208 pub fn capacity(&self) -> Option<usize> {
210 self.inner.capacity()
211 }
212
213 pub fn is_shutdown(&self) -> bool {
215 self.shutdown.is_shutdown()
216 }
217
218 pub fn shutdown(&self) {
220 self.shutdown.shutdown();
221 }
222
223 pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
227 std::iter::from_fn(move || self.recv().ok())
228 }
229
230 pub fn try_iter(&self) -> impl Iterator<Item = T> + '_ {
234 std::iter::from_fn(move || self.try_recv().ok())
235 }
236}
237
238#[derive(Debug)]
242pub struct ThreadChannel<T> {
243 sender: ThreadSender<T>,
244 receiver: ThreadReceiver<T>,
245}
246
247impl<T> ThreadChannel<T> {
248 pub fn unbounded() -> (ThreadSender<T>, ThreadReceiver<T>) {
256 let (tx, rx) = crossbeam_channel::unbounded();
257 let shutdown = Arc::new(ShutdownState::new());
258
259 let sender = ThreadSender {
260 inner: tx,
261 shutdown: Arc::clone(&shutdown),
262 };
263
264 let receiver = ThreadReceiver {
265 inner: rx,
266 shutdown,
267 };
268
269 (sender, receiver)
270 }
271
272 pub fn bounded(capacity: usize) -> (ThreadSender<T>, ThreadReceiver<T>) {
284 let (tx, rx) = crossbeam_channel::bounded(capacity);
285 let shutdown = Arc::new(ShutdownState::new());
286
287 let sender = ThreadSender {
288 inner: tx,
289 shutdown: Arc::clone(&shutdown),
290 };
291
292 let receiver = ThreadReceiver {
293 inner: rx,
294 shutdown,
295 };
296
297 (sender, receiver)
298 }
299
300 pub fn new_unbounded() -> Self {
302 let (sender, receiver) = Self::unbounded();
303 Self { sender, receiver }
304 }
305
306 pub fn new_bounded(capacity: usize) -> Self {
308 let (sender, receiver) = Self::bounded(capacity);
309 Self { sender, receiver }
310 }
311
312 pub fn sender(&self) -> &ThreadSender<T> {
314 &self.sender
315 }
316
317 pub fn receiver(&self) -> &ThreadReceiver<T> {
319 &self.receiver
320 }
321
322 pub fn clone_sender(&self) -> ThreadSender<T> {
324 self.sender.clone()
325 }
326
327 pub fn clone_receiver(&self) -> ThreadReceiver<T> {
329 self.receiver.clone()
330 }
331
332 pub fn split(self) -> (ThreadSender<T>, ThreadReceiver<T>) {
334 (self.sender, self.receiver)
335 }
336}
337
338impl<T> GracefulChannel for ThreadChannel<T> {
339 fn shutdown(&self) {
340 self.sender.shutdown();
341 }
342
343 fn is_shutdown(&self) -> bool {
344 self.sender.is_shutdown()
345 }
346
347 fn drain(&self) -> Result<()> {
348 while self.receiver.try_recv().is_ok() {}
350 Ok(())
351 }
352
353 fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
354 self.shutdown();
355 let start = std::time::Instant::now();
356
357 while !self.receiver.is_empty() {
358 if start.elapsed() >= timeout {
359 return Err(IpcError::Timeout);
360 }
361 let _ = self.receiver.try_recv();
362 std::thread::sleep(Duration::from_millis(1));
363 }
364
365 Ok(())
366 }
367}
368
369impl<T> GracefulChannel for ThreadSender<T> {
370 fn shutdown(&self) {
371 self.shutdown.shutdown();
372 }
373
374 fn is_shutdown(&self) -> bool {
375 self.shutdown.is_shutdown()
376 }
377
378 fn drain(&self) -> Result<()> {
379 self.shutdown.wait_for_drain(None)
380 }
381
382 fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
383 self.shutdown();
384 self.shutdown.wait_for_drain(Some(timeout))
385 }
386}
387
388impl<T> GracefulChannel for ThreadReceiver<T> {
389 fn shutdown(&self) {
390 self.shutdown.shutdown();
391 }
392
393 fn is_shutdown(&self) -> bool {
394 self.shutdown.is_shutdown()
395 }
396
397 fn drain(&self) -> Result<()> {
398 while self.try_recv().is_ok() {}
399 Ok(())
400 }
401
402 fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
403 self.shutdown();
404 let start = std::time::Instant::now();
405
406 while !self.is_empty() {
407 if start.elapsed() >= timeout {
408 return Err(IpcError::Timeout);
409 }
410 let _ = self.try_recv();
411 std::thread::sleep(Duration::from_millis(1));
412 }
413
414 Ok(())
415 }
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421 use std::thread;
422
423 #[test]
424 fn test_unbounded_channel() {
425 let (tx, rx) = ThreadChannel::<i32>::unbounded();
426
427 tx.send(42).unwrap();
428 tx.send(43).unwrap();
429
430 assert_eq!(rx.recv().unwrap(), 42);
431 assert_eq!(rx.recv().unwrap(), 43);
432 }
433
434 #[test]
435 fn test_bounded_channel() {
436 let (tx, rx) = ThreadChannel::<i32>::bounded(2);
437
438 tx.send(1).unwrap();
439 tx.send(2).unwrap();
440
441 assert!(matches!(tx.try_send(3), Err(IpcError::WouldBlock)));
443
444 assert_eq!(rx.recv().unwrap(), 1);
445
446 tx.send(3).unwrap();
448
449 assert_eq!(rx.recv().unwrap(), 2);
450 assert_eq!(rx.recv().unwrap(), 3);
451 }
452
453 #[test]
454 fn test_multi_producer() {
455 let (tx, rx) = ThreadChannel::<i32>::unbounded();
456 let tx2 = tx.clone();
457
458 let h1 = thread::spawn(move || {
459 for i in 0..5 {
460 tx.send(i).unwrap();
461 }
462 });
463
464 let h2 = thread::spawn(move || {
465 for i in 5..10 {
466 tx2.send(i).unwrap();
467 }
468 });
469
470 h1.join().unwrap();
471 h2.join().unwrap();
472
473 let mut received: Vec<i32> = rx.try_iter().collect();
474 received.sort();
475
476 assert_eq!(received, (0..10).collect::<Vec<_>>());
477 }
478
479 #[test]
480 fn test_multi_consumer() {
481 let (tx, rx) = ThreadChannel::<i32>::unbounded();
482 let rx2 = rx.clone();
483
484 for i in 0..10 {
485 tx.send(i).unwrap();
486 }
487 drop(tx);
488
489 let h1 = thread::spawn(move || {
490 let mut received = Vec::new();
491 while let Ok(v) = rx.recv() {
492 received.push(v);
493 }
494 received
495 });
496
497 let h2 = thread::spawn(move || {
498 let mut received = Vec::new();
499 while let Ok(v) = rx2.recv() {
500 received.push(v);
501 }
502 received
503 });
504
505 let r1 = h1.join().unwrap();
506 let r2 = h2.join().unwrap();
507
508 let mut all: Vec<i32> = r1.into_iter().chain(r2).collect();
509 all.sort();
510
511 assert_eq!(all, (0..10).collect::<Vec<_>>());
512 }
513
514 #[test]
515 fn test_shutdown() {
516 let (tx, rx) = ThreadChannel::<i32>::unbounded();
517
518 tx.send(1).unwrap();
519 tx.shutdown();
520
521 assert!(matches!(tx.send(2), Err(IpcError::Closed)));
523
524 assert_eq!(rx.recv().unwrap(), 1);
526 }
527
528 #[test]
529 fn test_recv_timeout() {
530 let (_tx, rx) = ThreadChannel::<i32>::unbounded();
531
532 let result = rx.recv_timeout(Duration::from_millis(50));
533 assert!(matches!(result, Err(IpcError::Timeout)));
534 }
535
536 #[test]
537 fn test_send_timeout() {
538 let (tx, _rx) = ThreadChannel::<i32>::bounded(1);
539
540 tx.send(1).unwrap();
541
542 let result = tx.send_timeout(2, Duration::from_millis(50));
543 assert!(matches!(result, Err(IpcError::Timeout)));
544 }
545
546 #[test]
547 fn test_try_recv() {
548 let (tx, rx) = ThreadChannel::<i32>::unbounded();
549
550 assert!(matches!(rx.try_recv(), Err(IpcError::WouldBlock)));
551
552 tx.send(42).unwrap();
553
554 assert_eq!(rx.try_recv().unwrap(), 42);
555 assert!(matches!(rx.try_recv(), Err(IpcError::WouldBlock)));
556 }
557
558 #[test]
559 fn test_channel_capacity() {
560 let (tx, rx) = ThreadChannel::<i32>::bounded(5);
561
562 assert_eq!(tx.capacity(), Some(5));
563 assert_eq!(rx.capacity(), Some(5));
564 assert!(tx.is_empty());
565 assert!(!tx.is_full());
566
567 for i in 0..5 {
568 tx.send(i).unwrap();
569 }
570
571 assert!(tx.is_full());
572 assert!(!tx.is_empty());
573 assert_eq!(tx.len(), 5);
574 }
575
576 #[test]
577 fn test_unbounded_capacity() {
578 let (tx, rx) = ThreadChannel::<i32>::unbounded();
579
580 assert_eq!(tx.capacity(), None);
581 assert_eq!(rx.capacity(), None);
582 assert!(!tx.is_full()); }
584
585 #[test]
586 fn test_graceful_channel_trait() {
587 let channel = ThreadChannel::<i32>::new_unbounded();
588
589 assert!(!channel.is_shutdown());
590
591 channel.sender().send(1).unwrap();
592 channel.sender().send(2).unwrap();
593
594 channel.shutdown();
595
596 assert!(channel.is_shutdown());
597
598 channel.drain().unwrap();
600
601 assert!(channel.receiver().is_empty());
602 }
603
604 #[test]
605 fn test_iter() {
606 let (tx, rx) = ThreadChannel::<i32>::unbounded();
607
608 tx.send(1).unwrap();
609 tx.send(2).unwrap();
610 tx.send(3).unwrap();
611 drop(tx);
612
613 let collected: Vec<i32> = rx.iter().collect();
614 assert_eq!(collected, vec![1, 2, 3]);
615 }
616
617 #[test]
618 fn test_try_iter() {
619 let (tx, rx) = ThreadChannel::<i32>::unbounded();
620
621 tx.send(1).unwrap();
622 tx.send(2).unwrap();
623
624 let collected: Vec<i32> = rx.try_iter().collect();
625 assert_eq!(collected, vec![1, 2]);
626
627 tx.send(3).unwrap();
629 assert_eq!(rx.recv().unwrap(), 3);
630 }
631}