use std::collections::HashMap;
use std::time::Duration;
use super::{DiscoveredPeer, Discovery};
use crate::error::Error;
pub const DEFAULT_SERVICE_TYPE: &str = "_cognitum._tcp.local.";
pub const DEFAULT_BROWSE_DURATION: Duration = Duration::from_secs(2);
pub const DEFAULT_PORT: u16 = 8443;
#[derive(Debug, Clone)]
pub struct MdnsDiscovery {
service_type: String,
browse_duration: Duration,
scheme: &'static str,
default_port: u16,
}
impl Default for MdnsDiscovery {
fn default() -> Self {
Self {
service_type: DEFAULT_SERVICE_TYPE.to_owned(),
browse_duration: DEFAULT_BROWSE_DURATION,
scheme: "https",
default_port: DEFAULT_PORT,
}
}
}
impl MdnsDiscovery {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> MdnsDiscoveryBuilder {
MdnsDiscoveryBuilder::default()
}
pub fn service_type(&self) -> &str {
&self.service_type
}
pub fn browse_duration(&self) -> Duration {
self.browse_duration
}
}
#[derive(Debug, Clone)]
pub struct MdnsDiscoveryBuilder {
service_type: String,
browse_duration: Duration,
scheme: &'static str,
default_port: u16,
}
impl Default for MdnsDiscoveryBuilder {
fn default() -> Self {
let defaults = MdnsDiscovery::default();
Self {
service_type: defaults.service_type,
browse_duration: defaults.browse_duration,
scheme: defaults.scheme,
default_port: defaults.default_port,
}
}
}
impl MdnsDiscoveryBuilder {
pub fn service_type(mut self, st: impl Into<String>) -> Self {
self.service_type = st.into();
self
}
pub fn browse_duration(mut self, d: Duration) -> Self {
self.browse_duration = d;
self
}
pub fn scheme(mut self, scheme: &'static str) -> Self {
self.scheme = scheme;
self
}
pub fn default_port(mut self, port: u16) -> Self {
self.default_port = port;
self
}
pub fn build(self) -> MdnsDiscovery {
MdnsDiscovery {
service_type: self.service_type,
browse_duration: self.browse_duration,
scheme: self.scheme,
default_port: self.default_port,
}
}
}
#[async_trait::async_trait]
impl Discovery for MdnsDiscovery {
async fn discover(&self) -> Result<Vec<DiscoveredPeer>, Error> {
let service_type = self.service_type.clone();
let budget = self.browse_duration;
let scheme = self.scheme;
let default_port = self.default_port;
let handle = tokio::task::spawn_blocking(move || -> Result<Vec<DiscoveredPeer>, Error> {
let daemon = mdns_sd::ServiceDaemon::new().map_err(|e| Error::Api {
code: 0,
message: format!("mdns: failed to start ServiceDaemon: {e}"),
})?;
let receiver = daemon.browse(&service_type).map_err(|e| Error::Api {
code: 0,
message: format!("mdns: browse failed for `{service_type}`: {e}"),
})?;
let mut seen: HashMap<String, DiscoveredPeer> = HashMap::new();
let started = std::time::Instant::now();
loop {
let elapsed = started.elapsed();
if elapsed >= budget {
break;
}
let remaining = budget - elapsed;
match receiver.recv_timeout(remaining) {
Ok(mdns_sd::ServiceEvent::ServiceResolved(info)) => {
let latency = elapsed.as_millis().min(u128::from(u32::MAX)) as u32;
let fullname = info.get_fullname().to_owned();
let peer = resolve_info_to_peer(&info, scheme, default_port, latency);
if let Some(p) = peer {
seen.insert(fullname, p);
}
}
Ok(_) => {}
Err(_) => break,
}
}
let _ = daemon.shutdown();
let mut out: Vec<DiscoveredPeer> = seen.into_values().collect();
out.sort_by(|a, b| a.url.cmp(&b.url));
Ok(out)
})
.await
.map_err(|e| Error::Api {
code: 0,
message: format!("mdns: discover task panicked: {e}"),
})?;
handle
}
}
fn resolve_info_to_peer(
info: &mdns_sd::ResolvedService,
scheme: &str,
default_port: u16,
observed_ms: u32,
) -> Option<DiscoveredPeer> {
let mut ipv4: Option<String> = None;
let mut other: Option<String> = None;
for a in info.get_addresses() {
let s = a.to_string();
if s.contains('.') && !s.contains(':') {
ipv4 = Some(s);
break;
}
other.get_or_insert(s);
}
let host: String = ipv4
.or(other)
.unwrap_or_else(|| info.get_hostname().trim_end_matches('.').to_owned());
if host.is_empty() {
return None;
}
let txt_port = info
.get_property_val_str("port")
.and_then(|s| s.parse::<u16>().ok());
let record_port = info.get_port();
let port = txt_port
.or(if record_port == 0 {
None
} else {
Some(record_port)
})
.unwrap_or(default_port);
let host_literal = if host.contains(':') && !host.starts_with('[') {
format!("[{host}]")
} else {
host
};
let url = format!("{scheme}://{host_literal}:{port}");
let device_id = info.get_property_val_str("id").map(str::to_owned);
let fp = info
.get_property_val_str("fp")
.map(super::normalize_fingerprint);
Some(
DiscoveredPeer::new(url)
.with_latency_ms(observed_ms)
.set_device_id(device_id)
.set_tls_fingerprint(fp),
)
}
impl DiscoveredPeer {
fn set_device_id(mut self, id: Option<String>) -> Self {
self.device_id = id;
self
}
fn set_tls_fingerprint(mut self, fp: Option<String>) -> Self {
self.tls_fingerprint = fp;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_defaults_match_adr() {
let d = MdnsDiscovery::builder().build();
assert_eq!(d.service_type(), DEFAULT_SERVICE_TYPE);
assert_eq!(d.browse_duration(), DEFAULT_BROWSE_DURATION);
}
#[test]
fn builder_overrides_are_applied() {
let d = MdnsDiscovery::builder()
.service_type("_cognitum-dev._tcp.local.")
.browse_duration(Duration::from_millis(500))
.scheme("http")
.default_port(18080)
.build();
assert_eq!(d.service_type(), "_cognitum-dev._tcp.local.");
assert_eq!(d.browse_duration(), Duration::from_millis(500));
}
#[tokio::test]
async fn discover_on_empty_network_returns_empty_fast() {
let d = MdnsDiscovery::builder()
.service_type("_cognitum-nope._tcp.local.")
.browse_duration(Duration::from_millis(150))
.build();
let peers = d.discover().await.unwrap_or_default();
assert!(peers.is_empty());
}
#[test]
fn fingerprint_normalisation_strips_prefix_and_colons() {
use super::super::normalize_fingerprint;
let got = normalize_fingerprint("sha256:AA:BB:CC:DD");
assert_eq!(got, "aabbccdd");
let got2 = normalize_fingerprint(" SHA256:AaBb ");
assert_eq!(got2, "aabb");
let got3 = normalize_fingerprint("AABB");
assert_eq!(got3, "aabb");
}
#[test]
fn set_tls_fingerprint_round_trip_on_discovered_peer() {
let p = DiscoveredPeer::new("https://seed:8443").with_tls_fingerprint("sha256:AA:BB");
assert_eq!(p.tls_fingerprint.as_deref(), Some("aabb"));
}
}