use crate::stats::tcp::TcpState;
use crate::stats::{IpPair, Speed};
use std::collections::HashMap;
pub struct TcpStateTracker {
pub(crate) states: HashMap<IpPair, TcpState>,
last_update: HashMap<IpPair, std::time::Instant>,
first_seen: HashMap<IpPair, std::time::Instant>,
timeout: std::time::Duration,
}
impl Default for TcpStateTracker {
fn default() -> Self {
Self::new()
}
}
impl TcpStateTracker {
pub fn new() -> Self {
Self {
states: HashMap::new(),
last_update: HashMap::new(),
first_seen: HashMap::new(),
timeout: std::time::Duration::from_secs(300), }
}
pub fn with_timeout(timeout: std::time::Duration) -> Self {
Self {
states: HashMap::new(),
last_update: HashMap::new(),
first_seen: HashMap::new(),
timeout,
}
}
pub fn update(&mut self, pair: IpPair, syn: bool, ack: bool, fin: bool, rst: bool) {
let current_state = self.states.get(&pair).copied().unwrap_or(TcpState::Closed);
let new_state = current_state.transition(syn, ack, fin, rst);
let now = std::time::Instant::now();
self.first_seen.entry(pair).or_insert(now);
self.states.insert(pair, new_state);
self.last_update.insert(pair, now);
}
pub fn get(&self, pair: &IpPair) -> Option<TcpState> {
self.states.get(pair).copied()
}
pub fn connections_with_state(
&self,
speeds: &HashMap<IpPair, Speed>,
) -> Vec<(IpPair, TcpState, Speed)> {
self.states
.iter()
.filter_map(|(pair, state)| speeds.get(pair).map(|speed| (*pair, *state, *speed)))
.collect()
}
pub fn connection_age(&self, pair: &IpPair) -> Option<std::time::Duration> {
self.first_seen
.get(pair)
.map(|first| std::time::Instant::now().duration_since(*first))
}
pub fn prune_stale(&mut self) {
let now = std::time::Instant::now();
let to_remove: Vec<IpPair> = self
.last_update
.iter()
.filter(|(_, last)| now.duration_since(**last) > self.timeout)
.map(|(pair, _)| *pair)
.collect();
for pair in to_remove {
self.states.remove(&pair);
self.last_update.remove(&pair);
self.first_seen.remove(&pair);
}
}
pub fn len(&self) -> usize {
self.states.len()
}
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn ip_pair(a: u8, b: u8) -> IpPair {
IpPair {
src_ip: Ipv4Addr::new(192, 168, a, a),
dst_ip: Ipv4Addr::new(192, 168, b, b),
is_local: false,
protocol: 6,
}
}
#[test]
fn test_tracker_update() {
let mut tracker = TcpStateTracker::new();
let pair = ip_pair(1, 2);
tracker.update(pair, true, false, false, false);
assert_eq!(tracker.get(&pair), Some(TcpState::SynSent));
tracker.update(pair, false, true, false, false);
assert_eq!(tracker.get(&pair), Some(TcpState::Established));
}
#[test]
fn test_tracker_prune_stale() {
let mut tracker = TcpStateTracker::with_timeout(std::time::Duration::from_millis(100));
let pair = ip_pair(1, 2);
tracker.update(pair, true, false, false, false);
assert_eq!(tracker.len(), 1);
std::thread::sleep(std::time::Duration::from_millis(150));
tracker.prune_stale();
assert_eq!(tracker.len(), 0);
}
}