use mdns_sd::{ServiceDaemon, ServiceInfo};
use std::collections::HashMap;
use std::time::Duration;
use tracing::{debug, info, warn};
pub const MDNS_SERVICE_TYPE: &str = "_aerosync._tcp.local.";
#[derive(Debug, Clone)]
pub struct AeroSyncPeer {
pub name: String,
pub host: String,
pub port: u16,
pub version: Option<String>,
pub ws_enabled: bool,
pub auth_required: bool,
}
impl AeroSyncPeer {
pub fn addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
pub struct MdnsHandle {
daemon: ServiceDaemon,
service_fullname: String,
}
impl Drop for MdnsHandle {
fn drop(&mut self) {
if let Err(e) = self.daemon.unregister(&self.service_fullname) {
warn!("mDNS unregister failed: {}", e);
}
}
}
pub struct AeroSyncMdns;
impl AeroSyncMdns {
pub fn register(
instance_name: &str,
port: u16,
version: &str,
ws_enabled: bool,
auth_required: bool,
) -> Result<MdnsHandle, mdns_sd::Error> {
let daemon = ServiceDaemon::new()?;
let host = hostname_or_localhost();
let mut properties = HashMap::new();
properties.insert("version".to_string(), version.to_string());
properties.insert("ws".to_string(), ws_enabled.to_string());
properties.insert("auth".to_string(), auth_required.to_string());
let service = ServiceInfo::new(
MDNS_SERVICE_TYPE,
instance_name,
&host,
(), port,
properties,
)?;
let fullname = service.get_fullname().to_string();
daemon.register(service)?;
info!(
"mDNS: broadcasting AeroSync receiver as '{}' on port {}",
instance_name, port
);
Ok(MdnsHandle {
daemon,
service_fullname: fullname,
})
}
pub async fn discover(timeout: Duration) -> Vec<AeroSyncPeer> {
let (mdns_peers, local_peers) = tokio::join!(
tokio::task::spawn_blocking(move || Self::discover_sync(timeout)),
Self::probe_localhost_ports(&[7788, 7789, 8080, 9000]),
);
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut result = Vec::new();
for peer in local_peers {
let key = peer.addr();
if seen.insert(key) {
result.push(peer);
}
}
for peer in mdns_peers.unwrap_or_default() {
let key = peer.addr();
if seen.insert(key) {
result.push(peer);
}
}
result
}
async fn probe_localhost_ports(ports: &[u16]) -> Vec<AeroSyncPeer> {
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(1))
.build()
{
Ok(c) => c,
Err(_) => return vec![],
};
let mut peers = Vec::new();
for &port in ports {
let url = format!("http://127.0.0.1:{}/health", port);
if let Ok(resp) = client.get(&url).send().await {
let is_aerosync = resp
.headers()
.get("x-aerosync")
.and_then(|v| v.to_str().ok())
.map(|v| v == "true")
.unwrap_or(false);
if is_aerosync {
let body: serde_json::Value =
resp.json().await.unwrap_or(serde_json::Value::Null);
let version = body["version"].as_str().map(|s| s.to_string());
let ws_enabled = client
.get(format!("http://127.0.0.1:{}/ws", port))
.header("Upgrade", "websocket")
.header("Connection", "Upgrade")
.header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
.header("Sec-WebSocket-Version", "13")
.send()
.await
.map(|r| r.status().as_u16() == 101 || r.status().as_u16() == 400)
.unwrap_or(false);
debug!(
"localhost probe: found AeroSync on port {} (version={:?} ws={})",
port, version, ws_enabled
);
peers.push(AeroSyncPeer {
name: hostname_or_localhost(),
host: "127.0.0.1".to_string(),
port,
version,
ws_enabled,
auth_required: false, });
}
}
}
peers
}
fn discover_sync(timeout: Duration) -> Vec<AeroSyncPeer> {
let daemon = match ServiceDaemon::new() {
Ok(d) => d,
Err(e) => {
warn!("mDNS: failed to create daemon for discovery: {}", e);
return vec![];
}
};
let receiver = match daemon.browse(MDNS_SERVICE_TYPE) {
Ok(r) => r,
Err(e) => {
warn!("mDNS: failed to browse {}: {}", MDNS_SERVICE_TYPE, e);
return vec![];
}
};
let mut peers: HashMap<String, AeroSyncPeer> = HashMap::new();
let deadline = std::time::Instant::now() + timeout;
while let Some(remaining) = deadline.checked_duration_since(std::time::Instant::now()) {
let poll = remaining.min(Duration::from_millis(200));
match receiver.recv_timeout(poll) {
Ok(mdns_sd::ServiceEvent::ServiceResolved(info)) => {
let fullname = info.get_fullname().to_string();
let name = info.get_hostname().trim_end_matches('.').to_string();
let port = info.get_port();
let host = info
.get_addresses()
.iter()
.find(|a| a.is_ipv4())
.or_else(|| info.get_addresses().iter().next())
.map(|a| a.to_string())
.unwrap_or_else(|| name.clone());
let props = info.get_properties();
let version = props.get("version").map(|v| v.val_str().to_string());
let ws_enabled = props
.get("ws")
.map(|v| v.val_str() == "true")
.unwrap_or(true);
let auth_required = props
.get("auth")
.map(|v| v.val_str() == "true")
.unwrap_or(false);
debug!(
"mDNS resolved: {} → {}:{} (version={:?} ws={} auth={})",
name, host, port, version, ws_enabled, auth_required
);
peers.insert(
fullname,
AeroSyncPeer {
name,
host,
port,
version,
ws_enabled,
auth_required,
},
);
}
Ok(mdns_sd::ServiceEvent::SearchStopped(_)) => break,
Ok(_) => {} Err(_) => {
if std::time::Instant::now() >= deadline {
break;
}
}
}
}
let _ = daemon.stop_browse(MDNS_SERVICE_TYPE);
peers.into_values().collect()
}
}
fn hostname_or_localhost() -> String {
hostname::get()
.ok()
.and_then(|s| s.into_string().ok())
.unwrap_or_else(|| "localhost".to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_peer_addr_format() {
let peer = AeroSyncPeer {
name: "machine-a".to_string(),
host: "192.168.1.10".to_string(),
port: 7788,
version: Some("0.2.0".to_string()),
ws_enabled: true,
auth_required: false,
};
assert_eq!(peer.addr(), "192.168.1.10:7788");
}
#[test]
fn test_peer_addr_ipv6() {
let peer = AeroSyncPeer {
name: "machine-b".to_string(),
host: "::1".to_string(),
port: 7788,
version: None,
ws_enabled: false,
auth_required: true,
};
assert_eq!(peer.addr(), "::1:7788");
}
#[test]
fn test_mdns_service_type_constant() {
assert!(MDNS_SERVICE_TYPE.contains("_aerosync"));
assert!(MDNS_SERVICE_TYPE.ends_with(".local."));
}
#[test]
fn test_peer_fields() {
let peer = AeroSyncPeer {
name: "recv-1".to_string(),
host: "10.0.0.5".to_string(),
port: 8080,
version: Some("0.2.0".to_string()),
ws_enabled: true,
auth_required: true,
};
assert!(peer.ws_enabled);
assert!(peer.auth_required);
assert_eq!(peer.version.as_deref(), Some("0.2.0"));
assert_eq!(peer.port, 8080);
}
}