#![cfg(feature = "discovery")]
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use mdns_sd::{ServiceDaemon, ServiceEvent, ServiceInfo};
use tokio::sync::mpsc;
use super::{Backend, CA_SERVICE_TYPE, DiscoveryEvent};
const MDNS_TYPE: &str = "_epics-ca._tcp.local.";
const MAX_MDNS_INSTANCES: usize = 256;
pub struct MdnsBackend {
daemon: ServiceDaemon,
browser: tokio::task::JoinHandle<()>,
snapshot: Arc<Mutex<Vec<SocketAddr>>>,
event_rx: Mutex<Option<mpsc::UnboundedReceiver<DiscoveryEvent>>>,
}
impl Drop for MdnsBackend {
fn drop(&mut self) {
let _ = self.daemon.shutdown();
self.browser.abort();
}
}
impl MdnsBackend {
pub fn new() -> Result<Self, mdns_sd::Error> {
let daemon = ServiceDaemon::new()?;
let receiver = daemon.browse(MDNS_TYPE)?;
let snapshot: Arc<Mutex<Vec<SocketAddr>>> = Arc::new(Mutex::new(Vec::new()));
let (event_tx, event_rx) = mpsc::unbounded_channel();
let snap_clone = snapshot.clone();
let browser = tokio::spawn(async move {
let mut known: std::collections::HashMap<
String,
std::collections::HashSet<SocketAddr>,
> = std::collections::HashMap::new();
while let Ok(event) = receiver.recv_async().await {
match event {
ServiceEvent::ServiceResolved(info) => {
let fullname = info.get_fullname().to_string();
if !known.contains_key(&fullname) && known.len() >= MAX_MDNS_INSTANCES {
tracing::warn!(instance = %fullname,
cap = MAX_MDNS_INSTANCES,
"mDNS: instance cap reached; ignoring new instance");
continue;
}
let resolved: std::collections::HashSet<SocketAddr> =
resolve_addresses(&info).into_iter().collect();
let prev = known.entry(fullname.clone()).or_default();
for &addr in resolved.difference(prev) {
if let Ok(mut snap) = snap_clone.lock() {
if !snap.contains(&addr) {
snap.push(addr);
}
}
let _ = event_tx.send(DiscoveryEvent::Added {
instance: fullname.clone(),
addr,
});
tracing::info!(addr = %addr, instance = %fullname,
"mDNS discovered IOC");
}
for &addr in prev.difference(&resolved) {
if let Ok(mut snap) = snap_clone.lock() {
snap.retain(|a| *a != addr);
}
let _ = event_tx.send(DiscoveryEvent::Removed {
instance: fullname.clone(),
addr,
});
tracing::info!(addr = %addr, instance = %fullname,
"mDNS IOC address withdrawn");
}
*prev = resolved;
}
ServiceEvent::ServiceRemoved(_, fullname) => {
if let Some(addrs) = known.remove(&fullname) {
for addr in addrs {
if let Ok(mut snap) = snap_clone.lock() {
snap.retain(|a| *a != addr);
}
let _ = event_tx.send(DiscoveryEvent::Removed {
instance: fullname.clone(),
addr,
});
}
}
tracing::info!(instance = %fullname, "mDNS IOC went away");
}
_ => {}
}
}
});
Ok(Self {
daemon,
browser,
snapshot,
event_rx: Mutex::new(Some(event_rx)),
})
}
}
#[async_trait::async_trait]
impl Backend for MdnsBackend {
async fn discover(&self) -> Vec<SocketAddr> {
tokio::time::sleep(Duration::from_millis(500)).await;
self.snapshot.lock().map(|s| s.clone()).unwrap_or_default()
}
fn subscribe(&self) -> Option<mpsc::UnboundedReceiver<DiscoveryEvent>> {
self.event_rx.lock().ok().and_then(|mut g| g.take())
}
}
pub struct MdnsAnnouncer {
daemon: ServiceDaemon,
fullname: String,
}
impl MdnsAnnouncer {
pub fn announce(
instance: &str,
tcp_port: u16,
txt: Vec<(String, String)>,
) -> Result<Self, mdns_sd::Error> {
let daemon = ServiceDaemon::new()?;
let hostname = epics_base_rs::runtime::env::hostname();
let host_target = format!("{hostname}.local.");
let ips: Vec<IpAddr> = if_addrs::get_if_addrs()
.unwrap_or_default()
.into_iter()
.filter(|iface| !iface.is_loopback())
.filter_map(|iface| match iface.ip() {
IpAddr::V4(v4) => Some(IpAddr::V4(v4)),
_ => None,
})
.collect();
let info = ServiceInfo::new(
MDNS_TYPE,
instance,
&host_target,
&ips[..],
tcp_port,
&txt[..],
)?;
let fullname = info.get_fullname().to_string();
daemon.register(info)?;
tracing::info!(instance = %instance, port = tcp_port,
"mDNS announce registered ({fullname})");
Ok(Self { daemon, fullname })
}
}
impl Drop for MdnsAnnouncer {
fn drop(&mut self) {
let _ = self.daemon.unregister(&self.fullname);
let _ = self.daemon.shutdown();
}
}
impl MdnsBackend {
pub fn announce_helper(
instance: &str,
port: u16,
txt: Vec<(String, String)>,
) -> Result<MdnsAnnouncer, mdns_sd::Error> {
MdnsAnnouncer::announce(instance, port, txt)
}
}
fn resolve_addresses(info: &ServiceInfo) -> Vec<SocketAddr> {
let port = info.get_port();
info.get_addresses_v4()
.iter()
.map(|ip| SocketAddr::new(IpAddr::V4(**ip), port))
.collect()
}
#[allow(dead_code)]
const _: fn() = || {
let _ = CA_SERVICE_TYPE;
};