use std::collections::{HashMap, HashSet};
use std::fmt;
use std::net::IpAddr;
use std::sync::{
atomic::{AtomicBool, Ordering},
mpsc::{self, RecvTimeoutError},
Arc, Mutex,
};
use std::time::Duration;
use pnet::packet::{ip::IpNextHeaderProtocols, ipv4::Ipv4Packet, udp::UdpPacket, Packet};
use trust_dns_resolver::proto::op::{Message, MessageType};
use trust_dns_resolver::proto::rr::RData;
use crate::{send, state::BruteForceState};
#[derive(Debug, Clone)]
pub struct DiscoveredDomain {
pub domain: String,
pub ip: String,
pub record_type: String,
pub timestamp: u64,
}
#[derive(Debug, Clone)]
pub struct VerificationResult {
pub domain: String,
pub ip: String,
pub http_status: Option<u16>,
pub https_status: Option<u16>,
pub title: Option<String>,
pub server: Option<String>,
pub is_alive: bool,
}
#[derive(Debug, Clone)]
pub struct SummaryStats {
pub total_domains: usize,
pub unique_ips: HashSet<String>,
pub ip_ranges: HashMap<String, Vec<String>>,
pub record_types: HashMap<String, usize>,
pub verified_domains: usize,
pub alive_domains: usize,
}
impl fmt::Display for VerificationResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{:<30} {:<15} {:<6} {:<6} {:<20} {:<10}",
self.domain,
self.ip,
self.http_status
.map_or("N/A".to_string(), |s| s.to_string()),
self.https_status
.map_or("N/A".to_string(), |s| s.to_string()),
self.title.as_deref().unwrap_or("N/A"),
if self.is_alive { "YES" } else { "NO" }
)
}
}
lazy_static::lazy_static! {
pub static ref DISCOVERED_DOMAINS: Arc<Mutex<Vec<DiscoveredDomain>>> = Arc::new(Mutex::new(Vec::new()));
pub static ref VERIFICATION_RESULTS: Arc<Mutex<Vec<VerificationResult>>> = Arc::new(Mutex::new(Vec::new()));
}
pub fn handle_dns_packet(
dns_recv: mpsc::Receiver<Arc<Vec<u8>>>,
flag_id: u16,
running: Arc<AtomicBool>,
silent: bool,
state: BruteForceState,
) {
if !silent {
println!(
"\n{:<30} {:<45} {:<7} {:<20}",
"域名", "IP地址", "记录类型", "时间戳"
);
println!("{}", "-".repeat(110));
}
while running.load(Ordering::Relaxed) {
match dns_recv.recv_timeout(Duration::from_millis(500)) {
Ok(ipv4_packet) => {
if !running.load(Ordering::Relaxed) {
break;
}
if let Some(ipv4) = Ipv4Packet::new(ipv4_packet.as_ref()) {
if ipv4.get_next_level_protocol() == IpNextHeaderProtocols::Udp {
if let Some(udp) = UdpPacket::new(ipv4.payload()) {
if let Ok(message) = Message::from_vec(udp.payload()) {
process_dns_response(
&message,
flag_id,
udp.get_destination(),
silent,
&state,
);
}
}
}
}
}
Err(RecvTimeoutError::Timeout) => {
continue;
}
Err(RecvTimeoutError::Disconnected) => {
break;
}
}
}
}
fn process_dns_response(
message: &Message,
flag_id: u16,
destination_port: u16,
silent: bool,
state: &BruteForceState,
) {
if message.message_type() != MessageType::Response {
return;
}
let tid = message.id() / 100;
if tid == flag_id {
if !message.answers().is_empty() {
let query_name = message
.queries()
.first()
.map(|query| normalize_domain(query.name().to_utf8()))
.unwrap_or_default();
let timestamp = chrono::Utc::now().timestamp() as u64;
for answer in message.answers() {
if let Some(discovered) =
discovered_from_record(&query_name, answer.data(), timestamp)
{
state.add_discovered_domain(discovered.clone());
print_discovered(&discovered, silent);
}
}
}
update_local_status(message.id(), destination_port, state);
}
}
fn discovered_from_record(
query_name: &str,
data: Option<&RData>,
timestamp: u64,
) -> Option<DiscoveredDomain> {
let (ip, record_type) = match data? {
RData::A(ip) => (ip.to_string(), "A".to_string()),
RData::AAAA(ip) => (ip.to_string(), "AAAA".to_string()),
RData::CNAME(name) => (normalize_domain(name.to_utf8()), "CNAME".to_string()),
RData::NS(name) => (normalize_domain(name.to_utf8()), "NS".to_string()),
RData::MX(mx) => (
format!(
"{} {}",
mx.preference(),
normalize_domain(mx.exchange().to_utf8())
),
"MX".to_string(),
),
RData::TXT(txt) => (
txt.txt_data()
.iter()
.map(|bytes| String::from_utf8_lossy(bytes).trim().to_string())
.filter(|value| !value.is_empty())
.collect::<Vec<_>>()
.join(" "),
"TXT".to_string(),
),
_ => return None,
};
Some(DiscoveredDomain {
domain: query_name.to_string(),
ip,
record_type,
timestamp,
})
}
fn normalize_domain(domain: String) -> String {
domain.trim_end_matches('.').to_string()
}
fn print_discovered(discovered: &DiscoveredDomain, silent: bool) {
if silent {
println!("{}", discovered.domain);
return;
}
let display_ip = if discovered.record_type == "TXT" && discovered.ip.len() > 15 {
format!("{}...", &discovered.ip[..12])
} else {
discovered.ip.clone()
};
println!(
"{:<30} {:<50} {:<10} {}",
discovered.domain,
display_ip,
discovered.record_type,
chrono::DateTime::from_timestamp(discovered.timestamp as i64, 0)
.unwrap_or_default()
.format("%H:%M:%S")
);
}
fn update_local_status(message_id: u16, destination_port: u16, state: &BruteForceState) {
let index = send::generate_map_index(message_id % 100, destination_port);
let _ = state.search_from_index_and_delete(index as u32);
state.push_to_stack(index as usize);
}
pub fn print_verification_result(result: &VerificationResult) {
static HEADER_PRINTED: std::sync::Once = std::sync::Once::new();
HEADER_PRINTED.call_once(|| {
println!(
"\n{:<30} {:<15} {:<6} {:<6} {:<20} {:<10}",
"域名", "IP地址", "HTTP", "HTTPS", "标题", "存活"
);
println!("{}", "-".repeat(90));
});
println!("{}", result);
}
pub fn get_discovered_domains() -> Vec<DiscoveredDomain> {
if let Ok(domains) = DISCOVERED_DOMAINS.lock() {
domains.clone()
} else {
Vec::new()
}
}
pub fn get_verification_results() -> Vec<VerificationResult> {
if let Ok(results) = VERIFICATION_RESULTS.lock() {
results.clone()
} else {
Vec::new()
}
}
pub fn generate_summary_from_data(
discovered: &[DiscoveredDomain],
verified: &[VerificationResult],
) -> SummaryStats {
let mut unique_ips = HashSet::new();
let mut record_types = HashMap::new();
let mut ip_ranges = HashMap::new();
for domain in discovered {
if let Ok(ip) = domain.ip.parse::<IpAddr>() {
unique_ips.insert(domain.ip.clone());
if let IpAddr::V4(ipv4) = ip {
let octets = ipv4.octets();
let range = format!("{}.{}.{}.0/24", octets[0], octets[1], octets[2]);
ip_ranges
.entry(range)
.or_insert_with(Vec::new)
.push(domain.ip.clone());
}
}
*record_types.entry(domain.record_type.clone()).or_insert(0) += 1;
}
let verified_count = verified.len();
let alive_count = verified.iter().filter(|v| v.is_alive).count();
SummaryStats {
total_domains: discovered.len(),
unique_ips,
ip_ranges,
record_types,
verified_domains: verified_count,
alive_domains: alive_count,
}
}
pub fn print_summary_stats(summary: &SummaryStats) {
println!("\n{}", "=".repeat(60));
println!(" 汇总统计");
println!("{}", "=".repeat(60));
println!("发现域名总数: {}", summary.total_domains);
println!("唯一IP数量: {}", summary.unique_ips.len());
println!("已验证域名: {}", summary.verified_domains);
println!("存活域名: {}", summary.alive_domains);
println!("\n记录类型分布:");
for (record_type, count) in &summary.record_types {
println!(" {}: {}", record_type, count);
}
println!("\nIP段分布 (前10个):");
let mut sorted_ranges: Vec<_> = summary.ip_ranges.iter().collect();
sorted_ranges.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
for (range, ips) in sorted_ranges.iter().take(10) {
println!(" {}: {} 个IP", range, ips.len());
}
if summary.unique_ips.len() > 0 {
println!("\n发现的IP地址 (前20个):");
let mut sorted_ips: Vec<_> = summary.unique_ips.iter().collect();
sorted_ips.sort();
for ip in sorted_ips.iter().take(20) {
println!(" {}", ip);
}
if summary.unique_ips.len() > 20 {
println!(" ... 还有 {} 个IP", summary.unique_ips.len() - 20);
}
}
println!("{}", "=".repeat(60));
}
pub fn generate_summary() -> SummaryStats {
let discovered = get_discovered_domains();
let verified = get_verification_results();
generate_summary_from_data(&discovered, &verified)
}
pub fn print_summary() {
let summary = generate_summary();
print_summary_stats(&summary);
}