#![cfg_attr(not(feature = "std"), no_std)]
#![warn(clippy::large_futures)]
#![allow(async_fn_in_trait)]
#![allow(clippy::uninlined_format_args)]
#![allow(unknown_lints)]
use core::cmp::Ordering;
use core::fmt::Display;
use core::ops::RangeBounds;
use domain::base::header::Flags;
use domain::base::iana::{Opcode, Rcode};
use domain::base::message::ShortMessage;
use domain::base::message_builder::PushError;
use domain::base::name::{FromStrError, Label, ToLabelIter};
use domain::base::rdata::ComposeRecordData;
use domain::base::wire::{Composer, ParseError};
use domain::base::{
Message, MessageBuilder, ParsedName, Question, Record, RecordData, Rtype, ToName,
};
use domain::dep::octseq::{FreezeBuilder, FromBuilder, Octets, OctetsBuilder, ShortBuf, Truncate};
use domain::rdata::AllRecordData;
pub(crate) mod fmt;
#[cfg(feature = "io")]
pub mod buf; pub mod domain {
pub use domain::*;
}
pub mod host;
#[cfg(feature = "io")]
pub mod io;
pub const DNS_SD_OWNER: NameSlice = NameSlice::new(&["_services", "_dns-sd", "_udp", "local"]);
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum MdnsError {
ShortBuf,
InvalidMessage,
}
impl Display for MdnsError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::ShortBuf => write!(f, "ShortBuf"),
Self::InvalidMessage => write!(f, "InvalidMessage"),
}
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for MdnsError {
fn format(&self, f: defmt::Formatter<'_>) {
match self {
Self::ShortBuf => defmt::write!(f, "ShortBuf"),
Self::InvalidMessage => defmt::write!(f, "InvalidMessage"),
}
}
}
impl core::error::Error for MdnsError {}
impl From<ShortBuf> for MdnsError {
fn from(_: ShortBuf) -> Self {
Self::ShortBuf
}
}
impl From<PushError> for MdnsError {
fn from(_: PushError) -> Self {
Self::ShortBuf
}
}
impl From<FromStrError> for MdnsError {
fn from(_: FromStrError) -> Self {
Self::InvalidMessage
}
}
impl From<ShortMessage> for MdnsError {
fn from(_: ShortMessage) -> Self {
Self::InvalidMessage
}
}
impl From<ParseError> for MdnsError {
fn from(_: ParseError) -> Self {
Self::InvalidMessage
}
}
#[derive(Debug, Clone)]
pub struct NameSlice<'a>(&'a [&'a str]);
impl<'a> NameSlice<'a> {
pub const fn new(labels: &'a [&'a str]) -> Self {
Self(labels)
}
}
impl core::fmt::Display for NameSlice<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
for label in self.0 {
write!(f, "{}.", label)?;
}
Ok(())
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for NameSlice<'_> {
fn format(&self, f: defmt::Formatter<'_>) {
for label in self.0 {
defmt::write!(f, "{}.", label);
}
}
}
impl ToName for NameSlice<'_> {}
#[derive(Clone)]
pub struct NameSliceIter<'a> {
name: &'a NameSlice<'a>,
index: usize,
}
impl<'a> Iterator for NameSliceIter<'a> {
type Item = &'a Label;
fn next(&mut self) -> Option<Self::Item> {
match self.index.cmp(&self.name.0.len()) {
Ordering::Less => {
let label = unwrap!(
Label::from_slice(self.name.0[self.index].as_bytes()),
"Unreachable"
);
self.index += 1;
Some(label)
}
Ordering::Equal => {
let label = Label::root();
self.index += 1;
Some(label)
}
Ordering::Greater => None,
}
}
}
impl DoubleEndedIterator for NameSliceIter<'_> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.index > 0 {
self.index -= 1;
if self.index == self.name.0.len() {
let label = Label::root();
Some(label)
} else {
let label = unwrap!(
Label::from_slice(self.name.0[self.index].as_bytes()),
"Unreachable"
);
Some(label)
}
} else {
None
}
}
}
impl ToLabelIter for NameSlice<'_> {
type LabelIter<'t>
= NameSliceIter<'t>
where
Self: 't;
fn iter_labels(&self) -> Self::LabelIter<'_> {
NameSliceIter {
name: self,
index: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct Txt<'a>(&'a [(&'a str, &'a str)]);
impl<'a> Txt<'a> {
pub const fn new(txt: &'a [(&'a str, &'a str)]) -> Self {
Self(txt)
}
}
impl core::fmt::Display for Txt<'_> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "Txt [")?;
for (i, (k, v)) in self.0.iter().enumerate() {
if i > 0 {
write!(f, ", {}={}", k, v)?;
} else {
write!(f, "{}={}", k, v)?;
}
}
write!(f, "]")?;
Ok(())
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for Txt<'_> {
fn format(&self, f: defmt::Formatter<'_>) {
defmt::write!(f, "Txt [");
for (i, (k, v)) in self.0.iter().enumerate() {
if i > 0 {
defmt::write!(f, ", {}={}", k, v);
} else {
defmt::write!(f, "{}={}", k, v);
}
}
defmt::write!(f, "]");
}
}
impl RecordData for Txt<'_> {
fn rtype(&self) -> Rtype {
Rtype::TXT
}
}
impl ComposeRecordData for Txt<'_> {
fn rdlen(&self, _compress: bool) -> Option<u16> {
None
}
fn compose_rdata<Target: Composer + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
if self.0.is_empty() {
target.append_slice(&[0])?;
} else {
for (k, v) in self.0 {
target.append_slice(&[(k.len() + v.len() + 1) as u8])?;
target.append_slice(k.as_bytes())?;
target.append_slice(b"=")?;
target.append_slice(v.as_bytes())?;
}
}
Ok(())
}
fn compose_canonical_rdata<Target: Composer + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
self.compose_rdata(target)
}
}
#[derive(Debug, Clone)]
pub enum RecordDataChain<T, U> {
This(T),
Next(U),
}
impl<T, U> core::fmt::Display for RecordDataChain<T, U>
where
T: core::fmt::Display,
U: core::fmt::Display,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::This(data) => write!(f, "{}", data),
Self::Next(data) => write!(f, "{}", data),
}
}
}
#[cfg(feature = "defmt")]
impl<T, U> defmt::Format for RecordDataChain<T, U>
where
T: defmt::Format,
U: defmt::Format,
{
fn format(&self, f: defmt::Formatter<'_>) {
match self {
Self::This(data) => defmt::write!(f, "{}", data),
Self::Next(data) => defmt::write!(f, "{}", data),
}
}
}
impl<T, U> RecordData for RecordDataChain<T, U>
where
T: RecordData,
U: RecordData,
{
fn rtype(&self) -> Rtype {
match self {
Self::This(data) => data.rtype(),
Self::Next(data) => data.rtype(),
}
}
}
impl<T, U> ComposeRecordData for RecordDataChain<T, U>
where
T: ComposeRecordData,
U: ComposeRecordData,
{
fn rdlen(&self, compress: bool) -> Option<u16> {
match self {
Self::This(data) => data.rdlen(compress),
Self::Next(data) => data.rdlen(compress),
}
}
fn compose_rdata<Target: Composer + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
match self {
Self::This(data) => data.compose_rdata(target),
Self::Next(data) => data.compose_rdata(target),
}
}
fn compose_canonical_rdata<Target: Composer + ?Sized>(
&self,
target: &mut Target,
) -> Result<(), Target::AppendError> {
match self {
Self::This(data) => data.compose_canonical_rdata(target),
Self::Next(data) => data.compose_canonical_rdata(target),
}
}
}
pub struct Buf<'a>(pub &'a mut [u8], pub usize);
impl<'a> Buf<'a> {
pub fn new(buf: &'a mut [u8]) -> Self {
Self(buf, 0)
}
}
impl FreezeBuilder for Buf<'_> {
type Octets = Self;
fn freeze(self) -> Self {
self
}
}
impl Octets for Buf<'_> {
type Range<'r>
= &'r [u8]
where
Self: 'r;
fn range(&self, range: impl RangeBounds<usize>) -> Self::Range<'_> {
self.0[..self.1].range(range)
}
}
impl<'a> FromBuilder for Buf<'a> {
type Builder = Buf<'a>;
fn from_builder(builder: Self::Builder) -> Self {
Buf(&mut builder.0[builder.1..], 0)
}
}
impl Composer for Buf<'_> {}
impl OctetsBuilder for Buf<'_> {
type AppendError = ShortBuf;
fn append_slice(&mut self, slice: &[u8]) -> Result<(), Self::AppendError> {
if self.1 + slice.len() <= self.0.len() {
let end = self.1 + slice.len();
self.0[self.1..end].copy_from_slice(slice);
self.1 = end;
Ok(())
} else {
Err(ShortBuf)
}
}
}
impl Truncate for Buf<'_> {
fn truncate(&mut self, len: usize) {
self.1 = len;
}
}
impl AsMut<[u8]> for Buf<'_> {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.0[..self.1]
}
}
impl AsRef<[u8]> for Buf<'_> {
fn as_ref(&self) -> &[u8] {
&self.0[..self.1]
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum MdnsRequest<'a> {
None,
Request {
legacy: bool,
multicast: bool,
data: &'a [u8],
},
}
#[derive(Debug, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum MdnsResponse<'a> {
None,
Reply { data: &'a [u8], delay: bool },
}
pub trait MdnsHandler {
fn handle<'a>(
&mut self,
request: MdnsRequest<'_>,
response_buf: &'a mut [u8],
) -> Result<MdnsResponse<'a>, MdnsError>;
}
impl<T> MdnsHandler for &mut T
where
T: MdnsHandler,
{
fn handle<'a>(
&mut self,
request: MdnsRequest<'_>,
response_buf: &'a mut [u8],
) -> Result<MdnsResponse<'a>, MdnsError> {
(**self).handle(request, response_buf)
}
}
pub struct NoHandler;
impl NoHandler {
pub fn chain<T>(self, handler: T) -> ChainedHandler<T, Self> {
ChainedHandler::new(handler, self)
}
}
impl MdnsHandler for NoHandler {
fn handle<'a>(
&mut self,
_request: MdnsRequest<'_>,
_response_buf: &'a mut [u8],
) -> Result<MdnsResponse<'a>, MdnsError> {
Ok(MdnsResponse::None)
}
}
pub struct ChainedHandler<T, U> {
first: T,
second: U,
}
impl<T, U> ChainedHandler<T, U> {
pub const fn new(first: T, second: U) -> Self {
Self { first, second }
}
pub fn chain<V>(self, handler: V) -> ChainedHandler<V, Self> {
ChainedHandler::new(handler, self)
}
}
impl<T, U> MdnsHandler for ChainedHandler<T, U>
where
T: MdnsHandler,
U: MdnsHandler,
{
fn handle<'a>(
&mut self,
request: MdnsRequest<'_>,
response_buf: &'a mut [u8],
) -> Result<MdnsResponse<'a>, MdnsError> {
match self.first.handle(request.clone(), response_buf)? {
MdnsResponse::None => self.second.handle(request, response_buf),
MdnsResponse::Reply { data, delay } => {
let len = data.len();
Ok(MdnsResponse::Reply {
data: &response_buf[..len],
delay,
})
}
}
}
}
pub type HostAnswer<'a> =
Record<NameSlice<'a>, RecordDataChain<Txt<'a>, AllRecordData<&'a [u8], NameSlice<'a>>>>;
pub trait HostAnswers {
fn visit<F, E>(&self, f: F) -> Result<(), E>
where
F: FnMut(HostAnswer) -> Result<(), E>,
E: From<MdnsError>;
}
impl<T> HostAnswers for &T
where
T: HostAnswers,
{
fn visit<F, E>(&self, f: F) -> Result<(), E>
where
F: FnMut(HostAnswer) -> Result<(), E>,
E: From<MdnsError>,
{
(*self).visit(f)
}
}
impl<T> HostAnswers for &mut T
where
T: HostAnswers,
{
fn visit<F, E>(&self, f: F) -> Result<(), E>
where
F: FnMut(HostAnswer) -> Result<(), E>,
E: From<MdnsError>,
{
(**self).visit(f)
}
}
pub type HostQuestion<'a> = Question<NameSlice<'a>>;
pub trait HostQuestions {
fn visit<F, E>(&self, f: F) -> Result<(), E>
where
F: FnMut(HostQuestion) -> Result<(), E>,
E: From<MdnsError>;
fn query(&self, id: u16, buf: &mut [u8]) -> Result<usize, MdnsError> {
let buf = Buf(buf, 0);
let mut mb = MessageBuilder::from_target(buf)?;
set_header(&mut mb, id, false);
let mut qb = mb.question();
let mut pushed = false;
self.visit(|question| {
qb.push(question)?;
pushed = true;
Ok::<_, MdnsError>(())
})?;
let buf = qb.finish();
if pushed {
Ok(buf.1)
} else {
Ok(0)
}
}
}
impl<T> HostQuestions for &T
where
T: HostQuestions,
{
fn visit<F, E>(&self, f: F) -> Result<(), E>
where
F: FnMut(HostQuestion) -> Result<(), E>,
E: From<MdnsError>,
{
(*self).visit(f)
}
}
impl<T> HostQuestions for &mut T
where
T: HostQuestions,
{
fn visit<F, E>(&self, f: F) -> Result<(), E>
where
F: FnMut(HostQuestion) -> Result<(), E>,
E: From<MdnsError>,
{
(**self).visit(f)
}
}
pub struct NoHostQuestions;
impl NoHostQuestions {
pub fn chain<T>(self, questions: T) -> ChainedHostQuestions<T, Self> {
ChainedHostQuestions::new(questions, self)
}
}
impl HostQuestions for NoHostQuestions {
fn visit<F, E>(&self, _f: F) -> Result<(), E>
where
F: FnMut(HostQuestion) -> Result<(), E>,
{
Ok(())
}
}
pub struct ChainedHostQuestions<T, U> {
first: T,
second: U,
}
impl<T, U> ChainedHostQuestions<T, U> {
pub const fn new(first: T, second: U) -> Self {
Self { first, second }
}
pub fn chain<V>(self, answers: V) -> ChainedHostQuestions<V, Self> {
ChainedHostQuestions::new(answers, self)
}
}
impl<T, U> HostQuestions for ChainedHostQuestions<T, U>
where
T: HostQuestions,
U: HostQuestions,
{
fn visit<F, E>(&self, mut f: F) -> Result<(), E>
where
F: FnMut(HostQuestion) -> Result<(), E>,
E: From<MdnsError>,
{
self.first.visit(&mut f)?;
self.second.visit(f)
}
}
pub struct NoHostAnswers;
impl NoHostAnswers {
pub fn chain<T>(self, answers: T) -> ChainedHostAnswers<T, Self> {
ChainedHostAnswers::new(answers, self)
}
}
impl HostAnswers for NoHostAnswers {
fn visit<F, E>(&self, _f: F) -> Result<(), E>
where
F: FnMut(HostAnswer) -> Result<(), E>,
{
Ok(())
}
}
pub struct ChainedHostAnswers<T, U> {
first: T,
second: U,
}
impl<T, U> ChainedHostAnswers<T, U> {
pub const fn new(first: T, second: U) -> Self {
Self { first, second }
}
pub fn chain<V>(self, answers: V) -> ChainedHostAnswers<V, Self> {
ChainedHostAnswers::new(answers, self)
}
}
impl<T, U> HostAnswers for ChainedHostAnswers<T, U>
where
T: HostAnswers,
U: HostAnswers,
{
fn visit<F, E>(&self, mut f: F) -> Result<(), E>
where
F: FnMut(HostAnswer) -> Result<(), E>,
E: From<MdnsError>,
{
self.first.visit(&mut f)?;
self.second.visit(f)
}
}
pub struct HostAnswersMdnsHandler<T> {
answers: T,
}
impl<T> HostAnswersMdnsHandler<T> {
pub const fn new(answers: T) -> Self {
Self { answers }
}
}
impl<T> MdnsHandler for HostAnswersMdnsHandler<T>
where
T: HostAnswers,
{
fn handle<'a>(
&mut self,
request: MdnsRequest<'_>,
response_buf: &'a mut [u8],
) -> Result<MdnsResponse<'a>, MdnsError> {
let buf = Buf(response_buf, 0);
let mut mb = MessageBuilder::from_target(buf)?;
let mut pushed = false;
let buf = if let MdnsRequest::Request { legacy, data, .. } = request {
let message = Message::from_octets(data)?;
if !matches!(message.header().opcode(), Opcode::QUERY)
|| !matches!(message.header().rcode(), Rcode::NOERROR)
|| message.header().qr()
{
return Ok(MdnsResponse::None);
}
let mut ab = if legacy {
set_header(&mut mb, message.header().id(), true);
let mut qb = mb.question();
for question in message.question() {
qb.push(question?)?;
}
qb.answer()
} else {
set_header(&mut mb, 0, true);
mb.answer()
};
let mut additional_a = false;
let mut additional_srv_txt = false;
for question in message.question() {
let question = question?;
self.answers.visit(|answer| {
if matches!(answer.data(), RecordDataChain::Next(AllRecordData::Srv(_))) {
additional_a = true;
}
if !answer.owner().name_eq(&DNS_SD_OWNER)
&& matches!(answer.data(), RecordDataChain::Next(AllRecordData::Ptr(_)))
{
additional_a = true;
additional_srv_txt = true;
}
if question.qname().name_eq(&answer.owner()) {
debug!(
"Answering question [{}] with: [{}]",
debug2format!(question),
debug2format!(answer)
);
ab.push(answer)?;
pushed = true;
}
Ok::<_, MdnsError>(())
})?;
}
if additional_a || additional_srv_txt {
let mut aa = ab.additional();
self.answers.visit(|answer| {
if matches!(
answer.data(),
RecordDataChain::Next(AllRecordData::A(_))
| RecordDataChain::Next(AllRecordData::Aaaa(_))
| RecordDataChain::Next(AllRecordData::Srv(_))
| RecordDataChain::Next(AllRecordData::Txt(_))
| RecordDataChain::This(Txt(_))
) {
debug!("Additional answer: [{}]", debug2format!(answer));
aa.push(answer)?;
}
Ok::<_, MdnsError>(())
})?;
aa.finish()
} else {
ab.finish()
}
} else {
set_header(&mut mb, 0, true);
let mut ab = mb.answer();
self.answers.visit(|answer| {
ab.push(answer)?;
pushed = true;
Ok::<_, MdnsError>(())
})?;
ab.finish()
};
if pushed {
Ok(MdnsResponse::Reply {
data: &buf.0[..buf.1],
delay: false,
})
} else {
Ok(MdnsResponse::None)
}
}
}
pub type PeerAnswer<'a> =
Record<ParsedName<&'a [u8]>, AllRecordData<&'a [u8], ParsedName<&'a [u8]>>>;
pub trait PeerAnswers {
fn answers<'a, T, A>(&self, answers: T, additional: A) -> Result<(), MdnsError>
where
T: IntoIterator<Item = Result<PeerAnswer<'a>, MdnsError>> + Clone + 'a,
A: IntoIterator<Item = Result<PeerAnswer<'a>, MdnsError>> + Clone + 'a;
}
impl<T> PeerAnswers for &mut T
where
T: PeerAnswers,
{
fn answers<'a, U, V>(&self, answers: U, additional: V) -> Result<(), MdnsError>
where
U: IntoIterator<Item = Result<PeerAnswer<'a>, MdnsError>> + Clone + 'a,
V: IntoIterator<Item = Result<PeerAnswer<'a>, MdnsError>> + Clone + 'a,
{
(**self).answers(answers, additional)
}
}
impl<T> PeerAnswers for &T
where
T: PeerAnswers,
{
fn answers<'a, U, V>(&self, answers: U, additional: V) -> Result<(), MdnsError>
where
U: IntoIterator<Item = Result<PeerAnswer<'a>, MdnsError>> + Clone + 'a,
V: IntoIterator<Item = Result<PeerAnswer<'a>, MdnsError>> + Clone + 'a,
{
(*self).answers(answers, additional)
}
}
pub struct PeerAnswersMdnsHandler<T> {
answers: T,
}
impl<T> PeerAnswersMdnsHandler<T> {
pub const fn new(answers: T) -> Self {
Self { answers }
}
}
impl<T> MdnsHandler for PeerAnswersMdnsHandler<T>
where
T: PeerAnswers,
{
fn handle<'a>(
&mut self,
request: MdnsRequest<'_>,
_response_buf: &'a mut [u8],
) -> Result<MdnsResponse<'a>, MdnsError> {
let MdnsRequest::Request { data, legacy, .. } = request else {
return Ok(MdnsResponse::None);
};
if legacy {
return Ok(MdnsResponse::None);
}
let message = Message::from_octets(data)?;
if !matches!(message.header().opcode(), Opcode::QUERY)
|| !matches!(message.header().rcode(), Rcode::NOERROR)
|| !message.header().qr()
{
return Ok(MdnsResponse::None);
}
let answers = message.answer()?;
let additional = message.additional()?;
let answers = answers.filter_map(|answer| {
match answer {
Ok(answer) => answer.into_record::<AllRecordData<_, _>>(),
Err(e) => Err(e),
}
.map_err(|_| MdnsError::InvalidMessage)
.transpose()
});
let additional = additional.filter_map(|answer| {
match answer {
Ok(answer) => answer.into_record::<AllRecordData<_, _>>(),
Err(e) => Err(e),
}
.map_err(|_| MdnsError::InvalidMessage)
.transpose()
});
self.answers.answers(answers, additional)?;
Ok(MdnsResponse::None)
}
}
pub fn set_header<T: Composer>(answer: &mut MessageBuilder<T>, id: u16, response: bool) {
let header = answer.header_mut();
header.set_id(id);
header.set_opcode(Opcode::QUERY);
header.set_rcode(Rcode::NOERROR);
let mut flags = Flags::new();
flags.qr = response;
flags.aa = response;
header.set_flags(flags);
}