use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use h3::server::RequestResolver;
use http::{Request, Response};
use http_body_util::BodyExt;
use tracing::{debug, info};
use super::AppState;
use crate::{Body, ProxyError};
pub async fn run_h3_server(
addr: SocketAddr,
tls_config: Arc<rustls::ServerConfig>,
state: Arc<AppState>,
) -> Result<(), ProxyError> {
let mut rustls_config = (*tls_config).clone();
rustls_config.alpn_protocols = vec![b"h3".to_vec()];
let quic_server_config = quinn::crypto::rustls::QuicServerConfig::try_from(rustls_config)
.map_err(|e| ProxyError::Internal(format!("failed to create QUIC server config: {e}")))?;
let server_config = quinn::ServerConfig::with_crypto(Arc::new(quic_server_config));
let endpoint = quinn::Endpoint::server(server_config, addr).map_err(|e| {
ProxyError::Internal(format!("failed to bind QUIC endpoint on {addr}: {e}"))
})?;
info!(%addr, "listening for HTTP/3 (QUIC) connections");
loop {
if state.shutdown.is_shutdown() {
info!("HTTP/3 accept loop stopping (shutdown)");
break;
}
let Some(incoming) = endpoint.accept().await else {
info!("QUIC endpoint closed");
break;
};
let state = Arc::clone(&state);
let _conn_guard = state.shutdown.track_conn();
tokio::spawn(async move {
let _guard = _conn_guard;
let client_addr = incoming.remote_address();
if let Err(e) = handle_h3_connection(incoming, client_addr, state).await {
debug!(client = %client_addr, "HTTP/3 connection error: {e}");
}
});
}
endpoint.close(quinn::VarInt::from_u32(0), b"server shutting down");
Ok(())
}
async fn handle_h3_connection(
incoming: quinn::Incoming,
client_addr: SocketAddr,
state: Arc<AppState>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let quinn_conn = incoming.await?;
debug!(client = %client_addr, "QUIC connection established");
let h3_conn = h3_quinn::Connection::new(quinn_conn);
let mut server_conn: h3::server::Connection<h3_quinn::Connection, Bytes> =
h3::server::Connection::new(h3_conn).await?;
loop {
let resolver: Option<RequestResolver<h3_quinn::Connection, Bytes>> =
match server_conn.accept().await {
Ok(resolver) => resolver,
Err(e) => {
debug!(client = %client_addr, "H3 accept error: {e}");
return Err(Box::new(e));
}
};
let Some(resolver) = resolver else {
debug!(client = %client_addr, "H3 connection closing gracefully");
break;
};
let state = Arc::clone(&state);
tokio::spawn(async move {
if let Err(e) = handle_h3_request(resolver, client_addr, state).await {
debug!(client = %client_addr, "H3 request error: {e}");
}
});
}
Ok(())
}
async fn handle_h3_request(
resolver: RequestResolver<h3_quinn::Connection, Bytes>,
client_addr: SocketAddr,
state: Arc<AppState>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let (req, mut stream) = resolver.resolve_request().await?;
debug!(
client = %client_addr,
method = %req.method(),
uri = %req.uri(),
"HTTP/3 request received"
);
let mut body_bytes = Vec::new();
while let Some(chunk) = stream.recv_data().await? {
use bytes::Buf;
body_bytes.extend_from_slice(chunk.chunk());
}
let body_data = Bytes::from(body_bytes);
let response = route_h3_via_salvo(req, body_data, client_addr, &state).await;
let (resp_parts, resp_body) = response.into_parts();
let resp_head = Response::from_parts(resp_parts, ());
stream.send_response(resp_head).await?;
let collected = resp_body
.collect()
.await
.map_err(|e| std::io::Error::other(format!("body collect error: {e}")))?;
let resp_bytes = collected.to_bytes();
if !resp_bytes.is_empty() {
stream.send_data(resp_bytes).await?;
}
stream.finish().await?;
Ok(())
}
async fn route_h3_via_salvo(
req: Request<()>,
body_data: Bytes,
client_addr: SocketAddr,
state: &AppState,
) -> Response<Body> {
use salvo::http::ReqBody;
let service = state.service.load();
let (parts, _) = req.into_parts();
let req_body: ReqBody = ReqBody::Once(body_data);
let hyper_req = hyper::Request::from_parts(parts, req_body);
let local_addr: SocketAddr = ([0, 0, 0, 0], 443).into();
let https_port = state.config.load().global.https_addr.port();
let alt_svc_h3 = format!("h3=\":{https_port}\"; ma=2592000").parse().ok();
let handler = service.hyper_handler(
local_addr.into(),
client_addr.into(),
salvo::http::uri::Scheme::HTTPS,
None,
alt_svc_h3,
);
use hyper::service::Service as HyperService;
let hyper_resp = match handler.call(hyper_req).await {
Ok(resp) => resp,
Err(_) => hyper::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.body(salvo::http::ResBody::None)
.unwrap(),
};
let (parts, res_body) = hyper_resp.into_parts();
let body: Body = res_body
.map_err(|e| -> crate::BoxError { Box::new(e) })
.boxed();
Response::from_parts(parts, body)
}