use std::net::SocketAddr;
use std::sync::atomic::{AtomicU16, Ordering};
use std::time::Duration;
use hickory_proto::op::{Message, ResponseCode};
use hickory_proto::op::update_message::zone_transfer as build_axfr_query;
use hickory_proto::rr::{Name, RData};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use crate::error::{Result, ShoheError};
use crate::resolver::{DnsQuery, DnsQueryResult, DnsRecord};
use crate::resolver::standard::record_to_dns_record;
const MAX_ZONE_RECORDS: usize = 500_000;
static QUERY_ID_CTR: AtomicU16 = AtomicU16::new(1);
pub async fn axfr(domain: &str, server: SocketAddr, timeout_secs: u64) -> Result<DnsQueryResult> {
let timeout = Duration::from_secs(timeout_secs);
let name: Name = domain
.parse()
.map_err(|e| ShoheError::Parse(format!("invalid domain '{domain}': {e}")))?;
let mut tcp = tokio::time::timeout(timeout, TcpStream::connect(server))
.await
.map_err(|_| ShoheError::Transport(format!("connection to {server} timed out")))?
.map_err(|e| ShoheError::Io(e))?;
let query_id = QUERY_ID_CTR.fetch_add(1, Ordering::Relaxed);
let mut query = build_axfr_query(name.clone(), None);
query.metadata.id = query_id;
let query_bytes = query
.to_vec()
.map_err(|e| ShoheError::DnsProto(e))?;
let msg_len = u16::try_from(query_bytes.len())
.map_err(|_| ShoheError::Transport("AXFR query too large".to_string()))?;
let mut frame = Vec::with_capacity(2 + query_bytes.len());
frame.extend_from_slice(&msg_len.to_be_bytes());
frame.extend_from_slice(&query_bytes);
tcp.write_all(&frame).await?;
let mut records: Vec<DnsRecord> = Vec::new();
let mut expected_serial: Option<u32> = None;
let start = std::time::Instant::now();
loop {
if start.elapsed() >= timeout {
return Err(ShoheError::Transport(format!("AXFR from {server} timed out")));
}
let remaining = timeout.saturating_sub(start.elapsed());
let mut len_buf = [0u8; 2];
tokio::time::timeout(remaining, tcp.read_exact(&mut len_buf))
.await
.map_err(|_| ShoheError::Transport(format!("AXFR read from {server} timed out")))?
.map_err(|e| ShoheError::Io(e))?;
let msg_len = u16::from_be_bytes(len_buf) as usize;
if msg_len == 0 {
break;
}
let mut msg_buf = vec![0u8; msg_len];
let remaining = timeout.saturating_sub(start.elapsed());
tokio::time::timeout(remaining, tcp.read_exact(&mut msg_buf))
.await
.map_err(|_| ShoheError::Transport(format!("AXFR read from {server} timed out")))?
.map_err(|e| ShoheError::Io(e))?;
let message = Message::from_vec(&msg_buf)
.map_err(|e| ShoheError::DnsProto(e.into()))?;
let rcode = message.response_code;
if rcode != ResponseCode::NoError {
return Err(ShoheError::DnsResolution(format!(
"AXFR from {server} returned error: {rcode}"
)));
}
if message.metadata.id != query_id {
continue;
}
let mut done = false;
for record in &message.answers {
if records.len() >= MAX_ZONE_RECORDS {
return Err(ShoheError::DnsResolution(format!(
"AXFR from {server} exceeded {MAX_ZONE_RECORDS} record limit"
)));
}
if let RData::SOA(soa) = &record.data {
match expected_serial {
None => {
expected_serial = Some(soa.serial);
records.push(record_to_dns_record(record));
}
Some(serial) if soa.serial == serial => {
records.push(record_to_dns_record(record));
done = true;
break;
}
_ => {
records.push(record_to_dns_record(record));
}
}
} else {
records.push(record_to_dns_record(record));
}
}
if done {
break;
}
}
if records.is_empty() {
return Err(ShoheError::DnsResolution(format!(
"AXFR from {server} returned no records"
)));
}
Ok(DnsQueryResult {
query: DnsQuery {
name: domain.to_string(),
record_type: "AXFR".to_string(),
class: "IN".to_string(),
},
answers: records,
authority: vec![],
additional: vec![],
duration_ms: 0,
server_addr: server.to_string(),
})
}