ipckit/
graceful.rs

1//! Graceful shutdown mechanism for IPC channels
2//!
3//! This module provides the `GracefulChannel` trait for implementing graceful shutdown
4//! in IPC channels, preventing errors like `EventLoopClosed` when background threads
5//! continue sending messages after the main event loop has closed.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use ipckit::{NamedPipe, GracefulChannel, GracefulNamedPipe};
11//! use std::time::Duration;
12//!
13//! fn main() -> Result<(), ipckit::IpcError> {
14//!     let pipe = NamedPipe::create("my_pipe")?;
15//!     let graceful = GracefulNamedPipe::new(pipe);
16//!
17//!     // ... use the channel ...
18//!
19//!     // Graceful shutdown
20//!     graceful.shutdown();
21//!     graceful.drain()?; // Wait for pending messages
22//!
23//!     Ok(())
24//! }
25//! ```
26
27use crate::error::{IpcError, Result};
28use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31
32/// Trait for channels that support graceful shutdown
33///
34/// This trait provides methods for signaling shutdown, checking shutdown status,
35/// and draining pending messages before closing the channel.
36pub trait GracefulChannel {
37    /// Signal the channel to shutdown
38    ///
39    /// After calling this method:
40    /// - New send operations will return `IpcError::Closed`
41    /// - Pending messages may still be processed (use `drain()` to wait)
42    /// - `is_shutdown()` will return `true`
43    fn shutdown(&self);
44
45    /// Check if the channel has been signaled to shutdown
46    fn is_shutdown(&self) -> bool;
47
48    /// Wait for all pending messages to be processed
49    ///
50    /// This method blocks until all messages that were in-flight before
51    /// `shutdown()` was called have been processed.
52    fn drain(&self) -> Result<()>;
53
54    /// Shutdown with a timeout
55    ///
56    /// Combines `shutdown()` and `drain()` with a timeout.
57    /// Returns `IpcError::Timeout` if the drain doesn't complete within the timeout.
58    fn shutdown_timeout(&self, timeout: Duration) -> Result<()>;
59}
60
61/// Shutdown state that can be shared between channel instances
62#[derive(Debug)]
63pub struct ShutdownState {
64    /// Whether shutdown has been signaled
65    shutdown: AtomicBool,
66    /// Number of pending operations
67    pending_count: AtomicUsize,
68}
69
70impl Default for ShutdownState {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl ShutdownState {
77    /// Create a new shutdown state
78    pub fn new() -> Self {
79        Self {
80            shutdown: AtomicBool::new(false),
81            pending_count: AtomicUsize::new(0),
82        }
83    }
84
85    /// Signal shutdown
86    pub fn shutdown(&self) {
87        self.shutdown.store(true, Ordering::SeqCst);
88    }
89
90    /// Check if shutdown has been signaled
91    pub fn is_shutdown(&self) -> bool {
92        self.shutdown.load(Ordering::SeqCst)
93    }
94
95    /// Increment pending operation count
96    pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
97        if self.is_shutdown() {
98            return Err(IpcError::Closed);
99        }
100        self.pending_count.fetch_add(1, Ordering::SeqCst);
101
102        // Double-check after incrementing to prevent race condition
103        if self.is_shutdown() {
104            self.pending_count.fetch_sub(1, Ordering::SeqCst);
105            return Err(IpcError::Closed);
106        }
107
108        Ok(OperationGuard { state: self })
109    }
110
111    /// Get the current pending operation count
112    pub fn pending_count(&self) -> usize {
113        self.pending_count.load(Ordering::SeqCst)
114    }
115
116    /// Wait for all pending operations to complete
117    pub fn wait_for_drain(&self, timeout: Option<Duration>) -> Result<()> {
118        let start = Instant::now();
119        let sleep_duration = Duration::from_millis(1);
120
121        loop {
122            if self.pending_count() == 0 {
123                return Ok(());
124            }
125
126            if let Some(timeout) = timeout {
127                if start.elapsed() >= timeout {
128                    return Err(IpcError::Timeout);
129                }
130            }
131
132            std::thread::sleep(sleep_duration);
133        }
134    }
135}
136
137/// RAII guard for tracking pending operations
138pub struct OperationGuard<'a> {
139    state: &'a ShutdownState,
140}
141
142impl Drop for OperationGuard<'_> {
143    fn drop(&mut self) {
144        self.state.pending_count.fetch_sub(1, Ordering::SeqCst);
145    }
146}
147
148/// A wrapper that adds graceful shutdown capability to any channel
149#[derive(Debug)]
150pub struct GracefulWrapper<T> {
151    inner: T,
152    state: Arc<ShutdownState>,
153}
154
155impl<T> GracefulWrapper<T> {
156    /// Create a new graceful wrapper around a channel
157    pub fn new(inner: T) -> Self {
158        Self {
159            inner,
160            state: Arc::new(ShutdownState::new()),
161        }
162    }
163
164    /// Create a new graceful wrapper with a shared shutdown state
165    pub fn with_state(inner: T, state: Arc<ShutdownState>) -> Self {
166        Self { inner, state }
167    }
168
169    /// Get a reference to the inner channel
170    pub fn inner(&self) -> &T {
171        &self.inner
172    }
173
174    /// Get a mutable reference to the inner channel
175    pub fn inner_mut(&mut self) -> &mut T {
176        &mut self.inner
177    }
178
179    /// Get the shutdown state
180    pub fn state(&self) -> Arc<ShutdownState> {
181        Arc::clone(&self.state)
182    }
183
184    /// Consume the wrapper and return the inner channel
185    pub fn into_inner(self) -> T {
186        self.inner
187    }
188
189    /// Begin an operation, returning a guard that tracks the operation
190    pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
191        self.state.begin_operation()
192    }
193}
194
195impl<T> GracefulChannel for GracefulWrapper<T> {
196    fn shutdown(&self) {
197        self.state.shutdown();
198    }
199
200    fn is_shutdown(&self) -> bool {
201        self.state.is_shutdown()
202    }
203
204    fn drain(&self) -> Result<()> {
205        self.state.wait_for_drain(None)
206    }
207
208    fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
209        self.shutdown();
210        self.state.wait_for_drain(Some(timeout))
211    }
212}
213
214impl<T: Clone> Clone for GracefulWrapper<T> {
215    fn clone(&self) -> Self {
216        Self {
217            inner: self.inner.clone(),
218            state: Arc::clone(&self.state),
219        }
220    }
221}
222
223// ============================================================================
224// GracefulNamedPipe - Named pipe with graceful shutdown
225// ============================================================================
226
227use crate::pipe::NamedPipe;
228use std::io::{Read, Write};
229
230/// Named pipe with graceful shutdown support
231pub struct GracefulNamedPipe {
232    inner: NamedPipe,
233    state: Arc<ShutdownState>,
234}
235
236impl GracefulNamedPipe {
237    /// Create a new graceful named pipe wrapper
238    pub fn new(pipe: NamedPipe) -> Self {
239        Self {
240            inner: pipe,
241            state: Arc::new(ShutdownState::new()),
242        }
243    }
244
245    /// Create a new graceful named pipe with a shared shutdown state
246    pub fn with_state(pipe: NamedPipe, state: Arc<ShutdownState>) -> Self {
247        Self { inner: pipe, state }
248    }
249
250    /// Create a new named pipe server with graceful shutdown
251    pub fn create(name: &str) -> Result<Self> {
252        let pipe = NamedPipe::create(name)?;
253        Ok(Self::new(pipe))
254    }
255
256    /// Connect to an existing named pipe with graceful shutdown
257    pub fn connect(name: &str) -> Result<Self> {
258        let pipe = NamedPipe::connect(name)?;
259        Ok(Self::new(pipe))
260    }
261
262    /// Get the pipe name
263    pub fn name(&self) -> &str {
264        self.inner.name()
265    }
266
267    /// Check if this is the server end
268    pub fn is_server(&self) -> bool {
269        self.inner.is_server()
270    }
271
272    /// Wait for a client to connect (server only)
273    pub fn wait_for_client(&mut self) -> Result<()> {
274        if self.state.is_shutdown() {
275            return Err(IpcError::Closed);
276        }
277        self.inner.wait_for_client()
278    }
279
280    /// Get the shutdown state for sharing with other channels
281    pub fn state(&self) -> Arc<ShutdownState> {
282        Arc::clone(&self.state)
283    }
284
285    /// Get a reference to the inner pipe
286    pub fn inner(&self) -> &NamedPipe {
287        &self.inner
288    }
289
290    /// Get a mutable reference to the inner pipe
291    pub fn inner_mut(&mut self) -> &mut NamedPipe {
292        &mut self.inner
293    }
294}
295
296impl GracefulChannel for GracefulNamedPipe {
297    fn shutdown(&self) {
298        self.state.shutdown();
299    }
300
301    fn is_shutdown(&self) -> bool {
302        self.state.is_shutdown()
303    }
304
305    fn drain(&self) -> Result<()> {
306        self.state.wait_for_drain(None)
307    }
308
309    fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
310        self.shutdown();
311        self.state.wait_for_drain(Some(timeout))
312    }
313}
314
315impl Read for GracefulNamedPipe {
316    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
317        if self.state.is_shutdown() {
318            return Err(std::io::Error::new(
319                std::io::ErrorKind::BrokenPipe,
320                "Channel is shutdown",
321            ));
322        }
323
324        let _guard = self.state.begin_operation().map_err(|_| {
325            std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
326        })?;
327
328        self.inner.read(buf)
329    }
330}
331
332impl Write for GracefulNamedPipe {
333    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
334        if self.state.is_shutdown() {
335            return Err(std::io::Error::new(
336                std::io::ErrorKind::BrokenPipe,
337                "Channel is shutdown",
338            ));
339        }
340
341        let _guard = self.state.begin_operation().map_err(|_| {
342            std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
343        })?;
344
345        self.inner.write(buf)
346    }
347
348    fn flush(&mut self) -> std::io::Result<()> {
349        self.inner.flush()
350    }
351}
352
353// ============================================================================
354// GracefulIpcChannel - IPC channel with graceful shutdown
355// ============================================================================
356
357use crate::channel::IpcChannel;
358use serde::{de::DeserializeOwned, Serialize};
359use std::marker::PhantomData;
360
361/// IPC channel with graceful shutdown support
362pub struct GracefulIpcChannel<T = Vec<u8>> {
363    inner: IpcChannel<T>,
364    state: Arc<ShutdownState>,
365    _marker: PhantomData<T>,
366}
367
368impl<T> GracefulIpcChannel<T> {
369    /// Create a new graceful IPC channel wrapper
370    pub fn new(channel: IpcChannel<T>) -> Self {
371        Self {
372            inner: channel,
373            state: Arc::new(ShutdownState::new()),
374            _marker: PhantomData,
375        }
376    }
377
378    /// Create a new graceful IPC channel with a shared shutdown state
379    pub fn with_state(channel: IpcChannel<T>, state: Arc<ShutdownState>) -> Self {
380        Self {
381            inner: channel,
382            state,
383            _marker: PhantomData,
384        }
385    }
386
387    /// Create a new IPC channel server with graceful shutdown
388    pub fn create(name: &str) -> Result<Self> {
389        let channel = IpcChannel::create(name)?;
390        Ok(Self::new(channel))
391    }
392
393    /// Connect to an existing IPC channel with graceful shutdown
394    pub fn connect(name: &str) -> Result<Self> {
395        let channel = IpcChannel::connect(name)?;
396        Ok(Self::new(channel))
397    }
398
399    /// Get the channel name
400    pub fn name(&self) -> &str {
401        self.inner.name()
402    }
403
404    /// Check if this is the server end
405    pub fn is_server(&self) -> bool {
406        self.inner.is_server()
407    }
408
409    /// Wait for a client to connect (server only)
410    pub fn wait_for_client(&mut self) -> Result<()> {
411        if self.state.is_shutdown() {
412            return Err(IpcError::Closed);
413        }
414        self.inner.wait_for_client()
415    }
416
417    /// Get the shutdown state for sharing with other channels
418    pub fn state(&self) -> Arc<ShutdownState> {
419        Arc::clone(&self.state)
420    }
421
422    /// Get a reference to the inner channel
423    pub fn inner(&self) -> &IpcChannel<T> {
424        &self.inner
425    }
426
427    /// Get a mutable reference to the inner channel
428    pub fn inner_mut(&mut self) -> &mut IpcChannel<T> {
429        &mut self.inner
430    }
431}
432
433impl<T> GracefulChannel for GracefulIpcChannel<T> {
434    fn shutdown(&self) {
435        self.state.shutdown();
436    }
437
438    fn is_shutdown(&self) -> bool {
439        self.state.is_shutdown()
440    }
441
442    fn drain(&self) -> Result<()> {
443        self.state.wait_for_drain(None)
444    }
445
446    fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
447        self.shutdown();
448        self.state.wait_for_drain(Some(timeout))
449    }
450}
451
452impl GracefulIpcChannel<Vec<u8>> {
453    /// Send raw bytes
454    pub fn send_bytes(&mut self, data: &[u8]) -> Result<()> {
455        if self.state.is_shutdown() {
456            return Err(IpcError::Closed);
457        }
458
459        let _guard = self.state.begin_operation()?;
460        self.inner.send_bytes(data)
461    }
462
463    /// Receive raw bytes
464    pub fn recv_bytes(&mut self) -> Result<Vec<u8>> {
465        if self.state.is_shutdown() {
466            return Err(IpcError::Closed);
467        }
468
469        let _guard = self.state.begin_operation()?;
470        self.inner.recv_bytes()
471    }
472}
473
474impl<T: Serialize + DeserializeOwned> GracefulIpcChannel<T> {
475    /// Send a typed message (serialized as JSON)
476    pub fn send(&mut self, msg: &T) -> Result<()> {
477        if self.state.is_shutdown() {
478            return Err(IpcError::Closed);
479        }
480
481        let _guard = self.state.begin_operation()?;
482        self.inner.send(msg)
483    }
484
485    /// Receive a typed message (deserialized from JSON)
486    pub fn recv(&mut self) -> Result<T> {
487        if self.state.is_shutdown() {
488            return Err(IpcError::Closed);
489        }
490
491        let _guard = self.state.begin_operation()?;
492        self.inner.recv()
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499    use std::thread;
500    use std::time::Duration;
501
502    #[test]
503    fn test_shutdown_state() {
504        let state = ShutdownState::new();
505
506        assert!(!state.is_shutdown());
507        assert_eq!(state.pending_count(), 0);
508
509        state.shutdown();
510
511        assert!(state.is_shutdown());
512    }
513
514    #[test]
515    fn test_operation_guard() {
516        let state = ShutdownState::new();
517
518        {
519            let _guard = state.begin_operation().unwrap();
520            assert_eq!(state.pending_count(), 1);
521
522            {
523                let _guard2 = state.begin_operation().unwrap();
524                assert_eq!(state.pending_count(), 2);
525            }
526
527            assert_eq!(state.pending_count(), 1);
528        }
529
530        assert_eq!(state.pending_count(), 0);
531    }
532
533    #[test]
534    fn test_operation_after_shutdown() {
535        let state = ShutdownState::new();
536
537        state.shutdown();
538
539        let result = state.begin_operation();
540        assert!(result.is_err());
541    }
542
543    #[test]
544    fn test_drain() {
545        let state = Arc::new(ShutdownState::new());
546        let state_clone = Arc::clone(&state);
547
548        // Start a background operation
549        let handle = thread::spawn(move || {
550            let _guard = state_clone.begin_operation().unwrap();
551            thread::sleep(Duration::from_millis(50));
552        });
553
554        // Give the thread time to start
555        thread::sleep(Duration::from_millis(10));
556
557        // Shutdown and drain
558        state.shutdown();
559        let result = state.wait_for_drain(Some(Duration::from_secs(1)));
560
561        handle.join().unwrap();
562        assert!(result.is_ok());
563    }
564
565    #[test]
566    fn test_drain_timeout() {
567        let state = Arc::new(ShutdownState::new());
568        let state_clone = Arc::clone(&state);
569
570        // Start a long background operation
571        let handle = thread::spawn(move || {
572            let _guard = state_clone.begin_operation().unwrap();
573            thread::sleep(Duration::from_secs(10));
574        });
575
576        // Give the thread time to start
577        thread::sleep(Duration::from_millis(10));
578
579        // Shutdown with short timeout
580        state.shutdown();
581        let result = state.wait_for_drain(Some(Duration::from_millis(50)));
582
583        assert!(matches!(result, Err(IpcError::Timeout)));
584
585        // Clean up - we need to wait for the thread
586        drop(state);
587        // The thread will eventually finish
588        let _ = handle.join();
589    }
590
591    #[test]
592    fn test_graceful_wrapper() {
593        let wrapper = GracefulWrapper::new(42);
594
595        assert!(!wrapper.is_shutdown());
596        assert_eq!(*wrapper.inner(), 42);
597
598        wrapper.shutdown();
599
600        assert!(wrapper.is_shutdown());
601    }
602
603    #[test]
604    fn test_graceful_named_pipe() {
605        let name = format!("test_graceful_pipe_{}", std::process::id());
606
607        let handle = thread::spawn({
608            let name = name.clone();
609            move || {
610                let mut server = GracefulNamedPipe::create(&name).unwrap();
611                server.wait_for_client().ok();
612
613                let mut buf = [0u8; 32];
614                let n = server.read(&mut buf).unwrap();
615                assert_eq!(&buf[..n], b"Hello!");
616
617                // Shutdown
618                server.shutdown();
619                assert!(server.is_shutdown());
620
621                // Operations after shutdown should fail
622                let result = server.write(b"test");
623                assert!(result.is_err());
624            }
625        });
626
627        thread::sleep(Duration::from_millis(100));
628
629        let mut client = GracefulNamedPipe::connect(&name).unwrap();
630        client.write_all(b"Hello!").unwrap();
631
632        handle.join().unwrap();
633    }
634
635    #[test]
636    fn test_graceful_ipc_channel() {
637        let name = format!("test_graceful_channel_{}", std::process::id());
638
639        let handle = thread::spawn({
640            let name = name.clone();
641            move || {
642                let mut server = GracefulIpcChannel::<Vec<u8>>::create(&name).unwrap();
643                server.wait_for_client().ok();
644
645                let data = server.recv_bytes().unwrap();
646                assert_eq!(data, b"Hello, IPC!");
647
648                // Shutdown
649                server.shutdown();
650                assert!(server.is_shutdown());
651
652                // Operations after shutdown should fail
653                let result = server.recv_bytes();
654                assert!(matches!(result, Err(IpcError::Closed)));
655            }
656        });
657
658        thread::sleep(Duration::from_millis(100));
659
660        let mut client = GracefulIpcChannel::<Vec<u8>>::connect(&name).unwrap();
661        client.send_bytes(b"Hello, IPC!").unwrap();
662
663        handle.join().unwrap();
664    }
665}