use crate::connection_pool::ConnectionPool;
use crate::{web::Machine, wol};
use anyhow::{Context, Result};
use std::collections::VecDeque;
use std::net::{Ipv4Addr, SocketAddr};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::io::copy_bidirectional;
use tokio::net::TcpListener;
use tokio::sync::watch;
use tokio::time::Instant;
use tracing::{debug, error, info, warn};
fn turn_off_url(remote_ip: &str, turn_off_port: u16) -> String {
format!("http://{}:{}/machines/turn-off", remote_ip, turn_off_port)
}
#[derive(Clone)]
struct TurnOffLimiter {
request_times: Arc<Mutex<VecDeque<Instant>>>,
max_requests: usize,
window: Duration,
turn_off_port: u16,
remote_ip: Ipv4Addr,
mac: String,
triggered: Arc<AtomicBool>,
}
impl TurnOffLimiter {
fn new(machine: &Machine, turn_off_port: u16) -> Self {
let window_minutes = machine.request_rate.period_minutes.max(1);
let window_secs = window_minutes.saturating_mul(60);
Self {
request_times: Arc::new(Mutex::new(VecDeque::new())),
max_requests: machine.request_rate.max_requests as usize,
window: Duration::from_secs(window_secs as u64),
turn_off_port,
remote_ip: machine.ip,
mac: machine.mac.clone(),
triggered: Arc::new(AtomicBool::new(false)),
}
}
fn record_request(&self) -> Option<usize> {
if self.max_requests == 0 {
return None;
}
let now = Instant::now();
let mut times = self.request_times.lock().unwrap();
times.push_back(now);
while let Some(oldest) = times.front() {
if now.duration_since(*oldest) > self.window {
times.pop_front();
} else {
break;
}
}
let current = times.len();
if current >= self.max_requests && !self.triggered.swap(true, Ordering::SeqCst) {
Some(current)
} else {
None
}
}
}
pub async fn turn_off_remote_machine(
remote_ip: &str,
turn_off_port: u16,
) -> Result<(), reqwest::Error> {
let url = turn_off_url(remote_ip, turn_off_port);
info!("Sending turn-off signal to {}", url);
let client = reqwest::Client::builder()
.no_proxy()
.timeout(Duration::from_secs(5))
.build()?;
let response = client.post(&url).send().await?;
if response.status().is_success() {
info!(
"Successfully sent turn-off signal to {}:{}",
remote_ip, turn_off_port
);
} else {
error!(
"Failed to send turn-off signal to {}:{}, status: {}",
remote_ip,
turn_off_port,
response.status()
);
}
Ok(())
}
pub async fn proxy(
local_port: u16,
remote_addr: SocketAddr,
machine: Machine,
wol_port: u16,
mut rx: watch::Receiver<bool>,
connection_pool: ConnectionPool,
) -> Result<()> {
let listen_addr = format!("0.0.0.0:{}", local_port);
let listener = TcpListener::bind(&listen_addr)
.await
.with_context(|| format!("Failed to bind TCP listener on {}", listen_addr))?;
info!(
"TCP Forwarder listening on {}, proxying to {}, rate limit: {}/{}min",
listen_addr,
remote_addr,
machine.request_rate.max_requests,
machine.request_rate.period_minutes
);
let rate_limiter = if machine.can_be_turned_off {
if let Some(port) = machine.turn_off_port {
if machine.request_rate.max_requests > 0 {
Some(TurnOffLimiter::new(&machine, port))
} else {
info!(
"Machine {} has rate limit disabled (max_requests = 0)",
machine.mac
);
None
}
} else {
debug!(
"Turn off port not configured for {}, skipping rate-based shutdown",
machine.mac
);
None
}
} else {
info!(
"Machine {} cannot be turned off automatically (feature disabled)",
machine.mac
);
None
};
loop {
tokio::select! {
result = rx.changed() => {
if result.is_err() || !*rx.borrow() {
info!("Proxy for {} on port {} cancelled.", remote_addr, local_port);
return Ok(());
}
}
result = listener.accept() => {
let (mut inbound, client_addr) = result
.context("Failed to accept incoming connection")?;
info!(
"Accepted connection from {} to forward to {}",
client_addr, remote_addr
);
let remote_addr_clone = remote_addr;
let mac_str_clone = machine.mac.clone();
let rate_limiter = rate_limiter.clone();
let connection_pool_clone = connection_pool.clone();
tokio::spawn(async move {
if let Some(limiter) = rate_limiter.clone() {
if let Some(hit_count) = limiter.record_request() {
let remote_ip = limiter.remote_ip.to_string();
let turn_off_port = limiter.turn_off_port;
let mac = limiter.mac.clone();
let window = limiter.window;
tokio::spawn(async move {
info!(
"Request limit reached for {}: {} requests within {:?}, sending turn-off signal",
mac, hit_count, window
);
if let Err(e) = turn_off_remote_machine(&remote_ip, turn_off_port).await {
error!(
"Failed to send turn-off signal for {} on {}:{}: {}",
mac, remote_ip, turn_off_port, e
);
}
});
}
}
let connect_timeout = Duration::from_millis(1000);
if !wol::tcp_check(remote_addr_clone, connect_timeout) {
info!(
"Host {} seems to be down. Sending WOL packet to MAC {}.",
remote_addr_clone, mac_str_clone
);
let mac = match wol::parse_mac(&mac_str_clone) {
Ok(m) => m,
Err(e) => {
error!("Invalid MAC for WOL on proxy: {}: {}", mac_str_clone, e);
return;
}
};
let broadcast_addr = Ipv4Addr::new(255, 255, 255, 255);
let wol_config = Default::default();
if let Err(e) = crate::wol::send_packets(&mac, broadcast_addr, wol_port, 3, &wol_config).await {
error!("Failed to send WOL packet for {}: {}", mac_str_clone, e);
return;
}
info!(
"WOL packet sent. Waiting up to 60s for {} to become reachable...",
remote_addr_clone
);
let deadline = tokio::time::Instant::now() + Duration::from_secs(60);
let mut host_up = false;
while tokio::time::Instant::now() < deadline {
if wol::tcp_check(remote_addr_clone, connect_timeout) {
info!("Host {} is now up.", remote_addr_clone);
host_up = true;
break;
}
tokio::time::sleep(Duration::from_secs(2)).await;
}
if !host_up {
warn!(
"Timeout waiting for host {} to come up. Dropping connection from {}.",
remote_addr_clone, client_addr
);
return;
}
}
let mut outbound = match connection_pool_clone.get_connection(remote_addr_clone).await {
Ok(stream) => {
debug!("Successfully obtained or created connection to {}", remote_addr_clone);
stream
}
Err(e) => {
error!("Failed to obtain connection to remote {}: {}", remote_addr_clone, e);
return;
}
};
match copy_bidirectional(&mut inbound, &mut outbound).await {
Ok(_) => {
drop(outbound);
debug!(
"Completed data transfer for {} (connection closed)",
remote_addr_clone
);
}
Err(e) => {
drop(outbound);
connection_pool_clone.remove_target(remote_addr_clone).await;
warn!(
"Error forwarding data between {} and {}: {}",
client_addr, remote_addr_clone, e
);
}
}
});
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::ErrorKind;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
#[test]
fn turn_off_url_formats_expected_path() {
let url = super::turn_off_url("192.168.1.10", 8080);
assert_eq!(url, "http://192.168.1.10:8080/machines/turn-off");
}
#[tokio::test]
async fn turn_off_remote_machine_sends_expected_request() {
let listener = match TcpListener::bind("127.0.0.1:0").await {
Ok(listener) => listener,
Err(err)
if matches!(
err.kind(),
ErrorKind::PermissionDenied | ErrorKind::AddrNotAvailable
) =>
{
eprintln!(
"skipping test because binding TCP sockets is not permitted: {}",
err
);
return;
}
Err(err) => panic!("failed to bind http test listener: {err}"),
};
let addr = listener.local_addr().expect("failed to read listener addr");
let received = Arc::new(Mutex::new(None));
let received_clone = received.clone();
let server_task = tokio::spawn(async move {
if let Ok((mut socket, _)) = listener.accept().await {
let mut buf = vec![0u8; 1024];
if let Ok(n) = socket.read(&mut buf).await {
if n > 0 {
let request = String::from_utf8_lossy(&buf[..n]).to_string();
*received_clone.lock().await = Some(request);
}
}
let _ = socket
.write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")
.await;
}
});
turn_off_remote_machine(&addr.ip().to_string(), addr.port())
.await
.expect("turn_off_remote_machine should succeed");
server_task.await.expect("server task panicked");
let request = received.lock().await.clone().expect("no request captured");
assert!(request.starts_with("POST /machines/turn-off"));
let host_line = request
.lines()
.find(|line| line.to_ascii_lowercase().starts_with("host:"))
.unwrap_or_else(|| panic!("Host header missing in request: {request}"));
let host_value = host_line.split_once(':').map(|(_, value)| value.trim());
let expected_ip = addr.ip().to_string();
let expected_with_port = format!("{}:{}", expected_ip, addr.port());
assert!(
matches!(host_value, Some(value) if value.eq_ignore_ascii_case(&expected_ip) || value.eq_ignore_ascii_case(&expected_with_port)),
"unexpected host header: {host_line}"
);
}
}