use std::net::IpAddr;
use std::net::Ipv4Addr;
use bytes::Bytes;
use crate::firewall::PacketInfo;
const IPV4_HEADER_LEN: usize = 20;
const UDP_HEADER_LEN: usize = 8;
const MTU: usize = 1280;
pub fn build_udp_reply(query: &PacketInfo, dns_payload: &[u8]) -> Option<Bytes> {
let (IpAddr::V4(app_ip), IpAddr::V4(magic_ip)) = (query.src_ip, query.dst_ip) else {
return None;
};
let total = IPV4_HEADER_LEN + UDP_HEADER_LEN + dns_payload.len();
if total > MTU {
return None;
}
let mut p = vec![0u8; total];
p[0] = 0x45; p[1] = 0; p[2..4].copy_from_slice(&(total as u16).to_be_bytes());
p[8] = 64; p[9] = 17; p[12..16].copy_from_slice(&magic_ip.octets()); p[16..20].copy_from_slice(&app_ip.octets());
let ip_csum = ones_complement_sum(&p[..IPV4_HEADER_LEN]);
p[10..12].copy_from_slice(&ip_csum.to_be_bytes());
let udp_off = IPV4_HEADER_LEN;
p[udp_off..udp_off + 2].copy_from_slice(&query.dst_port.to_be_bytes()); p[udp_off + 2..udp_off + 4].copy_from_slice(&query.src_port.to_be_bytes()); let udp_len = (UDP_HEADER_LEN + dns_payload.len()) as u16;
p[udp_off + 4..udp_off + 6].copy_from_slice(&udp_len.to_be_bytes());
p[udp_off + UDP_HEADER_LEN..].copy_from_slice(dns_payload);
let udp_csum = udp_checksum(&magic_ip, &app_ip, &p[udp_off..]);
let udp_csum = if udp_csum == 0 { 0xffff } else { udp_csum };
p[udp_off + 6..udp_off + 8].copy_from_slice(&udp_csum.to_be_bytes());
Some(Bytes::from(p))
}
fn ones_complement_sum(bytes: &[u8]) -> u16 {
let mut sum: u32 = 0;
let mut i = 0;
while i + 1 < bytes.len() {
sum += u16::from_be_bytes([bytes[i], bytes[i + 1]]) as u32;
i += 2;
}
if i < bytes.len() {
sum += (bytes[i] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!(sum as u16)
}
fn udp_checksum(src: &Ipv4Addr, dst: &Ipv4Addr, udp_segment: &[u8]) -> u16 {
let mut sum: u32 = 0;
for o in src.octets().chunks(2) {
sum += u16::from_be_bytes([o[0], o[1]]) as u32;
}
for o in dst.octets().chunks(2) {
sum += u16::from_be_bytes([o[0], o[1]]) as u32;
}
sum += 17u32; sum += udp_segment.len() as u32; let mut i = 0;
while i + 1 < udp_segment.len() {
sum += u16::from_be_bytes([udp_segment[i], udp_segment[i + 1]]) as u32;
i += 2;
}
if i < udp_segment.len() {
sum += (udp_segment[i] as u32) << 8;
}
while sum >> 16 != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
!(sum as u16)
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
fn ipv4_checksum_ok(hdr: &[u8]) -> bool {
let mut sum: u32 = 0;
for c in hdr.chunks(2) {
sum += u16::from_be_bytes([c[0], *c.get(1).unwrap_or(&0)]) as u32;
}
while sum >> 16 != 0 {
sum = (sum & 0xffff) + (sum >> 16);
}
sum as u16 == 0xffff
}
#[test]
fn build_udp_reply_swaps_and_checksums() {
let query = crate::firewall::PacketInfo {
src_ip: IpAddr::V4(Ipv4Addr::new(100, 64, 0, 5)), dst_ip: IpAddr::V4(Ipv4Addr::new(100, 100, 100, 53)), protocol: 17,
src_port: 51000,
dst_port: 53,
tcp_flags: 0,
icmp_type: 0,
icmp_id: 0,
};
let dns = b"\x12\x34\x81\x80\x00\x00\x00\x00\x00\x00\x00\x00"; let pkt = build_udp_reply(&query, dns).expect("v4 reply");
let info = crate::firewall::parse_packet_info(&pkt).expect("parses");
assert_eq!(info.src_ip, query.dst_ip);
assert_eq!(info.dst_ip, query.src_ip);
assert_eq!(info.src_port, 53);
assert_eq!(info.dst_port, 51000);
assert!(ipv4_checksum_ok(&pkt[..20]));
assert_eq!(&pkt[28..], dns);
}
#[test]
fn build_udp_reply_rejects_oversize() {
let query = crate::firewall::PacketInfo {
src_ip: IpAddr::V4(Ipv4Addr::new(100, 64, 0, 5)),
dst_ip: IpAddr::V4(Ipv4Addr::new(100, 100, 100, 53)),
protocol: 17,
src_port: 51000,
dst_port: 53,
tcp_flags: 0,
icmp_type: 0,
icmp_id: 0,
};
let big = vec![0u8; 1300];
assert!(build_udp_reply(&query, &big).is_none());
}
}