use anyhow::{Context, Result};
use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4, TcpStream};
use std::time::{Duration, Instant};
use tokio::net::UdpSocket;
use tracing::{debug, info, instrument, warn};
#[instrument(name = "send_wol_packets", skip(mac, config))]
pub async fn send_packets(
mac: &[u8; 6],
bcast: Ipv4Addr,
port: u16,
count: u32,
config: &crate::config::Config,
) -> Result<()> {
let packet = build_magic_packet(mac);
debug!(
"Built WOL magic packet for MAC {:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]
);
let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0))
.await
.context("Failed to bind UDP socket")?;
sock.set_broadcast(true)
.context("Failed to enable broadcast on socket")?;
let addr = SocketAddrV4::new(bcast, port);
info!("Sending {} WOL packets to {}:{}", count, bcast, port);
for i in 0..count {
debug!("Sending WOL packet {}/{}", i + 1, count);
sock.send_to(&packet, addr)
.await
.context("Failed to send WOL packet")?;
tokio::time::sleep(config.wol_packet_sleeptime()).await;
}
info!(
"Successfully sent {} WOL packets to {}:{}",
count, bcast, port
);
Ok(())
}
#[instrument(name = "check_host_reachability", skip(ip))]
pub fn check_host(
ip: IpAddr,
check_tcp_port: u16,
wait_secs: u64,
interval_ms: u64,
connect_timeout_ms: u64,
_config: &crate::config::Config,
) -> bool {
let poll_every = Duration::from_millis(interval_ms);
let connect_timeout = Duration::from_millis(connect_timeout_ms);
let deadline = Instant::now() + Duration::from_secs(wait_secs);
let target = SocketAddr::new(ip, check_tcp_port);
info!(
"Waiting up to {}s for {}:{} ...",
wait_secs, ip, check_tcp_port
);
loop {
debug!("Checking if {}:{} is reachable", ip, check_tcp_port);
if tcp_check(target, connect_timeout) {
info!("Host {}:{} is UP ✅", ip, check_tcp_port);
return true;
}
if Instant::now() >= deadline {
warn!("TIMEOUT ❌ waiting for {}:{}", ip, check_tcp_port);
return false;
}
debug!(
"Host {}:{} not reachable, waiting {:?} before next check",
ip, check_tcp_port, poll_every
);
std::thread::sleep(poll_every);
}
}
pub fn tcp_check(addr: SocketAddr, timeout: Duration) -> bool {
TcpStream::connect_timeout(&addr, timeout).is_ok()
}
#[instrument(name = "parse_mac", skip(s))]
pub fn parse_mac(s: &str) -> Result<[u8; 6]> {
let hex: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
if hex.len() != 12 {
anyhow::bail!("expected 12 hex digits, got {}", hex.len());
}
let mut mac = [0u8; 6];
for i in 0..6 {
mac[i] = u8::from_str_radix(&hex[2 * i..2 * i + 2], 16).with_context(|| {
format!(
"invalid hex in MAC at position {}: '{}'",
i,
&hex[2 * i..2 * i + 2]
)
})?;
}
debug!(
"Successfully parsed MAC address: {:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}",
mac[0], mac[1], mac[2], mac[3], mac[4], mac[5]
);
Ok(mac)
}
fn build_magic_packet(mac: &[u8; 6]) -> [u8; 102] {
let mut pkt = [0u8; 102];
pkt[0..6].fill(0xFF);
for i in 0..16 {
let start = 6 + i * 6;
pkt[start..start + 6].copy_from_slice(mac);
}
pkt
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use std::io::ErrorKind;
use std::net::TcpListener;
use std::time::Duration;
#[test]
fn magic_packet_has_sync_stream_and_repeated_mac() {
let mac = [0xDE, 0xAD, 0xBE, 0xEF, 0xFE, 0xED];
let packet = build_magic_packet(&mac);
assert_eq!(packet.len(), 102, "magic packet must be 102 bytes");
assert!(
packet.iter().take(6).all(|&b| b == 0xFF),
"packet must start with six 0xFF bytes"
);
for (idx, chunk) in packet[6..].chunks_exact(6).enumerate() {
assert_eq!(chunk, mac, "MAC repetition {} does not match", idx + 1);
}
}
#[test]
fn parse_mac_accepts_common_formats() {
let expected = [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF];
let inputs = [
"aa:bb:cc:dd:ee:ff",
"AA-BB-CC-DD-EE-FF",
"aabb.ccdd.eeff",
"AABBCCDDEEFF",
];
for input in inputs {
let mac = parse_mac(input).expect("MAC should parse");
assert_eq!(
mac, expected,
"parsed MAC did not match for input '{}':",
input
);
}
}
#[test]
fn parse_mac_rejects_invalid_input() {
let invalid_inputs = [
"",
"1234567890ABCD",
"zz:zz:zz:zz:zz:zz",
"aa-bb-cc-dd-ee",
"aa:bb:cc:dd:ee:ff:11",
];
for input in invalid_inputs {
assert!(
parse_mac(input).is_err(),
"expected error for input '{}'",
input
);
}
}
#[test]
fn tcp_check_reports_true_when_server_listening() {
let listener = match TcpListener::bind("127.0.0.1:0") {
Ok(listener) => listener,
Err(err) if err.kind() == ErrorKind::PermissionDenied => {
eprintln!(
"skipping tcp_check_reports_true_when_server_listening: {}",
err
);
return;
}
Err(err) => panic!("failed to bind listener: {err}"),
};
let addr = listener.local_addr().expect("failed to get addr");
assert!(tcp_check(addr, Duration::from_millis(100)));
drop(listener);
}
#[test]
fn check_host_returns_false_when_unreachable() {
let config = Config::default();
let result = check_host(IpAddr::V4(Ipv4Addr::LOCALHOST), 65_000, 0, 10, 10, &config);
assert!(!result);
}
#[test]
fn check_host_returns_true_when_host_up() {
let listener = match TcpListener::bind("127.0.0.1:0") {
Ok(listener) => listener,
Err(err) if err.kind() == ErrorKind::PermissionDenied => {
eprintln!("skipping check_host_returns_true_when_host_up: {}", err);
return;
}
Err(err) => panic!("failed to bind listener: {err}"),
};
let addr = listener.local_addr().expect("failed to get addr");
let config = Config::default();
let is_up = check_host(addr.ip(), addr.port(), 1, 10, 50, &config);
assert!(is_up);
drop(listener);
}
}