use mdns_sd::{ResolvedService, ServiceDaemon, ServiceEvent as MdnsEvent, ServiceInfo};
use std::collections::HashMap;
use std::sync::Mutex;
use std::time::Duration;
use tokio::sync::oneshot;
use koi_common::types::ServiceRecord;
use crate::error::{MdnsError, Result};
const RESOLVE_TIMEOUT: Duration = Duration::from_secs(5);
enum MdnsOp {
Register(Box<ServiceInfo>),
Unregister(String), Browse {
service_type: String,
reply: oneshot::Sender<std::result::Result<mdns_sd::Receiver<MdnsEvent>, String>>,
},
StopBrowse(String),
Shutdown {
reply: oneshot::Sender<std::result::Result<(), String>>,
},
}
pub(crate) struct MdnsDaemon {
op_tx: Mutex<std::sync::mpsc::SyncSender<MdnsOp>>,
}
impl MdnsDaemon {
pub fn new() -> Result<Self> {
let daemon = ServiceDaemon::new().map_err(|e| MdnsError::Daemon(e.to_string()))?;
let (op_tx, op_rx) = std::sync::mpsc::sync_channel(256);
std::thread::Builder::new()
.name("koi-mdns-ops".into())
.spawn(move || worker_loop(daemon, op_rx))
.map_err(|e| MdnsError::Daemon(format!("Failed to spawn mDNS worker: {e}")))?;
Ok(Self {
op_tx: Mutex::new(op_tx),
})
}
fn send(&self, op: MdnsOp) -> Result<()> {
self.op_tx
.lock()
.unwrap_or_else(|e| e.into_inner())
.try_send(op)
.map_err(|e| match e {
std::sync::mpsc::TrySendError::Full(_) => {
MdnsError::Daemon("mDNS worker queue full".into())
}
std::sync::mpsc::TrySendError::Disconnected(_) => {
MdnsError::Daemon("mDNS worker stopped".into())
}
})
}
pub async fn browse(&self, service_type: &str) -> Result<mdns_sd::Receiver<MdnsEvent>> {
let (tx, rx) = oneshot::channel();
self.send(MdnsOp::Browse {
service_type: service_type.to_string(),
reply: tx,
})?;
rx.await
.map_err(|_| MdnsError::Daemon("mDNS worker dropped reply".into()))?
.map_err(MdnsError::Daemon)
}
pub fn register(
&self,
name: &str,
service_type: &str,
port: u16,
ip: Option<&str>,
txt: &HashMap<String, String>,
) -> Result<()> {
let hostname = hostname::get()
.unwrap_or_else(|_| "localhost".into())
.to_string_lossy()
.to_string();
let host = format!("{hostname}.local.");
let properties: Vec<(&str, &str)> =
txt.iter().map(|(k, v)| (k.as_str(), v.as_str())).collect();
let ip_str = ip.unwrap_or("");
let service_info =
ServiceInfo::new(service_type, name, &host, ip_str, port, &properties[..])
.map_err(|e| MdnsError::Daemon(e.to_string()))?;
let mut service_info = if ip.is_none() {
service_info.enable_addr_auto()
} else {
service_info
};
service_info.set_requires_probe(false);
let fullname = service_info.get_fullname().to_string();
tracing::debug!(fullname, ?ip, "Queued mDNS register");
self.send(MdnsOp::Register(Box::new(service_info)))
}
pub fn unregister(&self, name: &str, service_type: &str) -> Result<()> {
let fullname = format!("{name}.{service_type}");
self.send(MdnsOp::Unregister(fullname))
}
pub async fn resolve(&self, instance: &str) -> Result<ServiceRecord> {
let parts: Vec<&str> = instance.splitn(2, '.').collect();
if parts.len() < 2 {
return Err(MdnsError::ResolveTimeout(format!(
"Invalid instance name: {instance}"
)));
}
let service_type = parts[1];
let receiver = self.browse(service_type).await?;
let target_name = parts[0];
let deadline = tokio::time::Instant::now() + RESOLVE_TIMEOUT;
loop {
tokio::select! {
result = receiver.recv_async() => {
match result {
Ok(MdnsEvent::ServiceResolved(resolved)) => {
let record = resolved_to_record(&resolved);
if record.name == target_name || resolved.get_fullname() == instance {
let _ = self.stop_browse(service_type);
return Ok(record);
}
}
Ok(_) => continue,
Err(_) => {
let _ = self.stop_browse(service_type);
break;
}
}
}
_ = tokio::time::sleep_until(deadline) => {
let _ = self.stop_browse(service_type);
return Err(MdnsError::ResolveTimeout(format!(
"Could not resolve {instance} within {RESOLVE_TIMEOUT:?}"
)));
}
}
}
Err(MdnsError::ResolveTimeout(format!(
"Could not resolve {instance}"
)))
}
pub fn stop_browse(&self, service_type: &str) -> Result<()> {
self.send(MdnsOp::StopBrowse(service_type.to_string()))
}
pub async fn shutdown(&self) -> Result<()> {
let (tx, rx) = oneshot::channel();
self.send(MdnsOp::Shutdown { reply: tx })?;
rx.await
.map_err(|_| MdnsError::Daemon("mDNS worker dropped reply".into()))?
.map_err(MdnsError::Daemon)
}
}
fn worker_loop(daemon: ServiceDaemon, rx: std::sync::mpsc::Receiver<MdnsOp>) {
tracing::debug!("mDNS worker thread started");
while let Ok(op) = rx.recv() {
match op {
MdnsOp::Register(info) => {
let fullname = info.get_fullname().to_string();
if let Err(e) = daemon.register(*info) {
tracing::warn!(fullname, error = %e, "mDNS register failed");
}
}
MdnsOp::Unregister(fullname) => {
if let Err(e) = daemon.unregister(&fullname) {
tracing::warn!(fullname, error = %e, "mDNS unregister failed");
}
}
MdnsOp::Browse {
service_type,
reply,
} => {
let result = daemon.browse(&service_type).map_err(|e| e.to_string());
let _ = reply.send(result);
}
MdnsOp::StopBrowse(service_type) => {
if let Err(e) = daemon.stop_browse(&service_type) {
tracing::debug!(service_type, error = %e, "mDNS stop_browse failed");
}
}
MdnsOp::Shutdown { reply } => {
let result = daemon.shutdown().map(|_| ()).map_err(|e| e.to_string());
let _ = reply.send(result);
break;
}
}
}
tracing::debug!("mDNS worker thread stopped");
}
pub(crate) fn resolved_to_record(resolved: &ResolvedService) -> ServiceRecord {
let fullname = resolved.get_fullname();
let name = fullname
.find("._")
.map(|i| &fullname[..i])
.unwrap_or(fullname)
.to_string();
let service_type = resolved.ty_domain.clone();
let service_type = service_type
.trim_end_matches('.')
.trim_end_matches(".local")
.to_string();
let host = resolved.get_hostname().to_string();
let host = if host.is_empty() { None } else { Some(host) };
let addresses = resolved.get_addresses();
let ip = addresses
.iter()
.find(|a| a.is_ipv4())
.or_else(|| addresses.iter().next())
.map(|a| a.to_ip_addr());
let ip = ip.map(|addr| {
if addr.is_loopback() {
lan_ip().unwrap_or(addr).to_string()
} else {
addr.to_string()
}
});
if addresses.len() > 1 {
tracing::trace!(
name,
count = addresses.len(),
selected = ?ip,
"Multiple IPs found, using first"
);
}
let txt: HashMap<String, String> = resolved
.get_properties()
.iter()
.map(|p| (p.key().to_string(), p.val_str().to_string()))
.collect();
ServiceRecord {
name,
service_type,
host,
ip,
port: Some(resolved.get_port()),
txt,
}
}
fn lan_ip() -> Option<std::net::IpAddr> {
if_addrs::get_if_addrs()
.unwrap_or_default()
.into_iter()
.filter(|iface| !iface.is_loopback())
.filter_map(|iface| match iface.addr.ip() {
std::net::IpAddr::V4(v4) if !v4.is_link_local() => Some(std::net::IpAddr::V4(v4)),
_ => None,
})
.next()
}