use std::net::IpAddr;
use std::time::Duration;
#[derive(Debug, thiserror::Error)]
pub enum PublicIpError {
#[error("HTTP request failed: {0}")]
HttpError(String),
#[error("Failed to parse IP address: {0}")]
ParseError(String),
#[error("Request timed out")]
Timeout,
#[error("All public IP providers failed")]
AllProvidersFailed,
#[error("IPv6 is not yet supported")]
UnsupportedIpVersion,
#[error("ASN lookup failed: {0}")]
AsnLookupFailed(String),
#[error("{0}")]
Other(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PublicIpProvider {
#[default]
AwsCheckIp,
Ipify,
ICanHazIp,
}
impl PublicIpProvider {
pub fn url(&self) -> &'static str {
match self {
PublicIpProvider::AwsCheckIp => "https://checkip.amazonaws.com",
PublicIpProvider::Ipify => "https://api.ipify.org",
PublicIpProvider::ICanHazIp => "https://icanhazip.com",
}
}
pub fn all() -> &'static [PublicIpProvider] {
&[
PublicIpProvider::AwsCheckIp,
PublicIpProvider::Ipify,
PublicIpProvider::ICanHazIp,
]
}
}
pub async fn get_public_ip_from_provider(
provider: PublicIpProvider,
timeout: Duration,
) -> Result<IpAddr, PublicIpError> {
let url = provider.url().to_string();
let ip_str = tokio::task::spawn_blocking(move || {
let agent: ureq::Agent = ureq::Agent::config_builder()
.timeout_global(Some(timeout))
.build()
.into();
let body = agent
.get(&url)
.call()
.map_err(|e| match e {
ureq::Error::Timeout(_) => PublicIpError::Timeout,
other => PublicIpError::HttpError(other.to_string()),
})?
.body_mut()
.read_to_string()
.map_err(|e| PublicIpError::HttpError(e.to_string()))?;
Ok::<String, PublicIpError>(body)
})
.await
.map_err(|e| PublicIpError::HttpError(e.to_string()))??;
ip_str
.trim()
.parse::<IpAddr>()
.map_err(|e| PublicIpError::ParseError(format!("{e}: {}", ip_str.trim())))
}
pub async fn get_public_ip(preferred_provider: PublicIpProvider) -> Result<IpAddr, PublicIpError> {
let timeout = Duration::from_secs(5);
match get_public_ip_from_provider(preferred_provider, timeout).await {
Ok(ip) => return Ok(ip),
Err(_) => {
}
}
for provider in PublicIpProvider::all() {
if *provider == preferred_provider {
continue; }
match get_public_ip_from_provider(*provider, timeout).await {
Ok(ip) => return Ok(ip),
Err(_) => {
}
}
}
Err(PublicIpError::AllProvidersFailed)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_urls() {
assert_eq!(
PublicIpProvider::AwsCheckIp.url(),
"https://checkip.amazonaws.com"
);
assert_eq!(PublicIpProvider::Ipify.url(), "https://api.ipify.org");
assert_eq!(PublicIpProvider::ICanHazIp.url(), "https://icanhazip.com");
}
#[test]
fn test_provider_all() {
let providers = PublicIpProvider::all();
assert_eq!(providers.len(), 3);
assert!(providers.contains(&PublicIpProvider::AwsCheckIp));
assert!(providers.contains(&PublicIpProvider::Ipify));
assert!(providers.contains(&PublicIpProvider::ICanHazIp));
}
#[test]
fn test_default_provider() {
assert_eq!(PublicIpProvider::default(), PublicIpProvider::AwsCheckIp);
}
#[tokio::test]
async fn test_get_public_ip() {
let result = get_public_ip(PublicIpProvider::default()).await;
match result {
Ok(ip) => {
match ip {
IpAddr::V4(ipv4) => {
assert!(!ipv4.is_private());
assert!(!ipv4.is_loopback());
assert!(!ipv4.is_link_local());
}
IpAddr::V6(ipv6) => {
assert!(!ipv6.is_loopback());
}
}
}
Err(e) => {
eprintln!(
"Public IP detection failed (expected in some test environments): {}",
e
);
}
}
}
#[tokio::test]
async fn test_get_public_ip_from_each_provider() {
let timeout = Duration::from_secs(10);
for provider in PublicIpProvider::all() {
let result = get_public_ip_from_provider(*provider, timeout).await;
match result {
Ok(ip) => {
eprintln!("Provider {} returned IP: {}", provider.url(), ip);
assert!(matches!(ip, IpAddr::V4(_) | IpAddr::V6(_)));
}
Err(e) => {
eprintln!("Provider {} failed: {}", provider.url(), e);
}
}
}
}
#[tokio::test]
async fn test_provider_failover() {
let result = get_public_ip(PublicIpProvider::ICanHazIp).await;
match result {
Ok(_) => {}
Err(PublicIpError::AllProvidersFailed) => {}
Err(e) => {
panic!("Unexpected error type: {}", e);
}
}
}
#[tokio::test]
async fn test_timeout_handling() {
let very_short_timeout = Duration::from_millis(1);
let result =
get_public_ip_from_provider(PublicIpProvider::AwsCheckIp, very_short_timeout).await;
assert!(result.is_err());
match result.unwrap_err() {
PublicIpError::Timeout | PublicIpError::HttpError(_) => {}
e => panic!("Unexpected error type: {}", e),
}
}
#[test]
fn test_error_display() {
let errors = vec![
PublicIpError::HttpError("connection failed".to_string()),
PublicIpError::ParseError("invalid IP".to_string()),
PublicIpError::Timeout,
PublicIpError::AllProvidersFailed,
PublicIpError::UnsupportedIpVersion,
PublicIpError::AsnLookupFailed("lookup failed".to_string()),
];
for error in errors {
let error_str = error.to_string();
assert!(!error_str.is_empty());
match error {
PublicIpError::HttpError(msg) => assert!(error_str.contains(&msg)),
PublicIpError::ParseError(msg) => assert!(error_str.contains(&msg)),
PublicIpError::Timeout => assert!(error_str.contains("timed out")),
PublicIpError::AllProvidersFailed => assert!(error_str.contains("All")),
PublicIpError::UnsupportedIpVersion => assert!(error_str.contains("IPv6")),
PublicIpError::AsnLookupFailed(msg) => assert!(error_str.contains(&msg)),
PublicIpError::Other(msg) => assert!(error_str.contains(&msg)),
}
}
}
#[test]
fn test_provider_equality() {
assert_eq!(PublicIpProvider::AwsCheckIp, PublicIpProvider::AwsCheckIp);
assert_ne!(PublicIpProvider::AwsCheckIp, PublicIpProvider::Ipify);
let preferred = PublicIpProvider::Ipify;
let mut tried_count = 0;
for provider in PublicIpProvider::all() {
if *provider == preferred {
continue;
}
tried_count += 1;
}
assert_eq!(tried_count, 2);
}
}