use std::{
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
sync::Arc,
};
use tokio::{net::UdpSocket, sync::RwLock};
use hickory_proto::{
op::{Message, MessageType, OpCode},
rr::{
rdata::{A, AAAA, PTR, SRV},
Name, RData, Record,
},
serialize::binary::BinEncodable,
};
use crate::mdns::{MDNS_IPV4_ADDR, MDNS_IPV6_ADDR, MDNS_PORT};
const RECORD_TTL: u32 = 120;
pub fn get_interface_index(ip: &IpAddr, if_addrs: &[if_addrs::Interface]) -> Option<u32> {
for iface in if_addrs {
match (&iface.addr, ip) {
(if_addrs::IfAddr::V4(v4), IpAddr::V4(target)) if v4.ip == *target => {
return iface.index;
}
(if_addrs::IfAddr::V6(v6), IpAddr::V6(target)) if v6.ip == *target => {
return iface.index;
}
_ => {}
}
}
None
}
pub async fn setup_multicast_socket(
advertised_ip: IpAddr,
) -> Result<Arc<UdpSocket>, std::io::Error> {
match advertised_ip {
IpAddr::V4(ipv4) => {
let socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
socket.set_reuse_address(true)?;
#[cfg(target_family = "unix")]
socket.set_reuse_port(true)?;
socket.bind(&SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, MDNS_PORT).into())?;
socket.set_multicast_loop_v4(true)?;
socket.join_multicast_v4(&MDNS_IPV4_ADDR, &ipv4)?;
log::debug!("Joined mDNS IPv4 multicast group on interface {}", ipv4);
socket.set_multicast_if_v4(&ipv4)?;
log::debug!("Set multicast IPv4 interface to {}", ipv4);
socket.set_nonblocking(true)?;
let socket = std::net::UdpSocket::from(socket);
Ok(Arc::new(UdpSocket::from_std(socket)?))
}
IpAddr::V6(ipv6) => {
let socket = socket2::Socket::new(
socket2::Domain::IPV6,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
socket.set_reuse_address(true)?;
#[cfg(target_family = "unix")]
socket.set_reuse_port(true)?;
socket.bind(&SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, MDNS_PORT, 0, 0).into())?;
socket.set_multicast_loop_v6(true)?;
let if_addrs = if_addrs::get_if_addrs()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let if_index = get_interface_index(&advertised_ip, &if_addrs).unwrap_or(0);
socket.join_multicast_v6(&MDNS_IPV6_ADDR, if_index)?;
log::debug!(
"Joined mDNS IPv6 multicast group on interface {} with index {}",
ipv6,
if_index
);
socket.set_multicast_if_v6(if_index)?;
log::debug!("Set multicast IPv6 interface to index {}", if_index);
socket.set_nonblocking(true)?;
let socket = std::net::UdpSocket::from(socket);
Ok(Arc::new(UdpSocket::from_std(socket)?))
}
}
}
pub async fn send_mdns_announcement(
socket: &UdpSocket,
instance_name: &Name,
port: u16,
advertised_ip: &Arc<RwLock<IpAddr>>,
) -> Result<usize, hickory_proto::ProtoError> {
let ip = *advertised_ip.read().await;
let response_message = create_mdns_response_message(instance_name, ip, port);
let bytes = response_message.to_bytes()?;
let multicast_addr: SocketAddr = match ip {
IpAddr::V4(_) => SocketAddr::new(IpAddr::V4(MDNS_IPV4_ADDR), MDNS_PORT),
IpAddr::V6(_) => SocketAddr::new(IpAddr::V6(MDNS_IPV6_ADDR), MDNS_PORT),
};
match socket.send_to(&bytes, multicast_addr).await {
Ok(n) => {
log::debug!(
"Sent mDNS announcement for {} with IP {} ({} bytes)",
instance_name,
ip,
n
);
Ok(n)
}
Err(e) => {
log::warn!(
"Failed to send mDNS announcement for {} with IP {}: {}",
instance_name,
ip,
e
);
Ok(0)
}
}
}
pub fn create_mdns_response_message(
instance_name: &Name,
interface_ip: IpAddr,
port: u16,
) -> Message {
let mut message = Message::new();
message
.set_id(0)
.set_message_type(MessageType::Response)
.set_op_code(OpCode::Query)
.set_authoritative(true);
let service_type_name = instance_name.trim_to(3); message.add_answer(Record::from_rdata(
service_type_name.clone(), RECORD_TTL,
RData::PTR(PTR(instance_name.clone())), ));
message.add_additional(Record::from_rdata(
instance_name.clone(),
RECORD_TTL,
RData::SRV(SRV::new(0, 0, port, instance_name.clone())),
));
match interface_ip {
IpAddr::V4(ipv4_addr) => {
message.add_additional(Record::from_rdata(
instance_name.clone(),
RECORD_TTL,
RData::A(A(ipv4_addr)),
));
}
IpAddr::V6(ipv6_addr) => {
message.add_additional(Record::from_rdata(
instance_name.clone(),
RECORD_TTL,
RData::AAAA(AAAA(ipv6_addr)),
));
}
}
message
}
pub fn extract_service_info(message: &Message) -> Option<(Name, SocketAddr)> {
let mut ptr_instance_name: Option<Name> = None;
let mut srv_target_name: Option<Name> = None;
let mut srv_port: Option<u16> = None;
let mut srv_actual_target_host: Option<Name> = None;
let mut ip_address: Option<IpAddr> = None;
let mut ip_owner_name: Option<Name> = None;
for record in message.answers().iter().chain(message.additionals()) {
match record.data() {
RData::PTR(ptr_data) => {
if ptr_instance_name.is_none() {
ptr_instance_name = Some(ptr_data.0.clone());
log::trace!("Found PTR record: {} -> {}", record.name(), ptr_data.0);
}
}
RData::SRV(srv_data) => {
if srv_target_name.is_none() && srv_port.is_none() {
srv_target_name = Some(record.name().clone());
srv_port = Some(srv_data.port());
srv_actual_target_host = Some(srv_data.target().clone());
log::trace!(
"Found SRV record: {} port {} target {}",
record.name(),
srv_data.port(),
srv_data.target()
);
}
}
RData::A(a_data) => {
if ip_address.is_none() {
ip_address = Some(IpAddr::V4(a_data.0));
ip_owner_name = Some(record.name().clone());
log::trace!("Found A record: {} -> {}", record.name(), a_data.0);
}
}
RData::AAAA(aaaa_data) => {
if ip_address.is_none() {
ip_address = Some(IpAddr::V6(aaaa_data.0));
ip_owner_name = Some(record.name().clone());
log::trace!("Found AAAA record: {} -> {}", record.name(), aaaa_data.0);
}
}
_ => {} }
}
if let (
Some(instance_name),
Some(srv_owner),
Some(port),
Some(srv_target),
Some(ip_addr_val),
Some(ip_owner),
) = (
ptr_instance_name.clone(),
srv_target_name.clone(),
srv_port,
srv_actual_target_host.clone(),
ip_address,
ip_owner_name.clone(),
) {
if srv_owner == instance_name && ip_owner == srv_target {
log::trace!(
"Successfully extracted service info: Name='{}', IP='{}', Port='{}'",
instance_name,
ip_addr_val,
port
);
return Some((instance_name, SocketAddr::new(ip_addr_val, port)));
} else {
log::warn!(
"Inconsistent records for service extraction. PTR instance: {}, SRV owner: {}, SRV target: {}, IP owner: {}",
instance_name, srv_owner, srv_target, ip_owner
);
}
} else {
log::trace!("Could not extract complete service info. Missing one or more required records. PTR: {:?}, SRV_Owner: {:?}, Port: {:?}, SRV_Target: {:?}, IP: {:?}, IP_Owner: {:?}",
ptr_instance_name.as_ref().map(|n| n.to_utf8()),
srv_target_name.as_ref().map(|n| n.to_utf8()),
srv_port,
srv_actual_target_host.as_ref().map(|n| n.to_utf8()),
ip_address,
ip_owner_name.as_ref().map(|n| n.to_utf8())
);
}
None
}