use crate::protocol::ConnectDatagram;
use async_std::net::SocketAddr;
use async_std::pin::Pin;
use futures::io::IoSlice;
use futures::task::{Context, Poll};
use futures::{AsyncWrite, Sink};
use log::*;
use std::error::Error;
pub use futures::{SinkExt, StreamExt};
use std::fmt::Debug;
#[derive(Debug)]
pub enum ConnectionWriteError {
ConnectionClosed,
IoError(std::io::Error),
}
impl Error for ConnectionWriteError {}
impl std::fmt::Display for ConnectionWriteError {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ConnectionWriteError::ConnectionClosed => {
formatter.write_str("cannot send message when connection is closed")
}
ConnectionWriteError::IoError(err) => std::fmt::Display::fmt(&err, formatter),
}
}
}
pub struct ConnectionWriter {
local_addr: SocketAddr,
peer_addr: SocketAddr,
write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
pending_writes: Vec<Vec<u8>>,
closed: bool,
}
impl ConnectionWriter {
pub fn new(
local_addr: SocketAddr,
peer_addr: SocketAddr,
write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
) -> Self {
Self {
local_addr,
peer_addr,
write_stream,
pending_writes: Vec::new(),
closed: false,
}
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr.clone()
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr.clone()
}
pub fn is_closed(&self) -> bool {
self.closed
}
pub(crate) fn write_pending_bytes(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<(), ConnectionWriteError>> {
if self.pending_writes.len() > 0 {
let stream = self.write_stream.as_mut();
match stream.poll_flush(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => {
let stream = self.write_stream.as_mut();
let pending = self.pending_writes.split_off(0);
let writeable_vec: Vec<IoSlice> =
pending.iter().map(|p| IoSlice::new(p)).collect();
trace!("sending pending bytes to network stream");
match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(bytes_written)) => {
trace!("wrote {} bytes to network stream", bytes_written);
self.pending_writes.clear();
Poll::Ready(Ok(()))
}
Poll::Ready(Err(err)) => {
error!("Encountered error when writing to network stream");
Poll::Ready(Err(ConnectionWriteError::IoError(err)))
}
}
}
Poll::Ready(Err(err)) => {
error!("Encountered error when flushing network stream");
Poll::Ready(Err(ConnectionWriteError::IoError(err)))
}
}
} else {
Poll::Ready(Ok(()))
}
}
}
impl Sink<ConnectDatagram> for ConnectionWriter {
type Error = ConnectionWriteError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.is_closed() {
trace!("connection is closed, cannot send message");
Poll::Ready(Err(ConnectionWriteError::ConnectionClosed))
} else {
trace!("connection ready to send message");
Poll::Ready(Ok(()))
}
}
fn start_send(mut self: Pin<&mut Self>, item: ConnectDatagram) -> Result<(), Self::Error> {
trace!("preparing datagram to be queued for sending");
let buffer = item.encode();
let msg_size = buffer.len();
trace!("serialized pending message into {} bytes", msg_size);
self.pending_writes.push(buffer);
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.write_pending_bytes(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.closed = true;
debug!("Closing the sink for connection with {}", self.peer_addr);
match self.write_pending_bytes(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => {
let stream = self.write_stream.as_mut();
match stream.poll_close(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
Poll::Ready(Err(err)) => Poll::Ready(Err(ConnectionWriteError::IoError(err))),
}
}
err => err,
}
}
}