use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use crate::data::flow::{Direction, FlowKey};
use crate::data::history::FlowHistory;
#[derive(Clone)]
pub struct FlowTracker {
inner: Arc<Mutex<FlowTrackerInner>>,
}
struct FlowTrackerInner {
flows: HashMap<FlowKey, FlowHistory>,
last_rotation: Instant,
pub total_sent: u64,
pub total_recv: u64,
pub peak_sent: f64,
pub peak_recv: f64,
current_sent: u64,
current_recv: u64,
}
#[derive(Debug, Clone)]
pub struct FlowSnapshot {
pub key: FlowKey,
pub sent_2s: f64,
pub sent_10s: f64,
pub sent_40s: f64,
pub recv_2s: f64,
pub recv_10s: f64,
pub recv_40s: f64,
pub total_sent: u64,
pub total_recv: u64,
pub process_name: Option<String>,
pub pid: Option<u32>,
pub history: Vec<u64>,
}
#[derive(Debug, Clone)]
pub struct TotalStats {
pub sent_2s: f64,
pub sent_10s: f64,
pub sent_40s: f64,
pub recv_2s: f64,
pub recv_10s: f64,
pub recv_40s: f64,
pub cumulative_sent: u64,
pub cumulative_recv: u64,
pub peak_sent: f64,
pub peak_recv: f64,
}
impl Default for FlowTracker {
fn default() -> Self {
Self::new()
}
}
impl FlowTracker {
pub fn new() -> Self {
FlowTracker {
inner: Arc::new(Mutex::new(FlowTrackerInner {
flows: HashMap::new(),
last_rotation: Instant::now(),
total_sent: 0,
total_recv: 0,
peak_sent: 0.0,
peak_recv: 0.0,
current_sent: 0,
current_recv: 0,
})),
}
}
pub fn record(&self, key: FlowKey, direction: Direction, bytes: u64) {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let history = inner.flows.entry(key).or_default();
match direction {
Direction::Sent => {
history.add_sent(bytes);
inner.total_sent += bytes;
inner.current_sent += bytes;
}
Direction::Received => {
history.add_recv(bytes);
inner.total_recv += bytes;
inner.current_recv += bytes;
}
}
}
pub fn set_process_info(&self, key: &FlowKey, pid: u32, name: String) {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(history) = inner.flows.get_mut(key) {
history.pid = Some(pid);
history.process_name = Some(name);
}
}
pub fn maybe_rotate(&self) {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let elapsed = inner.last_rotation.elapsed();
if elapsed.as_secs() >= 1 {
let sent_rate = inner.current_sent as f64;
let recv_rate = inner.current_recv as f64;
if sent_rate > inner.peak_sent {
inner.peak_sent = sent_rate;
}
if recv_rate > inner.peak_recv {
inner.peak_recv = recv_rate;
}
inner.current_sent = 0;
inner.current_recv = 0;
for history in inner.flows.values_mut() {
history.rotate();
}
inner.last_rotation = Instant::now();
let now = Instant::now();
inner
.flows
.retain(|_, h| now.duration_since(h.last_seen).as_secs() < 60);
}
}
pub fn snapshot(&self) -> (Vec<FlowSnapshot>, TotalStats) {
let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let snapshots: Vec<FlowSnapshot> = inner
.flows
.iter()
.map(|(key, h)| {
let history: Vec<u64> = h
.sent
.iter()
.zip(h.recv.iter())
.map(|(&s, &r)| s + r)
.collect();
FlowSnapshot {
key: *key,
sent_2s: h.avg_sent_2s(),
sent_10s: h.avg_sent_10s(),
sent_40s: h.avg_sent_40s(),
recv_2s: h.avg_recv_2s(),
recv_10s: h.avg_recv_10s(),
recv_40s: h.avg_recv_40s(),
total_sent: h.total_sent,
total_recv: h.total_recv,
process_name: h.process_name.clone(),
pid: h.pid,
history,
}
})
.collect();
let totals = TotalStats {
sent_2s: snapshots.iter().map(|f| f.sent_2s).sum(),
sent_10s: snapshots.iter().map(|f| f.sent_10s).sum(),
sent_40s: snapshots.iter().map(|f| f.sent_40s).sum(),
recv_2s: snapshots.iter().map(|f| f.recv_2s).sum(),
recv_10s: snapshots.iter().map(|f| f.recv_10s).sum(),
recv_40s: snapshots.iter().map(|f| f.recv_40s).sum(),
cumulative_sent: inner.total_sent,
cumulative_recv: inner.total_recv,
peak_sent: inner.peak_sent,
peak_recv: inner.peak_recv,
};
(snapshots, totals)
}
pub fn flow_keys(&self) -> Vec<FlowKey> {
let inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.flows.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::flow::Protocol;
fn test_key(port: u16) -> FlowKey {
FlowKey {
src: "10.0.0.1".parse().unwrap(),
dst: "10.0.0.2".parse().unwrap(),
src_port: port,
dst_port: 80,
protocol: Protocol::Tcp,
}
}
#[test]
fn new_tracker_empty_snapshot() {
let t = FlowTracker::new();
let (flows, totals) = t.snapshot();
assert!(flows.is_empty());
assert_eq!(totals.cumulative_sent, 0);
assert_eq!(totals.cumulative_recv, 0);
}
#[test]
fn record_sent_packet() {
let t = FlowTracker::new();
let key = test_key(5000);
t.record(key, Direction::Sent, 1500);
let (flows, totals) = t.snapshot();
assert_eq!(flows.len(), 1);
assert_eq!(totals.cumulative_sent, 1500);
assert_eq!(totals.cumulative_recv, 0);
}
#[test]
fn record_recv_packet() {
let t = FlowTracker::new();
let key = test_key(5000);
t.record(key, Direction::Received, 500);
let (_, totals) = t.snapshot();
assert_eq!(totals.cumulative_recv, 500);
}
#[test]
fn multiple_flows() {
let t = FlowTracker::new();
t.record(test_key(5000), Direction::Sent, 100);
t.record(test_key(5001), Direction::Sent, 200);
t.record(test_key(5002), Direction::Sent, 300);
let (flows, totals) = t.snapshot();
assert_eq!(flows.len(), 3);
assert_eq!(totals.cumulative_sent, 600);
}
#[test]
fn flow_keys_returns_all() {
let t = FlowTracker::new();
t.record(test_key(1), Direction::Sent, 10);
t.record(test_key(2), Direction::Sent, 20);
assert_eq!(t.flow_keys().len(), 2);
}
#[test]
fn set_process_info() {
let t = FlowTracker::new();
let key = test_key(5000);
t.record(key, Direction::Sent, 100);
t.set_process_info(&key, 1234, "curl".to_string());
let (flows, _) = t.snapshot();
assert_eq!(flows[0].pid, Some(1234));
assert_eq!(flows[0].process_name.as_deref(), Some("curl"));
}
#[test]
fn set_process_info_nonexistent_key_no_panic() {
let t = FlowTracker::new();
let key = test_key(9999);
t.set_process_info(&key, 1234, "ghost".to_string());
let (flows, _) = t.snapshot();
assert!(flows.is_empty());
}
#[test]
fn default_trait() {
let t = FlowTracker::default();
let (flows, totals) = t.snapshot();
assert!(flows.is_empty());
assert_eq!(totals.cumulative_sent, 0);
}
#[test]
fn record_both_directions() {
let t = FlowTracker::new();
let key = test_key(5000);
t.record(key, Direction::Sent, 100);
t.record(key, Direction::Received, 200);
let (flows, totals) = t.snapshot();
assert_eq!(flows.len(), 1);
assert_eq!(totals.cumulative_sent, 100);
assert_eq!(totals.cumulative_recv, 200);
}
#[test]
fn record_same_flow_accumulates() {
let t = FlowTracker::new();
let key = test_key(5000);
t.record(key, Direction::Sent, 100);
t.record(key, Direction::Sent, 200);
t.record(key, Direction::Sent, 300);
let (_, totals) = t.snapshot();
assert_eq!(totals.cumulative_sent, 600);
}
#[test]
fn flow_keys_empty_tracker() {
let t = FlowTracker::new();
assert!(t.flow_keys().is_empty());
}
#[test]
fn snapshot_totals_sum_flow_rates() {
let t = FlowTracker::new();
t.record(test_key(1), Direction::Sent, 100);
t.record(test_key(2), Direction::Sent, 200);
let (_, totals) = t.snapshot();
assert_eq!(totals.cumulative_sent, 300);
assert_eq!(totals.cumulative_recv, 0);
}
#[test]
fn maybe_rotate_does_not_panic_empty() {
let t = FlowTracker::new();
t.maybe_rotate(); }
#[test]
fn clone_shares_state() {
let t = FlowTracker::new();
let t2 = t.clone();
t.record(test_key(5000), Direction::Sent, 100);
let (flows, _) = t2.snapshot();
assert_eq!(flows.len(), 1);
}
#[test]
fn concurrent_access_no_panic() {
let t = FlowTracker::new();
let handles: Vec<_> = (0..10)
.map(|i| {
let t = t.clone();
std::thread::spawn(move || {
for j in 0..100 {
t.record(test_key(i * 100 + j), Direction::Sent, 10);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let (flows, totals) = t.snapshot();
assert_eq!(flows.len(), 1000);
assert_eq!(totals.cumulative_sent, 10_000);
}
#[test]
fn mutex_poison_recovery() {
let t = FlowTracker::new();
let t2 = t.clone();
let h = std::thread::spawn(move || {
let _inner = t2.inner.lock().unwrap();
panic!("intentional poison");
});
let _ = h.join();
t.record(test_key(1), Direction::Sent, 42);
let (flows, _) = t.snapshot();
assert_eq!(flows.len(), 1);
assert_eq!(t.flow_keys().len(), 1);
t.set_process_info(&test_key(1), 99, "recovered".into());
t.maybe_rotate();
}
#[test]
fn peak_tracking_works() {
let t = FlowTracker::new();
t.record(test_key(1), Direction::Sent, 5000);
t.record(test_key(1), Direction::Received, 3000);
{
let mut inner = t.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.last_rotation = std::time::Instant::now() - std::time::Duration::from_secs(2);
}
t.maybe_rotate();
let (_, totals) = t.snapshot();
assert!(totals.peak_sent >= 5000.0);
assert!(totals.peak_recv >= 3000.0);
}
#[test]
fn process_info_overwrites() {
let t = FlowTracker::new();
let key = test_key(5000);
t.record(key, Direction::Sent, 100);
t.set_process_info(&key, 1, "old".into());
t.set_process_info(&key, 2, "new".into());
let (flows, _) = t.snapshot();
assert_eq!(flows[0].pid, Some(2));
assert_eq!(flows[0].process_name.as_deref(), Some("new"));
}
#[test]
fn snapshot_includes_total_sent_recv_per_flow() {
let t = FlowTracker::new();
let key = test_key(5000);
t.record(key, Direction::Sent, 100);
t.record(key, Direction::Received, 50);
let (flows, _) = t.snapshot();
assert_eq!(flows[0].total_sent, 100);
assert_eq!(flows[0].total_recv, 50);
}
}