use std::{
collections::{HashMap, HashSet},
net::{IpAddr, SocketAddr},
sync::Arc,
};
use hickory_proto::{
op::{Message, MessageType, ResponseCode},
rr::{Name, RecordType},
serialize::binary::{BinDecodable, BinEncodable},
};
use tokio::{
net::UdpSocket,
sync::{mpsc, RwLock},
};
use crate::mdns::{MDNS_IPV4_ADDR, MDNS_IPV6_ADDR, MDNS_PORT};
use super::utils::{create_mdns_response_message, extract_service_info};
const BUFFER_SIZE: usize = 4096;
pub async fn server_task(
socket: Arc<UdpSocket>,
notifier_tx: mpsc::Sender<(Name, SocketAddr)>,
registered_services: Arc<RwLock<HashMap<Name, HashMap<Name, u16>>>>,
service_cache: Arc<RwLock<HashMap<Name, SocketAddr>>>,
follow_services: Arc<RwLock<HashSet<Name>>>,
advertised_ip: Arc<RwLock<IpAddr>>,
) {
let mut buf = [0u8; BUFFER_SIZE];
loop {
let (len, src_addr) = match socket.recv_from(&mut buf).await {
Ok(v) => v,
Err(e) => {
if e.kind() == std::io::ErrorKind::ConnectionReset
|| e.kind() == std::io::ErrorKind::BrokenPipe
{
log::warn!("Socket connection error ({}). Task for {:?} might need to be restarted or interface is down.", e, socket.local_addr().ok());
break;
} else {
log::error!(
"Failed to receive data on mDNS socket {:?}: {}",
socket.local_addr().ok(),
e
);
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
continue;
}
};
if len == 0 {
continue;
}
let message = match Message::from_bytes(&buf[..len]) {
Ok(msg) => msg,
Err(e) => {
log::warn!(
"Failed to parse mDNS message from bytes ({} bytes received) from {}: {}",
len,
src_addr,
e
);
continue;
}
};
if message.response_code() != ResponseCode::NoError {
log::debug!(
"Ignoring mDNS message with response code: {:?}",
message.response_code()
);
continue;
}
match message.message_type() {
MessageType::Query => {
handle_query(
message,
&socket,
®istered_services,
&advertised_ip,
src_addr,
)
.await;
}
MessageType::Response => {
handle_response(
message,
¬ifier_tx,
®istered_services,
&service_cache,
&follow_services,
)
.await;
}
}
}
}
async fn handle_query(
query_message: Message,
socket: &UdpSocket,
registered_services: &Arc<RwLock<HashMap<Name, HashMap<Name, u16>>>>,
advertised_ip: &Arc<RwLock<IpAddr>>,
src_addr: SocketAddr,
) {
let interface_ip = *advertised_ip.read().await;
for query in query_message.queries() {
log::trace!(
"Received mDNS query for service name: {}, type: {:?}, class: {:?}",
query.name(),
query.query_type(),
query.query_class()
);
let query_name_str = query.name().to_utf8();
let services_guard = registered_services.read().await;
if query.query_type() == RecordType::PTR || query.query_type() == RecordType::ANY {
if let Some(instances_map) = services_guard.get(query.name()) {
for (instance_name, &port) in instances_map.iter() {
log::debug!(
"Responding to PTR/ANY query for service type {} with instance: {} at {}",
query_name_str,
instance_name,
port
);
let response_message =
create_mdns_response_message(instance_name, interface_ip, port);
let bytes = response_message.to_bytes();
match bytes {
Ok(bytes) => {
if query.mdns_unicast_response {
if let Err(e) = socket.send_to(&bytes, src_addr).await {
log::error!(
"Failed to send response for instance {}: {}",
instance_name,
e
);
}
} else {
let multicast_addr: SocketAddr = match interface_ip {
IpAddr::V4(_) => {
SocketAddr::new(IpAddr::V4(MDNS_IPV4_ADDR), MDNS_PORT)
}
IpAddr::V6(_) => {
SocketAddr::new(IpAddr::V6(MDNS_IPV6_ADDR), MDNS_PORT)
}
};
if let Err(e) = socket.send_to(&bytes, multicast_addr).await {
log::error!(
"Failed to send multicast response for instance {}: {}",
instance_name,
e
);
}
}
}
Err(e) => {
log::error!(
"Failed to serialize response for instance {}: {}",
instance_name,
e
);
}
}
}
}
}
if query.name().num_labels() > 3 {
let service_type_key = query.name().trim_to(3);
if let Some(instances_map) = services_guard.get(&service_type_key) {
if let Some((registered_instance_name, &port)) =
instances_map.get_key_value(query.name())
{
log::debug!(
"Responding to specific query for registered service instance: {} at {}",
registered_instance_name,
port
);
let response_message =
create_mdns_response_message(registered_instance_name, interface_ip, port);
let bytes = response_message.to_bytes();
match bytes {
Ok(bytes) => {
if query.mdns_unicast_response {
if let Err(e) = socket.send_to(&bytes, src_addr).await {
log::error!(
"Failed to send response for instance {}: {}",
registered_instance_name,
e
);
}
} else {
let multicast_addr: SocketAddr = match interface_ip {
IpAddr::V4(_) => {
SocketAddr::new(IpAddr::V4(MDNS_IPV4_ADDR), MDNS_PORT)
}
IpAddr::V6(_) => {
SocketAddr::new(IpAddr::V6(MDNS_IPV6_ADDR), MDNS_PORT)
}
};
if let Err(e) = socket.send_to(&bytes, multicast_addr).await {
log::error!(
"Failed to send multicast response for instance {}: {}",
registered_instance_name,
e
);
}
}
}
Err(e) => {
log::error!(
"Failed to serialize response for instance {}: {}",
registered_instance_name,
e
);
}
}
}
}
}
}
}
async fn handle_response(
response_message: Message,
notifier_tx: &mpsc::Sender<(Name, SocketAddr)>,
registered_services: &Arc<RwLock<HashMap<Name, HashMap<Name, u16>>>>,
service_cache: &Arc<RwLock<HashMap<Name, SocketAddr>>>,
follow_services: &Arc<RwLock<HashSet<Name>>>,
) {
if let Some((discovered_instance_name, discovered_addr)) =
extract_service_info(&response_message)
{
log::trace!(
"Potential service discovered in response: {} at {}",
discovered_instance_name,
discovered_addr
);
let service_type_name = discovered_instance_name.trim_to(3);
let is_following: bool = {
let follow_guard = follow_services.read().await;
let registered_services_guard = registered_services.read().await;
follow_guard.contains(&service_type_name)
&& !registered_services_guard
.get(&service_type_name)
.map_or(false, |instances| {
instances.contains_key(&discovered_instance_name) })
};
if is_following {
let mut cache_guard = service_cache.write().await;
let old_value = cache_guard.insert(discovered_instance_name.clone(), discovered_addr);
if old_value.map_or(true, |old_addr| old_addr != discovered_addr) {
log::debug!(
"Service cache updated for: {} at {} (was {:?})",
discovered_instance_name,
discovered_addr,
old_value
);
if let Err(e) = notifier_tx
.send((discovered_instance_name.clone(), discovered_addr))
.await
{
log::error!(
"Failed to send notification for discovered service {}: {}",
discovered_instance_name,
e
);
}
} else {
log::trace!(
"Service cache already up-to-date for: {} at {}",
discovered_instance_name,
discovered_addr
);
}
} else {
log::trace!(
"Ignoring discovered service {} of type {} because not followed.",
discovered_instance_name,
service_type_name
);
}
} else {
log::trace!("Response message did not contain complete service info or was not a service announcement.");
}
}