use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
use dns_protocol::{Cursor, Deserialize, Label, Message, Question, ResourceType};
use super::{
Srv, Txt,
error::{ProtoError, proto_error_parse},
};
#[derive(Debug, Clone, Copy)]
pub enum Response<'a> {
A {
name: Label<'a>,
addr: Ipv4Addr,
},
AAAA {
name: Label<'a>,
addr: Ipv6Addr,
zone: Option<u32>,
},
Ptr(Label<'a>),
Txt {
name: Label<'a>,
txt: Txt<'a, 'a>,
},
Srv {
name: Label<'a>,
srv: Srv<'a>,
},
}
trait Ipv6AddrExt {
fn is_unicast_link_local(&self) -> bool;
fn is_multicast_link_local(&self) -> bool;
}
impl Ipv6AddrExt for Ipv6Addr {
#[inline]
fn is_unicast_link_local(&self) -> bool {
let octets = self.octets();
octets[0] == 0xfe && (octets[1] & 0xc0) == 0x80
}
#[inline]
fn is_multicast_link_local(&self) -> bool {
let octets = self.octets();
octets[0] == 0xff && (octets[1] & 0x0f) == 0x02
}
}
pub struct Endpoint;
impl Endpoint {
pub fn prepare_question(name: Label<'_>, unicast_response: bool) -> Question<'_> {
let qclass = if unicast_response {
let base: u16 = 1;
base | (1 << 15)
} else {
1
};
Question::new(name, ResourceType::Ptr, qclass)
}
pub fn recv<'innards>(
from: SocketAddr,
msg: &Message<'_, 'innards>,
) -> impl Iterator<Item = Result<Response<'innards>, ProtoError>> {
msg
.answers()
.iter()
.chain(msg.additional().iter())
.filter_map(move |record| {
let record_name = record.name();
match record.ty() {
ResourceType::A => {
let src = record.data();
let res: Result<[u8; 4], _> = src.try_into();
match res {
Ok(ip) => Some(Ok(Response::A {
name: record_name,
addr: Ipv4Addr::from(ip),
})),
Err(_) => {
#[cfg(feature = "tracing")]
tracing::error!("mdns endpoint: invalid A record data");
Some(Err(proto_error_parse("A")))
}
}
}
ResourceType::AAAA => {
let src = record.data();
let res: Result<[u8; 16], _> = src.try_into();
match res {
Ok(ip) => {
let ip = Ipv6Addr::from(ip);
let mut zone = None;
if Ipv6AddrExt::is_unicast_link_local(&ip) || ip.is_multicast_link_local() {
if let SocketAddr::V6(addr) = from {
zone = Some(addr.scope_id());
}
}
Some(Ok(Response::AAAA {
name: record_name,
addr: ip,
zone,
}))
}
Err(_) => {
#[cfg(feature = "tracing")]
tracing::error!("mdns endpoint: invalid AAAA record data");
Some(Err(proto_error_parse("AAAA")))
}
}
}
ResourceType::Ptr => {
let mut label = Label::default();
let cursor = Cursor::new(record.data());
Some(label.deserialize(cursor).map(|_| Response::Ptr(label)))
}
ResourceType::Srv => {
let data = record.data();
Some(Srv::from_bytes(data).map(|srv| Response::Srv {
name: record_name,
srv,
}))
}
ResourceType::Txt => {
let data = record.data();
Some(Ok(Response::Txt {
name: record_name,
txt: Txt::from_bytes(data),
}))
}
_ => None,
}
})
}
}