use super::H2Connection;
#[cfg(feature = "unstable")]
use crate::h2::transport::H2Transport;
use crate::{
Body, Headers,
h2::transport::{OutboundPart, StreamState},
headers::hpack::PseudoHeaders,
};
use std::{
future::Future,
io,
pin::Pin,
sync::{Arc, atomic::Ordering},
task::{Context, Poll},
};
#[must_use = "futures do nothing unless awaited"]
#[derive(Debug)]
pub struct SubmitSend {
pub(super) stream_id: u32,
pub(super) stream: Option<Arc<StreamState>>,
}
impl Future for SubmitSend {
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Some(state) = &self.stream else {
log::debug!("h2 stream {}: submit_send on closed stream", self.stream_id);
return Poll::Ready(Err(io::ErrorKind::NotConnected.into()));
};
let stream_id = self.stream_id;
let try_take = || -> Option<io::Result<()>> {
state.send.submit_resolved.load(Ordering::Acquire).then(|| {
state
.send
.completion_result
.lock()
.expect("completion_result mutex poisoned")
.take()
.unwrap_or_else(|| {
log::error!(
"h2 stream {stream_id}: submit_resolved without a completion_result — \
driver should write the result before flipping submit_resolved"
);
Ok(())
})
})
};
if let Some(result) = try_take() {
return Poll::Ready(result);
}
if state.lifecycle_lock().is_reset() {
return Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into()));
}
state.send.completion_waker.register(cx.waker());
if let Some(result) = try_take() {
return Poll::Ready(result);
}
if state.lifecycle_lock().is_reset() {
return Poll::Ready(Err(io::ErrorKind::ConnectionAborted.into()));
}
Poll::Pending
}
}
impl H2Connection {
pub(crate) fn submit_send(
&self,
stream_id: u32,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
) -> SubmitSend {
let stream = self.streams_lock().get(&stream_id).cloned();
if let Some(state) = &stream {
state.stage(submission_parts(pseudos, headers, body, true));
self.outbound_waker.wake();
}
SubmitSend { stream_id, stream }
}
pub(crate) fn submit_upgrade(
&self,
stream_id: u32,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
) -> SubmitSend {
let stream = self.streams_lock().get(&stream_id).cloned();
if let Some(state) = &stream {
state.stage(submission_parts(pseudos, headers, body, false));
log::trace!("h2 stream {stream_id}: submit_upgrade — parts staged");
self.outbound_waker.wake();
}
SubmitSend { stream_id, stream }
}
pub(crate) fn submit_trailers(&self, stream_id: u32, trailers: Headers) -> io::Result<()> {
let stream = self
.streams_lock()
.get(&stream_id)
.cloned()
.ok_or(io::ErrorKind::NotConnected)?;
stream.stage([OutboundPart::Trailers(trailers)]);
stream.send.outbound_write_waker.wake();
self.outbound_waker.wake();
log::trace!("h2 stream {stream_id}: submit_trailers staged trailing HEADERS terminator");
Ok(())
}
#[cfg(feature = "unstable")]
pub fn open_stream(
self: &Arc<Self>,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
) -> Option<(u32, SubmitSend, H2Transport)> {
self.open_stream_inner(pseudos, headers, body, false)
.map(|(id, state, transport)| {
(
id,
SubmitSend {
stream_id: id,
stream: Some(state),
},
transport,
)
})
}
#[cfg(feature = "unstable")]
pub fn open_connect_stream(
self: &Arc<Self>,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
) -> Option<(u32, SubmitSend, H2Transport)> {
self.open_stream_inner(pseudos, headers, body, true)
.map(|(id, state, transport)| {
(
id,
SubmitSend {
stream_id: id,
stream: Some(state),
},
transport,
)
})
}
#[cfg(feature = "unstable")]
fn open_stream_inner(
self: &Arc<Self>,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
is_upgrade: bool,
) -> Option<(u32, Arc<StreamState>, H2Transport)> {
if !self.swansong.state().is_running() {
return None;
}
let state = Arc::new(StreamState::default());
state.stage(submission_parts(pseudos, headers, body, !is_upgrade));
let stream_id = {
let mut streams = self.streams_lock();
let stream_id = self
.next_client_stream_id
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |n| {
(n < (1u32 << 31)).then_some(n + 2)
})
.ok()?;
streams.insert(stream_id, state.clone());
stream_id
};
log::trace!("h2 client: open_stream allocated stream {stream_id} (upgrade={is_upgrade})");
self.outbound_waker.wake();
let transport = H2Transport::new(Arc::clone(self), stream_id, state.clone());
Some((stream_id, state, transport))
}
}
fn submission_parts(
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
close: bool,
) -> Vec<OutboundPart> {
let mut parts = Vec::with_capacity(3);
parts.push(OutboundPart::Headers { pseudos, headers });
if let Some(body) = body {
parts.push(OutboundPart::Body(body));
}
if close {
parts.push(OutboundPart::Close);
}
parts
}