use crate::error::{AppError, Result};
use std::net::TcpListener;
use std::time::Duration;
use tokio::net::{TcpSocket, TcpStream};
use tokio::time::timeout;
pub async fn is_port_available(host: &str, port: u16) -> Result<bool> {
let addr = format!("{}:{}", host, port).parse().map_err(|e| {
AppError::Io(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid address {}:{}: {}", host, port, e),
))
})?;
let socket = match addr {
std::net::SocketAddr::V4(_) => TcpSocket::new_v4(),
std::net::SocketAddr::V6(_) => TcpSocket::new_v6(),
}
.map_err(|e| {
AppError::Io(std::io::Error::other(format!(
"Failed to create socket: {}",
e
)))
})?;
match socket.bind(addr) {
Ok(_) => Ok(true),
Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => Ok(false),
Err(e) => Err(AppError::Io(e)),
}
}
#[allow(dead_code)]
pub fn is_port_available_sync(host: &str, port: u16) -> Result<bool> {
match TcpListener::bind(format!("{}:{}", host, port)) {
Ok(_) => Ok(true),
Err(e) if e.kind() == std::io::ErrorKind::AddrInUse => Ok(false),
Err(e) => Err(AppError::Io(e)),
}
}
pub async fn check_ports(endpoints: &[(String, u16)]) -> Result<Vec<(String, u16)>> {
let mut occupied = Vec::new();
for (host, port) in endpoints {
if !is_port_available(host, *port).await? {
occupied.push((host.clone(), *port));
}
}
Ok(occupied)
}
pub async fn test_port_connection(host: &str, port: u16) -> Result<bool> {
let addr = format!("{}:{}", host, port);
match timeout(Duration::from_secs(2), TcpStream::connect(&addr)).await {
Ok(Ok(_)) => Ok(true),
Ok(Err(_)) => Ok(false),
Err(_) => Ok(false), }
}
pub async fn test_tunnel_connection(host: &str, port: u16) -> Result<bool> {
use tokio::io::AsyncWriteExt;
let addr = format!("{}:{}", host, port);
let mut stream = match timeout(Duration::from_secs(2), TcpStream::connect(&addr)).await {
Ok(Ok(s)) => s,
Ok(Err(_)) => return Ok(false),
Err(_) => return Ok(false), };
match timeout(Duration::from_secs(1), stream.write_all(b"X")).await {
Ok(Ok(_)) => {
Ok(true)
}
Ok(Err(e)) => {
if e.kind() == std::io::ErrorKind::ConnectionReset
|| e.kind() == std::io::ErrorKind::BrokenPipe
{
Ok(false)
} else {
Ok(true)
}
}
Err(_) => {
Ok(true)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_port_check() {
let port = 49152
+ (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
% 16384) as u16;
let available = is_port_available("127.0.0.1", port).await;
assert!(available.is_ok());
}
#[tokio::test]
async fn occupied_port_is_detected_on_specific_host() {
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
assert!(!is_port_available("127.0.0.1", port).await.unwrap());
}
}