use crate::{ArcHandler, Runtime};
use futures_lite::io::{AsyncRead, AsyncWrite};
use std::{
io,
net::IpAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use trillium::{Handler, Transport, Upgrade};
use trillium_http::{
HttpContext,
h2::{H2Connection, H2Transport},
};
pub(crate) const CLIENT_PREFACE: &[u8; 24] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
pub(crate) async fn run_h2<T>(
transport: T,
context: Arc<HttpContext>,
handler: ArcHandler<impl Handler>,
runtime: Runtime,
peer_ip: Option<IpAddr>,
is_secure: bool,
) where
T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
let h2 = H2Connection::new(context);
let mut driver = h2.clone().run(transport);
while let Some(result) = driver.next().await {
match result {
Ok(conn) => {
let stream_id = conn.h2_stream_id();
log::trace!("run_h2: spawning handler task for stream {stream_id:?}");
let handler = handler.clone();
runtime.spawn(async move {
let inner_handler = handler.clone();
let result = H2Connection::process_inbound(conn, |mut conn| async move {
let handler = &inner_handler;
conn.set_peer_ip(peer_ip);
conn.set_secure(is_secure);
let conn = handler.run(conn.into()).await;
let conn = handler.before_send(conn).await;
conn.into_inner::<H2Transport>()
})
.await;
match result {
Ok(conn) if conn.should_upgrade() => {
let upgrade = Upgrade::from(conn);
if handler.has_upgrade(&upgrade) {
log::debug!("upgrading h2 stream");
handler.upgrade(upgrade).await;
} else {
log::error!("h2 upgrade specified but no upgrade handler provided");
}
}
Ok(_) => {}
Err(e) => {
log::debug!("h2 stream error: {e}");
}
}
});
}
Err(e) => {
log::debug!("h2 connection error: {e}");
break;
}
}
}
log::trace!("run_h2: driver exhausted, connection done");
}
#[derive(Debug)]
pub(crate) struct Prefixed<T> {
prefix: Vec<u8>,
offset: usize,
inner: T,
}
impl<T> Prefixed<T> {
pub(crate) fn new(prefix: Vec<u8>, inner: T) -> Self {
Self {
prefix,
offset: 0,
inner,
}
}
}
impl<T: AsyncRead + Unpin> AsyncRead for Prefixed<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if this.offset < this.prefix.len() {
let take = (this.prefix.len() - this.offset).min(buf.len());
buf[..take].copy_from_slice(&this.prefix[this.offset..this.offset + take]);
this.offset += take;
if this.offset >= this.prefix.len() {
this.prefix = Vec::new();
}
return Poll::Ready(Ok(take));
}
Pin::new(&mut this.inner).poll_read(cx, buf)
}
}
impl<T: AsyncWrite + Unpin> AsyncWrite for Prefixed<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.get_mut().inner).poll_close(cx)
}
}
impl<T: Transport> Transport for Prefixed<T> {
fn set_linger(&mut self, linger: Option<std::time::Duration>) -> io::Result<()> {
self.inner.set_linger(linger)
}
fn set_nodelay(&mut self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
fn set_ip_ttl(&mut self, ttl: u32) -> io::Result<()> {
self.inner.set_ip_ttl(ttl)
}
fn peer_addr(&self) -> io::Result<Option<std::net::SocketAddr>> {
self.inner.peer_addr()
}
fn negotiated_alpn(&self) -> Option<std::borrow::Cow<'_, [u8]>> {
self.inner.negotiated_alpn()
}
}