use std::io;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket};
use std::time::{Duration, Instant};
use crate::protocol::{
BROADCAST_PORT, PACKET_SIZE, PKT_TYPE_PEER_INFO, PROTOCOL_VERSION,
encode_discovery_request, parse_packet,
};
const TOTAL_BUDGET: Duration = Duration::from_millis(500);
const DISCOVERY_BURSTS: u32 = 3;
const BURST_INTERVAL: Duration = Duration::from_millis(150);
const QUIET_WINDOW: Duration = Duration::from_millis(50);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DeviceInfo {
pub ip: IpAddr,
pub serial: u32,
pub compatible: bool,
}
pub struct DeviceDiscovery {
explicit_target: Option<SocketAddr>,
}
impl Default for DeviceDiscovery {
fn default() -> Self {
Self::new()
}
}
impl DeviceDiscovery {
pub fn new() -> Self {
Self {
explicit_target: None,
}
}
#[cfg(test)]
pub(crate) fn with_target(target: SocketAddr) -> Self {
Self {
explicit_target: Some(target),
}
}
pub fn discover(&self) -> io::Result<Vec<DeviceInfo>> {
let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))?;
socket.set_broadcast(true)?;
let request = encode_discovery_request();
let targets = self.discovery_targets();
let mut devices: Vec<DeviceInfo> = Vec::new();
let mut buf = [0u8; PACKET_SIZE];
let start = Instant::now();
let mut bursts_sent: u32 = 0;
let mut last_reply: Option<Instant> = None;
loop {
if start.elapsed() >= TOTAL_BUDGET {
break;
}
if last_reply.is_some_and(|t| t.elapsed() >= QUIET_WINDOW) {
break;
}
let next_burst_at = BURST_INTERVAL * bursts_sent;
if bursts_sent < DISCOVERY_BURSTS && start.elapsed() >= next_burst_at {
for target in &targets {
let _ = socket.send_to(&request, target);
}
bursts_sent += 1;
}
let now = start.elapsed();
let mut wait = TOTAL_BUDGET.saturating_sub(now);
if bursts_sent < DISCOVERY_BURSTS {
wait = wait.min((BURST_INTERVAL * bursts_sent).saturating_sub(now));
}
if let Some(t) = last_reply {
wait = wait.min(QUIET_WINDOW.saturating_sub(t.elapsed()));
}
socket.set_read_timeout(Some(wait.max(Duration::from_millis(1))))?;
match socket.recv_from(&mut buf) {
Ok((n, src)) => {
let Some(parsed) = parse_packet(&buf[..n]) else {
continue;
};
if parsed.pkt_type != PKT_TYPE_PEER_INFO {
continue;
}
if devices.iter().any(|d| d.serial == parsed.serial) {
continue;
}
devices.push(DeviceInfo {
ip: src.ip(),
serial: parsed.serial,
compatible: parsed.version == PROTOCOL_VERSION,
});
last_reply = Some(Instant::now());
}
Err(e)
if e.kind() == io::ErrorKind::WouldBlock
|| e.kind() == io::ErrorKind::TimedOut =>
{
}
Err(e) => return Err(e),
}
}
Ok(devices)
}
fn discovery_targets(&self) -> Vec<SocketAddr> {
if let Some(target) = self.explicit_target {
return vec![target];
}
let mut targets: Vec<SocketAddr> = enumerate_broadcast_addresses()
.into_iter()
.map(|ip| SocketAddr::from((ip, BROADCAST_PORT)))
.collect();
targets.push(SocketAddr::from((Ipv4Addr::BROADCAST, BROADCAST_PORT)));
targets
}
}
fn enumerate_broadcast_addresses() -> Vec<Ipv4Addr> {
let interfaces = match if_addrs::get_if_addrs() {
Ok(v) => v,
Err(_) => return Vec::new(),
};
interfaces
.into_iter()
.filter_map(|iface| match iface.addr {
if_addrs::IfAddr::V4(v4) => Some(v4),
_ => None,
})
.filter(|v4| !v4.ip.is_loopback() && !is_link_local(&v4.ip))
.filter_map(|v4| v4.broadcast)
.collect()
}
fn is_link_local(ip: &Ipv4Addr) -> bool {
let o = ip.octets();
o[0] == 169 && o[1] == 254
}
#[cfg(test)]
#[path = "devicediscovery_tests.rs"]
mod tests;