use std::collections::{HashMap, VecDeque};
use std::net::{IpAddr, SocketAddr};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct PortPredictorConfig {
pub max_samples: usize,
pub sample_ttl: Duration,
pub min_samples_for_prediction: usize,
pub max_prediction_attempts: usize,
}
impl Default for PortPredictorConfig {
fn default() -> Self {
Self {
max_samples: 10,
sample_ttl: Duration::from_secs(60), min_samples_for_prediction: 2,
max_prediction_attempts: 3,
}
}
}
#[derive(Debug, Clone)]
struct PortObservation {
port: u16,
observed_at: Instant,
}
#[derive(Debug)]
pub struct PortPredictor {
config: PortPredictorConfig,
history: HashMap<IpAddr, VecDeque<PortObservation>>,
}
impl PortPredictor {
pub fn new(config: PortPredictorConfig) -> Self {
Self {
config,
history: HashMap::new(),
}
}
pub fn record_observation(&mut self, addr: SocketAddr, now: Instant) {
let entry = self.history.entry(addr.ip()).or_default();
while let Some(obs) = entry.front() {
if now.duration_since(obs.observed_at) > self.config.sample_ttl {
entry.pop_front();
} else {
break;
}
}
if entry.iter().any(|obs| obs.port == addr.port()) {
return;
}
entry.push_back(PortObservation {
port: addr.port(),
observed_at: now,
});
if entry.len() > self.config.max_samples {
entry.pop_front();
}
}
pub fn predict_ports(&self, ip: IpAddr) -> Vec<u16> {
let Some(samples) = self.history.get(&ip) else {
return Vec::new();
};
if samples.len() < self.config.min_samples_for_prediction {
return Vec::new();
}
let mut predictions = Vec::new();
let mut sorted_observations: Vec<_> = samples.iter().collect();
sorted_observations.sort_by_key(|o| o.observed_at);
let count = sorted_observations.len();
if count >= 2 {
let last = sorted_observations[count - 1];
let prev = sorted_observations[count - 2];
let delta = last.port.wrapping_sub(prev.port);
let next_1 = last.port.wrapping_add(delta);
predictions.push(next_1);
let next_2 = next_1.wrapping_add(delta);
predictions.push(next_2);
}
predictions
}
pub fn clear(&mut self, ip: IpAddr) {
self.history.remove(&ip);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
fn test_ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))
}
#[test]
fn test_linear_prediction_increment() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
predictor.record_observation(SocketAddr::new(ip, 1002), now + Duration::from_secs(1));
let predicted = predictor.predict_ports(ip);
assert!(predicted.contains(&1004));
assert!(predicted.contains(&1006));
}
#[test]
fn test_linear_prediction_decrement() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 2000), now);
predictor.record_observation(SocketAddr::new(ip, 1990), now + Duration::from_secs(1));
let predicted = predictor.predict_ports(ip);
assert!(predicted.contains(&1980));
assert!(predicted.contains(&1970));
}
#[test]
fn test_insufficient_samples() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
let predicted = predictor.predict_ports(ip);
assert!(predicted.is_empty());
}
#[test]
fn test_ttl_expiry() {
let mut config = PortPredictorConfig::default();
config.sample_ttl = Duration::from_millis(100);
let mut predictor = PortPredictor::new(config);
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
let future = now + Duration::from_millis(200);
predictor.record_observation(SocketAddr::new(ip, 1002), future);
let predicted = predictor.predict_ports(ip);
assert!(
predicted.is_empty(),
"Should not predict with only 1 valid sample"
);
}
#[test]
fn default_config_values_are_conservative() {
let config = PortPredictorConfig::default();
assert_eq!(config.max_samples, 10);
assert_eq!(config.sample_ttl, Duration::from_secs(60));
assert_eq!(config.min_samples_for_prediction, 2);
assert_eq!(config.max_prediction_attempts, 3);
}
#[test]
fn duplicate_ports_are_ignored() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
predictor.record_observation(SocketAddr::new(ip, 1000), now + Duration::from_secs(1));
predictor.record_observation(SocketAddr::new(ip, 1002), now + Duration::from_secs(2));
assert_eq!(predictor.predict_ports(ip), vec![1004, 1006]);
}
#[test]
fn max_samples_evicts_oldest_observations() {
let config = PortPredictorConfig {
max_samples: 2,
..PortPredictorConfig::default()
};
let mut predictor = PortPredictor::new(config);
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
predictor.record_observation(SocketAddr::new(ip, 1001), now + Duration::from_secs(1));
predictor.record_observation(SocketAddr::new(ip, 1010), now + Duration::from_secs(2));
assert_eq!(predictor.predict_ports(ip), vec![1019, 1028]);
}
#[test]
fn prediction_uses_observation_time_not_insertion_order() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1004), now + Duration::from_secs(2));
predictor.record_observation(SocketAddr::new(ip, 1000), now);
predictor.record_observation(SocketAddr::new(ip, 1002), now + Duration::from_secs(1));
assert_eq!(predictor.predict_ports(ip), vec![1006, 1008]);
}
#[test]
fn clear_removes_history_for_only_requested_ip() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let other_ip = IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1));
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 1000), now);
predictor.record_observation(SocketAddr::new(ip, 1001), now + Duration::from_secs(1));
predictor.record_observation(SocketAddr::new(other_ip, 2000), now);
predictor.record_observation(
SocketAddr::new(other_ip, 2001),
now + Duration::from_secs(1),
);
predictor.clear(ip);
assert!(predictor.predict_ports(ip).is_empty());
assert_eq!(predictor.predict_ports(other_ip), vec![2002, 2003]);
}
#[test]
fn wrapping_delta_predictions_are_supported() {
let mut predictor = PortPredictor::new(PortPredictorConfig::default());
let ip = test_ip();
let now = Instant::now();
predictor.record_observation(SocketAddr::new(ip, 65_534), now);
predictor.record_observation(SocketAddr::new(ip, 1), now + Duration::from_secs(1));
assert_eq!(predictor.predict_ports(ip), vec![4, 7]);
}
}