use crate::{
errors::{CatBridgeError, NetworkError},
mion::proto::{
DEFAULT_MION_CONTROL_PORT, MION_ANNOUNCE_TIMEOUT_SECONDS,
control::{MionIdentity, MionIdentityAnnouncement},
},
};
use bytes::{Bytes, BytesMut};
use fnv::FnvHashSet;
use futures::stream::{StreamExt, unfold};
use mac_address::MacAddress;
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig};
use std::{
fmt::{Display, Formatter, Result as FmtResult},
hash::BuildHasherDefault,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
};
use tokio::{
net::UdpSocket,
sync::mpsc::{UnboundedReceiver, unbounded_channel},
task::JoinSet,
time::{Duration, Instant, sleep},
};
use tracing::{debug, error, warn};
pub async fn discover_and_collect_bridges(
fetch_detailed_info: bool,
early_timeout: Option<Duration>,
override_control_port: Option<u16>,
) -> Result<Vec<MionIdentity>, CatBridgeError> {
discover_and_collect_bridges_with_logging_hooks(
fetch_detailed_info,
early_timeout,
override_control_port,
noop_logger_interface,
)
.await
}
pub async fn discover_and_collect_bridges_with_logging_hooks<InterfaceLoggingHook>(
fetch_detailed_info: bool,
early_timeout: Option<Duration>,
override_control_port: Option<u16>,
interface_logging_hook: InterfaceLoggingHook,
) -> Result<Vec<MionIdentity>, CatBridgeError>
where
InterfaceLoggingHook: Fn(&'_ Addr) + Clone + Send + 'static,
{
let mut recv_channel = discover_bridges_with_logging_hooks(
fetch_detailed_info,
override_control_port,
interface_logging_hook,
)
.await?;
let mut results = Vec::new();
loop {
tokio::select! {
opt = recv_channel.recv() => {
let Some(identity) = opt else {
break;
};
if !results.contains(&identity) {
results.push(identity);
}
}
() = sleep(early_timeout.unwrap_or(Duration::from_secs(MION_ANNOUNCE_TIMEOUT_SECONDS * 2))) => {
break;
}
}
}
Ok(results)
}
pub async fn discover_bridges(
fetch_detailed_info: bool,
override_control_port: Option<u16>,
) -> Result<UnboundedReceiver<MionIdentity>, CatBridgeError> {
discover_bridges_with_logging_hooks(
fetch_detailed_info,
override_control_port,
noop_logger_interface,
)
.await
}
pub async fn discover_bridges_with_logging_hooks<InterfaceLoggingHook>(
fetch_detailed_info: bool,
override_control_port: Option<u16>,
interface_logging_hook: InterfaceLoggingHook,
) -> Result<UnboundedReceiver<MionIdentity>, CatBridgeError>
where
InterfaceLoggingHook: Fn(&'_ Addr) + Clone + Send + 'static,
{
let to_broadcast = Bytes::from(MionIdentityAnnouncement::new(fetch_detailed_info));
let mut tasks = JoinSet::new();
for (interface_addr, interface_ipv4) in get_all_broadcast_addresses()? {
let broadcast_messaged_cloned = to_broadcast.clone();
let cloned_iface_hook = interface_logging_hook.clone();
tasks
.build_task()
.name(&format!("cat_dev::discover_mion::{interface_ipv4}"))
.spawn(async move {
broadcast_to_mions_on_interface(
override_control_port,
broadcast_messaged_cloned,
interface_addr,
interface_ipv4,
cloned_iface_hook,
)
.await
})
.map_err(CatBridgeError::SpawnFailure)?;
}
let mut listening_sockets = Vec::with_capacity(tasks.len());
while let Some(joined) = tasks.join_next().await {
let joined_result = match joined {
Ok(data) => data,
Err(cause) => {
tasks.abort_all();
return Err(CatBridgeError::JoinFailure(cause));
}
};
let mut opt_socket = match joined_result {
Ok(optional_socket) => optional_socket,
Err(cause) => {
tasks.abort_all();
return Err(cause.into());
}
};
if let Some(socket) = opt_socket.take() {
listening_sockets.push(socket);
}
}
let mut our_addresses = FnvHashSet::with_capacity_and_hasher(
listening_sockets.len(),
BuildHasherDefault::default(),
);
for sock in &listening_sockets {
if let Ok(our_addr) = sock.local_addr() {
our_addresses.insert(our_addr.ip());
}
}
let streams = listening_sockets
.into_iter()
.map(|socket| Box::pin(unfold(socket, unfold_socket)))
.collect::<Vec<_>>();
let mut single_stream = futures::stream::select_all(streams);
let timeout_at = Instant::now() + Duration::from_secs(MION_ANNOUNCE_TIMEOUT_SECONDS);
let (send, recv) = unbounded_channel::<MionIdentity>();
tokio::task::spawn(async move {
loop {
tokio::select! {
opt = single_stream.next() => {
let Some((read_data_len, from, mut buff)) = opt else {
continue;
};
buff.truncate(read_data_len);
let frozen = buff.freeze();
let from_ip = from.ip();
if our_addresses.contains(&from_ip) {
debug!("broadcast saw our own message");
continue;
}
let ip_address = match from_ip {
IpAddr::V4(v4) => v4,
IpAddr::V6(v6) => {
debug!(%v6, "broadcast packet from IPv6, ignoring, can't be announcement");
continue;
},
};
let Ok(identity) = MionIdentity::try_from((ip_address, frozen.clone())) else {
warn!(%from, packet = %format!("{frozen:02x?}"), "could not parse packet from MION");
continue;
};
if let Err(_closed) = send.send(identity) {
break;
}
}
() = tokio::time::sleep_until(timeout_at) => {
break;
}
}
}
});
Ok(recv)
}
pub async fn find_mion(
find_by: MionFindBy,
find_detailed: bool,
early_scan_timeout: Option<Duration>,
override_control_port: Option<u16>,
) -> Result<Option<MionIdentity>, CatBridgeError> {
find_mion_with_logging_hooks(
find_by,
find_detailed,
early_scan_timeout,
override_control_port,
noop_logger_interface,
)
.await
}
pub async fn find_mion_with_logging_hooks<InterfaceLoggingHook>(
find_by: MionFindBy,
find_detailed_info: bool,
early_scan_timeout: Option<Duration>,
override_control_port: Option<u16>,
interface_logging_hook: InterfaceLoggingHook,
) -> Result<Option<MionIdentity>, CatBridgeError>
where
InterfaceLoggingHook: Fn(&'_ Addr) + Clone + Send + 'static,
{
let port = override_control_port.unwrap_or(DEFAULT_MION_CONTROL_PORT);
let (find_by_mac, find_by_name) = match find_by {
MionFindBy::Ip(ipv4) => {
let local_socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port))
.await
.map_err(|_| NetworkError::BindFailure)?;
local_socket
.connect(SocketAddrV4::new(ipv4, port))
.await
.map_err(NetworkError::IO)?;
local_socket
.send(&Bytes::from(MionIdentityAnnouncement::new(
find_detailed_info,
)))
.await
.map_err(NetworkError::IO)?;
let mut buff = BytesMut::zeroed(8192);
tokio::select! {
result = local_socket.recv(&mut buff) => {
let actual_size = result.map_err(NetworkError::IO)?;
buff.truncate(actual_size);
}
() = sleep(Duration::from_secs(MION_ANNOUNCE_TIMEOUT_SECONDS)) => {
return Ok(None);
}
}
return Ok(Some(MionIdentity::try_from((ipv4, buff.freeze()))?));
}
MionFindBy::MacAddress(mac) => (Some(mac), None),
MionFindBy::Name(name) => (None, Some(name)),
};
let mut recv_channel = discover_bridges_with_logging_hooks(
find_detailed_info,
override_control_port,
interface_logging_hook,
)
.await?;
loop {
tokio::select! {
opt = recv_channel.recv() => {
let Some(identity) = opt else {
break;
};
if let Some(filter_mac) = find_by_mac.as_ref() && *filter_mac == identity.mac_address() {
return Ok(Some(identity));
}
if let Some(filter_name) = find_by_name.as_ref() && filter_name == identity.name() {
return Ok(Some(identity));
}
}
() = sleep(early_scan_timeout.unwrap_or(Duration::from_secs(MION_ANNOUNCE_TIMEOUT_SECONDS * 2))) => {
break;
}
}
}
Ok(None)
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub enum MionFindBy {
Ip(Ipv4Addr),
MacAddress(MacAddress),
Name(String),
}
impl MionFindBy {
#[must_use]
pub fn from_name_or_ip(value: String) -> Self {
if let Ok(ipv4) = value.as_str().parse::<Ipv4Addr>() {
Self::Ip(ipv4)
} else {
Self::Name(value)
}
}
#[must_use]
pub const fn will_cause_full_scan(&self) -> bool {
match self {
Self::Ip(_ip) => false,
Self::MacAddress(_mac) => true,
Self::Name(_name) => true,
}
}
}
impl From<String> for MionFindBy {
fn from(value: String) -> Self {
if let Ok(mac) = MacAddress::try_from(value.as_str()) {
Self::MacAddress(mac)
} else {
Self::from_name_or_ip(value)
}
}
}
impl Display for MionFindBy {
fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
match self {
Self::Ip(ip) => write!(fmt, "{ip}"),
Self::MacAddress(mac) => write!(fmt, "{mac}"),
Self::Name(name) => write!(fmt, "{name}"),
}
}
}
pub fn get_all_broadcast_addresses() -> Result<Vec<(Addr, Ipv4Addr)>, NetworkError> {
Ok(NetworkInterface::show()
.map_err(|cause| {
error!(?cause, "could not list network interfaces on this device");
NetworkError::ListInterfacesFailure(cause)
})?
.into_iter()
.fold(Vec::<(Addr, Ipv4Addr)>::new(), |mut accum, iface| {
for local_address in &iface.addr {
let ip = match local_address.ip() {
IpAddr::V4(v4) => v4,
IpAddr::V6(_) => {
debug!(?iface, ?local_address, "cannot broadcast to IPv6 addresses");
continue;
}
};
accum.push((*local_address, ip));
}
accum
}))
}
async fn broadcast_to_mions_on_interface<InterfaceLoggingHook>(
override_control_port: Option<u16>,
body_to_broadcast: Bytes,
interface_addr: Addr,
interface_ipv4: Ipv4Addr,
interface_hook: InterfaceLoggingHook,
) -> Result<Option<UdpSocket>, NetworkError>
where
InterfaceLoggingHook: Fn(&'_ Addr),
{
interface_hook(&interface_addr);
let Some(broadcast_address) = interface_addr.broadcast() else {
debug!(
?interface_addr,
?interface_ipv4,
"failed to get broadcast address"
);
return Ok(None);
};
debug!(
?interface_addr,
?interface_ipv4,
"actually broadcasting to interface"
);
let local_socket = UdpSocket::bind(SocketAddr::V4(SocketAddrV4::new(
interface_ipv4,
override_control_port.unwrap_or(DEFAULT_MION_CONTROL_PORT),
)))
.await
.map_err(|_| NetworkError::BindFailure)?;
local_socket
.set_broadcast(true)
.map_err(|_| NetworkError::SetBroadcastFailure)?;
local_socket
.send_to(
&body_to_broadcast,
SocketAddr::new(
broadcast_address,
override_control_port.unwrap_or(DEFAULT_MION_CONTROL_PORT),
),
)
.await
.map_err(NetworkError::IO)?;
Ok(Some(local_socket))
}
async fn unfold_socket(sock: UdpSocket) -> Option<((usize, SocketAddr, BytesMut), UdpSocket)> {
let mut buff = BytesMut::zeroed(1024);
let Ok((len, addr)) = sock.recv_from(&mut buff).await else {
warn!("failed to receive data from broadcast socket");
return None;
};
Some(((len, addr, buff), sock))
}
#[inline]
fn noop_logger_interface(_: &Addr) {}
#[cfg(test)]
mod unit_tests {
use super::*;
#[test]
pub fn can_list_at_least_one_interface() {
assert!(
!get_all_broadcast_addresses()
.expect("Failed to list all broadcast addresses!")
.is_empty(),
"Failed to list all broadcast addresses... for some reason your PC isn't compatible to scan devices... perhaps you don't have a private IPv4 address?",
);
}
#[tokio::test]
pub async fn cant_find_nonexisting_device() {
assert!(
find_mion(MionFindBy::Name("𩸽".to_owned()), false, None, None)
.await
.expect("Failed to scan to find a specific mion")
.is_none(),
"Somehow found a MION that can't exist?"
);
assert!(
find_mion(MionFindBy::Name("𩸽".to_owned()), true, None, None)
.await
.expect("Failed to scan to find a specific mion")
.is_none(),
"Somehow found a MION that can't exist?"
);
assert!(
find_mion(
MionFindBy::Name("𩸽".to_owned()),
true,
Some(Duration::from_secs(3)),
None,
)
.await
.expect("Failed to scan to find a specific mion")
.is_none(),
"Somehow found a MION that can't exist?"
);
}
}