1use crate::error::{IpcError, Result};
28use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
29use std::sync::Arc;
30use std::time::{Duration, Instant};
31
32pub trait GracefulChannel {
37 fn shutdown(&self);
44
45 fn is_shutdown(&self) -> bool;
47
48 fn drain(&self) -> Result<()>;
53
54 fn shutdown_timeout(&self, timeout: Duration) -> Result<()>;
59}
60
61#[derive(Debug)]
63pub struct ShutdownState {
64 shutdown: AtomicBool,
66 pending_count: AtomicUsize,
68}
69
70impl Default for ShutdownState {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl ShutdownState {
77 pub fn new() -> Self {
79 Self {
80 shutdown: AtomicBool::new(false),
81 pending_count: AtomicUsize::new(0),
82 }
83 }
84
85 pub fn shutdown(&self) {
87 self.shutdown.store(true, Ordering::SeqCst);
88 }
89
90 pub fn is_shutdown(&self) -> bool {
92 self.shutdown.load(Ordering::SeqCst)
93 }
94
95 pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
97 if self.is_shutdown() {
98 return Err(IpcError::Closed);
99 }
100 self.pending_count.fetch_add(1, Ordering::SeqCst);
101
102 if self.is_shutdown() {
104 self.pending_count.fetch_sub(1, Ordering::SeqCst);
105 return Err(IpcError::Closed);
106 }
107
108 Ok(OperationGuard { state: self })
109 }
110
111 pub fn pending_count(&self) -> usize {
113 self.pending_count.load(Ordering::SeqCst)
114 }
115
116 pub fn wait_for_drain(&self, timeout: Option<Duration>) -> Result<()> {
118 let start = Instant::now();
119 let sleep_duration = Duration::from_millis(1);
120
121 loop {
122 if self.pending_count() == 0 {
123 return Ok(());
124 }
125
126 if let Some(timeout) = timeout {
127 if start.elapsed() >= timeout {
128 return Err(IpcError::Timeout);
129 }
130 }
131
132 std::thread::sleep(sleep_duration);
133 }
134 }
135}
136
137pub struct OperationGuard<'a> {
139 state: &'a ShutdownState,
140}
141
142impl Drop for OperationGuard<'_> {
143 fn drop(&mut self) {
144 self.state.pending_count.fetch_sub(1, Ordering::SeqCst);
145 }
146}
147
148#[derive(Debug)]
150pub struct GracefulWrapper<T> {
151 inner: T,
152 state: Arc<ShutdownState>,
153}
154
155impl<T> GracefulWrapper<T> {
156 pub fn new(inner: T) -> Self {
158 Self {
159 inner,
160 state: Arc::new(ShutdownState::new()),
161 }
162 }
163
164 pub fn with_state(inner: T, state: Arc<ShutdownState>) -> Self {
166 Self { inner, state }
167 }
168
169 pub fn inner(&self) -> &T {
171 &self.inner
172 }
173
174 pub fn inner_mut(&mut self) -> &mut T {
176 &mut self.inner
177 }
178
179 pub fn state(&self) -> Arc<ShutdownState> {
181 Arc::clone(&self.state)
182 }
183
184 pub fn into_inner(self) -> T {
186 self.inner
187 }
188
189 pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
191 self.state.begin_operation()
192 }
193}
194
195impl<T> GracefulChannel for GracefulWrapper<T> {
196 fn shutdown(&self) {
197 self.state.shutdown();
198 }
199
200 fn is_shutdown(&self) -> bool {
201 self.state.is_shutdown()
202 }
203
204 fn drain(&self) -> Result<()> {
205 self.state.wait_for_drain(None)
206 }
207
208 fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
209 self.shutdown();
210 self.state.wait_for_drain(Some(timeout))
211 }
212}
213
214impl<T: Clone> Clone for GracefulWrapper<T> {
215 fn clone(&self) -> Self {
216 Self {
217 inner: self.inner.clone(),
218 state: Arc::clone(&self.state),
219 }
220 }
221}
222
223use crate::pipe::NamedPipe;
228use std::io::{Read, Write};
229
230pub struct GracefulNamedPipe {
232 inner: NamedPipe,
233 state: Arc<ShutdownState>,
234}
235
236impl GracefulNamedPipe {
237 pub fn new(pipe: NamedPipe) -> Self {
239 Self {
240 inner: pipe,
241 state: Arc::new(ShutdownState::new()),
242 }
243 }
244
245 pub fn with_state(pipe: NamedPipe, state: Arc<ShutdownState>) -> Self {
247 Self { inner: pipe, state }
248 }
249
250 pub fn create(name: &str) -> Result<Self> {
252 let pipe = NamedPipe::create(name)?;
253 Ok(Self::new(pipe))
254 }
255
256 pub fn connect(name: &str) -> Result<Self> {
258 let pipe = NamedPipe::connect(name)?;
259 Ok(Self::new(pipe))
260 }
261
262 pub fn name(&self) -> &str {
264 self.inner.name()
265 }
266
267 pub fn is_server(&self) -> bool {
269 self.inner.is_server()
270 }
271
272 pub fn wait_for_client(&mut self) -> Result<()> {
274 if self.state.is_shutdown() {
275 return Err(IpcError::Closed);
276 }
277 self.inner.wait_for_client()
278 }
279
280 pub fn state(&self) -> Arc<ShutdownState> {
282 Arc::clone(&self.state)
283 }
284
285 pub fn inner(&self) -> &NamedPipe {
287 &self.inner
288 }
289
290 pub fn inner_mut(&mut self) -> &mut NamedPipe {
292 &mut self.inner
293 }
294}
295
296impl GracefulChannel for GracefulNamedPipe {
297 fn shutdown(&self) {
298 self.state.shutdown();
299 }
300
301 fn is_shutdown(&self) -> bool {
302 self.state.is_shutdown()
303 }
304
305 fn drain(&self) -> Result<()> {
306 self.state.wait_for_drain(None)
307 }
308
309 fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
310 self.shutdown();
311 self.state.wait_for_drain(Some(timeout))
312 }
313}
314
315impl Read for GracefulNamedPipe {
316 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
317 if self.state.is_shutdown() {
318 return Err(std::io::Error::new(
319 std::io::ErrorKind::BrokenPipe,
320 "Channel is shutdown",
321 ));
322 }
323
324 let _guard = self.state.begin_operation().map_err(|_| {
325 std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
326 })?;
327
328 self.inner.read(buf)
329 }
330}
331
332impl Write for GracefulNamedPipe {
333 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
334 if self.state.is_shutdown() {
335 return Err(std::io::Error::new(
336 std::io::ErrorKind::BrokenPipe,
337 "Channel is shutdown",
338 ));
339 }
340
341 let _guard = self.state.begin_operation().map_err(|_| {
342 std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
343 })?;
344
345 self.inner.write(buf)
346 }
347
348 fn flush(&mut self) -> std::io::Result<()> {
349 self.inner.flush()
350 }
351}
352
353use crate::channel::IpcChannel;
358use serde::{de::DeserializeOwned, Serialize};
359use std::marker::PhantomData;
360
361pub struct GracefulIpcChannel<T = Vec<u8>> {
363 inner: IpcChannel<T>,
364 state: Arc<ShutdownState>,
365 _marker: PhantomData<T>,
366}
367
368impl<T> GracefulIpcChannel<T> {
369 pub fn new(channel: IpcChannel<T>) -> Self {
371 Self {
372 inner: channel,
373 state: Arc::new(ShutdownState::new()),
374 _marker: PhantomData,
375 }
376 }
377
378 pub fn with_state(channel: IpcChannel<T>, state: Arc<ShutdownState>) -> Self {
380 Self {
381 inner: channel,
382 state,
383 _marker: PhantomData,
384 }
385 }
386
387 pub fn create(name: &str) -> Result<Self> {
389 let channel = IpcChannel::create(name)?;
390 Ok(Self::new(channel))
391 }
392
393 pub fn connect(name: &str) -> Result<Self> {
395 let channel = IpcChannel::connect(name)?;
396 Ok(Self::new(channel))
397 }
398
399 pub fn name(&self) -> &str {
401 self.inner.name()
402 }
403
404 pub fn is_server(&self) -> bool {
406 self.inner.is_server()
407 }
408
409 pub fn wait_for_client(&mut self) -> Result<()> {
411 if self.state.is_shutdown() {
412 return Err(IpcError::Closed);
413 }
414 self.inner.wait_for_client()
415 }
416
417 pub fn state(&self) -> Arc<ShutdownState> {
419 Arc::clone(&self.state)
420 }
421
422 pub fn inner(&self) -> &IpcChannel<T> {
424 &self.inner
425 }
426
427 pub fn inner_mut(&mut self) -> &mut IpcChannel<T> {
429 &mut self.inner
430 }
431}
432
433impl<T> GracefulChannel for GracefulIpcChannel<T> {
434 fn shutdown(&self) {
435 self.state.shutdown();
436 }
437
438 fn is_shutdown(&self) -> bool {
439 self.state.is_shutdown()
440 }
441
442 fn drain(&self) -> Result<()> {
443 self.state.wait_for_drain(None)
444 }
445
446 fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
447 self.shutdown();
448 self.state.wait_for_drain(Some(timeout))
449 }
450}
451
452impl GracefulIpcChannel<Vec<u8>> {
453 pub fn send_bytes(&mut self, data: &[u8]) -> Result<()> {
455 if self.state.is_shutdown() {
456 return Err(IpcError::Closed);
457 }
458
459 let _guard = self.state.begin_operation()?;
460 self.inner.send_bytes(data)
461 }
462
463 pub fn recv_bytes(&mut self) -> Result<Vec<u8>> {
465 if self.state.is_shutdown() {
466 return Err(IpcError::Closed);
467 }
468
469 let _guard = self.state.begin_operation()?;
470 self.inner.recv_bytes()
471 }
472}
473
474impl<T: Serialize + DeserializeOwned> GracefulIpcChannel<T> {
475 pub fn send(&mut self, msg: &T) -> Result<()> {
477 if self.state.is_shutdown() {
478 return Err(IpcError::Closed);
479 }
480
481 let _guard = self.state.begin_operation()?;
482 self.inner.send(msg)
483 }
484
485 pub fn recv(&mut self) -> Result<T> {
487 if self.state.is_shutdown() {
488 return Err(IpcError::Closed);
489 }
490
491 let _guard = self.state.begin_operation()?;
492 self.inner.recv()
493 }
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499 use std::thread;
500 use std::time::Duration;
501
502 #[test]
503 fn test_shutdown_state() {
504 let state = ShutdownState::new();
505
506 assert!(!state.is_shutdown());
507 assert_eq!(state.pending_count(), 0);
508
509 state.shutdown();
510
511 assert!(state.is_shutdown());
512 }
513
514 #[test]
515 fn test_operation_guard() {
516 let state = ShutdownState::new();
517
518 {
519 let _guard = state.begin_operation().unwrap();
520 assert_eq!(state.pending_count(), 1);
521
522 {
523 let _guard2 = state.begin_operation().unwrap();
524 assert_eq!(state.pending_count(), 2);
525 }
526
527 assert_eq!(state.pending_count(), 1);
528 }
529
530 assert_eq!(state.pending_count(), 0);
531 }
532
533 #[test]
534 fn test_operation_after_shutdown() {
535 let state = ShutdownState::new();
536
537 state.shutdown();
538
539 let result = state.begin_operation();
540 assert!(result.is_err());
541 }
542
543 #[test]
544 fn test_drain() {
545 let state = Arc::new(ShutdownState::new());
546 let state_clone = Arc::clone(&state);
547
548 let handle = thread::spawn(move || {
550 let _guard = state_clone.begin_operation().unwrap();
551 thread::sleep(Duration::from_millis(50));
552 });
553
554 thread::sleep(Duration::from_millis(10));
556
557 state.shutdown();
559 let result = state.wait_for_drain(Some(Duration::from_secs(1)));
560
561 handle.join().unwrap();
562 assert!(result.is_ok());
563 }
564
565 #[test]
566 fn test_drain_timeout() {
567 let state = Arc::new(ShutdownState::new());
568 let state_clone = Arc::clone(&state);
569
570 let handle = thread::spawn(move || {
572 let _guard = state_clone.begin_operation().unwrap();
573 thread::sleep(Duration::from_secs(10));
574 });
575
576 thread::sleep(Duration::from_millis(10));
578
579 state.shutdown();
581 let result = state.wait_for_drain(Some(Duration::from_millis(50)));
582
583 assert!(matches!(result, Err(IpcError::Timeout)));
584
585 drop(state);
587 let _ = handle.join();
589 }
590
591 #[test]
592 fn test_graceful_wrapper() {
593 let wrapper = GracefulWrapper::new(42);
594
595 assert!(!wrapper.is_shutdown());
596 assert_eq!(*wrapper.inner(), 42);
597
598 wrapper.shutdown();
599
600 assert!(wrapper.is_shutdown());
601 }
602
603 #[test]
604 fn test_graceful_named_pipe() {
605 let name = format!("test_graceful_pipe_{}", std::process::id());
606
607 let handle = thread::spawn({
608 let name = name.clone();
609 move || {
610 let mut server = GracefulNamedPipe::create(&name).unwrap();
611 server.wait_for_client().ok();
612
613 let mut buf = [0u8; 32];
614 let n = server.read(&mut buf).unwrap();
615 assert_eq!(&buf[..n], b"Hello!");
616
617 server.shutdown();
619 assert!(server.is_shutdown());
620
621 let result = server.write(b"test");
623 assert!(result.is_err());
624 }
625 });
626
627 thread::sleep(Duration::from_millis(100));
628
629 let mut client = GracefulNamedPipe::connect(&name).unwrap();
630 client.write_all(b"Hello!").unwrap();
631
632 handle.join().unwrap();
633 }
634
635 #[test]
636 fn test_graceful_ipc_channel() {
637 let name = format!("test_graceful_channel_{}", std::process::id());
638
639 let handle = thread::spawn({
640 let name = name.clone();
641 move || {
642 let mut server = GracefulIpcChannel::<Vec<u8>>::create(&name).unwrap();
643 server.wait_for_client().ok();
644
645 let data = server.recv_bytes().unwrap();
646 assert_eq!(data, b"Hello, IPC!");
647
648 server.shutdown();
650 assert!(server.is_shutdown());
651
652 let result = server.recv_bytes();
654 assert!(matches!(result, Err(IpcError::Closed)));
655 }
656 });
657
658 thread::sleep(Duration::from_millis(100));
659
660 let mut client = GracefulIpcChannel::<Vec<u8>>::connect(&name).unwrap();
661 client.send_bytes(b"Hello, IPC!").unwrap();
662
663 handle.join().unwrap();
664 }
665}