use std::{net::SocketAddr, sync::Arc};
use bytes::{Buf, Bytes};
use futures_util::lock::Mutex;
use h3::server::RequestStream;
use h3_quinn::BidiStream;
use rustls::server::ResolvesServerCert;
use tokio::{net, task::JoinSet};
use tracing::{debug, warn};
use super::{
ResponseInfo, ServerContext, reap_tasks, request_handler::RequestHandler,
response_handler::ResponseHandler, sanitize_src_address,
};
use crate::{
net::{
NetError,
h3::h3_server::{H3Connection, H3Server},
http::{self, Version},
xfer::Protocol,
},
proto::rr::Record,
zone_handler::MessageResponse,
};
pub(super) async fn handle_h3(
socket: net::UdpSocket,
server_cert_resolver: Arc<dyn ResolvesServerCert>,
dns_hostname: Option<String>,
cx: Arc<ServerContext<impl RequestHandler>>,
) -> Result<(), NetError> {
debug!("registered h3: {:?}", socket);
handle_h3_with_server(
H3Server::with_socket(socket, server_cert_resolver)?,
dns_hostname,
cx,
)
.await
}
pub(super) async fn handle_h3_with_server(
mut server: H3Server,
dns_hostname: Option<String>,
cx: Arc<ServerContext<impl RequestHandler>>,
) -> Result<(), NetError> {
let dns_hostname = dns_hostname.map(|n| n.into());
let mut inner_join_set = JoinSet::new();
loop {
let shutdown = cx.shutdown.clone();
let (streams, src_addr) = tokio::select! {
result = server.accept() => match result {
Ok(Some(c)) => c,
Ok(None) => continue,
Err(error) => {
debug!(%error, "error receiving h3 connection");
continue;
}
},
_ = shutdown.cancelled() => {
break;
},
};
if let Err(error) = sanitize_src_address(src_addr) {
warn!(
%error, %src_addr,
"address can not be responded to",
);
continue;
}
let cx = cx.clone();
let dns_hostname = dns_hostname.clone();
inner_join_set.spawn(async move {
debug!("starting h3 stream request from: {src_addr}");
let result = h3_handler(streams, src_addr, dns_hostname, cx).await;
if let Err(error) = result {
warn!(%error, %src_addr, "h3 stream processing failed")
}
});
reap_tasks(&mut inner_join_set);
}
Ok(())
}
pub(crate) async fn h3_handler(
mut connection: H3Connection,
src_addr: SocketAddr,
_dns_hostname: Option<Arc<str>>,
cx: Arc<ServerContext<impl RequestHandler>>,
) -> Result<(), NetError> {
let mut max_requests = 100u32;
loop {
let (_, mut stream) = tokio::select! {
result = connection.accept() => match result {
Some(Ok(next_request)) => next_request,
Some(Err(err)) => {
warn!("error accepting request {}: {}", src_addr, err);
return Err(err);
}
None => {
break;
}
},
_ = cx.shutdown.cancelled() => {
break;
},
};
let request = match stream
.recv_data()
.await
.map_err(|e| NetError::from(format!("h3 stream receive data failed: {e}")))?
{
Some(mut request) => request.copy_to_bytes(request.remaining()),
None => continue,
};
debug!(
"Received bytes {} from {src_addr} {request:?}",
request.remaining()
);
let cx = cx.clone();
let stream = Arc::new(Mutex::new(stream));
let responder = H3ResponseHandle(stream.clone());
tokio::spawn(async move {
cx.handle_request(request, src_addr, Protocol::H3, responder)
.await
});
max_requests -= 1;
if max_requests == 0 {
warn!("exceeded request count, shutting down h3 conn: {src_addr}");
connection.shutdown().await?;
break;
}
}
Ok(())
}
#[derive(Clone)]
struct H3ResponseHandle(Arc<Mutex<RequestStream<BidiStream<Bytes>, Bytes>>>);
#[async_trait::async_trait]
impl ResponseHandler for H3ResponseHandle {
async fn send_response<'a>(
&mut self,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> Result<ResponseInfo, NetError> {
let (info, bytes) = response.encode(Protocol::H3)?;
let bytes = Bytes::from(bytes);
let response = http::response(Version::Http3, bytes.len())?;
debug!("sending response: {:#?}", response);
let mut stream = self.0.lock().await;
stream.send_response(response).await?;
stream.send_data(bytes).await?;
stream.finish().await?;
Ok(info)
}
}