use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use super::flow_hash::{FlowId, EcmpPathEnumerator};
use super::nat::{NatId, NatProbe, NatProbeResponse, IpIdMarker, UplinkNatState};
#[derive(Debug, Clone)]
pub struct Hop {
pub ttl: u8,
pub addr: Option<SocketAddr>,
pub rtt: Option<Duration>,
pub nat_id: NatId,
pub flow_hash: u16,
pub nat_detected: bool,
pub icmp_type: Option<u8>,
pub is_last: bool,
pub discovered_at: Instant,
}
impl Hop {
pub fn from_response(
probe: &NatProbe,
response: &NatProbeResponse,
prev_nat_id: Option<NatId>,
) -> Self {
let nat_id = response.nat_id(probe.marker.udp_checksum);
let nat_detected = prev_nat_id.is_some_and(|prev| prev != nat_id);
Self {
ttl: probe.marker.ttl,
addr: Some(response.responder_addr),
rtt: Some(response.rtt(probe.sent_at)),
nat_id,
flow_hash: FlowId::from_udp(probe.src_addr, probe.dst_addr).flow_hash(),
nat_detected,
icmp_type: Some(response.icmp_type),
is_last: response.is_destination(),
discovered_at: response.received_at,
}
}
pub fn non_responding(ttl: u8, flow_hash: u16) -> Self {
Self {
ttl,
addr: None,
rtt: None,
nat_id: NatId::default(),
flow_hash,
nat_detected: false,
icmp_type: None,
is_last: false,
discovered_at: Instant::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct DiscoveredPath {
pub flow_id: FlowId,
pub flow_hash: u16,
pub hops: Vec<Hop>,
pub reaches_destination: bool,
pub total_rtt: Option<Duration>,
pub nat_count: usize,
pub quality_score: f64,
pub last_updated: Instant,
}
impl DiscoveredPath {
pub fn new(flow_id: FlowId) -> Self {
Self {
flow_hash: flow_id.flow_hash(),
flow_id,
hops: Vec::new(),
reaches_destination: false,
total_rtt: None,
nat_count: 0,
quality_score: 0.0,
last_updated: Instant::now(),
}
}
pub fn add_hop(&mut self, hop: Hop) {
if hop.nat_detected {
self.nat_count += 1;
}
if hop.is_last {
self.reaches_destination = true;
self.total_rtt = hop.rtt;
}
self.hops.push(hop);
self.hops.sort_by_key(|h| h.ttl);
self.update_quality();
self.last_updated = Instant::now();
}
fn update_quality(&mut self) {
if self.hops.is_empty() {
self.quality_score = 0.0;
return;
}
let mut score = 1.0;
let responding = self.hops.iter().filter(|h| h.addr.is_some()).count();
let response_ratio = responding as f64 / self.hops.len() as f64;
score *= response_ratio;
if let Some(rtt) = self.total_rtt {
let rtt_ms = rtt.as_secs_f64() * 1000.0;
score *= 1.0 / (1.0 + rtt_ms / 100.0);
}
score *= 1.0 / (1.0 + self.nat_count as f64 * 0.1);
if self.reaches_destination {
score *= 1.2;
}
self.quality_score = score.clamp(0.0, 1.0);
}
pub fn length(&self) -> usize {
self.hops.len()
}
pub fn is_complete(&self) -> bool {
self.reaches_destination
}
pub fn hop_at_ttl(&self, ttl: u8) -> Option<&Hop> {
self.hops.iter().find(|h| h.ttl == ttl)
}
}
#[derive(Debug, Default, Clone)]
pub struct PathDiversity {
pub unique_paths: usize,
pub unique_first_hops: usize,
pub unique_intermediate_hops: usize,
pub diversity_score: f64,
pub recommended_paths: usize,
}
impl PathDiversity {
pub fn from_paths(paths: &[DiscoveredPath]) -> Self {
if paths.is_empty() {
return Self::default();
}
let unique_paths = paths.len();
let first_hops: std::collections::HashSet<_> = paths
.iter()
.filter_map(|p| p.hops.first())
.filter_map(|h| h.addr)
.collect();
let unique_first_hops = first_hops.len();
let intermediate_hops: std::collections::HashSet<_> = paths
.iter()
.flat_map(|p| {
p.hops.iter()
.skip(1)
.filter(|h| !h.is_last)
.filter_map(|h| h.addr)
})
.collect();
let unique_intermediate_hops = intermediate_hops.len();
let path_diversity = if unique_paths > 1 {
(unique_first_hops as f64 / unique_paths as f64).min(1.0)
} else {
0.0
};
let hop_diversity = if intermediate_hops.is_empty() {
0.5
} else {
(unique_intermediate_hops as f64 / (unique_paths * 3) as f64).min(1.0)
};
let diversity_score = (path_diversity * 0.6 + hop_diversity * 0.4).clamp(0.0, 1.0);
let recommended_paths = match unique_first_hops {
0 => 1,
1 => unique_paths.min(2),
2..=3 => unique_first_hops,
_ => 4.min(unique_first_hops),
};
Self {
unique_paths,
unique_first_hops,
unique_intermediate_hops,
diversity_score,
recommended_paths,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PathDiscoveryConfig {
#[serde(default = "default_min_ttl")]
pub min_ttl: u8,
#[serde(default = "default_max_ttl")]
pub max_ttl: u8,
#[serde(default = "default_num_paths")]
pub num_paths: u16,
#[serde(default = "default_base_src_port")]
pub base_src_port: u16,
#[serde(default = "default_use_src_port")]
pub use_src_port: bool,
#[serde(default = "default_probe_timeout", with = "humantime_serde")]
pub probe_timeout: Duration,
#[serde(default = "default_probe_delay", with = "humantime_serde")]
pub probe_delay: Duration,
#[serde(default = "default_retries")]
pub retries: u8,
}
fn default_min_ttl() -> u8 { 1 }
fn default_max_ttl() -> u8 { 32 }
fn default_num_paths() -> u16 { 8 }
fn default_base_src_port() -> u16 { 33434 }
fn default_use_src_port() -> bool { true }
fn default_probe_timeout() -> Duration { Duration::from_secs(3) }
fn default_probe_delay() -> Duration { Duration::from_millis(50) }
fn default_retries() -> u8 { 2 }
impl Default for PathDiscoveryConfig {
fn default() -> Self {
Self {
min_ttl: default_min_ttl(),
max_ttl: default_max_ttl(),
num_paths: default_num_paths(),
base_src_port: default_base_src_port(),
use_src_port: default_use_src_port(),
probe_timeout: default_probe_timeout(),
probe_delay: default_probe_delay(),
retries: default_retries(),
}
}
}
#[derive(Debug)]
pub struct PathDiscovery {
config: PathDiscoveryConfig,
paths: RwLock<HashMap<SocketAddr, Vec<DiscoveredPath>>>,
nat_states: RwLock<HashMap<u16, UplinkNatState>>,
diversity_cache: RwLock<HashMap<SocketAddr, PathDiversity>>,
}
impl PathDiscovery {
pub fn new(config: PathDiscoveryConfig) -> Self {
Self {
config,
paths: RwLock::new(HashMap::new()),
nat_states: RwLock::new(HashMap::new()),
diversity_cache: RwLock::new(HashMap::new()),
}
}
pub fn get_nat_state(&self, uplink_id: u16) -> UplinkNatState {
let states = self.nat_states.read();
states.get(&uplink_id).cloned().unwrap_or_default()
}
pub fn update_nat_state<F>(&self, uplink_id: u16, f: F)
where
F: FnOnce(&mut UplinkNatState),
{
let mut states = self.nat_states.write();
let state = states.entry(uplink_id).or_default();
f(state);
}
pub fn generate_probes(
&self,
src_addr: SocketAddr,
dst_addr: SocketAddr,
) -> Vec<(u8, FlowId, IpIdMarker)> {
let base_flow = FlowId::from_udp(src_addr, dst_addr);
let enumerator = EcmpPathEnumerator::new(
base_flow,
self.config.base_src_port,
self.config.num_paths,
self.config.use_src_port,
);
let mut probes = Vec::new();
for flow in enumerator.flows() {
let flow_id = if self.config.use_src_port {
flow.src_port
} else {
flow.dst_port
};
for ttl in self.config.min_ttl..=self.config.max_ttl {
let marker = IpIdMarker::from_probe(ttl, flow_id, self.config.use_src_port);
probes.push((ttl, flow, marker));
}
}
probes
}
pub fn record_hop(&self, dst_addr: SocketAddr, hop: Hop) {
let mut paths = self.paths.write();
let dest_paths = paths.entry(dst_addr).or_default();
let path = dest_paths
.iter_mut()
.find(|p| p.flow_hash == hop.flow_hash);
if let Some(path) = path {
path.add_hop(hop);
} else {
let flow_id = FlowId::new(
dst_addr.ip(), dst_addr.ip(),
0,
0,
17,
);
let mut new_path = DiscoveredPath::new(flow_id);
new_path.flow_hash = hop.flow_hash;
new_path.add_hop(hop);
dest_paths.push(new_path);
}
self.diversity_cache.write().remove(&dst_addr);
}
pub fn get_paths(&self, dst_addr: SocketAddr) -> Vec<DiscoveredPath> {
self.paths.read().get(&dst_addr).cloned().unwrap_or_default()
}
pub fn get_diversity(&self, dst_addr: SocketAddr) -> PathDiversity {
if let Some(diversity) = self.diversity_cache.read().get(&dst_addr) {
return diversity.clone();
}
let paths = self.get_paths(dst_addr);
let diversity = PathDiversity::from_paths(&paths);
self.diversity_cache.write().insert(dst_addr, diversity.clone());
diversity
}
pub fn get_best_paths(&self, dst_addr: SocketAddr, count: usize) -> Vec<DiscoveredPath> {
let mut paths = self.get_paths(dst_addr);
paths.sort_by(|a, b| {
b.quality_score
.partial_cmp(&a.quality_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
paths.truncate(count);
paths
}
pub fn get_ecmp_flow_hashes(&self, dst_addr: SocketAddr) -> Vec<u16> {
let paths = self.paths.read();
paths
.get(&dst_addr)
.map(|p| p.iter().map(|path| path.flow_hash).collect())
.unwrap_or_default()
}
pub fn has_nat(&self, dst_addr: SocketAddr) -> bool {
self.paths
.read()
.get(&dst_addr)
.is_some_and(|paths| paths.iter().any(|p| p.nat_count > 0))
}
pub fn clear_paths(&self, dst_addr: SocketAddr) {
self.paths.write().remove(&dst_addr);
self.diversity_cache.write().remove(&dst_addr);
}
pub fn clear_all(&self) {
self.paths.write().clear();
self.diversity_cache.write().clear();
}
pub fn config(&self) -> &PathDiscoveryConfig {
&self.config
}
pub fn cleanup(&self, max_age: Duration) {
let now = Instant::now();
self.paths.write().retain(|_, paths| {
paths.retain(|p| now.duration_since(p.last_updated) < max_age);
!paths.is_empty()
});
self.diversity_cache.write().clear();
}
}
impl Default for PathDiscovery {
fn default() -> Self {
Self::new(PathDiscoveryConfig::default())
}
}
#[derive(Debug)]
pub struct EcmpFlowSelector {
hash_to_uplink: RwLock<HashMap<u16, u16>>,
preferred_hashes: RwLock<HashMap<SocketAddr, Vec<u16>>>,
}
impl Default for EcmpFlowSelector {
fn default() -> Self {
Self::new()
}
}
impl EcmpFlowSelector {
pub fn new() -> Self {
Self {
hash_to_uplink: RwLock::new(HashMap::new()),
preferred_hashes: RwLock::new(HashMap::new()),
}
}
pub fn set_mapping(&self, flow_hash: u16, uplink_id: u16) {
self.hash_to_uplink.write().insert(flow_hash, uplink_id);
}
pub fn get_uplink(&self, flow_hash: u16) -> Option<u16> {
self.hash_to_uplink.read().get(&flow_hash).copied()
}
pub fn set_preferred(&self, dst: SocketAddr, hashes: Vec<u16>) {
self.preferred_hashes.write().insert(dst, hashes);
}
pub fn get_preferred(&self, dst: SocketAddr) -> Option<u16> {
self.preferred_hashes
.read()
.get(&dst)
.and_then(|h| h.first().copied())
}
pub fn suggest_port_for_path(&self, base_flow: FlowId, target_hash: u16) -> Option<u16> {
for offset in 0..1000u16 {
let port = base_flow.src_port.wrapping_add(offset);
let test_flow = base_flow.with_src_port(port);
if test_flow.flow_hash() == target_hash {
return Some(port);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_discovered_path() {
let flow = FlowId::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
12345,
53,
17,
);
let mut path = DiscoveredPath::new(flow);
assert!(!path.is_complete());
assert_eq!(path.length(), 0);
let hop = Hop {
ttl: 1,
addr: Some(SocketAddr::from(([192, 168, 1, 254], 0))),
rtt: Some(Duration::from_millis(5)),
nat_id: NatId(0),
flow_hash: flow.flow_hash(),
nat_detected: false,
icmp_type: Some(11),
is_last: false,
discovered_at: Instant::now(),
};
path.add_hop(hop);
assert_eq!(path.length(), 1);
assert!(!path.is_complete());
}
#[test]
fn test_path_diversity() {
let flow1 = FlowId::new(
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
12345,
53,
17,
);
let flow2 = flow1.with_src_port(12346);
let path1 = DiscoveredPath::new(flow1);
let path2 = DiscoveredPath::new(flow2);
let diversity = PathDiversity::from_paths(&[path1, path2]);
assert_eq!(diversity.unique_paths, 2);
}
#[test]
fn test_path_discovery_config() {
let config = PathDiscoveryConfig::default();
assert_eq!(config.min_ttl, 1);
assert_eq!(config.max_ttl, 32);
assert_eq!(config.num_paths, 8);
}
#[test]
fn test_ecmp_flow_selector() {
let selector = EcmpFlowSelector::new();
selector.set_mapping(0x1234, 1);
assert_eq!(selector.get_uplink(0x1234), Some(1));
assert_eq!(selector.get_uplink(0x5678), None);
}
#[test]
fn test_hop_creation() {
let hop = Hop::non_responding(5, 0x1234);
assert!(hop.addr.is_none());
assert_eq!(hop.ttl, 5);
assert_eq!(hop.flow_hash, 0x1234);
}
}