backdisco 0.4.0

Discover backend origins from CDN frontends using LLM-assisted pattern analysis and brute force enumeration
use anyhow::Result;
use rustls::pki_types::ServerName;
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpStream;
use tokio::sync::Semaphore;
use tokio_rustls::TlsConnector;
use x509_parser::prelude::*;

/// Result of SAN extraction for a single host
#[derive(Debug, Clone)]
pub struct SanResult {
    pub host: String,
    pub sans: Vec<String>,
    pub error: Option<String>,
}

/// Extract SANs from certificates of multiple hosts concurrently
pub async fn extract_sans_from_hosts(
    hosts: &[String],
    concurrency: usize,
    timeout: Duration,
) -> Vec<SanResult> {
    let semaphore = Arc::new(Semaphore::new(concurrency));

    let tasks: Vec<_> = hosts
        .iter()
        .map(|host| {
            let permit = semaphore.clone();
            let host = host.clone();
            let timeout = timeout;
            tokio::spawn(async move {
                let _permit = permit.acquire().await.unwrap();
                extract_sans_single(&host, timeout).await
            })
        })
        .collect();

    let mut results = Vec::new();
    for task in tasks {
        if let Ok(result) = task.await {
            results.push(result);
        }
    }

    results
}

/// Extract SANs from a single host's certificate
async fn extract_sans_single(host: &str, timeout: Duration) -> SanResult {
    match tokio::time::timeout(timeout, extract_sans_impl(host)).await {
        Ok(Ok(sans)) => SanResult {
            host: host.to_string(),
            sans,
            error: None,
        },
        Ok(Err(e)) => SanResult {
            host: host.to_string(),
            sans: Vec::new(),
            error: Some(e.to_string()),
        },
        Err(_) => SanResult {
            host: host.to_string(),
            sans: Vec::new(),
            error: Some("connection timeout".to_string()),
        },
    }
}

/// Internal implementation of SAN extraction
async fn extract_sans_impl(host: &str) -> Result<Vec<String>> {
    // Build TLS config that captures certificates
    let root_store = rustls::RootCertStore::from_iter(
        webpki_roots::TLS_SERVER_ROOTS.iter().cloned(),
    );

    // Create a custom verifier that accepts all certs but lets us see them
    let config = rustls::ClientConfig::builder()
        .dangerous()
        .with_custom_certificate_verifier(Arc::new(CertCaptureVerifier::new(root_store)))
        .with_no_client_auth();

    let connector = TlsConnector::from(Arc::new(config));

    // Normalize host - strip scheme, extract hostname and port
    let (hostname, port) = normalize_host(host);
    // Always use 443 for TLS certificate extraction
    let port = if port == 80 { 443 } else { port };

    // Connect to the server
    let addr = format!("{}:{}", hostname, port);
    let stream = TcpStream::connect(&addr).await?;

    let server_name = ServerName::try_from(hostname.clone())?;
    let tls_stream = connector.connect(server_name, stream).await?;

    // Get the peer certificates
    let (_, conn) = tls_stream.get_ref();
    let certs = conn
        .peer_certificates()
        .ok_or_else(|| anyhow::anyhow!("No peer certificates"))?;

    if certs.is_empty() {
        anyhow::bail!("No certificates in chain");
    }

    // Parse the leaf certificate (first in chain)
    let leaf_cert = &certs[0];
    let sans = extract_sans_from_der(leaf_cert.as_ref())?;

    Ok(sans)
}

/// Normalize a hostname/URL to extract just the hostname and port
/// Handles inputs like:
/// - "example.com" -> ("example.com", 443)
/// - "example.com:8443" -> ("example.com", 8443)
/// - "https://example.com" -> ("example.com", 443)
/// - "https://example.com:8443" -> ("example.com", 8443)
/// - "http://example.com" -> ("example.com", 80)
/// - "https://example.com/path" -> ("example.com", 443)
pub fn normalize_host(host: &str) -> (String, u16) {
    let mut s = host.trim();
    let mut default_port = 443u16;

    // Strip scheme
    if let Some(rest) = s.strip_prefix("https://") {
        s = rest;
        default_port = 443;
    } else if let Some(rest) = s.strip_prefix("http://") {
        s = rest;
        default_port = 80;
    }

    // Strip path (everything after first /)
    if let Some(idx) = s.find('/') {
        s = &s[..idx];
    }

    // Strip query string (everything after ?)
    if let Some(idx) = s.find('?') {
        s = &s[..idx];
    }

    // Parse host:port
    if let Some(idx) = s.rfind(':') {
        let port_str = &s[idx + 1..];
        if let Ok(port) = port_str.parse::<u16>() {
            return (s[..idx].to_string(), port);
        }
    }

    (s.to_string(), default_port)
}

/// Extract SANs from DER-encoded certificate
fn extract_sans_from_der(der: &[u8]) -> Result<Vec<String>> {
    let (_, cert) = X509Certificate::from_der(der)?;
    let mut sans = Vec::new();

    // Get Subject Alternative Names extension
    if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
        for name in &san_ext.value.general_names {
            match name {
                GeneralName::DNSName(dns) => {
                    sans.push(dns.to_string());
                }
                GeneralName::IPAddress(ip) => {
                    // Convert IP bytes to string
                    if ip.len() == 4 {
                        sans.push(format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]));
                    } else if ip.len() == 16 {
                        // IPv6 - simplified representation
                        let parts: Vec<String> = ip
                            .chunks(2)
                            .map(|c| format!("{:02x}{:02x}", c[0], c[1]))
                            .collect();
                        sans.push(parts.join(":"));
                    }
                }
                _ => {}
            }
        }
    }

    // Also add Common Name from subject if present
    for rdn in cert.subject().iter() {
        for attr in rdn.iter() {
            if attr.attr_type() == &oid_registry::OID_X509_COMMON_NAME {
                if let Ok(cn) = attr.attr_value().as_str() {
                    if !sans.contains(&cn.to_string()) {
                        sans.push(cn.to_string());
                    }
                }
            }
        }
    }

    Ok(sans)
}

/// Custom certificate verifier that accepts all certs (for inspection purposes)
#[derive(Debug)]
struct CertCaptureVerifier {
    // Root store kept for potential future validation, currently unused
    #[allow(dead_code)]
    roots: rustls::RootCertStore,
}

impl CertCaptureVerifier {
    fn new(roots: rustls::RootCertStore) -> Self {
        Self { roots }
    }
}

impl rustls::client::danger::ServerCertVerifier for CertCaptureVerifier {
    fn verify_server_cert(
        &self,
        _end_entity: &rustls::pki_types::CertificateDer<'_>,
        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
        _server_name: &ServerName<'_>,
        _ocsp_response: &[u8],
        _now: rustls::pki_types::UnixTime,
    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
        // Accept all certificates - we just want to see them
        Ok(rustls::client::danger::ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        message: &[u8],
        cert: &rustls::pki_types::CertificateDer<'_>,
        dss: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        rustls::crypto::verify_tls12_signature(
            message,
            cert,
            dss,
            &rustls::crypto::ring::default_provider().signature_verification_algorithms,
        )
    }

    fn verify_tls13_signature(
        &self,
        message: &[u8],
        cert: &rustls::pki_types::CertificateDer<'_>,
        dss: &rustls::DigitallySignedStruct,
    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
        rustls::crypto::verify_tls13_signature(
            message,
            cert,
            dss,
            &rustls::crypto::ring::default_provider().signature_verification_algorithms,
        )
    }

    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
        rustls::crypto::ring::default_provider()
            .signature_verification_algorithms
            .supported_schemes()
    }
}

use crate::wordlist::{expand_wildcard, wildcard_to_base};

/// Result of SAN merging, including discovered wildcards
pub struct SanMergeResult {
    pub targets: Vec<String>,
    pub wildcards: Vec<String>,
    pub wildcard_bases: Vec<String>,
}

/// Merge discovered SANs into the target list, returning the expanded list
/// All hostnames are normalized (scheme/path stripped)
/// Wildcards are expanded to common subdomains
pub fn merge_sans_into_targets(
    original_targets: &[String],
    san_results: &[SanResult],
    max_subdomain_depth: usize,
) -> SanMergeResult {
    let mut all_hosts: HashSet<String> = HashSet::new();
    let mut wildcards: Vec<String> = Vec::new();
    let mut wildcard_bases: HashSet<String> = HashSet::new();

    // Normalize and add original targets
    for target in original_targets {
        let (hostname, _) = normalize_host(target);
        if !hostname.is_empty() {
            all_hosts.insert(hostname);
        }
    }

    // Add SANs from results
    for result in san_results {
        for san in &result.sans {
            // Handle wildcards specially
            if san.starts_with('*') {
                wildcards.push(san.clone());
                
                // Extract base domain from wildcard
                if let Some(base) = wildcard_to_base(san) {
                    wildcard_bases.insert(base.clone());
                    all_hosts.insert(base.clone());
                    
                    // Expand wildcard to common subdomains
                    let expanded = expand_wildcard(san, max_subdomain_depth);
                    for host in expanded {
                        all_hosts.insert(host);
                    }
                }
                continue;
            }
            
            // Skip IP addresses for now - focus on hostnames
            if san.parse::<std::net::IpAddr>().is_ok() {
                continue;
            }
            
            // Normalize the SAN (though they should already be clean hostnames)
            let (hostname, _) = normalize_host(san);
            if !hostname.is_empty() {
                all_hosts.insert(hostname);
            }
        }
    }

    let mut targets: Vec<String> = all_hosts.into_iter().collect();
    targets.sort();
    
    wildcards.sort();
    wildcards.dedup();
    
    let wildcard_bases: Vec<String> = wildcard_bases.into_iter().collect();

    SanMergeResult {
        targets,
        wildcards,
        wildcard_bases,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_normalize_host() {
        assert_eq!(normalize_host("example.com"), ("example.com".to_string(), 443));
        assert_eq!(normalize_host("example.com:8443"), ("example.com".to_string(), 8443));
        assert_eq!(normalize_host("https://example.com"), ("example.com".to_string(), 443));
        assert_eq!(normalize_host("http://example.com"), ("example.com".to_string(), 80));
        assert_eq!(normalize_host("https://example.com:8443"), ("example.com".to_string(), 8443));
        assert_eq!(normalize_host("https://example.com/path/to/resource"), ("example.com".to_string(), 443));
        assert_eq!(normalize_host("https://example.com:8443/path?query=1"), ("example.com".to_string(), 8443));
    }

    #[test]
    fn test_merge_sans() {
        let original = vec!["a.example.com".to_string(), "b.example.com".to_string()];
        let san_results = vec![
            SanResult {
                host: "a.example.com".to_string(),
                sans: vec![
                    "a.example.com".to_string(),
                    "c.example.com".to_string(),
                    "*.example.com".to_string(), // Now expanded instead of filtered
                ],
                error: None,
            },
        ];

        let result = merge_sans_into_targets(&original, &san_results, 1);
        
        // Original targets should be present
        assert!(result.targets.contains(&"a.example.com".to_string()));
        assert!(result.targets.contains(&"b.example.com".to_string()));
        assert!(result.targets.contains(&"c.example.com".to_string()));
        
        // Wildcard base domain should be added
        assert!(result.targets.contains(&"example.com".to_string()));
        
        // Expanded wildcards should include common subdomains
        assert!(result.targets.contains(&"www.example.com".to_string()));
        assert!(result.targets.contains(&"api.example.com".to_string()));
        
        // Wildcards should be tracked
        assert!(result.wildcards.contains(&"*.example.com".to_string()));
        
        // No raw wildcards in targets
        assert!(!result.targets.iter().any(|s| s.starts_with('*')));
    }
}