use std::time::Duration;
use mdns_sd::{ServiceDaemon, ServiceInfo};
use crate::SERVICE_TYPE;
use crate::error::DiscoveryError;
use crate::txt::TxtRecord;
const UNREGISTER_TIMEOUT: Duration = Duration::from_secs(1);
#[derive(Debug, Clone)]
pub struct AdvertiseOptions {
pub port: u16,
pub instance_name: String,
pub hostname: String,
pub txt: TxtRecord,
}
pub struct Advertiser {
daemon: Option<ServiceDaemon>,
full_name: String,
}
impl Advertiser {
pub fn announce(opts: AdvertiseOptions) -> Result<Self, DiscoveryError> {
let daemon = ServiceDaemon::new()?;
let props = opts.txt.to_properties()?;
let raw_host = if opts.hostname.is_empty() {
local_hostname()
} else {
opts.hostname.clone()
};
let trimmed = raw_host
.trim()
.trim_end_matches(".local.")
.trim_end_matches(".local");
let host = if trimmed.is_empty() {
String::from("localhost.local.")
} else {
format!("{trimmed}.local.")
};
let info = ServiceInfo::new(
SERVICE_TYPE,
&opts.instance_name,
&host,
"",
opts.port,
Some(props),
)?
.enable_addr_auto();
let full_name = info.get_fullname().to_string();
daemon.register(info)?;
tracing::info!(
service = %full_name,
port = opts.port,
"rtl_tcp mDNS advertisement registered"
);
Ok(Self {
daemon: Some(daemon),
full_name,
})
}
pub fn stop(mut self) -> Result<(), DiscoveryError> {
let Some(daemon) = self.daemon.take() else {
return Ok(());
};
let unregister_result = daemon.unregister(&self.full_name);
if let Ok(rx) = &unregister_result {
let _ = rx.recv_timeout(UNREGISTER_TIMEOUT);
}
let _ = daemon.shutdown();
unregister_result?;
Ok(())
}
}
impl Drop for Advertiser {
fn drop(&mut self) {
let Some(daemon) = self.daemon.take() else {
return;
};
if let Ok(rx) = daemon.unregister(&self.full_name) {
let _ = rx.recv_timeout(UNREGISTER_TIMEOUT);
}
let _ = daemon.shutdown();
}
}
#[cfg(unix)]
#[allow(unsafe_code)]
pub fn local_hostname() -> String {
const BUFFER_LEN: usize = 256;
let mut buf = [0u8; BUFFER_LEN];
let rc = unsafe {
libc::gethostname(
buf.as_mut_ptr().cast::<libc::c_char>(),
std::mem::size_of_val(&buf),
)
};
if rc != 0 {
tracing::warn!("gethostname() failed, using 'localhost' as nickname default");
return String::from("localhost");
}
let name_len = buf.iter().position(|&b| b == 0).unwrap_or(BUFFER_LEN);
let Ok(name) = std::str::from_utf8(&buf[..name_len]) else {
tracing::warn!("gethostname() returned non-UTF-8 bytes, using 'localhost'");
return String::from("localhost");
};
let trimmed = name
.trim()
.trim_end_matches(".local.")
.trim_end_matches(".local");
if trimmed.is_empty() {
String::from("localhost")
} else {
trimmed.to_string()
}
}
#[cfg(not(unix))]
pub fn local_hostname() -> String {
String::from("localhost")
}