use crate::{
TransportError, VarInt,
frame::{Frame, FrameType},
};
use std::net::SocketAddr;
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct NatMigrationConfig {
pub accept_legacy_frames: bool,
pub send_rfc_frames: bool,
pub priority_strategy: PriorityCalculation,
}
impl Default for NatMigrationConfig {
fn default() -> Self {
Self {
accept_legacy_frames: true,
send_rfc_frames: false,
priority_strategy: PriorityCalculation::IceLike,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum PriorityCalculation {
IceLike,
Simple,
Fixed(u32),
}
impl NatMigrationConfig {
pub fn rfc_compliant() -> Self {
Self {
accept_legacy_frames: false,
send_rfc_frames: true,
priority_strategy: PriorityCalculation::IceLike,
}
}
pub fn legacy_only() -> Self {
Self {
accept_legacy_frames: true,
send_rfc_frames: false,
priority_strategy: PriorityCalculation::IceLike,
}
}
}
pub fn calculate_address_priority(addr: &SocketAddr, strategy: PriorityCalculation) -> u32 {
match strategy {
PriorityCalculation::Fixed(p) => p,
PriorityCalculation::Simple => simple_priority(addr),
PriorityCalculation::IceLike => ice_like_priority(addr),
}
}
fn simple_priority(addr: &SocketAddr) -> u32 {
match addr {
SocketAddr::V4(v4) => {
let ip = v4.ip();
if ip.is_loopback() {
100 } else if ip.is_private() {
200 } else {
300 }
}
SocketAddr::V6(v6) => {
let ip = v6.ip();
if ip.is_loopback() {
50 } else if ip.is_unicast_link_local() {
150 } else {
250 }
}
}
}
fn ice_like_priority(addr: &SocketAddr) -> u32 {
let type_pref = match addr {
SocketAddr::V4(v4) => {
let ip = v4.ip();
if ip.is_loopback() {
0 } else if ip.is_private() {
100 } else {
126 }
}
SocketAddr::V6(v6) => {
let ip = v6.ip();
if ip.is_loopback() {
0 } else if ip.is_unicast_link_local() {
90 } else {
120 }
}
};
let local_pref = match addr {
SocketAddr::V4(_) => 65535, SocketAddr::V6(_) => 65534, };
let component_id = 1;
((type_pref as u32) << 24) + ((local_pref as u32) << 8) + (256 - component_id)
}
pub struct FrameMigrator {
config: NatMigrationConfig,
}
impl FrameMigrator {
#[allow(missing_docs)]
pub fn new(config: NatMigrationConfig) -> Self {
Self { config }
}
pub fn should_send_rfc_frames(&self) -> bool {
self.config.send_rfc_frames
}
pub fn process_incoming_frame(
&self,
_frame_type: FrameType,
frame: Frame,
_sender_addr: SocketAddr,
) -> Result<Frame, TransportError> {
match frame {
Frame::AddAddress(mut add) => {
if add.priority == VarInt::from_u32(0) {
add.priority = VarInt::from_u32(calculate_address_priority(
&add.address,
self.config.priority_strategy,
));
}
Ok(Frame::AddAddress(add))
}
Frame::PunchMeNow(punch) => {
Ok(Frame::PunchMeNow(punch))
}
_ => Ok(frame),
}
}
pub fn should_accept_frame(&self, frame_type: FrameType) -> bool {
if self.config.accept_legacy_frames {
true
} else {
matches!(
frame_type,
FrameType::ADD_ADDRESS_IPV4
| FrameType::ADD_ADDRESS_IPV6
| FrameType::PUNCH_ME_NOW_IPV4
| FrameType::PUNCH_ME_NOW_IPV6
| FrameType::REMOVE_ADDRESS
)
}
}
}
#[derive(Debug, Clone)]
pub struct PeerCapabilities {
pub peer_id: Vec<u8>,
pub supports_rfc_nat: bool,
pub discovered_at: std::time::Instant,
}
pub struct CapabilityTracker {
peers: std::collections::HashMap<Vec<u8>, PeerCapabilities>,
}
impl CapabilityTracker {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {
peers: std::collections::HashMap::new(),
}
}
pub fn mark_rfc_capable(&mut self, peer_id: Vec<u8>) {
self.peers.insert(
peer_id.clone(),
PeerCapabilities {
peer_id,
supports_rfc_nat: true,
discovered_at: std::time::Instant::now(),
},
);
}
pub fn is_rfc_capable(&self, peer_id: &[u8]) -> bool {
self.peers
.get(peer_id)
.map(|cap| cap.supports_rfc_nat)
.unwrap_or(false)
}
pub fn cleanup_old_entries(&mut self, max_age: std::time::Duration) {
let now = std::time::Instant::now();
self.peers
.retain(|_, cap| now.duration_since(cap.discovered_at) < max_age);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_priority_calculation() {
let public_v4: SocketAddr = "8.8.8.8:53".parse().unwrap();
let private_v4: SocketAddr = "192.168.1.1:80".parse().unwrap();
let loopback_v4: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let simple_pub = calculate_address_priority(&public_v4, PriorityCalculation::Simple);
let simple_priv = calculate_address_priority(&private_v4, PriorityCalculation::Simple);
let simple_loop = calculate_address_priority(&loopback_v4, PriorityCalculation::Simple);
assert!(simple_pub > simple_priv);
assert!(simple_priv > simple_loop);
let ice_pub = calculate_address_priority(&public_v4, PriorityCalculation::IceLike);
let ice_priv = calculate_address_priority(&private_v4, PriorityCalculation::IceLike);
let ice_loop = calculate_address_priority(&loopback_v4, PriorityCalculation::IceLike);
assert!(ice_pub > ice_priv);
assert!(ice_priv > ice_loop);
let fixed = calculate_address_priority(&public_v4, PriorityCalculation::Fixed(12345));
assert_eq!(fixed, 12345);
}
#[test]
fn test_migration_configs() {
let default_config = NatMigrationConfig::default();
assert!(default_config.accept_legacy_frames);
assert!(!default_config.send_rfc_frames);
let rfc_config = NatMigrationConfig::rfc_compliant();
assert!(!rfc_config.accept_legacy_frames);
assert!(rfc_config.send_rfc_frames);
let legacy_config = NatMigrationConfig::legacy_only();
assert!(legacy_config.accept_legacy_frames);
assert!(!legacy_config.send_rfc_frames);
}
#[test]
fn test_capability_tracker() {
let mut tracker = CapabilityTracker::new();
let peer_id = vec![1, 2, 3, 4];
assert!(!tracker.is_rfc_capable(&peer_id));
tracker.mark_rfc_capable(peer_id.clone());
assert!(tracker.is_rfc_capable(&peer_id));
tracker.cleanup_old_entries(std::time::Duration::from_secs(3600));
assert!(tracker.is_rfc_capable(&peer_id)); }
}