use std::net::SocketAddr;
use hickory_proto::ProtoError;
use crate::{
authority::MessageRequest,
proto::{
op::{Header, LowerQuery, ResponseCode},
xfer::Protocol,
},
server::ResponseHandler,
};
#[derive(Debug)]
pub struct Request {
message: MessageRequest,
src: SocketAddr,
protocol: Protocol,
}
impl Request {
pub fn new(message: MessageRequest, src: SocketAddr, protocol: Protocol) -> Self {
Self {
message,
src,
protocol,
}
}
pub fn request_info(&self) -> Result<RequestInfo<'_>, ProtoError> {
Ok(RequestInfo {
src: self.src,
protocol: self.protocol,
header: self.message.header(),
query: self.message.raw_queries().try_as_query()?,
})
}
pub fn src(&self) -> SocketAddr {
self.src
}
pub fn protocol(&self) -> Protocol {
self.protocol
}
}
impl std::ops::Deref for Request {
type Target = MessageRequest;
fn deref(&self) -> &Self::Target {
&self.message
}
}
impl std::ops::DerefMut for Request {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.message
}
}
#[non_exhaustive]
#[derive(Clone)]
pub struct RequestInfo<'a> {
pub src: SocketAddr,
pub protocol: Protocol,
pub header: &'a Header,
pub query: &'a LowerQuery,
}
impl<'a> RequestInfo<'a> {
pub fn new(
src: SocketAddr,
protocol: Protocol,
header: &'a Header,
query: &'a LowerQuery,
) -> Self {
Self {
src,
protocol,
header,
query,
}
}
}
#[derive(Clone, Copy, Debug)]
#[repr(transparent)]
pub struct ResponseInfo(Header);
impl ResponseInfo {
pub(crate) fn serve_failed() -> Self {
let mut header = Header::new();
header.set_response_code(ResponseCode::ServFail);
header.into()
}
}
impl From<Header> for ResponseInfo {
fn from(header: Header) -> Self {
Self(header)
}
}
impl std::ops::Deref for ResponseInfo {
type Target = Header;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[async_trait::async_trait]
pub trait RequestHandler: Send + Sync + Unpin + 'static {
async fn handle_request<R: ResponseHandler>(
&self,
request: &mut Request,
response_handle: R,
) -> ResponseInfo;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proto::op::{Header, Query};
#[test]
fn request_info_clone() {
let query: Query = Query::new();
let header = Header::new();
let lower_query = query.into();
let origin = RequestInfo::new(
"127.0.0.1:3000".parse().unwrap(),
Protocol::Udp,
&header,
&lower_query,
);
let cloned = origin.clone();
assert_eq!(origin.header, cloned.header);
}
}