use std::collections::HashSet;
use std::net::SocketAddr;
use std::time::Duration;
use chrono::Utc;
use native_tls::TlsConnector;
use once_cell::sync::Lazy;
use regex::Regex;
use reqwest::{Client, Url};
use tokio::net::TcpStream;
use tracing::{debug, instrument};
use super::types::{CertificateInfo, DnsResolution, DomainExpiration, StatusResponse};
use crate::dns::{DnsResolver, RecordData, RecordType};
use crate::error::{Result, SeerError};
use crate::lookup::SmartLookup;
use crate::validation::{describe_reserved_ip, normalize_domain};
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
const MAX_REDIRECTS: usize = 5;
static TITLE_REGEX: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"(?i)<title[^>]*>([^<]+)</title>").expect("Invalid regex for HTML title extraction")
});
#[derive(Debug, Clone)]
pub struct StatusClient {
timeout: Duration,
dns_resolver: DnsResolver,
smart_lookup: SmartLookup,
}
impl Default for StatusClient {
fn default() -> Self {
Self::new()
}
}
impl StatusClient {
pub fn new() -> Self {
Self {
timeout: DEFAULT_TIMEOUT,
dns_resolver: DnsResolver::new(),
smart_lookup: SmartLookup::new(),
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[instrument(skip(self), fields(domain = %domain))]
pub async fn check(&self, domain: &str) -> Result<StatusResponse> {
let domain = normalize_domain(domain)?;
debug!("Checking status for domain: {}", domain);
let mut response = StatusResponse::new(domain.clone());
let (http_result, cert_result, expiry_result, dns_result) = tokio::join!(
self.fetch_http_info(&domain),
self.fetch_certificate_info(&domain),
self.fetch_domain_expiration(&domain),
self.fetch_dns_resolution(&domain)
);
match http_result {
Ok((status, status_text, title)) => {
response.http_status = Some(status);
response.http_status_text = Some(status_text);
response.title = title;
}
Err(e) => response.errors.push(super::types::StatusError {
check: "http".to_string(),
message: e.to_string(),
}),
}
match cert_result {
Ok(cert_info) => response.certificate = Some(cert_info),
Err(e) => response.errors.push(super::types::StatusError {
check: "ssl".to_string(),
message: e.to_string(),
}),
}
match expiry_result {
Ok(expiry_info) => response.domain_expiration = expiry_info,
Err(e) => response.errors.push(super::types::StatusError {
check: "expiration".to_string(),
message: e.to_string(),
}),
}
match dns_result {
Ok(dns_info) => response.dns_resolution = Some(dns_info),
Err(e) => response.errors.push(super::types::StatusError {
check: "dns".to_string(),
message: e.to_string(),
}),
}
Ok(response)
}
async fn fetch_http_info(&self, domain: &str) -> Result<(u16, String, Option<String>)> {
let mut url = Url::parse(&format!("https://{}", domain))
.map_err(|e| SeerError::HttpError(format!("invalid URL: {}", e)))?;
let mut visited = HashSet::new();
for _ in 0..=MAX_REDIRECTS {
let validated_addrs = validate_url_target(&url).await?;
if !visited.insert(url.clone()) {
return Err(SeerError::HttpError("redirect loop detected".to_string()));
}
let host = url
.host_str()
.ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
let client = Client::builder()
.redirect(reqwest::redirect::Policy::none())
.user_agent(concat!("Seer/", env!("CARGO_PKG_VERSION")))
.resolve_to_addrs(host, &validated_addrs)
.build()
.map_err(|e| SeerError::HttpError(format!("failed to build HTTP client: {}", e)))?;
let response = client
.get(url.clone())
.timeout(self.timeout)
.send()
.await
.map_err(|e| SeerError::HttpError(e.to_string()))?;
if response.status().is_redirection() {
let location = response.headers().get(reqwest::header::LOCATION);
let location = location.and_then(|v| v.to_str().ok()).ok_or_else(|| {
SeerError::HttpError("redirect missing location header".to_string())
})?;
let next_url = url
.join(location)
.or_else(|_| Url::parse(location))
.map_err(|e| SeerError::HttpError(format!("invalid redirect URL: {}", e)))?;
url = next_url;
continue;
}
let status = response.status();
let status_code = status.as_u16();
let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
let title = if status.is_success() {
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if content_type.contains("text/html") {
const MAX_TITLE_BODY: usize = 64 * 1024;
use futures::StreamExt;
let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk
.map_err(|e| SeerError::HttpError(format!("body chunk: {}", e)))?;
let remaining = MAX_TITLE_BODY.saturating_sub(buf.len());
if remaining == 0 {
break;
}
let take = remaining.min(chunk.len());
buf.extend_from_slice(&chunk[..take]);
if buf.len() >= MAX_TITLE_BODY {
break;
}
}
let body = String::from_utf8_lossy(&buf);
extract_title(&body)
} else {
None
}
} else {
None
};
return Ok((status_code, status_text, title));
}
Err(SeerError::HttpError("too many redirects".to_string()))
}
async fn fetch_certificate_info(&self, domain: &str) -> Result<CertificateInfo> {
let addr = format!("{}:443", domain);
let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
.await
.map_err(|e| SeerError::CertificateError(format!("DNS lookup failed: {}", e)))?
.collect();
if socket_addrs.is_empty() {
return Err(SeerError::CertificateError(format!(
"DNS lookup returned no addresses for {}",
domain
)));
}
for socket_addr in &socket_addrs {
if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
return Err(SeerError::CertificateError(format!(
"cannot connect to {}: {} — {}",
domain,
socket_addr.ip(),
reason
)));
}
}
let connector = TlsConnector::builder()
.danger_accept_invalid_certs(true) .build()
.map_err(|e| SeerError::CertificateError(e.to_string()))?;
let connector = tokio_native_tls::TlsConnector::from(connector);
let stream =
tokio::time::timeout(self.timeout, TcpStream::connect(socket_addrs.as_slice()))
.await
.map_err(|_| SeerError::Timeout(format!("connection to {} timed out", domain)))?
.map_err(|e| SeerError::CertificateError(e.to_string()))?;
let tls_stream = tokio::time::timeout(self.timeout, connector.connect(domain, stream))
.await
.map_err(|_| SeerError::Timeout(format!("TLS handshake with {} timed out", domain)))?
.map_err(|e| SeerError::CertificateError(e.to_string()))?;
let cert = tls_stream
.get_ref()
.peer_certificate()
.map_err(|e| SeerError::CertificateError(e.to_string()))?
.ok_or_else(|| SeerError::CertificateError("no certificate found".to_string()))?;
let der = cert
.to_der()
.map_err(|e| SeerError::CertificateError(e.to_string()))?;
parse_certificate_der(&der, domain)
}
async fn fetch_domain_expiration(&self, domain: &str) -> Result<Option<DomainExpiration>> {
match self.smart_lookup.lookup(domain).await {
Ok(result) => {
let (expiration_date, registrar) = result.expiration_info();
if let Some(exp_date) = expiration_date {
let days_until_expiry = (exp_date - Utc::now()).num_days();
Ok(Some(DomainExpiration {
expiration_date: exp_date,
days_until_expiry,
registrar,
}))
} else {
Ok(None)
}
}
Err(_) => Ok(None), }
}
async fn fetch_dns_resolution(&self, domain: &str) -> Result<DnsResolution> {
let resolver = &self.dns_resolver;
let (a_result, aaaa_result, cname_result, ns_result) = tokio::join!(
resolver.resolve(domain, RecordType::A, None),
resolver.resolve(domain, RecordType::AAAA, None),
resolver.resolve(domain, RecordType::CNAME, None),
resolver.resolve(domain, RecordType::NS, None)
);
let a_records: Vec<String> = a_result
.unwrap_or_default()
.into_iter()
.filter_map(|r| {
if let RecordData::A { address } = r.data {
Some(address)
} else {
None
}
})
.collect();
let aaaa_records: Vec<String> = aaaa_result
.unwrap_or_default()
.into_iter()
.filter_map(|r| {
if let RecordData::AAAA { address } = r.data {
Some(address)
} else {
None
}
})
.collect();
let cname_target: Option<String> =
cname_result.unwrap_or_default().into_iter().find_map(|r| {
if let RecordData::CNAME { target } = r.data {
Some(target.trim_end_matches('.').to_string())
} else {
None
}
});
let nameservers: Vec<String> = ns_result
.unwrap_or_default()
.into_iter()
.filter_map(|r| {
if let RecordData::NS { nameserver } = r.data {
Some(nameserver.trim_end_matches('.').to_string())
} else {
None
}
})
.collect();
let resolves = !a_records.is_empty() || !aaaa_records.is_empty() || cname_target.is_some();
Ok(DnsResolution {
a_records,
aaaa_records,
cname_target,
nameservers,
resolves,
})
}
}
fn extract_title(html: &str) -> Option<String> {
TITLE_REGEX
.captures(html)
.and_then(|caps| caps.get(1))
.map(|m| m.as_str().trim().to_string())
.filter(|s| !s.is_empty())
}
async fn validate_url_target(url: &Url) -> Result<Vec<SocketAddr>> {
let scheme = url.scheme();
if scheme != "https" && scheme != "http" {
return Err(SeerError::HttpError(format!(
"unsupported URL scheme: {}",
scheme
)));
}
if !url.username().is_empty() || url.password().is_some() {
return Err(SeerError::HttpError(
"URL credentials are not allowed".to_string(),
));
}
let host = url
.host_str()
.ok_or_else(|| SeerError::HttpError("missing URL host".to_string()))?;
let port = url.port_or_known_default().unwrap_or(443);
if port != 80 && port != 443 {
return Err(SeerError::HttpError(format!(
"non-standard port {} is not allowed in redirects",
port
)));
}
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if let Some(reason) = describe_reserved_ip(&ip) {
return Err(SeerError::HttpError(format!(
"cannot connect to {}: {} — {}",
host, ip, reason
)));
}
return Ok(vec![SocketAddr::new(ip, port)]);
}
let addr = format!("{}:{}", host, port);
let socket_addrs: Vec<_> = tokio::net::lookup_host(&addr)
.await
.map_err(|e| SeerError::HttpError(format!("DNS lookup failed: {}", e)))?
.collect();
if socket_addrs.is_empty() {
return Err(SeerError::HttpError(format!(
"DNS lookup returned no addresses for {}",
host
)));
}
for socket_addr in &socket_addrs {
if let Some(reason) = describe_reserved_ip(&socket_addr.ip()) {
return Err(SeerError::HttpError(format!(
"cannot connect to {}: {} — {}",
host,
socket_addr.ip(),
reason
)));
}
}
Ok(socket_addrs)
}
fn parse_certificate_der(der: &[u8], domain: &str) -> Result<CertificateInfo> {
use x509_parser::prelude::*;
let (_, cert) = X509Certificate::from_der(der)
.map_err(|e| SeerError::CertificateError(format!("failed to parse certificate: {}", e)))?;
let issuer =
extract_name_from_x509(cert.issuer()).unwrap_or_else(|| "Unknown Issuer".to_string());
let subject =
extract_name_from_x509(cert.subject()).unwrap_or_else(|| "Unknown Subject".to_string());
let valid_from = asn1_time_to_chrono(cert.validity().not_before)?;
let valid_until = asn1_time_to_chrono(cert.validity().not_after)?;
let now = Utc::now();
let days_until_expiry = (valid_until - now).num_days();
let is_valid = now >= valid_from && now <= valid_until;
let hostname_verified = cert_matches_hostname(&cert, domain);
Ok(CertificateInfo {
issuer,
subject,
valid_from,
valid_until,
days_until_expiry,
is_valid,
hostname_verified,
})
}
fn hostname_matches_pattern(host: &str, pattern: &str) -> bool {
let host = host.to_ascii_lowercase();
let pattern = pattern.to_ascii_lowercase();
if let Some(rest) = pattern.strip_prefix("*.") {
let Some(dot) = host.find('.') else {
return false;
};
let host_rest = &host[dot + 1..];
host_rest == rest
} else {
host == pattern
}
}
fn cert_matches_hostname(cert: &x509_parser::certificate::X509Certificate<'_>, host: &str) -> bool {
use x509_parser::prelude::*;
if let Ok(Some(san_ext)) = cert.tbs_certificate.subject_alternative_name() {
for name in &san_ext.value.general_names {
if let GeneralName::DNSName(n) = name {
if hostname_matches_pattern(host, n) {
return true;
}
}
}
}
for cn in cert.subject().iter_common_name() {
if let Ok(s) = cn.as_str() {
if hostname_matches_pattern(host, s) {
return true;
}
}
}
false
}
fn extract_name_from_x509(name: &x509_parser::prelude::X509Name) -> Option<String> {
use x509_parser::prelude::*;
for rdn in name.iter() {
for attr in rdn.iter() {
if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
if let Some(s) = extract_attr_string(attr.attr_value()) {
return Some(s);
}
}
}
}
for rdn in name.iter() {
for attr in rdn.iter() {
if attr.attr_type() == &oid_registry::OID_X509_ORGANIZATION_NAME {
if let Some(s) = extract_attr_string(attr.attr_value()) {
return Some(s);
}
}
}
}
None
}
fn extract_attr_string(value: &x509_parser::der_parser::asn1_rs::Any) -> Option<String> {
if let Ok(s) = value.as_str() {
return Some(s.to_string());
}
if let Ok(utf8) = value.as_utf8string() {
return Some(utf8.string().to_string());
}
if let Ok(s) = std::str::from_utf8(value.data) {
return Some(s.to_string());
}
None
}
fn asn1_time_to_chrono(time: x509_parser::time::ASN1Time) -> Result<chrono::DateTime<Utc>> {
let timestamp = time.timestamp();
chrono::DateTime::from_timestamp(timestamp, 0)
.ok_or_else(|| SeerError::CertificateError("invalid certificate timestamp".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hostname_matches_pattern_exact() {
assert!(hostname_matches_pattern("example.com", "example.com"));
assert!(hostname_matches_pattern("EXAMPLE.COM", "example.com"));
assert!(hostname_matches_pattern("example.com", "EXAMPLE.COM"));
assert!(!hostname_matches_pattern("evil.com", "example.com"));
assert!(!hostname_matches_pattern("example.com", "evil.com"));
}
#[test]
fn hostname_matches_pattern_wildcard() {
assert!(hostname_matches_pattern("a.example.com", "*.example.com"));
assert!(hostname_matches_pattern("A.EXAMPLE.COM", "*.example.com"));
assert!(!hostname_matches_pattern("example.com", "*.example.com"));
assert!(!hostname_matches_pattern(
"a.b.example.com",
"*.example.com"
));
assert!(!hostname_matches_pattern("b.other.com", "*.example.com"));
}
#[test]
fn hostname_matches_pattern_wildcard_requires_dot() {
assert!(!hostname_matches_pattern("localhost", "*.example.com"));
}
}