use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::Stream;
use futures_util::StreamExt;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
use crate::event_handler;
use crate::event_sender;
use crate::socket::{
connection_lost_error, InnerSocket, MessageFramed, SocketError, WebSocketAdapter,
MAX_MESSAGE_LEN,
};
#[cfg(unix)]
use crate::socket::PolledUnixStream;
#[cfg(windows)]
use crate::socket::PolledNamedPipe;
macro_rules! impl_polled_async_read {
($ty:ty) => {
impl AsyncRead for $ty {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if self.disconnected {
return Poll::Ready(Err(connection_lost_error()));
}
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
if let Some(b) = self.peek.take() {
buf.put_slice(&[b]);
return Poll::Ready(Ok(()));
}
if self.last_check.elapsed() >= self.interval {
match self.run_liveness_check(cx) {
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => {}
Poll::Pending => return Poll::Pending,
}
}
tokio::io::AsyncRead::poll_read(Pin::new(&mut self.inner), cx, buf)
}
}
};
}
#[cfg(unix)]
impl_polled_async_read!(PolledUnixStream);
#[cfg(windows)]
impl_polled_async_read!(PolledNamedPipe);
#[cfg(any(unix, windows))]
impl<T> AsyncRead for MessageFramed<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
let this = self.as_mut().get_mut();
let mut inner = Pin::new(&mut this.inner);
loop {
if let Some(ref mut read_buf) = this.read_buf {
let len: usize = read_buf.len();
if this.payload_filled == len && this.read_cursor < len {
let to_copy = (len - this.read_cursor).min(buf.remaining());
buf.put_slice(&read_buf[this.read_cursor..this.read_cursor + to_copy]);
this.read_cursor += to_copy;
if this.read_cursor >= len {
this.read_buf = None;
this.read_cursor = 0;
this.payload_filled = 0;
}
return Poll::Ready(Ok(()));
}
if this.read_cursor == len {
this.read_buf = None;
this.read_cursor = 0;
this.payload_filled = 0;
continue;
}
if this.payload_filled < len {
let mut payload_read_buf = ReadBuf::new(&mut read_buf[this.payload_filled..]);
match inner.as_mut().poll_read(cx, &mut payload_read_buf) {
Poll::Ready(Ok(())) => {
let n = payload_read_buf.filled().len();
if n == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"connection closed while reading frame payload",
)));
}
this.payload_filled += n;
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
continue;
}
}
match MessageFramed::<T>::poll_read_fill_length(
&mut inner,
cx,
&mut this.length_buf,
&mut this.length_filled,
) {
Poll::Ready(Ok(frame_len)) => {
this.length_filled = 0;
if frame_len == 0 {
this.read_buf = Some(Vec::new());
this.payload_filled = 0;
this.read_cursor = 0;
continue;
}
this.read_buf = Some(vec![0u8; frame_len as usize]);
this.payload_filled = 0;
this.read_cursor = 0;
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
}
}
impl AsyncRead for WebSocketAdapter {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
loop {
let (from, to_copy, len) = match &self.read_buf {
Some(d) => {
let from = self.read_cursor;
let len: usize = d.len();
let to_copy = (len - from).min(buf.remaining());
(from, to_copy, len)
}
None => (0, 0, 0),
};
if to_copy > 0 {
let copy: Vec<u8> = self.read_buf.as_ref().unwrap()[from..from + to_copy].to_vec();
buf.put_slice(©);
self.read_cursor += to_copy;
if self.read_cursor >= len {
self.read_buf = None;
self.read_cursor = 0;
}
return Poll::Ready(Ok(()));
}
match self.stream.as_mut().poll_next(cx) {
Poll::Ready(Some(Ok(Message::Binary(data)))) => {
self.read_buf = Some(data.to_vec());
self.read_cursor = 0;
}
Poll::Ready(Some(Ok(Message::Text(t)))) => {
self.read_buf = Some(t.as_bytes().to_vec());
self.read_cursor = 0;
}
Poll::Ready(Some(Ok(_))) => continue,
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e)))
}
Poll::Ready(None) => return Poll::Ready(Ok(())),
Poll::Pending => return Poll::Pending,
}
}
}
}
impl AsyncRead for InnerSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.get_mut() {
InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
InnerSocket::WebSocket(s) => Pin::new(s).poll_read(cx, buf),
#[cfg(unix)]
InnerSocket::Unix(s, _) => tokio::io::AsyncRead::poll_read(Pin::new(s), cx, buf),
#[cfg(windows)]
InnerSocket::NamedPipe(s, _) => tokio::io::AsyncRead::poll_read(Pin::new(s), cx, buf),
}
}
}
async fn read_one_framed_message<S>(s: &mut S) -> Result<Vec<u8>, SocketError>
where
S: AsyncReadExt + Unpin + ?Sized,
{
let mut len_buf = [0u8; 4];
AsyncReadExt::read_exact(s, &mut len_buf)
.await
.map_err(SocketError::Io)?;
let len = u32::from_be_bytes(len_buf);
if len > MAX_MESSAGE_LEN {
return Err(SocketError::Io(io::Error::new(
io::ErrorKind::InvalidData,
format!("frame length {} exceeds max {}", len, MAX_MESSAGE_LEN),
)));
}
let mut buf = vec![0u8; len as usize];
AsyncReadExt::read_exact(s, &mut buf)
.await
.map_err(SocketError::Io)?;
Ok(buf)
}
pub(crate) async fn read_one_message(inner: &mut InnerSocket) -> Result<Vec<u8>, SocketError> {
match inner {
InnerSocket::Closed => return Err(SocketError::Io(connection_lost_error())),
InnerSocket::WebSocket(a) => {
let mut buf = vec![0u8; MAX_MESSAGE_LEN as usize];
let n = AsyncReadExt::read(&mut *a, &mut buf)
.await
.map_err(SocketError::Io)?;
buf.truncate(n);
Ok(buf)
}
#[cfg(unix)]
InnerSocket::Unix(s, _) => read_one_framed_message(&mut *s).await,
#[cfg(windows)]
InnerSocket::NamedPipe(s, _) => read_one_framed_message(&mut *s).await,
}
}
pub(crate) fn spawn_message_handler<F, Fut>(arc: Arc<Mutex<InnerSocket>>, mut callback: F)
where
F: FnMut(Result<Vec<u8>, SocketError>) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send,
{
tokio::spawn(async move {
let event_name: Option<String> = {
let guard = arc.lock().await;
match &*guard {
#[cfg(windows)]
InnerSocket::NamedPipe(_, name) => Some(name.clone()),
#[cfg(unix)]
InnerSocket::Unix(_, name) => Some(name.clone()),
_ => None,
}
};
if let Some(ref name) = event_name {
if let Some(mut stream) = event_handler::named_event_stream(name) {
while let Some(()) = stream.next().await {
let result = {
let mut guard = arc.lock().await;
read_one_message(&mut *guard).await
};
if result.is_ok() {
let ack_name = event_sender::data_acked_name_from_data_ready(name);
event_sender::signal_named_event(&ack_name);
}
let is_err = result.is_err();
if let Err(ref e) = result {
crate::logger::log_error(e);
}
callback(result).await;
if is_err {
return;
}
}
} else {
let e = SocketError::Io(io::Error::new(
io::ErrorKind::Other,
"failed to create or open data-ready event/semaphore for message handler",
));
crate::logger::log_error(&e);
callback(Err(e)).await;
}
return;
}
loop {
let result = {
let mut guard = arc.lock().await;
read_one_message(&mut *guard).await
};
let is_err = result.is_err();
if let Err(ref e) = result {
crate::logger::log_error(e);
}
callback(result).await;
if is_err {
break;
}
}
});
}