use crate::{Node, Nodes, BROADCAST_INTERVAL, MAX_SILENT_INTERVALS};
use if_addrs::get_if_addrs;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::{broadcast, oneshot, watch};
use tokio::time::sleep;
use tokio::time::{Duration, Instant};
use tracing::{debug, error, info, warn};
pub async fn discover(
broadcast_port: u16,
) -> Result<
(
oneshot::Receiver<()>,
oneshot::Receiver<()>,
watch::Sender<()>,
Arc<Nodes>,
),
Box<dyn std::error::Error + Send + Sync>,
> {
let nodes = Arc::new(Nodes::new());
let (shutdown_tx, mut shutdown_rx) = watch::channel(());
let (up_tx, up_rx) = oneshot::channel();
let (fin_tx, fin_rx) = oneshot::channel();
let own_ip = get_own_private_ip().unwrap_or(Ipv4Addr::new(127, 0, 0, 1));
info!("Own IP address: {}", own_ip);
let socket = Arc::new(
UdpSocket::bind(("0.0.0.0", broadcast_port))
.await
.expect("Failed to bind socket"),
);
socket.set_broadcast(true).expect("Failed to set broadcast");
let ip_str = own_ip.to_string();
let octets: Vec<&str> = ip_str.split('.').collect();
if octets.len() != 4 {
return Err("Invalid IP address format".into());
}
let broadcast_ip = format!("{}.{}.{}.255", octets[0], octets[1], octets[2]);
let _ = up_tx.send(());
let nodes_clone = Arc::clone(&nodes);
let socket_clone = Arc::clone(&socket);
let mut shutdown_clone = shutdown_rx.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = shutdown_clone.changed() => {
info!("Shutdown signal received, stopping broadcast task");
break;
}
_ = sleep(BROADCAST_INTERVAL) => {
nodes_clone.reap();
match socket_clone
.send_to(&own_ip.octets(), (broadcast_ip.as_str(), broadcast_port))
.await
{
Ok(_) => {}
Err(e) => {
error!("Failed to send broadcast: {}", e);
}
}
}
}
}
});
let nodes_clone = Arc::clone(&nodes);
tokio::spawn(async move {
let mut buffer = [0; 1024];
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
info!("Shutdown signal received, stopping receive task");
break;
}
result = socket.recv_from(&mut buffer) => {
match result {
Ok((_, src_addr)) => {
if let Some(discovered_ip) = extract_private_ip(&src_addr) {
if discovered_ip != own_ip {
if !nodes_clone.test(&discovered_ip) {
info!("Discovered new node: {}", discovered_ip);
}
let is_self = own_ip == discovered_ip;
nodes_clone.add(discovered_ip, None, None, is_self);
};
} else {
warn!("Received broadcast from non-private IP: {}", src_addr.ip());
}
}
Err(e) => {
warn!("Error receiving broadcast: {}", e);
}
}
}
}
}
});
Ok((up_rx, fin_rx, shutdown_tx, Arc::clone(&nodes)))
}
pub fn get_own_private_ip() -> Option<Ipv4Addr> {
let addrs = match get_if_addrs() {
Ok(addrs) => addrs,
Err(e) => {
warn!("Failed to get network interfaces: {}", e);
return None;
}
};
for addr in addrs {
if let IpAddr::V4(ip) = addr.ip() {
if ip.is_private() && ip.octets()[0] == 10 {
return Some(ip);
}
}
}
None
}
fn extract_private_ip(addr: &SocketAddr) -> Option<Ipv4Addr> {
match addr.ip() {
IpAddr::V4(ipv4) => {
if ipv4.is_private() && ipv4.octets()[0] == 10 {
Some(ipv4)
} else {
None
}
}
IpAddr::V6(_) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::str::FromStr;
use std::thread::sleep;
#[test]
fn test_get_own_private_ip() {
let ip: Option<Ipv4Addr> = get_own_private_ip();
assert_eq!(ip, None);
}
#[test]
fn test_nodes_add_and_test() {
let nodes: Nodes = Nodes::new([]);
nodes.add(Ipv4Addr::from_str("127.0.0.1").unwrap());
assert!(nodes.test(Ipv4Addr::from_str("127.0.0.1").unwrap()));
assert!(!nodes.test(Ipv4Addr::from_str("192.168.0.1").unwrap()));
}
#[test]
fn test_nodes_all() {
let nodes: Nodes = Nodes::new([]);
nodes.add(Ipv4Addr::from_str("127.0.0.1").unwrap());
nodes.add(Ipv4Addr::from_str("192.168.0.1").unwrap());
let all_nodes: Vec<Node> = nodes.all();
assert_eq!(all_nodes.len(), 2);
assert!(all_nodes
.iter()
.any(|node| node.ip == Ipv4Addr::from_str("127.0.0.1").unwrap()));
assert!(all_nodes
.iter()
.any(|node| node.ip == Ipv4Addr::from_str("192.168.0.1").unwrap()));
}
#[test]
fn test_nodes_reap() {
let nodes: Nodes = Nodes::new([]);
nodes.add(Ipv4Addr::from_str("127.0.0.1").unwrap());
nodes.add(Ipv4Addr::from_str("192.168.0.1").unwrap());
sleep(Duration::from_secs(
(MAX_SILENT_INTERVALS + 1) * BROADCAST_INTERVAL.as_secs(),
));
nodes.reap();
assert_eq!(nodes.all().len(), 0);
}
}