use std::net::IpAddr;
use std::sync::Arc;
use arc_swap::ArcSwap;
use hickory_proto::{
op::{Message, MessageType, OpCode, ResponseCode},
serialize::binary::{BinDecodable, BinEncodable, BinEncoder},
};
use crate::dns::acl::{Acl, AclAction};
use crate::dns::local::{LocalZoneSet, ZoneAction};
use crate::dns::RateLimiter;
use super::loader::XdpHandle;
use super::socket::{
XskSocket, create_xsk_socket, get_rx_queue_count, iface_index,
};
use super::umem::{XdpDesc, FRAME_SIZE};
const ETH_HDR: usize = 14;
const IPV4_HDR_MIN: usize = 20;
const IPV6_HDR: usize = 40;
const UDP_HDR: usize = 8;
const ETH_P_IP: u16 = 0x0800;
const ETH_P_IPV6: u16 = 0x86DD;
const PROTO_UDP: u8 = 17;
pub fn start_xdp(
iface: &str,
zones: Arc<ArcSwap<LocalZoneSet>>,
rate_limiter: Arc<RateLimiter>,
acl: Arc<Acl>,
) -> Result<XdpHandle, String> {
let ifidx = iface_index(iface)
.ok_or_else(|| format!("interface {iface} not found"))?;
let mut handle = XdpHandle::load(iface)?;
let queue_count = get_rx_queue_count(iface).max(1);
tracing::info!(iface = %iface, queues = queue_count, "Starting XDP workers");
for q in 0..queue_count {
let sock = unsafe { create_xsk_socket(ifidx, q, true) }
.unwrap_or_else(|_| unsafe {
create_xsk_socket(ifidx, q, false)
.expect("AF_XDP socket creation failed even in copy mode")
});
handle.register_socket(q, sock.fd)?;
let z = Arc::clone(&zones);
let rl = Arc::clone(&rate_limiter);
let acl = Arc::clone(&acl);
std::thread::Builder::new()
.name(format!("xdp-{iface}-q{q}"))
.spawn(move || xdp_worker(sock, z, rl, acl))
.map_err(|e| format!("thread spawn: {e}"))?;
}
Ok(handle)
}
fn xdp_worker(
mut sock: XskSocket,
zones: Arc<ArcSwap<LocalZoneSet>>,
rate_limiter: Arc<RateLimiter>,
acl: Arc<Acl>,
) {
use libc::{poll, pollfd, POLLIN};
loop {
sock.umem.reclaim_tx();
let mut pfd = pollfd { fd: sock.fd, events: POLLIN, revents: 0 };
let ret = unsafe { poll(&mut pfd, 1, 1 ) };
if ret < 0 {
break;
}
let rxds = sock.rx.consume_rx();
if rxds.is_empty() {
continue;
}
let snapshot = zones.load();
let mut tx_descs: Vec<XdpDesc> = Vec::with_capacity(rxds.len());
let mut rx_addrs: Vec<u64> = Vec::with_capacity(rxds.len());
for desc in &rxds {
rx_addrs.push(desc.addr);
if let Some(tx_addr) = sock.umem.tx_free.pop_front() {
let (rx_frame, tx_frame) = unsafe {
let rx = std::slice::from_raw_parts(
sock.umem.area.add(desc.addr as usize),
desc.len as usize,
);
let tx = std::slice::from_raw_parts_mut(
sock.umem.area.add(tx_addr as usize),
FRAME_SIZE as usize,
);
(rx, tx)
};
let src_ip = extract_src_ip(rx_frame);
if src_ip.map(|ip| !rate_limiter.check(ip)).unwrap_or(false) {
sock.umem.tx_free.push_back(tx_addr);
continue;
}
match process_packet(rx_frame, tx_frame, &snapshot, &acl, src_ip) {
Some(tx_len) => tx_descs.push(XdpDesc {
addr: tx_addr,
len: tx_len as u32,
options: 0,
}),
None => {
sock.umem.tx_free.push_back(tx_addr);
}
}
}
}
sock.umem.fill.enqueue_batch(&rx_addrs);
if !tx_descs.is_empty() {
sock.tx.enqueue_tx(&tx_descs);
if sock.tx.needs_wakeup() {
unsafe {
libc::sendto(
sock.fd,
std::ptr::null(),
0,
libc::MSG_DONTWAIT,
std::ptr::null(),
0,
);
}
}
}
}
}
#[inline]
fn extract_src_ip(rx: &[u8]) -> Option<IpAddr> {
if rx.len() < ETH_HDR { return None; }
let ethertype = u16::from_be_bytes([rx[12], rx[13]]);
match ethertype {
ETH_P_IP => {
if rx.len() < ETH_HDR + 20 { return None; }
let src: [u8; 4] = rx[ETH_HDR + 12..ETH_HDR + 16].try_into().ok()?;
Some(IpAddr::V4(std::net::Ipv4Addr::from(src)))
}
ETH_P_IPV6 => {
if rx.len() < ETH_HDR + 40 { return None; }
let src: [u8; 16] = rx[ETH_HDR + 8..ETH_HDR + 24].try_into().ok()?;
Some(IpAddr::V6(std::net::Ipv6Addr::from(src)))
}
_ => None,
}
}
fn process_packet(
rx: &[u8],
tx: &mut [u8],
zones: &LocalZoneSet,
acl: &Acl,
src_ip: Option<IpAddr>,
) -> Option<usize> {
if rx.len() < ETH_HDR { return None; }
let ethertype = u16::from_be_bytes([rx[12], rx[13]]);
let (ip_off, is_v6) = match ethertype {
ETH_P_IP => (ETH_HDR, false),
ETH_P_IPV6 => (ETH_HDR, true),
_ => return None,
};
let (udp_off, ip_hdr_len, src_ip_off, dst_ip_off, ip_len_off) = if !is_v6 {
if rx.len() < ip_off + IPV4_HDR_MIN { return None; }
if rx[ip_off + 9] != PROTO_UDP { return None; }
let ihl = (rx[ip_off] & 0x0F) as usize * 4;
if ihl < 20 || ihl > 60 { return None; }
(ip_off + ihl, ihl, ip_off + 12, ip_off + 16, ip_off + 2)
} else {
if rx.len() < ip_off + IPV6_HDR { return None; }
if rx[ip_off + 6] != PROTO_UDP { return None; }
(ip_off + IPV6_HDR, IPV6_HDR, ip_off + 8, ip_off + 24, ip_off + 4)
};
if rx.len() < udp_off + UDP_HDR { return None; }
let src_port = u16::from_be_bytes([rx[udp_off], rx[udp_off + 1]]);
let dst_port = u16::from_be_bytes([rx[udp_off + 2], rx[udp_off + 3]]);
if dst_port != 53 { return None; }
let dns_off = udp_off + UDP_HDR;
if rx.len() <= dns_off { return None; }
let dns_in = &rx[dns_off..];
let mut dns_out: Vec<u8> = Vec::with_capacity(512);
if !answer_dns(dns_in, zones, acl, src_ip, &mut dns_out) {
return None; }
let reply_len = dns_off + dns_out.len();
if reply_len > tx.len() { return None; }
tx[0..6].copy_from_slice(&rx[6..12]);
tx[6..12].copy_from_slice(&rx[0..6]);
tx[12..14].copy_from_slice(&rx[12..14]);
if !is_v6 {
tx[ip_off..ip_off + ip_hdr_len].copy_from_slice(&rx[ip_off..ip_off + ip_hdr_len]);
let new_tot = (ip_hdr_len + UDP_HDR + dns_out.len()) as u16;
tx[ip_len_off..ip_len_off + 2].copy_from_slice(&new_tot.to_be_bytes());
let src: [u8; 4] = rx[src_ip_off..src_ip_off + 4].try_into().ok()?;
let dst: [u8; 4] = rx[dst_ip_off..dst_ip_off + 4].try_into().ok()?;
tx[ip_off + 12..ip_off + 16].copy_from_slice(&dst);
tx[ip_off + 16..ip_off + 20].copy_from_slice(&src);
tx[ip_off + 10..ip_off + 12].fill(0);
let cksum = ipv4_checksum(&tx[ip_off..ip_off + ip_hdr_len]);
tx[ip_off + 10..ip_off + 12].copy_from_slice(&cksum.to_be_bytes());
} else {
tx[ip_off..ip_off + IPV6_HDR].copy_from_slice(&rx[ip_off..ip_off + IPV6_HDR]);
let payload_len = (UDP_HDR + dns_out.len()) as u16;
tx[ip_len_off..ip_len_off + 2].copy_from_slice(&payload_len.to_be_bytes());
let src: [u8; 16] = rx[src_ip_off..src_ip_off + 16].try_into().ok()?;
let dst: [u8; 16] = rx[dst_ip_off..dst_ip_off + 16].try_into().ok()?;
tx[ip_off + 8..ip_off + 24].copy_from_slice(&dst);
tx[ip_off + 24..ip_off + 40].copy_from_slice(&src);
}
let udp_len = (UDP_HDR + dns_out.len()) as u16;
tx[udp_off..udp_off + 2].copy_from_slice(&dst_port.to_be_bytes()); tx[udp_off + 2..udp_off + 4].copy_from_slice(&src_port.to_be_bytes());
tx[udp_off + 4..udp_off + 6].copy_from_slice(&udp_len.to_be_bytes());
tx[udp_off + 6..udp_off + 8].fill(0);
let cksum = if !is_v6 {
let si: [u8; 4] = tx[ip_off + 12..ip_off + 16].try_into().ok()?;
let di: [u8; 4] = tx[ip_off + 16..ip_off + 20].try_into().ok()?;
udp_checksum_v4(&si, &di, &tx[udp_off..udp_off + UDP_HDR + dns_out.len()])
} else {
let si: [u8; 16] = tx[ip_off + 8..ip_off + 24].try_into().ok()?;
let di: [u8; 16] = tx[ip_off + 24..ip_off + 40].try_into().ok()?;
udp_checksum_v6(&si, &di, &tx[udp_off..udp_off + UDP_HDR + dns_out.len()])
};
tx[udp_off + 6..udp_off + 8].copy_from_slice(&cksum.to_be_bytes());
tx[dns_off..dns_off + dns_out.len()].copy_from_slice(&dns_out);
Some(reply_len)
}
fn answer_dns(
query_bytes: &[u8],
zones: &LocalZoneSet,
acl: &Acl,
src_ip: Option<IpAddr>,
out: &mut Vec<u8>,
) -> bool {
let msg = match Message::from_bytes(query_bytes) {
Ok(m) => m,
Err(_) => return false,
};
if msg.message_type() != MessageType::Query { return false; }
if msg.op_code() != OpCode::Query { return false; }
let q = match msg.queries().first() {
Some(q) => q,
None => return false,
};
if let Some(ip) = src_ip {
match acl.check(ip) {
AclAction::Allow => {}
AclAction::Deny => return false, AclAction::Refuse => {
let mut refused = Message::new();
refused.set_id(msg.id());
refused.set_message_type(MessageType::Response);
refused.set_op_code(OpCode::Query);
refused.set_response_code(ResponseCode::Refused);
refused.set_recursion_desired(msg.recursion_desired());
refused.add_query(q.clone());
let mut enc = BinEncoder::new(out);
return refused.emit(&mut enc).is_ok();
}
}
}
let name = q.name();
let rtype = q.query_type();
if rtype == hickory_proto::rr::RecordType::ANY { return false; }
let mut resp = Message::new();
resp.set_id(msg.id());
resp.set_message_type(MessageType::Response);
resp.set_op_code(OpCode::Query);
resp.set_recursion_desired(msg.recursion_desired());
resp.set_recursion_available(false);
resp.add_query(q.clone());
match zones.find(name) {
Some(ZoneAction::Refuse) => {
resp.set_response_code(ResponseCode::Refused);
resp.set_authoritative(false);
}
Some(ZoneAction::NxDomain) => {
resp.set_response_code(ResponseCode::NXDomain);
resp.set_authoritative(true);
}
Some(ZoneAction::Static) | Some(ZoneAction::Redirect) => {
resp.set_authoritative(true);
let records = zones.local_records(name, rtype);
if !records.is_empty() {
resp.set_response_code(ResponseCode::NoError);
for r in records {
resp.add_answer(r.clone());
}
} else if zones.name_has_records(name) {
resp.set_response_code(ResponseCode::NoError);
} else {
resp.set_response_code(ResponseCode::NXDomain);
}
}
None => return false,
}
let mut enc = BinEncoder::new(out);
resp.emit(&mut enc).is_ok()
}
fn ones_complement_sum(data: &[u8]) -> u32 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < data.len() {
sum += u16::from_be_bytes([data[i], data[i + 1]]) as u32;
i += 2;
}
if data.len() % 2 == 1 {
sum += (data[data.len() - 1] as u32) << 8;
}
sum
}
fn fold_checksum(mut s: u32) -> u16 {
while s >> 16 != 0 {
s = (s & 0xFFFF) + (s >> 16);
}
let r = !(s as u16);
if r == 0 { 0xFFFF } else { r } }
fn ipv4_checksum(header: &[u8]) -> u16 {
fold_checksum(ones_complement_sum(header))
}
fn udp_checksum_v4(src: &[u8; 4], dst: &[u8; 4], udp: &[u8]) -> u16 {
let udp_len = udp.len() as u32;
let s = ones_complement_sum(src)
+ ones_complement_sum(dst)
+ PROTO_UDP as u32
+ udp_len
+ ones_complement_sum(udp);
fold_checksum(s)
}
fn udp_checksum_v6(src: &[u8; 16], dst: &[u8; 16], udp: &[u8]) -> u16 {
let udp_len = udp.len() as u32;
let udp_len_bytes = udp_len.to_be_bytes();
let s = ones_complement_sum(src)
+ ones_complement_sum(dst)
+ ones_complement_sum(&udp_len_bytes)
+ PROTO_UDP as u32
+ ones_complement_sum(udp);
fold_checksum(s)
}