use crate::compress::*;
use crate::constants::*;
use crate::dns_sector::*;
use crate::errors::*;
use crate::parsed_packet::*;
use byteorder::{BigEndian, ByteOrder};
use std::marker;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
#[derive(Copy, Clone, Debug)]
pub struct RRRaw<'t> {
pub packet: &'t [u8],
pub offset: usize,
pub name_end: usize,
}
pub struct RRRawMut<'t> {
pub packet: &'t mut [u8],
pub offset: usize,
pub name_end: usize,
}
pub trait DNSIterable {
fn next(self) -> Option<Self>
where
Self: marker::Sized;
fn offset(&self) -> Option<usize>;
fn offset_next(&self) -> usize;
fn set_offset(&mut self, offset: usize);
fn set_offset_next(&mut self, offset: usize);
fn invalidate(&mut self);
fn is_tombstone(&self) -> bool {
self.offset().is_none()
}
fn recompute_rr(&mut self);
fn recompute_sections(&mut self);
fn raw(&self) -> RRRaw<'_>;
fn raw_mut(&mut self) -> RRRawMut<'_>;
fn parsed_packet(&self) -> &ParsedPacket;
fn parsed_packet_mut(&mut self) -> &mut ParsedPacket;
#[inline]
fn packet(&self) -> &[u8] {
let raw = self.raw();
raw.packet
}
#[inline]
fn name_slice(&self) -> &[u8] {
let raw = self.raw();
&raw.packet[raw.offset..raw.name_end]
}
#[inline]
fn rdata_slice(&self) -> &[u8] {
let raw = self.raw();
&raw.packet[raw.name_end..]
}
#[inline]
fn name_slice_mut(&mut self) -> &mut [u8] {
let raw = self.raw_mut();
&mut raw.packet[raw.offset..raw.name_end]
}
#[inline]
fn rdata_slice_mut(&mut self) -> &mut [u8] {
let raw = self.raw_mut();
&mut raw.packet[raw.name_end..]
}
fn uncompress(&mut self) -> Result<(), Error> {
if !self.parsed_packet().maybe_compressed {
return Ok(());
}
let (uncompressed, new_offset_next) = {
let ref_offset_next = self.offset_next();
let compressed = self.raw_mut().packet;
Compress::uncompress_with_previous_offset(compressed, ref_offset_next)?
};
self.parsed_packet_mut().packet = Some(uncompressed);
self.set_offset_next(new_offset_next);
self.recompute_sections();
self.recompute_rr();
Ok(())
}
}
pub trait TypedIterable {
fn name(&self) -> Vec<u8>
where
Self: DNSIterable,
{
let raw = self.raw();
let offset = raw.offset;
if raw.name_end <= offset {
return Vec::new();
}
let packet = raw.packet;
let mut name = Compress::raw_name_to_str(packet, offset);
name.make_ascii_lowercase();
name
}
fn copy_raw_name(&self, name: &mut Vec<u8>) -> usize
where
Self: DNSIterable,
{
let raw = self.raw();
if raw.name_end <= raw.offset {
return 0;
}
Compress::copy_uncompressed_name(name, raw.packet, raw.offset).name_len
}
fn current_section(&self) -> Result<Section, Error>
where
Self: DNSIterable,
{
let offset = self.offset();
let parsed_packet = self.parsed_packet();
if offset < parsed_packet.offset_question {
bail!(DSError::InternalError("name before the question section"));
}
let mut section = Section::Question;
if parsed_packet.offset_answers.is_some() && offset >= parsed_packet.offset_answers {
section = Section::Answer;
}
if parsed_packet.offset_nameservers.is_some() && offset >= parsed_packet.offset_nameservers
{
section = Section::NameServers
}
if parsed_packet.offset_additional.is_some() && offset >= parsed_packet.offset_additional {
section = Section::Additional;
}
Ok(section)
}
fn resize_rr(&mut self, shift: isize) -> Result<(), Error>
where
Self: DNSIterable,
{
{
if shift == 0 {
return Ok(());
}
let offset = self.offset().ok_or(DSError::VoidRecord)?;
let packet = &mut self.parsed_packet_mut().packet_mut();
let packet_len = packet.len();
if shift > 0 {
let new_packet_len = packet_len + shift as usize;
if new_packet_len > 0xffff {
bail!(DSError::PacketTooLarge);
}
packet.resize(new_packet_len, 0);
debug_assert_eq!(
new_packet_len,
(offset as isize + shift) as usize + (packet_len - offset) as usize
);
packet.copy_within(offset..offset + packet_len, offset + shift as usize);
} else if shift < 0 {
let shift = (-shift) as usize;
assert!(packet_len >= shift);
packet.copy_within(offset + shift.., offset);
packet.truncate(packet_len - shift);
}
}
let new_offset_next = (self.offset_next() as isize + shift) as usize;
self.set_offset_next(new_offset_next);
let section = self.current_section()?;
let parsed_packet = self.parsed_packet_mut();
if section == Section::NameServers
|| section == Section::Answer
|| section == Section::Question
{
parsed_packet.offset_additional = parsed_packet
.offset_additional
.map(|x| (x as isize + shift) as usize)
}
if section == Section::Answer || section == Section::Question {
parsed_packet.offset_nameservers = parsed_packet
.offset_nameservers
.map(|x| (x as isize + shift) as usize)
}
if section == Section::Question {
parsed_packet.offset_answers = parsed_packet
.offset_answers
.map(|x| (x as isize + shift) as usize)
}
Ok(())
}
fn set_raw_name(&mut self, name: &[u8]) -> Result<(), Error>
where
Self: DNSIterable,
{
let new_name_len = DNSSector::check_uncompressed_name(name, 0)?;
let name = &name[..new_name_len];
if self.parsed_packet().maybe_compressed {
let (uncompressed, new_offset) = {
let ref_offset = self.offset().ok_or(DSError::VoidRecord)?;
let compressed = self.raw_mut().packet;
Compress::uncompress_with_previous_offset(compressed, ref_offset)?
};
self.parsed_packet_mut().packet = Some(uncompressed);
self.set_offset(new_offset);
self.recompute_rr(); self.recompute_sections();
}
let offset = self.offset().ok_or(DSError::VoidRecord)?;
debug_assert_eq!(self.parsed_packet().maybe_compressed, false);
let current_name_len = Compress::raw_name_len(self.name_slice());
let shift = new_name_len as isize - current_name_len as isize;
self.resize_rr(shift)?;
{
let packet = &mut self.parsed_packet_mut().packet_mut();
packet[offset..offset + new_name_len].copy_from_slice(name);
}
self.recompute_rr();
Ok(())
}
fn delete(&mut self) -> Result<(), Error>
where
Self: DNSIterable,
{
self.offset().ok_or(DSError::VoidRecord)?;
let section = self.current_section()?;
if self.parsed_packet().maybe_compressed {
let (uncompressed, new_offset) = {
let ref_offset = self.offset().expect("delete() called on a tombstone");
let compressed = self.raw_mut().packet;
Compress::uncompress_with_previous_offset(compressed, ref_offset)?
};
self.parsed_packet_mut().packet = Some(uncompressed);
self.set_offset(new_offset);
self.recompute_rr(); self.recompute_sections();
}
let rr_len = self.offset_next()
- self
.offset()
.expect("Deleting record with no known offset after optional decompression");
assert!(rr_len > 0);
self.resize_rr(-(rr_len as isize))?;
let offset = self.offset().unwrap();
self.set_offset_next(offset);
self.invalidate();
let parsed_packet = self.parsed_packet_mut();
let rrcount = parsed_packet.rrcount_dec(section)?;
if rrcount <= 0 {
let offset = match section {
Section::Question => &mut parsed_packet.offset_question,
Section::Answer => &mut parsed_packet.offset_answers,
Section::NameServers => &mut parsed_packet.offset_nameservers,
Section::Additional => &mut parsed_packet.offset_additional,
_ => panic!("delete() cannot be used to delete EDNS pseudo-records"),
};
*offset = None;
}
Ok(())
}
#[inline]
fn rr_type(&self) -> u16
where
Self: DNSIterable,
{
BigEndian::read_u16(&self.rdata_slice()[DNS_RR_TYPE_OFFSET..])
}
#[inline]
fn rr_class(&self) -> u16
where
Self: DNSIterable,
{
BigEndian::read_u16(&self.rdata_slice()[DNS_RR_CLASS_OFFSET..])
}
}
#[derive(Copy, Clone, Debug)]
pub enum RawRRData<'t> {
IpAddr(IpAddr),
Data(&'t [u8]),
}
pub trait RdataIterable {
#[inline]
fn rr_ttl(&self) -> u32
where
Self: DNSIterable + TypedIterable,
{
BigEndian::read_u32(&self.rdata_slice()[DNS_RR_TTL_OFFSET..])
}
fn set_rr_ttl(&mut self, ttl: u32)
where
Self: DNSIterable + TypedIterable,
{
BigEndian::write_u32(&mut self.rdata_slice_mut()[DNS_RR_TTL_OFFSET..], ttl);
}
#[inline]
fn rr_rdlen(&self) -> usize
where
Self: DNSIterable + TypedIterable,
{
BigEndian::read_u16(&self.rdata_slice()[DNS_RR_RDLEN_OFFSET..]) as usize
}
fn rr_rd(&self) -> Result<RawRRData<'_>, Error>
where
Self: DNSIterable + TypedIterable,
{
if let Ok(ip_addr) = self.rr_ip() {
return Ok(RawRRData::IpAddr(ip_addr));
}
let rdata_len = self.rr_rdlen();
let rdata = &self.rdata_slice()[DNS_RR_HEADER_SIZE..DNS_RR_HEADER_SIZE + rdata_len];
Ok(RawRRData::Data(rdata))
}
fn rr_ip(&self) -> Result<IpAddr, Error>
where
Self: DNSIterable + TypedIterable,
{
match self.rr_type() {
x if x == Type::A.into() => {
let rdata = self.rdata_slice();
assert!(rdata.len() >= DNS_RR_HEADER_SIZE + 4);
let mut ip = [0u8; 4];
ip.copy_from_slice(&rdata[DNS_RR_HEADER_SIZE..DNS_RR_HEADER_SIZE + 4]);
Ok(IpAddr::V4(Ipv4Addr::from(ip)))
}
x if x == Type::AAAA.into() => {
let rdata = self.rdata_slice();
assert!(rdata.len() >= DNS_RR_HEADER_SIZE + 16);
let mut ip = [0u8; 16];
ip.copy_from_slice(&rdata[DNS_RR_HEADER_SIZE..DNS_RR_HEADER_SIZE + 16]);
Ok(IpAddr::V6(Ipv6Addr::from(ip)))
}
_ => bail!(DSError::PropertyNotFound),
}
}
fn set_rr_ip(&mut self, ip: &IpAddr) -> Result<(), Error>
where
Self: DNSIterable + TypedIterable,
{
match self.rr_type() {
x if x == Type::A.into() => match *ip {
IpAddr::V4(ip) => {
let rdata = self.rdata_slice_mut();
assert!(rdata.len() >= DNS_RR_HEADER_SIZE + 4);
rdata[DNS_RR_HEADER_SIZE..DNS_RR_HEADER_SIZE + 4].copy_from_slice(&ip.octets());
Ok(())
}
_ => bail!(DSError::WrongAddressFamily),
},
x if x == Type::AAAA.into() => match *ip {
IpAddr::V6(ip) => {
let rdata = self.rdata_slice_mut();
assert!(rdata.len() >= DNS_RR_HEADER_SIZE + 16);
rdata[DNS_RR_HEADER_SIZE..DNS_RR_HEADER_SIZE + 16]
.copy_from_slice(&ip.octets());
Ok(())
}
_ => bail!(DSError::WrongAddressFamily),
},
_ => bail!(DSError::PropertyNotFound),
}
}
}
#[derive(Debug)]
pub struct RRIterator<'t> {
pub parsed_packet: &'t mut ParsedPacket,
pub section: Section,
pub offset: Option<usize>,
pub offset_next: usize,
pub name_end: usize,
pub rrs_left: u16,
}
impl<'t> RRIterator<'t> {
pub fn new(parsed_packet: &'t mut ParsedPacket, section: Section) -> Self {
RRIterator {
parsed_packet,
section,
offset: None,
offset_next: 0,
name_end: 0,
rrs_left: 0,
}
}
pub fn recompute(&mut self) {
let offset = self
.offset
.expect("recompute() called prior to iterating over RRs");
let name_end = Self::skip_name(self.parsed_packet.packet(), offset);
let offset_next = Self::skip_rdata(self.parsed_packet.packet(), name_end);
self.name_end = name_end;
self.offset_next = offset_next;
}
pub fn skip_name(packet: &[u8], mut offset: usize) -> usize {
let packet_len = packet.len();
loop {
let label_len = match packet[offset] {
len if len & 0xc0 == 0xc0 => {
assert!(packet_len - offset > 2);
offset += 2;
break;
}
len => len,
} as usize;
assert!(label_len < packet_len - offset - 1);
offset += label_len + 1;
if label_len == 0 {
break;
}
}
offset
}
#[inline]
fn rr_rdlen(packet: &[u8], offset: usize) -> usize {
BigEndian::read_u16(&packet[offset + DNS_RR_RDLEN_OFFSET..]) as usize
}
#[inline]
pub fn skip_rdata(packet: &[u8], offset: usize) -> usize {
offset + DNS_RR_HEADER_SIZE + Self::rr_rdlen(packet, offset)
}
#[inline]
pub fn skip_rr(packet: &[u8], offset: usize) -> usize {
Self::skip_rdata(packet, Self::skip_name(packet, offset))
}
#[inline]
fn edns_rr_rdlen(packet: &[u8], offset: usize) -> usize {
BigEndian::read_u16(&packet[offset + DNS_EDNS_RR_RDLEN_OFFSET..]) as usize
}
pub fn edns_skip_rr(packet: &[u8], mut offset: usize) -> usize {
offset += DNS_EDNS_RR_HEADER_SIZE + Self::edns_rr_rdlen(packet, offset);
offset
}
}