use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::Sink;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::Mutex;
use tokio_tungstenite::tungstenite::Message;
use crate::event_handler::{self, AckWaitError};
use crate::event_sender;
use crate::socket::{
connection_lost_error, InnerSocket, MessageFramed, SocketError, WebSocketAdapter,
MAX_MESSAGE_LEN,
};
#[cfg(unix)]
use crate::socket::PolledUnixStream;
#[cfg(windows)]
use crate::socket::PolledNamedPipe;
macro_rules! impl_polled_async_write {
($ty:ty) => {
impl AsyncWrite for $ty {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.disconnected {
return Poll::Ready(Err(connection_lost_error()));
}
if self.last_check.elapsed() >= self.interval {
match self.run_liveness_check(cx) {
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => {}
Poll::Pending => return Poll::Pending,
}
}
tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.disconnected {
return Poll::Ready(Err(connection_lost_error()));
}
if self.last_check.elapsed() >= self.interval {
match self.run_liveness_check(cx) {
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => {}
Poll::Pending => return Poll::Pending,
}
}
tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.inner), cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.disconnected {
return Poll::Ready(Err(connection_lost_error()));
}
tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.inner), cx)
}
}
};
}
#[cfg(unix)]
impl_polled_async_write!(PolledUnixStream);
#[cfg(windows)]
impl_polled_async_write!(PolledNamedPipe);
#[cfg(any(unix, windows))]
impl<T> AsyncWrite for MessageFramed<T>
where
T: tokio::io::AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
this.write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.write_buf.is_empty() {
return Poll::Ready(Ok(()));
}
let payload: Vec<u8> = std::mem::take(&mut this.write_buf);
let payload_len = payload.len();
if payload_len > MAX_MESSAGE_LEN as usize {
this.write_buf = payload;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("message length {} exceeds max {}", payload_len, MAX_MESSAGE_LEN),
)));
}
let len = payload.len() as u32;
let header = len.to_be_bytes();
let mut inner = Pin::new(&mut this.inner);
let mut header_written = 0usize;
while header_written < 4 {
match inner.as_mut().poll_write(cx, &header[header_written..]) {
Poll::Ready(Ok(0)) => {
this.write_buf = payload;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame length",
)));
}
Poll::Ready(Ok(n)) => header_written += n,
Poll::Ready(Err(e)) => {
this.write_buf = payload;
return Poll::Ready(Err(e));
}
Poll::Pending => {
this.write_buf = payload;
return Poll::Pending;
}
}
}
let mut written = 0usize;
while written < payload.len() {
match inner.as_mut().poll_write(cx, &payload[written..]) {
Poll::Ready(Ok(0)) => {
this.write_buf = payload;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write frame payload",
)));
}
Poll::Ready(Ok(n)) => written += n,
Poll::Ready(Err(e)) => {
this.write_buf = payload;
return Poll::Ready(Err(e));
}
Poll::Pending => {
this.write_buf = payload;
return Poll::Pending;
}
}
}
inner.as_mut().poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
}
}
impl AsyncWrite for WebSocketAdapter {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.write_buf.extend_from_slice(buf);
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if self.write_buf.is_empty() {
return Poll::Ready(Ok(()));
}
let data = std::mem::take(&mut self.write_buf);
match self.stream.as_mut().poll_ready(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
Poll::Pending => {
self.write_buf = data;
return Poll::Pending;
}
}
match self.stream.as_mut().start_send(Message::Binary(data.into())) {
Ok(()) => {}
Err(e) => return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
}
match self.stream.as_mut().poll_flush(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, e))),
Poll::Pending => Poll::Pending,
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.stream
.as_mut()
.poll_close(cx)
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
}
impl AsyncWrite for InnerSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.get_mut() {
InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
InnerSocket::WebSocket(s) => Pin::new(s).poll_write(cx, buf),
#[cfg(unix)]
InnerSocket::Unix(s, _) => tokio::io::AsyncWrite::poll_write(Pin::new(s), cx, buf),
#[cfg(windows)]
InnerSocket::NamedPipe(s, _) => tokio::io::AsyncWrite::poll_write(Pin::new(s), cx, buf),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
InnerSocket::WebSocket(s) => Pin::new(s).poll_flush(cx),
#[cfg(unix)]
InnerSocket::Unix(s, _) => tokio::io::AsyncWrite::poll_flush(Pin::new(s), cx),
#[cfg(windows)]
InnerSocket::NamedPipe(s, _) => tokio::io::AsyncWrite::poll_flush(Pin::new(s), cx),
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match self.get_mut() {
InnerSocket::Closed => Poll::Ready(Err(connection_lost_error())),
InnerSocket::WebSocket(s) => Pin::new(s).poll_shutdown(cx),
#[cfg(unix)]
InnerSocket::Unix(s, _) => tokio::io::AsyncWrite::poll_shutdown(Pin::new(s), cx),
#[cfg(windows)]
InnerSocket::NamedPipe(s, _) => tokio::io::AsyncWrite::poll_shutdown(Pin::new(s), cx),
}
}
}
pub(crate) async fn write_message(inner: &mut InnerSocket, msg: &[u8]) -> Result<(), SocketError> {
AsyncWriteExt::write_all(inner, msg)
.await
.map_err(SocketError::Io)?;
AsyncWriteExt::flush(inner).await.map_err(SocketError::Io)?;
Ok(())
}
const ACK_TIMEOUT_MS: u32 = 5000;
pub(crate) async fn send_message(
inner: &Arc<Mutex<InnerSocket>>,
msg: &[u8],
) -> Result<(), SocketError> {
if msg.len() > MAX_MESSAGE_LEN as usize {
let e = SocketError::Io(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"message length {} exceeds max {}",
msg.len(),
MAX_MESSAGE_LEN
),
));
crate::logger::log_error(&e);
return Err(e);
}
let mut guard = inner.lock().await;
if matches!(*guard, InnerSocket::Closed) {
let e = SocketError::Io(connection_lost_error());
crate::logger::log_error(&e);
return Err(e);
}
if let Err(e) = write_message(&mut *guard, msg).await {
crate::logger::log_error(&e);
return Err(e);
}
let (data_ready_name, ack_name) = match &*guard {
#[cfg(windows)]
InnerSocket::NamedPipe(_, name) => (
Some(name.clone()),
Some(event_sender::data_acked_name_from_data_ready(name)),
),
#[cfg(unix)]
InnerSocket::Unix(_, name) => (
Some(name.clone()),
Some(event_sender::data_acked_name_from_data_ready(name)),
),
_ => (None, None),
};
drop(guard);
if let (Some(dr), Some(ack)) = (data_ready_name, ack_name) {
event_sender::signal_named_event(&dr);
match tokio::task::spawn_blocking(move || event_handler::wait_for_ack(&ack, ACK_TIMEOUT_MS)).await {
Ok(Ok(())) => {}
Ok(Err(AckWaitError::Timeout)) | Ok(Err(AckWaitError::CreateOpenFailed)) | Err(_) => {
let e = SocketError::RecipientAckTimeout;
crate::logger::log_error(&e);
return Err(e);
}
}
}
Ok(())
}