use super::H2Connection;
#[cfg(feature = "unstable")]
use crate::h2::transport::H2Transport;
use crate::{Body, Headers, h2::transport::StreamState, headers::hpack::PseudoHeaders};
use std::{
future::Future,
io,
pin::Pin,
sync::{Arc, atomic::Ordering},
task::{Context, Poll},
};
#[must_use = "futures do nothing unless awaited"]
#[derive(Debug)]
pub struct SubmitSend {
pub(super) stream_id: u32,
pub(super) stream: Option<Arc<StreamState>>,
}
impl Future for SubmitSend {
type Output = io::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Some(state) = &self.stream else {
log::debug!("h2 stream {}: submit_send on closed stream", self.stream_id);
return Poll::Ready(Err(io::ErrorKind::NotConnected.into()));
};
let stream_id = self.stream_id;
let try_take = || -> Option<io::Result<()>> {
state.send.completed.load(Ordering::Acquire).then(|| {
state
.send
.completion_result
.lock()
.expect("completion_result mutex poisoned")
.take()
.unwrap_or_else(|| {
log::error!(
"h2 stream {stream_id}: completed without a completion_result — \
driver should write the result before flipping completed"
);
Ok(())
})
})
};
if let Some(result) = try_take() {
return Poll::Ready(result);
}
state.send.completion_waker.register(cx.waker());
if let Some(result) = try_take() {
return Poll::Ready(result);
}
Poll::Pending
}
}
impl H2Connection {
pub(crate) fn submit_send(
&self,
stream_id: u32,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
) -> SubmitSend {
let stream = self.streams_lock().get(&stream_id).cloned();
if let Some(state) = &stream {
*state
.send
.submission
.lock()
.expect("send submission mutex poisoned") =
Some(crate::h2::transport::Submission {
pseudos,
headers,
body,
is_upgrade: false,
});
state.needs_servicing.store(true, Ordering::Release);
self.outbound_waker.wake();
}
SubmitSend { stream_id, stream }
}
pub(crate) fn submit_upgrade(
&self,
stream_id: u32,
pseudos: PseudoHeaders<'static>,
headers: Headers,
) -> SubmitSend {
let stream = self.streams_lock().get(&stream_id).cloned();
if let Some(state) = &stream {
let reader = crate::h2::transport::H2OutboundReader::new(state.clone(), stream_id);
let body = Body::new_streaming(reader, None);
*state
.send
.submission
.lock()
.expect("send submission mutex poisoned") =
Some(crate::h2::transport::Submission {
pseudos,
headers,
body: Some(body),
is_upgrade: true,
});
log::trace!("h2 stream {stream_id}: submit_upgrade — submission staged");
state.needs_servicing.store(true, Ordering::Release);
self.outbound_waker.wake();
}
SubmitSend { stream_id, stream }
}
#[cfg(feature = "unstable")]
pub fn open_stream(
self: &Arc<Self>,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
) -> Option<(u32, SubmitSend, H2Transport)> {
self.open_stream_inner(pseudos, headers, body, false)
.map(|(id, state, transport)| {
(
id,
SubmitSend {
stream_id: id,
stream: Some(state),
},
transport,
)
})
}
#[cfg(feature = "unstable")]
pub fn open_connect_stream(
self: &Arc<Self>,
pseudos: PseudoHeaders<'static>,
headers: Headers,
) -> Option<(u32, H2Transport)> {
let (id, _state, transport) = self.open_stream_inner(pseudos, headers, None, true)?;
Some((id, transport))
}
#[cfg(feature = "unstable")]
fn open_stream_inner(
self: &Arc<Self>,
pseudos: PseudoHeaders<'static>,
headers: Headers,
body: Option<Body>,
is_upgrade: bool,
) -> Option<(u32, Arc<StreamState>, H2Transport)> {
if !self.swansong.state().is_running() {
return None;
}
let stream_id = self
.next_client_stream_id
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |n| {
(n < (1u32 << 31)).then_some(n + 2)
})
.ok()?;
let state = Arc::new(StreamState::default());
let body = if is_upgrade {
let reader = crate::h2::transport::H2OutboundReader::new(state.clone(), stream_id);
Some(Body::new_streaming(reader, None))
} else {
body
};
*state
.send
.submission
.lock()
.expect("send submission mutex poisoned") = Some(crate::h2::transport::Submission {
pseudos,
headers,
body,
is_upgrade,
});
state.needs_servicing.store(true, Ordering::Release);
self.streams_lock().insert(stream_id, state.clone());
log::trace!("h2 client: open_stream allocated stream {stream_id} (upgrade={is_upgrade})");
self.outbound_waker.wake();
let transport = H2Transport::new(Arc::clone(self), stream_id, state.clone());
Some((stream_id, state, transport))
}
}