use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use tokio::net::UdpSocket;
use tokio::sync::broadcast;
use tracing::debug;
pub const DEFAULT_STUN_SERVERS: &[&str] = &[
"stun.l.google.com:19302",
"stun1.l.google.com:19302",
"stun.cloudflare.com:3478",
"stun.stunprotocol.org:3478",
];
pub const DEFAULT_DNS_SERVERS: &[&str] = &[
"8.8.8.8:53",
"1.1.1.1:53",
"9.9.9.9:53",
];
#[derive(Debug, Clone)]
pub struct ConnectivityConfig {
pub stun_servers: Vec<String>,
pub dns_servers: Vec<String>,
pub timeout: Duration,
pub probe_interval: Duration,
pub retries: u32,
}
impl Default for ConnectivityConfig {
fn default() -> Self {
Self {
stun_servers: DEFAULT_STUN_SERVERS.iter().map(|s| s.to_string()).collect(),
dns_servers: DEFAULT_DNS_SERVERS.iter().map(|s| s.to_string()).collect(),
timeout: Duration::from_secs(3),
probe_interval: Duration::from_secs(30),
retries: 2,
}
}
}
#[derive(Debug, Clone)]
pub struct ConnectivityResult {
pub interface: String,
pub local_addr: SocketAddr,
pub has_internet: bool,
pub external_addr: Option<SocketAddr>,
pub nat_type: NatType,
pub rtt: Option<Duration>,
pub timestamp: Instant,
pub error: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NatType {
None,
FullCone,
RestrictedCone,
PortRestrictedCone,
Symmetric,
Unknown,
}
impl std::fmt::Display for NatType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
NatType::None => write!(f, "No NAT"),
NatType::FullCone => write!(f, "Full Cone"),
NatType::RestrictedCone => write!(f, "Restricted Cone"),
NatType::PortRestrictedCone => write!(f, "Port Restricted Cone"),
NatType::Symmetric => write!(f, "Symmetric"),
NatType::Unknown => write!(f, "Unknown"),
}
}
}
pub struct ConnectivityProber {
config: ConnectivityConfig,
results: Arc<RwLock<HashMap<String, ConnectivityResult>>>,
event_tx: broadcast::Sender<ConnectivityResult>,
shutdown_tx: broadcast::Sender<()>,
}
impl ConnectivityProber {
pub fn new(config: ConnectivityConfig) -> Self {
let (event_tx, _) = broadcast::channel(64);
let (shutdown_tx, _) = broadcast::channel(1);
Self {
config,
results: Arc::new(RwLock::new(HashMap::new())),
event_tx,
shutdown_tx,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<ConnectivityResult> {
self.event_tx.subscribe()
}
pub fn get_result(&self, interface: &str) -> Option<ConnectivityResult> {
self.results.read().get(interface).cloned()
}
pub fn all_results(&self) -> HashMap<String, ConnectivityResult> {
self.results.read().clone()
}
pub async fn probe(&self, interface: &str, local_addr: IpAddr) -> ConnectivityResult {
let _start = Instant::now();
let stun_result = self.probe_stun(interface, local_addr).await;
if let Some(result) = stun_result {
self.results.write().insert(interface.to_string(), result.clone());
let _ = self.event_tx.send(result.clone());
return result;
}
let dns_result = self.probe_dns(interface, local_addr).await;
if let Some(result) = dns_result {
self.results.write().insert(interface.to_string(), result.clone());
let _ = self.event_tx.send(result.clone());
return result;
}
let result = ConnectivityResult {
interface: interface.to_string(),
local_addr: SocketAddr::new(local_addr, 0),
has_internet: false,
external_addr: None,
nat_type: NatType::Unknown,
rtt: None,
timestamp: Instant::now(),
error: Some("All connectivity probes failed".to_string()),
};
self.results.write().insert(interface.to_string(), result.clone());
let _ = self.event_tx.send(result.clone());
result
}
async fn probe_stun(&self, interface: &str, local_addr: IpAddr) -> Option<ConnectivityResult> {
for server in &self.config.stun_servers {
for retry in 0..self.config.retries {
match self.do_stun_probe(interface, local_addr, server).await {
Ok(result) => return Some(result),
Err(e) => {
debug!("STUN probe to {} failed (attempt {}): {}", server, retry + 1, e);
}
}
}
}
None
}
async fn do_stun_probe(
&self,
interface: &str,
local_addr: IpAddr,
server: &str,
) -> Result<ConnectivityResult, String> {
let start = Instant::now();
let server_addr: SocketAddr = server.parse()
.map_err(|e| format!("Invalid STUN server address: {}", e))?;
let bind_addr = SocketAddr::new(local_addr, 0);
let socket = self.create_bound_socket(bind_addr, interface).await
.map_err(|e| format!("Failed to create socket: {}", e))?;
let txn_id: [u8; 12] = rand::random();
let request = build_stun_binding_request(&txn_id);
socket.send_to(&request, server_addr).await
.map_err(|e| format!("Failed to send STUN request: {}", e))?;
let mut buf = [0u8; 1024];
let timeout_result = tokio::time::timeout(
self.config.timeout,
socket.recv_from(&mut buf)
).await;
let rtt = start.elapsed();
let (len, _from) = timeout_result
.map_err(|_| "STUN response timeout".to_string())?
.map_err(|e| format!("Failed to receive STUN response: {}", e))?;
let external_addr = parse_stun_response(&buf[..len], &txn_id)
.ok_or_else(|| "Failed to parse STUN response".to_string())?;
let local = socket.local_addr().map_err(|e| e.to_string())?;
let nat_type = if external_addr.ip() == local.ip() && external_addr.port() == local.port() {
NatType::None
} else if external_addr.ip() == local.ip() {
NatType::PortRestrictedCone
} else {
NatType::Unknown
};
Ok(ConnectivityResult {
interface: interface.to_string(),
local_addr: local,
has_internet: true,
external_addr: Some(external_addr),
nat_type,
rtt: Some(rtt),
timestamp: Instant::now(),
error: None,
})
}
async fn probe_dns(&self, interface: &str, local_addr: IpAddr) -> Option<ConnectivityResult> {
for server in &self.config.dns_servers {
for retry in 0..self.config.retries {
match self.do_dns_probe(interface, local_addr, server).await {
Ok(result) => return Some(result),
Err(e) => {
debug!("DNS probe to {} failed (attempt {}): {}", server, retry + 1, e);
}
}
}
}
None
}
async fn do_dns_probe(
&self,
interface: &str,
local_addr: IpAddr,
server: &str,
) -> Result<ConnectivityResult, String> {
let start = Instant::now();
let server_addr: SocketAddr = server.parse()
.map_err(|e| format!("Invalid DNS server address: {}", e))?;
let bind_addr = SocketAddr::new(local_addr, 0);
let socket = self.create_bound_socket(bind_addr, interface).await
.map_err(|e| format!("Failed to create socket: {}", e))?;
let query = build_dns_query();
socket.send_to(&query, server_addr).await
.map_err(|e| format!("Failed to send DNS query: {}", e))?;
let mut buf = [0u8; 512];
let timeout_result = tokio::time::timeout(
self.config.timeout,
socket.recv_from(&mut buf)
).await;
let rtt = start.elapsed();
let (len, _) = timeout_result
.map_err(|_| "DNS response timeout".to_string())?
.map_err(|e| format!("Failed to receive DNS response: {}", e))?;
if len < 12 {
return Err("DNS response too short".to_string());
}
let local = socket.local_addr().map_err(|e| e.to_string())?;
Ok(ConnectivityResult {
interface: interface.to_string(),
local_addr: local,
has_internet: true,
external_addr: None, nat_type: NatType::Unknown,
rtt: Some(rtt),
timestamp: Instant::now(),
error: None,
})
}
async fn create_bound_socket(
&self,
bind_addr: SocketAddr,
interface: &str,
) -> std::io::Result<UdpSocket> {
use socket2::{Socket, Domain, Type, Protocol};
let domain = if bind_addr.is_ipv6() { Domain::IPV6 } else { Domain::IPV4 };
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
#[cfg(target_os = "linux")]
{
use std::ffi::CString;
use std::os::unix::io::AsRawFd;
if let Ok(cname) = CString::new(interface) {
let ret = unsafe {
libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_BINDTODEVICE,
cname.as_ptr() as *const libc::c_void,
(interface.len() + 1) as libc::socklen_t,
)
};
if ret != 0 {
debug!("SO_BINDTODEVICE failed for {}: {}, falling back to address binding",
interface, std::io::Error::last_os_error());
}
}
}
#[cfg(target_os = "macos")]
{
if let Some(idx) = super::if_nametoindex(interface) {
use std::os::unix::io::AsRawFd;
let ret = unsafe {
libc::setsockopt(
socket.as_raw_fd(),
libc::IPPROTO_IP,
libc::IP_BOUND_IF,
&idx as *const u32 as *const libc::c_void,
std::mem::size_of::<u32>() as libc::socklen_t,
)
};
if ret != 0 {
debug!("IP_BOUND_IF failed for {}: {}, falling back to address binding",
interface, std::io::Error::last_os_error());
}
}
}
socket.set_nonblocking(true)?;
socket.bind(&bind_addr.into())?;
UdpSocket::from_std(socket.into())
}
pub fn stop(&self) {
let _ = self.shutdown_tx.send(());
}
}
fn build_stun_binding_request(txn_id: &[u8; 12]) -> Vec<u8> {
let mut request = Vec::with_capacity(20);
request.extend_from_slice(&[0x00, 0x01]);
request.extend_from_slice(&[0x00, 0x00]);
request.extend_from_slice(&[0x21, 0x12, 0xa4, 0x42]);
request.extend_from_slice(txn_id);
request
}
fn parse_stun_response(data: &[u8], expected_txn_id: &[u8; 12]) -> Option<SocketAddr> {
if data.len() < 20 {
return None;
}
if data[0] != 0x01 || data[1] != 0x01 {
return None;
}
if &data[4..8] != &[0x21, 0x12, 0xa4, 0x42] {
return None;
}
if &data[8..20] != expected_txn_id {
return None;
}
let msg_len = u16::from_be_bytes([data[2], data[3]]) as usize;
let attrs_end = 20 + msg_len.min(data.len() - 20);
let mut pos = 20;
while pos + 4 <= attrs_end {
let attr_type = u16::from_be_bytes([data[pos], data[pos + 1]]);
let attr_len = u16::from_be_bytes([data[pos + 2], data[pos + 3]]) as usize;
pos += 4;
if pos + attr_len > attrs_end {
break;
}
if attr_type == 0x0020 || attr_type == 0x0001 {
if attr_len >= 8 {
let family = data[pos + 1];
let port_bytes = [data[pos + 2], data[pos + 3]];
let port = if attr_type == 0x0020 {
u16::from_be_bytes(port_bytes) ^ 0x2112
} else {
u16::from_be_bytes(port_bytes)
};
if family == 0x01 && attr_len >= 8 {
let addr_bytes = [data[pos + 4], data[pos + 5], data[pos + 6], data[pos + 7]];
let addr = if attr_type == 0x0020 {
let xor_key = [0x21, 0x12, 0xa4, 0x42];
[
addr_bytes[0] ^ xor_key[0],
addr_bytes[1] ^ xor_key[1],
addr_bytes[2] ^ xor_key[2],
addr_bytes[3] ^ xor_key[3],
]
} else {
addr_bytes
};
let ip = std::net::Ipv4Addr::from(addr);
return Some(SocketAddr::new(IpAddr::V4(ip), port));
} else if family == 0x02 && attr_len >= 20 {
let mut addr_bytes = [0u8; 16];
addr_bytes.copy_from_slice(&data[pos + 4..pos + 20]);
if attr_type == 0x0020 {
let mut xor_key = [0u8; 16];
xor_key[0..4].copy_from_slice(&[0x21, 0x12, 0xa4, 0x42]);
xor_key[4..16].copy_from_slice(expected_txn_id);
for i in 0..16 {
addr_bytes[i] ^= xor_key[i];
}
}
let ip = std::net::Ipv6Addr::from(addr_bytes);
return Some(SocketAddr::new(IpAddr::V6(ip), port));
}
}
}
pos += (attr_len + 3) & !3;
}
None
}
fn build_dns_query() -> Vec<u8> {
let mut query = Vec::with_capacity(17);
let txn_id: u16 = rand::random();
query.extend_from_slice(&txn_id.to_be_bytes());
query.extend_from_slice(&[0x01, 0x00]);
query.extend_from_slice(&[0x00, 0x01]);
query.extend_from_slice(&[0x00, 0x00]);
query.extend_from_slice(&[0x00, 0x00]);
query.extend_from_slice(&[0x00, 0x00]);
query.push(0x00);
query.extend_from_slice(&[0x00, 0x01]);
query.extend_from_slice(&[0x00, 0x01]);
query
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_stun_request() {
let txn_id = [0u8; 12];
let request = build_stun_binding_request(&txn_id);
assert_eq!(request.len(), 20);
assert_eq!(&request[0..2], &[0x00, 0x01]); assert_eq!(&request[4..8], &[0x21, 0x12, 0xa4, 0x42]); }
#[test]
fn test_build_dns_query() {
let query = build_dns_query();
assert!(query.len() >= 17);
assert_eq!(query[2] & 0x80, 0x00);
}
#[test]
fn test_nat_type_display() {
assert_eq!(format!("{}", NatType::None), "No NAT");
assert_eq!(format!("{}", NatType::Symmetric), "Symmetric");
}
}