use std::collections::{HashMap, VecDeque};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::{Duration, Instant};
use bytes::BytesMut;
use log::{trace, warn};
use shared::{TaggedBytesMut, TransportContext, TransportMessage, TransportProtocol};
use crate::config::{DEFAULT_QUERY_INTERVAL, MAX_MESSAGE_RECORDS, MdnsConfig, RESPONSE_TTL};
use crate::message::header::Header;
use crate::message::name::Name;
use crate::message::parser::Parser;
use crate::message::question::Question;
use crate::message::resource::a::AResource;
use crate::message::resource::{Resource, ResourceHeader};
use crate::message::{DNSCLASS_INET, DnsType, Message};
use shared::error::{Error, Result};
pub const MDNS_MULTICAST_IPV4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
pub const MDNS_PORT: u16 = 5353;
pub const MDNS_DEST_ADDR: SocketAddr = SocketAddr::new(IpAddr::V4(MDNS_MULTICAST_IPV4), MDNS_PORT);
pub type QueryId = u64;
#[derive(Debug, Clone)]
pub struct Query {
pub id: QueryId,
pub name_with_suffix: String,
pub start_time: Instant,
pub next_retry: Instant,
}
#[derive(Debug)]
pub enum MdnsEvent {
QueryAnswered(QueryId, IpAddr),
QueryTimeout(QueryId),
}
pub struct Mdns {
config: MdnsConfig,
local_names: Vec<String>,
queries: Vec<Query>,
next_query_id: QueryId,
query_interval: Duration,
query_timeout: Option<Duration>,
write_outs: VecDeque<TaggedBytesMut>,
event_outs: VecDeque<MdnsEvent>,
next_timeout: Option<Instant>,
closed: bool,
}
impl Mdns {
pub fn new(config: MdnsConfig) -> Self {
let local_names = config
.local_names
.iter()
.map(|name| {
if name.ends_with('.') {
name.clone()
} else {
format!("{name}.")
}
})
.collect();
let query_interval = if config.query_interval == Duration::ZERO {
DEFAULT_QUERY_INTERVAL
} else {
config.query_interval
};
let query_timeout = config.query_timeout;
Self {
config,
local_names,
queries: Vec::new(),
next_query_id: 1,
query_interval,
query_timeout,
write_outs: VecDeque::new(),
event_outs: VecDeque::new(),
next_timeout: None,
closed: false,
}
}
pub fn query(&mut self, name: &str) -> QueryId {
let name_with_suffix = if name.ends_with('.') {
name.to_string()
} else {
format!("{name}.")
};
let id = self.next_query_id;
self.next_query_id += 1;
let now = Instant::now();
let query = Query {
id,
name_with_suffix: name_with_suffix.clone(),
start_time: now,
next_retry: now + self.query_interval, };
self.queries.push(query);
self.send_question(&name_with_suffix, now);
self.update_next_timeout();
id
}
pub fn cancel_query(&mut self, query_id: QueryId) {
self.queries.retain(|q| q.id != query_id);
self.update_next_timeout();
}
pub fn is_query_pending(&self, query_id: QueryId) -> bool {
self.queries.iter().any(|q| q.id == query_id)
}
pub fn pending_query_count(&self) -> usize {
self.queries.len()
}
fn send_question(&mut self, name: &str, now: Instant) {
let packed_name = match Name::new(name) {
Ok(pn) => pn,
Err(err) => {
log::warn!("Failed to construct mDNS packet: {err}");
return;
}
};
let raw_query = {
let mut msg = Message {
header: Header::default(),
questions: vec![Question {
typ: DnsType::A,
class: DNSCLASS_INET,
name: packed_name,
}],
..Default::default()
};
match msg.pack() {
Ok(v) => v,
Err(err) => {
log::error!("Failed to construct mDNS packet {err}");
return;
}
}
};
log::trace!("Queuing mDNS query for {name}");
self.write_outs.push_back(TransportMessage {
now,
transport: TransportContext {
local_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), MDNS_PORT),
peer_addr: MDNS_DEST_ADDR,
transport_protocol: TransportProtocol::UDP,
ecn: None,
},
message: BytesMut::from(&raw_query[..]),
});
}
fn send_answer(&mut self, local_ip: IpAddr, name: &str, now: Instant) {
let packed_name = match Name::new(name) {
Ok(n) => n,
Err(err) => {
log::warn!("Failed to pack name for answer: {err}");
return;
}
};
let raw_answer = {
let mut msg = Message {
header: Header {
response: true,
authoritative: true,
..Default::default()
},
answers: vec![Resource {
header: ResourceHeader {
typ: DnsType::A,
class: DNSCLASS_INET,
name: packed_name,
ttl: RESPONSE_TTL,
..Default::default()
},
body: Some(Box::new(AResource {
a: match local_ip {
IpAddr::V4(ip) => ip.octets(),
IpAddr::V6(_) => {
log::warn!("Cannot send IPv6 address in A record");
return;
}
},
})),
}],
..Default::default()
};
match msg.pack() {
Ok(v) => v,
Err(err) => {
log::error!("Failed to pack answer: {err}");
return;
}
}
};
log::trace!("mDNS Queuing answer for {name} -> {local_ip}");
self.write_outs.push_back(TransportMessage {
now,
transport: TransportContext {
local_addr: SocketAddr::new(local_ip, MDNS_PORT),
peer_addr: MDNS_DEST_ADDR,
transport_protocol: TransportProtocol::UDP,
ecn: None,
},
message: BytesMut::from(&raw_answer[..]),
});
}
fn process_message(&mut self, msg: &TaggedBytesMut) {
let mut parser = Parser::default();
if let Err(err) = parser.start(&msg.message) {
log::error!("Failed to parse mDNS packet: {err}");
return;
}
let src = msg.transport.peer_addr;
self.process_questions(&mut parser, src, msg.now);
self.process_answers(&mut parser, src);
}
fn process_questions(&mut self, parser: &mut Parser<'_>, _src: SocketAddr, now: Instant) {
let mut names_to_answer: Vec<String> = Vec::new();
for _ in 0..=MAX_MESSAGE_RECORDS {
let q = match parser.question() {
Ok(q) => q,
Err(err) => {
if err == Error::ErrSectionDone {
break;
}
log::error!("Failed to parse question: {err}");
return;
}
};
for local_name in &self.local_names {
if *local_name == q.name.data {
names_to_answer.push(q.name.data.clone());
break;
}
}
}
let _ = parser.skip_all_questions();
if let Some(local_ip) = self.config.local_ip {
for name in names_to_answer {
log::trace!(
"mDNS Found question for local name: {}, responding with {}",
name,
local_ip
);
self.send_answer(local_ip, &name, now);
}
} else if !names_to_answer.is_empty() {
log::warn!("Received questions for local names but no local_addr configured");
}
}
fn process_answers(&mut self, parser: &mut Parser<'_>, src: SocketAddr) {
for _ in 0..=MAX_MESSAGE_RECORDS {
let answer_header = match parser.answer_header() {
Ok(a) => a,
Err(err) => {
if err != Error::ErrSectionDone {
log::warn!("Failed to parse answer header: {err}");
}
return;
}
};
if answer_header.typ != DnsType::A && answer_header.typ != DnsType::Aaaa {
continue;
}
let answer_resource = match parser.answer() {
Ok(a) => a,
Err(err) => {
if err != Error::ErrSectionDone {
log::warn!("Failed to parse answer: {err}");
}
return;
}
};
let local_ip = if let Some(body) = answer_resource.body
&& let Some(a) = body.as_any().downcast_ref::<AResource>()
{
let local_ip = Ipv4Addr::from_octets(a.a).into();
if local_ip != src.ip() {
warn!(
"mDNS answers with different local ip on AResource {} vs src ip {} on Socket for query {}",
local_ip,
src.ip(),
answer_header.name.data
);
} else {
trace!(
"mDNS answers with the local ip {} on AResource and Socket for query {}",
local_ip, answer_header.name.data
);
}
local_ip
} else {
warn!(
"mDNS answers without AResource, fallback to use src ip {} on Socket for local ip for query {}",
src.ip(),
answer_header.name.data
);
src.ip()
};
let mut matched_query_ids = HashMap::new();
for query in &self.queries {
if query.name_with_suffix == answer_header.name.data {
matched_query_ids.insert(query.id, local_ip);
}
}
for (query_id, local_ip) in matched_query_ids {
self.event_outs
.push_back(MdnsEvent::QueryAnswered(query_id, local_ip));
self.queries.retain(|q| q.id != query_id);
}
}
}
fn update_next_timeout(&mut self) {
self.next_timeout = self.queries.iter().map(|q| q.next_retry).min();
}
}
impl sansio::Protocol<TaggedBytesMut, (), ()> for Mdns {
type Rout = ();
type Wout = TaggedBytesMut;
type Eout = MdnsEvent;
type Error = Error;
type Time = Instant;
fn handle_read(&mut self, msg: TaggedBytesMut) -> Result<()> {
if self.closed {
return Err(Error::ErrConnectionClosed);
}
self.process_message(&msg);
self.update_next_timeout();
Ok(())
}
fn poll_read(&mut self) -> Option<Self::Rout> {
None
}
fn handle_write(&mut self, _msg: ()) -> Result<()> {
Ok(())
}
fn poll_write(&mut self) -> Option<Self::Wout> {
self.write_outs.pop_front()
}
fn handle_event(&mut self, _evt: ()) -> Result<()> {
Ok(())
}
fn poll_event(&mut self) -> Option<Self::Eout> {
self.event_outs.pop_front()
}
fn handle_timeout(&mut self, now: Self::Time) -> Result<()> {
if self.closed {
return Err(Error::ErrConnectionClosed);
}
if let Some(next_timeout) = self.next_timeout.as_ref()
&& next_timeout <= &now
{
if let Some(timeout_duration) = self.query_timeout {
let mut timed_out_ids = Vec::new();
for query in &self.queries {
if now.duration_since(query.start_time) >= timeout_duration {
timed_out_ids.push(query.id);
}
}
for query_id in timed_out_ids {
log::debug!(
"mDNS Query {} timed out after {:?}",
query_id,
timeout_duration
);
self.event_outs.push_back(MdnsEvent::QueryTimeout(query_id));
self.queries.retain(|q| q.id != query_id);
}
}
let mut names_to_query = Vec::new();
for query in &mut self.queries {
if query.next_retry <= now {
names_to_query.push(query.name_with_suffix.clone());
query.next_retry = now + self.query_interval;
}
}
for name in names_to_query {
self.send_question(&name, now);
}
self.update_next_timeout();
}
Ok(())
}
fn poll_timeout(&mut self) -> Option<Self::Time> {
self.next_timeout
}
fn close(&mut self) -> Result<()> {
self.closed = true;
self.queries.clear();
self.write_outs.clear();
self.event_outs.clear();
self.next_timeout = None;
Ok(())
}
}
#[cfg(test)]
mod mdns_test;