pub mod helpers;
mod query;
mod tick;
use std::{collections::HashMap, net::Ipv4Addr, time::Instant};
use crate::constants::{buffer, network_values::PROTOCOL_TCP};
use crate::stats::{
update_pairs_stats_buffer, Direction, IpPair, PairStatMap, QualityTracker, SessionStats, Speed,
StatKey, StatsMap, TcpStateTracker, TimedSpeed,
};
use ringbuf::traits::{Consumer, RingBuffer};
use ringbuf::HeapRb;
pub struct StatsAggregator {
speed_buffer: HeapRb<Vec<u128>>,
stat_keys_buffer: HeapRb<StatKey>,
stats_buffer: HeapRb<StatsMap>,
pairs_buffer: HeapRb<PairStatMap>,
hosts_buffer: HeapRb<HashMap<Ipv4Addr, Speed>>,
total_speed_buffer: HeapRb<TimedSpeed>,
last_tick_time: Option<Instant>,
last_elapsed_secs: f64,
pub session_stats: SessionStats,
tcp_state_tracker: TcpStateTracker,
quality_tracker: QualityTracker,
connection_ports: HashMap<IpPair, (u16, u16, Direction)>,
}
impl StatsAggregator {
fn new() -> Self {
Self::new_with_window_size(buffer::DEFAULT_STATS_WINDOW_SIZE)
}
fn new_with_window_size(window: usize) -> Self {
Self {
speed_buffer: HeapRb::new(window),
stat_keys_buffer: HeapRb::new(buffer::DEFAULT_STATS_KEYS_BUFFER_SIZE),
stats_buffer: HeapRb::new(window),
pairs_buffer: HeapRb::new(window),
hosts_buffer: HeapRb::new(window),
total_speed_buffer: HeapRb::new(window),
last_tick_time: None,
last_elapsed_secs: 1.0,
session_stats: Default::default(),
tcp_state_tracker: TcpStateTracker::new(),
quality_tracker: QualityTracker::new(),
connection_ports: HashMap::new(),
}
}
fn update_pairs_stats_buffer(&mut self) {
update_pairs_stats_buffer(
&self.stats_buffer,
&mut self.pairs_buffer,
self.last_elapsed_secs,
);
}
fn update_hosts_stats_buffer(&mut self) {
self.hosts_buffer.clear();
self.pairs_buffer.iter().for_each(|pairs| {
let mut hosts_pair: HashMap<Ipv4Addr, Speed> = Default::default();
pairs
.iter()
.filter(|(pair, _)| !pair.is_local)
.for_each(|(pair, timed_speed)| {
hosts_pair
.entry(pair.src_ip)
.and_modify(|sp| {
*sp += timed_speed.speed;
})
.or_insert(timed_speed.speed);
});
self.hosts_buffer.push_overwrite(hosts_pair);
});
}
fn update_total_speed(&mut self) {
self.total_speed_buffer.clear();
self.hosts_buffer.iter().for_each(|per_host| {
let mut speed_sum: Speed = Default::default();
per_host.iter().for_each(|(_adr, speed)| {
speed_sum += *speed;
});
self.total_speed_buffer
.push_overwrite(TimedSpeed::new(speed_sum, self.last_elapsed_secs));
});
}
fn update_tcp_states(&mut self) {
self.tcp_state_tracker.prune_stale();
let active_pairs: Vec<IpPair> = self.tcp_state_tracker.states.keys().copied().collect();
let mut stale_ports = Vec::new();
for pair in self.connection_ports.keys() {
if !active_pairs.contains(pair) {
stale_ports.push(*pair);
}
}
for pair in stale_ports {
self.connection_ports.remove(&pair);
}
if let Some(latest_stats) = self.stats_buffer.iter().last() {
for (key, value) in latest_stats.iter() {
let (mut src, mut dst) = (key.src_ip, key.dst_ip);
let is_local = key.direction == Direction::Local;
if Direction::Incoming == key.direction || (is_local && src > dst) {
(src, dst) = (dst, src);
}
let pair = IpPair {
src_ip: src,
dst_ip: dst,
is_local,
protocol: key.protocol,
};
self.connection_ports.entry(pair).or_insert((
key.src_port,
key.dst_port,
key.direction,
));
if key.protocol == PROTOCOL_TCP {
self.tcp_state_tracker.update(
pair,
key.tcp_syn,
key.tcp_ack,
key.tcp_fin,
key.tcp_rst,
);
}
if let Some(timestamp) = value.last_timestamp {
self.quality_tracker.update(
pair,
timestamp,
value.last_seq.unwrap_or(0),
value.last_ack.unwrap_or(0),
);
}
}
}
}
fn update_quality_metrics(&mut self) {
self.quality_tracker.prune_stale();
}
}
impl Default for StatsAggregator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::stats::StatValues;
macro_rules! ip {
($a:expr, $b:expr, $c:expr, $d:expr) => {
std::net::Ipv4Addr::new($a, $b, $c, $d)
};
}
#[test]
fn test_aggregator_default() {
let agg = StatsAggregator::default();
assert!(agg.speed_str().is_empty());
}
#[test]
fn test_aggregator_tick_empty() {
let mut agg = StatsAggregator::default();
agg.tick(HashMap::new());
}
#[test]
fn test_session_stats_tracking() {
let mut agg = StatsAggregator::default();
let mut stats = HashMap::new();
let key = StatKey {
src_port: 12345,
dst_port: 443,
src_ip: ip!(192, 168, 1, 1),
dst_ip: ip!(93, 184, 216, 34),
direction: Direction::Outgoing,
protocol: 6,
tcp_syn: false,
tcp_ack: false,
tcp_fin: false,
tcp_rst: false,
};
stats.insert(
key,
StatValues {
size: 1000,
last_timestamp: None,
last_seq: None,
last_ack: None,
},
);
agg.tick(stats);
assert_eq!(agg.session_stats.total_bits_up, 1000);
assert_eq!(agg.session_stats.total_bits_down, 0);
}
}