use super::{H2Connection, H2ErrorCode};
use crate::{
Body, Buffer, Headers,
headers::hpack::{FieldSection, PseudoHeaders},
};
use atomic_waker::AtomicWaker;
use futures_lite::io::{AsyncRead, AsyncWrite};
use std::{
fmt, io,
pin::Pin,
sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicU64, Ordering},
},
task::{Context, Poll},
};
pub struct H2Transport {
connection: Arc<H2Connection>,
stream_id: u32,
state: Arc<StreamState>,
}
impl fmt::Debug for H2Transport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("H2Transport")
.field("stream_id", &self.stream_id)
.finish_non_exhaustive()
}
}
impl H2Transport {
pub(super) fn new(
connection: Arc<H2Connection>,
stream_id: u32,
state: Arc<StreamState>,
) -> Self {
Self {
connection,
stream_id,
state,
}
}
pub fn stream_id(&self) -> u32 {
self.stream_id
}
pub fn connection(&self) -> &Arc<H2Connection> {
&self.connection
}
}
impl Drop for H2Transport {
fn drop(&mut self) {
if !self.connection.streams_lock().contains_key(&self.stream_id) {
return;
}
let send_done = self.state.send.completed.load(Ordering::Acquire);
let recv_done = self.state.recv.eof.load(Ordering::Acquire);
if send_done && recv_done {
log::trace!(
"h2 stream {}: H2Transport dropped on wire-closed stream — releasing",
self.stream_id,
);
self.connection.release_stream(self.stream_id);
return;
}
if self
.state
.send
.outbound_close_requested
.load(Ordering::Acquire)
{
return;
}
log::debug!(
"h2 stream {}: H2Transport dropped mid-stream — RST_STREAM(Cancel) \
(send_done={send_done}, recv_done={recv_done})",
self.stream_id,
);
self.connection
.stream_error(self.stream_id, H2ErrorCode::Cancel);
}
}
impl AsyncRead for H2Transport {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
out: &mut [u8],
) -> Poll<io::Result<usize>> {
if out.is_empty() {
return Poll::Ready(Ok(0));
}
let recv_state = &self.state.recv;
let connection = &*self.connection;
if !recv_state.is_reading.swap(true, Ordering::AcqRel) {
self.state.needs_servicing.store(true, Ordering::Release);
connection.outbound_waker().wake();
}
let mut recv = recv_state.buf.lock().expect("recv buf mutex poisoned");
let take = out.len().min(recv.len());
if take > 0 {
out[..take].copy_from_slice(&recv[..take]);
recv.ignore_front(take);
drop(recv);
recv_state
.bytes_consumed
.fetch_add(take as u64, Ordering::AcqRel);
self.state.needs_servicing.store(true, Ordering::Release);
connection.outbound_waker().wake();
return Poll::Ready(Ok(take));
}
if recv_state.eof.load(Ordering::Acquire) {
return Poll::Ready(Ok(0));
}
recv_state.waker.register(cx.waker());
Poll::Pending
}
}
impl AsyncWrite for H2Transport {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let send = &self.state.send;
if send.outbound_close_requested.load(Ordering::Acquire) {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
let cap = self.connection.context.config.response_buffer_max_len;
let mut outbound = send.outbound.lock().expect("outbound buf mutex poisoned");
if outbound.len() >= cap {
send.outbound_write_waker.register(cx.waker());
if outbound.len() >= cap {
return Poll::Pending;
}
}
let take = (cap - outbound.len()).min(buf.len());
log::trace!(
"h2 stream {}: H2Transport::poll_write appending {take}/{} bytes to outbound queue",
self.stream_id,
buf.len(),
);
outbound.extend_from_slice(&buf[..take]);
drop(outbound);
send.outbound_waker.wake();
self.connection.outbound_waker().wake();
Poll::Ready(Ok(take))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
log::trace!(
"h2 stream {}: H2Transport::poll_close marking outbound closed",
self.stream_id,
);
self.state
.send
.outbound_close_requested
.store(true, Ordering::Release);
self.state.send.outbound_waker.wake();
self.connection.outbound_waker().wake();
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Default)]
pub(super) struct StreamState {
pub(super) recv: RecvState,
pub(super) send: SendState,
pub(super) pending_reset: Mutex<Option<H2ErrorCode>>,
pub(super) pending_release: AtomicBool,
pub(super) needs_servicing: AtomicBool,
}
#[derive(Debug, Default)]
pub(super) struct RecvState {
pub(super) buf: Mutex<Buffer>,
pub(super) eof: AtomicBool,
pub(super) waker: AtomicWaker,
pub(super) is_reading: AtomicBool,
pub(super) bytes_consumed: AtomicU64,
pub(super) trailers: Mutex<Option<Headers>>,
pub(super) response_headers: Mutex<Option<FieldSection<'static>>>,
pub(super) first_response_headers_seen: AtomicBool,
pub(super) response_headers_waker: AtomicWaker,
}
#[derive(Debug, Default)]
pub(super) struct SendState {
pub(super) submission: Mutex<Option<Submission>>,
pub(super) completed: AtomicBool,
pub(super) completion_result: Mutex<Option<io::Result<()>>>,
pub(super) completion_waker: AtomicWaker,
pub(super) outbound: Mutex<Buffer>,
pub(super) outbound_close_requested: AtomicBool,
pub(super) outbound_waker: AtomicWaker,
pub(super) outbound_write_waker: AtomicWaker,
}
#[derive(Debug)]
pub(super) struct H2OutboundReader {
state: Arc<StreamState>,
stream_id: u32,
}
impl H2OutboundReader {
pub(super) fn new(state: Arc<StreamState>, stream_id: u32) -> Self {
Self { state, stream_id }
}
}
impl AsyncRead for H2OutboundReader {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
out: &mut [u8],
) -> Poll<io::Result<usize>> {
if out.is_empty() {
return Poll::Ready(Ok(0));
}
let send = &self.state.send;
let mut outbound = send.outbound.lock().expect("outbound buf mutex poisoned");
let take = out.len().min(outbound.len());
if take > 0 {
out[..take].copy_from_slice(&outbound[..take]);
outbound.ignore_front(take);
log::trace!(
"h2 stream {}: H2OutboundReader::poll_read drained {take} bytes",
self.stream_id,
);
drop(outbound);
send.outbound_write_waker.wake();
return Poll::Ready(Ok(take));
}
send.outbound_waker.register(cx.waker());
if send.outbound_close_requested.load(Ordering::Acquire) {
log::trace!(
"h2 stream {}: H2OutboundReader::poll_read EOF (close_requested + empty)",
self.stream_id,
);
return Poll::Ready(Ok(0));
}
Poll::Pending
}
}
#[derive(Debug)]
pub(super) struct Submission {
pub(super) pseudos: PseudoHeaders<'static>,
pub(super) headers: Headers,
pub(super) body: Option<Body>,
pub(super) is_upgrade: bool,
}
impl Submission {
pub(super) fn field_section(&self) -> FieldSection<'_> {
FieldSection::new(self.pseudos.clone(), &self.headers)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HttpContext;
use futures_lite::{AsyncRead, AsyncWrite};
use std::{
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll, Wake, Waker},
};
struct CountingWaker(AtomicBool);
impl Wake for CountingWaker {
fn wake(self: Arc<Self>) {
self.0.store(true, Ordering::Release);
}
}
fn pair_with_cap(cap: usize) -> (H2Transport, H2OutboundReader) {
let mut context = HttpContext::new();
context.config.response_buffer_max_len = cap;
let connection = H2Connection::new(Arc::new(context));
let state = Arc::new(StreamState::default());
let transport = H2Transport::new(connection.clone(), 1, state.clone());
let reader = H2OutboundReader::new(state, 1);
(transport, reader)
}
#[test]
fn poll_write_caps_at_response_buffer_max_len() {
let (mut transport, _reader) = pair_with_cap(16);
let waker = Waker::from(Arc::new(CountingWaker(AtomicBool::new(false))));
let mut cx = Context::from_waker(&waker);
let buf = [0u8; 32];
match Pin::new(&mut transport).poll_write(&mut cx, &buf) {
Poll::Ready(Ok(n)) => assert_eq!(n, 16, "should accept exactly cap bytes"),
other => panic!("expected Ready(Ok(16)), got {other:?}"),
}
}
#[test]
fn poll_write_returns_pending_when_full_and_drain_wakes() {
let (mut transport, mut reader) = pair_with_cap(8);
let counting = Arc::new(CountingWaker(AtomicBool::new(false)));
let writer_waker = Waker::from(counting.clone());
let mut writer_cx = Context::from_waker(&writer_waker);
let buf = [0u8; 8];
match Pin::new(&mut transport).poll_write(&mut writer_cx, &buf) {
Poll::Ready(Ok(8)) => {}
other => panic!("expected Ready(Ok(8)), got {other:?}"),
}
let extra = [0u8; 4];
match Pin::new(&mut transport).poll_write(&mut writer_cx, &extra) {
Poll::Pending => {}
other => panic!("expected Pending, got {other:?}"),
}
assert!(
!counting.0.load(Ordering::Acquire),
"writer waker should not have fired yet"
);
let reader_waker = Waker::noop().clone();
let mut reader_cx = Context::from_waker(&reader_waker);
let mut sink = [0u8; 4];
match Pin::new(&mut reader).poll_read(&mut reader_cx, &mut sink) {
Poll::Ready(Ok(4)) => {}
other => panic!("expected Ready(Ok(4)), got {other:?}"),
}
assert!(
counting.0.load(Ordering::Acquire),
"drain should have woken the writer"
);
match Pin::new(&mut transport).poll_write(&mut writer_cx, &extra) {
Poll::Ready(Ok(4)) => {}
other => panic!("expected Ready(Ok(4)), got {other:?}"),
}
}
}