antidns/
packet_buffer.rs

1use snafu::{ensure, ResultExt, Snafu};
2use std::string::FromUtf8Error;
3
4const MAX_JUMP_INSTRUCTIONS: i32 = 5;
5
6#[derive(Debug, Snafu)]
7pub enum BufferError {
8    #[snafu(display("unexpected end of buffer"))]
9    EndOfBuffer,
10
11    #[snafu(display("limit of {} jumps exceeded", limit))]
12    TooManyJumps {
13        limit: i32,
14    },
15
16    #[snafu(display("single label exceeds 63 characters in length"))]
17    LabelTooLong,
18
19    UnicodeError {
20        source: FromUtf8Error,
21    },
22}
23
24type Result<T> = std::result::Result<T, BufferError>;
25
26pub struct BytePacketBuffer {
27    pub buf: [u8; 512],
28    pub pos: usize,
29}
30
31impl BytePacketBuffer {
32    pub fn new() -> BytePacketBuffer {
33        BytePacketBuffer {
34            buf: [0; 512],
35            pos: 0,
36        }
37    }
38
39    pub fn pos(&self) -> usize {
40        self.pos
41    }
42
43    pub fn step(&mut self, steps: usize) -> Result<()> {
44        self.pos += steps;
45
46        Ok(())
47    }
48
49    fn seek(&mut self, pos: usize) {
50        self.pos = pos;
51    }
52
53    pub fn read(&mut self) -> Result<u8> {
54        ensure!(self.pos < 512, EndOfBufferSnafu);
55        let res = self.buf[self.pos];
56        self.pos += 1;
57
58        Ok(res)
59    }
60
61    fn get(&mut self, pos: usize) -> Result<u8> {
62        ensure!(pos < 512, EndOfBufferSnafu);
63        Ok(self.buf[pos])
64    }
65
66    pub fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
67        ensure!(start + len < 512, EndOfBufferSnafu);
68        Ok(&self.buf[start..start + len as usize])
69    }
70
71    pub fn read_u16(&mut self) -> Result<u16> {
72        let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
73        Ok(res)
74    }
75
76    pub fn read_u32(&mut self) -> Result<u32> {
77        let res = ((self.read()? as u32) << 24)
78            | ((self.read()? as u32) << 16)
79            | ((self.read()? as u32) << 8)
80            | (self.read()? as u32);
81
82        Ok(res)
83    }
84
85    pub fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
86        let mut pos = self.pos();
87        let mut jumped = false;
88
89        let mut delim = "";
90        let mut jumps_performed = 0;
91        loop {
92            // Dns Packets are untrusted data, so we need to be paranoid. Someone
93            // can craft a packet with a cycle in the jump instructions. This guards
94            // against such packets.
95            ensure!(
96                jumps_performed <= MAX_JUMP_INSTRUCTIONS,
97                TooManyJumpsSnafu {
98                    limit: MAX_JUMP_INSTRUCTIONS
99                }
100            );
101
102            let len = self.get(pos)?;
103
104            // A two byte sequence, where the two highest bits of the first byte is
105            // set, represents a offset relative to the start of the buffer. We
106            // handle this by jumping to the offset, setting a flag to indicate
107            // that we shouldn't update the shared buffer position once done.
108            if (len & 0xC0) == 0xC0 {
109                // When a jump is performed, we only modify the shared buffer
110                // position once, and avoid making the change later on.
111                if !jumped {
112                    self.seek(pos + 2);
113                }
114
115                let b2 = self.get(pos + 1)? as u16;
116                let offset = (((len as u16) ^ 0xC0) << 8) | b2;
117                pos = offset as usize;
118                jumped = true;
119                jumps_performed += 1;
120                continue;
121            }
122
123            pos += 1;
124
125            // Names are terminated by an empty label of length 0
126            if len == 0 {
127                break;
128            }
129
130            outstr.push_str(delim);
131
132            let str_buffer = self.get_range(pos, len as usize)?;
133            outstr.push_str(
134                &String::from_utf8(str_buffer.to_vec())
135                    .context(UnicodeSnafu)?
136                    .to_lowercase(),
137            );
138
139            delim = ".";
140
141            pos += len as usize;
142        }
143
144        if !jumped {
145            self.seek(pos);
146        }
147
148        Ok(())
149    }
150
151    pub fn write(&mut self, val: u8) -> Result<()> {
152        ensure!(self.pos < 512, EndOfBufferSnafu);
153        self.buf[self.pos] = val;
154        self.pos += 1;
155        Ok(())
156    }
157
158    pub fn write_u8(&mut self, val: u8) -> Result<()> {
159        self.write(val)?;
160
161        Ok(())
162    }
163
164    pub fn write_u16(&mut self, val: u16) -> Result<()> {
165        self.write((val >> 8) as u8)?;
166        self.write((val & 0xFF) as u8)?;
167
168        Ok(())
169    }
170
171    pub fn write_u32(&mut self, val: u32) -> Result<()> {
172        self.write(((val >> 24) & 0xFF) as u8)?;
173        self.write(((val >> 16) & 0xFF) as u8)?;
174        self.write(((val >> 8) & 0xFF) as u8)?;
175        self.write((val & 0xFF) as u8)?;
176
177        Ok(())
178    }
179
180    pub fn write_qname(&mut self, qname: &str) -> Result<()> {
181        for label in qname.split('.') {
182            let len = label.len();
183            ensure!(len <= 0x34, LabelTooLongSnafu);
184
185            self.write_u8(len as u8)?;
186            for b in label.as_bytes() {
187                self.write_u8(*b)?;
188            }
189        }
190
191        self.write_u8(0)?;
192
193        Ok(())
194    }
195
196    pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
197        let l = bytes.len();
198        ensure!(self.pos + l < 512, EndOfBufferSnafu);
199
200        let byte_slice = &mut self.buf[self.pos..self.pos + l];
201        byte_slice.copy_from_slice(bytes);
202        Ok(())
203    }
204
205    fn set(&mut self, pos: usize, val: u8) {
206        self.buf[pos] = val;
207    }
208
209    pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
210        self.set(pos, (val >> 8) as u8);
211        self.set(pos + 1, (val & 0xFF) as u8);
212
213        Ok(())
214    }
215}