use std::{io, net::SocketAddr, sync::Arc};
use bytes::{Bytes, BytesMut};
use drain::Watch;
use futures_util::lock::Mutex;
use tracing::{debug, warn};
use trust_dns_proto::{
error::ProtoError,
quic::{DoqErrorCode, QuicStream},
rr::Record,
};
use crate::{
authority::MessageResponse,
proto::quic::QuicStreams,
server::{
request_handler::RequestHandler, response_handler::ResponseHandler, server_future,
Protocol, ResponseInfo,
},
};
pub(crate) async fn quic_handler<T>(
handler: Arc<T>,
mut quic_streams: QuicStreams,
src_addr: SocketAddr,
_dns_hostname: Option<Arc<str>>,
shutdown: Watch,
) -> Result<(), ProtoError>
where
T: RequestHandler,
{
let mut max_requests = 100u32;
loop {
let mut request_stream = tokio::select! {
result = quic_streams.next() => match result {
Some(Ok(next_request)) => next_request,
Some(Err(err)) => {
warn!("error accepting request {}: {}", src_addr, err);
return Err(err);
}
None => {
break;
}
},
_ = shutdown.clone().signaled() => {
break;
},
};
let request = request_stream.receive_bytes().await?;
debug!(
"Received bytes {} from {src_addr} {request:?}",
request.len()
);
let handler = handler.clone();
let stream = Arc::new(Mutex::new(request_stream));
let responder = QuicResponseHandle(stream.clone());
handle_request(request, src_addr, handler, responder).await;
max_requests -= 1;
if max_requests == 0 {
warn!("exceeded request count, shutting down quic conn: {src_addr}");
stream.lock().await.stop(DoqErrorCode::NoError)?;
break;
}
}
Ok(())
}
async fn handle_request<T>(
bytes: BytesMut,
src_addr: SocketAddr,
handler: Arc<T>,
responder: QuicResponseHandle,
) where
T: RequestHandler,
{
server_future::handle_request(&bytes, src_addr, Protocol::Quic, handler, responder).await
}
#[derive(Clone)]
struct QuicResponseHandle(Arc<Mutex<QuicStream>>);
#[async_trait::async_trait]
impl ResponseHandler for QuicResponseHandle {
async fn send_response<'a>(
&mut self,
mut response: MessageResponse<
'_,
'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
impl Iterator<Item = &'a Record> + Send + 'a,
>,
) -> io::Result<ResponseInfo> {
use crate::proto::serialize::binary::BinEncoder;
response.header_mut().set_id(0);
let mut bytes = Vec::with_capacity(512);
let info = {
let mut encoder = BinEncoder::new(&mut bytes);
response.destructive_emit(&mut encoder)?
};
let bytes = Bytes::from(bytes);
debug!("sending quic response: {}", bytes.len());
let mut lock = self.0.lock().await;
lock.send_bytes(bytes).await?;
lock.finish().await?;
Ok(info)
}
}