use std::fmt;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_io::{AsyncRead, AsyncWrite};
use crate::transport::{RecvCompletion, Transport};
pub struct AsyncRdmaStream<T: Transport> {
transport: T,
recv_pending: Option<(usize, usize, usize)>,
write_pending: Option<usize>,
eof: bool,
}
impl<T: Transport> Unpin for AsyncRdmaStream<T> {}
impl<T: Transport> fmt::Debug for AsyncRdmaStream<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AsyncRdmaStream")
.field("local_addr", &self.transport.local_addr())
.field("peer_addr", &self.transport.peer_addr())
.field("eof", &self.eof)
.field("recv_pending", &self.recv_pending.is_some())
.field("write_pending", &self.write_pending.is_some())
.finish()
}
}
impl<T: Transport> AsyncRdmaStream<T> {
pub fn new(transport: T) -> Self {
Self {
transport,
recv_pending: None,
write_pending: None,
eof: false,
}
}
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_read(cx, buf)).await
}
pub async fn write(&mut self, data: &[u8]) -> io::Result<usize> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_write(cx, data)).await
}
pub async fn write_all(&mut self, mut data: &[u8]) -> io::Result<()> {
while !data.is_empty() {
let n = self.write(data).await?;
data = &data[n..];
}
Ok(())
}
pub async fn shutdown(&mut self) -> io::Result<()> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_close(cx)).await
}
pub fn peer_addr(&self) -> Option<SocketAddr> {
self.transport.peer_addr()
}
pub fn local_addr(&self) -> Option<SocketAddr> {
self.transport.local_addr()
}
}
impl<T: Transport> Drop for AsyncRdmaStream<T> {
fn drop(&mut self) {
let _ = self.transport.disconnect();
}
}
impl<T: Transport> AsyncRead for AsyncRdmaStream<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if buf.is_empty() || this.eof {
return Poll::Ready(Ok(0));
}
if let Some((buf_idx, offset, total_len)) = this.recv_pending {
let remaining = total_len - offset;
let copy_len = remaining.min(buf.len());
buf[..copy_len]
.copy_from_slice(&this.transport.recv_buf(buf_idx)[offset..offset + copy_len]);
if copy_len < remaining {
this.recv_pending = Some((buf_idx, offset + copy_len, total_len));
} else {
this.recv_pending = None;
let _ = this.transport.repost_recv(buf_idx);
}
return Poll::Ready(Ok(copy_len));
}
let mut completions = [RecvCompletion::default(); 1];
loop {
match this.transport.poll_recv(cx, &mut completions) {
Poll::Pending => {
if this.transport.poll_disconnect(cx) {
this.eof = true;
return Poll::Ready(Ok(0));
}
return Poll::Pending;
}
Poll::Ready(Err(_)) => {
this.eof = true;
return Poll::Ready(Ok(0));
}
Poll::Ready(Ok(0)) => {
continue;
}
Poll::Ready(Ok(_)) => {
let c = &completions[0];
if c.byte_len == 0 {
return Poll::Ready(Ok(0));
}
let copy_len = c.byte_len.min(buf.len());
buf[..copy_len]
.copy_from_slice(&this.transport.recv_buf(c.buf_idx)[..copy_len]);
if copy_len < c.byte_len {
this.recv_pending = Some((c.buf_idx, copy_len, c.byte_len));
} else {
let _ = this.transport.repost_recv(c.buf_idx);
}
return Poll::Ready(Ok(copy_len));
}
}
}
}
}
impl<T: Transport> AsyncWrite for AsyncRdmaStream<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
if this.eof {
this.write_pending = None;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"connection closed",
)));
}
if this.write_pending.is_none() {
match this.transport.send_copy(buf) {
Ok(0) => {
let mut completions = [RecvCompletion::default(); 1];
let _ = this.transport.poll_recv(cx, &mut completions);
}
Ok(n) => {
this.write_pending = Some(n);
}
Err(e) => return Poll::Ready(Err(io::Error::other(e))),
}
}
if this.write_pending.is_none() {
match this.transport.poll_send_completion(cx) {
Poll::Pending => {
if this.transport.poll_disconnect(cx) {
this.eof = true;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"connection closed",
)));
}
return Poll::Pending;
}
Poll::Ready(Err(e)) => {
this.eof = true;
return Poll::Ready(Err(io::Error::other(e)));
}
Poll::Ready(Ok(())) => match this.transport.send_copy(buf) {
Ok(0) => {
let mut completions = [RecvCompletion::default(); 1];
let _ = this.transport.poll_recv(cx, &mut completions);
return Poll::Pending;
}
Ok(n) => this.write_pending = Some(n),
Err(e) => return Poll::Ready(Err(io::Error::other(e))),
},
}
}
let len = this.write_pending.unwrap();
match this.transport.poll_send_completion(cx) {
Poll::Pending => {
if this.transport.poll_disconnect(cx) {
this.eof = true;
this.write_pending = None;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"connection closed",
)));
}
Poll::Pending
}
Poll::Ready(Err(e)) => {
this.eof = true;
this.write_pending = None;
Poll::Ready(Err(io::Error::other(e)))
}
Poll::Ready(Ok(())) => {
this.write_pending = None;
Poll::Ready(Ok(len))
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
if this.eof {
this.write_pending = None;
return Poll::Ready(Ok(()));
}
if this.write_pending.is_some() {
match this.transport.poll_send_completion(cx) {
Poll::Pending => {
if this.transport.poll_disconnect(cx) {
this.eof = true;
this.write_pending = None;
return Poll::Ready(Ok(()));
}
return Poll::Pending;
}
Poll::Ready(_) => {
this.write_pending = None;
}
}
}
let _ = this.transport.disconnect();
this.eof = true;
Poll::Ready(Ok(()))
}
}