use anyhow::Result;
use rustls::pki_types::ServerName;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Semaphore;
use tokio_rustls::TlsConnector;
use x509_parser::prelude::*;
#[derive(Debug, Clone)]
pub struct SanResult {
pub host: String,
pub sans: Vec<String>,
pub error: Option<String>,
}
pub async fn extract_sans_from_hosts(
hosts: &[String],
concurrency: usize,
timeout: Duration,
) -> Vec<SanResult> {
let semaphore = Arc::new(Semaphore::new(concurrency));
let tasks: Vec<_> = hosts
.iter()
.map(|host| {
let permit = semaphore.clone();
let host = host.clone();
let timeout = timeout;
tokio::spawn(async move {
let _permit = permit.acquire().await.unwrap();
extract_sans_single(&host, timeout).await
})
})
.collect();
let mut results = Vec::new();
for task in tasks {
if let Ok(result) = task.await {
results.push(result);
}
}
results
}
async fn extract_sans_single(host: &str, timeout: Duration) -> SanResult {
match tokio::time::timeout(timeout, extract_sans_impl(host)).await {
Ok(Ok(sans)) => SanResult {
host: host.to_string(),
sans,
error: None,
},
Ok(Err(e)) => SanResult {
host: host.to_string(),
sans: Vec::new(),
error: Some(e.to_string()),
},
Err(_) => SanResult {
host: host.to_string(),
sans: Vec::new(),
error: Some("connection timeout".to_string()),
},
}
}
async fn extract_sans_impl(host: &str) -> Result<Vec<String>> {
let root_store = rustls::RootCertStore::from_iter(
webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
);
let config = rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(CertCaptureVerifier::new(root_store)))
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(config));
let (hostname, port) = normalize_host(host);
let port = if port == 80 { 443 } else { port };
let addr = format!("{}:{}", hostname, port);
let stream = TcpStream::connect(&addr).await?;
let server_name = ServerName::try_from(hostname.clone())?;
let tls_stream = connector.connect(server_name, stream).await?;
let (_, conn) = tls_stream.get_ref();
let certs = conn
.peer_certificates()
.ok_or_else(|| anyhow::anyhow!("No peer certificates"))?;
if certs.is_empty() {
anyhow::bail!("No certificates in chain");
}
let leaf_cert = &certs[0];
let sans = extract_sans_from_der(leaf_cert.as_ref())?;
Ok(sans)
}
pub fn normalize_host(host: &str) -> (String, u16) {
let mut s = host.trim();
let mut default_port = 443u16;
if let Some(rest) = s.strip_prefix("https://") {
s = rest;
default_port = 443;
} else if let Some(rest) = s.strip_prefix("http://") {
s = rest;
default_port = 80;
}
if let Some(idx) = s.find('/') {
s = &s[..idx];
}
if let Some(idx) = s.find('?') {
s = &s[..idx];
}
if let Some(idx) = s.rfind(':') {
let port_str = &s[idx + 1..];
if let Ok(port) = port_str.parse::<u16>() {
return (s[..idx].to_string(), port);
}
}
(s.to_string(), default_port)
}
fn extract_sans_from_der(der: &[u8]) -> Result<Vec<String>> {
let (_, cert) = X509Certificate::from_der(der)?;
let mut sans = Vec::new();
if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
for name in &san_ext.value.general_names {
match name {
GeneralName::DNSName(dns) => {
sans.push(dns.to_string());
}
GeneralName::IPAddress(ip) => {
if ip.len() == 4 {
sans.push(format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]));
} else if ip.len() == 16 {
let parts: Vec<String> = ip
.chunks(2)
.map(|c| format!("{:02x}{:02x}", c[0], c[1]))
.collect();
sans.push(parts.join(":"));
}
}
_ => {}
}
}
}
for rdn in cert.subject().iter() {
for attr in rdn.iter() {
if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
if let Ok(cn) = attr.attr_value().as_str() {
if !sans.contains(&cn.to_string()) {
sans.push(cn.to_string());
}
}
}
}
}
Ok(sans)
}
#[derive(Debug)]
struct CertCaptureVerifier {
#[allow(dead_code)]
roots: rustls::RootCertStore,
}
impl CertCaptureVerifier {
fn new(roots: rustls::RootCertStore) -> Self {
Self { roots }
}
}
impl rustls::client::danger::ServerCertVerifier for CertCaptureVerifier {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&rustls::crypto::ring::default_provider().signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
rustls::crypto::ring::default_provider()
.signature_verification_algorithms
.supported_schemes()
}
}
use crate::wordlist::{expand_wildcard, wildcard_to_base};
pub struct SanMergeResult {
pub targets: Vec<String>,
pub wildcards: Vec<String>,
pub wildcard_bases: Vec<String>,
}
pub fn merge_sans_into_targets(
original_targets: &[String],
san_results: &[SanResult],
max_subdomain_depth: usize,
) -> SanMergeResult {
let mut all_hosts: HashSet<String> = HashSet::new();
let mut wildcards: Vec<String> = Vec::new();
let mut wildcard_bases: HashSet<String> = HashSet::new();
for target in original_targets {
let (hostname, _) = normalize_host(target);
if !hostname.is_empty() {
all_hosts.insert(hostname);
}
}
for result in san_results {
for san in &result.sans {
if san.starts_with('*') {
wildcards.push(san.clone());
if let Some(base) = wildcard_to_base(san) {
wildcard_bases.insert(base.clone());
all_hosts.insert(base.clone());
let expanded = expand_wildcard(san, max_subdomain_depth);
for host in expanded {
all_hosts.insert(host);
}
}
continue;
}
if san.parse::<std::net::IpAddr>().is_ok() {
continue;
}
let (hostname, _) = normalize_host(san);
if !hostname.is_empty() {
all_hosts.insert(hostname);
}
}
}
let mut targets: Vec<String> = all_hosts.into_iter().collect();
targets.sort();
wildcards.sort();
wildcards.dedup();
let wildcard_bases: Vec<String> = wildcard_bases.into_iter().collect();
SanMergeResult {
targets,
wildcards,
wildcard_bases,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_host() {
assert_eq!(normalize_host("example.com"), ("example.com".to_string(), 443));
assert_eq!(normalize_host("example.com:8443"), ("example.com".to_string(), 8443));
assert_eq!(normalize_host("https://example.com"), ("example.com".to_string(), 443));
assert_eq!(normalize_host("http://example.com"), ("example.com".to_string(), 80));
assert_eq!(normalize_host("https://example.com:8443"), ("example.com".to_string(), 8443));
assert_eq!(normalize_host("https://example.com/path/to/resource"), ("example.com".to_string(), 443));
assert_eq!(normalize_host("https://example.com:8443/path?query=1"), ("example.com".to_string(), 8443));
}
#[test]
fn test_merge_sans() {
let original = vec!["a.example.com".to_string(), "b.example.com".to_string()];
let san_results = vec![
SanResult {
host: "a.example.com".to_string(),
sans: vec![
"a.example.com".to_string(),
"c.example.com".to_string(),
"*.example.com".to_string(), ],
error: None,
},
];
let result = merge_sans_into_targets(&original, &san_results, 1);
assert!(result.targets.contains(&"a.example.com".to_string()));
assert!(result.targets.contains(&"b.example.com".to_string()));
assert!(result.targets.contains(&"c.example.com".to_string()));
assert!(result.targets.contains(&"example.com".to_string()));
assert!(result.targets.contains(&"www.example.com".to_string()));
assert!(result.targets.contains(&"api.example.com".to_string()));
assert!(result.wildcards.contains(&"*.example.com".to_string()));
assert!(!result.targets.iter().any(|s| s.starts_with('*')));
}
}