rsubdomain 1.2.11

A high-performance subdomain brute-force tool written in Rust
Documentation
use std::collections::{HashMap, HashSet};
use std::fmt;
use std::net::IpAddr;
use std::sync::{
    atomic::{AtomicBool, Ordering},
    mpsc::{self, RecvTimeoutError},
    Arc, Mutex,
};
use std::time::Duration;

use pnet::packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, udp::UdpPacket, Packet};
use trust_dns_resolver::proto::op::{Message, MessageType};
use trust_dns_resolver::proto::rr::RData;

use crate::{send, state::BruteForceState};

/// 发现的域名结果
#[derive(Debug, Clone)]
pub struct DiscoveredDomain {
    pub domain: String,
    pub ip: String,
    pub record_type: String,
    pub timestamp: u64,
}

/// 验证结果
#[derive(Debug, Clone)]
pub struct VerificationResult {
    pub domain: String,
    pub ip: String,
    pub http_status: Option<u16>,
    pub https_status: Option<u16>,
    pub title: Option<String>,
    pub server: Option<String>,
    pub is_alive: bool,
}

/// 汇总统计信息
#[derive(Debug, Clone)]
pub struct SummaryStats {
    pub total_domains: usize,
    pub unique_ips: HashSet<String>,
    pub ip_ranges: HashMap<String, Vec<String>>,
    pub record_types: HashMap<String, usize>,
    pub verified_domains: usize,
    pub alive_domains: usize,
}

impl fmt::Display for VerificationResult {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "{:<30} {:<15} {:<6} {:<6} {:<20} {:<10}",
            self.domain,
            self.ip,
            self.http_status
                .map_or("N/A".to_string(), |s| s.to_string()),
            self.https_status
                .map_or("N/A".to_string(), |s| s.to_string()),
            self.title.as_deref().unwrap_or("N/A"),
            if self.is_alive { "YES" } else { "NO" }
        )
    }
}

// 全局结果收集器
lazy_static::lazy_static! {
    pub static ref DISCOVERED_DOMAINS: Arc<Mutex<Vec<DiscoveredDomain>>> = Arc::new(Mutex::new(Vec::new()));
    pub static ref VERIFICATION_RESULTS: Arc<Mutex<Vec<VerificationResult>>> = Arc::new(Mutex::new(Vec::new()));
}

pub fn handle_dns_packet(
    dns_recv: mpsc::Receiver<Arc<Vec<u8>>>,
    flag_id: u16,
    running: Arc<AtomicBool>,
    silent: bool,
    state: BruteForceState,
) {
    // 打印表格头部
    if !silent {
        println!(
            "\n{:<30} {:<45} {:<7} {:<20}",
            "域名", "IP地址", "记录类型", "时间戳"
        );
        println!("{}", "-".repeat(110));
    }

    while running.load(Ordering::Relaxed) {
        match dns_recv.recv_timeout(Duration::from_millis(500)) {
            Ok(ipv4_packet) => {
                // 在处理数据包前再次检查运行状态
                if !running.load(Ordering::Relaxed) {
                    break;
                }

                if let Some(ipv4) = Ipv4Packet::new(ipv4_packet.as_ref()) {
                    if ipv4.get_next_level_protocol() == IpNextHeaderProtocols::Udp {
                        if let Some(udp) = UdpPacket::new(ipv4.payload()) {
                            if let Ok(message) = Message::from_vec(udp.payload()) {
                                process_dns_response(
                                    &message,
                                    flag_id,
                                    udp.get_destination(),
                                    silent,
                                    &state,
                                );
                            }
                        }
                    }
                }
            }
            Err(RecvTimeoutError::Timeout) => {
                // 超时是正常的,继续循环检查running标志
                continue;
            }
            Err(RecvTimeoutError::Disconnected) => {
                // 通道已断开,退出循环
                break;
            }
        }
    }
}

/// 处理DNS响应
fn process_dns_response(
    message: &Message,
    flag_id: u16,
    destination_port: u16,
    silent: bool,
    state: &BruteForceState,
) {
    if message.message_type() != MessageType::Response {
        return;
    }

    let tid = message.id() / 100;
    if tid == flag_id {
        if !message.answers().is_empty() {
            let query_name = message
                .queries()
                .first()
                .map(|query| normalize_domain(query.name().to_utf8()))
                .unwrap_or_default();
            let timestamp = chrono::Utc::now().timestamp() as u64;

            for answer in message.answers() {
                if let Some(discovered) =
                    discovered_from_record(&query_name, answer.data(), timestamp)
                {
                    state.add_discovered_domain(discovered.clone());
                    print_discovered(&discovered, silent);
                }
            }
        }

        // 处理本地状态
        update_local_status(message.id(), destination_port, state);
    }
}

fn discovered_from_record(
    query_name: &str,
    data: Option<&RData>,
    timestamp: u64,
) -> Option<DiscoveredDomain> {
    let (ip, record_type) = match data? {
        RData::A(ip) => (ip.to_string(), "A".to_string()),
        RData::AAAA(ip) => (ip.to_string(), "AAAA".to_string()),
        RData::CNAME(name) => (normalize_domain(name.to_utf8()), "CNAME".to_string()),
        RData::NS(name) => (normalize_domain(name.to_utf8()), "NS".to_string()),
        RData::MX(mx) => (
            format!(
                "{} {}",
                mx.preference(),
                normalize_domain(mx.exchange().to_utf8())
            ),
            "MX".to_string(),
        ),
        RData::TXT(txt) => (
            txt.txt_data()
                .iter()
                .map(|bytes| String::from_utf8_lossy(bytes).trim().to_string())
                .filter(|value| !value.is_empty())
                .collect::<Vec<_>>()
                .join(" "),
            "TXT".to_string(),
        ),
        _ => return None,
    };

    Some(DiscoveredDomain {
        domain: query_name.to_string(),
        ip,
        record_type,
        timestamp,
    })
}

fn normalize_domain(domain: String) -> String {
    domain.trim_end_matches('.').to_string()
}

fn print_discovered(discovered: &DiscoveredDomain, silent: bool) {
    if silent {
        println!("{}", discovered.domain);
        return;
    }

    let display_ip = if discovered.record_type == "TXT" && discovered.ip.len() > 15 {
        format!("{}...", &discovered.ip[..12])
    } else {
        discovered.ip.clone()
    };

    println!(
        "{:<30} {:<50} {:<10} {}",
        discovered.domain,
        display_ip,
        discovered.record_type,
        chrono::DateTime::from_timestamp(discovered.timestamp as i64, 0)
            .unwrap_or_default()
            .format("%H:%M:%S")
    );
}

/// 更新本地状态
fn update_local_status(message_id: u16, destination_port: u16, state: &BruteForceState) {
    let index = send::generate_map_index(message_id % 100, destination_port);
    let _ = state.search_from_index_and_delete(index as u32);
    state.push_to_stack(index as usize);
}

/// 实时打印验证结果
pub fn print_verification_result(result: &VerificationResult) {
    static HEADER_PRINTED: std::sync::Once = std::sync::Once::new();

    HEADER_PRINTED.call_once(|| {
        println!(
            "\n{:<30} {:<15} {:<6} {:<6} {:<20} {:<10}",
            "域名", "IP地址", "HTTP", "HTTPS", "标题", "存活"
        );
        println!("{}", "-".repeat(90));
    });

    println!("{}", result);
}

/// 获取发现的域名列表
pub fn get_discovered_domains() -> Vec<DiscoveredDomain> {
    if let Ok(domains) = DISCOVERED_DOMAINS.lock() {
        domains.clone()
    } else {
        Vec::new()
    }
}

/// 获取验证结果列表
pub fn get_verification_results() -> Vec<VerificationResult> {
    if let Ok(results) = VERIFICATION_RESULTS.lock() {
        results.clone()
    } else {
        Vec::new()
    }
}

/// 生成汇总统计
pub fn generate_summary_from_data(
    discovered: &[DiscoveredDomain],
    verified: &[VerificationResult],
) -> SummaryStats {
    let mut unique_ips = HashSet::new();
    let mut record_types = HashMap::new();
    let mut ip_ranges = HashMap::new();

    // 统计发现的域名
    for domain in discovered {
        if let Ok(ip) = domain.ip.parse::<IpAddr>() {
            unique_ips.insert(domain.ip.clone());

            // 计算IP段
            if let IpAddr::V4(ipv4) = ip {
                let octets = ipv4.octets();
                let range = format!("{}.{}.{}.0/24", octets[0], octets[1], octets[2]);
                ip_ranges
                    .entry(range)
                    .or_insert_with(Vec::new)
                    .push(domain.ip.clone());
            }
        }

        *record_types.entry(domain.record_type.clone()).or_insert(0) += 1;
    }

    let verified_count = verified.len();
    let alive_count = verified.iter().filter(|v| v.is_alive).count();

    SummaryStats {
        total_domains: discovered.len(),
        unique_ips,
        ip_ranges,
        record_types,
        verified_domains: verified_count,
        alive_domains: alive_count,
    }
}

/// 打印汇总信息
pub fn print_summary_stats(summary: &SummaryStats) {
    println!("\n{}", "=".repeat(60));
    println!("                    汇总统计");
    println!("{}", "=".repeat(60));

    println!("发现域名总数: {}", summary.total_domains);
    println!("唯一IP数量: {}", summary.unique_ips.len());
    println!("已验证域名: {}", summary.verified_domains);
    println!("存活域名: {}", summary.alive_domains);

    println!("\n记录类型分布:");
    for (record_type, count) in &summary.record_types {
        println!("  {}: {}", record_type, count);
    }

    println!("\nIP段分布 (前10个):");
    let mut sorted_ranges: Vec<_> = summary.ip_ranges.iter().collect();
    sorted_ranges.sort_by(|a, b| b.1.len().cmp(&a.1.len()));

    for (range, ips) in sorted_ranges.iter().take(10) {
        println!("  {}: {} 个IP", range, ips.len());
    }

    if summary.unique_ips.len() > 0 {
        println!("\n发现的IP地址 (前20个):");
        let mut sorted_ips: Vec<_> = summary.unique_ips.iter().collect();
        sorted_ips.sort();
        for ip in sorted_ips.iter().take(20) {
            println!("  {}", ip);
        }
        if summary.unique_ips.len() > 20 {
            println!("  ... 还有 {} 个IP", summary.unique_ips.len() - 20);
        }
    }

    println!("{}", "=".repeat(60));
}

/// 生成汇总统计(兼容旧接口)
pub fn generate_summary() -> SummaryStats {
    let discovered = get_discovered_domains();
    let verified = get_verification_results();
    generate_summary_from_data(&discovered, &verified)
}

/// 打印汇总信息(兼容旧接口)
pub fn print_summary() {
    let summary = generate_summary();
    print_summary_stats(&summary);
}