use super::stun::{get_public_ip_stun_with_fallback_and_cache, StunError};
use super::stun_cache::StunCache;
use super::PublicIpError;
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Clone, Debug)]
pub struct StunClient {
cache: Arc<RwLock<StunCache>>,
servers: Vec<String>,
timeout: Duration,
}
impl StunClient {
pub fn new() -> Self {
Self::with_servers(vec![
"stun.l.google.com:19302".to_string(),
"stun1.l.google.com:19302".to_string(),
])
}
pub fn with_servers(servers: Vec<String>) -> Self {
Self {
cache: Arc::new(RwLock::new(StunCache::new())),
servers,
timeout: Duration::from_millis(500),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_cache(cache: StunCache, servers: Vec<String>) -> Self {
Self {
cache: Arc::new(RwLock::new(cache)),
servers,
timeout: Duration::from_millis(500),
}
}
pub async fn get_public_ip(&self) -> Result<IpAddr, PublicIpError> {
get_public_ip_stun_with_fallback_and_cache(self.timeout, &self.cache)
.await
.map_err(|e| match e {
StunError::Timeout => PublicIpError::Timeout,
StunError::IoError(err) => PublicIpError::HttpError(err.to_string()),
StunError::InvalidResponse => {
PublicIpError::ParseError("Invalid STUN response".to_string())
}
StunError::NoMappedAddress => {
PublicIpError::ParseError("No mapped address in STUN response".to_string())
}
})
}
pub fn servers(&self) -> &[String] {
&self.servers
}
pub async fn clear_cache(&self) {
let cache = self.cache.write().await;
cache.clear();
}
pub async fn prewarm_cache(&self) -> Result<(), PublicIpError> {
let cache = self.cache.read().await;
for server in &self.servers {
let _ = cache.get_stun_server_addrs(server).await;
}
Ok(())
}
pub async fn is_server_cached(&self, server: &str) -> bool {
let cache = self.cache.read().await;
cache.get_stun_server_addrs(server).await.is_ok()
}
}
impl Default for StunClient {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub servers_cached: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_stun_client_default() {
let client = StunClient::new();
assert_eq!(client.servers().len(), 2);
assert!(client.servers()[0].contains("google.com"));
}
#[tokio::test]
async fn test_stun_client_custom_servers() {
let servers = vec![
"stun.example.com:3478".to_string(),
"stun2.example.com:3478".to_string(),
];
let client = StunClient::with_servers(servers.clone());
assert_eq!(client.servers(), &servers);
}
#[tokio::test]
async fn test_stun_client_timeout() {
let client = StunClient::new().with_timeout(Duration::from_secs(2));
assert_eq!(client.servers().len(), 2);
}
#[tokio::test]
async fn test_cache_operations() {
let client = StunClient::new();
client.clear_cache().await;
let _ = client.prewarm_cache().await;
}
#[tokio::test]
async fn test_public_ip_detection() {
let client = StunClient::new();
match client.get_public_ip().await {
Ok(ip) => {
assert!(!ip.is_unspecified());
assert!(!ip.is_loopback());
}
Err(_) => {
}
}
}
}