use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
struct CacheEntry {
addresses: Vec<SocketAddr>,
resolved_at: Instant,
}
const CACHE_TTL: Duration = Duration::from_secs(3600);
#[derive(Debug)]
pub struct StunCache {
cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
}
impl Default for StunCache {
fn default() -> Self {
Self::new()
}
}
impl StunCache {
pub fn new() -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn get_stun_server_addrs(&self, server: &str) -> std::io::Result<Vec<SocketAddr>> {
{
let cache = self.cache.lock().expect("mutex poisoned");
if let Some(entry) = cache.get(server) {
if entry.resolved_at.elapsed() < CACHE_TTL {
return Ok(entry.addresses.clone());
}
}
}
let addresses: Vec<SocketAddr> = tokio::net::lookup_host(server).await?.collect();
if addresses.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"No addresses found for STUN server",
));
}
{
let mut cache = self.cache.lock().expect("mutex poisoned");
cache.insert(
server.to_string(),
CacheEntry {
addresses: addresses.clone(),
resolved_at: Instant::now(),
},
);
}
Ok(addresses)
}
pub fn clear(&self) {
let mut cache = self.cache.lock().expect("mutex poisoned");
cache.clear();
}
}
pub async fn prewarm_stun_cache_with_cache(cache: &Arc<tokio::sync::RwLock<StunCache>>) {
let servers = vec![
"stun.l.google.com:19302".to_string(),
"stun1.l.google.com:19302".to_string(),
];
for server in &servers {
let cache_read = cache.read().await;
let _ = cache_read.get_stun_server_addrs(server).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_ttl_constant() {
assert_eq!(CACHE_TTL.as_secs(), 3600, "CACHE_TTL should be 1 hour");
}
#[tokio::test]
async fn test_stun_cache() {
let cache = StunCache::new();
let test_server = "stun.l.google.com:19302";
let addrs1 = match cache.get_stun_server_addrs(test_server).await {
Ok(addrs) => addrs,
Err(e) => {
eprintln!("DNS resolution failed in test environment: {}", e);
return; }
};
assert!(!addrs1.is_empty(), "Should resolve at least one address");
let addrs2 = cache.get_stun_server_addrs(test_server).await.unwrap();
assert_eq!(
addrs1, addrs2,
"Second call should return identical cached result"
);
}
#[tokio::test]
async fn test_stun_cache_error_handling() {
let cache = StunCache::new();
let invalid_server = "this.definitely.does.not.exist.invalid:12345";
let result = cache.get_stun_server_addrs(invalid_server).await;
assert!(result.is_err(), "Invalid server should return error");
}
#[tokio::test]
async fn test_stun_cache_ttl() {
let cache = StunCache::new();
let test_server = "test.example.com:3478";
let test_addresses = vec!["127.0.0.1:3478".parse().unwrap()];
{
let mut cache_lock = cache.cache.lock().unwrap();
cache_lock.insert(
test_server.to_string(),
CacheEntry {
addresses: test_addresses.clone(),
resolved_at: Instant::now(),
},
);
}
let result = cache.get_stun_server_addrs(test_server).await;
assert!(result.is_ok(), "Should return cached entry");
assert_eq!(
result.unwrap(),
test_addresses,
"Should return correct addresses"
);
{
let mut cache_lock = cache.cache.lock().unwrap();
if let Some(expired_time) =
Instant::now().checked_sub(CACHE_TTL + Duration::from_secs(1))
{
cache_lock.insert(
test_server.to_string(),
CacheEntry {
addresses: test_addresses.clone(),
resolved_at: expired_time,
},
);
} else {
cache_lock.remove(test_server);
}
}
let result = cache.get_stun_server_addrs(test_server).await;
assert!(
result.is_err(),
"Should attempt re-resolution for expired/missing entry"
);
}
}