use super::Conn;
use crate::h3::H3ClientState;
use futures_lite::AsyncWriteExt;
use std::{
borrow::Cow,
io::{self, ErrorKind},
};
use trillium_http::{
BufWriter, Error, KnownHeaderName, Method, ProtocolSession, ReceivedBodyState, Result, Version,
h3::{Frame, FrameStream, H3Connection, H3Error},
headers::qpack::{FieldSection, PseudoHeaders},
};
fn h3_to_io(e: H3Error) -> io::Error {
match e {
H3Error::Io(io) => io,
H3Error::Protocol(code) => io::Error::new(ErrorKind::InvalidData, code.to_string()),
other => io::Error::other(other),
}
}
impl Conn {
pub(super) async fn try_exec_h3(&mut self) -> Result<bool> {
let Some(h3) = self.h3_client_state.clone() else {
return Ok(false);
};
let origin = self.url.origin();
let (host, port) = if self.http_version == Version::Http3 {
let host = self
.url
.host_str()
.ok_or(Error::UnexpectedUriFormat)?
.to_string();
let port = self
.url
.port_or_known_default()
.ok_or(Error::UnexpectedUriFormat)?;
(host, port)
} else if let Some(entry) = h3.alt_svc.get(&origin)
&& entry.is_usable()
{
(entry.host.clone(), entry.port)
} else {
return Ok(false);
};
let entry = match h3
.get_or_create_quic_conn(&origin, &host, port, &self.config, &self.context)
.await
{
Ok(entry) => entry,
Err(e) => {
log::debug!("H3 connect to {host}:{port} failed: {e}, falling back to H1");
h3.mark_broken(&origin);
return Ok(false);
}
};
if self.protocol.is_some() {
let Some(settings) = entry.h3.peer_settings_ready().await else {
return Err(Error::Closed);
};
if !settings.enable_connect_protocol() {
return Err(Error::ExtendedConnectUnsupported);
}
#[cfg(feature = "webtransport")]
if self.protocol.as_deref() == Some("webtransport") {
if !settings.enable_webtransport() || !settings.h3_datagram() {
return Err(Error::ExtendedConnectUnsupported);
}
let _ = entry.dispatcher.get_or_init(
trillium_server_common::h3::web_transport::WebTransportDispatcher::new,
);
}
}
let (stream_id, transport) = match entry.quic_conn.open_bidi().await {
Ok(t) => t,
Err(e) => {
log::debug!("H3 open_bidi failed: {e}, falling back to H1");
h3.mark_broken(&origin);
return Ok(false);
}
};
self.transport = Some(transport);
self.http_version = Version::Http3;
self.finalize_headers_h3()?;
self.protocol_session = ProtocolSession::Http3 {
connection: entry.h3.clone(),
stream_id,
};
#[cfg(feature = "webtransport")]
if self.protocol.as_deref() == Some("webtransport") {
self.wt_pool_entry = Some(entry.clone());
}
self.send_h3_request().await?;
self.recv_h3_response_headers().await?;
self.update_alt_svc_from_response(&h3);
Ok(true)
}
async fn send_h3_request(&mut self) -> Result<()> {
let Some((h3, stream_id)) = self.protocol_session.as_h3_borrowed() else {
return Err(Error::Closed);
};
let mut pseudo_headers = PseudoHeaders::default()
.with_method(self.method)
.with_authority(
self.authority
.as_deref()
.ok_or(Error::UnexpectedUriFormat)?,
);
if self.method != Method::Connect {
pseudo_headers
.set_path(Some(
self.path.as_deref().ok_or(Error::UnexpectedUriFormat)?,
))
.set_scheme(Some(
self.scheme.as_deref().ok_or(Error::UnexpectedUriFormat)?,
));
}
if let Some(protocol) = &self.protocol {
pseudo_headers.set_protocol(Some(protocol.as_ref()));
if self.method == Method::Connect {
pseudo_headers
.set_path(Some(
self.path.as_deref().ok_or(Error::UnexpectedUriFormat)?,
))
.set_scheme(Some(
self.scheme.as_deref().ok_or(Error::UnexpectedUriFormat)?,
));
}
}
let transport = self.transport.as_mut().ok_or(Error::Closed)?;
let max_buf = self.context.config().response_buffer_max_len();
let mut bufwriter = BufWriter::new_with_buffer(
Vec::with_capacity(self.context.config().response_buffer_len()),
transport,
max_buf,
);
let initial_cap = self.context.config().request_buffer_initial_len();
let max_peer_field_section_size = None;
let field_section = FieldSection::new(pseudo_headers, &self.request_headers);
log::trace!("sending headers:\n{field_section}");
encode_field_section_h3(
h3,
&field_section,
max_peer_field_section_size,
initial_cap,
bufwriter.buffer_mut(),
stream_id,
)?;
let copy_loops_per_yield = self.context.config().copy_loops_per_yield();
if let Some(body) = self.request_body.take() {
let mut body = body.into_h3();
bufwriter.copy_from(&mut body, copy_loops_per_yield).await?;
self.request_trailers = body.trailers();
if let Some(trailers) = &self.request_trailers {
let field_section = FieldSection::new(PseudoHeaders::default(), trailers);
log::trace!("sending trailers:\n{field_section}");
encode_field_section_h3(
h3,
&field_section,
max_peer_field_section_size,
initial_cap,
bufwriter.buffer_mut(),
stream_id,
)?;
}
}
bufwriter.flush().await?;
bufwriter.close().await?;
Ok(())
}
pub(crate) fn finalize_headers_h3(&mut self) -> Result<()> {
if self.headers_finalized {
return Ok(());
}
let authority = self
.request_headers
.remove(KnownHeaderName::Host)
.and_then(|h| h.first().map(|v| Cow::Owned(v.to_string())))
.or_else(|| {
let host = self.url.host_str()?;
Some(Cow::Owned(self.url.port().map_or_else(
|| host.to_string(),
|port| format!("{host}:{port}"),
)))
})
.ok_or(Error::UnexpectedUriFormat)?;
self.authority = Some(authority);
if let Some(target) = &self.request_target
&& self.method == Method::Options
{
self.scheme = Some(Cow::Owned(self.url.scheme().to_string()));
self.path = Some(target.clone());
} else if self.method == Method::Connect && self.protocol.is_none() {
self.scheme = None;
self.path = None;
} else {
self.scheme = Some(Cow::Owned(self.url.scheme().to_string()));
self.path = Some(Cow::Owned({
let mut path = self.url.path().to_string();
if let Some(query) = self.url.query() {
path.push('?');
path.push_str(query);
}
path
}));
}
if let Some(len) = self.body_len()
&& len > 0
{
self.request_headers
.insert(KnownHeaderName::ContentLength, len);
}
self.request_headers.remove_all([
KnownHeaderName::Connection,
KnownHeaderName::TransferEncoding,
KnownHeaderName::KeepAlive,
KnownHeaderName::ProxyConnection,
KnownHeaderName::Upgrade,
KnownHeaderName::Expect,
]);
self.headers_finalized = true;
Ok(())
}
async fn recv_h3_response_headers(&mut self) -> Result<()> {
let Some((h3, stream_id)) = self.protocol_session.as_h3_borrowed() else {
return Err(Error::Closed);
};
let transport = self.transport.as_mut().ok_or(Error::Closed)?;
let mut frame_stream = FrameStream::new(transport, &mut self.buffer);
let field_section = loop {
let Some(mut frame) = frame_stream
.next()
.await
.map_err(|e| Error::Io(h3_to_io(e)))?
else {
return Err(Error::Closed);
};
if matches!(frame.frame(), Frame::Headers(_)) {
let encoded = frame.buffer_payload().await?;
break h3
.decode_field_section(encoded, stream_id)
.await
.map_err(|_| Error::InvalidHead)?;
}
};
log::trace!("received:\n{field_section}");
self.status = field_section.pseudo_headers().status();
self.response_headers = field_section.into_headers().into_owned();
self.response_body_state = ReceivedBodyState::new_h3();
Ok(())
}
pub(super) fn update_alt_svc_from_response(&self, h3: &H3ClientState) {
if let Some(alt_svc) = self.response_headers.get_str(KnownHeaderName::AltSvc) {
h3.update_alt_svc(alt_svc, &self.url);
}
}
}
fn encode_field_section_h3(
h3: &H3Connection,
field_section: &FieldSection<'_>,
max_peer_field_section_size: Option<u64>,
initial_cap: usize,
buffer: &mut Vec<u8>,
stream_id: u64,
) -> io::Result<()> {
let mut field_section_buf = Vec::with_capacity(initial_cap);
h3.encode_field_section(field_section, &mut field_section_buf, stream_id)
.map_err(|error| {
log::error!("encode error: {error:?}");
io::Error::other(error)
})?;
let size = field_section_buf.len() as u64;
if let Some(max_size) = max_peer_field_section_size
&& size > max_size
{
return Err(io::Error::new(
ErrorKind::InvalidData,
format!("field section would be longer than peer allows ({size} > {max_size})"),
));
}
let frame = Frame::Headers(field_section_buf.len() as u64);
let frame_header_len = frame.encoded_len();
buffer.resize(frame_header_len, 0);
frame.encode(buffer);
buffer.extend_from_slice(&field_section_buf);
Ok(())
}