ipcez 0.1.0

Rust library for ipcez.
Documentation
//! AsyncWrite implementations for socket transport types (WebSocket, Unix, named pipe).
//! Also sends one message (write + flush, signal data_ready, wait for recipient ack on local transports).

use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::Sink;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;

use crate::event_handler::{self, AckWaitError};
use crate::event_sender;
use crate::socket::{
    connection_lost_error, InnerSocket, MessageFramed, SocketError, WebSocketAdapter,
    MAX_MESSAGE_LEN,
};
#[cfg(unix)]
use crate::socket::PolledUnixStream;
#[cfg(windows)]
use crate::socket::PolledNamedPipe;

macro_rules! impl_polled_async_write {
    ($ty:ty) => {
        impl AsyncWrite for $ty {
            fn poll_write(
                mut self: Pin<&mut Self>,
                cx: &mut Context<'_>,
                buf: &[u8],
            ) -> Poll<io::Result<usize>> {
                if self.disconnected {
                    return Poll::Ready(Err(connection_lost_error()));
                }
                if self.last_check.elapsed() >= self.interval {
                    match self.run_liveness_check(cx) {
                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                        Poll::Ready(Ok(())) => {}
                        Poll::Pending => return Poll::Pending,
                    }
                }
                tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
            }

            fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
                if self.disconnected {
                    return Poll::Ready(Err(connection_lost_error()));
                }
                if self.last_check.elapsed() >= self.interval {
                    match self.run_liveness_check(cx) {
                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                        Poll::Ready(Ok(())) => {}
                        Poll::Pending => return Poll::Pending,
                    }
                }
                tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
            }

            fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
                if self.disconnected {
                    return Poll::Ready(Err(connection_lost_error()));
                }
                tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
            }
        }
    };
}

#[cfg(unix)]
impl_polled_async_write!(PolledUnixStream);

#[cfg(windows)]
impl_polled_async_write!(PolledNamedPipe);

#[cfg(any(unix, windows))]
impl<T> AsyncWrite for MessageFramed<T>
where
    T: tokio::io::AsyncRead + AsyncWrite + Unpin,
{
    fn poll_write(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.get_mut();
        this.write_buf.extend_from_slice(buf);
        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        let this = self.get_mut();
        if this.write_buf.is_empty() {
            return Poll::Ready(Ok(()));
        }
        let payload: Vec<u8> = std::mem::take(&mut this.write_buf);
        let payload_len = payload.len();
        if payload_len > MAX_MESSAGE_LEN as usize {
            this.write_buf = payload;
            return Poll::Ready(Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                format!("message length {} exceeds max {}", payload_len, MAX_MESSAGE_LEN),
            )));
        }
        let len = payload.len() as u32;
        let header = len.to_be_bytes();
        let mut inner = Pin::new(&mut this.inner);
        let mut header_written = 0usize;
        while header_written < 4 {
            match inner.as_mut().poll_write(cx, &header[header_written..]) {
                Poll::Ready(Ok(0)) => {
                    this.write_buf = payload;
                    return Poll::Ready(Err(io::Error::new(
                        io::ErrorKind::WriteZero,
                        "failed to write frame length",
                    )));
                }
                Poll::Ready(Ok(n)) => header_written += n,
                Poll::Ready(Err(e)) => {
                    this.write_buf = payload;
                    return Poll::Ready(Err(e));
                }
                Poll::Pending => {
                    this.write_buf = payload;
                    return Poll::Pending;
                }
            }
        }
        let mut written = 0usize;
        while written < payload.len() {
            match inner.as_mut().poll_write(cx, &payload[written..]) {
                Poll::Ready(Ok(0)) => {
                    this.write_buf = payload;
                    return Poll::Ready(Err(io::Error::new(
                        io::ErrorKind::WriteZero,
                        "failed to write frame payload",
                    )));
                }
                Poll::Ready(Ok(n)) => written += n,
                Poll::Ready(Err(e)) => {
                    this.write_buf = payload;
                    return Poll::Ready(Err(e));
                }
                Poll::Pending => {
                    this.write_buf = payload;
                    return Poll::Pending;
                }
            }
        }
        inner.as_mut().poll_flush(cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
    }
}

impl AsyncWrite for WebSocketAdapter {
    fn poll_write(
        mut self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        self.write_buf.extend_from_slice(buf);
        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        if self.write_buf.is_empty() {
            return Poll::Ready(Ok(()));
        }
        let data = std::mem::take(&mut self.write_buf);
        match self.stream.as_mut().poll_ready(cx) {
            Poll::Ready(Ok(())) => {}
            Poll::Ready(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
            Poll::Pending => {
                self.write_buf = data;
                return Poll::Pending;
            }
        }
        match self.stream.as_mut().start_send(Message::Binary(data.into())) {
            Ok(()) => {}
            Err(e) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
        }
        match self.stream.as_mut().poll_flush(cx) {
            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
            Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
            Poll::Pending => Poll::Pending,
        }
    }

    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.stream
            .as_mut()
            .poll_close(cx)
            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
    }
}

impl AsyncWrite for InnerSocket {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        match self.get_mut() {
            InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
            InnerSocket::WebSocket(s) => Pin::new(s).poll_write(cx, buf),
            #[cfg(unix)]
            InnerSocket::Unix(s, _) => tokio::io::AsyncWrite::poll_write(Pin::new(s), cx, buf),
            #[cfg(windows)]
            InnerSocket::NamedPipe(s, _) => tokio::io::AsyncWrite::poll_write(Pin::new(s), cx, buf),
        }
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.get_mut() {
            InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
            InnerSocket::WebSocket(s) => Pin::new(s).poll_flush(cx),
            #[cfg(unix)]
            InnerSocket::Unix(s, _) => tokio::io::AsyncWrite::poll_flush(Pin::new(s), cx),
            #[cfg(windows)]
            InnerSocket::NamedPipe(s, _) => tokio::io::AsyncWrite::poll_flush(Pin::new(s), cx),
        }
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        match self.get_mut() {
            InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
            InnerSocket::WebSocket(s) => Pin::new(s).poll_shutdown(cx),
            #[cfg(unix)]
            InnerSocket::Unix(s, _) => tokio::io::AsyncWrite::poll_shutdown(Pin::new(s), cx),
            #[cfg(windows)]
            InnerSocket::NamedPipe(s, _) => tokio::io::AsyncWrite::poll_shutdown(Pin::new(s), cx),
        }
    }
}

/// Writes one message to the inner socket (write_all + flush).
pub(crate) async fn write_message(inner: &mut InnerSocket, msg: &[u8]) -> Result<(), SocketError> {
    AsyncWriteExt::write_all(inner, msg)
        .await
        .map_err(SocketError::Io)?;
    AsyncWriteExt::flush(inner).await.map_err(SocketError::Io)?;
    Ok(())
}

const ACK_TIMEOUT_MS: u32 = 5000;

/// Sends one message: validates length, writes and flushes, signals "data ready", then (local only) waits up to 5s for recipient "data acked".
pub(crate) async fn send_message(
    inner: &Arc<Mutex<InnerSocket>>,
    msg: &[u8],
) -> Result<(), SocketError> {
    if msg.len() > MAX_MESSAGE_LEN as usize {
        let e = SocketError::Io(io::Error::new(
            io::ErrorKind::InvalidInput,
            format!(
                "message length {} exceeds max {}",
                msg.len(),
                MAX_MESSAGE_LEN
            ),
        ));
        crate::logger::log_error(&e);
        return Err(e);
    }
    let mut guard = inner.lock().await;
    if matches!(*guard, InnerSocket::Closed) {
        let e = SocketError::Io(connection_lost_error());
        crate::logger::log_error(&e);
        return Err(e);
    }
    if let Err(e) = write_message(&mut *guard, msg).await {
        crate::logger::log_error(&e);
        return Err(e);
    }
    let (data_ready_name, ack_name) = match &*guard {
        #[cfg(windows)]
        InnerSocket::NamedPipe(_, name) => (
            Some(name.clone()),
            Some(event_sender::data_acked_name_from_data_ready(name)),
        ),
        #[cfg(unix)]
        InnerSocket::Unix(_, name) => (
            Some(name.clone()),
            Some(event_sender::data_acked_name_from_data_ready(name)),
        ),
        _ => (None, None),
    };
    drop(guard);
    if let (Some(dr), Some(ack)) = (data_ready_name, ack_name) {
        event_sender::signal_named_event(&dr);
        match tokio::task::spawn_blocking(move || event_handler::wait_for_ack(&ack, ACK_TIMEOUT_MS)).await {
            Ok(Ok(())) => {}
            Ok(Err(AckWaitError::Timeout)) | Ok(Err(AckWaitError::CreateOpenFailed)) | Err(_) => {
                let e = SocketError::RecipientAckTimeout;
                crate::logger::log_error(&e);
                return Err(e);
            }
        }
    }
    Ok(())
}