use crate::stats::{Direction, Speed, StatKey, StatValues};
use ringbuf::traits::Consumer;
use ringbuf::HeapRb;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct PortStats {
pub port: u16,
pub speed: Speed,
}
pub fn top_ports_per_host(
stats_buffer: &HeapRb<HashMap<StatKey, StatValues>>,
host_ip: std::net::Ipv4Addr,
n: usize,
) -> Vec<PortStats> {
let mut port_map: HashMap<u16, Speed> = HashMap::new();
for stats in stats_buffer.iter() {
for (key, value) in stats.iter() {
if key.src_ip == host_ip || key.dst_ip == host_ip {
let port = if key.dst_ip == host_ip {
key.src_port } else {
key.dst_port };
let mut speed = Speed::default();
match key.direction {
Direction::Incoming => speed.input = value.size,
Direction::Outgoing => speed.output = value.size,
Direction::Internet => speed.output = value.size, Direction::Local => {
speed.input += value.size;
speed.output += value.size;
}
Direction::None => {}
}
port_map
.entry(port)
.and_modify(|s| *s += speed)
.or_insert(speed);
}
}
}
let mut result: Vec<PortStats> = port_map
.into_iter()
.map(|(port, speed)| PortStats { port, speed })
.collect();
result.sort_by(|a, b| {
let a_total = a.speed.input + a.speed.output;
let b_total = b.speed.input + b.speed.output;
b_total.cmp(&a_total)
});
result.truncate(n);
result
}
pub fn top_ports_all(
stats_buffer: &HeapRb<HashMap<StatKey, StatValues>>,
n: usize,
) -> Vec<PortStats> {
let mut port_map: HashMap<u16, Speed> = HashMap::new();
for stats in stats_buffer.iter() {
for (key, value) in stats.iter() {
let mut speed = Speed::default();
match key.direction {
Direction::Incoming => speed.input = value.size,
Direction::Outgoing => speed.output = value.size,
Direction::Internet => speed.output = value.size, Direction::Local => {
speed.input += value.size;
speed.output += value.size;
}
Direction::None => {}
}
port_map
.entry(key.dst_port)
.and_modify(|s| *s += speed)
.or_insert(speed);
}
}
let mut result: Vec<PortStats> = port_map
.into_iter()
.map(|(port, speed)| PortStats { port, speed })
.collect();
result.sort_by(|a, b| {
let a_total = a.speed.input + a.speed.output;
let b_total = b.speed.input + b.speed.output;
b_total.cmp(&a_total)
});
result.truncate(n);
result
}
pub fn format_port_stats(port_stats: &PortStats) -> String {
format!("{}: {}", port_stats.port, port_stats.speed)
}
#[cfg(test)]
mod tests {
use super::*;
use ringbuf::HeapRb;
macro_rules! ip {
($a:expr, $b:expr, $c:expr, $d:expr) => {
std::net::Ipv4Addr::new($a, $b, $c, $d)
};
}
#[test]
fn test_top_ports_per_host_empty() {
let stats_buffer = HeapRb::<HashMap<StatKey, StatValues>>::new(2);
let result = top_ports_per_host(&stats_buffer, ip!(192, 168, 1, 1), 10);
assert!(result.is_empty());
}
#[test]
fn test_top_ports_all_empty() {
let stats_buffer = HeapRb::<HashMap<StatKey, StatValues>>::new(2);
let result = top_ports_all(&stats_buffer, 10);
assert!(result.is_empty());
}
#[test]
fn test_format_port_stats() {
let stats = PortStats {
port: 443,
speed: Speed {
input: 16_777_216, output: 8_388_608, },
};
let formatted = format_port_stats(&stats);
assert!(formatted.contains("443:"));
assert!(formatted.contains("2.00 MiB/s"));
assert!(formatted.contains("1.00 MiB/s"));
}
#[test]
fn test_port_stats_default() {
let stats = PortStats::default();
assert_eq!(stats.port, 0);
assert_eq!(stats.speed.input, 0);
assert_eq!(stats.speed.output, 0);
}
}