use bytesize::ByteSize;
use local_ip_address::{local_ip, local_ipv6};
use pnet::datalink::{self, NetworkInterface};
use std::cmp::min;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use sysinfo::Networks;
use tokio::sync::Mutex;
use tracing::{info, warn};
#[cfg(target_os = "linux")]
use std::{io, mem, os::unix::io::RawFd};
pub fn format_socket_addr(ip: IpAddr, port: u16) -> String {
match ip {
IpAddr::V4(v4) => format!("{}:{}", v4, port),
IpAddr::V6(v6) => format!("[{}]:{}", v6, port),
}
}
pub fn format_url(scheme: &str, ip: IpAddr, port: u16) -> String {
format!("{}://{}", scheme, format_socket_addr(ip, port))
}
pub fn preferred_local_ip() -> Option<IpAddr> {
local_ip().ok().or_else(|| local_ipv6().ok())
}
#[derive(Debug, Clone, Default)]
pub struct Interface {
pub name: String,
pub bandwidth: u64,
network_data_mutex: Arc<Mutex<()>>,
}
#[derive(Debug, Clone, Default)]
pub struct NetworkData {
pub max_rx_bandwidth: u64,
pub rx_bandwidth: Option<u64>,
pub max_tx_bandwidth: u64,
pub tx_bandwidth: Option<u64>,
}
impl Interface {
const DEFAULT_NETWORKS_REFRESH_INTERVAL: Duration = Duration::from_secs(1);
pub fn new(ip: IpAddr, rate_limit: ByteSize) -> Interface {
let rate_limit = Self::byte_size_to_bits(rate_limit); let Some(interface) = Self::get_network_interface_by_ip(ip) else {
warn!(
"can not find interface for IP address {}, network interface unknown with bandwidth {} bps",
ip, rate_limit
);
return Interface {
name: "unknown".to_string(),
bandwidth: rate_limit,
network_data_mutex: Arc::new(Mutex::new(())),
};
};
match Self::get_speed(&interface.name) {
Some(speed) => {
let bandwidth = min(Self::megabits_to_bits(speed), rate_limit);
info!(
"network interface {} with bandwidth {} bps",
interface.name, bandwidth
);
Interface {
name: interface.name,
bandwidth,
network_data_mutex: Arc::new(Mutex::new(())),
}
}
None => {
warn!(
"can not get speed, network interface {} with bandwidth {} bps",
interface.name, rate_limit
);
Interface {
name: interface.name,
bandwidth: rate_limit,
network_data_mutex: Arc::new(Mutex::new(())),
}
}
}
}
pub async fn get_network_data(&self) -> NetworkData {
let _guard = self.network_data_mutex.lock().await;
let mut networks = Networks::new_with_refreshed_list();
tokio::time::sleep(Self::DEFAULT_NETWORKS_REFRESH_INTERVAL).await;
networks.refresh();
let Some(network_data) = networks.get(self.name.as_str()) else {
warn!("can not find network data for interface {}", self.name);
return NetworkData {
max_rx_bandwidth: self.bandwidth,
max_tx_bandwidth: self.bandwidth,
..Default::default()
};
};
let rx_bandwidth = (Self::bytes_to_bits(network_data.received()) as f64
/ Self::DEFAULT_NETWORKS_REFRESH_INTERVAL.as_secs_f64())
.round() as u64;
let tx_bandwidth = (Self::bytes_to_bits(network_data.transmitted()) as f64
/ Self::DEFAULT_NETWORKS_REFRESH_INTERVAL.as_secs_f64())
.round() as u64;
NetworkData {
max_rx_bandwidth: self.bandwidth,
rx_bandwidth: Some(rx_bandwidth),
max_tx_bandwidth: self.bandwidth,
tx_bandwidth: Some(tx_bandwidth),
}
}
pub fn get_speed(name: &str) -> Option<u64> {
#[cfg(target_os = "linux")]
{
let speed_path = format!("/sys/class/net/{}/speed", name);
std::fs::read_to_string(&speed_path)
.ok()
.and_then(|speed_str| speed_str.trim().parse::<u64>().ok())
}
#[cfg(not(target_os = "linux"))]
{
warn!("can not get interface {} speed on non-linux platform", name);
None
}
}
pub fn get_network_interface_by_ip(ip: IpAddr) -> Option<NetworkInterface> {
datalink::interfaces()
.into_iter()
.find(|interface| interface.ips.iter().any(|ip_net| ip_net.ip() == ip))
}
pub fn byte_size_to_bits(size: ByteSize) -> u64 {
size.as_u64() * 8
}
pub fn megabits_to_bits(size: u64) -> u64 {
size * 1_000_000 }
pub fn bytes_to_bits(size: u64) -> u64 {
size * 8 }
}
#[cfg(target_os = "linux")]
pub fn set_tcp_fastopen_connect(fd: RawFd) -> io::Result<()> {
let enable: libc::c_int = 1;
unsafe {
let ret = libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_FASTOPEN_CONNECT,
&enable as *const _ as *const libc::c_void,
mem::size_of_val(&enable) as libc::socklen_t,
);
if ret != 0 {
let err = std::io::Error::last_os_error();
return Err(err);
}
}
Ok(())
}
#[cfg(target_os = "linux")]
pub fn set_tcp_fastopen(fd: RawFd) -> io::Result<()> {
let queue: libc::c_int = 1024;
unsafe {
let ret = libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_FASTOPEN,
&queue as *const _ as *const libc::c_void,
mem::size_of_val(&queue) as libc::socklen_t,
);
if ret != 0 {
let err = std::io::Error::last_os_error();
return Err(err);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use bytesize::ByteSize;
use std::str::FromStr;
#[test]
fn test_byte_size_to_bits() {
let test_cases = vec![
(ByteSize::kb(1), 8_000u64),
(ByteSize::mb(1), 8_000_000u64),
(ByteSize::gb(1), 8_000_000_000u64),
(ByteSize::b(0), 0u64),
];
for (input, expected) in test_cases {
let result = Interface::byte_size_to_bits(input);
assert_eq!(result, expected);
}
}
#[test]
fn test_megabits_to_bits() {
let test_cases = vec![
(1u64, 1_000_000u64),
(1000u64, 1_000_000_000u64),
(0u64, 0u64),
];
for (input, expected) in test_cases {
let result = Interface::megabits_to_bits(input);
assert_eq!(result, expected);
}
}
#[test]
fn test_bytes_to_bits() {
let test_cases = vec![(1u64, 8u64), (1000u64, 8_000u64), (0u64, 0u64)];
for (input, expected) in test_cases {
let result = Interface::bytes_to_bits(input);
assert_eq!(result, expected);
}
}
#[test]
fn test_format_socket_addr_ipv4() {
assert_eq!(
format_socket_addr(IpAddr::from_str("127.0.0.1").unwrap(), 80),
"127.0.0.1:80"
);
assert_eq!(
format_socket_addr(IpAddr::from_str("192.168.1.1").unwrap(), 8080),
"192.168.1.1:8080"
);
}
#[test]
fn test_format_socket_addr_ipv6() {
assert_eq!(
format_socket_addr(IpAddr::from_str("::1").unwrap(), 80),
"[::1]:80"
);
assert_eq!(
format_socket_addr(IpAddr::from_str("2001:db8::1").unwrap(), 8080),
"[2001:db8::1]:8080"
);
}
#[test]
fn test_format_url_ipv4() {
assert_eq!(
format_url("http", IpAddr::from_str("127.0.0.1").unwrap(), 80),
"http://127.0.0.1:80"
);
assert_eq!(
format_url("https", IpAddr::from_str("192.168.1.1").unwrap(), 443),
"https://192.168.1.1:443"
);
}
#[test]
fn test_format_url_ipv6() {
assert_eq!(
format_url("http", IpAddr::from_str("::1").unwrap(), 80),
"http://[::1]:80"
);
assert_eq!(
format_url("https", IpAddr::from_str("2001:db8::1").unwrap(), 443),
"https://[2001:db8::1]:443"
);
}
#[test]
fn test_preferred_local_ip() {
let ip = preferred_local_ip();
assert!(ip.is_some());
}
}