use super::{
H2Connection, H2ErrorCode,
stream_state::{StreamEvent, StreamLifecycle},
};
use crate::{
Body, Buffer, Headers,
headers::hpack::{FieldSection, PseudoHeaders},
};
use atomic_waker::AtomicWaker;
use futures_lite::io::{AsyncRead, AsyncWrite};
use std::{
collections::VecDeque,
fmt, io,
pin::Pin,
sync::{
Arc, Mutex, MutexGuard,
atomic::{AtomicBool, AtomicU64, Ordering},
},
task::{Context, Poll},
};
pub struct H2Transport {
connection: Arc<H2Connection>,
stream_id: u32,
state: Arc<StreamState>,
}
impl fmt::Debug for H2Transport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("H2Transport")
.field("stream_id", &self.stream_id)
.finish_non_exhaustive()
}
}
impl H2Transport {
pub(super) fn new(
connection: Arc<H2Connection>,
stream_id: u32,
state: Arc<StreamState>,
) -> Self {
Self {
connection,
stream_id,
state,
}
}
pub fn stream_id(&self) -> u32 {
self.stream_id
}
pub fn connection(&self) -> &Arc<H2Connection> {
&self.connection
}
}
impl Drop for H2Transport {
fn drop(&mut self) {
if !self.connection.streams_lock().contains_key(&self.stream_id) {
log::trace!(
"h2 stream {}: H2Transport dropped on already-released stream",
self.stream_id,
);
return;
}
let lifecycle = *self.state.lifecycle_lock();
if lifecycle.is_closed() {
log::trace!(
"h2 stream {}: H2Transport dropped on wire-closed stream — releasing",
self.stream_id,
);
self.state
.send
.transport_dropped
.store(true, Ordering::Release);
} else if !lifecycle.send_closed()
&& self.state.send.submit_resolved.load(Ordering::Acquire)
{
log::trace!(
"h2 stream {}: H2Transport dropped (upgrade tunnel) — scheduling graceful close",
self.stream_id,
);
self.state.request_close();
} else {
log::debug!(
"h2 stream {}: H2Transport dropped mid-stream — RST_STREAM(Cancel)",
self.stream_id,
);
self.state.request_reset(H2ErrorCode::Cancel);
}
self.state.needs_servicing.store(true, Ordering::Release);
self.state.send.outbound_write_waker.wake();
self.connection.outbound_waker().wake();
}
}
impl AsyncRead for H2Transport {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
out: &mut [u8],
) -> Poll<io::Result<usize>> {
if out.is_empty() {
return Poll::Ready(Ok(0));
}
let recv_state = &self.state.recv;
let connection = &*self.connection;
if !recv_state.is_reading.swap(true, Ordering::AcqRel) {
self.state.needs_servicing.store(true, Ordering::Release);
connection.outbound_waker().wake();
}
let mut recv = recv_state.buf.lock().expect("recv buf mutex poisoned");
let take = out.len().min(recv.len());
if take > 0 {
out[..take].copy_from_slice(&recv[..take]);
recv.ignore_front(take);
drop(recv);
recv_state
.bytes_consumed
.fetch_add(take as u64, Ordering::AcqRel);
self.state.needs_servicing.store(true, Ordering::Release);
connection.outbound_waker().wake();
return Poll::Ready(Ok(take));
}
recv_state.waker.register(cx.waker());
drop(recv);
if self.state.lifecycle_lock().recv_closed() {
return Poll::Ready(Ok(0));
}
Poll::Pending
}
}
impl AsyncWrite for H2Transport {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if self.state.lifecycle_lock().send_closed() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
let send = &self.state.send;
let cap = self.connection.context.config.response_buffer_max_len;
let mut outbound = send.outbound.lock().expect("outbound buf mutex poisoned");
if outbound.len() >= cap {
send.outbound_write_waker.register(cx.waker());
if outbound.len() >= cap {
return Poll::Pending;
}
}
let take = (cap - outbound.len()).min(buf.len());
log::trace!(
"h2 stream {}: H2Transport::poll_write appending {take}/{} bytes to outbound ring",
self.stream_id,
buf.len(),
);
outbound.extend_from_slice(&buf[..take]);
drop(outbound);
self.connection.outbound_waker().wake();
Poll::Ready(Ok(take))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
log::trace!(
"h2 stream {}: H2Transport::poll_close enqueuing Close",
self.stream_id,
);
self.state.request_close();
self.state.needs_servicing.store(true, Ordering::Release);
self.state.send.outbound_write_waker.wake();
self.connection.outbound_waker().wake();
Poll::Ready(Ok(()))
}
}
#[derive(Debug)]
pub(super) enum OutboundPart {
Headers {
pseudos: PseudoHeaders<'static>,
headers: Headers,
},
Body(Body),
Trailers(Headers),
Close,
Reset(H2ErrorCode),
}
impl OutboundPart {
pub(super) fn is_terminal(&self) -> bool {
matches!(self, Self::Trailers(_) | Self::Close | Self::Reset(_))
}
}
#[derive(Debug, Default)]
pub(super) struct StreamState {
lifecycle: Mutex<StreamLifecycle>,
pub(super) recv: RecvState,
pub(super) send: SendState,
pub(super) needs_servicing: AtomicBool,
}
impl StreamState {
pub(super) fn lifecycle_lock(&self) -> MutexGuard<'_, StreamLifecycle> {
self.lifecycle.lock().expect("lifecycle mutex poisoned")
}
pub(super) fn apply_event(
&self,
event: StreamEvent,
) -> Result<(), super::stream_state::StreamProtocolError> {
self.lifecycle_lock().on_event(event)
}
pub(super) fn stage(&self, parts: impl IntoIterator<Item = OutboundPart>) {
self.send
.queue
.lock()
.expect("send queue mutex poisoned")
.extend(parts);
self.needs_servicing.store(true, Ordering::Release);
}
pub(super) fn request_close(&self) {
if self.lifecycle_lock().send_closed() {
return;
}
let mut queue = self.send.queue.lock().expect("send queue mutex poisoned");
if queue.back().is_none_or(|p| !p.is_terminal()) {
queue.push_back(OutboundPart::Close);
}
drop(queue);
self.needs_servicing.store(true, Ordering::Release);
}
pub(super) fn request_reset(&self, code: H2ErrorCode) {
let mut queue = self.send.queue.lock().expect("send queue mutex poisoned");
if matches!(queue.back(), Some(OutboundPart::Reset(_))) {
return;
}
queue.clear();
queue.push_back(OutboundPart::Reset(code));
drop(queue);
self.needs_servicing.store(true, Ordering::Release);
}
}
#[derive(Debug, Default)]
pub(super) struct RecvState {
pub(super) buf: Mutex<Buffer>,
pub(super) waker: AtomicWaker,
pub(super) is_reading: AtomicBool,
pub(super) bytes_consumed: AtomicU64,
pub(super) trailers: Mutex<Option<Headers>>,
pub(super) response_headers: Mutex<Option<FieldSection<'static>>>,
pub(super) first_response_headers_seen: AtomicBool,
pub(super) response_headers_waker: AtomicWaker,
}
#[derive(Debug, Default)]
pub(super) struct SendState {
pub(super) queue: Mutex<VecDeque<OutboundPart>>,
pub(super) outbound: Mutex<Buffer>,
pub(super) outbound_write_waker: AtomicWaker,
pub(super) submit_resolved: AtomicBool,
pub(super) completion_result: Mutex<Option<io::Result<()>>>,
pub(super) completion_waker: AtomicWaker,
pub(super) transport_dropped: AtomicBool,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HttpContext;
#[test]
fn request_close_enqueues_single_close() {
let state = StreamState::default();
*state.lifecycle_lock() = StreamLifecycle::Open;
state.request_close();
state.request_close();
let queue = state.send.queue.lock().unwrap();
assert_eq!(queue.len(), 1, "second request_close is a no-op");
assert!(matches!(queue.front(), Some(OutboundPart::Close)));
}
#[test]
fn request_close_noop_when_send_closed() {
let state = StreamState::default();
*state.lifecycle_lock() = StreamLifecycle::HalfClosedLocal;
state.request_close();
assert!(
state.send.queue.lock().unwrap().is_empty(),
"no terminator queued once the send half is closed"
);
}
#[test]
fn request_reset_clears_queue_and_is_first_wins() {
let state = StreamState::default();
*state.lifecycle_lock() = StreamLifecycle::Open;
state.stage([OutboundPart::Body(Body::default()), OutboundPart::Close]);
state.request_reset(H2ErrorCode::Cancel);
state.request_reset(H2ErrorCode::InternalError);
let queue = state.send.queue.lock().unwrap();
assert_eq!(queue.len(), 1, "queue cleared, single reset");
assert!(
matches!(
queue.front(),
Some(OutboundPart::Reset(H2ErrorCode::Cancel))
),
"first reset code wins",
);
}
#[test]
fn poll_write_caps_at_response_buffer_max_len() {
use futures_lite::AsyncWrite;
use std::task::{Context, Poll, Waker};
let mut context = HttpContext::new();
context.config.response_buffer_max_len = 16;
let connection = H2Connection::new(Arc::new(context));
let state = Arc::new(StreamState::default());
*state.lifecycle_lock() = StreamLifecycle::Open;
let mut transport = H2Transport::new(connection, 1, state);
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let buf = [0u8; 32];
match Pin::new(&mut transport).poll_write(&mut cx, &buf) {
Poll::Ready(Ok(n)) => assert_eq!(n, 16, "should accept exactly cap bytes"),
other => panic!("expected Ready(Ok(16)), got {other:?}"),
}
}
}