use crate::error::{IpcError, Result};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub trait GracefulChannel {
fn shutdown(&self);
fn is_shutdown(&self) -> bool;
fn drain(&self) -> Result<()>;
fn shutdown_timeout(&self, timeout: Duration) -> Result<()>;
}
#[derive(Debug)]
pub struct ShutdownState {
shutdown: AtomicBool,
pending_count: AtomicUsize,
}
impl Default for ShutdownState {
fn default() -> Self {
Self::new()
}
}
impl ShutdownState {
pub fn new() -> Self {
Self {
shutdown: AtomicBool::new(false),
pending_count: AtomicUsize::new(0),
}
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::SeqCst)
}
pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
if self.is_shutdown() {
return Err(IpcError::Closed);
}
self.pending_count.fetch_add(1, Ordering::SeqCst);
if self.is_shutdown() {
self.pending_count.fetch_sub(1, Ordering::SeqCst);
return Err(IpcError::Closed);
}
Ok(OperationGuard { state: self })
}
pub fn pending_count(&self) -> usize {
self.pending_count.load(Ordering::SeqCst)
}
pub fn wait_for_drain(&self, timeout: Option<Duration>) -> Result<()> {
let start = Instant::now();
let sleep_duration = Duration::from_millis(1);
loop {
if self.pending_count() == 0 {
return Ok(());
}
if let Some(timeout) = timeout {
if start.elapsed() >= timeout {
return Err(IpcError::Timeout);
}
}
std::thread::sleep(sleep_duration);
}
}
}
pub struct OperationGuard<'a> {
state: &'a ShutdownState,
}
impl Drop for OperationGuard<'_> {
fn drop(&mut self) {
self.state.pending_count.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(Debug)]
pub struct GracefulWrapper<T> {
inner: T,
state: Arc<ShutdownState>,
}
impl<T> GracefulWrapper<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
state: Arc::new(ShutdownState::new()),
}
}
pub fn with_state(inner: T, state: Arc<ShutdownState>) -> Self {
Self { inner, state }
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn state(&self) -> Arc<ShutdownState> {
Arc::clone(&self.state)
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
self.state.begin_operation()
}
}
impl<T> GracefulChannel for GracefulWrapper<T> {
fn shutdown(&self) {
self.state.shutdown();
}
fn is_shutdown(&self) -> bool {
self.state.is_shutdown()
}
fn drain(&self) -> Result<()> {
self.state.wait_for_drain(None)
}
fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
self.shutdown();
self.state.wait_for_drain(Some(timeout))
}
}
impl<T: Clone> Clone for GracefulWrapper<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
state: Arc::clone(&self.state),
}
}
}
use crate::pipe::NamedPipe;
use std::io::{Read, Write};
pub struct GracefulNamedPipe {
inner: NamedPipe,
state: Arc<ShutdownState>,
}
impl GracefulNamedPipe {
pub fn new(pipe: NamedPipe) -> Self {
Self {
inner: pipe,
state: Arc::new(ShutdownState::new()),
}
}
pub fn with_state(pipe: NamedPipe, state: Arc<ShutdownState>) -> Self {
Self { inner: pipe, state }
}
pub fn create(name: &str) -> Result<Self> {
let pipe = NamedPipe::create(name)?;
Ok(Self::new(pipe))
}
pub fn connect(name: &str) -> Result<Self> {
let pipe = NamedPipe::connect(name)?;
Ok(Self::new(pipe))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn is_server(&self) -> bool {
self.inner.is_server()
}
pub fn wait_for_client(&mut self) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
self.inner.wait_for_client()
}
pub fn state(&self) -> Arc<ShutdownState> {
Arc::clone(&self.state)
}
pub fn inner(&self) -> &NamedPipe {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut NamedPipe {
&mut self.inner
}
}
impl GracefulChannel for GracefulNamedPipe {
fn shutdown(&self) {
self.state.shutdown();
}
fn is_shutdown(&self) -> bool {
self.state.is_shutdown()
}
fn drain(&self) -> Result<()> {
self.state.wait_for_drain(None)
}
fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
self.shutdown();
self.state.wait_for_drain(Some(timeout))
}
}
impl Read for GracefulNamedPipe {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.state.is_shutdown() {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Channel is shutdown",
));
}
let _guard = self.state.begin_operation().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
})?;
self.inner.read(buf)
}
}
impl Write for GracefulNamedPipe {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if self.state.is_shutdown() {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Channel is shutdown",
));
}
let _guard = self.state.begin_operation().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
})?;
self.inner.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
use crate::channel::IpcChannel;
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
pub struct GracefulIpcChannel<T = Vec<u8>> {
inner: IpcChannel<T>,
state: Arc<ShutdownState>,
_marker: PhantomData<T>,
}
impl<T> GracefulIpcChannel<T> {
pub fn new(channel: IpcChannel<T>) -> Self {
Self {
inner: channel,
state: Arc::new(ShutdownState::new()),
_marker: PhantomData,
}
}
pub fn with_state(channel: IpcChannel<T>, state: Arc<ShutdownState>) -> Self {
Self {
inner: channel,
state,
_marker: PhantomData,
}
}
pub fn create(name: &str) -> Result<Self> {
let channel = IpcChannel::create(name)?;
Ok(Self::new(channel))
}
pub fn connect(name: &str) -> Result<Self> {
let channel = IpcChannel::connect(name)?;
Ok(Self::new(channel))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn is_server(&self) -> bool {
self.inner.is_server()
}
pub fn wait_for_client(&mut self) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
self.inner.wait_for_client()
}
pub fn state(&self) -> Arc<ShutdownState> {
Arc::clone(&self.state)
}
pub fn inner(&self) -> &IpcChannel<T> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut IpcChannel<T> {
&mut self.inner
}
}
impl<T> GracefulChannel for GracefulIpcChannel<T> {
fn shutdown(&self) {
self.state.shutdown();
}
fn is_shutdown(&self) -> bool {
self.state.is_shutdown()
}
fn drain(&self) -> Result<()> {
self.state.wait_for_drain(None)
}
fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
self.shutdown();
self.state.wait_for_drain(Some(timeout))
}
}
impl GracefulIpcChannel<Vec<u8>> {
pub fn send_bytes(&mut self, data: &[u8]) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.send_bytes(data)
}
pub fn recv_bytes(&mut self) -> Result<Vec<u8>> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.recv_bytes()
}
}
impl<T: Serialize + DeserializeOwned> GracefulIpcChannel<T> {
pub fn send(&mut self, msg: &T) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.send(msg)
}
pub fn recv(&mut self) -> Result<T> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.recv()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_shutdown_state() {
let state = ShutdownState::new();
assert!(!state.is_shutdown());
assert_eq!(state.pending_count(), 0);
state.shutdown();
assert!(state.is_shutdown());
}
#[test]
fn test_operation_guard() {
let state = ShutdownState::new();
{
let _guard = state.begin_operation().unwrap();
assert_eq!(state.pending_count(), 1);
{
let _guard2 = state.begin_operation().unwrap();
assert_eq!(state.pending_count(), 2);
}
assert_eq!(state.pending_count(), 1);
}
assert_eq!(state.pending_count(), 0);
}
#[test]
fn test_operation_after_shutdown() {
let state = ShutdownState::new();
state.shutdown();
let result = state.begin_operation();
assert!(result.is_err());
}
#[test]
fn test_drain() {
let state = Arc::new(ShutdownState::new());
let state_clone = Arc::clone(&state);
let handle = thread::spawn(move || {
let _guard = state_clone.begin_operation().unwrap();
thread::sleep(Duration::from_millis(50));
});
thread::sleep(Duration::from_millis(10));
state.shutdown();
let result = state.wait_for_drain(Some(Duration::from_secs(1)));
handle.join().unwrap();
assert!(result.is_ok());
}
#[test]
fn test_drain_timeout() {
let state = Arc::new(ShutdownState::new());
let state_clone = Arc::clone(&state);
let handle = thread::spawn(move || {
let _guard = state_clone.begin_operation().unwrap();
thread::sleep(Duration::from_secs(10));
});
thread::sleep(Duration::from_millis(10));
state.shutdown();
let result = state.wait_for_drain(Some(Duration::from_millis(50)));
assert!(matches!(result, Err(IpcError::Timeout)));
drop(state);
let _ = handle.join();
}
#[test]
fn test_graceful_wrapper() {
let wrapper = GracefulWrapper::new(42);
assert!(!wrapper.is_shutdown());
assert_eq!(*wrapper.inner(), 42);
wrapper.shutdown();
assert!(wrapper.is_shutdown());
}
#[test]
fn test_graceful_named_pipe() {
let name = format!("test_graceful_pipe_{}", std::process::id());
let handle = thread::spawn({
let name = name.clone();
move || {
let mut server = GracefulNamedPipe::create(&name).unwrap();
server.wait_for_client().ok();
let mut buf = [0u8; 32];
let n = server.read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"Hello!");
server.shutdown();
assert!(server.is_shutdown());
let result = server.write(b"test");
assert!(result.is_err());
}
});
thread::sleep(Duration::from_millis(100));
let mut client = GracefulNamedPipe::connect(&name).unwrap();
client.write_all(b"Hello!").unwrap();
handle.join().unwrap();
}
#[test]
fn test_graceful_ipc_channel() {
let name = format!("test_graceful_channel_{}", std::process::id());
let handle = thread::spawn({
let name = name.clone();
move || {
let mut server = GracefulIpcChannel::<Vec<u8>>::create(&name).unwrap();
server.wait_for_client().ok();
let data = server.recv_bytes().unwrap();
assert_eq!(data, b"Hello, IPC!");
server.shutdown();
assert!(server.is_shutdown());
let result = server.recv_bytes();
assert!(matches!(result, Err(IpcError::Closed)));
}
});
thread::sleep(Duration::from_millis(100));
let mut client = GracefulIpcChannel::<Vec<u8>>::connect(&name).unwrap();
client.send_bytes(b"Hello, IPC!").unwrap();
handle.join().unwrap();
}
}