use crate::error::{IpcError, Result};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
pub trait GracefulChannel {
fn shutdown(&self);
fn is_shutdown(&self) -> bool;
fn drain(&self) -> Result<()>;
fn shutdown_timeout(&self, timeout: Duration) -> Result<()>;
}
#[derive(Debug)]
pub struct ShutdownState {
shutdown: AtomicBool,
pending_count: AtomicUsize,
}
impl Default for ShutdownState {
fn default() -> Self {
Self::new()
}
}
impl ShutdownState {
pub fn new() -> Self {
Self {
shutdown: AtomicBool::new(false),
pending_count: AtomicUsize::new(0),
}
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::SeqCst);
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::SeqCst)
}
pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
if self.is_shutdown() {
return Err(IpcError::Closed);
}
self.pending_count.fetch_add(1, Ordering::SeqCst);
if self.is_shutdown() {
self.pending_count.fetch_sub(1, Ordering::SeqCst);
return Err(IpcError::Closed);
}
Ok(OperationGuard { state: self })
}
pub fn pending_count(&self) -> usize {
self.pending_count.load(Ordering::SeqCst)
}
pub fn wait_for_drain(&self, timeout: Option<Duration>) -> Result<()> {
let start = Instant::now();
let sleep_duration = Duration::from_millis(1);
loop {
if self.pending_count() == 0 {
return Ok(());
}
if let Some(timeout) = timeout {
if start.elapsed() >= timeout {
return Err(IpcError::Timeout);
}
}
std::thread::sleep(sleep_duration);
}
}
}
pub struct OperationGuard<'a> {
state: &'a ShutdownState,
}
impl Drop for OperationGuard<'_> {
fn drop(&mut self) {
self.state.pending_count.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(Debug)]
pub struct GracefulWrapper<T> {
inner: T,
state: Arc<ShutdownState>,
}
impl<T> GracefulWrapper<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
state: Arc::new(ShutdownState::new()),
}
}
pub fn with_state(inner: T, state: Arc<ShutdownState>) -> Self {
Self { inner, state }
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn state(&self) -> Arc<ShutdownState> {
Arc::clone(&self.state)
}
pub fn into_inner(self) -> T {
self.inner
}
pub fn begin_operation(&self) -> Result<OperationGuard<'_>> {
self.state.begin_operation()
}
}
impl<T> GracefulChannel for GracefulWrapper<T> {
fn shutdown(&self) {
self.state.shutdown();
}
fn is_shutdown(&self) -> bool {
self.state.is_shutdown()
}
fn drain(&self) -> Result<()> {
self.state.wait_for_drain(None)
}
fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
self.shutdown();
self.state.wait_for_drain(Some(timeout))
}
}
impl<T: Clone> Clone for GracefulWrapper<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
state: Arc::clone(&self.state),
}
}
}
use crate::pipe::NamedPipe;
use std::io::{Read, Write};
pub struct GracefulNamedPipe {
inner: NamedPipe,
state: Arc<ShutdownState>,
}
impl GracefulNamedPipe {
pub fn new(pipe: NamedPipe) -> Self {
Self {
inner: pipe,
state: Arc::new(ShutdownState::new()),
}
}
pub fn with_state(pipe: NamedPipe, state: Arc<ShutdownState>) -> Self {
Self { inner: pipe, state }
}
pub fn create(name: &str) -> Result<Self> {
let pipe = NamedPipe::create(name)?;
Ok(Self::new(pipe))
}
pub fn connect(name: &str) -> Result<Self> {
let pipe = NamedPipe::connect(name)?;
Ok(Self::new(pipe))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn is_server(&self) -> bool {
self.inner.is_server()
}
pub fn wait_for_client(&mut self) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
self.inner.wait_for_client()
}
pub fn state(&self) -> Arc<ShutdownState> {
Arc::clone(&self.state)
}
pub fn inner(&self) -> &NamedPipe {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut NamedPipe {
&mut self.inner
}
}
impl GracefulChannel for GracefulNamedPipe {
fn shutdown(&self) {
self.state.shutdown();
}
fn is_shutdown(&self) -> bool {
self.state.is_shutdown()
}
fn drain(&self) -> Result<()> {
self.state.wait_for_drain(None)
}
fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
self.shutdown();
self.state.wait_for_drain(Some(timeout))
}
}
impl Read for GracefulNamedPipe {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
if self.state.is_shutdown() {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Channel is shutdown",
));
}
let _guard = self.state.begin_operation().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
})?;
self.inner.read(buf)
}
}
impl Write for GracefulNamedPipe {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if self.state.is_shutdown() {
return Err(std::io::Error::new(
std::io::ErrorKind::BrokenPipe,
"Channel is shutdown",
));
}
let _guard = self.state.begin_operation().map_err(|_| {
std::io::Error::new(std::io::ErrorKind::BrokenPipe, "Channel is shutdown")
})?;
self.inner.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
use crossbeam_channel as cb;
use std::thread::ThreadId;
type BoxResult = std::result::Result<(), Box<dyn std::error::Error + Send + Sync>>;
struct WorkItem {
func: Box<dyn FnOnce() -> BoxResult + Send>,
reply: cb::Sender<BoxResult>,
}
struct DispatchState {
affinity_thread: parking_lot::RwLock<Option<ThreadId>>,
tx: cb::Sender<WorkItem>,
pending: AtomicUsize,
}
#[derive(Clone)]
pub struct ReentrantDispatch {
state: Arc<DispatchState>,
rx: Arc<cb::Receiver<WorkItem>>,
}
impl ReentrantDispatch {
pub fn new() -> Self {
let (tx, rx) = cb::unbounded();
Self {
state: Arc::new(DispatchState {
affinity_thread: parking_lot::RwLock::new(None),
tx,
pending: AtomicUsize::new(0),
}),
rx: Arc::new(rx),
}
}
pub fn bind_current_thread(&self) {
*self.state.affinity_thread.write() = Some(std::thread::current().id());
}
pub fn is_affinity_thread(&self) -> bool {
self.state
.affinity_thread
.read()
.is_some_and(|id| id == std::thread::current().id())
}
pub fn submit_reentrant<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if self.is_affinity_thread() {
return Ok(f());
}
let (reply_tx, reply_rx) = cb::bounded::<BoxResult>(1);
let slot: Arc<parking_lot::Mutex<Option<R>>> = Arc::new(parking_lot::Mutex::new(None));
let slot_write = Arc::clone(&slot);
let item = WorkItem {
func: Box::new(move || {
let result = f();
*slot_write.lock() = Some(result);
Ok(())
}),
reply: reply_tx,
};
self.state.pending.fetch_add(1, Ordering::Relaxed);
self.state.tx.send(item).map_err(|_| IpcError::Closed)?;
reply_rx
.recv()
.map_err(|_| IpcError::Closed)?
.map_err(|e| IpcError::Other(e.to_string()))?;
let value = slot
.lock()
.take()
.expect("affinity thread must have set the value");
Ok(value)
}
pub fn pump(&self, budget: Duration) -> usize {
let start = Instant::now();
let mut count = 0;
loop {
if start.elapsed() >= budget {
break;
}
match self.rx.try_recv() {
Ok(item) => {
self.state.pending.fetch_sub(1, Ordering::Relaxed);
let result = (item.func)();
let _ = item.reply.send(result);
count += 1;
}
Err(cb::TryRecvError::Empty) => break,
Err(cb::TryRecvError::Disconnected) => break,
}
}
count
}
pub fn pending_count(&self) -> usize {
self.state.pending.load(Ordering::Relaxed)
}
}
impl Default for ReentrantDispatch {
fn default() -> Self {
Self::new()
}
}
use crate::channel::IpcChannel;
use serde::{de::DeserializeOwned, Serialize};
use std::marker::PhantomData;
pub struct GracefulIpcChannel<T = Vec<u8>> {
inner: IpcChannel<T>,
state: Arc<ShutdownState>,
dispatch: ReentrantDispatch,
_marker: PhantomData<T>,
}
impl<T> GracefulIpcChannel<T> {
pub fn new(channel: IpcChannel<T>) -> Self {
Self {
inner: channel,
state: Arc::new(ShutdownState::new()),
dispatch: ReentrantDispatch::new(),
_marker: PhantomData,
}
}
pub fn with_state(channel: IpcChannel<T>, state: Arc<ShutdownState>) -> Self {
Self {
inner: channel,
state,
dispatch: ReentrantDispatch::new(),
_marker: PhantomData,
}
}
pub fn create(name: &str) -> Result<Self> {
let channel = IpcChannel::create(name)?;
Ok(Self::new(channel))
}
pub fn connect(name: &str) -> Result<Self> {
let channel = IpcChannel::connect(name)?;
Ok(Self::new(channel))
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn is_server(&self) -> bool {
self.inner.is_server()
}
pub fn wait_for_client(&mut self) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
self.inner.wait_for_client()
}
pub fn state(&self) -> Arc<ShutdownState> {
Arc::clone(&self.state)
}
pub fn inner(&self) -> &IpcChannel<T> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut IpcChannel<T> {
&mut self.inner
}
pub fn bind_affinity_thread(&self) {
self.dispatch.bind_current_thread();
}
pub fn submit_reentrant<F, R>(&self, f: F) -> Result<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
self.dispatch.submit_reentrant(f)
}
pub fn pump_pending(&self, budget: Duration) -> usize {
self.dispatch.pump(budget)
}
}
impl<T> GracefulChannel for GracefulIpcChannel<T> {
fn shutdown(&self) {
self.state.shutdown();
}
fn is_shutdown(&self) -> bool {
self.state.is_shutdown()
}
fn drain(&self) -> Result<()> {
self.state.wait_for_drain(None)
}
fn shutdown_timeout(&self, timeout: Duration) -> Result<()> {
self.shutdown();
self.state.wait_for_drain(Some(timeout))
}
}
impl GracefulIpcChannel<Vec<u8>> {
pub fn send_bytes(&mut self, data: &[u8]) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.send_bytes(data)
}
pub fn recv_bytes(&mut self) -> Result<Vec<u8>> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.recv_bytes()
}
}
impl<T: Serialize + DeserializeOwned> GracefulIpcChannel<T> {
pub fn send(&mut self, msg: &T) -> Result<()> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.send(msg)
}
pub fn recv(&mut self) -> Result<T> {
if self.state.is_shutdown() {
return Err(IpcError::Closed);
}
let _guard = self.state.begin_operation()?;
self.inner.recv()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_shutdown_state() {
let state = ShutdownState::new();
assert!(!state.is_shutdown());
assert_eq!(state.pending_count(), 0);
state.shutdown();
assert!(state.is_shutdown());
}
#[test]
fn test_operation_guard() {
let state = ShutdownState::new();
{
let _guard = state.begin_operation().unwrap();
assert_eq!(state.pending_count(), 1);
{
let _guard2 = state.begin_operation().unwrap();
assert_eq!(state.pending_count(), 2);
}
assert_eq!(state.pending_count(), 1);
}
assert_eq!(state.pending_count(), 0);
}
#[test]
fn test_operation_after_shutdown() {
let state = ShutdownState::new();
state.shutdown();
let result = state.begin_operation();
assert!(result.is_err());
}
#[test]
fn test_drain() {
let state = Arc::new(ShutdownState::new());
let state_clone = Arc::clone(&state);
let handle = thread::spawn(move || {
let _guard = state_clone.begin_operation().unwrap();
thread::sleep(Duration::from_millis(50));
});
thread::sleep(Duration::from_millis(10));
state.shutdown();
let result = state.wait_for_drain(Some(Duration::from_secs(1)));
handle.join().unwrap();
assert!(result.is_ok());
}
#[test]
fn test_drain_timeout() {
let state = Arc::new(ShutdownState::new());
let state_clone = Arc::clone(&state);
let handle = thread::spawn(move || {
let _guard = state_clone.begin_operation().unwrap();
thread::sleep(Duration::from_secs(10));
});
thread::sleep(Duration::from_millis(10));
state.shutdown();
let result = state.wait_for_drain(Some(Duration::from_millis(50)));
assert!(matches!(result, Err(IpcError::Timeout)));
drop(state);
let _ = handle.join();
}
#[test]
fn test_graceful_wrapper() {
let wrapper = GracefulWrapper::new(42);
assert!(!wrapper.is_shutdown());
assert_eq!(*wrapper.inner(), 42);
wrapper.shutdown();
assert!(wrapper.is_shutdown());
}
#[test]
fn test_graceful_named_pipe() {
let name = format!("test_graceful_pipe_{}", std::process::id());
let handle = thread::spawn({
let name = name.clone();
move || {
let mut server = GracefulNamedPipe::create(&name).unwrap();
server.wait_for_client().ok();
let mut buf = [0u8; 32];
let n = server.read(&mut buf).unwrap();
assert_eq!(&buf[..n], b"Hello!");
server.shutdown();
assert!(server.is_shutdown());
let result = server.write(b"test");
assert!(result.is_err());
}
});
thread::sleep(Duration::from_millis(100));
let mut client = GracefulNamedPipe::connect(&name).unwrap();
client.write_all(b"Hello!").unwrap();
handle.join().unwrap();
}
#[test]
fn test_graceful_ipc_channel() {
let name = format!("test_graceful_channel_{}", std::process::id());
let handle = thread::spawn({
let name = name.clone();
move || {
let mut server = GracefulIpcChannel::<Vec<u8>>::create(&name).unwrap();
server.wait_for_client().ok();
let data = server.recv_bytes().unwrap();
assert_eq!(data, b"Hello, IPC!");
server.shutdown();
assert!(server.is_shutdown());
let result = server.recv_bytes();
assert!(matches!(result, Err(IpcError::Closed)));
}
});
thread::sleep(Duration::from_millis(100));
let mut client = GracefulIpcChannel::<Vec<u8>>::connect(&name).unwrap();
client.send_bytes(b"Hello, IPC!").unwrap();
handle.join().unwrap();
}
#[test]
fn test_reentrant_dispatch_inline_on_affinity_thread() {
let dispatch = ReentrantDispatch::new();
dispatch.bind_current_thread();
let result: i32 = dispatch.submit_reentrant(|| 42).unwrap();
assert_eq!(result, 42);
assert_eq!(dispatch.pending_count(), 0);
}
#[test]
fn test_reentrant_dispatch_cross_thread() {
let dispatch = ReentrantDispatch::new();
dispatch.bind_current_thread();
let dispatch_worker = dispatch.clone();
let dispatch_pump = dispatch.clone();
let handle = thread::spawn(move || {
let result: u64 = dispatch_worker
.submit_reentrant(|| 99_u64)
.expect("submit failed");
assert_eq!(result, 99);
});
thread::sleep(Duration::from_millis(20));
let processed = dispatch_pump.pump(Duration::from_millis(100));
assert_eq!(processed, 1);
handle.join().unwrap();
}
#[test]
fn test_reentrant_dispatch_multiple_submissions() {
let dispatch = ReentrantDispatch::new();
dispatch.bind_current_thread();
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let handles: Vec<_> = (0..5)
.map(|_| {
let d = dispatch.clone();
let c = Arc::clone(&counter);
thread::spawn(move || {
d.submit_reentrant(move || {
c.fetch_add(1, Ordering::SeqCst);
})
.unwrap();
})
})
.collect();
thread::sleep(Duration::from_millis(30));
let processed = dispatch.pump(Duration::from_millis(500));
assert_eq!(processed, 5);
for h in handles {
h.join().unwrap();
}
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
#[test]
fn test_graceful_channel_submit_reentrant() {
let name = format!("test_reentrant_channel_{}", std::process::id());
let server = GracefulIpcChannel::<Vec<u8>>::create(&name).unwrap();
server.bind_affinity_thread();
let result: &'static str = server.submit_reentrant(|| "hello from affinity").unwrap();
assert_eq!(result, "hello from affinity");
}
#[test]
fn test_graceful_channel_submit_reentrant_after_shutdown() {
let name = format!("test_reentrant_shutdown_{}", std::process::id());
let channel = GracefulIpcChannel::<Vec<u8>>::create(&name).unwrap();
channel.shutdown();
let result = channel.submit_reentrant(|| ());
assert!(matches!(result, Err(IpcError::Closed)));
}
}