#[cfg(feature = "logging")]
use crate::log::{debug, trace};
use crate::{
dns_cache::{current_time_millis, DnsCache},
dns_parser::{
ip_address_rr_type, DnsAddress, DnsEntryExt, DnsIncoming, DnsOutgoing, DnsPointer,
DnsRecordBox, DnsRecordExt, DnsSrv, DnsTxt, InterfaceId, RRType, ScopedIp,
CLASS_CACHE_FLUSH, CLASS_IN, FLAGS_AA, FLAGS_QR_QUERY, FLAGS_QR_RESPONSE, MAX_MSG_ABSOLUTE,
},
error::{e_fmt, Error, Result},
service_info::{DnsRegistry, MyIntf, Probe, ServiceInfo, ServiceStatus},
Receiver, ResolvedService, TxtProperties,
};
use flume::{bounded, Sender, TrySendError};
use if_addrs::{IfAddr, Interface};
use mio::{event::Source, net::UdpSocket as MioUdpSocket, Interest, Poll, Registry, Token};
use socket2::Domain;
use socket_pktinfo::PktInfoUdpSocket;
use std::{
cmp::{self, Reverse},
collections::{hash_map::Entry, BinaryHeap, HashMap, HashSet},
fmt, io,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket},
str, thread,
time::Duration,
vec,
};
pub const SERVICE_NAME_LEN_MAX_DEFAULT: u8 = 15;
pub const IP_CHECK_INTERVAL_IN_SECS_DEFAULT: u32 = 5;
pub const VERIFY_TIMEOUT_DEFAULT: Duration = Duration::from_secs(10);
const MDNS_PORT: u16 = 5353;
const GROUP_ADDR_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
const GROUP_ADDR_V6: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0xfb);
const LOOPBACK_V4: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
const RESOLVE_WAIT_IN_MILLIS: u64 = 500;
#[derive(Debug)]
pub enum UnregisterStatus {
OK,
NotFound,
}
#[derive(Debug, PartialEq, Clone, Eq)]
#[non_exhaustive]
pub enum DaemonStatus {
Running,
Shutdown,
}
#[derive(Hash, Eq, PartialEq)]
enum Counter {
Register,
RegisterResend,
Unregister,
UnregisterResend,
Browse,
ResolveHostname,
Respond,
CacheRefreshPTR,
CacheRefreshSrvTxt,
CacheRefreshAddr,
KnownAnswerSuppression,
CachedPTR,
CachedSRV,
CachedAddr,
CachedTxt,
CachedNSec,
CachedSubtype,
DnsRegistryProbe,
DnsRegistryActive,
DnsRegistryTimer,
DnsRegistryNameChange,
Timer,
}
impl fmt::Display for Counter {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Register => write!(f, "register"),
Self::RegisterResend => write!(f, "register-resend"),
Self::Unregister => write!(f, "unregister"),
Self::UnregisterResend => write!(f, "unregister-resend"),
Self::Browse => write!(f, "browse"),
Self::ResolveHostname => write!(f, "resolve-hostname"),
Self::Respond => write!(f, "respond"),
Self::CacheRefreshPTR => write!(f, "cache-refresh-ptr"),
Self::CacheRefreshSrvTxt => write!(f, "cache-refresh-srv-txt"),
Self::CacheRefreshAddr => write!(f, "cache-refresh-addr"),
Self::KnownAnswerSuppression => write!(f, "known-answer-suppression"),
Self::CachedPTR => write!(f, "cached-ptr"),
Self::CachedSRV => write!(f, "cached-srv"),
Self::CachedAddr => write!(f, "cached-addr"),
Self::CachedTxt => write!(f, "cached-txt"),
Self::CachedNSec => write!(f, "cached-nsec"),
Self::CachedSubtype => write!(f, "cached-subtype"),
Self::DnsRegistryProbe => write!(f, "dns-registry-probe"),
Self::DnsRegistryActive => write!(f, "dns-registry-active"),
Self::DnsRegistryTimer => write!(f, "dns-registry-timer"),
Self::DnsRegistryNameChange => write!(f, "dns-registry-name-change"),
Self::Timer => write!(f, "timer"),
}
}
}
struct MyUdpSocket {
pktinfo: PktInfoUdpSocket,
mio: MioUdpSocket,
}
impl MyUdpSocket {
pub fn new(pktinfo: PktInfoUdpSocket) -> io::Result<Self> {
let std_sock = pktinfo.try_clone_std()?;
let mio = MioUdpSocket::from_std(std_sock);
Ok(Self { pktinfo, mio })
}
}
impl Source for MyUdpSocket {
fn register(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> io::Result<()> {
self.mio.register(registry, token, interests)
}
fn reregister(
&mut self,
registry: &Registry,
token: Token,
interests: Interest,
) -> io::Result<()> {
self.mio.reregister(registry, token, interests)
}
fn deregister(&mut self, registry: &Registry) -> std::io::Result<()> {
self.mio.deregister(registry)
}
}
pub type Metrics = HashMap<String, i64>;
const IPV4_SOCK_EVENT_KEY: usize = 4; const IPV6_SOCK_EVENT_KEY: usize = 6; const SIGNAL_SOCK_EVENT_KEY: usize = usize::MAX - 1;
#[derive(Clone)]
pub struct ServiceDaemon {
sender: Sender<Command>,
signal_addr: SocketAddr,
}
impl ServiceDaemon {
pub fn new() -> Result<Self> {
let signal_addr = SocketAddrV4::new(LOOPBACK_V4, 0);
let signal_sock = UdpSocket::bind(signal_addr)
.map_err(|e| e_fmt!("failed to create signal_sock for daemon: {}", e))?;
let signal_addr = signal_sock
.local_addr()
.map_err(|e| e_fmt!("failed to get signal sock addr: {}", e))?;
signal_sock
.set_nonblocking(true)
.map_err(|e| e_fmt!("failed to set nonblocking for signal socket: {}", e))?;
let poller = Poll::new().map_err(|e| e_fmt!("failed to create mio Poll: {e}"))?;
let (sender, receiver) = bounded(100);
let mio_sock = MioUdpSocket::from_std(signal_sock);
thread::Builder::new()
.name("mDNS_daemon".to_string())
.spawn(move || Self::daemon_thread(mio_sock, poller, receiver))
.map_err(|e| e_fmt!("thread builder failed to spawn: {}", e))?;
Ok(Self {
sender,
signal_addr,
})
}
fn send_cmd(&self, cmd: Command) -> Result<()> {
let cmd_name = cmd.to_string();
self.sender.try_send(cmd).map_err(|e| match e {
TrySendError::Full(_) => Error::Again,
e => e_fmt!("flume::channel::send failed: {}", e),
})?;
let addr = SocketAddrV4::new(LOOPBACK_V4, 0);
let socket = UdpSocket::bind(addr)
.map_err(|e| e_fmt!("Failed to create socket to send signal: {}", e))?;
socket
.send_to(cmd_name.as_bytes(), self.signal_addr)
.map_err(|e| {
e_fmt!(
"signal socket send_to {} ({}) failed: {}",
self.signal_addr,
cmd_name,
e
)
})?;
Ok(())
}
pub fn browse(&self, service_type: &str) -> Result<Receiver<ServiceEvent>> {
check_domain_suffix(service_type)?;
let (resp_s, resp_r) = bounded(10);
self.send_cmd(Command::Browse(service_type.to_string(), 1, false, resp_s))?;
Ok(resp_r)
}
pub fn browse_cache(&self, service_type: &str) -> Result<Receiver<ServiceEvent>> {
check_domain_suffix(service_type)?;
let (resp_s, resp_r) = bounded(10);
self.send_cmd(Command::Browse(service_type.to_string(), 1, true, resp_s))?;
Ok(resp_r)
}
pub fn stop_browse(&self, ty_domain: &str) -> Result<()> {
self.send_cmd(Command::StopBrowse(ty_domain.to_string()))
}
pub fn resolve_hostname(
&self,
hostname: &str,
timeout: Option<u64>,
) -> Result<Receiver<HostnameResolutionEvent>> {
check_hostname(hostname)?;
let (resp_s, resp_r) = bounded(10);
self.send_cmd(Command::ResolveHostname(
hostname.to_string(),
1,
resp_s,
timeout,
))?;
Ok(resp_r)
}
pub fn stop_resolve_hostname(&self, hostname: &str) -> Result<()> {
self.send_cmd(Command::StopResolveHostname(hostname.to_string()))
}
pub fn register(&self, service_info: ServiceInfo) -> Result<()> {
check_service_name(service_info.get_fullname())?;
check_hostname(service_info.get_hostname())?;
self.send_cmd(Command::Register(service_info))
}
pub fn unregister(&self, fullname: &str) -> Result<Receiver<UnregisterStatus>> {
let (resp_s, resp_r) = bounded(1);
self.send_cmd(Command::Unregister(fullname.to_lowercase(), resp_s))?;
Ok(resp_r)
}
pub fn monitor(&self) -> Result<Receiver<DaemonEvent>> {
let (resp_s, resp_r) = bounded(100);
self.send_cmd(Command::Monitor(resp_s))?;
Ok(resp_r)
}
pub fn shutdown(&self) -> Result<Receiver<DaemonStatus>> {
let (resp_s, resp_r) = bounded(1);
self.send_cmd(Command::Exit(resp_s))?;
Ok(resp_r)
}
pub fn status(&self) -> Result<Receiver<DaemonStatus>> {
let (resp_s, resp_r) = bounded(1);
if self.sender.is_disconnected() {
resp_s
.send(DaemonStatus::Shutdown)
.map_err(|e| e_fmt!("failed to send daemon status to the client: {}", e))?;
} else {
self.send_cmd(Command::GetStatus(resp_s))?;
}
Ok(resp_r)
}
pub fn get_metrics(&self) -> Result<Receiver<Metrics>> {
let (resp_s, resp_r) = bounded(1);
self.send_cmd(Command::GetMetrics(resp_s))?;
Ok(resp_r)
}
pub fn set_service_name_len_max(&self, len_max: u8) -> Result<()> {
const SERVICE_NAME_LEN_MAX_LIMIT: u8 = 30;
if len_max > SERVICE_NAME_LEN_MAX_LIMIT {
return Err(Error::Msg(format!(
"service name length max {len_max} is too large"
)));
}
self.send_cmd(Command::SetOption(DaemonOption::ServiceNameLenMax(len_max)))
}
pub fn set_ip_check_interval(&self, interval_in_secs: u32) -> Result<()> {
let interval_in_millis = interval_in_secs as u64 * 1000;
self.send_cmd(Command::SetOption(DaemonOption::IpCheckInterval(
interval_in_millis,
)))
}
pub fn get_ip_check_interval(&self) -> Result<u32> {
let (resp_s, resp_r) = bounded(1);
self.send_cmd(Command::GetOption(resp_s))?;
let option = resp_r
.recv_timeout(Duration::from_secs(10))
.map_err(|e| e_fmt!("failed to receive ip check interval: {}", e))?;
let ip_check_interval_in_secs = option.ip_check_interval / 1000;
Ok(ip_check_interval_in_secs as u32)
}
pub fn enable_interface(&self, if_kind: impl IntoIfKindVec) -> Result<()> {
let if_kind_vec = if_kind.into_vec();
self.send_cmd(Command::SetOption(DaemonOption::EnableInterface(
if_kind_vec.kinds,
)))
}
pub fn disable_interface(&self, if_kind: impl IntoIfKindVec) -> Result<()> {
let if_kind_vec = if_kind.into_vec();
self.send_cmd(Command::SetOption(DaemonOption::DisableInterface(
if_kind_vec.kinds,
)))
}
pub fn accept_unsolicited(&self, accept: bool) -> Result<()> {
self.send_cmd(Command::SetOption(DaemonOption::AcceptUnsolicited(accept)))
}
#[cfg(test)]
pub fn test_down_interface(&self, ifname: &str) -> Result<()> {
self.send_cmd(Command::SetOption(DaemonOption::TestDownInterface(
ifname.to_string(),
)))
}
#[cfg(test)]
pub fn test_up_interface(&self, ifname: &str) -> Result<()> {
self.send_cmd(Command::SetOption(DaemonOption::TestUpInterface(
ifname.to_string(),
)))
}
pub fn set_multicast_loop_v4(&self, on: bool) -> Result<()> {
self.send_cmd(Command::SetOption(DaemonOption::MulticastLoopV4(on)))
}
pub fn set_multicast_loop_v6(&self, on: bool) -> Result<()> {
self.send_cmd(Command::SetOption(DaemonOption::MulticastLoopV6(on)))
}
pub fn verify(&self, instance_fullname: String, timeout: Duration) -> Result<()> {
self.send_cmd(Command::Verify(instance_fullname, timeout))
}
fn daemon_thread(signal_sock: MioUdpSocket, poller: Poll, receiver: Receiver<Command>) {
let mut zc = Zeroconf::new(signal_sock, poller);
if let Some(cmd) = zc.run(receiver) {
match cmd {
Command::Exit(resp_s) => {
if let Err(e) = resp_s.send(DaemonStatus::Shutdown) {
debug!("exit: failed to send response of shutdown: {}", e);
}
}
_ => {
debug!("Unexpected command: {:?}", cmd);
}
}
}
}
}
fn _new_socket_bind(intf: &Interface, should_loop: bool) -> Result<MyUdpSocket> {
let intf_ip = &intf.ip();
match intf_ip {
IpAddr::V4(ip) => {
let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), MDNS_PORT);
let sock = new_socket(addr.into(), true)?;
sock.join_multicast_v4(&GROUP_ADDR_V4, ip)
.map_err(|e| e_fmt!("join multicast group on addr {}: {}", intf_ip, e))?;
sock.set_multicast_if_v4(ip)
.map_err(|e| e_fmt!("set multicast_if on addr {}: {}", ip, e))?;
sock.set_multicast_ttl_v4(255)
.map_err(|e| e_fmt!("set set_multicast_ttl_v4 on addr {}: {}", ip, e))?;
if !should_loop {
sock.set_multicast_loop_v4(false)
.map_err(|e| e_fmt!("failed to set multicast loop v4 for {ip}: {e}"))?;
}
let multicast_addr = SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT).into();
let test_packets = DnsOutgoing::new(0).to_data_on_wire();
for packet in test_packets {
sock.send_to(&packet, &multicast_addr)
.map_err(|e| e_fmt!("send multicast packet on addr {}: {}", ip, e))?;
}
MyUdpSocket::new(sock)
.map_err(|e| e_fmt!("failed to create MySocket for interface {}: {e}", intf.name))
}
IpAddr::V6(ip) => {
let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), MDNS_PORT, 0, 0);
let sock = new_socket(addr.into(), true)?;
let if_index = intf.index.unwrap_or(0);
sock.join_multicast_v6(&GROUP_ADDR_V6, if_index)
.map_err(|e| e_fmt!("join multicast group on addr {}: {}", ip, e))?;
sock.set_multicast_if_v6(if_index)
.map_err(|e| e_fmt!("set multicast_if on addr {}: {}", ip, e))?;
MyUdpSocket::new(sock)
.map_err(|e| e_fmt!("failed to create MySocket for interface {}: {e}", intf.name))
}
}
}
fn new_socket(addr: SocketAddr, non_block: bool) -> Result<PktInfoUdpSocket> {
let domain = match addr {
SocketAddr::V4(_) => socket2::Domain::IPV4,
SocketAddr::V6(_) => socket2::Domain::IPV6,
};
let fd = PktInfoUdpSocket::new(domain).map_err(|e| e_fmt!("create socket failed: {}", e))?;
fd.set_reuse_address(true)
.map_err(|e| e_fmt!("set ReuseAddr failed: {}", e))?;
#[cfg(unix)] fd.set_reuse_port(true)
.map_err(|e| e_fmt!("set ReusePort failed: {}", e))?;
if non_block {
fd.set_nonblocking(true)
.map_err(|e| e_fmt!("set O_NONBLOCK: {}", e))?;
}
fd.bind(&addr.into())
.map_err(|e| e_fmt!("socket bind to {} failed: {}", &addr, e))?;
trace!("new socket bind to {}", &addr);
Ok(fd)
}
struct ReRun {
next_time: u64,
command: Command,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum IfKind {
All,
IPv4,
IPv6,
Name(String),
Addr(IpAddr),
LoopbackV4,
LoopbackV6,
}
impl IfKind {
fn matches(&self, intf: &Interface) -> bool {
match self {
Self::All => true,
Self::IPv4 => intf.ip().is_ipv4(),
Self::IPv6 => intf.ip().is_ipv6(),
Self::Name(ifname) => ifname == &intf.name,
Self::Addr(addr) => addr == &intf.ip(),
Self::LoopbackV4 => intf.is_loopback() && intf.ip().is_ipv4(),
Self::LoopbackV6 => intf.is_loopback() && intf.ip().is_ipv6(),
}
}
}
impl From<&str> for IfKind {
fn from(val: &str) -> Self {
Self::Name(val.to_string())
}
}
impl From<&String> for IfKind {
fn from(val: &String) -> Self {
Self::Name(val.to_string())
}
}
impl From<IpAddr> for IfKind {
fn from(val: IpAddr) -> Self {
Self::Addr(val)
}
}
pub struct IfKindVec {
kinds: Vec<IfKind>,
}
pub trait IntoIfKindVec {
fn into_vec(self) -> IfKindVec;
}
impl<T: Into<IfKind>> IntoIfKindVec for T {
fn into_vec(self) -> IfKindVec {
let if_kind: IfKind = self.into();
IfKindVec {
kinds: vec![if_kind],
}
}
}
impl<T: Into<IfKind>> IntoIfKindVec for Vec<T> {
fn into_vec(self) -> IfKindVec {
let kinds: Vec<IfKind> = self.into_iter().map(|x| x.into()).collect();
IfKindVec { kinds }
}
}
struct IfSelection {
if_kind: IfKind,
selected: bool,
}
struct Zeroconf {
my_intfs: HashMap<u32, MyIntf>,
ipv4_sock: MyUdpSocket,
ipv6_sock: MyUdpSocket,
my_services: HashMap<String, ServiceInfo>,
cache: DnsCache,
dns_registry_map: HashMap<u32, DnsRegistry>,
service_queriers: HashMap<String, Sender<ServiceEvent>>,
hostname_resolvers: HashMap<String, (Sender<HostnameResolutionEvent>, Option<u64>)>,
retransmissions: Vec<ReRun>,
counters: Metrics,
poller: Poll,
monitors: Vec<Sender<DaemonEvent>>,
service_name_len_max: u8,
ip_check_interval: u64,
if_selections: Vec<IfSelection>,
signal_sock: MioUdpSocket,
timers: BinaryHeap<Reverse<u64>>,
status: DaemonStatus,
pending_resolves: HashSet<String>,
resolved: HashSet<String>,
multicast_loop_v4: bool,
multicast_loop_v6: bool,
accept_unsolicited: bool,
#[cfg(test)]
test_down_interfaces: HashSet<String>,
}
fn join_multicast_group(my_sock: &PktInfoUdpSocket, intf: &Interface) -> Result<()> {
let intf_ip = &intf.ip();
match intf_ip {
IpAddr::V4(ip) => {
debug!("join multicast group V4 on addr {}", ip);
my_sock
.join_multicast_v4(&GROUP_ADDR_V4, ip)
.map_err(|e| e_fmt!("PKT join multicast group on addr {}: {}", intf_ip, e))?;
}
IpAddr::V6(ip) => {
let if_index = intf.index.unwrap_or(0);
debug!(
"join multicast group V6 on addr {} with index {}",
ip, if_index
);
my_sock
.join_multicast_v6(&GROUP_ADDR_V6, if_index)
.map_err(|e| e_fmt!("PKT join multicast group on addr {}: {}", ip, e))?;
}
}
Ok(())
}
impl Zeroconf {
fn new(signal_sock: MioUdpSocket, poller: Poll) -> Self {
let my_ifaddrs = my_ip_interfaces(false);
let mut my_intfs = HashMap::new();
let mut dns_registry_map = HashMap::new();
let addr = SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), MDNS_PORT);
let sock = new_socket(addr.into(), true).unwrap();
sock.set_multicast_ttl_v4(255)
.map_err(|e| e_fmt!("set set_multicast_ttl_v4 on addr: {}", e))
.unwrap();
let ipv4_sock = MyUdpSocket::new(sock).unwrap();
let addr = SocketAddrV6::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), MDNS_PORT, 0, 0);
let sock = new_socket(addr.into(), true).unwrap();
sock.set_multicast_hops_v6(255)
.map_err(|e| e_fmt!("set set_multicast_hops_v6: {}", e))
.unwrap();
let ipv6_sock = MyUdpSocket::new(sock).unwrap();
for intf in my_ifaddrs {
let sock = if intf.ip().is_ipv4() {
&ipv4_sock
} else {
&ipv6_sock
};
if let Err(e) = join_multicast_group(&sock.pktinfo, &intf) {
debug!(
"config socket to join multicast: {}: {e}. Skipped.",
&intf.ip()
);
}
let if_index = intf.index.unwrap_or(0);
dns_registry_map
.entry(if_index)
.or_insert_with(DnsRegistry::new);
my_intfs
.entry(if_index)
.and_modify(|v: &mut MyIntf| {
v.addrs.insert(intf.addr.clone());
})
.or_insert(MyIntf {
name: intf.name.clone(),
index: if_index,
addrs: HashSet::from([intf.addr]),
});
}
let monitors = Vec::new();
let service_name_len_max = SERVICE_NAME_LEN_MAX_DEFAULT;
let ip_check_interval = IP_CHECK_INTERVAL_IN_SECS_DEFAULT as u64 * 1000;
let timers = BinaryHeap::new();
let if_selections = vec![
IfSelection {
if_kind: IfKind::LoopbackV4,
selected: false,
},
IfSelection {
if_kind: IfKind::LoopbackV6,
selected: false,
},
];
let status = DaemonStatus::Running;
Self {
my_intfs,
ipv4_sock,
ipv6_sock,
my_services: HashMap::new(),
cache: DnsCache::new(),
dns_registry_map,
hostname_resolvers: HashMap::new(),
service_queriers: HashMap::new(),
retransmissions: Vec::new(),
counters: HashMap::new(),
poller,
monitors,
service_name_len_max,
ip_check_interval,
if_selections,
signal_sock,
timers,
status,
pending_resolves: HashSet::new(),
resolved: HashSet::new(),
multicast_loop_v4: true,
multicast_loop_v6: true,
accept_unsolicited: false,
#[cfg(test)]
test_down_interfaces: HashSet::new(),
}
}
fn run(&mut self, receiver: Receiver<Command>) -> Option<Command> {
if let Err(e) = self.poller.registry().register(
&mut self.signal_sock,
mio::Token(SIGNAL_SOCK_EVENT_KEY),
mio::Interest::READABLE,
) {
debug!("failed to add signal socket to the poller: {}", e);
return None;
}
if let Err(e) = self.poller.registry().register(
&mut self.ipv4_sock,
mio::Token(IPV4_SOCK_EVENT_KEY),
mio::Interest::READABLE,
) {
debug!("failed to register ipv4 socket: {}", e);
return None;
}
if let Err(e) = self.poller.registry().register(
&mut self.ipv6_sock,
mio::Token(IPV6_SOCK_EVENT_KEY),
mio::Interest::READABLE,
) {
debug!("failed to register ipv6 socket: {}", e);
return None;
}
let mut next_ip_check = if self.ip_check_interval > 0 {
current_time_millis() + self.ip_check_interval
} else {
0
};
if next_ip_check > 0 {
self.add_timer(next_ip_check);
}
let mut events = mio::Events::with_capacity(1024);
loop {
let now = current_time_millis();
let earliest_timer = self.peek_earliest_timer();
let timeout = earliest_timer.map(|timer| {
let millis = if timer > now { timer - now } else { 1 };
Duration::from_millis(millis)
});
events.clear();
match self.poller.poll(&mut events, timeout) {
Ok(_) => self.handle_poller_events(&events),
Err(e) => debug!("failed to select from sockets: {}", e),
}
let now = current_time_millis();
self.pop_timers_till(now);
for hostname in self
.hostname_resolvers
.clone()
.into_iter()
.filter(|(_, (_, timeout))| timeout.map(|t| now >= t).unwrap_or(false))
.map(|(hostname, _)| hostname)
{
trace!("hostname resolver timeout for {}", &hostname);
call_hostname_resolution_listener(
&self.hostname_resolvers,
&hostname,
HostnameResolutionEvent::SearchTimeout(hostname.to_owned()),
);
call_hostname_resolution_listener(
&self.hostname_resolvers,
&hostname,
HostnameResolutionEvent::SearchStopped(hostname.to_owned()),
);
self.hostname_resolvers.remove(&hostname);
}
while let Ok(command) = receiver.try_recv() {
if matches!(command, Command::Exit(_)) {
self.status = DaemonStatus::Shutdown;
return Some(command);
}
self.exec_command(command, false);
}
let mut i = 0;
while i < self.retransmissions.len() {
if now >= self.retransmissions[i].next_time {
let rerun = self.retransmissions.remove(i);
self.exec_command(rerun.command, true);
} else {
i += 1;
}
}
self.refresh_active_services();
let mut query_count = 0;
for (hostname, _sender) in self.hostname_resolvers.iter() {
for (hostname, ip_addr) in
self.cache.refresh_due_hostname_resolutions(hostname).iter()
{
self.send_query(hostname, ip_address_rr_type(&ip_addr.to_ip_addr()));
query_count += 1;
}
}
self.increase_counter(Counter::CacheRefreshAddr, query_count);
let now = current_time_millis();
let expired_services = self.cache.evict_expired_services(now);
if !expired_services.is_empty() {
debug!(
"run: send {} service removal to listeners",
expired_services.len()
);
self.notify_service_removal(expired_services);
}
let expired_addrs = self.cache.evict_expired_addr(now);
for (hostname, addrs) in expired_addrs {
call_hostname_resolution_listener(
&self.hostname_resolvers,
&hostname,
HostnameResolutionEvent::AddressesRemoved(hostname.clone(), addrs),
);
let instances = self.cache.get_instances_on_host(&hostname);
let instance_set: HashSet<String> = instances.into_iter().collect();
self.resolve_updated_instances(&instance_set);
}
self.probing_handler();
if now >= next_ip_check && next_ip_check > 0 {
next_ip_check = now + self.ip_check_interval;
self.add_timer(next_ip_check);
self.check_ip_changes();
}
}
}
fn process_set_option(&mut self, daemon_opt: DaemonOption) {
match daemon_opt {
DaemonOption::ServiceNameLenMax(length) => self.service_name_len_max = length,
DaemonOption::IpCheckInterval(interval) => self.ip_check_interval = interval,
DaemonOption::EnableInterface(if_kind) => self.enable_interface(if_kind),
DaemonOption::DisableInterface(if_kind) => self.disable_interface(if_kind),
DaemonOption::MulticastLoopV4(on) => self.set_multicast_loop_v4(on),
DaemonOption::MulticastLoopV6(on) => self.set_multicast_loop_v6(on),
DaemonOption::AcceptUnsolicited(accept) => self.set_accept_unsolicited(accept),
#[cfg(test)]
DaemonOption::TestDownInterface(ifname) => {
self.test_down_interfaces.insert(ifname);
}
#[cfg(test)]
DaemonOption::TestUpInterface(ifname) => {
self.test_down_interfaces.remove(&ifname);
}
}
}
fn enable_interface(&mut self, kinds: Vec<IfKind>) {
debug!("enable_interface: {:?}", kinds);
for if_kind in kinds {
self.if_selections.push(IfSelection {
if_kind,
selected: true,
});
}
self.apply_intf_selections(my_ip_interfaces(true));
}
fn disable_interface(&mut self, kinds: Vec<IfKind>) {
debug!("disable_interface: {:?}", kinds);
for if_kind in kinds {
self.if_selections.push(IfSelection {
if_kind,
selected: false,
});
}
self.apply_intf_selections(my_ip_interfaces(true));
}
fn set_multicast_loop_v4(&mut self, on: bool) {
self.multicast_loop_v4 = on;
self.ipv4_sock
.pktinfo
.set_multicast_loop_v4(on)
.map_err(|e| e_fmt!("failed to set multicast loop v4: {}", e))
.unwrap();
}
fn set_multicast_loop_v6(&mut self, on: bool) {
self.multicast_loop_v6 = on;
self.ipv6_sock
.pktinfo
.set_multicast_loop_v6(on)
.map_err(|e| e_fmt!("failed to set multicast loop v6: {}", e))
.unwrap();
}
fn set_accept_unsolicited(&mut self, accept: bool) {
self.accept_unsolicited = accept;
}
fn notify_monitors(&mut self, event: DaemonEvent) {
self.monitors.retain(|sender| {
if let Err(e) = sender.try_send(event.clone()) {
debug!("notify_monitors: try_send: {}", &e);
if matches!(e, TrySendError::Disconnected(_)) {
return false; }
}
true
});
}
fn del_addr_in_my_services(&mut self, addr: &IpAddr) {
for (_, service_info) in self.my_services.iter_mut() {
if service_info.is_addr_auto() {
service_info.remove_ipaddr(addr);
}
}
}
fn add_timer(&mut self, next_time: u64) {
self.timers.push(Reverse(next_time));
}
fn peek_earliest_timer(&self) -> Option<u64> {
self.timers.peek().map(|Reverse(v)| *v)
}
fn _pop_earliest_timer(&mut self) -> Option<u64> {
self.timers.pop().map(|Reverse(v)| v)
}
fn pop_timers_till(&mut self, now: u64) {
while let Some(Reverse(v)) = self.timers.peek() {
if *v > now {
break;
}
self.timers.pop();
}
}
fn selected_addrs(&self, interfaces: Vec<Interface>) -> HashSet<IpAddr> {
let intf_count = interfaces.len();
let mut intf_selections = vec![true; intf_count];
for selection in self.if_selections.iter() {
for i in 0..intf_count {
if selection.if_kind.matches(&interfaces[i]) {
intf_selections[i] = selection.selected;
}
}
}
let mut selected_addrs = HashSet::new();
for i in 0..intf_count {
if intf_selections[i] {
selected_addrs.insert(interfaces[i].addr.ip());
}
}
selected_addrs
}
fn apply_intf_selections(&mut self, interfaces: Vec<Interface>) {
let intf_count = interfaces.len();
let mut intf_selections = vec![true; intf_count];
for selection in self.if_selections.iter() {
for i in 0..intf_count {
if selection.if_kind.matches(&interfaces[i]) {
intf_selections[i] = selection.selected;
}
}
}
for (idx, intf) in interfaces.into_iter().enumerate() {
if intf_selections[idx] {
self.add_interface(intf);
} else {
self.del_interface(&intf);
}
}
}
fn del_ip(&mut self, ip: IpAddr) {
self.del_addr_in_my_services(&ip);
self.notify_monitors(DaemonEvent::IpDel(ip));
}
fn check_ip_changes(&mut self) {
let my_ifaddrs = my_ip_interfaces(true);
#[cfg(test)]
let my_ifaddrs: Vec<_> = my_ifaddrs
.into_iter()
.filter(|intf| !self.test_down_interfaces.contains(&intf.name))
.collect();
let ifaddrs_map: HashMap<u32, Vec<&IfAddr>> =
my_ifaddrs.iter().fold(HashMap::new(), |mut acc, intf| {
let if_index = intf.index.unwrap_or(0);
acc.entry(if_index).or_default().push(&intf.addr);
acc
});
let mut deleted_intfs = Vec::new();
let mut deleted_ips = Vec::new();
for (if_index, my_intf) in self.my_intfs.iter_mut() {
let mut last_ipv4 = None;
let mut last_ipv6 = None;
if let Some(current_addrs) = ifaddrs_map.get(if_index) {
my_intf.addrs.retain(|addr| {
if current_addrs.contains(&addr) {
true
} else {
match addr.ip() {
IpAddr::V4(ipv4) => last_ipv4 = Some(ipv4),
IpAddr::V6(ipv6) => last_ipv6 = Some(ipv6),
}
deleted_ips.push(addr.ip());
false
}
});
if my_intf.addrs.is_empty() {
deleted_intfs.push((*if_index, last_ipv4, last_ipv6))
}
} else {
debug!(
"check_ip_changes: interface {} ({}) no longer exists, removing",
my_intf.name, if_index
);
for addr in my_intf.addrs.iter() {
match addr.ip() {
IpAddr::V4(ipv4) => last_ipv4 = Some(ipv4),
IpAddr::V6(ipv6) => last_ipv6 = Some(ipv6),
}
deleted_ips.push(addr.ip())
}
deleted_intfs.push((*if_index, last_ipv4, last_ipv6));
}
}
if !deleted_ips.is_empty() || !deleted_intfs.is_empty() {
debug!(
"check_ip_changes: {} deleted ips {} deleted intfs",
deleted_ips.len(),
deleted_intfs.len()
);
}
for ip in deleted_ips {
self.del_ip(ip);
}
for (if_index, last_ipv4, last_ipv6) in deleted_intfs {
let Some(my_intf) = self.my_intfs.remove(&if_index) else {
continue;
};
if let Some(ipv4) = last_ipv4 {
debug!("leave multicast for {ipv4}");
if let Err(e) = self
.ipv4_sock
.pktinfo
.leave_multicast_v4(&GROUP_ADDR_V4, &ipv4)
{
debug!("leave multicast group for addr {ipv4}: {e}");
}
}
if let Some(ipv6) = last_ipv6 {
debug!("leave multicast for {ipv6}");
if let Err(e) = self
.ipv6_sock
.pktinfo
.leave_multicast_v6(&GROUP_ADDR_V6, my_intf.index)
{
debug!("leave multicast group for IPv6: {ipv6}: {e}");
}
}
let intf_id = InterfaceId {
name: my_intf.name.to_string(),
index: my_intf.index,
};
let removed_instances = self.cache.remove_records_on_intf(intf_id);
self.notify_service_removal(removed_instances);
}
self.apply_intf_selections(my_ifaddrs);
}
fn del_interface(&mut self, intf: &Interface) {
let if_index = intf.index.unwrap_or(0);
trace!(
"del_interface: {} ({if_index}) addr {}",
intf.name,
intf.ip()
);
let Some(my_intf) = self.my_intfs.get_mut(&if_index) else {
debug!("del_interface: interface {} not found", intf.name);
return;
};
let mut ip_removed = false;
if my_intf.addrs.remove(&intf.addr) {
ip_removed = true;
match intf.addr.ip() {
IpAddr::V4(ipv4) => {
if my_intf.next_ifaddr_v4().is_none() {
if let Err(e) = self
.ipv4_sock
.pktinfo
.leave_multicast_v4(&GROUP_ADDR_V4, &ipv4)
{
debug!("leave multicast group for addr {ipv4}: {e}");
}
}
}
IpAddr::V6(ipv6) => {
if my_intf.next_ifaddr_v6().is_none() {
if let Err(e) = self
.ipv6_sock
.pktinfo
.leave_multicast_v6(&GROUP_ADDR_V6, if_index)
{
debug!("leave multicast group for addr {ipv6}: {e}");
}
}
}
}
if my_intf.addrs.is_empty() {
debug!("del_interface: removing interface {}", intf.name);
self.my_intfs.remove(&if_index);
self.dns_registry_map.remove(&if_index);
self.cache.remove_addrs_on_disabled_intf(if_index);
}
}
if ip_removed {
self.notify_monitors(DaemonEvent::IpDel(intf.ip()));
self.del_addr_in_my_services(&intf.ip());
}
}
fn add_interface(&mut self, intf: Interface) {
let sock = if intf.ip().is_ipv4() {
&self.ipv4_sock
} else {
&self.ipv6_sock
};
let if_index = intf.index.unwrap_or(0);
let mut new_addr = false;
match self.my_intfs.entry(if_index) {
Entry::Occupied(mut entry) => {
let my_intf = entry.get_mut();
if !my_intf.addrs.contains(&intf.addr) {
if let Err(e) = join_multicast_group(&sock.pktinfo, &intf) {
debug!("add_interface: socket_config {}: {e}", &intf.name);
}
my_intf.addrs.insert(intf.addr.clone());
new_addr = true;
}
}
Entry::Vacant(entry) => {
if let Err(e) = join_multicast_group(&sock.pktinfo, &intf) {
debug!("add_interface: socket_config {}: {e}. Skipped.", &intf.name);
return;
}
new_addr = true;
let new_intf = MyIntf {
name: intf.name.clone(),
index: if_index,
addrs: HashSet::from([intf.addr.clone()]),
};
entry.insert(new_intf);
}
}
if !new_addr {
trace!("add_interface: interface {} already exists", &intf.name);
return;
}
debug!("add new interface {}: {}", intf.name, intf.ip());
let Some(my_intf) = self.my_intfs.get(&if_index) else {
debug!("add_interface: cannot find if_index {if_index}");
return;
};
let dns_registry = match self.dns_registry_map.get_mut(&if_index) {
Some(registry) => registry,
None => self
.dns_registry_map
.entry(if_index)
.or_insert_with(DnsRegistry::new),
};
for (_, service_info) in self.my_services.iter_mut() {
if service_info.is_addr_auto() {
let new_ip = intf.ip();
service_info.insert_ipaddr(new_ip);
if announce_service_on_intf(dns_registry, service_info, my_intf, &sock.pktinfo) {
debug!(
"Announce service {} on {}",
service_info.get_fullname(),
intf.ip()
);
service_info.set_status(if_index, ServiceStatus::Announced);
} else {
for timer in dns_registry.new_timers.drain(..) {
self.timers.push(Reverse(timer));
}
service_info.set_status(if_index, ServiceStatus::Probing);
}
}
}
let mut browse_reruns = Vec::new();
let mut i = 0;
while i < self.retransmissions.len() {
if matches!(self.retransmissions[i].command, Command::Browse(..)) {
browse_reruns.push(self.retransmissions.remove(i));
} else {
i += 1;
}
}
for rerun in browse_reruns {
self.exec_command(rerun.command, true);
}
self.notify_monitors(DaemonEvent::IpAdd(intf.ip()));
}
fn register_service(&mut self, mut info: ServiceInfo) {
if let Err(e) = check_service_name_length(info.get_type(), self.service_name_len_max) {
debug!("check_service_name_length: {}", &e);
self.notify_monitors(DaemonEvent::Error(e));
return;
}
if info.is_addr_auto() {
let selected_addrs = self.selected_addrs(my_ip_interfaces(true));
for addr in selected_addrs {
info.insert_ipaddr(addr);
}
}
debug!("register service {:?}", &info);
let outgoing_addrs = self.send_unsolicited_response(&mut info);
if !outgoing_addrs.is_empty() {
self.notify_monitors(DaemonEvent::Announce(
info.get_fullname().to_string(),
format!("{:?}", &outgoing_addrs),
));
}
let service_fullname = info.get_fullname().to_lowercase();
self.my_services.insert(service_fullname, info);
}
fn send_unsolicited_response(&mut self, info: &mut ServiceInfo) -> Vec<IpAddr> {
let mut outgoing_addrs = Vec::new();
let mut outgoing_intfs = HashSet::new();
for (if_index, intf) in self.my_intfs.iter() {
let dns_registry = match self.dns_registry_map.get_mut(if_index) {
Some(registry) => registry,
None => self
.dns_registry_map
.entry(*if_index)
.or_insert_with(DnsRegistry::new),
};
let mut announced = false;
if announce_service_on_intf(dns_registry, info, intf, &self.ipv4_sock.pktinfo) {
for addr in intf.addrs.iter().filter(|a| a.ip().is_ipv4()) {
outgoing_addrs.push(addr.ip());
}
outgoing_intfs.insert(intf.index);
debug!(
"Announce service IPv4 {} on {}",
info.get_fullname(),
intf.name
);
announced = true;
}
if announce_service_on_intf(dns_registry, info, intf, &self.ipv6_sock.pktinfo) {
for addr in intf.addrs.iter().filter(|a| a.ip().is_ipv6()) {
outgoing_addrs.push(addr.ip());
}
outgoing_intfs.insert(intf.index);
debug!(
"Announce service IPv6 {} on {}",
info.get_fullname(),
intf.name
);
announced = true;
}
if announced {
info.set_status(intf.index, ServiceStatus::Announced);
} else {
for timer in dns_registry.new_timers.drain(..) {
self.timers.push(Reverse(timer));
}
info.set_status(*if_index, ServiceStatus::Probing);
}
}
let next_time = current_time_millis() + 1000;
for if_index in outgoing_intfs {
self.add_retransmission(
next_time,
Command::RegisterResend(info.get_fullname().to_string(), if_index),
);
}
outgoing_addrs
}
fn probing_handler(&mut self) {
let now = current_time_millis();
for (if_index, intf) in self.my_intfs.iter() {
let Some(dns_registry) = self.dns_registry_map.get_mut(if_index) else {
continue;
};
let (out, expired_probes) = check_probing(dns_registry, &mut self.timers, now);
if !out.questions().is_empty() {
trace!("sending out probing of questions: {:?}", out.questions());
send_dns_outgoing(&out, intf, &self.ipv4_sock.pktinfo);
send_dns_outgoing(&out, intf, &self.ipv6_sock.pktinfo);
}
let waiting_services =
handle_expired_probes(expired_probes, &intf.name, dns_registry, &mut self.monitors);
for service_name in waiting_services {
if let Some(info) = self.my_services.get_mut(&service_name.to_lowercase()) {
if info.get_status(*if_index) == ServiceStatus::Announced {
debug!("service {} already announced", info.get_fullname());
continue;
}
let announced_v4 =
announce_service_on_intf(dns_registry, info, intf, &self.ipv4_sock.pktinfo);
let announced_v6 =
announce_service_on_intf(dns_registry, info, intf, &self.ipv6_sock.pktinfo);
if announced_v4 || announced_v6 {
let next_time = now + 1000;
let command =
Command::RegisterResend(info.get_fullname().to_string(), *if_index);
self.retransmissions.push(ReRun { next_time, command });
self.timers.push(Reverse(next_time));
let fullname = match dns_registry.name_changes.get(&service_name) {
Some(new_name) => new_name.to_string(),
None => service_name.to_string(),
};
let mut hostname = info.get_hostname();
if let Some(new_name) = dns_registry.name_changes.get(hostname) {
hostname = new_name;
}
debug!("wake up: announce service {} on {}", fullname, intf.name);
notify_monitors(
&mut self.monitors,
DaemonEvent::Announce(fullname, format!("{}:{}", hostname, &intf.name)),
);
info.set_status(*if_index, ServiceStatus::Announced);
}
}
}
}
}
fn unregister_service(
&self,
info: &ServiceInfo,
intf: &MyIntf,
sock: &PktInfoUdpSocket,
) -> Vec<u8> {
let is_ipv4 = sock.domain() == Domain::IPV4;
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
out.add_answer_at_time(
DnsPointer::new(
info.get_type(),
RRType::PTR,
CLASS_IN,
0,
info.get_fullname().to_string(),
),
0,
);
if let Some(sub) = info.get_subtype() {
trace!("Adding subdomain {}", sub);
out.add_answer_at_time(
DnsPointer::new(
sub,
RRType::PTR,
CLASS_IN,
0,
info.get_fullname().to_string(),
),
0,
);
}
out.add_answer_at_time(
DnsSrv::new(
info.get_fullname(),
CLASS_IN | CLASS_CACHE_FLUSH,
0,
info.get_priority(),
info.get_weight(),
info.get_port(),
info.get_hostname().to_string(),
),
0,
);
out.add_answer_at_time(
DnsTxt::new(
info.get_fullname(),
CLASS_IN | CLASS_CACHE_FLUSH,
0,
info.generate_txt(),
),
0,
);
let if_addrs = if is_ipv4 {
info.get_addrs_on_my_intf_v4(intf)
} else {
info.get_addrs_on_my_intf_v6(intf)
};
if if_addrs.is_empty() {
return vec![];
}
for address in if_addrs {
out.add_answer_at_time(
DnsAddress::new(
info.get_hostname(),
ip_address_rr_type(&address),
CLASS_IN | CLASS_CACHE_FLUSH,
0,
address,
intf.into(),
),
0,
);
}
send_dns_outgoing(&out, intf, sock).remove(0)
}
fn add_hostname_resolver(
&mut self,
hostname: String,
listener: Sender<HostnameResolutionEvent>,
timeout: Option<u64>,
) {
let real_timeout = timeout.map(|t| current_time_millis() + t);
self.hostname_resolvers
.insert(hostname.to_lowercase(), (listener, real_timeout));
if let Some(t) = real_timeout {
self.add_timer(t);
}
}
fn send_query(&self, name: &str, qtype: RRType) {
self.send_query_vec(&[(name, qtype)]);
}
fn send_query_vec(&self, questions: &[(&str, RRType)]) {
trace!("Sending query questions: {:?}", questions);
let mut out = DnsOutgoing::new(FLAGS_QR_QUERY);
let now = current_time_millis();
for (name, qtype) in questions {
out.add_question(name, *qtype);
for record in self.cache.get_known_answers(name, *qtype, now) {
trace!("add known answer: {:?}", record.record);
let mut new_record = record.record.clone();
new_record.get_record_mut().update_ttl(now);
out.add_answer_box(new_record);
}
}
for (_, intf) in self.my_intfs.iter() {
send_dns_outgoing(&out, intf, &self.ipv4_sock.pktinfo);
send_dns_outgoing(&out, intf, &self.ipv6_sock.pktinfo);
}
}
fn handle_read(&mut self, event_key: usize) -> bool {
let sock = match event_key {
IPV4_SOCK_EVENT_KEY => &mut self.ipv4_sock,
IPV6_SOCK_EVENT_KEY => &mut self.ipv6_sock,
_ => {
debug!("handle_read: unknown token {}", event_key);
return false;
}
};
let mut buf = vec![0u8; MAX_MSG_ABSOLUTE];
let (sz, pktinfo) = match sock.pktinfo.recv(&mut buf) {
Ok(sz) => sz,
Err(e) => {
if e.kind() != std::io::ErrorKind::WouldBlock {
debug!("listening socket read failed: {}", e);
}
return false;
}
};
let pkt_if_index = pktinfo.if_index as u32;
let Some(my_intf) = self.my_intfs.get(&pkt_if_index) else {
debug!(
"handle_read: no interface found for pktinfo if_index: {}",
pktinfo.if_index
);
return true; };
buf.truncate(sz);
match DnsIncoming::new(buf, my_intf.into()) {
Ok(msg) => {
if msg.is_query() {
self.handle_query(msg, pkt_if_index, event_key == IPV4_SOCK_EVENT_KEY);
} else if msg.is_response() {
self.handle_response(msg, pkt_if_index);
} else {
debug!("Invalid message: not query and not response");
}
}
Err(e) => debug!("Invalid incoming DNS message: {}", e),
}
true
}
fn query_unresolved(&mut self, instance: &str) -> bool {
if !valid_instance_name(instance) {
trace!("instance name {} not valid", instance);
return false;
}
if let Some(records) = self.cache.get_srv(instance) {
for record in records {
if let Some(srv) = record.record.any().downcast_ref::<DnsSrv>() {
if self.cache.get_addr(srv.host()).is_none() {
self.send_query_vec(&[(srv.host(), RRType::A), (srv.host(), RRType::AAAA)]);
return true;
}
}
}
} else {
self.send_query(instance, RRType::ANY);
return true;
}
false
}
fn query_cache_for_service(
&mut self,
ty_domain: &str,
sender: &Sender<ServiceEvent>,
now: u64,
) {
let mut resolved: HashSet<String> = HashSet::new();
let mut unresolved: HashSet<String> = HashSet::new();
if let Some(records) = self.cache.get_ptr(ty_domain) {
for record in records.iter().filter(|r| !r.record.expires_soon(now)) {
if let Some(ptr) = record.record.any().downcast_ref::<DnsPointer>() {
let mut new_event = None;
match self.resolve_service_from_cache(ty_domain, ptr.alias()) {
Ok(resolved_service) => {
if resolved_service.is_valid() {
debug!("Resolved service from cache: {}", ptr.alias());
new_event =
Some(ServiceEvent::ServiceResolved(Box::new(resolved_service)));
} else {
debug!("Resolved service is not valid: {}", ptr.alias());
}
}
Err(err) => {
debug!("Error while resolving service from cache: {}", err);
continue;
}
}
match sender.send(ServiceEvent::ServiceFound(
ty_domain.to_string(),
ptr.alias().to_string(),
)) {
Ok(()) => debug!("sent service found {}", ptr.alias()),
Err(e) => {
debug!("failed to send service found: {}", e);
continue;
}
}
if let Some(event) = new_event {
resolved.insert(ptr.alias().to_string());
match sender.send(event) {
Ok(()) => debug!("sent service resolved: {}", ptr.alias()),
Err(e) => debug!("failed to send service resolved: {}", e),
}
} else {
unresolved.insert(ptr.alias().to_string());
}
}
}
}
for instance in resolved.drain() {
self.pending_resolves.remove(&instance);
self.resolved.insert(instance);
}
for instance in unresolved.drain() {
self.add_pending_resolve(instance);
}
}
fn query_cache_for_hostname(
&mut self,
hostname: &str,
sender: Sender<HostnameResolutionEvent>,
) {
let addresses_map = self.cache.get_addresses_for_host(hostname);
for (name, addresses) in addresses_map {
match sender.send(HostnameResolutionEvent::AddressesFound(name, addresses)) {
Ok(()) => trace!("sent hostname addresses found"),
Err(e) => debug!("failed to send hostname addresses found: {}", e),
}
}
}
fn add_pending_resolve(&mut self, instance: String) {
if !self.pending_resolves.contains(&instance) {
let next_time = current_time_millis() + RESOLVE_WAIT_IN_MILLIS;
self.add_retransmission(next_time, Command::Resolve(instance.clone(), 1));
self.pending_resolves.insert(instance);
}
}
fn resolve_service_from_cache(
&self,
ty_domain: &str,
fullname: &str,
) -> Result<ResolvedService> {
let now = current_time_millis();
let mut resolved_service = ResolvedService {
ty_domain: ty_domain.to_string(),
sub_ty_domain: None,
fullname: fullname.to_string(),
host: String::new(),
port: 0,
addresses: HashSet::new(),
txt_properties: TxtProperties::new(),
};
if let Some(subtype) = self.cache.get_subtype(fullname) {
trace!(
"ty_domain: {} found subtype {} for instance: {}",
ty_domain,
subtype,
fullname
);
if resolved_service.sub_ty_domain.is_none() {
resolved_service.sub_ty_domain = Some(subtype.to_string());
}
}
if let Some(records) = self.cache.get_srv(fullname) {
if let Some(answer) = records.iter().find(|r| !r.record.expires_soon(now)) {
if let Some(dns_srv) = answer.record.any().downcast_ref::<DnsSrv>() {
resolved_service.host = dns_srv.host().to_string();
resolved_service.port = dns_srv.port();
}
}
}
if let Some(records) = self.cache.get_txt(fullname) {
if let Some(record) = records.iter().find(|r| !r.record.expires_soon(now)) {
if let Some(dns_txt) = record.record.any().downcast_ref::<DnsTxt>() {
resolved_service.txt_properties = dns_txt.text().into();
}
}
}
if let Some(records) = self.cache.get_addr(&resolved_service.host) {
for answer in records.iter() {
if let Some(dns_a) = answer.record.any().downcast_ref::<DnsAddress>() {
if dns_a.expires_soon(now) {
trace!(
"Addr expired or expires soon: {}",
dns_a.address().to_ip_addr()
);
} else {
resolved_service.addresses.insert(dns_a.address());
}
}
}
}
Ok(resolved_service)
}
fn handle_poller_events(&mut self, events: &mio::Events) {
for ev in events.iter() {
trace!("event received with key {:?}", ev.token());
if ev.token().0 == SIGNAL_SOCK_EVENT_KEY {
self.signal_sock_drain();
if let Err(e) = self.poller.registry().reregister(
&mut self.signal_sock,
ev.token(),
mio::Interest::READABLE,
) {
debug!("failed to modify poller for signal socket: {}", e);
}
continue; }
while self.handle_read(ev.token().0) {}
if ev.token().0 == IPV4_SOCK_EVENT_KEY {
if let Err(e) = self.poller.registry().reregister(
&mut self.ipv4_sock,
ev.token(),
mio::Interest::READABLE,
) {
debug!("modify poller for IPv4 socket: {}", e);
}
} else if ev.token().0 == IPV6_SOCK_EVENT_KEY {
if let Err(e) = self.poller.registry().reregister(
&mut self.ipv6_sock,
ev.token(),
mio::Interest::READABLE,
) {
debug!("modify poller for IPv6 socket: {}", e);
}
}
}
}
fn handle_response(&mut self, mut msg: DnsIncoming, if_index: u32) {
let now = current_time_millis();
let mut record_predicate = |record: &DnsRecordBox| {
if !record.get_record().is_expired(now) {
return true;
}
debug!("record is expired, removing it from cache.");
if self.cache.remove(record) {
if let Some(dns_ptr) = record.any().downcast_ref::<DnsPointer>() {
call_service_listener(
&self.service_queriers,
dns_ptr.get_name(),
ServiceEvent::ServiceRemoved(
dns_ptr.get_name().to_string(),
dns_ptr.alias().to_string(),
),
);
}
}
false
};
msg.answers_mut().retain(&mut record_predicate);
msg.authorities_mut().retain(&mut record_predicate);
msg.additionals_mut().retain(&mut record_predicate);
self.conflict_handler(&msg, if_index);
let mut is_for_us = true;
for answer in msg.answers() {
if answer.get_type() == RRType::PTR {
if self.service_queriers.contains_key(answer.get_name()) {
is_for_us = true;
break; } else {
is_for_us = false;
}
} else if answer.get_type() == RRType::A || answer.get_type() == RRType::AAAA {
let answer_lowercase = answer.get_name().to_lowercase();
if self.hostname_resolvers.contains_key(&answer_lowercase) {
is_for_us = true;
break; }
}
}
if self.accept_unsolicited {
is_for_us = true;
}
struct InstanceChange {
ty: RRType, name: String, }
let mut changes = Vec::new();
let mut timers = Vec::new();
let Some(my_intf) = self.my_intfs.get(&if_index) else {
return;
};
for record in msg.all_records() {
match self
.cache
.add_or_update(my_intf, record, &mut timers, is_for_us)
{
Some((dns_record, true)) => {
timers.push(dns_record.record.get_record().get_expire_time());
timers.push(dns_record.record.get_record().get_refresh_time());
let ty = dns_record.record.get_type();
let name = dns_record.record.get_name();
if ty == RRType::PTR && dns_record.record.get_record().get_ttl() > 1 {
if self.service_queriers.contains_key(name) {
timers.push(dns_record.record.get_record().get_refresh_time());
}
if let Some(dns_ptr) = dns_record.record.any().downcast_ref::<DnsPointer>()
{
debug!("calling listener with service found: {name}");
call_service_listener(
&self.service_queriers,
name,
ServiceEvent::ServiceFound(
name.to_string(),
dns_ptr.alias().to_string(),
),
);
changes.push(InstanceChange {
ty,
name: dns_ptr.alias().to_string(),
});
}
} else {
changes.push(InstanceChange {
ty,
name: name.to_string(),
});
}
}
Some((dns_record, false)) => {
timers.push(dns_record.record.get_record().get_expire_time());
timers.push(dns_record.record.get_record().get_refresh_time());
}
_ => {}
}
}
for t in timers {
self.add_timer(t);
}
for change in changes
.iter()
.filter(|change| change.ty == RRType::A || change.ty == RRType::AAAA)
{
let addr_map = self.cache.get_addresses_for_host(&change.name);
for (name, addresses) in addr_map {
call_hostname_resolution_listener(
&self.hostname_resolvers,
&change.name,
HostnameResolutionEvent::AddressesFound(name, addresses),
)
}
}
let mut updated_instances = HashSet::new();
for update in changes {
match update.ty {
RRType::PTR | RRType::SRV | RRType::TXT => {
updated_instances.insert(update.name);
}
RRType::A | RRType::AAAA => {
let instances = self.cache.get_instances_on_host(&update.name);
updated_instances.extend(instances);
}
_ => {}
}
}
self.resolve_updated_instances(&updated_instances);
}
fn conflict_handler(&mut self, msg: &DnsIncoming, if_index: u32) {
let Some(my_intf) = self.my_intfs.get(&if_index) else {
debug!("handle_response: no intf found for index {if_index}");
return;
};
let Some(dns_registry) = self.dns_registry_map.get_mut(&if_index) else {
return;
};
for answer in msg.answers().iter() {
let mut new_records = Vec::new();
let name = answer.get_name();
let Some(probe) = dns_registry.probing.get_mut(name) else {
continue;
};
if answer.get_type() == RRType::A || answer.get_type() == RRType::AAAA {
if let Some(answer_addr) = answer.any().downcast_ref::<DnsAddress>() {
if answer_addr.interface_id.index != if_index {
debug!(
"conflict handler: answer addr {:?} not in the subnet of intf {}",
answer_addr, my_intf.name
);
continue;
}
}
let any_match = probe.records.iter().any(|r| {
r.get_type() == answer.get_type()
&& r.get_class() == answer.get_class()
&& r.rrdata_match(answer.as_ref())
});
if any_match {
continue; }
}
probe.records.retain(|record| {
if record.get_type() == answer.get_type()
&& record.get_class() == answer.get_class()
&& !record.rrdata_match(answer.as_ref())
{
debug!(
"found conflict name: '{name}' record: {}: {} PEER: {}",
record.get_type(),
record.rdata_print(),
answer.rdata_print()
);
let mut new_record = record.clone();
let new_name = match record.get_type() {
RRType::A => hostname_change(name),
RRType::AAAA => hostname_change(name),
_ => name_change(name),
};
new_record.get_record_mut().set_new_name(new_name);
new_records.push(new_record);
return false; }
true
});
let create_time = current_time_millis() + fastrand::u64(0..250);
let waiting_services = probe.waiting_services.clone();
for record in new_records {
if dns_registry.update_hostname(name, record.get_name(), create_time) {
self.timers.push(Reverse(create_time));
}
dns_registry.name_changes.insert(
record.get_record().get_original_name().to_string(),
record.get_name().to_string(),
);
let new_probe = match dns_registry.probing.get_mut(record.get_name()) {
Some(p) => p,
None => {
let new_probe = dns_registry
.probing
.entry(record.get_name().to_string())
.or_insert_with(|| {
debug!("conflict handler: new probe of {}", record.get_name());
Probe::new(create_time)
});
self.timers.push(Reverse(new_probe.next_send));
new_probe
}
};
debug!(
"insert record with new name '{}' {} into probe",
record.get_name(),
record.get_type()
);
new_probe.insert_record(record);
new_probe.waiting_services.extend(waiting_services.clone());
}
}
}
fn resolve_updated_instances(&mut self, updated_instances: &HashSet<String>) {
let mut resolved: HashSet<String> = HashSet::new();
let mut unresolved: HashSet<String> = HashSet::new();
let mut removed_instances = HashMap::new();
let now = current_time_millis();
for (ty_domain, records) in self.cache.all_ptr().iter() {
if !self.service_queriers.contains_key(ty_domain) {
continue;
}
for record in records.iter().filter(|r| !r.record.expires_soon(now)) {
if let Some(dns_ptr) = record.record.any().downcast_ref::<DnsPointer>() {
if updated_instances.contains(dns_ptr.alias()) {
let mut instance_found = false;
let mut new_event = None;
if let Ok(resolved) =
self.resolve_service_from_cache(ty_domain, dns_ptr.alias())
{
debug!("resolve_updated_instances: from cache: {}", dns_ptr.alias());
instance_found = true;
if resolved.is_valid() {
new_event = Some(ServiceEvent::ServiceResolved(Box::new(resolved)));
} else {
debug!("Resolved service is not valid: {}", dns_ptr.alias());
}
}
if instance_found {
if let Some(event) = new_event {
debug!("call queriers to resolve {}", dns_ptr.alias());
resolved.insert(dns_ptr.alias().to_string());
call_service_listener(&self.service_queriers, ty_domain, event);
} else {
if self.resolved.remove(dns_ptr.alias()) {
removed_instances
.entry(ty_domain.to_string())
.or_insert_with(HashSet::new)
.insert(dns_ptr.alias().to_string());
}
unresolved.insert(dns_ptr.alias().to_string());
}
}
}
}
}
}
for instance in resolved.drain() {
self.pending_resolves.remove(&instance);
self.resolved.insert(instance);
}
for instance in unresolved.drain() {
self.add_pending_resolve(instance);
}
if !removed_instances.is_empty() {
debug!(
"resolve_updated_instances: removed {}",
&removed_instances.len()
);
self.notify_service_removal(removed_instances);
}
}
fn handle_query(&mut self, msg: DnsIncoming, if_index: u32, is_ipv4: bool) {
let sock = if is_ipv4 {
&self.ipv4_sock
} else {
&self.ipv6_sock
};
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
const META_QUERY: &str = "_services._dns-sd._udp.local.";
let Some(dns_registry) = self.dns_registry_map.get_mut(&if_index) else {
debug!("missing dns registry for intf {}", if_index);
return;
};
let Some(intf) = self.my_intfs.get(&if_index) else {
return;
};
for question in msg.questions().iter() {
let qtype = question.entry_type();
if qtype == RRType::PTR {
for service in self.my_services.values() {
if service.get_status(if_index) != ServiceStatus::Announced {
continue;
}
if question.entry_name() == service.get_type()
|| service
.get_subtype()
.as_ref()
.is_some_and(|v| v == question.entry_name())
{
add_answer_with_additionals(
&mut out,
&msg,
service,
intf,
dns_registry,
is_ipv4,
);
} else if question.entry_name() == META_QUERY {
let ptr_added = out.add_answer(
&msg,
DnsPointer::new(
question.entry_name(),
RRType::PTR,
CLASS_IN,
service.get_other_ttl(),
service.get_type().to_string(),
),
);
if !ptr_added {
trace!("answer was not added for meta-query {:?}", &question);
}
}
}
} else {
if qtype == RRType::ANY && msg.num_authorities() > 0 {
let probe_name = question.entry_name();
if let Some(probe) = dns_registry.probing.get_mut(probe_name) {
let now = current_time_millis();
if probe.start_time < now {
let incoming_records: Vec<_> = msg
.authorities()
.iter()
.filter(|r| r.get_name() == probe_name)
.collect();
probe.tiebreaking(&incoming_records, now, probe_name);
}
}
}
if qtype == RRType::A || qtype == RRType::AAAA || qtype == RRType::ANY {
for service in self.my_services.values() {
if service.get_status(if_index) != ServiceStatus::Announced {
continue;
}
let service_hostname =
match dns_registry.name_changes.get(service.get_hostname()) {
Some(new_name) => new_name,
None => service.get_hostname(),
};
if service_hostname.to_lowercase() == question.entry_name().to_lowercase() {
let intf_addrs = if is_ipv4 {
service.get_addrs_on_my_intf_v4(intf)
} else {
service.get_addrs_on_my_intf_v6(intf)
};
if intf_addrs.is_empty()
&& (qtype == RRType::A || qtype == RRType::AAAA)
{
let t = match qtype {
RRType::A => "TYPE_A",
RRType::AAAA => "TYPE_AAAA",
_ => "invalid_type",
};
trace!(
"Cannot find valid addrs for {} response on intf {:?}",
t,
&intf
);
return;
}
for address in intf_addrs {
out.add_answer(
&msg,
DnsAddress::new(
service_hostname,
ip_address_rr_type(&address),
CLASS_IN | CLASS_CACHE_FLUSH,
service.get_host_ttl(),
address,
intf.into(),
),
);
}
}
}
}
let query_name = question.entry_name().to_lowercase();
let service_opt = self
.my_services
.iter()
.find(|(k, _v)| {
let service_name = match dns_registry.name_changes.get(k.as_str()) {
Some(new_name) => new_name,
None => k,
};
service_name == &query_name
})
.map(|(_, v)| v);
let Some(service) = service_opt else {
continue;
};
if service.get_status(if_index) != ServiceStatus::Announced {
continue;
}
let intf_addrs = if is_ipv4 {
service.get_addrs_on_my_intf_v4(intf)
} else {
service.get_addrs_on_my_intf_v6(intf)
};
if intf_addrs.is_empty() {
debug!(
"Cannot find valid addrs for TYPE_SRV response on intf {:?}",
&intf
);
continue;
}
add_answer_of_service(
&mut out,
&msg,
question.entry_name(),
service,
qtype,
intf_addrs,
);
}
}
if !out.answers_count() > 0 {
out.set_id(msg.id());
send_dns_outgoing(&out, intf, &sock.pktinfo);
let if_name = intf.name.clone();
self.increase_counter(Counter::Respond, 1);
self.notify_monitors(DaemonEvent::Respond(if_name));
}
self.increase_counter(Counter::KnownAnswerSuppression, out.known_answer_count());
}
fn increase_counter(&mut self, counter: Counter, count: i64) {
let key = counter.to_string();
match self.counters.get_mut(&key) {
Some(v) => *v += count,
None => {
self.counters.insert(key, count);
}
}
}
fn set_counter(&mut self, counter: Counter, count: i64) {
let key = counter.to_string();
self.counters.insert(key, count);
}
fn signal_sock_drain(&self) {
let mut signal_buf = [0; 1024];
while let Ok(sz) = self.signal_sock.recv(&mut signal_buf) {
trace!(
"signal socket recvd: {}",
String::from_utf8_lossy(&signal_buf[0..sz])
);
}
}
fn add_retransmission(&mut self, next_time: u64, command: Command) {
self.retransmissions.push(ReRun { next_time, command });
self.add_timer(next_time);
}
fn notify_service_removal(&self, expired: HashMap<String, HashSet<String>>) {
for (ty_domain, sender) in self.service_queriers.iter() {
if let Some(instances) = expired.get(ty_domain) {
for instance_name in instances {
let event = ServiceEvent::ServiceRemoved(
ty_domain.to_string(),
instance_name.to_string(),
);
match sender.send(event) {
Ok(()) => debug!("notify_service_removal: sent ServiceRemoved to listener of {ty_domain}: {instance_name}"),
Err(e) => debug!("Failed to send event: {}", e),
}
}
}
}
}
fn exec_command(&mut self, command: Command, repeating: bool) {
trace!("exec_command: {:?} repeating: {}", &command, repeating);
match command {
Command::Browse(ty, next_delay, cache_only, listener) => {
self.exec_command_browse(repeating, ty, next_delay, cache_only, listener);
}
Command::ResolveHostname(hostname, next_delay, listener, timeout) => {
self.exec_command_resolve_hostname(
repeating, hostname, next_delay, listener, timeout,
);
}
Command::Register(service_info) => {
self.register_service(service_info);
self.increase_counter(Counter::Register, 1);
}
Command::RegisterResend(fullname, intf) => {
trace!("register-resend service: {fullname} on {}", &intf);
self.exec_command_register_resend(fullname, intf);
}
Command::Unregister(fullname, resp_s) => {
trace!("unregister service {} repeat {}", &fullname, &repeating);
self.exec_command_unregister(repeating, fullname, resp_s);
}
Command::UnregisterResend(packet, if_index, is_ipv4) => {
self.exec_command_unregister_resend(packet, if_index, is_ipv4);
}
Command::StopBrowse(ty_domain) => self.exec_command_stop_browse(ty_domain),
Command::StopResolveHostname(hostname) => {
self.exec_command_stop_resolve_hostname(hostname.to_lowercase())
}
Command::Resolve(instance, try_count) => self.exec_command_resolve(instance, try_count),
Command::GetMetrics(resp_s) => self.exec_command_get_metrics(resp_s),
Command::GetStatus(resp_s) => match resp_s.send(self.status.clone()) {
Ok(()) => trace!("Sent status to the client"),
Err(e) => debug!("Failed to send status: {}", e),
},
Command::Monitor(resp_s) => {
self.monitors.push(resp_s);
}
Command::SetOption(daemon_opt) => {
self.process_set_option(daemon_opt);
}
Command::GetOption(resp_s) => {
let val = DaemonOptionVal {
_service_name_len_max: self.service_name_len_max,
ip_check_interval: self.ip_check_interval,
};
if let Err(e) = resp_s.send(val) {
debug!("Failed to send options: {}", e);
}
}
Command::Verify(instance_fullname, timeout) => {
self.exec_command_verify(instance_fullname, timeout, repeating);
}
_ => {
debug!("unexpected command: {:?}", &command);
}
}
}
fn exec_command_get_metrics(&mut self, resp_s: Sender<HashMap<String, i64>>) {
self.set_counter(Counter::CachedPTR, self.cache.ptr_count() as i64);
self.set_counter(Counter::CachedSRV, self.cache.srv_count() as i64);
self.set_counter(Counter::CachedAddr, self.cache.addr_count() as i64);
self.set_counter(Counter::CachedTxt, self.cache.txt_count() as i64);
self.set_counter(Counter::CachedNSec, self.cache.nsec_count() as i64);
self.set_counter(Counter::CachedSubtype, self.cache.subtype_count() as i64);
self.set_counter(Counter::Timer, self.timers.len() as i64);
let dns_registry_probe_count: usize = self
.dns_registry_map
.values()
.map(|r| r.probing.len())
.sum();
self.set_counter(Counter::DnsRegistryProbe, dns_registry_probe_count as i64);
let dns_registry_active_count: usize = self
.dns_registry_map
.values()
.map(|r| r.active.values().map(|a| a.len()).sum::<usize>())
.sum();
self.set_counter(Counter::DnsRegistryActive, dns_registry_active_count as i64);
let dns_registry_timer_count: usize = self
.dns_registry_map
.values()
.map(|r| r.new_timers.len())
.sum();
self.set_counter(Counter::DnsRegistryTimer, dns_registry_timer_count as i64);
let dns_registry_name_change_count: usize = self
.dns_registry_map
.values()
.map(|r| r.name_changes.len())
.sum();
self.set_counter(
Counter::DnsRegistryNameChange,
dns_registry_name_change_count as i64,
);
if let Err(e) = resp_s.send(self.counters.clone()) {
debug!("Failed to send metrics: {}", e);
}
}
fn exec_command_browse(
&mut self,
repeating: bool,
ty: String,
next_delay: u32,
cache_only: bool,
listener: Sender<ServiceEvent>,
) {
let pretty_addrs: Vec<String> = self
.my_intfs
.iter()
.map(|(if_index, itf)| format!("{} ({if_index})", itf.name))
.collect();
if let Err(e) = listener.send(ServiceEvent::SearchStarted(format!(
"{ty} on {} interfaces [{}]",
pretty_addrs.len(),
pretty_addrs.join(", ")
))) {
debug!(
"Failed to send SearchStarted({})(repeating:{}): {}",
&ty, repeating, e
);
return;
}
let now = current_time_millis();
if !repeating {
self.service_queriers.insert(ty.clone(), listener.clone());
self.query_cache_for_service(&ty, &listener, now);
}
if cache_only {
match listener.send(ServiceEvent::SearchStopped(ty.clone())) {
Ok(()) => debug!("SearchStopped sent for {}", &ty),
Err(e) => debug!("Failed to send SearchStopped: {}", e),
}
return;
}
self.send_query(&ty, RRType::PTR);
self.increase_counter(Counter::Browse, 1);
let next_time = now + (next_delay * 1000) as u64;
let max_delay = 60 * 60;
let delay = cmp::min(next_delay * 2, max_delay);
self.add_retransmission(next_time, Command::Browse(ty, delay, cache_only, listener));
}
fn exec_command_resolve_hostname(
&mut self,
repeating: bool,
hostname: String,
next_delay: u32,
listener: Sender<HostnameResolutionEvent>,
timeout: Option<u64>,
) {
let addr_list: Vec<_> = self.my_intfs.iter().collect();
if let Err(e) = listener.send(HostnameResolutionEvent::SearchStarted(format!(
"{} on addrs {:?}",
&hostname, &addr_list
))) {
debug!(
"Failed to send ResolveStarted({})(repeating:{}): {}",
&hostname, repeating, e
);
return;
}
if !repeating {
self.add_hostname_resolver(hostname.to_owned(), listener.clone(), timeout);
self.query_cache_for_hostname(&hostname, listener.clone());
}
self.send_query_vec(&[(&hostname, RRType::A), (&hostname, RRType::AAAA)]);
self.increase_counter(Counter::ResolveHostname, 1);
let now = current_time_millis();
let next_time = now + u64::from(next_delay) * 1000;
let max_delay = 60 * 60;
let delay = cmp::min(next_delay * 2, max_delay);
if self
.hostname_resolvers
.get(&hostname)
.and_then(|(_sender, timeout)| *timeout)
.map(|timeout| next_time < timeout)
.unwrap_or(true)
{
self.add_retransmission(
next_time,
Command::ResolveHostname(hostname, delay, listener, None),
);
}
}
fn exec_command_resolve(&mut self, instance: String, try_count: u16) {
let pending_query = self.query_unresolved(&instance);
let max_try = 3;
if pending_query && try_count < max_try {
let next_time = current_time_millis() + RESOLVE_WAIT_IN_MILLIS;
self.add_retransmission(next_time, Command::Resolve(instance, try_count + 1));
}
}
fn exec_command_unregister(
&mut self,
repeating: bool,
fullname: String,
resp_s: Sender<UnregisterStatus>,
) {
let response = match self.my_services.remove_entry(&fullname) {
None => {
debug!("unregister: cannot find such service {}", &fullname);
UnregisterStatus::NotFound
}
Some((_k, info)) => {
let mut timers = Vec::new();
for (if_index, intf) in self.my_intfs.iter() {
let packet = self.unregister_service(&info, intf, &self.ipv4_sock.pktinfo);
if !repeating && !packet.is_empty() {
let next_time = current_time_millis() + 120;
self.retransmissions.push(ReRun {
next_time,
command: Command::UnregisterResend(packet, *if_index, true),
});
timers.push(next_time);
}
let packet = self.unregister_service(&info, intf, &self.ipv6_sock.pktinfo);
if !repeating && !packet.is_empty() {
let next_time = current_time_millis() + 120;
self.retransmissions.push(ReRun {
next_time,
command: Command::UnregisterResend(packet, *if_index, false),
});
timers.push(next_time);
}
}
for t in timers {
self.add_timer(t);
}
self.increase_counter(Counter::Unregister, 1);
UnregisterStatus::OK
}
};
if let Err(e) = resp_s.send(response) {
debug!("unregister: failed to send response: {}", e);
}
}
fn exec_command_unregister_resend(&mut self, packet: Vec<u8>, if_index: u32, is_ipv4: bool) {
let Some(intf) = self.my_intfs.get(&if_index) else {
return;
};
let sock = if is_ipv4 {
&self.ipv4_sock.pktinfo
} else {
&self.ipv6_sock.pktinfo
};
let if_addr = if is_ipv4 {
match intf.next_ifaddr_v4() {
Some(addr) => addr,
None => return,
}
} else {
match intf.next_ifaddr_v6() {
Some(addr) => addr,
None => return,
}
};
debug!("UnregisterResend from {:?}", if_addr);
multicast_on_intf(&packet[..], &intf.name, intf.index, if_addr, sock);
self.increase_counter(Counter::UnregisterResend, 1);
}
fn exec_command_stop_browse(&mut self, ty_domain: String) {
match self.service_queriers.remove_entry(&ty_domain) {
None => debug!("StopBrowse: cannot find querier for {}", &ty_domain),
Some((ty, sender)) => {
trace!("StopBrowse: removed queryer for {}", &ty);
let mut i = 0;
while i < self.retransmissions.len() {
if let Command::Browse(t, _, _, _) = &self.retransmissions[i].command {
if t == &ty {
self.retransmissions.remove(i);
trace!("StopBrowse: removed retransmission for {}", &ty);
continue;
}
}
i += 1;
}
self.cache.remove_service_type(&ty_domain);
match sender.send(ServiceEvent::SearchStopped(ty_domain)) {
Ok(()) => trace!("Sent SearchStopped to the listener"),
Err(e) => debug!("Failed to send SearchStopped: {}", e),
}
}
}
}
fn exec_command_stop_resolve_hostname(&mut self, hostname: String) {
if let Some((host, (sender, _timeout))) = self.hostname_resolvers.remove_entry(&hostname) {
trace!("StopResolve: removed queryer for {}", &host);
let mut i = 0;
while i < self.retransmissions.len() {
if let Command::Resolve(t, _) = &self.retransmissions[i].command {
if t == &host {
self.retransmissions.remove(i);
trace!("StopResolve: removed retransmission for {}", &host);
continue;
}
}
i += 1;
}
match sender.send(HostnameResolutionEvent::SearchStopped(hostname)) {
Ok(()) => trace!("Sent SearchStopped to the listener"),
Err(e) => debug!("Failed to send SearchStopped: {}", e),
}
}
}
fn exec_command_register_resend(&mut self, fullname: String, if_index: u32) {
let Some(info) = self.my_services.get_mut(&fullname) else {
trace!("announce: cannot find such service {}", &fullname);
return;
};
let Some(dns_registry) = self.dns_registry_map.get_mut(&if_index) else {
return;
};
let Some(intf) = self.my_intfs.get(&if_index) else {
return;
};
let announced_v4 =
announce_service_on_intf(dns_registry, info, intf, &self.ipv4_sock.pktinfo);
let announced_v6 =
announce_service_on_intf(dns_registry, info, intf, &self.ipv6_sock.pktinfo);
if announced_v4 || announced_v6 {
let mut hostname = info.get_hostname();
if let Some(new_name) = dns_registry.name_changes.get(hostname) {
hostname = new_name;
}
let service_name = match dns_registry.name_changes.get(&fullname) {
Some(new_name) => new_name.to_string(),
None => fullname,
};
debug!("resend: announce service {service_name} on {}", intf.name);
notify_monitors(
&mut self.monitors,
DaemonEvent::Announce(service_name, format!("{}:{}", hostname, &intf.name)),
);
info.set_status(if_index, ServiceStatus::Announced);
} else {
debug!("register-resend should not fail");
}
self.increase_counter(Counter::RegisterResend, 1);
}
fn exec_command_verify(&mut self, instance: String, timeout: Duration, repeating: bool) {
let now = current_time_millis();
let expire_at = if repeating {
None
} else {
Some(now + timeout.as_millis() as u64)
};
let record_vec = self.cache.service_verify_queries(&instance, expire_at);
if !record_vec.is_empty() {
let query_vec: Vec<(&str, RRType)> = record_vec
.iter()
.map(|(record, rr_type)| (record.as_str(), *rr_type))
.collect();
self.send_query_vec(&query_vec);
if let Some(new_expire) = expire_at {
self.add_timer(new_expire);
self.add_retransmission(now + 1000, Command::Verify(instance, timeout));
}
}
}
fn refresh_active_services(&mut self) {
let mut query_ptr_count = 0;
let mut query_srv_count = 0;
let mut new_timers = HashSet::new();
let mut query_addr_count = 0;
for (ty_domain, _sender) in self.service_queriers.iter() {
let refreshed_timers = self.cache.refresh_due_ptr(ty_domain);
if !refreshed_timers.is_empty() {
trace!("sending refresh query for PTR: {}", ty_domain);
self.send_query(ty_domain, RRType::PTR);
query_ptr_count += 1;
new_timers.extend(refreshed_timers);
}
let (instances, timers) = self.cache.refresh_due_srv_txt(ty_domain);
for (instance, types) in instances {
trace!("sending refresh query for: {}", &instance);
let query_vec = types
.into_iter()
.map(|ty| (instance.as_str(), ty))
.collect::<Vec<_>>();
self.send_query_vec(&query_vec);
query_srv_count += 1;
}
new_timers.extend(timers);
let (hostnames, timers) = self.cache.refresh_due_hosts(ty_domain);
for hostname in hostnames.iter() {
trace!("sending refresh queries for A and AAAA: {}", hostname);
self.send_query_vec(&[(hostname, RRType::A), (hostname, RRType::AAAA)]);
query_addr_count += 2;
}
new_timers.extend(timers);
}
for timer in new_timers {
self.add_timer(timer);
}
self.increase_counter(Counter::CacheRefreshPTR, query_ptr_count);
self.increase_counter(Counter::CacheRefreshSrvTxt, query_srv_count);
self.increase_counter(Counter::CacheRefreshAddr, query_addr_count);
}
}
fn add_answer_of_service(
out: &mut DnsOutgoing,
msg: &DnsIncoming,
entry_name: &str,
service: &ServiceInfo,
qtype: RRType,
intf_addrs: Vec<IpAddr>,
) {
if qtype == RRType::SRV || qtype == RRType::ANY {
out.add_answer(
msg,
DnsSrv::new(
entry_name,
CLASS_IN | CLASS_CACHE_FLUSH,
service.get_host_ttl(),
service.get_priority(),
service.get_weight(),
service.get_port(),
service.get_hostname().to_string(),
),
);
}
if qtype == RRType::TXT || qtype == RRType::ANY {
out.add_answer(
msg,
DnsTxt::new(
entry_name,
CLASS_IN | CLASS_CACHE_FLUSH,
service.get_other_ttl(),
service.generate_txt(),
),
);
}
if qtype == RRType::SRV {
for address in intf_addrs {
out.add_additional_answer(DnsAddress::new(
service.get_hostname(),
ip_address_rr_type(&address),
CLASS_IN | CLASS_CACHE_FLUSH,
service.get_host_ttl(),
address,
InterfaceId::default(),
));
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum ServiceEvent {
SearchStarted(String),
ServiceFound(String, String),
ServiceResolved(Box<ResolvedService>),
ServiceRemoved(String, String),
SearchStopped(String),
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum HostnameResolutionEvent {
SearchStarted(String),
AddressesFound(String, HashSet<ScopedIp>),
AddressesRemoved(String, HashSet<ScopedIp>),
SearchTimeout(String),
SearchStopped(String),
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum DaemonEvent {
Announce(String, String),
Error(Error),
IpAdd(IpAddr),
IpDel(IpAddr),
NameChange(DnsNameChange),
Respond(String),
}
#[derive(Clone, Debug)]
pub struct DnsNameChange {
pub original: String,
pub new_name: String,
pub rr_type: RRType,
pub intf_name: String,
}
#[derive(Debug)]
enum Command {
Browse(String, u32, bool, Sender<ServiceEvent>),
ResolveHostname(String, u32, Sender<HostnameResolutionEvent>, Option<u64>),
Register(ServiceInfo),
Unregister(String, Sender<UnregisterStatus>),
RegisterResend(String, u32),
UnregisterResend(Vec<u8>, u32, bool),
StopBrowse(String),
StopResolveHostname(String),
Resolve(String, u16),
GetMetrics(Sender<Metrics>),
GetStatus(Sender<DaemonStatus>),
Monitor(Sender<DaemonEvent>),
SetOption(DaemonOption),
GetOption(Sender<DaemonOptionVal>),
Verify(String, Duration),
Exit(Sender<DaemonStatus>),
}
impl fmt::Display for Command {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Browse(_, _, _, _) => write!(f, "Command Browse"),
Self::ResolveHostname(_, _, _, _) => write!(f, "Command ResolveHostname"),
Self::Exit(_) => write!(f, "Command Exit"),
Self::GetStatus(_) => write!(f, "Command GetStatus"),
Self::GetMetrics(_) => write!(f, "Command GetMetrics"),
Self::Monitor(_) => write!(f, "Command Monitor"),
Self::Register(_) => write!(f, "Command Register"),
Self::RegisterResend(_, _) => write!(f, "Command RegisterResend"),
Self::SetOption(_) => write!(f, "Command SetOption"),
Self::GetOption(_) => write!(f, "Command GetOption"),
Self::StopBrowse(_) => write!(f, "Command StopBrowse"),
Self::StopResolveHostname(_) => write!(f, "Command StopResolveHostname"),
Self::Unregister(_, _) => write!(f, "Command Unregister"),
Self::UnregisterResend(_, _, _) => write!(f, "Command UnregisterResend"),
Self::Resolve(_, _) => write!(f, "Command Resolve"),
Self::Verify(_, _) => write!(f, "Command VerifyResource"),
}
}
}
struct DaemonOptionVal {
_service_name_len_max: u8,
ip_check_interval: u64,
}
#[derive(Debug)]
enum DaemonOption {
ServiceNameLenMax(u8),
IpCheckInterval(u64),
EnableInterface(Vec<IfKind>),
DisableInterface(Vec<IfKind>),
MulticastLoopV4(bool),
MulticastLoopV6(bool),
AcceptUnsolicited(bool),
#[cfg(test)]
TestDownInterface(String),
#[cfg(test)]
TestUpInterface(String),
}
const DOMAIN_LEN: usize = "._tcp.local.".len();
fn check_service_name_length(ty_domain: &str, limit: u8) -> Result<()> {
if ty_domain.len() <= DOMAIN_LEN + 1 {
return Err(e_fmt!("Service type name cannot be empty: {}", ty_domain));
}
let service_name_len = ty_domain.len() - DOMAIN_LEN - 1; if service_name_len > limit as usize {
return Err(e_fmt!("Service name length must be <= {} bytes", limit));
}
Ok(())
}
fn check_domain_suffix(name: &str) -> Result<()> {
if !(name.ends_with("._tcp.local.") || name.ends_with("._udp.local.")) {
return Err(e_fmt!(
"mDNS service {} must end with '._tcp.local.' or '._udp.local.'",
name
));
}
Ok(())
}
fn check_service_name(fullname: &str) -> Result<()> {
check_domain_suffix(fullname)?;
let remaining: Vec<&str> = fullname[..fullname.len() - DOMAIN_LEN].split('.').collect();
let name = remaining.last().ok_or_else(|| e_fmt!("No service name"))?;
if &name[0..1] != "_" {
return Err(e_fmt!("Service name must start with '_'"));
}
let name = &name[1..];
if name.contains("--") {
return Err(e_fmt!("Service name must not contain '--'"));
}
if name.starts_with('-') || name.ends_with('-') {
return Err(e_fmt!("Service name (%s) may not start or end with '-'"));
}
let ascii_count = name.chars().filter(|c| c.is_ascii_alphabetic()).count();
if ascii_count < 1 {
return Err(e_fmt!(
"Service name must contain at least one letter (eg: 'A-Za-z')"
));
}
Ok(())
}
fn check_hostname(hostname: &str) -> Result<()> {
if !hostname.ends_with(".local.") {
return Err(e_fmt!("Hostname must end with '.local.': {hostname}"));
}
if hostname == ".local." {
return Err(e_fmt!(
"The part of the hostname before '.local.' cannot be empty"
));
}
if hostname.len() > 255 {
return Err(e_fmt!("Hostname length must be <= 255 bytes"));
}
Ok(())
}
fn call_service_listener(
listeners_map: &HashMap<String, Sender<ServiceEvent>>,
ty_domain: &str,
event: ServiceEvent,
) {
if let Some(listener) = listeners_map.get(ty_domain) {
match listener.send(event) {
Ok(()) => trace!("Sent event to listener successfully"),
Err(e) => debug!("Failed to send event: {}", e),
}
}
}
fn call_hostname_resolution_listener(
listeners_map: &HashMap<String, (Sender<HostnameResolutionEvent>, Option<u64>)>,
hostname: &str,
event: HostnameResolutionEvent,
) {
let hostname_lower = hostname.to_lowercase();
if let Some(listener) = listeners_map.get(&hostname_lower).map(|(l, _)| l) {
match listener.send(event) {
Ok(()) => trace!("Sent event to listener successfully"),
Err(e) => debug!("Failed to send event: {}", e),
}
}
}
fn my_ip_interfaces(with_loopback: bool) -> Vec<Interface> {
if_addrs::get_if_addrs()
.unwrap_or_default()
.into_iter()
.filter(|i| i.is_oper_up() && (!i.is_loopback() || with_loopback))
.collect()
}
fn send_dns_outgoing(out: &DnsOutgoing, my_intf: &MyIntf, sock: &PktInfoUdpSocket) -> Vec<Vec<u8>> {
let if_name = &my_intf.name;
let if_addr = if sock.domain() == Domain::IPV4 {
match my_intf.next_ifaddr_v4() {
Some(addr) => addr,
None => return vec![],
}
} else {
match my_intf.next_ifaddr_v6() {
Some(addr) => addr,
None => return vec![],
}
};
send_dns_outgoing_impl(out, if_name, my_intf.index, if_addr, sock)
}
fn send_dns_outgoing_impl(
out: &DnsOutgoing,
if_name: &str,
if_index: u32,
if_addr: &IfAddr,
sock: &PktInfoUdpSocket,
) -> Vec<Vec<u8>> {
let qtype = if out.is_query() {
"query"
} else {
if out.answers_count() == 0 && out.additionals().is_empty() {
return vec![]; }
"response"
};
trace!(
"send {}: {} questions {} answers {} authorities {} additional",
qtype,
out.questions().len(),
out.answers_count(),
out.authorities().len(),
out.additionals().len()
);
match if_addr.ip() {
IpAddr::V4(ipv4) => {
if let Err(e) = sock.set_multicast_if_v4(&ipv4) {
debug!(
"send_dns_outgoing: failed to set multicast interface for IPv4 {}: {}",
ipv4, e
);
return vec![]; }
}
IpAddr::V6(ipv6) => {
if let Err(e) = sock.set_multicast_if_v6(if_index) {
debug!(
"send_dns_outgoing: failed to set multicast interface for IPv6 {}: {}",
ipv6, e
);
return vec![]; }
}
}
let packet_list = out.to_data_on_wire();
for packet in packet_list.iter() {
multicast_on_intf(packet, if_name, if_index, if_addr, sock);
}
packet_list
}
fn multicast_on_intf(
packet: &[u8],
if_name: &str,
if_index: u32,
if_addr: &IfAddr,
socket: &PktInfoUdpSocket,
) {
if packet.len() > MAX_MSG_ABSOLUTE {
debug!("Drop over-sized packet ({})", packet.len());
return;
}
let addr: SocketAddr = match if_addr {
if_addrs::IfAddr::V4(_) => SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT).into(),
if_addrs::IfAddr::V6(_) => {
let mut sock = SocketAddrV6::new(GROUP_ADDR_V6, MDNS_PORT, 0, 0);
sock.set_scope_id(if_index); sock.into()
}
};
let sock_addr = addr.into();
match socket.send_to(packet, &sock_addr) {
Ok(sz) => trace!(
"sent out {} bytes on interface {} (idx {}) addr {}",
sz,
if_name,
if_index,
if_addr.ip()
),
Err(e) => trace!("Failed to send to {} via {:?}: {}", addr, &if_name, e),
}
}
fn valid_instance_name(name: &str) -> bool {
name.split('.').count() >= 5
}
fn notify_monitors(monitors: &mut Vec<Sender<DaemonEvent>>, event: DaemonEvent) {
monitors.retain(|sender| {
if let Err(e) = sender.try_send(event.clone()) {
debug!("notify_monitors: try_send: {}", &e);
if matches!(e, TrySendError::Disconnected(_)) {
return false; }
}
true
});
}
fn prepare_announce(
info: &ServiceInfo,
intf: &MyIntf,
dns_registry: &mut DnsRegistry,
is_ipv4: bool,
) -> Option<DnsOutgoing> {
let intf_addrs = if is_ipv4 {
info.get_addrs_on_my_intf_v4(intf)
} else {
info.get_addrs_on_my_intf_v6(intf)
};
if intf_addrs.is_empty() {
debug!(
"prepare_announce (ipv4: {is_ipv4}): no valid addrs on interface {}",
&intf.name
);
return None;
}
let service_fullname = match dns_registry.name_changes.get(info.get_fullname()) {
Some(new_name) => new_name,
None => info.get_fullname(),
};
debug!(
"prepare to announce service {service_fullname} on {:?}",
&intf_addrs
);
let mut probing_count = 0;
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
let create_time = current_time_millis() + fastrand::u64(0..250);
out.add_answer_at_time(
DnsPointer::new(
info.get_type(),
RRType::PTR,
CLASS_IN,
info.get_other_ttl(),
service_fullname.to_string(),
),
0,
);
if let Some(sub) = info.get_subtype() {
trace!("Adding subdomain {}", sub);
out.add_answer_at_time(
DnsPointer::new(
sub,
RRType::PTR,
CLASS_IN,
info.get_other_ttl(),
service_fullname.to_string(),
),
0,
);
}
let hostname = match dns_registry.name_changes.get(info.get_hostname()) {
Some(new_name) => new_name.to_string(),
None => info.get_hostname().to_string(),
};
let mut srv = DnsSrv::new(
info.get_fullname(),
CLASS_IN | CLASS_CACHE_FLUSH,
info.get_host_ttl(),
info.get_priority(),
info.get_weight(),
info.get_port(),
hostname,
);
if let Some(new_name) = dns_registry.name_changes.get(info.get_fullname()) {
srv.get_record_mut().set_new_name(new_name.to_string());
}
if !info.requires_probe()
|| dns_registry.is_probing_done(&srv, info.get_fullname(), create_time)
{
out.add_answer_at_time(srv, 0);
} else {
probing_count += 1;
}
let mut txt = DnsTxt::new(
info.get_fullname(),
CLASS_IN | CLASS_CACHE_FLUSH,
info.get_other_ttl(),
info.generate_txt(),
);
if let Some(new_name) = dns_registry.name_changes.get(info.get_fullname()) {
txt.get_record_mut().set_new_name(new_name.to_string());
}
if !info.requires_probe()
|| dns_registry.is_probing_done(&txt, info.get_fullname(), create_time)
{
out.add_answer_at_time(txt, 0);
} else {
probing_count += 1;
}
let hostname = info.get_hostname();
for address in intf_addrs {
let mut dns_addr = DnsAddress::new(
hostname,
ip_address_rr_type(&address),
CLASS_IN | CLASS_CACHE_FLUSH,
info.get_host_ttl(),
address,
intf.into(),
);
if let Some(new_name) = dns_registry.name_changes.get(hostname) {
dns_addr.get_record_mut().set_new_name(new_name.to_string());
}
if !info.requires_probe()
|| dns_registry.is_probing_done(&dns_addr, info.get_fullname(), create_time)
{
out.add_answer_at_time(dns_addr, 0);
} else {
probing_count += 1;
}
}
if probing_count > 0 {
return None;
}
Some(out)
}
fn announce_service_on_intf(
dns_registry: &mut DnsRegistry,
info: &ServiceInfo,
intf: &MyIntf,
sock: &PktInfoUdpSocket,
) -> bool {
let is_ipv4 = sock.domain() == Domain::IPV4;
if let Some(out) = prepare_announce(info, intf, dns_registry, is_ipv4) {
send_dns_outgoing(&out, intf, sock);
return true;
}
false
}
fn name_change(original: &str) -> String {
let mut parts: Vec<_> = original.split('.').collect();
let Some(first_part) = parts.get_mut(0) else {
return format!("{original} (2)");
};
let mut new_name = format!("{first_part} (2)");
if let Some(paren_pos) = first_part.rfind(" (") {
if let Some(end_paren) = first_part[paren_pos..].find(')') {
let absolute_end_pos = paren_pos + end_paren;
if absolute_end_pos == first_part.len() - 1 {
let num_start = paren_pos + 2; if let Ok(number) = first_part[num_start..absolute_end_pos].parse::<u32>() {
let base_name = &first_part[..paren_pos];
new_name = format!("{} ({})", base_name, number + 1)
}
}
}
}
*first_part = &new_name;
parts.join(".")
}
fn hostname_change(original: &str) -> String {
let mut parts: Vec<_> = original.split('.').collect();
let Some(first_part) = parts.get_mut(0) else {
return format!("{original}-2");
};
let mut new_name = format!("{first_part}-2");
if let Some(hyphen_pos) = first_part.rfind('-') {
if let Ok(number) = first_part[hyphen_pos + 1..].parse::<u32>() {
let base_name = &first_part[..hyphen_pos];
new_name = format!("{}-{}", base_name, number + 1);
}
}
*first_part = &new_name;
parts.join(".")
}
fn add_answer_with_additionals(
out: &mut DnsOutgoing,
msg: &DnsIncoming,
service: &ServiceInfo,
intf: &MyIntf,
dns_registry: &DnsRegistry,
is_ipv4: bool,
) {
let intf_addrs = if is_ipv4 {
service.get_addrs_on_my_intf_v4(intf)
} else {
service.get_addrs_on_my_intf_v6(intf)
};
if intf_addrs.is_empty() {
trace!("No addrs on LAN of intf {:?}", intf);
return;
}
let service_fullname = match dns_registry.name_changes.get(service.get_fullname()) {
Some(new_name) => new_name,
None => service.get_fullname(),
};
let hostname = match dns_registry.name_changes.get(service.get_hostname()) {
Some(new_name) => new_name,
None => service.get_hostname(),
};
let ptr_added = out.add_answer(
msg,
DnsPointer::new(
service.get_type(),
RRType::PTR,
CLASS_IN,
service.get_other_ttl(),
service_fullname.to_string(),
),
);
if !ptr_added {
trace!("answer was not added for msg {:?}", msg);
return;
}
if let Some(sub) = service.get_subtype() {
trace!("Adding subdomain {}", sub);
out.add_additional_answer(DnsPointer::new(
sub,
RRType::PTR,
CLASS_IN,
service.get_other_ttl(),
service_fullname.to_string(),
));
}
out.add_additional_answer(DnsSrv::new(
service_fullname,
CLASS_IN | CLASS_CACHE_FLUSH,
service.get_host_ttl(),
service.get_priority(),
service.get_weight(),
service.get_port(),
hostname.to_string(),
));
out.add_additional_answer(DnsTxt::new(
service_fullname,
CLASS_IN | CLASS_CACHE_FLUSH,
service.get_other_ttl(),
service.generate_txt(),
));
for address in intf_addrs {
out.add_additional_answer(DnsAddress::new(
hostname,
ip_address_rr_type(&address),
CLASS_IN | CLASS_CACHE_FLUSH,
service.get_host_ttl(),
address,
intf.into(),
));
}
}
fn check_probing(
dns_registry: &mut DnsRegistry,
timers: &mut BinaryHeap<Reverse<u64>>,
now: u64,
) -> (DnsOutgoing, Vec<String>) {
let mut expired_probes = Vec::new();
let mut out = DnsOutgoing::new(FLAGS_QR_QUERY);
for (name, probe) in dns_registry.probing.iter_mut() {
if now >= probe.next_send {
if probe.expired(now) {
expired_probes.push(name.clone());
} else {
out.add_question(name, RRType::ANY);
for record in probe.records.iter() {
out.add_authority(record.clone());
}
probe.update_next_send(now);
timers.push(Reverse(probe.next_send));
}
}
}
(out, expired_probes)
}
fn handle_expired_probes(
expired_probes: Vec<String>,
intf_name: &str,
dns_registry: &mut DnsRegistry,
monitors: &mut Vec<Sender<DaemonEvent>>,
) -> HashSet<String> {
let mut waiting_services = HashSet::new();
for name in expired_probes {
let Some(probe) = dns_registry.probing.remove(&name) else {
continue;
};
for record in probe.records.iter() {
if let Some(new_name) = record.get_record().get_new_name() {
dns_registry
.name_changes
.insert(name.clone(), new_name.to_string());
let event = DnsNameChange {
original: record.get_record().get_original_name().to_string(),
new_name: new_name.to_string(),
rr_type: record.get_type(),
intf_name: intf_name.to_string(),
};
debug!("Name change event: {:?}", &event);
notify_monitors(monitors, DaemonEvent::NameChange(event));
}
}
debug!(
"probe of '{name}' finished: move {} records to active. ({} waiting services)",
probe.records.len(),
probe.waiting_services.len(),
);
if !probe.records.is_empty() {
match dns_registry.active.get_mut(&name) {
Some(records) => {
records.extend(probe.records);
}
None => {
dns_registry.active.insert(name, probe.records);
}
}
waiting_services.extend(probe.waiting_services);
}
}
waiting_services
}
#[cfg(test)]
mod tests {
use super::{
_new_socket_bind, check_domain_suffix, check_service_name_length, hostname_change,
my_ip_interfaces, name_change, send_dns_outgoing_impl, valid_instance_name,
HostnameResolutionEvent, ServiceDaemon, ServiceEvent, ServiceInfo, GROUP_ADDR_V4,
MDNS_PORT,
};
use crate::{
dns_parser::{
DnsIncoming, DnsOutgoing, DnsPointer, InterfaceId, RRType, ScopedIp, CLASS_IN,
FLAGS_AA, FLAGS_QR_RESPONSE,
},
service_daemon::{add_answer_of_service, check_hostname},
};
use std::{
net::{SocketAddr, SocketAddrV4},
time::{Duration, SystemTime},
};
use test_log::test;
#[test]
fn test_socketaddr_print() {
let addr: SocketAddr = SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT).into();
let print = format!("{}", addr);
assert_eq!(print, "224.0.0.251:5353");
}
#[test]
fn test_instance_name() {
assert!(valid_instance_name("my-laser._printer._tcp.local."));
assert!(valid_instance_name("my-laser.._printer._tcp.local."));
assert!(!valid_instance_name("_printer._tcp.local."));
}
#[test]
fn test_check_service_name_length() {
let result = check_service_name_length("_tcp", 100);
assert!(result.is_err());
if let Err(e) = result {
println!("{}", e);
}
}
#[test]
fn test_check_hostname() {
for hostname in &[
"my_host.local.",
&("A".repeat(255 - ".local.".len()) + ".local."),
] {
let result = check_hostname(hostname);
assert!(result.is_ok());
}
for hostname in &[
"my_host.local",
".local.",
&("A".repeat(256 - ".local.".len()) + ".local."),
] {
let result = check_hostname(hostname);
assert!(result.is_err());
if let Err(e) = result {
println!("{}", e);
}
}
}
#[test]
fn test_check_domain_suffix() {
assert!(check_domain_suffix("_missing_dot._tcp.local").is_err());
assert!(check_domain_suffix("_missing_bar.tcp.local.").is_err());
assert!(check_domain_suffix("_mis_spell._tpp.local.").is_err());
assert!(check_domain_suffix("_mis_spell._upp.local.").is_err());
assert!(check_domain_suffix("_has_dot._tcp.local.").is_ok());
assert!(check_domain_suffix("_goodname._udp.local.").is_ok());
}
#[test]
fn test_service_with_temporarily_invalidated_ptr() {
let d = ServiceDaemon::new().expect("Failed to create daemon");
let service = "_test_inval_ptr._udp.local.";
let host_name = "my_host_tmp_invalidated_ptr.local.";
let intfs: Vec<_> = my_ip_interfaces(false);
let intf_ips: Vec<_> = intfs.iter().map(|intf| intf.ip()).collect();
let port = 5201;
let my_service =
ServiceInfo::new(service, "my_instance", host_name, &intf_ips[..], port, None)
.expect("invalid service info")
.enable_addr_auto();
let result = d.register(my_service.clone());
assert!(result.is_ok());
let browse_chan = d.browse(service).unwrap();
let timeout = Duration::from_secs(2);
let mut resolved = false;
while let Ok(event) = browse_chan.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(info) => {
resolved = true;
println!("Resolved a service of {}", &info.fullname);
break;
}
e => {
println!("Received event {:?}", e);
}
}
}
assert!(resolved);
println!("Stopping browse of {}", service);
d.stop_browse(service).unwrap();
let mut stopped = false;
while let Ok(event) = browse_chan.recv_timeout(timeout) {
match event {
ServiceEvent::SearchStopped(_) => {
stopped = true;
println!("Stopped browsing service");
break;
}
e => {
println!("Received event {:?}", e);
}
}
}
assert!(stopped);
let invalidate_ptr_packet = DnsPointer::new(
my_service.get_type(),
RRType::PTR,
CLASS_IN,
0,
my_service.get_fullname().to_string(),
);
let mut packet_buffer = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
packet_buffer.add_additional_answer(invalidate_ptr_packet);
for intf in intfs {
let sock = _new_socket_bind(&intf, true).unwrap();
send_dns_outgoing_impl(
&packet_buffer,
&intf.name,
intf.index.unwrap_or(0),
&intf.addr,
&sock.pktinfo,
);
}
println!(
"Sent PTR record invalidation. Starting second browse for {}",
service
);
let browse_chan = d.browse(service).unwrap();
resolved = false;
while let Ok(event) = browse_chan.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(info) => {
resolved = true;
println!("Resolved a service of {}", &info.fullname);
break;
}
e => {
println!("Received event {:?}", e);
}
}
}
assert!(resolved);
d.shutdown().unwrap();
}
#[test]
fn test_expired_srv() {
let service_type = "_expired-srv._udp.local.";
let instance = "test_instance";
let host_name = "expired_srv_host.local.";
let mut my_service = ServiceInfo::new(service_type, instance, host_name, "", 5023, None)
.unwrap()
.enable_addr_auto();
let new_ttl = 3; my_service._set_host_ttl(new_ttl);
let mdns_server = ServiceDaemon::new().expect("Failed to create mdns server");
let result = mdns_server.register(my_service);
assert!(result.is_ok());
let mdns_client = ServiceDaemon::new().expect("Failed to create mdns client");
let browse_chan = mdns_client.browse(service_type).unwrap();
let timeout = Duration::from_secs(2);
let mut resolved = false;
while let Ok(event) = browse_chan.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(info) => {
resolved = true;
println!("Resolved a service of {}", &info.fullname);
break;
}
_ => {}
}
}
assert!(resolved);
mdns_server.shutdown().unwrap();
let expire_timeout = Duration::from_secs(new_ttl as u64);
while let Ok(event) = browse_chan.recv_timeout(expire_timeout) {
match event {
ServiceEvent::ServiceRemoved(service_type, full_name) => {
println!("Service removed: {}: {}", &service_type, &full_name);
break;
}
_ => {}
}
}
}
#[test]
fn test_hostname_resolution_address_removed() {
let server = ServiceDaemon::new().expect("Failed to create server");
let hostname = "addr_remove_host._tcp.local.";
let service_ip_addr: ScopedIp = my_ip_interfaces(false)
.iter()
.find(|iface| iface.ip().is_ipv4())
.map(|iface| iface.ip().into())
.unwrap();
let mut my_service = ServiceInfo::new(
"_host_res_test._tcp.local.",
"my_instance",
hostname,
&service_ip_addr.to_ip_addr(),
1234,
None,
)
.expect("invalid service info");
let addr_ttl = 2;
my_service._set_host_ttl(addr_ttl);
server.register(my_service).unwrap();
let client = ServiceDaemon::new().expect("Failed to create client");
let event_receiver = client.resolve_hostname(hostname, None).unwrap();
let resolved = loop {
match event_receiver.recv() {
Ok(HostnameResolutionEvent::AddressesFound(found_hostname, addresses)) => {
assert!(found_hostname == hostname);
assert!(addresses.contains(&service_ip_addr));
println!("address found: {:?}", &addresses);
break true;
}
Ok(HostnameResolutionEvent::SearchStopped(_)) => break false,
Ok(_event) => {}
Err(_) => break false,
}
};
assert!(resolved);
server.shutdown().unwrap();
let timeout = Duration::from_secs(addr_ttl as u64 + 1);
let removed = loop {
match event_receiver.recv_timeout(timeout) {
Ok(HostnameResolutionEvent::AddressesRemoved(removed_host, addresses)) => {
assert!(removed_host == hostname);
assert!(addresses.contains(&service_ip_addr));
println!(
"address removed: hostname: {} addresses: {:?}",
&hostname, &addresses
);
break true;
}
Ok(_event) => {}
Err(_) => {
break false;
}
}
};
assert!(removed);
client.shutdown().unwrap();
}
#[test]
fn test_refresh_ptr() {
let service_type = "_refresh-ptr._udp.local.";
let instance = "test_instance";
let host_name = "refresh_ptr_host.local.";
let service_ip_addr = my_ip_interfaces(false)
.iter()
.find(|iface| iface.ip().is_ipv4())
.map(|iface| iface.ip())
.unwrap();
let mut my_service = ServiceInfo::new(
service_type,
instance,
host_name,
&service_ip_addr,
5023,
None,
)
.unwrap();
let new_ttl = 3; my_service._set_other_ttl(new_ttl);
let mdns_server = ServiceDaemon::new().expect("Failed to create mdns server");
let result = mdns_server.register(my_service);
assert!(result.is_ok());
let mdns_client = ServiceDaemon::new().expect("Failed to create mdns client");
let browse_chan = mdns_client.browse(service_type).unwrap();
let timeout = Duration::from_millis(1500); let mut resolved = false;
while let Ok(event) = browse_chan.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(info) => {
resolved = true;
println!("Resolved a service of {}", &info.fullname);
break;
}
_ => {}
}
}
assert!(resolved);
let timeout = Duration::from_millis(new_ttl as u64 * 1000 * 90 / 100);
while let Ok(event) = browse_chan.recv_timeout(timeout) {
println!("event: {:?}", &event);
}
let metrics_chan = mdns_client.get_metrics().unwrap();
let metrics = metrics_chan.recv_timeout(timeout).unwrap();
let ptr_refresh_counter = metrics["cache-refresh-ptr"];
assert_eq!(ptr_refresh_counter, 1);
let srvtxt_refresh_counter = metrics["cache-refresh-srv-txt"];
assert_eq!(srvtxt_refresh_counter, 1);
mdns_server.shutdown().unwrap();
mdns_client.shutdown().unwrap();
}
#[test]
fn test_name_change() {
assert_eq!(name_change("foo.local."), "foo (2).local.");
assert_eq!(name_change("foo (2).local."), "foo (3).local.");
assert_eq!(name_change("foo (9).local."), "foo (10).local.");
assert_eq!(name_change("foo"), "foo (2)");
assert_eq!(name_change("foo (2)"), "foo (3)");
assert_eq!(name_change(""), " (2)");
assert_eq!(name_change("foo (abc)"), "foo (abc) (2)"); assert_eq!(name_change("foo (2"), "foo (2 (2)"); assert_eq!(name_change("foo (2) extra"), "foo (2) extra (2)"); }
#[test]
fn test_hostname_change() {
assert_eq!(hostname_change("foo.local."), "foo-2.local.");
assert_eq!(hostname_change("foo"), "foo-2");
assert_eq!(hostname_change("foo-2.local."), "foo-3.local.");
assert_eq!(hostname_change("foo-9"), "foo-10");
assert_eq!(hostname_change("test-42.domain."), "test-43.domain.");
}
#[test]
fn test_add_answer_txt_ttl() {
let service_type = "_test_add_answer._udp.local.";
let instance = "test_instance";
let host_name = "add_answer_host.local.";
let service_intf = my_ip_interfaces(false)
.into_iter()
.find(|iface| iface.ip().is_ipv4())
.unwrap();
let service_ip_addr = service_intf.ip();
let my_service = ServiceInfo::new(
service_type,
instance,
host_name,
&service_ip_addr,
5023,
None,
)
.unwrap();
let mut out = DnsOutgoing::new(FLAGS_QR_RESPONSE | FLAGS_AA);
let mut dummy_data = out.to_data_on_wire();
let interface_id = InterfaceId::from(&service_intf);
let incoming = DnsIncoming::new(dummy_data.pop().unwrap(), interface_id).unwrap();
let if_addrs = vec![service_intf.ip()];
add_answer_of_service(
&mut out,
&incoming,
instance,
&my_service,
RRType::TXT,
if_addrs,
);
assert!(
out.answers_count() > 0,
"No answers added to the outgoing message"
);
let answer = out._answers().first().unwrap();
assert_eq!(answer.0.get_type(), RRType::TXT);
assert_eq!(answer.0.get_record().get_ttl(), my_service.get_other_ttl());
}
#[test]
fn test_interface_flip() {
let ty_domain = "_intf-flip._udp.local.";
let host_name = "intf_flip.local.";
let now = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap();
let instance_name = now.as_micros().to_string(); let port = 5200;
let (ip_addr1, intf_name) = my_ip_interfaces(false)
.iter()
.find(|iface| iface.ip().is_ipv4())
.map(|iface| (iface.ip(), iface.name.clone()))
.unwrap();
println!("Using interface {} with IP {}", intf_name, ip_addr1);
let service1 =
ServiceInfo::new(ty_domain, &instance_name, host_name, &ip_addr1, port, None)
.expect("valid service info");
let server1 = ServiceDaemon::new().expect("failed to start server");
server1
.register(service1)
.expect("Failed to register service1");
std::thread::sleep(Duration::from_secs(2));
let client = ServiceDaemon::new().expect("failed to start client");
let receiver = client.browse(ty_domain).unwrap();
let timeout = Duration::from_secs(3);
let mut got_data = false;
while let Ok(event) = receiver.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(_) => {
println!("Received ServiceResolved event");
got_data = true;
break;
}
_ => {}
}
}
assert!(got_data, "Should receive ServiceResolved event");
client.set_ip_check_interval(1).unwrap();
println!("Shutting down interface {}", &intf_name);
client.test_down_interface(&intf_name).unwrap();
let mut got_removed = false;
while let Ok(event) = receiver.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceRemoved(ty_domain, instance) => {
got_removed = true;
println!("removed: {ty_domain} : {instance}");
break;
}
_ => {}
}
}
assert!(got_removed, "Should receive ServiceRemoved event");
println!("Bringing up interface {}", &intf_name);
client.test_up_interface(&intf_name).unwrap();
let mut got_data = false;
while let Ok(event) = receiver.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(resolved) => {
got_data = true;
println!("Received ServiceResolved: {:?}", resolved);
break;
}
_ => {}
}
}
assert!(
got_data,
"Should receive ServiceResolved event after interface is back up"
);
server1.shutdown().unwrap();
client.shutdown().unwrap();
}
#[test]
fn test_cache_only() {
let service_type = "_cache_only._udp.local.";
let instance = "test_instance";
let host_name = "cache_only_host.local.";
let service_ip_addr = my_ip_interfaces(false)
.iter()
.find(|iface| iface.ip().is_ipv4())
.map(|iface| iface.ip())
.unwrap();
let mut my_service = ServiceInfo::new(
service_type,
instance,
host_name,
&service_ip_addr,
5023,
None,
)
.unwrap();
let new_ttl = 3; my_service._set_other_ttl(new_ttl);
let mdns_client = ServiceDaemon::new().expect("Failed to create mdns client");
let browse_chan = mdns_client.browse_cache(service_type).unwrap();
std::thread::sleep(Duration::from_secs(2));
let mdns_server = ServiceDaemon::new().expect("Failed to create mdns server");
let result = mdns_server.register(my_service);
assert!(result.is_ok());
let timeout = Duration::from_millis(1500); let mut resolved = false;
while let Ok(event) = browse_chan.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(info) => {
resolved = true;
println!("Resolved a service of {}", &info.get_fullname());
break;
}
_ => {}
}
}
assert!(resolved);
mdns_server.shutdown().unwrap();
mdns_client.shutdown().unwrap();
}
#[test]
fn test_cache_only_unsolicited() {
let service_type = "_cache_only._udp.local.";
let instance = "test_instance";
let host_name = "cache_only_host.local.";
let service_ip_addr = my_ip_interfaces(false)
.iter()
.find(|iface| iface.ip().is_ipv4())
.map(|iface| iface.ip())
.unwrap();
let mut my_service = ServiceInfo::new(
service_type,
instance,
host_name,
&service_ip_addr,
5023,
None,
)
.unwrap();
let new_ttl = 3; my_service._set_other_ttl(new_ttl);
let mdns_server = ServiceDaemon::new().expect("Failed to create mdns server");
let result = mdns_server.register(my_service);
assert!(result.is_ok());
let mdns_client = ServiceDaemon::new().expect("Failed to create mdns client");
mdns_client.accept_unsolicited(true).unwrap();
std::thread::sleep(Duration::from_secs(2));
let browse_chan = mdns_client.browse_cache(service_type).unwrap();
let timeout = Duration::from_millis(1500); let mut resolved = false;
while let Ok(event) = browse_chan.recv_timeout(timeout) {
match event {
ServiceEvent::ServiceResolved(info) => {
resolved = true;
println!("Resolved a service of {}", &info.get_fullname());
break;
}
_ => {}
}
}
assert!(resolved);
mdns_server.shutdown().unwrap();
mdns_client.shutdown().unwrap();
}
}