ipckit/
thread_channel.rs

1//! Thread Channel for intra-process thread communication
2//!
3//! This module provides a high-performance channel for communication between threads
4//! within the same process, using crossbeam-channel as the underlying implementation.
5//!
6//! # Example
7//!
8//! ```rust
9//! use ipckit::ThreadChannel;
10//! use std::thread;
11//!
12//! // Create an unbounded channel
13//! let (tx, rx) = ThreadChannel::<String>::unbounded();
14//!
15//! thread::spawn(move || {
16//!     tx.send("Hello from thread!".to_string()).unwrap();
17//! });
18//!
19//! let msg = rx.recv().unwrap();
20//! assert_eq!(msg, "Hello from thread!");
21//! ```
22
23use crate::error::{IpcError, Result};
24use crate::graceful::{GracefulChannel, ShutdownState};
25use crossbeam_channel::{self, Receiver, RecvTimeoutError, Sender, TryRecvError, TrySendError};
26use std::sync::Arc;
27use std::time::Duration;
28
29/// A thread-safe channel sender for intra-process communication.
30///
31/// This is the sending half of a [`ThreadChannel`]. It can be cloned to create
32/// multiple producers that send to the same channel.
33#[derive(Debug)]
34pub struct ThreadSender<T> {
35    inner: Sender<T>,
36    shutdown: Arc<ShutdownState>,
37}
38
39/// A thread-safe channel receiver for intra-process communication.
40///
41/// This is the receiving half of a [`ThreadChannel`]. It can be cloned to create
42/// multiple consumers that receive from the same channel.
43#[derive(Debug)]
44pub struct ThreadReceiver<T> {
45    inner: Receiver<T>,
46    shutdown: Arc<ShutdownState>,
47}
48
49impl<T> Clone for ThreadSender<T> {
50    fn clone(&self) -> Self {
51        Self {
52            inner: self.inner.clone(),
53            shutdown: Arc::clone(&self.shutdown),
54        }
55    }
56}
57
58impl<T> Clone for ThreadReceiver<T> {
59    fn clone(&self) -> Self {
60        Self {
61            inner: self.inner.clone(),
62            shutdown: Arc::clone(&self.shutdown),
63        }
64    }
65}
66
67impl<T> ThreadSender<T> {
68    /// Send a message through the channel.
69    ///
70    /// This method blocks if the channel is bounded and full.
71    ///
72    /// # Errors
73    ///
74    /// Returns `IpcError::Closed` if the channel has been shutdown or all receivers have been dropped.
75    pub fn send(&self, msg: T) -> Result<()> {
76        if self.shutdown.is_shutdown() {
77            return Err(IpcError::Closed);
78        }
79
80        self.inner.send(msg).map_err(|_| IpcError::Closed)
81    }
82
83    /// Try to send a message without blocking.
84    ///
85    /// # Errors
86    ///
87    /// - `IpcError::Closed` if the channel has been shutdown or all receivers have been dropped.
88    /// - `IpcError::WouldBlock` if the channel is full (bounded channels only).
89    pub fn try_send(&self, msg: T) -> Result<()> {
90        if self.shutdown.is_shutdown() {
91            return Err(IpcError::Closed);
92        }
93
94        self.inner.try_send(msg).map_err(|e| match e {
95            TrySendError::Full(_) => IpcError::WouldBlock,
96            TrySendError::Disconnected(_) => IpcError::Closed,
97        })
98    }
99
100    /// Send a message with a timeout.
101    ///
102    /// # Errors
103    ///
104    /// - `IpcError::Closed` if the channel has been shutdown or all receivers have been dropped.
105    /// - `IpcError::Timeout` if the timeout expires before the message can be sent.
106    pub fn send_timeout(&self, msg: T, timeout: Duration) -> Result<()> {
107        if self.shutdown.is_shutdown() {
108            return Err(IpcError::Closed);
109        }
110
111        self.inner.send_timeout(msg, timeout).map_err(|e| {
112            if e.is_timeout() {
113                IpcError::Timeout
114            } else {
115                IpcError::Closed
116            }
117        })
118    }
119
120    /// Check if the channel is empty.
121    pub fn is_empty(&self) -> bool {
122        self.inner.is_empty()
123    }
124
125    /// Check if the channel is full (always false for unbounded channels).
126    pub fn is_full(&self) -> bool {
127        self.inner.is_full()
128    }
129
130    /// Get the number of messages in the channel.
131    pub fn len(&self) -> usize {
132        self.inner.len()
133    }
134
135    /// Get the capacity of the channel (None for unbounded channels).
136    pub fn capacity(&self) -> Option<usize> {
137        self.inner.capacity()
138    }
139
140    /// Check if the channel has been shutdown.
141    pub fn is_shutdown(&self) -> bool {
142        self.shutdown.is_shutdown()
143    }
144
145    /// Shutdown the channel.
146    pub fn shutdown(&self) {
147        self.shutdown.shutdown();
148    }
149}
150
151impl<T> ThreadReceiver<T> {
152    /// Receive a message from the channel.
153    ///
154    /// This method blocks until a message is available.
155    ///
156    /// # Errors
157    ///
158    /// Returns `IpcError::Closed` if the channel has been shutdown or all senders have been dropped.
159    pub fn recv(&self) -> Result<T> {
160        if self.shutdown.is_shutdown() {
161            // Try to drain remaining messages first
162            return self.inner.try_recv().map_err(|_| IpcError::Closed);
163        }
164
165        self.inner.recv().map_err(|_| IpcError::Closed)
166    }
167
168    /// Try to receive a message without blocking.
169    ///
170    /// # Errors
171    ///
172    /// - `IpcError::Closed` if the channel has been shutdown or all senders have been dropped.
173    /// - `IpcError::WouldBlock` if no message is available.
174    pub fn try_recv(&self) -> Result<T> {
175        self.inner.try_recv().map_err(|e| match e {
176            TryRecvError::Empty => IpcError::WouldBlock,
177            TryRecvError::Disconnected => IpcError::Closed,
178        })
179    }
180
181    /// Receive a message with a timeout.
182    ///
183    /// # Errors
184    ///
185    /// - `IpcError::Closed` if the channel has been shutdown or all senders have been dropped.
186    /// - `IpcError::Timeout` if the timeout expires before a message is available.
187    pub fn recv_timeout(&self, timeout: Duration) -> Result<T> {
188        if self.shutdown.is_shutdown() {
189            return self.try_recv();
190        }
191
192        self.inner.recv_timeout(timeout).map_err(|e| match e {
193            RecvTimeoutError::Timeout => IpcError::Timeout,
194            RecvTimeoutError::Disconnected => IpcError::Closed,
195        })
196    }
197
198    /// Check if the channel is empty.
199    pub fn is_empty(&self) -> bool {
200        self.inner.is_empty()
201    }
202
203    /// Get the number of messages in the channel.
204    pub fn len(&self) -> usize {
205        self.inner.len()
206    }
207
208    /// Get the capacity of the channel (None for unbounded channels).
209    pub fn capacity(&self) -> Option<usize> {
210        self.inner.capacity()
211    }
212
213    /// Check if the channel has been shutdown.
214    pub fn is_shutdown(&self) -> bool {
215        self.shutdown.is_shutdown()
216    }
217
218    /// Shutdown the channel.
219    pub fn shutdown(&self) {
220        self.shutdown.shutdown();
221    }
222
223    /// Create an iterator over received messages.
224    ///
225    /// The iterator will block waiting for messages and will stop when the channel is closed.
226    pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
227        std::iter::from_fn(move || self.recv().ok())
228    }
229
230    /// Create a non-blocking iterator over available messages.
231    ///
232    /// The iterator will return `None` when no more messages are immediately available.
233    pub fn try_iter(&self) -> impl Iterator<Item = T> + '_ {
234        std::iter::from_fn(move || self.try_recv().ok())
235    }
236}
237
238/// A bidirectional thread channel that combines both sender and receiver.
239///
240/// This is useful when you need both send and receive capabilities in one place.
241#[derive(Debug)]
242pub struct ThreadChannel<T> {
243    sender: ThreadSender<T>,
244    receiver: ThreadReceiver<T>,
245}
246
247impl<T> ThreadChannel<T> {
248    /// Create a new unbounded thread channel.
249    ///
250    /// An unbounded channel has no capacity limit and will never block on send.
251    ///
252    /// # Returns
253    ///
254    /// A tuple of (sender, receiver) for the channel.
255    pub fn unbounded() -> (ThreadSender<T>, ThreadReceiver<T>) {
256        let (tx, rx) = crossbeam_channel::unbounded();
257        let shutdown = Arc::new(ShutdownState::new());
258
259        let sender = ThreadSender {
260            inner: tx,
261            shutdown: Arc::clone(&shutdown),
262        };
263
264        let receiver = ThreadReceiver {
265            inner: rx,
266            shutdown,
267        };
268
269        (sender, receiver)
270    }
271
272    /// Create a new bounded thread channel with the specified capacity.
273    ///
274    /// A bounded channel will block on send when the channel is full.
275    ///
276    /// # Arguments
277    ///
278    /// * `capacity` - The maximum number of messages the channel can hold.
279    ///
280    /// # Returns
281    ///
282    /// A tuple of (sender, receiver) for the channel.
283    pub fn bounded(capacity: usize) -> (ThreadSender<T>, ThreadReceiver<T>) {
284        let (tx, rx) = crossbeam_channel::bounded(capacity);
285        let shutdown = Arc::new(ShutdownState::new());
286
287        let sender = ThreadSender {
288            inner: tx,
289            shutdown: Arc::clone(&shutdown),
290        };
291
292        let receiver = ThreadReceiver {
293            inner: rx,
294            shutdown,
295        };
296
297        (sender, receiver)
298    }
299
300    /// Create a new bidirectional thread channel (unbounded).
301    pub fn new_unbounded() -> Self {
302        let (sender, receiver) = Self::unbounded();
303        Self { sender, receiver }
304    }
305
306    /// Create a new bidirectional thread channel (bounded).
307    pub fn new_bounded(capacity: usize) -> Self {
308        let (sender, receiver) = Self::bounded(capacity);
309        Self { sender, receiver }
310    }
311
312    /// Get a reference to the sender.
313    pub fn sender(&self) -> &ThreadSender<T> {
314        &self.sender
315    }
316
317    /// Get a reference to the receiver.
318    pub fn receiver(&self) -> &ThreadReceiver<T> {
319        &self.receiver
320    }
321
322    /// Clone the sender.
323    pub fn clone_sender(&self) -> ThreadSender<T> {
324        self.sender.clone()
325    }
326
327    /// Clone the receiver.
328    pub fn clone_receiver(&self) -> ThreadReceiver<T> {
329        self.receiver.clone()
330    }
331
332    /// Split the channel into sender and receiver.
333    pub fn split(self) -> (ThreadSender<T>, ThreadReceiver<T>) {
334        (self.sender, self.receiver)
335    }
336}
337
338impl<T> GracefulChannel for ThreadChannel<T> {
339    fn shutdown(&self) {
340        self.sender.shutdown();
341    }
342
343    fn is_shutdown(&self) -> bool {
344        self.sender.is_shutdown()
345    }
346
347    fn drain(&self) -> Result<()> {
348        // For thread channels, drain means receiving all pending messages
349        while self.receiver.try_recv().is_ok() {}
350        Ok(())
351    }
352
353    fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
354        self.shutdown();
355        let start = std::time::Instant::now();
356
357        while !self.receiver.is_empty() {
358            if start.elapsed() >= timeout {
359                return Err(IpcError::Timeout);
360            }
361            let _ = self.receiver.try_recv();
362            std::thread::sleep(Duration::from_millis(1));
363        }
364
365        Ok(())
366    }
367}
368
369impl<T> GracefulChannel for ThreadSender<T> {
370    fn shutdown(&self) {
371        self.shutdown.shutdown();
372    }
373
374    fn is_shutdown(&self) -> bool {
375        self.shutdown.is_shutdown()
376    }
377
378    fn drain(&self) -> Result<()> {
379        self.shutdown.wait_for_drain(None)
380    }
381
382    fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
383        self.shutdown();
384        self.shutdown.wait_for_drain(Some(timeout))
385    }
386}
387
388impl<T> GracefulChannel for ThreadReceiver<T> {
389    fn shutdown(&self) {
390        self.shutdown.shutdown();
391    }
392
393    fn is_shutdown(&self) -> bool {
394        self.shutdown.is_shutdown()
395    }
396
397    fn drain(&self) -> Result<()> {
398        while self.try_recv().is_ok() {}
399        Ok(())
400    }
401
402    fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
403        self.shutdown();
404        let start = std::time::Instant::now();
405
406        while !self.is_empty() {
407            if start.elapsed() >= timeout {
408                return Err(IpcError::Timeout);
409            }
410            let _ = self.try_recv();
411            std::thread::sleep(Duration::from_millis(1));
412        }
413
414        Ok(())
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use std::thread;
422
423    #[test]
424    fn test_unbounded_channel() {
425        let (tx, rx) = ThreadChannel::<i32>::unbounded();
426
427        tx.send(42).unwrap();
428        tx.send(43).unwrap();
429
430        assert_eq!(rx.recv().unwrap(), 42);
431        assert_eq!(rx.recv().unwrap(), 43);
432    }
433
434    #[test]
435    fn test_bounded_channel() {
436        let (tx, rx) = ThreadChannel::<i32>::bounded(2);
437
438        tx.send(1).unwrap();
439        tx.send(2).unwrap();
440
441        // Channel is full, try_send should fail
442        assert!(matches!(tx.try_send(3), Err(IpcError::WouldBlock)));
443
444        assert_eq!(rx.recv().unwrap(), 1);
445
446        // Now we can send again
447        tx.send(3).unwrap();
448
449        assert_eq!(rx.recv().unwrap(), 2);
450        assert_eq!(rx.recv().unwrap(), 3);
451    }
452
453    #[test]
454    fn test_multi_producer() {
455        let (tx, rx) = ThreadChannel::<i32>::unbounded();
456        let tx2 = tx.clone();
457
458        let h1 = thread::spawn(move || {
459            for i in 0..5 {
460                tx.send(i).unwrap();
461            }
462        });
463
464        let h2 = thread::spawn(move || {
465            for i in 5..10 {
466                tx2.send(i).unwrap();
467            }
468        });
469
470        h1.join().unwrap();
471        h2.join().unwrap();
472
473        let mut received: Vec<i32> = rx.try_iter().collect();
474        received.sort();
475
476        assert_eq!(received, (0..10).collect::<Vec<_>>());
477    }
478
479    #[test]
480    fn test_multi_consumer() {
481        let (tx, rx) = ThreadChannel::<i32>::unbounded();
482        let rx2 = rx.clone();
483
484        for i in 0..10 {
485            tx.send(i).unwrap();
486        }
487        drop(tx);
488
489        let h1 = thread::spawn(move || {
490            let mut received = Vec::new();
491            while let Ok(v) = rx.recv() {
492                received.push(v);
493            }
494            received
495        });
496
497        let h2 = thread::spawn(move || {
498            let mut received = Vec::new();
499            while let Ok(v) = rx2.recv() {
500                received.push(v);
501            }
502            received
503        });
504
505        let r1 = h1.join().unwrap();
506        let r2 = h2.join().unwrap();
507
508        let mut all: Vec<i32> = r1.into_iter().chain(r2).collect();
509        all.sort();
510
511        assert_eq!(all, (0..10).collect::<Vec<_>>());
512    }
513
514    #[test]
515    fn test_shutdown() {
516        let (tx, rx) = ThreadChannel::<i32>::unbounded();
517
518        tx.send(1).unwrap();
519        tx.shutdown();
520
521        // Should fail after shutdown
522        assert!(matches!(tx.send(2), Err(IpcError::Closed)));
523
524        // Can still receive pending messages
525        assert_eq!(rx.recv().unwrap(), 1);
526    }
527
528    #[test]
529    fn test_recv_timeout() {
530        let (_tx, rx) = ThreadChannel::<i32>::unbounded();
531
532        let result = rx.recv_timeout(Duration::from_millis(50));
533        assert!(matches!(result, Err(IpcError::Timeout)));
534    }
535
536    #[test]
537    fn test_send_timeout() {
538        let (tx, _rx) = ThreadChannel::<i32>::bounded(1);
539
540        tx.send(1).unwrap();
541
542        let result = tx.send_timeout(2, Duration::from_millis(50));
543        assert!(matches!(result, Err(IpcError::Timeout)));
544    }
545
546    #[test]
547    fn test_try_recv() {
548        let (tx, rx) = ThreadChannel::<i32>::unbounded();
549
550        assert!(matches!(rx.try_recv(), Err(IpcError::WouldBlock)));
551
552        tx.send(42).unwrap();
553
554        assert_eq!(rx.try_recv().unwrap(), 42);
555        assert!(matches!(rx.try_recv(), Err(IpcError::WouldBlock)));
556    }
557
558    #[test]
559    fn test_channel_capacity() {
560        let (tx, rx) = ThreadChannel::<i32>::bounded(5);
561
562        assert_eq!(tx.capacity(), Some(5));
563        assert_eq!(rx.capacity(), Some(5));
564        assert!(tx.is_empty());
565        assert!(!tx.is_full());
566
567        for i in 0..5 {
568            tx.send(i).unwrap();
569        }
570
571        assert!(tx.is_full());
572        assert!(!tx.is_empty());
573        assert_eq!(tx.len(), 5);
574    }
575
576    #[test]
577    fn test_unbounded_capacity() {
578        let (tx, rx) = ThreadChannel::<i32>::unbounded();
579
580        assert_eq!(tx.capacity(), None);
581        assert_eq!(rx.capacity(), None);
582        assert!(!tx.is_full()); // Unbounded is never full
583    }
584
585    #[test]
586    fn test_graceful_channel_trait() {
587        let channel = ThreadChannel::<i32>::new_unbounded();
588
589        assert!(!channel.is_shutdown());
590
591        channel.sender().send(1).unwrap();
592        channel.sender().send(2).unwrap();
593
594        channel.shutdown();
595
596        assert!(channel.is_shutdown());
597
598        // Drain remaining messages
599        channel.drain().unwrap();
600
601        assert!(channel.receiver().is_empty());
602    }
603
604    #[test]
605    fn test_iter() {
606        let (tx, rx) = ThreadChannel::<i32>::unbounded();
607
608        tx.send(1).unwrap();
609        tx.send(2).unwrap();
610        tx.send(3).unwrap();
611        drop(tx);
612
613        let collected: Vec<i32> = rx.iter().collect();
614        assert_eq!(collected, vec![1, 2, 3]);
615    }
616
617    #[test]
618    fn test_try_iter() {
619        let (tx, rx) = ThreadChannel::<i32>::unbounded();
620
621        tx.send(1).unwrap();
622        tx.send(2).unwrap();
623
624        let collected: Vec<i32> = rx.try_iter().collect();
625        assert_eq!(collected, vec![1, 2]);
626
627        // try_iter doesn't block, so we can continue
628        tx.send(3).unwrap();
629        assert_eq!(rx.recv().unwrap(), 3);
630    }
631}