use std::{
io,
net::{IpAddr, Ipv4Addr, SocketAddr},
time::{Duration, Instant},
};
use sha2::{Digest, Sha256};
const MDNS_SERVICE_TYPE: &str = "_hayate._udp.local.";
const UDP_DISCOVERY_PORT: u16 = 50002;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DiscoveredPeer {
pub name: String,
pub addr: SocketAddr,
pub os: String,
pub rtt_ms: Option<f64>,
}
#[must_use]
pub fn derive_channel_id(phrase: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(phrase.as_bytes());
let result = hasher.finalize();
hex::encode(&result[..4])
}
pub struct BroadcasterGuard {
mdns_handle: Option<mdns_sd::ServiceDaemon>,
cancel_tx: Option<flume::Sender<()>>,
}
impl std::fmt::Debug for BroadcasterGuard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BroadcasterGuard")
.field("mdns_handle", &self.mdns_handle.is_some())
.field("cancel_tx", &self.cancel_tx.is_some())
.finish()
}
}
impl BroadcasterGuard {
#[must_use]
pub(crate) fn new(cancel_tx: flume::Sender<()>, mdns: mdns_sd::ServiceDaemon) -> Self {
Self {
cancel_tx: Some(cancel_tx),
mdns_handle: Some(mdns),
}
}
}
impl Drop for BroadcasterGuard {
fn drop(&mut self) {
if let Some(tx) = self.cancel_tx.take() {
let _ = tx.send(());
}
if let Some(mdns) = self.mdns_handle.take() {
let _ = mdns.shutdown();
}
}
}
pub fn start_broadcaster_hybrid(
channel_id: &str,
port: u16,
os_name: &str,
) -> Result<BroadcasterGuard, io::Error> {
let (cancel_tx, cancel_rx) = flume::bounded::<()>(1);
let mdns = mdns_sd::ServiceDaemon::new().map_err(io::Error::other)?;
let instance_name = format!("hayate-{channel_id}");
let host_name = format!("hayate-{channel_id}.local.");
let txt_props: &[(&str, &str)] = &[
("chid", channel_id),
("os", os_name),
("port", &port.to_string()),
];
let ip_str = crate::local_addr::primary_local_ipv4()
.map_or_else(|| "127.0.0.1".to_owned(), |ip| ip.to_string());
let service = mdns_sd::ServiceInfo::new(
MDNS_SERVICE_TYPE,
&instance_name,
&host_name,
ip_str.as_str(),
port,
txt_props,
)
.map_err(io::Error::other)?;
mdns.register(service).map_err(io::Error::other)?;
let cid = channel_id.to_owned();
let os = os_name.to_owned();
let mdns_clone = mdns.clone();
compio::runtime::spawn(async move {
let _ = udp_broadcast_loop(&cid, port, &os, cancel_rx).await;
let _ = mdns_clone.shutdown();
})
.detach();
Ok(BroadcasterGuard::new(cancel_tx, mdns))
}
pub async fn start_broadcaster(
channel_id: &str,
port: u16,
cancel_rx: flume::Receiver<()>,
) -> Result<(), io::Error> {
udp_broadcast_loop(channel_id, port, std::env::consts::OS, cancel_rx).await
}
async fn udp_broadcast_loop(
channel_id: &str,
port: u16,
os: &str,
cancel_rx: flume::Receiver<()>,
) -> Result<(), io::Error> {
let socket = compio::net::UdpSocket::bind("0.0.0.0:0").await?;
socket.set_broadcast(true)?;
let msg = format!("HAYATE_PEER:v2:{channel_id}:{os}:{port}");
let msg_bytes = msg.into_bytes();
let target = SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), UDP_DISCOVERY_PORT);
let loopback = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), UDP_DISCOVERY_PORT);
loop {
let _ = socket.send_to(msg_bytes.clone(), target).await;
let _ = socket.send_to(msg_bytes.clone(), loopback).await;
let sleep_fut = compio::time::sleep(Duration::from_millis(800));
let cancel_fut = cancel_rx.recv_async();
let sleep_pinned = std::pin::pin!(sleep_fut);
let cancel_pinned = std::pin::pin!(cancel_fut);
if let futures_util::future::Either::Right(_) =
futures_util::future::select(sleep_pinned, cancel_pinned).await
{
break;
}
}
Ok(())
}
pub fn listen_for_broadcast(
target_phrase: Option<&str>,
timeout: Duration,
) -> Result<Option<(String, SocketAddr, String)>, io::Error> {
let target_channel_id = target_phrase.map(derive_channel_id);
let (found_tx, found_rx) = flume::bounded::<(String, SocketAddr, String)>(1);
let target_cid_mdns = target_channel_id.clone();
let found_tx_mdns = found_tx.clone();
let mdns_task = std::thread::spawn(move || {
let Ok(mdns) = mdns_sd::ServiceDaemon::new() else {
return;
};
let Ok(receiver) = mdns.browse(MDNS_SERVICE_TYPE) else {
return;
};
let deadline = Instant::now() + timeout;
while let Ok(event) = receiver.recv_timeout(Duration::from_millis(200)) {
if Instant::now() > deadline {
break;
}
if let mdns_sd::ServiceEvent::ServiceResolved(info) = event {
for addr in info.get_addresses_v4() {
let remote_chid = info
.get_property_val_str("chid")
.unwrap_or_default()
.to_owned();
let remote_os = info
.get_property_val_str("os")
.unwrap_or("unknown")
.to_owned();
let remote_port = info.get_port();
let matches = match &target_cid_mdns {
Some(expected) => remote_chid == *expected,
None => true,
};
if matches {
let peer_addr = SocketAddr::new(IpAddr::V4(addr), remote_port);
let _ = found_tx_mdns
.send((format!("mDNS:{remote_chid}"), peer_addr, remote_os));
let _ = mdns.shutdown();
return;
}
}
}
}
let _ = mdns.shutdown();
});
let target_cid_udp = target_channel_id.clone();
let found_tx_udp = found_tx;
let udp_task = std::thread::spawn(move || -> Result<(), io::Error> {
let std_socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
std_socket.set_reuse_address(true)?;
#[cfg(not(windows))]
std_socket.set_reuse_port(true)?;
let listen_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), UDP_DISCOVERY_PORT);
std_socket.bind(&socket2::SockAddr::from(listen_addr))?;
std_socket
.set_read_timeout(Some(Duration::from_millis(200)))
.unwrap_or(());
let socket: std::net::UdpSocket = std_socket.into();
let mut buf = [0u8; 1024];
let deadline = Instant::now() + timeout;
while Instant::now() <= deadline {
match socket.recv_from(&mut buf) {
Ok((n, src_addr)) => {
let data = &buf[..n];
if let Ok(text) = std::str::from_utf8(data)
&& let Some(result) =
parse_udp_packet(text, target_cid_udp.as_deref(), src_addr)
{
let _ = found_tx_udp.send(result);
return Ok(());
}
}
Err(ref e)
if e.kind() == io::ErrorKind::WouldBlock
|| e.kind() == io::ErrorKind::TimedOut => {}
Err(_) => break,
}
}
Ok(())
});
let adjusted_timeout = timeout
.checked_add(Duration::from_secs(2))
.unwrap_or(timeout);
if let Ok(result) = found_rx.recv_timeout(adjusted_timeout) {
let _ = mdns_task.join();
let _ = udp_task.join();
Ok(Some(result))
} else {
let _ = mdns_task.join();
let _ = udp_task.join();
Ok(None)
}
}
fn parse_udp_packet(
text: &str,
target_channel_id: Option<&str>,
src_addr: SocketAddr,
) -> Option<(String, SocketAddr, String)> {
let mut parts = text.split(':');
if parts.next()? != "HAYATE_PEER" {
return None;
}
let next = parts.next()?;
let (channel_id, os, port_str) = if next == "v2" {
(
parts.next()?.to_owned(),
parts.next()?.to_owned(),
parts.next()?.to_owned(),
)
} else {
let os = parts.next()?.to_owned();
let port_str = parts.next()?.to_owned();
(next.to_owned(), os, port_str)
};
let matches = match target_channel_id {
Some(expected) => channel_id == *expected,
None => true,
};
if matches {
let port = port_str.parse::<u16>().ok()?;
let peer_addr = SocketAddr::new(src_addr.ip(), port);
Some((format!("UDP:{channel_id}"), peer_addr, os))
} else {
None
}
}
pub async fn listen_for_broadcast_udp(
target_phrase: Option<&str>,
timeout: Duration,
) -> Result<Option<(String, SocketAddr, String)>, io::Error> {
let target_channel_id = target_phrase.map(derive_channel_id);
let std_socket = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)?;
std_socket.set_reuse_address(true)?;
#[cfg(not(windows))]
std_socket.set_reuse_port(true)?;
let listen_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), UDP_DISCOVERY_PORT);
std_socket.bind(&socket2::SockAddr::from(listen_addr))?;
let socket = compio::net::UdpSocket::from_std(std_socket.into())?;
let buf = vec![0u8; 1024];
let res = compio::time::timeout(timeout, async move {
let mut temp_buf = buf;
loop {
let compio::BufResult(recv_res, b) = socket.recv_from(temp_buf).await;
temp_buf = b;
match recv_res {
Ok((n, src_addr)) => {
let data = &temp_buf[..n];
if let Ok(text) = std::str::from_utf8(data)
&& let Some(result) =
parse_udp_packet(text, target_channel_id.as_deref(), src_addr)
{
return Ok(Some(result));
}
}
Err(e) => return Err(e),
}
}
})
.await;
match res {
Ok(inner_res) => inner_res,
Err(_) => Ok(None),
}
}