use bytes::{Buf, Bytes};
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite};
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
use super::io::{ConnCtx, with_state};
use crate::connection::RecvMode;
pub struct ConnStream {
ctx: ConnCtx,
write_closed: bool,
fill_buf: Option<Bytes>,
}
impl ConnStream {
pub fn new(ctx: ConnCtx) -> Self {
ConnStream {
ctx,
write_closed: false,
fill_buf: None,
}
}
pub fn conn_ctx(&self) -> ConnCtx {
self.ctx
}
fn is_recv_closed(driver: &mut crate::backend::Driver, conn_index: u32) -> bool {
driver
.connections
.get(conn_index)
.map(|c| matches!(c.recv_mode, RecvMode::Closed))
.unwrap_or(true)
}
fn flush_pending_recv(driver: &mut crate::backend::Driver, conn_index: u32) {
#[cfg(has_io_uring)]
if let Some(pending) = driver.pending_recv_bufs[conn_index as usize].take() {
let data = unsafe { std::slice::from_raw_parts(pending.ptr, pending.len as usize) };
driver.accumulators.append(conn_index, data);
driver.pending_replenish.push(pending.bid);
}
#[cfg(not(has_io_uring))]
let _ = (driver, conn_index);
}
}
impl AsyncRead for ConnStream {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if let Some(ref mut cached) = self.fill_buf {
if !cached.is_empty() {
let n = cached.len().min(buf.len());
buf[..n].copy_from_slice(&cached[..n]);
cached.advance(n);
if cached.is_empty() {
self.fill_buf = None;
}
return Poll::Ready(Ok(n));
}
self.fill_buf = None;
}
let conn_index = self.ctx.conn_index;
with_state(|driver, executor| {
Self::flush_pending_recv(driver, conn_index);
let data = driver.accumulators.data(conn_index);
if !data.is_empty() {
let n = data.len().min(buf.len());
buf[..n].copy_from_slice(&data[..n]);
driver.accumulators.consume(conn_index, n);
return Poll::Ready(Ok(n));
}
if Self::is_recv_closed(driver, conn_index) {
return Poll::Ready(Ok(0));
}
executor.recv_waiters[conn_index as usize] = true;
Poll::Pending
})
}
}
impl AsyncBufRead for ConnStream {
fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let this = self.get_mut();
if matches!(this.fill_buf.as_ref(), Some(b) if !b.is_empty()) {
return Poll::Ready(Ok(this.fill_buf.as_ref().unwrap().as_ref()));
}
this.fill_buf = None;
let conn_index = this.ctx.conn_index;
let frozen = with_state(|driver, executor| {
Self::flush_pending_recv(driver, conn_index);
let data = driver.accumulators.data(conn_index);
if !data.is_empty() {
return Ok(Some(driver.accumulators.take_frozen(conn_index)));
}
if Self::is_recv_closed(driver, conn_index) {
return Ok(None); }
executor.recv_waiters[conn_index as usize] = true;
Err(()) });
match frozen {
Err(()) => Poll::Pending,
Ok(None) => {
Poll::Ready(Ok(&[]))
}
Ok(Some(bytes)) => {
this.fill_buf = Some(bytes);
Poll::Ready(Ok(this.fill_buf.as_ref().unwrap().as_ref()))
}
}
}
fn consume(mut self: Pin<&mut Self>, amt: usize) {
if let Some(ref mut b) = self.fill_buf {
let advance = amt.min(b.len());
b.advance(advance);
if b.is_empty() {
self.fill_buf = None;
}
}
}
}
impl AsyncWrite for ConnStream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let ctx = self.ctx;
let max_len = with_state(|driver, _| driver.send_copy_pool.slot_size() as usize);
let write_len = buf.len().min(max_len);
match ctx.send_nowait(&buf[..write_len]) {
Ok(()) => Poll::Ready(Ok(write_len)),
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if !self.write_closed {
self.write_closed = true;
self.ctx.shutdown_write();
}
Poll::Ready(Ok(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn conn_stream_size() {
assert!(std::mem::size_of::<ConnStream>() <= 48);
}
}