use std::collections::HashSet;
use std::time::Duration;
use once_cell::sync::Lazy;
use regex::Regex;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::time::timeout;
use tracing::{debug, instrument, warn};
use super::parser::WhoisResponse;
use super::servers::{get_tld, get_whois_server};
use crate::cache::TtlCache;
use crate::error::{Result, SeerError};
use crate::retry::{RetryExecutor, RetryPolicy};
use crate::validation::normalize_domain;
static REFERRAL_PATTERNS: Lazy<Vec<Regex>> = Lazy::new(|| {
vec![
Regex::new(r"(?i)Registrar WHOIS Server:\s*(.+)")
.expect("Invalid regex pattern for Registrar WHOIS Server"),
Regex::new(r"(?i)Whois Server:\s*(.+)").expect("Invalid regex pattern for Whois Server"),
Regex::new(r"(?i)ReferralServer:\s*whois://(.+)")
.expect("Invalid regex pattern for ReferralServer"),
]
});
const WHOIS_PORT: u16 = 43;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(15); const MAX_RESPONSE_SIZE: usize = 1024 * 1024; const MAX_REFERRAL_DEPTH: u8 = 3;
const IANA_WHOIS_SERVER: &str = "whois.iana.org";
const SERVER_CACHE_TTL: Duration = Duration::from_secs(24 * 60 * 60);
static DISCOVERED_SERVERS: Lazy<TtlCache<String, String>> =
Lazy::new(|| TtlCache::new(SERVER_CACHE_TTL));
#[derive(Debug, Clone)]
pub struct WhoisClient {
timeout: Duration,
retry_policy: RetryPolicy,
}
impl Default for WhoisClient {
fn default() -> Self {
Self::new()
}
}
impl WhoisClient {
pub fn new() -> Self {
Self {
timeout: DEFAULT_TIMEOUT,
retry_policy: RetryPolicy::default(),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self {
self.retry_policy = policy;
self
}
pub fn without_retries(mut self) -> Self {
self.retry_policy = RetryPolicy::no_retry();
self
}
#[instrument(skip(self), fields(domain = %domain))]
pub async fn lookup(&self, domain: &str) -> Result<WhoisResponse> {
let start = std::time::Instant::now();
let domain = normalize_domain(domain)?;
let tld = get_tld(&domain).ok_or_else(|| SeerError::InvalidDomain(domain.clone()))?;
let whois_server = if let Some(server) = get_whois_server(tld) {
server.to_string()
} else {
let tld_lower = tld.to_lowercase();
if let Some(server) = DISCOVERED_SERVERS.get(&tld_lower) {
debug!(tld = %tld, server = %server, "Using cached WHOIS server");
server
} else {
debug!(tld = %tld, "Querying IANA for WHOIS server");
let server = self.discover_whois_server_with_retry(tld).await?;
DISCOVERED_SERVERS.insert(tld_lower, server.clone());
server
}
};
let mut visited = HashSet::new();
let result = self
.lookup_with_referrals(&domain, &whois_server, 0, &mut visited)
.await;
let elapsed_ms = start.elapsed().as_millis();
debug!(
domain = %domain,
elapsed_ms = elapsed_ms,
"WHOIS lookup complete"
);
result
}
fn lookup_with_referrals<'a>(
&'a self,
domain: &'a str,
whois_server: &'a str,
depth: u8,
visited: &'a mut HashSet<String>,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<WhoisResponse>> + Send + 'a>>
{
Box::pin(async move {
if depth >= MAX_REFERRAL_DEPTH {
warn!(depth = depth, server = %whois_server, "Max referral depth reached before query, aborting referral chain");
return Err(SeerError::WhoisError(
"max WHOIS referral depth exceeded".to_string(),
));
}
let server_lower = whois_server.to_lowercase();
if visited.contains(&server_lower) {
warn!(server = %whois_server, "Circular WHOIS referral detected");
return Err(SeerError::WhoisError(
"circular WHOIS referral detected".to_string(),
));
}
visited.insert(server_lower);
debug!(whois_server = %whois_server, depth = depth, "Querying WHOIS server");
let raw_response = self.query_server_with_retry(whois_server, domain).await?;
let current_response = WhoisResponse::parse(domain, whois_server, &raw_response);
if current_response.has_core_data() {
debug!(
server = %whois_server,
"Registry response has core data, skipping registrar referral"
);
return Ok(current_response);
}
if let Some(referral) = extract_referral(&raw_response) {
if referral != whois_server && !visited.contains(&referral.to_lowercase()) {
debug!(
referral_depth = depth,
"Registry response lacks core data, following referral to {}", referral
);
match self
.lookup_with_referrals(domain, &referral, depth + 1, visited)
.await
{
Ok(referral_response) => {
if referral_response.is_available()
|| referral_response.indicates_not_found()
{
debug!(
referral = %referral,
"Referral server indicates domain not found, using registry response"
);
return Ok(current_response);
}
return Ok(referral_response);
}
Err(e) => {
debug!(referral = %referral, error = %e, "Referral lookup failed, using registry response");
return Ok(current_response);
}
}
}
}
Ok(current_response)
})
}
#[instrument(skip(self), fields(domain = %domain, server = %server))]
pub async fn lookup_with_server(&self, domain: &str, server: &str) -> Result<WhoisResponse> {
if server.contains('\r') || server.contains('\n') {
return Err(SeerError::WhoisError(format!(
"invalid WHOIS server: contains illegal characters: {}",
server.replace('\r', "\\r").replace('\n', "\\n")
)));
}
if !is_safe_whois_server(server) {
return Err(SeerError::WhoisError(format!(
"invalid WHOIS server: {}",
server
)));
}
let domain = normalize_domain(domain)?;
let raw_response = self.query_server_with_retry(server, &domain).await?;
Ok(WhoisResponse::parse(&domain, server, &raw_response))
}
async fn query_server_with_retry(&self, server: &str, query: &str) -> Result<String> {
let executor = RetryExecutor::new(self.retry_policy.clone());
let server = server.to_string();
let query = query.to_string();
let timeout_duration = self.timeout;
executor
.execute(|| {
let server = server.clone();
let query = query.clone();
async move { query_server_internal(&server, &query, timeout_duration).await }
})
.await
}
async fn discover_whois_server_with_retry(&self, tld: &str) -> Result<String> {
let response = self.query_server_with_retry(IANA_WHOIS_SERVER, tld).await?;
if let Some(server) = extract_iana_whois_server(&response) {
if !is_safe_whois_server(&server) {
warn!(server = %server, "IANA returned unsafe WHOIS server, rejecting");
return Err(SeerError::WhoisError(format!(
"IANA returned unsafe WHOIS server: {}",
server
)));
}
crate::net::validate_public_host(&server, WHOIS_PORT).await?;
return Ok(server);
}
if let Some(url) = extract_iana_registration_url(&response) {
return Err(SeerError::WhoisServerNotFound(format!(
"No WHOIS server for '.{}' - check whois directly via: {}",
tld, url
)));
}
Err(SeerError::WhoisServerNotFound(format!(
"No WHOIS server found for TLD '{}'",
tld
)))
}
}
async fn query_server_internal(
server: &str,
query: &str,
timeout_duration: Duration,
) -> Result<String> {
if query.bytes().any(|b| b == 0 || b == b'\r' || b == b'\n') {
return Err(SeerError::WhoisError(
"query string must not contain CR/LF/NUL".into(),
));
}
crate::net::validate_public_host(server, WHOIS_PORT).await?;
let addr = format!("{}:{}", server, WHOIS_PORT);
debug!("WHOIS query to {}", server);
let mut stream = timeout(timeout_duration, TcpStream::connect(&addr))
.await
.map_err(|_| SeerError::Timeout(format!("connection to {} timed out", server)))?
.map_err(|e| SeerError::WhoisError(format!("failed to connect to {}: {}", server, e)))?;
let query_bytes = format!("{}\r\n", query);
timeout(timeout_duration, stream.write_all(query_bytes.as_bytes()))
.await
.map_err(|_| SeerError::Timeout("write timed out".to_string()))?
.map_err(|e| SeerError::WhoisError(format!("failed to send query: {}", e)))?;
let mut response = Vec::new();
let mut buf = [0u8; 4096];
let deadline = tokio::time::Instant::now() + timeout_duration;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
if !response.is_empty() {
tracing::warn!(
"WHOIS read deadline reached with partial response ({} bytes)",
response.len()
);
break;
}
return Err(SeerError::Timeout("read timed out".to_string()));
}
let read_result = timeout(remaining, stream.read(&mut buf)).await;
match read_result {
Ok(Ok(0)) => break, Ok(Ok(n)) => {
response.extend_from_slice(&buf[..n]);
if response.len() > MAX_RESPONSE_SIZE {
return Err(SeerError::WhoisError("response too large".to_string()));
}
}
Ok(Err(e)) => {
return Err(SeerError::WhoisError(format!("read error: {}", e)));
}
Err(_) => {
if !response.is_empty() {
tracing::warn!(
"WHOIS per-read timeout with partial response ({} bytes)",
response.len()
);
break;
}
return Err(SeerError::Timeout("read timed out".to_string()));
}
}
}
let _ = stream.shutdown().await;
Ok(String::from_utf8(response)
.unwrap_or_else(|e| e.into_bytes().iter().map(|&c| c as char).collect()))
}
fn extract_iana_whois_server(response: &str) -> Option<String> {
for line in response.lines() {
let line = line.trim();
if line.to_lowercase().starts_with("whois:") {
let server = line[6..].trim();
if !server.is_empty() {
return Some(server.to_lowercase());
}
}
}
None
}
fn extract_iana_registration_url(response: &str) -> Option<String> {
for line in response.lines() {
let line = line.trim();
if line.to_lowercase().starts_with("remarks:") {
let remarks = line[8..].trim();
if let Some(url_start) = remarks.find("http") {
let url = &remarks[url_start..];
let url_end = url.find(char::is_whitespace).unwrap_or(url.len());
return Some(url[..url_end].to_string());
}
}
}
None
}
fn is_safe_whois_server(server: &str) -> bool {
if server.is_empty() || !server.contains('.') {
return false;
}
if !server
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '.' || c == '-')
{
return false;
}
if let Ok(ip) = server.parse::<std::net::IpAddr>() {
return !crate::validation::is_private_or_reserved_ip(&ip);
}
true
}
fn extract_referral(response: &str) -> Option<String> {
for re in REFERRAL_PATTERNS.iter() {
if let Some(caps) = re.captures(response) {
if let Some(m) = caps.get(1) {
let server = m.as_str().trim().to_lowercase();
if is_safe_whois_server(&server) {
return Some(server);
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_domain() {
assert_eq!(normalize_domain("example.com").unwrap(), "example.com");
assert_eq!(normalize_domain("EXAMPLE.COM").unwrap(), "example.com");
assert_eq!(
normalize_domain("https://www.example.com/path").unwrap(),
"example.com"
);
assert!(normalize_domain("invalid").is_err());
}
#[test]
fn test_default_client_has_retry_policy() {
let client = WhoisClient::new();
assert_eq!(client.retry_policy.max_attempts, 3);
}
#[test]
fn test_client_without_retries() {
let client = WhoisClient::new().without_retries();
assert_eq!(client.retry_policy.max_attempts, 1);
}
#[test]
fn test_client_custom_retry_policy() {
let policy = RetryPolicy::new().with_max_attempts(5);
let client = WhoisClient::new().with_retry_policy(policy);
assert_eq!(client.retry_policy.max_attempts, 5);
}
#[tokio::test]
async fn query_server_internal_rejects_crlf_in_query() {
let err = query_server_internal(
"whois.example.com",
"example.com\r\nWHOIS evil.example",
Duration::from_secs(1),
)
.await
.expect_err("CRLF in query must be rejected");
match err {
SeerError::WhoisError(msg) => {
assert!(msg.contains("CR/LF/NUL"), "unexpected message: {msg}");
}
other => panic!("expected WhoisError, got {other:?}"),
}
assert!(
query_server_internal("whois.example.com", "a\nb", Duration::from_secs(1))
.await
.is_err()
);
assert!(
query_server_internal("whois.example.com", "a\0b", Duration::from_secs(1))
.await
.is_err()
);
}
#[test]
fn iana_discovery_rejects_unsafe_server() {
let synthetic = "refer: whois.iana.org\nwhois: 127.0.0.1\n";
let extracted = extract_iana_whois_server(synthetic);
assert_eq!(extracted.as_deref(), Some("127.0.0.1"));
assert!(
!is_safe_whois_server(&extracted.unwrap()),
"is_safe_whois_server must reject 127.0.0.1"
);
assert!(!is_safe_whois_server("whois.evil.com\r\nevil"));
assert!(!is_safe_whois_server("whois.evil.com:4444"));
assert!(is_safe_whois_server("whois.nic.xyz"));
}
#[tokio::test]
async fn whois_refuses_loopback_server() {
let err = query_server_internal("127.0.0.1", "example.com", Duration::from_secs(1))
.await
.expect_err("loopback server must be rejected");
match err {
SeerError::InvalidInput(msg) => {
assert!(
msg.contains("reserved") || msg.contains("127.0.0.1"),
"unexpected message: {msg}"
);
}
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn whois_refuses_rfc1918_server() {
let err = query_server_internal("10.0.0.1", "example.com", Duration::from_secs(1))
.await
.expect_err("RFC1918 server must be rejected");
match err {
SeerError::InvalidInput(msg) => {
assert!(
msg.contains("reserved") || msg.contains("10.0.0.1"),
"unexpected message: {msg}"
);
}
other => panic!("expected InvalidInput, got {other:?}"),
}
}
#[tokio::test]
async fn whois_refuses_link_local_metadata_server() {
let err = query_server_internal("169.254.169.254", "example.com", Duration::from_secs(1))
.await
.expect_err("link-local metadata server must be rejected");
assert!(matches!(err, SeerError::InvalidInput(_)));
}
}