ipcez 0.1.0

Rust library for ipcez.
Documentation
//! AsyncRead implementations for socket transport types (WebSocket, Unix, named pipe).
//! Also drives the incoming message loop and signals "data acked" after each successful read (local transports).

use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::Stream;
use futures_util::StreamExt;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;

use crate::event_handler;
use crate::event_sender;
use crate::socket::{
    connection_lost_error, InnerSocket, MessageFramed, SocketError, WebSocketAdapter,
    MAX_MESSAGE_LEN,
};
#[cfg(unix)]
use crate::socket::PolledUnixStream;
#[cfg(windows)]
use crate::socket::PolledNamedPipe;

macro_rules! impl_polled_async_read {
    ($ty:ty) => {
        impl AsyncRead for $ty {
            fn poll_read(
                mut self: Pin<&mut Self>,
                cx: &mut Context<'_>,
                buf: &mut ReadBuf<'_>,
            ) -> Poll<io::Result<()>> {
                if self.disconnected {
                    return Poll::Ready(Err(connection_lost_error()));
                }
                if buf.remaining() == 0 {
                    return Poll::Ready(Ok(()));
                }
                if let Some(b) = self.peek.take() {
                    buf.put_slice(&[b]);
                    return Poll::Ready(Ok(()));
                }
                if self.last_check.elapsed() >= self.interval {
                    match self.run_liveness_check(cx) {
                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                        Poll::Ready(Ok(())) => {}
                        Poll::Pending => return Poll::Pending,
                    }
                }
                tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf)
            }
        }
    };
}

#[cfg(unix)]
impl_polled_async_read!(PolledUnixStream);

#[cfg(windows)]
impl_polled_async_read!(PolledNamedPipe);

#[cfg(any(unix, windows))]
impl<T> AsyncRead for MessageFramed<T>
where
    T: AsyncRead + AsyncWrite + Unpin,
{
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        if buf.remaining() == 0 {
            return Poll::Ready(Ok(()));
        }
        let this = self.as_mut().get_mut();
        let mut inner = Pin::new(&mut this.inner);

        loop {
            if let Some(ref mut read_buf) = this.read_buf {
                let len: usize = read_buf.len();
                if this.payload_filled == len && this.read_cursor < len {
                    let to_copy = (len - this.read_cursor).min(buf.remaining());
                    buf.put_slice(&read_buf[this.read_cursor..this.read_cursor + to_copy]);
                    this.read_cursor += to_copy;
                    if this.read_cursor >= len {
                        this.read_buf = None;
                        this.read_cursor = 0;
                        this.payload_filled = 0;
                    }
                    return Poll::Ready(Ok(()));
                }
                if this.read_cursor == len {
                    this.read_buf = None;
                    this.read_cursor = 0;
                    this.payload_filled = 0;
                    continue;
                }
                if this.payload_filled < len {
                    let mut payload_read_buf = ReadBuf::new(&mut read_buf[this.payload_filled..]);
                    match inner.as_mut().poll_read(cx, &mut payload_read_buf) {
                        Poll::Ready(Ok(())) => {
                            let n = payload_read_buf.filled().len();
                            if n == 0 {
                                return Poll::Ready(Err(io::Error::new(
                                    io::ErrorKind::UnexpectedEof,
                                    "connection closed while reading frame payload",
                                )));
                            }
                            this.payload_filled += n;
                        }
                        Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                        Poll::Pending => return Poll::Pending,
                    }
                    continue;
                }
            }

            match MessageFramed::<T>::poll_read_fill_length(
                &mut inner,
                cx,
                &mut this.length_buf,
                &mut this.length_filled,
            ) {
                Poll::Ready(Ok(frame_len)) => {
                    this.length_filled = 0;
                    if frame_len == 0 {
                        this.read_buf = Some(Vec::new());
                        this.payload_filled = 0;
                        this.read_cursor = 0;
                        continue;
                    }
                    this.read_buf = Some(vec![0u8; frame_len as usize]);
                    this.payload_filled = 0;
                    this.read_cursor = 0;
                }
                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

impl AsyncRead for WebSocketAdapter {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        if buf.remaining() == 0 {
            return Poll::Ready(Ok(()));
        }
        loop {
            let (from, to_copy, len) = match &self.read_buf {
                Some(d) => {
                    let from = self.read_cursor;
                    let len: usize = d.len();
                    let to_copy = (len - from).min(buf.remaining());
                    (from, to_copy, len)
                }
                None => (0, 0, 0),
            };
            if to_copy > 0 {
                let copy: Vec<u8> = self.read_buf.as_ref().unwrap()[from..from + to_copy].to_vec();
                buf.put_slice(&copy);
                self.read_cursor += to_copy;
                if self.read_cursor >= len {
                    self.read_buf = None;
                    self.read_cursor = 0;
                }
                return Poll::Ready(Ok(()));
            }
            match self.stream.as_mut().poll_next(cx) {
                Poll::Ready(Some(Ok(Message::Binary(data)))) => {
                    self.read_buf = Some(data.to_vec());
                    self.read_cursor = 0;
                }
                Poll::Ready(Some(Ok(Message::Text(t)))) => {
                    self.read_buf = Some(t.as_bytes().to_vec());
                    self.read_cursor = 0;
                }
                Poll::Ready(Some(Ok(_))) => continue,
                Poll::Ready(Some(Err(e))) => {
                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)))
                }
                Poll::Ready(None) => return Poll::Ready(Ok(())),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}

impl AsyncRead for InnerSocket {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        match self.get_mut() {
            InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
            InnerSocket::WebSocket(s) => Pin::new(s).poll_read(cx, buf),
            #[cfg(unix)]
            InnerSocket::Unix(s, _) => tokio::io::AsyncRead::poll_read(Pin::new(s), cx, buf),
            #[cfg(windows)]
            InnerSocket::NamedPipe(s, _) => tokio::io::AsyncRead::poll_read(Pin::new(s), cx, buf),
        }
    }
}

/// Reads one length-prefixed frame (4-byte big-endian length + payload). Used for Unix and named pipe.
async fn read_one_framed_message<S>(s: &mut S) -> Result<Vec<u8>, SocketError>
where
    S: AsyncReadExt + Unpin + ?Sized,
{
    let mut len_buf = [0u8; 4];
    AsyncReadExt::read_exact(s, &mut len_buf)
        .await
        .map_err(SocketError::Io)?;
    let len = u32::from_be_bytes(len_buf);
    if len > MAX_MESSAGE_LEN {
        return Err(SocketError::Io(io::Error::new(
            io::ErrorKind::InvalidData,
            format!("frame length {} exceeds max {}", len, MAX_MESSAGE_LEN),
        )));
    }
    let mut buf = vec![0u8; len as usize];
    AsyncReadExt::read_exact(s, &mut buf)
        .await
        .map_err(SocketError::Io)?;
    Ok(buf)
}

/// Reads exactly one message from the inner transport (one frame for WebSocket, length-prefixed for local).
pub(crate) async fn read_one_message(inner: &mut InnerSocket) -> Result<Vec<u8>, SocketError> {
    match inner {
        InnerSocket::Closed => return Err(SocketError::Io(connection_lost_error())),
        InnerSocket::WebSocket(a) => {
            let mut buf = vec![0u8; MAX_MESSAGE_LEN as usize];
            let n = AsyncReadExt::read(&mut *a, &mut buf)
                .await
                .map_err(SocketError::Io)?;
            buf.truncate(n);
            Ok(buf)
        }
        #[cfg(unix)]
        InnerSocket::Unix(s, _) => read_one_framed_message(&mut *s).await,
        #[cfg(windows)]
        InnerSocket::NamedPipe(s, _) => read_one_framed_message(&mut *s).await,
    }
}

/// Spawns a task that reads incoming messages and invokes `callback` for each.
/// Local transports: after each successful read, signals "data acked" so the sender can complete, then invokes the callback.
pub(crate) fn spawn_message_handler<F, Fut>(arc: Arc<Mutex<InnerSocket>>, mut callback: F)
where
    F: FnMut(Result<Vec<u8>, SocketError>) -> Fut + Send + 'static,
    Fut: std::future::Future<Output = ()> + Send,
{
    tokio::spawn(async move {
        let event_name: Option<String> = {
            let guard = arc.lock().await;
            match &*guard {
                #[cfg(windows)]
                InnerSocket::NamedPipe(_, name) => Some(name.clone()),
                #[cfg(unix)]
                InnerSocket::Unix(_, name) => Some(name.clone()),
                _ => None,
            }
        };

        if let Some(ref name) = event_name {
            if let Some(mut stream) = event_handler::named_event_stream(name) {
                while let Some(()) = stream.next().await {
                    let result = {
                        let mut guard = arc.lock().await;
                        read_one_message(&mut *guard).await
                    };
                    if result.is_ok() {
                        let ack_name = event_sender::data_acked_name_from_data_ready(name);
                        event_sender::signal_named_event(&ack_name);
                    }
                    let is_err = result.is_err();
                    if let Err(ref e) = result {
                        crate::logger::log_error(e);
                    }
                    callback(result).await;
                    if is_err {
                        return;
                    }
                }
            } else {
                let e = SocketError::Io(io::Error::new(
                    io::ErrorKind::Other,
                    "failed to create or open data-ready event/semaphore for message handler",
                ));
                crate::logger::log_error(&e);
                callback(Err(e)).await;
            }
            return;
        }

        loop {
            let result = {
                let mut guard = arc.lock().await;
                read_one_message(&mut *guard).await
            };
            let is_err = result.is_err();
            if let Err(ref e) = result {
                crate::logger::log_error(e);
            }
            callback(result).await;
            if is_err {
                break;
            }
        }
    });
}