use std::collections::{HashMap, VecDeque};
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct PortPredictorConfig {
pub max_samples: usize,
pub sample_ttl: Duration,
pub min_samples_for_prediction: usize,
pub max_prediction_attempts: usize,
}
impl Default for PortPredictorConfig {
fn default() -> Self {
Self {
max_samples: 10,
sample_ttl: Duration::from_secs(60), min_samples_for_prediction: 2,
max_prediction_attempts: 3,
}
}
}
#[derive(Debug, Clone)]
struct PortObservation {
port: u16,
observed_at: Instant,
}
#[derive(Debug)]
pub struct PortPredictor {
config: PortPredictorConfig,
history: HashMap<IpAddr, VecDeque<PortObservation>>,
}
impl PortPredictor {
pub fn new(config: PortPredictorConfig) -> Self {
Self {
config,
history: HashMap::new(),
}
}
pub fn record_observation(&mut self, addr: SocketAddr, now: Instant) {
let entry = self.history.entry(addr.ip()).or_default();
while let Some(obs) = entry.front() {
if now.duration_since(obs.observed_at) > self.config.sample_ttl {
entry.pop_front();
} else {
break;
}
}
if entry.iter().any(|obs| obs.port == addr.port()) {
return;
}
entry.push_back(PortObservation {
port: addr.port(),
observed_at: now,
});
if entry.len() > self.config.max_samples {
entry.pop_front();
}
}
pub fn predict_ports(&self, ip: IpAddr) -> Vec<u16> {
let Some(samples) = self.history.get(&ip) else {
return Vec::new();
};
if samples.len() < self.config.min_samples_for_prediction {
return Vec::new();
}
let mut predictions = Vec::new();
let mut sorted_observations: Vec<_> = samples.iter().collect();
sorted_observations.sort_by_key(|o| o.observed_at);
let count = sorted_observations.len();
if count >= 2 {
let last = sorted_observations[count - 1];
let prev = sorted_observations[count - 2];
let delta = last.port.wrapping_sub(prev.port);
let next_1 = last.port.wrapping_add(delta);
predictions.push(next_1);
let next_2 = next_1.wrapping_add(delta);
predictions.push(next_2);
}
predictions
}
pub fn clear(&mut self, ip: IpAddr) {
self.history.remove(&ip);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn test_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
}
#[test]
fn test_linear_prediction_increment() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
predictor.record_observation(SocketAddr::new(ip, 1002), now + Duration::from_secs(1));
let predicted = predictor.predict_ports(ip);
assert!(predicted.contains(&1004));
assert!(predicted.contains(&1006));
}
#[test]
fn test_linear_prediction_decrement() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 2000), now);
predictor.record_observation(SocketAddr::new(ip, 1990), now + Duration::from_secs(1));
let predicted = predictor.predict_ports(ip);
assert!(predicted.contains(&1980));
assert!(predicted.contains(&1970));
}
#[test]
fn test_insufficient_samples() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
let predicted = predictor.predict_ports(ip);
assert!(predicted.is_empty());
}
#[test]
fn test_ttl_expiry() {
let mut config = PortPredictorConfig::default();
config.sample_ttl = Duration::from_millis(100);
let mut predictor = PortPredictor::new(config);
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
let future = now + Duration::from_millis(200);
predictor.record_observation(SocketAddr::new(ip, 1002), future);
let predicted = predictor.predict_ports(ip);
assert!(
predicted.is_empty(),
"Should not predict with only 1 valid sample"
);
}
}