use crate::error::SmuxError;
use crate::frame::{Cmd, Frame, UpdHeader, INITIAL_PEER_WINDOW};
use crate::session::Session;
use bytes::Bytes;
use futures::future::BoxFuture;
use std::io;
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::{Mutex, Notify};
use tokio::time::{sleep_until, Instant as TokioInstant};
pub struct Stream {
pub(crate) inner: Arc<StreamInner>,
read_future: Arc<Mutex<Option<BoxFuture<'static, io::Result<(Vec<u8>, usize)>>>>>,
write_future: Arc<Mutex<Option<BoxFuture<'static, io::Result<usize>>>>>,
shutdown_future: Arc<Mutex<Option<BoxFuture<'static, io::Result<()>>>>>,
}
impl Stream {
pub(crate) fn new(inner: Arc<StreamInner>) -> Self {
Self {
inner,
read_future: Arc::new(Mutex::new(None)),
write_future: Arc::new(Mutex::new(None)),
shutdown_future: Arc::new(Mutex::new(None)),
}
}
pub fn id(&self) -> u32 {
self.inner.id
}
pub async fn close(&self) -> io::Result<()> {
self.inner.close().await
}
pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.read(buf).await
}
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf).await
}
pub async fn write_all(&self, buf: &[u8]) -> io::Result<()> {
let mut remaining = buf;
while !remaining.is_empty() {
let n = self.write(remaining).await?;
remaining = &remaining[n..];
}
Ok(())
}
pub async fn set_read_deadline(&self, deadline: Option<Instant>) {
self.inner.set_read_deadline(deadline).await;
}
pub async fn set_write_deadline(&self, deadline: Option<Instant>) {
self.inner.set_write_deadline(deadline).await;
}
pub async fn set_deadline(&self, deadline: Option<Instant>) {
self.inner.set_read_deadline(deadline).await;
self.inner.set_write_deadline(deadline).await;
}
pub fn get_die_notifier(&self) -> Arc<Notify> {
self.inner.get_die_notifier()
}
pub async fn is_closed(&self) -> bool {
self.inner.is_closed().await
}
pub fn local_addr(&self) -> Option<&str> {
self.inner.session.local_addr()
}
pub fn remote_addr(&self) -> Option<&str> {
self.inner.session.remote_addr()
}
pub async fn copy_to<W>(&self, writer: &mut W) -> io::Result<u64>
where
W: tokio::io::AsyncWrite + Unpin,
{
use tokio::io::AsyncWriteExt;
let mut total = 0u64;
let mut buf = vec![0u8; 8192];
loop {
match self.read(&mut buf).await {
Ok(0) => {
break;
}
Ok(n) => {
writer.write_all(&buf[..n]).await?;
total += n as u64;
}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
break;
}
Err(e) => {
return Err(e);
}
}
}
writer.flush().await?;
Ok(total)
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let remaining = buf.remaining();
if remaining == 0 {
return Poll::Ready(Ok(()));
}
let read_future = self.read_future.clone();
let mut future_lock = match read_future.try_lock() {
Ok(lock) => lock,
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
if future_lock.is_none() {
let inner = Arc::clone(&self.inner);
let fut = Box::pin(async move {
let mut temp_buf = vec![0u8; remaining];
let n = inner.read(&mut temp_buf).await?;
Ok::<(Vec<u8>, usize), io::Error>((temp_buf, n))
});
*future_lock = Some(fut);
}
if let Some(mut fut) = future_lock.take() {
match fut.as_mut().poll(cx) {
Poll::Ready(Ok((temp_buf, n))) => {
buf.put_slice(&temp_buf[..n]);
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
Poll::Ready(Err(e))
}
Poll::Pending => {
*future_lock = Some(fut);
Poll::Pending
}
}
} else {
Poll::Ready(Ok(()))
}
}
}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let write_future = self.write_future.clone();
let mut future_lock = match write_future.try_lock() {
Ok(lock) => lock,
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
if future_lock.is_none() {
let inner = Arc::clone(&self.inner);
let data = buf.to_vec();
let fut = Box::pin(async move {
inner.write(&data).await
});
*future_lock = Some(fut);
}
if let Some(mut fut) = future_lock.take() {
match fut.as_mut().poll(cx) {
Poll::Ready(result) => {
Poll::Ready(result)
}
Poll::Pending => {
*future_lock = Some(fut);
Poll::Pending
}
}
} else {
Poll::Ready(Ok(0))
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let shutdown_future = self.shutdown_future.clone();
let mut future_lock = match shutdown_future.try_lock() {
Ok(lock) => lock,
Err(_) => {
cx.waker().wake_by_ref();
return Poll::Pending;
}
};
if future_lock.is_none() {
let inner = Arc::clone(&self.inner);
let fut = Box::pin(async move {
inner.close().await
});
*future_lock = Some(fut);
}
if let Some(mut fut) = future_lock.take() {
match fut.as_mut().poll(cx) {
Poll::Ready(result) => {
Poll::Ready(result)
}
Poll::Pending => {
*future_lock = Some(fut);
Poll::Pending
}
}
} else {
Poll::Ready(Ok(()))
}
}
}
pub(crate) struct StreamInner {
pub id: u32,
pub session: Arc<Session>,
buffers: Arc<Mutex<Vec<Bytes>>>,
frame_size: usize,
read_event: Arc<Notify>,
die: Arc<Mutex<bool>>,
fin_event: Arc<Mutex<bool>>,
read_deadline: Arc<Mutex<Option<Instant>>>,
write_deadline: Arc<Mutex<Option<Instant>>>,
die_notify: Arc<Notify>,
num_read: AtomicU32,
num_written: AtomicU32,
incr: AtomicU32,
peer_consumed: AtomicU32,
peer_window: AtomicU32,
update_event: Arc<Notify>,
}
impl StreamInner {
pub fn new(id: u32, frame_size: usize, session: Arc<Session>) -> Self {
Self {
id,
session,
buffers: Arc::new(Mutex::new(Vec::new())),
frame_size,
read_event: Arc::new(Notify::new()),
die: Arc::new(Mutex::new(false)),
fin_event: Arc::new(Mutex::new(false)),
read_deadline: Arc::new(Mutex::new(None)),
write_deadline: Arc::new(Mutex::new(None)),
die_notify: Arc::new(Notify::new()),
num_read: AtomicU32::new(0),
num_written: AtomicU32::new(0),
incr: AtomicU32::new(0),
peer_consumed: AtomicU32::new(0),
peer_window: AtomicU32::new(INITIAL_PEER_WINDOW),
update_event: Arc::new(Notify::new()),
}
}
pub async fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
loop {
match self.try_read(buf).await {
Ok(n) if n > 0 => return Ok(n),
Err(e) if e.kind() != io::ErrorKind::WouldBlock => return Err(e),
_ => {}
}
if let Err(e) = self.wait_read().await {
return Err(e);
}
}
}
async fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
if self.session.config.version == 2 {
return self.try_read_v2(buf).await;
}
if buf.is_empty() {
return Ok(0);
}
let mut buffers = self.buffers.lock().await;
let mut n = 0;
if let Some(first_buf) = buffers.first_mut() {
n = buf.len().min(first_buf.len());
buf[..n].copy_from_slice(&first_buf[..n]);
*first_buf = first_buf.slice(n..);
if first_buf.is_empty() {
buffers.remove(0);
}
}
if n > 0 {
self.session.return_tokens(n);
return Ok(n);
}
if *self.die.lock().await {
return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
async fn try_read_v2(&self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let mut notify_consumed = 0u32;
let mut n = 0;
{
let mut buffers = self.buffers.lock().await;
if let Some(first_buf) = buffers.first_mut() {
n = buf.len().min(first_buf.len());
buf[..n].copy_from_slice(&first_buf[..n]);
*first_buf = first_buf.slice(n..);
if first_buf.is_empty() {
buffers.remove(0);
}
}
let n_u32 = n as u32;
self.num_read.fetch_add(n_u32, Ordering::Relaxed);
let incr = self.incr.fetch_add(n_u32, Ordering::Relaxed) + n_u32;
if incr >= (self.session.config.max_stream_buffer / 2) as u32
|| self.num_read.load(Ordering::Relaxed) == n_u32
{
notify_consumed = self.num_read.load(Ordering::Relaxed);
self.incr.store(0, Ordering::Relaxed);
}
}
if n > 0 {
self.session.return_tokens(n);
if notify_consumed > 0 {
if let Err(e) = self.send_window_update(notify_consumed).await {
return Err(io::Error::new(io::ErrorKind::Other, e));
}
}
return Ok(n);
}
if *self.die.lock().await {
return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
}
Err(io::Error::from(io::ErrorKind::WouldBlock))
}
async fn wait_read(&self) -> io::Result<()> {
loop {
{
let buffers = self.buffers.lock().await;
if !buffers.is_empty() {
return Ok(());
}
}
if *self.fin_event.lock().await {
let buffers = self.buffers.lock().await;
if buffers.is_empty() {
return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
}
return Ok(());
}
if *self.die.lock().await {
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
let deadline_opt = *self.read_deadline.lock().await;
if let Some(deadline) = deadline_opt {
let now = Instant::now();
if now >= deadline {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
let tokio_deadline = TokioInstant::from_std(deadline);
tokio::select! {
_ = self.read_event.notified() => {
continue;
}
_ = sleep_until(tokio_deadline) => {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
} else {
self.read_event.notified().await;
}
}
}
pub async fn write(&self, buf: &[u8]) -> io::Result<usize> {
if self.session.config.version == 2 {
return self.write_v2(buf).await;
}
if *self.fin_event.lock().await || *self.die.lock().await {
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
if let Some(deadline) = *self.write_deadline.lock().await {
if Instant::now() >= deadline {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
let mut sent = 0;
let mut remaining = buf;
while !remaining.is_empty() {
if let Some(deadline) = *self.write_deadline.lock().await {
if Instant::now() >= deadline {
if sent > 0 {
return Ok(sent);
}
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
let size = remaining.len().min(self.frame_size);
let frame = Frame {
ver: self.session.config.version,
cmd: Cmd::Psh,
sid: self.id,
data: Bytes::copy_from_slice(&remaining[..size]),
};
match self.session.write_frame(frame, false).await {
Ok(n) => {
sent += n;
remaining = &remaining[size..];
}
Err(e) => {
if sent > 0 {
return Ok(sent);
}
return Err(io::Error::from(e));
}
}
}
Ok(sent)
}
async fn write_v2(&self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
if *self.fin_event.lock().await || *self.die.lock().await {
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
if let Some(deadline) = *self.write_deadline.lock().await {
if Instant::now() >= deadline {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
let mut sent = 0;
let mut remaining = buf;
while !remaining.is_empty() {
if let Some(deadline) = *self.write_deadline.lock().await {
if Instant::now() >= deadline {
if sent > 0 {
return Ok(sent);
}
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
let num_written = self.num_written.load(Ordering::Relaxed);
let peer_consumed = self.peer_consumed.load(Ordering::Relaxed);
let peer_window = self.peer_window.load(Ordering::Relaxed);
let inflight = num_written.wrapping_sub(peer_consumed) as i32;
if inflight < 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
SmuxError::Consumed,
));
}
let win = peer_window as i32 - inflight;
if win <= 0 {
if sent > 0 {
return Ok(sent);
}
if let Some(deadline) = *self.write_deadline.lock().await {
let now = Instant::now();
if now >= deadline {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
let tokio_deadline = TokioInstant::from_std(deadline);
tokio::select! {
_ = self.update_event.notified() => {
continue;
}
_ = sleep_until(tokio_deadline) => {
return Err(io::Error::from(io::ErrorKind::TimedOut));
}
}
} else {
self.update_event.notified().await;
continue;
}
}
let n = remaining.len().min(win as usize).min(self.frame_size);
let frame = Frame {
ver: self.session.config.version,
cmd: Cmd::Psh,
sid: self.id,
data: Bytes::copy_from_slice(&remaining[..n]),
};
match self.session.write_frame(frame, false).await {
Ok(written) => {
self.num_written.fetch_add(n as u32, Ordering::Relaxed);
sent += written;
remaining = &remaining[n..];
}
Err(e) => {
if sent > 0 {
return Ok(sent);
}
return Err(io::Error::from(e));
}
}
}
Ok(sent)
}
async fn send_window_update(&self, consumed: u32) -> Result<(), SmuxError> {
let upd = UpdHeader {
consumed,
window: self.session.config.max_stream_buffer as u32,
};
let frame = Frame {
ver: self.session.config.version,
cmd: Cmd::Upd,
sid: self.id,
data: upd.encode(),
};
self.session
.write_frame(frame, true)
.await
.map_err(|_| SmuxError::Io(io::ErrorKind::Other))?;
Ok(())
}
pub async fn push_bytes(&self, data: Bytes) {
let mut buffers = self.buffers.lock().await;
buffers.push(data);
self.read_event.notify_one();
}
pub fn fin(&self) {
let fin_event = Arc::clone(&self.fin_event);
let read_event = Arc::clone(&self.read_event);
tokio::spawn(async move {
*fin_event.lock().await = true;
read_event.notify_one();
});
}
pub async fn update(&self, consumed: u32, window: u32) {
self.peer_consumed.store(consumed, Ordering::Relaxed);
self.peer_window.store(window, Ordering::Relaxed);
self.update_event.notify_one();
}
pub async fn close(&self) -> io::Result<()> {
if *self.die.lock().await {
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
*self.die.lock().await = true;
self.die_notify.notify_waiters();
let frame = Frame {
ver: self.session.config.version,
cmd: Cmd::Fin,
sid: self.id,
data: Bytes::new(),
};
self.session.write_frame(frame, false).await?;
self.session.stream_closed(self.id).await;
Ok(())
}
pub async fn recycle_tokens(&self) -> usize {
let buffers = self.buffers.lock().await;
let total: usize = buffers.iter().map(|b| b.len()).sum();
total
}
pub async fn set_read_deadline(&self, deadline: Option<Instant>) {
*self.read_deadline.lock().await = deadline;
}
pub async fn set_write_deadline(&self, deadline: Option<Instant>) {
*self.write_deadline.lock().await = deadline;
}
pub fn get_die_notifier(&self) -> Arc<Notify> {
Arc::clone(&self.die_notify)
}
pub async fn is_closed(&self) -> bool {
*self.die.lock().await
}
pub(crate) async fn has_buffered_data(&self) -> bool {
let buffers = self.buffers.lock().await;
!buffers.is_empty()
}
pub(crate) async fn is_fin(&self) -> bool {
*self.fin_event.lock().await
}
pub(crate) fn get_read_event_notifier(&self) -> Arc<Notify> {
Arc::clone(&self.read_event)
}
}