use bytesize::ByteSize;
use pnet::datalink::{self, NetworkInterface};
use std::cmp::min;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use sysinfo::Networks;
use tokio::sync::Mutex;
use tracing::{info, warn};
#[derive(Debug, Clone, Default)]
pub struct NetworkStats {
pub max_rx_bandwidth: u64,
pub rx_bandwidth: Option<u64>,
pub max_tx_bandwidth: u64,
pub tx_bandwidth: Option<u64>,
}
#[derive(Debug, Clone, Default)]
pub struct Network {
interface_name: String,
bandwidth: u64,
mutex: Arc<Mutex<()>>,
}
impl Network {
const DEFAULT_NETWORK_REFRESH_INTERVAL: Duration = Duration::from_secs(1);
pub fn new(ip: IpAddr, rate_limit: ByteSize) -> Network {
let rate_limit = Self::byte_size_to_bits(rate_limit); let Some(interface) = Self::get_network_interface_by_ip(ip) else {
warn!(
"can not find interface for IP address {}, network interface unknown with bandwidth {} bps",
ip, rate_limit
);
return Self {
interface_name: "unknown".to_string(),
bandwidth: rate_limit,
mutex: Arc::new(Mutex::new(())),
};
};
match Self::get_speed(&interface.name) {
Some(speed) => {
let bandwidth = min(Self::megabits_to_bits(speed), rate_limit);
info!(
"network interface {} with bandwidth {} bps",
interface.name, bandwidth
);
Self {
interface_name: interface.name,
bandwidth,
mutex: Arc::new(Mutex::new(())),
}
}
None => {
warn!(
"can not get speed, network interface {} with bandwidth {} bps",
interface.name, rate_limit
);
Self {
interface_name: interface.name,
bandwidth: rate_limit,
mutex: Arc::new(Mutex::new(())),
}
}
}
}
pub async fn get_stats(&self) -> NetworkStats {
let _guard = self.mutex.lock().await;
let mut networks = Networks::new_with_refreshed_list();
tokio::time::sleep(Self::DEFAULT_NETWORK_REFRESH_INTERVAL).await;
networks.refresh();
let Some(network_stats) = networks.get(self.interface_name.as_str()) else {
warn!(
"can not find network data for interface {}",
self.interface_name
);
return NetworkStats {
max_rx_bandwidth: self.bandwidth,
max_tx_bandwidth: self.bandwidth,
..Default::default()
};
};
let rx_bandwidth = (Self::bytes_to_bits(network_stats.received()) as f64
/ Self::DEFAULT_NETWORK_REFRESH_INTERVAL.as_secs_f64())
.round() as u64;
let tx_bandwidth = (Self::bytes_to_bits(network_stats.transmitted()) as f64
/ Self::DEFAULT_NETWORK_REFRESH_INTERVAL.as_secs_f64())
.round() as u64;
NetworkStats {
max_rx_bandwidth: self.bandwidth,
rx_bandwidth: Some(rx_bandwidth),
max_tx_bandwidth: self.bandwidth,
tx_bandwidth: Some(tx_bandwidth),
}
}
pub fn get_speed(name: &str) -> Option<u64> {
#[cfg(target_os = "linux")]
{
let speed_path = format!("/sys/class/net/{}/speed", name);
std::fs::read_to_string(&speed_path)
.ok()
.and_then(|speed_str| speed_str.trim().parse::<u64>().ok())
}
#[cfg(not(target_os = "linux"))]
{
warn!("can not get interface {} speed on non-linux platform", name);
None
}
}
pub fn get_network_interface_by_ip(ip: IpAddr) -> Option<NetworkInterface> {
datalink::interfaces()
.into_iter()
.find(|interface| interface.ips.iter().any(|ip_net| ip_net.ip() == ip))
}
pub fn byte_size_to_bits(size: ByteSize) -> u64 {
size.as_u64() * 8
}
pub fn megabits_to_bits(size: u64) -> u64 {
size * 1_000_000 }
pub fn bytes_to_bits(size: u64) -> u64 {
size * 8 }
}
#[cfg(test)]
mod tests {
use super::*;
use bytesize::ByteSize;
#[test]
fn test_byte_size_to_bits() {
let test_cases = vec![
(ByteSize::kb(1), 8_000u64),
(ByteSize::mb(1), 8_000_000u64),
(ByteSize::gb(1), 8_000_000_000u64),
(ByteSize::b(0), 0u64),
];
for (input, expected) in test_cases {
let result = Network::byte_size_to_bits(input);
assert_eq!(result, expected);
}
}
#[test]
fn test_megabits_to_bits() {
let test_cases = vec![
(1u64, 1_000_000u64),
(1000u64, 1_000_000_000u64),
(0u64, 0u64),
];
for (input, expected) in test_cases {
let result = Network::megabits_to_bits(input);
assert_eq!(result, expected);
}
}
#[test]
fn test_bytes_to_bits() {
let test_cases = vec![(1u64, 8u64), (1000u64, 8_000u64), (0u64, 0u64)];
for (input, expected) in test_cases {
let result = Network::bytes_to_bits(input);
assert_eq!(result, expected);
}
}
}