use std::{io, net::SocketAddr, sync::Arc};
use bytes::{Buf, Bytes};
use futures_util::lock::Mutex;
use h3::server::RequestStream;
use h3_quinn::BidiStream;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, warn};
use crate::{
access::AccessControl,
authority::MessageResponse,
server::{
ResponseInfo,
request_handler::RequestHandler,
response_handler::{ResponseHandler, encode_fallback_servfail_response},
},
};
use hickory_proto::{
ProtoError,
h3::{H3Error, h3_server::H3Connection},
http::Version,
rr::Record,
xfer::Protocol,
};
pub(crate) async fn h3_handler<T>(
access: Arc<AccessControl>,
handler: Arc<T>,
mut connection: H3Connection,
src_addr: SocketAddr,
_dns_hostname: Option<Arc<str>>,
shutdown: CancellationToken,
) -> Result<(), ProtoError>
where
T: RequestHandler,
{
let mut max_requests = 100u32;
loop {
let (_, mut stream) = tokio::select! {
result = connection.accept() => match result {
Some(Ok(next_request)) => next_request,
Some(Err(err)) => {
warn!("error accepting request {}: {}", src_addr, err);
return Err(err);
}
None => {
break;
}
},
_ = shutdown.cancelled() => {
break;
},
};
let request = match stream
.recv_data()
.await
.map_err(|e| ProtoError::from(format!("h3 stream receive data failed: {e}")))?
{
Some(mut request) => request.copy_to_bytes(request.remaining()),
None => continue,
};
debug!(
"Received bytes {} from {src_addr} {request:?}",
request.remaining()
);
let handler = handler.clone();
let access = access.clone();
let stream = Arc::new(Mutex::new(stream));
let responder = H3ResponseHandle(stream.clone());
tokio::spawn(async move {
super::handle_request(&request, src_addr, Protocol::H3, access, handler, responder)
.await
});
max_requests -= 1;
if max_requests == 0 {
warn!("exceeded request count, shutting down h3 conn: {src_addr}");
connection.shutdown().await?;
break;
}
}
Ok(())
}
#[derive(Clone)]
struct H3ResponseHandle(Arc<Mutex<RequestStream<BidiStream<Bytes>, Bytes>>>);
#[async_trait::async_trait]
impl ResponseHandler for H3ResponseHandle {
async fn send_response<'a>(
&mut self,
response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo> {
use crate::proto::http::response;
use crate::proto::serialize::binary::BinEncoder;
let id = response.header().id();
let mut bytes = Vec::with_capacity(512);
let info = {
let mut encoder = BinEncoder::new(&mut bytes);
response.destructive_emit(&mut encoder).or_else(|error| {
error!(%error, "error encoding message");
encode_fallback_servfail_response(id, &mut bytes)
})?
};
let bytes = Bytes::from(bytes);
let response = response::new(Version::Http3, bytes.len())?;
debug!("sending response: {:#?}", response);
let mut stream = self.0.lock().await;
stream
.send_response(response)
.await
.map_err(H3Error::from)?;
stream.send_data(bytes).await.map_err(H3Error::from)?;
stream.finish().await.map_err(H3Error::from)?;
Ok(info)
}
}