use std::net::SocketAddr;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConnectionMethod {
DirectIPv4,
DirectIPv6,
HolePunched {
coordinator: SocketAddr,
},
Relayed {
relay: SocketAddr,
},
}
impl std::fmt::Display for ConnectionMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConnectionMethod::DirectIPv4 => write!(f, "Direct IPv4"),
ConnectionMethod::DirectIPv6 => write!(f, "Direct IPv6"),
ConnectionMethod::HolePunched { coordinator } => {
write!(f, "Hole-punched via {}", coordinator)
}
ConnectionMethod::Relayed { relay } => write!(f, "Relayed via {}", relay),
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectionAttemptError {
pub method: AttemptedMethod,
pub error: String,
pub timestamp: Instant,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AttemptedMethod {
DirectIPv4,
DirectIPv6,
HolePunch {
round: u32,
},
Relay,
}
#[derive(Debug, Clone)]
pub enum ConnectionStage {
DirectIPv4 {
started: Instant,
},
DirectIPv6 {
started: Instant,
},
HolePunching {
coordinator: SocketAddr,
round: u32,
started: Instant,
},
Relay {
relay_addr: SocketAddr,
relay_index: usize,
started: Instant,
},
Connected {
via: ConnectionMethod,
},
Failed {
errors: Vec<ConnectionAttemptError>,
},
}
#[derive(Debug, Clone)]
pub struct StrategyConfig {
pub ipv4_timeout: Duration,
pub ipv6_timeout: Duration,
pub holepunch_timeout: Duration,
pub relay_timeout: Duration,
pub max_holepunch_rounds: u32,
pub ipv6_enabled: bool,
pub relay_enabled: bool,
pub coordinator: Option<SocketAddr>,
pub relay_addrs: Vec<SocketAddr>,
}
impl Default for StrategyConfig {
fn default() -> Self {
Self {
ipv4_timeout: Duration::from_secs(5),
ipv6_timeout: Duration::from_secs(5),
holepunch_timeout: Duration::from_secs(15),
relay_timeout: Duration::from_secs(30),
max_holepunch_rounds: 3,
ipv6_enabled: true,
relay_enabled: true,
coordinator: None,
relay_addrs: Vec::new(),
}
}
}
impl StrategyConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_ipv4_timeout(mut self, timeout: Duration) -> Self {
self.ipv4_timeout = timeout;
self
}
pub fn with_ipv6_timeout(mut self, timeout: Duration) -> Self {
self.ipv6_timeout = timeout;
self
}
pub fn with_holepunch_timeout(mut self, timeout: Duration) -> Self {
self.holepunch_timeout = timeout;
self
}
pub fn with_relay_timeout(mut self, timeout: Duration) -> Self {
self.relay_timeout = timeout;
self
}
pub fn with_max_holepunch_rounds(mut self, rounds: u32) -> Self {
self.max_holepunch_rounds = rounds;
self
}
pub fn with_ipv6_enabled(mut self, enabled: bool) -> Self {
self.ipv6_enabled = enabled;
self
}
pub fn with_relay_enabled(mut self, enabled: bool) -> Self {
self.relay_enabled = enabled;
self
}
pub fn with_coordinator(mut self, addr: SocketAddr) -> Self {
self.coordinator = Some(addr);
self
}
pub fn with_relay(mut self, addr: SocketAddr) -> Self {
self.relay_addrs.push(addr);
self
}
pub fn with_relays(mut self, addrs: Vec<SocketAddr>) -> Self {
self.relay_addrs = addrs;
self
}
}
#[derive(Debug)]
pub struct ConnectionStrategy {
stage: ConnectionStage,
config: StrategyConfig,
errors: Vec<ConnectionAttemptError>,
}
impl ConnectionStrategy {
pub fn new(config: StrategyConfig) -> Self {
Self {
stage: ConnectionStage::DirectIPv4 {
started: Instant::now(),
},
config,
errors: Vec::new(),
}
}
pub fn current_stage(&self) -> &ConnectionStage {
&self.stage
}
pub fn config(&self) -> &StrategyConfig {
&self.config
}
pub fn ipv4_timeout(&self) -> Duration {
self.config.ipv4_timeout
}
pub fn ipv6_timeout(&self) -> Duration {
self.config.ipv6_timeout
}
pub fn holepunch_timeout(&self) -> Duration {
self.config.holepunch_timeout
}
pub fn relay_timeout(&self) -> Duration {
self.config.relay_timeout
}
pub fn transition_to_ipv6(&mut self, error: impl Into<String>) {
self.errors.push(ConnectionAttemptError {
method: AttemptedMethod::DirectIPv4,
error: error.into(),
timestamp: Instant::now(),
});
if self.config.ipv6_enabled {
self.stage = ConnectionStage::DirectIPv6 {
started: Instant::now(),
};
} else {
self.transition_to_holepunch_internal();
}
}
pub fn transition_to_holepunch(&mut self, error: impl Into<String>) {
self.errors.push(ConnectionAttemptError {
method: AttemptedMethod::DirectIPv6,
error: error.into(),
timestamp: Instant::now(),
});
self.transition_to_holepunch_internal();
}
fn transition_to_holepunch_internal(&mut self) {
if let Some(coordinator) = self.config.coordinator {
self.stage = ConnectionStage::HolePunching {
coordinator,
round: 1,
started: Instant::now(),
};
} else {
self.transition_to_relay_internal();
}
}
pub fn record_holepunch_error(&mut self, round: u32, error: impl Into<String>) {
self.errors.push(ConnectionAttemptError {
method: AttemptedMethod::HolePunch { round },
error: error.into(),
timestamp: Instant::now(),
});
}
pub fn should_retry_holepunch(&self) -> bool {
if let ConnectionStage::HolePunching { round, .. } = &self.stage {
*round < self.config.max_holepunch_rounds
} else {
false
}
}
pub fn set_coordinator(&mut self, coordinator: SocketAddr) {
if let ConnectionStage::HolePunching { coordinator: c, .. } = &mut self.stage {
*c = coordinator;
}
self.config.coordinator = Some(coordinator);
}
pub fn increment_round(&mut self) {
if let ConnectionStage::HolePunching {
coordinator, round, ..
} = &self.stage
{
self.stage = ConnectionStage::HolePunching {
coordinator: *coordinator,
round: round + 1,
started: Instant::now(),
};
}
}
pub fn transition_to_relay(&mut self, error: impl Into<String>) {
if let ConnectionStage::HolePunching { round, .. } = &self.stage {
self.errors.push(ConnectionAttemptError {
method: AttemptedMethod::HolePunch { round: *round },
error: error.into(),
timestamp: Instant::now(),
});
}
self.transition_to_relay_internal();
}
pub fn transition_to_next_relay(&mut self, error: impl Into<String>) {
if let ConnectionStage::Relay { relay_index, .. } = &self.stage {
self.errors.push(ConnectionAttemptError {
method: AttemptedMethod::Relay,
error: error.into(),
timestamp: Instant::now(),
});
let next_index = relay_index + 1;
if next_index < self.config.relay_addrs.len() {
self.stage = ConnectionStage::Relay {
relay_addr: self.config.relay_addrs[next_index],
relay_index: next_index,
started: Instant::now(),
};
} else {
self.stage = ConnectionStage::Failed {
errors: std::mem::take(&mut self.errors),
};
}
}
}
fn transition_to_relay_internal(&mut self) {
if self.config.relay_enabled && !self.config.relay_addrs.is_empty() {
self.stage = ConnectionStage::Relay {
relay_addr: self.config.relay_addrs[0],
relay_index: 0,
started: Instant::now(),
};
} else if !self.config.relay_enabled {
self.transition_to_failed("Relay disabled and all other methods failed");
} else {
self.transition_to_failed("No relay servers configured");
}
}
pub fn transition_to_failed(&mut self, error: impl Into<String>) {
if let ConnectionStage::Relay { .. } = &self.stage {
self.errors.push(ConnectionAttemptError {
method: AttemptedMethod::Relay,
error: error.into(),
timestamp: Instant::now(),
});
}
self.stage = ConnectionStage::Failed {
errors: std::mem::take(&mut self.errors),
};
}
pub fn mark_connected(&mut self, method: ConnectionMethod) {
self.stage = ConnectionStage::Connected { via: method };
}
pub fn is_terminal(&self) -> bool {
matches!(
self.stage,
ConnectionStage::Connected { .. } | ConnectionStage::Failed { .. }
)
}
pub fn errors(&self) -> &[ConnectionAttemptError] {
&self.errors
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = StrategyConfig::default();
assert_eq!(config.ipv4_timeout, Duration::from_secs(5));
assert_eq!(config.ipv6_timeout, Duration::from_secs(5));
assert_eq!(config.holepunch_timeout, Duration::from_secs(15));
assert_eq!(config.relay_timeout, Duration::from_secs(30));
assert_eq!(config.max_holepunch_rounds, 3);
assert!(config.ipv6_enabled);
assert!(config.relay_enabled);
}
#[test]
fn test_config_builder() {
let config = StrategyConfig::new()
.with_ipv4_timeout(Duration::from_secs(3))
.with_ipv6_timeout(Duration::from_secs(3))
.with_max_holepunch_rounds(5)
.with_ipv6_enabled(false);
assert_eq!(config.ipv4_timeout, Duration::from_secs(3));
assert_eq!(config.max_holepunch_rounds, 5);
assert!(!config.ipv6_enabled);
}
#[test]
fn test_initial_stage() {
let strategy = ConnectionStrategy::new(StrategyConfig::default());
assert!(matches!(
strategy.current_stage(),
ConnectionStage::DirectIPv4 { .. }
));
}
#[test]
fn test_transition_ipv4_to_ipv6() {
let mut strategy = ConnectionStrategy::new(StrategyConfig::default());
strategy.transition_to_ipv6("Connection refused");
assert!(matches!(
strategy.current_stage(),
ConnectionStage::DirectIPv6 { .. }
));
assert_eq!(strategy.errors().len(), 1);
assert!(matches!(
strategy.errors()[0].method,
AttemptedMethod::DirectIPv4
));
}
#[test]
fn test_skip_ipv6_when_disabled() {
let config = StrategyConfig::new()
.with_ipv6_enabled(false)
.with_coordinator("127.0.0.1:9000".parse().unwrap());
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("Connection refused");
assert!(matches!(
strategy.current_stage(),
ConnectionStage::HolePunching { round: 1, .. }
));
}
#[test]
fn test_transition_to_holepunch() {
let config = StrategyConfig::new().with_coordinator("127.0.0.1:9000".parse().unwrap());
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
assert!(matches!(
strategy.current_stage(),
ConnectionStage::HolePunching {
round: 1,
coordinator,
..
} if coordinator.port() == 9000
));
}
#[test]
fn test_holepunch_rounds() {
let config = StrategyConfig::new()
.with_coordinator("127.0.0.1:9000".parse().unwrap())
.with_max_holepunch_rounds(3);
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
assert!(strategy.should_retry_holepunch());
strategy.record_holepunch_error(1, "Round 1 failed");
strategy.increment_round();
if let ConnectionStage::HolePunching { round, .. } = strategy.current_stage() {
assert_eq!(*round, 2);
} else {
panic!("Expected HolePunching stage");
}
assert!(strategy.should_retry_holepunch());
strategy.record_holepunch_error(2, "Round 2 failed");
strategy.increment_round();
if let ConnectionStage::HolePunching { round, .. } = strategy.current_stage() {
assert_eq!(*round, 3);
} else {
panic!("Expected HolePunching stage");
}
assert!(!strategy.should_retry_holepunch());
}
#[test]
fn test_transition_to_relay() {
let config = StrategyConfig::new()
.with_coordinator("127.0.0.1:9000".parse().unwrap())
.with_relay("127.0.0.1:9001".parse().unwrap());
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
strategy.transition_to_relay("Holepunch failed");
if let ConnectionStage::Relay {
relay_addr,
relay_index,
..
} = strategy.current_stage()
{
assert_eq!(relay_addr.port(), 9001);
assert_eq!(*relay_index, 0);
} else {
panic!("Expected Relay stage");
}
}
#[test]
fn test_transition_to_failed() {
let config = StrategyConfig::new()
.with_coordinator("127.0.0.1:9000".parse().unwrap())
.with_relay("127.0.0.1:9001".parse().unwrap());
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
strategy.transition_to_relay("Holepunch failed");
strategy.transition_to_failed("Relay failed");
if let ConnectionStage::Failed { errors } = strategy.current_stage() {
assert_eq!(errors.len(), 4);
} else {
panic!("Expected Failed stage");
}
}
#[test]
fn test_mark_connected() {
let mut strategy = ConnectionStrategy::new(StrategyConfig::default());
strategy.mark_connected(ConnectionMethod::DirectIPv4);
if let ConnectionStage::Connected { via } = strategy.current_stage() {
assert_eq!(*via, ConnectionMethod::DirectIPv4);
} else {
panic!("Expected Connected stage");
}
assert!(strategy.is_terminal());
}
#[test]
fn test_connection_method_display() {
assert_eq!(format!("{}", ConnectionMethod::DirectIPv4), "Direct IPv4");
assert_eq!(format!("{}", ConnectionMethod::DirectIPv6), "Direct IPv6");
assert_eq!(
format!(
"{}",
ConnectionMethod::HolePunched {
coordinator: "1.2.3.4:9000".parse().unwrap()
}
),
"Hole-punched via 1.2.3.4:9000"
);
assert_eq!(
format!(
"{}",
ConnectionMethod::Relayed {
relay: "5.6.7.8:9001".parse().unwrap()
}
),
"Relayed via 5.6.7.8:9001"
);
}
#[test]
fn test_no_coordinator_skips_to_relay() {
let config = StrategyConfig::new().with_relay("127.0.0.1:9001".parse().unwrap());
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
assert!(matches!(
strategy.current_stage(),
ConnectionStage::Relay { .. }
));
}
#[test]
fn test_no_relay_fails() {
let config = StrategyConfig::new()
.with_coordinator("127.0.0.1:9000".parse().unwrap())
.with_relay_enabled(false);
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
strategy.transition_to_relay("Holepunch failed");
assert!(matches!(
strategy.current_stage(),
ConnectionStage::Failed { .. }
));
}
#[test]
fn test_multi_relay_fallback() {
let config = StrategyConfig::new()
.with_coordinator("127.0.0.1:9000".parse().unwrap())
.with_relay("127.0.0.1:9001".parse().unwrap())
.with_relay("127.0.0.1:9002".parse().unwrap())
.with_relay("127.0.0.1:9003".parse().unwrap());
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
strategy.transition_to_relay("Holepunch failed");
if let ConnectionStage::Relay {
relay_addr,
relay_index,
..
} = strategy.current_stage()
{
assert_eq!(relay_addr.port(), 9001);
assert_eq!(*relay_index, 0);
} else {
panic!("Expected Relay stage");
}
strategy.transition_to_next_relay("Relay 1 failed");
if let ConnectionStage::Relay {
relay_addr,
relay_index,
..
} = strategy.current_stage()
{
assert_eq!(relay_addr.port(), 9002);
assert_eq!(*relay_index, 1);
} else {
panic!("Expected Relay stage");
}
strategy.transition_to_next_relay("Relay 2 failed");
if let ConnectionStage::Relay {
relay_addr,
relay_index,
..
} = strategy.current_stage()
{
assert_eq!(relay_addr.port(), 9003);
assert_eq!(*relay_index, 2);
} else {
panic!("Expected Relay stage");
}
strategy.transition_to_next_relay("Relay 3 failed");
if let ConnectionStage::Failed { errors } = strategy.current_stage() {
assert_eq!(errors.len(), 6);
} else {
panic!("Expected Failed stage");
}
}
#[test]
fn test_with_relays_vec() {
let relays: Vec<SocketAddr> = vec![
"127.0.0.1:9001".parse().unwrap(),
"127.0.0.1:9002".parse().unwrap(),
];
let config = StrategyConfig::new().with_relays(relays);
assert_eq!(config.relay_addrs.len(), 2);
}
#[test]
fn test_single_relay_still_works() {
let config = StrategyConfig::new().with_relay("127.0.0.1:9001".parse().unwrap());
let mut strategy = ConnectionStrategy::new(config);
strategy.transition_to_ipv6("IPv4 failed");
strategy.transition_to_holepunch("IPv6 failed");
strategy.transition_to_relay("Holepunch failed");
if let ConnectionStage::Relay { relay_addr, .. } = strategy.current_stage() {
assert_eq!(relay_addr.port(), 9001);
} else {
panic!("Expected Relay stage");
}
strategy.transition_to_next_relay("Relay failed");
assert!(matches!(
strategy.current_stage(),
ConnectionStage::Failed { .. }
));
}
}