use super::{H2Connection, H2ErrorCode, lifecycle::StreamLifecycle};
use crate::{Buffer, Headers, headers::hpack::FieldSection};
use atomic_waker::AtomicWaker;
use futures_lite::io::{AsyncRead, AsyncWrite};
use std::{
fmt, io,
pin::Pin,
sync::{
Arc, Mutex, MutexGuard,
atomic::{AtomicBool, AtomicU64, Ordering},
},
task::{Context, Poll},
};
pub struct H2Transport {
connection: Arc<H2Connection>,
stream_id: u32,
state: Arc<StreamState>,
}
impl fmt::Debug for H2Transport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("H2Transport")
.field("stream_id", &self.stream_id)
.finish_non_exhaustive()
}
}
impl H2Transport {
pub(super) fn new(
connection: Arc<H2Connection>,
stream_id: u32,
state: Arc<StreamState>,
) -> Self {
Self {
connection,
stream_id,
state,
}
}
pub fn stream_id(&self) -> u32 {
self.stream_id
}
pub fn connection(&self) -> &Arc<H2Connection> {
&self.connection
}
}
impl Drop for H2Transport {
fn drop(&mut self) {
if !self.connection.streams_lock().contains_key(&self.stream_id) {
log::trace!(
"h2 stream {}: H2Transport dropped on already-released stream",
self.stream_id,
);
return;
}
let mut lifecycle = self.state.lifecycle_lock();
match &*lifecycle {
StreamLifecycle::Reset(_) | StreamLifecycle::ResetRequested(_) => {
log::trace!(
"h2 stream {}: H2Transport dropped on already-reset stream",
self.stream_id,
);
}
StreamLifecycle::AwaitingRelease => {
log::trace!(
"h2 stream {}: H2Transport dropped while already AwaitingRelease",
self.stream_id,
);
}
StreamLifecycle::UpgradeOpen { recv_eof } => {
let recv_eof = *recv_eof;
log::trace!(
"h2 stream {}: H2Transport dropped (upgrade) — scheduling graceful close",
self.stream_id,
);
*lifecycle = StreamLifecycle::UpgradeClosing {
recv_eof,
pending_trailers: None,
};
drop(lifecycle);
self.state.needs_servicing.store(true, Ordering::Release);
self.state.send.outbound_waker.wake();
self.connection.outbound_waker().wake();
}
StreamLifecycle::UpgradeClosing { .. } => {
log::trace!(
"h2 stream {}: H2Transport dropped — graceful close already in flight",
self.stream_id,
);
drop(lifecycle);
self.state.needs_servicing.store(true, Ordering::Release);
self.state.send.outbound_waker.wake();
self.connection.outbound_waker().wake();
}
_ => {
let send_done = self.state.send.completed.load(Ordering::Acquire);
let recv_done = lifecycle.recv_eof();
if send_done && recv_done {
log::trace!(
"h2 stream {}: H2Transport dropped on wire-closed stream — releasing",
self.stream_id,
);
*lifecycle = StreamLifecycle::AwaitingRelease;
drop(lifecycle);
self.state.needs_servicing.store(true, Ordering::Release);
self.connection.outbound_waker().wake();
} else {
log::debug!(
"h2 stream {}: H2Transport dropped mid-stream — RST_STREAM(Cancel)",
self.stream_id,
);
*lifecycle = StreamLifecycle::ResetRequested(H2ErrorCode::Cancel);
drop(lifecycle);
self.state.needs_servicing.store(true, Ordering::Release);
self.connection.outbound_waker().wake();
}
}
}
}
}
impl AsyncRead for H2Transport {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
out: &mut [u8],
) -> Poll<io::Result<usize>> {
if out.is_empty() {
return Poll::Ready(Ok(0));
}
let recv_state = &self.state.recv;
let connection = &*self.connection;
if !recv_state.is_reading.swap(true, Ordering::AcqRel) {
self.state.needs_servicing.store(true, Ordering::Release);
connection.outbound_waker().wake();
}
let mut recv = recv_state.buf.lock().expect("recv buf mutex poisoned");
let take = out.len().min(recv.len());
if take > 0 {
out[..take].copy_from_slice(&recv[..take]);
recv.ignore_front(take);
drop(recv);
recv_state
.bytes_consumed
.fetch_add(take as u64, Ordering::AcqRel);
self.state.needs_servicing.store(true, Ordering::Release);
connection.outbound_waker().wake();
return Poll::Ready(Ok(take));
}
recv_state.waker.register(cx.waker());
drop(recv);
if self.state.lifecycle_lock().recv_eof() {
return Poll::Ready(Ok(0));
}
Poll::Pending
}
}
impl AsyncWrite for H2Transport {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let send = &self.state.send;
if !matches!(
&*self.state.lifecycle_lock(),
StreamLifecycle::UpgradeOpen { .. }
) {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
let cap = self.connection.context.config.response_buffer_max_len;
let mut outbound = send.outbound.lock().expect("outbound buf mutex poisoned");
if outbound.len() >= cap {
send.outbound_write_waker.register(cx.waker());
if outbound.len() >= cap {
return Poll::Pending;
}
}
let take = (cap - outbound.len()).min(buf.len());
log::trace!(
"h2 stream {}: H2Transport::poll_write appending {take}/{} bytes to outbound queue",
self.stream_id,
buf.len(),
);
outbound.extend_from_slice(&buf[..take]);
drop(outbound);
send.outbound_waker.wake();
self.connection.outbound_waker().wake();
Poll::Ready(Ok(take))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
log::trace!(
"h2 stream {}: H2Transport::poll_close marking outbound closed",
self.stream_id,
);
let mut lifecycle = self.state.lifecycle_lock();
if let StreamLifecycle::UpgradeOpen { recv_eof } = &*lifecycle {
*lifecycle = StreamLifecycle::UpgradeClosing {
recv_eof: *recv_eof,
pending_trailers: None,
};
}
drop(lifecycle);
self.state.send.outbound_waker.wake();
self.connection.outbound_waker().wake();
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Default)]
pub(super) struct StreamState {
pub(super) lifecycle: Mutex<StreamLifecycle>,
pub(super) recv: RecvState,
pub(super) send: SendState,
pub(super) needs_servicing: AtomicBool,
}
impl StreamState {
pub(super) fn lifecycle_lock(&self) -> MutexGuard<'_, StreamLifecycle> {
self.lifecycle.lock().expect("lifecycle mutex poisoned")
}
}
#[derive(Debug, Default)]
pub(super) struct RecvState {
pub(super) buf: Mutex<Buffer>,
pub(super) waker: AtomicWaker,
pub(super) is_reading: AtomicBool,
pub(super) bytes_consumed: AtomicU64,
pub(super) trailers: Mutex<Option<Headers>>,
pub(super) response_headers: Mutex<Option<FieldSection<'static>>>,
pub(super) first_response_headers_seen: AtomicBool,
pub(super) response_headers_waker: AtomicWaker,
}
#[derive(Debug, Default)]
pub(super) struct SendState {
pub(super) completed: AtomicBool,
pub(super) completion_result: Mutex<Option<io::Result<()>>>,
pub(super) completion_waker: AtomicWaker,
pub(super) outbound: Mutex<Buffer>,
pub(super) outbound_waker: AtomicWaker,
pub(super) outbound_write_waker: AtomicWaker,
}
#[derive(Debug)]
pub(super) struct H2OutboundReader {
state: Arc<StreamState>,
stream_id: u32,
}
impl H2OutboundReader {
pub(super) fn new(state: Arc<StreamState>, stream_id: u32) -> Self {
Self { state, stream_id }
}
}
impl AsyncRead for H2OutboundReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
out: &mut [u8],
) -> Poll<io::Result<usize>> {
if out.is_empty() {
return Poll::Ready(Ok(0));
}
let send = &self.state.send;
let mut outbound = send.outbound.lock().expect("outbound buf mutex poisoned");
let take = out.len().min(outbound.len());
if take > 0 {
out[..take].copy_from_slice(&outbound[..take]);
outbound.ignore_front(take);
log::trace!(
"h2 stream {}: H2OutboundReader::poll_read drained {take} bytes",
self.stream_id,
);
drop(outbound);
send.outbound_write_waker.wake();
return Poll::Ready(Ok(take));
}
send.outbound_waker.register(cx.waker());
let lifecycle_says_eof = !matches!(
&*self.state.lifecycle_lock(),
StreamLifecycle::UpgradeOpen { .. }
);
if lifecycle_says_eof {
log::trace!(
"h2 stream {}: H2OutboundReader::poll_read EOF (lifecycle past UpgradeOpen, queue \
empty)",
self.stream_id,
);
return Poll::Ready(Ok(0));
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HttpContext;
use futures_lite::{AsyncRead, AsyncWrite};
use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll, Wake, Waker},
};
struct CountingWaker(AtomicBool);
impl Wake for CountingWaker {
fn wake(self: Arc<Self>) {
self.0.store(true, Ordering::Release);
}
}
fn pair_with_cap(cap: usize) -> (H2Transport, H2OutboundReader) {
let mut context = HttpContext::new();
context.config.response_buffer_max_len = cap;
let connection = H2Connection::new(Arc::new(context));
let state = Arc::new(StreamState::default());
*state.lifecycle_lock() = StreamLifecycle::UpgradeOpen { recv_eof: true };
let transport = H2Transport::new(connection.clone(), 1, state.clone());
let reader = H2OutboundReader::new(state, 1);
(transport, reader)
}
#[test]
fn poll_write_caps_at_response_buffer_max_len() {
let (mut transport, _reader) = pair_with_cap(16);
let waker = Waker::from(Arc::new(CountingWaker(AtomicBool::new(false))));
let mut cx = Context::from_waker(&waker);
let buf = [0u8; 32];
match Pin::new(&mut transport).poll_write(&mut cx, &buf) {
Poll::Ready(Ok(n)) => assert_eq!(n, 16, "should accept exactly cap bytes"),
other => panic!("expected Ready(Ok(16)), got {other:?}"),
}
}
#[test]
fn poll_write_returns_pending_when_full_and_drain_wakes() {
let (mut transport, mut reader) = pair_with_cap(8);
let counting = Arc::new(CountingWaker(AtomicBool::new(false)));
let writer_waker = Waker::from(counting.clone());
let mut writer_cx = Context::from_waker(&writer_waker);
let buf = [0u8; 8];
match Pin::new(&mut transport).poll_write(&mut writer_cx, &buf) {
Poll::Ready(Ok(8)) => {}
other => panic!("expected Ready(Ok(8)), got {other:?}"),
}
let extra = [0u8; 4];
match Pin::new(&mut transport).poll_write(&mut writer_cx, &extra) {
Poll::Pending => {}
other => panic!("expected Pending, got {other:?}"),
}
assert!(
!counting.0.load(Ordering::Acquire),
"writer waker should not have fired yet"
);
let reader_waker = Waker::noop().clone();
let mut reader_cx = Context::from_waker(&reader_waker);
let mut sink = [0u8; 4];
match Pin::new(&mut reader).poll_read(&mut reader_cx, &mut sink) {
Poll::Ready(Ok(4)) => {}
other => panic!("expected Ready(Ok(4)), got {other:?}"),
}
assert!(
counting.0.load(Ordering::Acquire),
"drain should have woken the writer"
);
match Pin::new(&mut transport).poll_write(&mut writer_cx, &extra) {
Poll::Ready(Ok(4)) => {}
other => panic!("expected Ready(Ok(4)), got {other:?}"),
}
}
}