use anyhow::{ensure, Error};
use byteorder::{BigEndian, ByteOrder};
use std::net::IpAddr;
const EDNS_CLIENT_SUBNET: u16 = 8;
const FAMILY_IPV4: u16 = 1;
const FAMILY_IPV6: u16 = 2;
pub fn extract_client_ip(
headers: &hyper::HeaderMap,
remote_addr: Option<std::net::SocketAddr>,
) -> Option<IpAddr> {
if let Some(xff) = headers.get("x-forwarded-for") {
if let Ok(xff_str) = xff.to_str() {
if let Some(first_ip) = xff_str.split(',').next() {
if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
return Some(ip);
}
}
}
}
if let Some(xri) = headers.get("x-real-ip") {
if let Ok(xri_str) = xri.to_str() {
if let Ok(ip) = xri_str.parse::<IpAddr>() {
return Some(ip);
}
}
}
remote_addr.map(|addr| addr.ip())
}
pub fn build_ecs_option(client_ip: IpAddr, prefix_v4: u8, prefix_v6: u8) -> Vec<u8> {
let mut option_data = Vec::new();
match client_ip {
IpAddr::V4(addr) => {
option_data.extend_from_slice(&FAMILY_IPV4.to_be_bytes());
option_data.push(prefix_v4);
option_data.push(0);
let octets = addr.octets();
let bytes_to_send = prefix_v4.div_ceil(8) as usize;
option_data.extend_from_slice(&octets[..bytes_to_send.min(4)]);
}
IpAddr::V6(addr) => {
option_data.extend_from_slice(&FAMILY_IPV6.to_be_bytes());
option_data.push(prefix_v6);
option_data.push(0);
let octets = addr.octets();
let bytes_to_send = prefix_v6.div_ceil(8) as usize;
option_data.extend_from_slice(&octets[..bytes_to_send.min(16)]);
}
}
option_data
}
pub fn add_ecs_to_packet(
packet: &mut Vec<u8>,
client_ip: IpAddr,
prefix_v4: u8,
prefix_v6: u8,
) -> Result<(), Error> {
use crate::dns;
let packet_len = packet.len();
ensure!(packet_len >= 12, "DNS packet too short");
ensure!(packet_len <= 4096, "Packet too large");
ensure!(dns::qdcount(packet) == 1, "No question");
let mut offset = dns::skip_name(packet, 12)?;
ensure!(packet_len - offset >= 4, "Short packet");
offset += 4;
let ancount = dns::ancount(packet);
let nscount = BigEndian::read_u16(&packet[8..10]);
let arcount = dns::arcount(packet);
ensure!(arcount > 0, "No EDNS OPT record found");
offset = dns::traverse_rrs(packet, offset, ancount as usize + nscount as usize, |_| {
Ok(())
})?;
let mut opt_record_info: Option<(usize, usize, usize)> = None;
for _ in 0..arcount {
let rr_start = offset;
if offset + 11 <= packet_len && packet[offset] == 0 {
let rtype = BigEndian::read_u16(&packet[offset + 1..offset + 3]);
if rtype == dns::DNS_TYPE_OPT {
let rdlength_offset = offset + 9; let rdlength =
BigEndian::read_u16(&packet[rdlength_offset..rdlength_offset + 2]) as usize;
opt_record_info = Some((rr_start, rdlength_offset, rdlength));
break;
}
}
offset = dns::skip_name(packet, offset)?;
ensure!(offset + 10 <= packet_len, "Incomplete RR");
let rdlen = BigEndian::read_u16(&packet[offset + 8..offset + 10]) as usize;
offset += 10 + rdlen;
}
let (_opt_start, rdlength_offset, rdlength) =
opt_record_info.ok_or_else(|| anyhow::anyhow!("No EDNS OPT record found"))?;
let rdata_start = rdlength_offset + 2;
let mut rdata_offset = 0;
let mut has_ecs = false;
while rdata_offset < rdlength {
if rdata_start + rdata_offset + 4 > packet_len {
break;
}
let opt_code = BigEndian::read_u16(&packet[rdata_start + rdata_offset..]);
let opt_len = BigEndian::read_u16(&packet[rdata_start + rdata_offset + 2..]) as usize;
if rdata_start + rdata_offset + 4 + opt_len > packet_len {
break;
}
if opt_code == EDNS_CLIENT_SUBNET {
has_ecs = true;
break;
}
rdata_offset += 4 + opt_len;
}
if has_ecs {
return Ok(());
}
let ecs_data = build_ecs_option(client_ip, prefix_v4, prefix_v6);
let ecs_option = [
&EDNS_CLIENT_SUBNET.to_be_bytes()[..],
&(ecs_data.len() as u16).to_be_bytes()[..],
&ecs_data[..],
]
.concat();
let new_rdlength = rdlength + ecs_option.len();
BigEndian::write_u16(
&mut packet[rdlength_offset..rdlength_offset + 2],
new_rdlength as u16,
);
let old_rdata_end = rdata_start + rdlength;
let remaining = if old_rdata_end < packet_len {
packet[old_rdata_end..].to_vec()
} else {
Vec::new()
};
packet.truncate(old_rdata_end);
packet.extend_from_slice(&ecs_option);
packet.extend_from_slice(&remaining);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_extract_client_ip() {
use std::net::SocketAddr;
let mut headers = hyper::HeaderMap::new();
headers.insert("x-forwarded-for", "192.168.1.1, 10.0.0.1".parse().unwrap());
assert_eq!(
extract_client_ip(&headers, None),
Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)))
);
headers.clear();
headers.insert("x-real-ip", "10.0.0.2".parse().unwrap());
assert_eq!(
extract_client_ip(&headers, None),
Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)))
);
headers.clear();
let remote: SocketAddr = "203.0.113.45:12345".parse().unwrap();
assert_eq!(
extract_client_ip(&headers, Some(remote)),
Some(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 45)))
);
headers.insert("x-forwarded-for", "192.168.1.5".parse().unwrap());
assert_eq!(
extract_client_ip(&headers, Some(remote)),
Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 5)))
);
}
#[test]
fn test_build_ecs_option() {
let ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
let data = build_ecs_option(ip, 24, 56);
assert_eq!(data[0..2], [0, 1]); assert_eq!(data[2], 24); assert_eq!(data[3], 0); assert_eq!(data[4..7], [192, 168, 1]);
let ip = IpAddr::V6(Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 1));
let data = build_ecs_option(ip, 24, 56);
assert_eq!(data[0..2], [0, 2]); assert_eq!(data[2], 56); assert_eq!(data[3], 0); assert_eq!(data.len(), 4 + 7); }
#[test]
fn test_add_ecs_to_packet() {
use crate::dns;
let mut packet = vec![
0x00, 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, 0x00, 0x01, 0x00, 0x01, ];
dns::set_edns_max_payload_size(&mut packet, 4096).unwrap();
let original_len = packet.len();
let client_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
let result = add_ecs_to_packet(&mut packet, client_ip, 24, 56);
assert!(
result.is_ok(),
"Failed to add ECS to packet: {:?}",
result.err()
);
assert_eq!(packet[10..12], [0x00, 0x01]);
assert!(
packet.len() > original_len,
"Packet should be longer after adding ECS"
);
}
#[test]
fn test_ecs_not_overwritten_if_exists() {
use crate::dns;
let mut packet = vec![
0x00, 0x00, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, b'e',
b'x', b'a', b'm', b'p', b'l', b'e', 0x03, b'c', b'o', b'm', 0x00, 0x00, 0x01, 0x00,
0x01,
];
dns::set_edns_max_payload_size(&mut packet, 4096).unwrap();
let client_provided_ecs =
build_ecs_option(IpAddr::V4(Ipv4Addr::new(10, 20, 30, 40)), 24, 56);
let opt_rdlength_offset = packet.len() - 2; let rdlength = BigEndian::read_u16(&packet[opt_rdlength_offset..]) as usize;
let ecs_option = [
&EDNS_CLIENT_SUBNET.to_be_bytes()[..],
&(client_provided_ecs.len() as u16).to_be_bytes()[..],
&client_provided_ecs[..],
]
.concat();
BigEndian::write_u16(
&mut packet[opt_rdlength_offset..],
(rdlength + ecs_option.len()) as u16,
);
packet.extend_from_slice(&ecs_option);
let packet_before = packet.clone();
let server_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
let result = add_ecs_to_packet(&mut packet, server_ip, 24, 56);
assert!(result.is_ok(), "Should succeed but not modify packet");
assert_eq!(
packet, packet_before,
"Packet should not be modified when ECS already exists"
);
}
}