hickory_proto/serialize/binary/
encoder.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use core::marker::PhantomData;
9
10use alloc::vec::Vec;
11
12use crate::{
13    ProtoError,
14    error::{ProtoErrorKind, ProtoResult},
15    op::Header,
16};
17
18use super::BinEncodable;
19
20// this is private to make sure there is no accidental access to the inner buffer.
21mod private {
22    use alloc::vec::Vec;
23
24    use crate::error::{ProtoErrorKind, ProtoResult};
25
26    /// A wrapper for a buffer that guarantees writes never exceed a defined set of bytes
27    pub(super) struct MaximalBuf<'a> {
28        max_size: usize,
29        buffer: &'a mut Vec<u8>,
30    }
31
32    impl<'a> MaximalBuf<'a> {
33        pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
34            MaximalBuf {
35                max_size: max_size as usize,
36                buffer,
37            }
38        }
39
40        /// Sets the maximum size to enforce
41        pub(super) fn set_max_size(&mut self, max: u16) {
42            self.max_size = max as usize;
43        }
44
45        pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
46            debug_assert!(offset <= self.buffer.len());
47            if offset + data.len() > self.max_size {
48                return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
49            }
50
51            if offset == self.buffer.len() {
52                self.buffer.extend(data);
53                return Ok(());
54            }
55
56            let end = offset + data.len();
57            if end > self.buffer.len() {
58                self.buffer.resize(end, 0);
59            }
60
61            self.buffer[offset..end].copy_from_slice(data);
62            Ok(())
63        }
64
65        pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
66            let end = offset + len;
67            if end > self.max_size {
68                return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
69            }
70
71            self.buffer.resize(end, 0);
72            Ok(())
73        }
74
75        /// truncates are always safe
76        pub(super) fn truncate(&mut self, len: usize) {
77            self.buffer.truncate(len)
78        }
79
80        /// returns the length of the underlying buffer
81        pub(super) fn len(&self) -> usize {
82            self.buffer.len()
83        }
84
85        /// Immutable reads are always safe
86        pub(super) fn buffer(&'a self) -> &'a [u8] {
87            self.buffer as &'a [u8]
88        }
89
90        /// Returns a reference to the internal buffer
91        pub(super) fn into_bytes(self) -> &'a Vec<u8> {
92            self.buffer
93        }
94    }
95}
96
97/// Encode DNS messages and resource record types.
98pub struct BinEncoder<'a> {
99    offset: usize,
100    buffer: private::MaximalBuf<'a>,
101    /// start of label pointers with their labels in fully decompressed form for easy comparison, smallvec here?
102    name_pointers: Vec<(usize, Vec<u8>)>,
103    mode: EncodeMode,
104    canonical_names: bool,
105}
106
107impl<'a> BinEncoder<'a> {
108    /// Create a new encoder with the Vec to fill
109    pub fn new(buf: &'a mut Vec<u8>) -> Self {
110        Self::with_offset(buf, 0, EncodeMode::Normal)
111    }
112
113    /// Specify the mode for encoding
114    ///
115    /// # Arguments
116    ///
117    /// * `mode` - In Signing mode, canonical forms of all data are encoded, otherwise format matches the source form
118    pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
119        Self::with_offset(buf, 0, mode)
120    }
121
122    /// Begins the encoder at the given offset
123    ///
124    /// This is used for pointers. If this encoder is starting at some point further in
125    ///  the sequence of bytes, for the proper offset of the pointer, the offset accounts for that
126    ///  by using the offset to add to the pointer location being written.
127    ///
128    /// # Arguments
129    ///
130    /// * `offset` - index at which to start writing into the buffer
131    pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
132        if buf.capacity() < 512 {
133            let reserve = 512 - buf.capacity();
134            buf.reserve(reserve);
135        }
136
137        BinEncoder {
138            offset: offset as usize,
139            // TODO: add max_size to signature
140            buffer: private::MaximalBuf::new(u16::MAX, buf),
141            name_pointers: Vec::new(),
142            mode,
143            canonical_names: false,
144        }
145    }
146
147    // TODO: move to constructor (kept for backward compatibility)
148    /// Sets the maximum size of the buffer
149    ///
150    /// DNS message lens must be smaller than u16::max_value due to hard limits in the protocol
151    ///
152    /// *this method will move to the constructor in a future release*
153    pub fn set_max_size(&mut self, max: u16) {
154        self.buffer.set_max_size(max);
155    }
156
157    /// Returns a reference to the internal buffer
158    pub fn into_bytes(self) -> &'a Vec<u8> {
159        self.buffer.into_bytes()
160    }
161
162    /// Returns the length of the buffer
163    pub fn len(&self) -> usize {
164        self.buffer.len()
165    }
166
167    /// Returns `true` if the buffer is empty
168    pub fn is_empty(&self) -> bool {
169        self.buffer.buffer().is_empty()
170    }
171
172    /// Returns the current offset into the buffer
173    pub fn offset(&self) -> usize {
174        self.offset
175    }
176
177    /// sets the current offset to the new offset
178    pub fn set_offset(&mut self, offset: usize) {
179        self.offset = offset;
180    }
181
182    /// Returns the current Encoding mode
183    pub fn mode(&self) -> EncodeMode {
184        self.mode
185    }
186
187    /// If set to true, then names will be written into the buffer in canonical form
188    pub fn set_canonical_names(&mut self, canonical_names: bool) {
189        self.canonical_names = canonical_names;
190    }
191
192    /// Returns true if then encoder is writing in canonical form
193    pub fn is_canonical_names(&self) -> bool {
194        self.canonical_names
195    }
196
197    /// Emit all names in canonical form, useful for <https://tools.ietf.org/html/rfc3597>
198    pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
199        &mut self,
200        f: F,
201    ) -> ProtoResult<()> {
202        let was_canonical = self.is_canonical_names();
203        self.set_canonical_names(true);
204
205        let res = f(self);
206        self.set_canonical_names(was_canonical);
207
208        res
209    }
210
211    // TODO: deprecate this...
212    /// Reserve specified additional length in the internal buffer.
213    pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
214        Ok(())
215    }
216
217    /// trims to the current offset
218    pub fn trim(&mut self) {
219        let offset = self.offset;
220        self.buffer.truncate(offset);
221        self.name_pointers.retain(|&(start, _)| start < offset);
222    }
223
224    // /// returns an error if the maximum buffer size would be exceeded with the addition number of elements
225    // ///
226    // /// and reserves the additional space in the buffer
227    // fn enforce_size(&mut self, additional: usize) -> ProtoResult<()> {
228    //     if (self.buffer.len() + additional) > self.max_size {
229    //         Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into())
230    //     } else {
231    //         self.reserve(additional);
232    //         Ok(())
233    //     }
234    // }
235
236    /// borrow a slice from the encoder
237    pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
238        assert!(start < self.offset);
239        assert!(end <= self.buffer.len());
240        &self.buffer.buffer()[start..end]
241    }
242
243    /// Stores a label pointer to an already written label
244    ///
245    /// The location is the current position in the buffer
246    ///  implicitly, it is expected that the name will be written to the stream after the current index.
247    pub fn store_label_pointer(&mut self, start: usize, end: usize) {
248        assert!(start <= (u16::MAX as usize));
249        assert!(end <= (u16::MAX as usize));
250        assert!(start <= end);
251        if self.offset < 0x3FFF_usize {
252            self.name_pointers
253                .push((start, self.slice_of(start, end).to_vec())); // the next char will be at the len() location
254        }
255    }
256
257    /// Looks up the index of an already written label
258    pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
259        let search = self.slice_of(start, end);
260
261        for (match_start, matcher) in &self.name_pointers {
262            if matcher.as_slice() == search {
263                assert!(match_start <= &(u16::MAX as usize));
264                return Some(*match_start as u16);
265            }
266        }
267
268        None
269    }
270
271    /// Emit one byte into the buffer
272    pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
273        self.buffer.write(self.offset, &[b])?;
274        self.offset += 1;
275        Ok(())
276    }
277
278    /// matches description from above.
279    ///
280    /// ```
281    /// use hickory_proto::serialize::binary::BinEncoder;
282    ///
283    /// let mut bytes: Vec<u8> = Vec::new();
284    /// {
285    ///   let mut encoder: BinEncoder = BinEncoder::new(&mut bytes);
286    ///   encoder.emit_character_data("abc");
287    /// }
288    /// assert_eq!(bytes, vec![3,b'a',b'b',b'c']);
289    /// ```
290    pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
291        let char_bytes = char_data.as_ref();
292        if char_bytes.len() > 255 {
293            return Err(ProtoErrorKind::CharacterDataTooLong {
294                max: 255,
295                len: char_bytes.len(),
296            }
297            .into());
298        }
299
300        self.emit_character_data_unrestricted(char_data)
301    }
302
303    /// Emit character data of unrestricted length
304    ///
305    /// Although character strings are typically restricted to being no longer than 255 characters,
306    /// some modern standards allow longer strings to be encoded.
307    pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
308        // first the length is written
309        let data = data.as_ref();
310        self.emit(data.len() as u8)?;
311        self.write_slice(data)
312    }
313
314    /// Emit one byte into the buffer
315    pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
316        self.emit(data)
317    }
318
319    /// Writes a u16 in network byte order to the buffer
320    pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
321        self.write_slice(&data.to_be_bytes())
322    }
323
324    /// Writes an i32 in network byte order to the buffer
325    pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
326        self.write_slice(&data.to_be_bytes())
327    }
328
329    /// Writes an u32 in network byte order to the buffer
330    pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
331        self.write_slice(&data.to_be_bytes())
332    }
333
334    fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
335        self.buffer.write(self.offset, data)?;
336        self.offset += data.len();
337        Ok(())
338    }
339
340    /// Writes the byte slice to the stream
341    pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
342        self.write_slice(data)
343    }
344
345    /// Emits all the elements of an Iterator to the encoder
346    pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
347        &mut self,
348        mut iter: I,
349    ) -> ProtoResult<usize> {
350        self.emit_iter(&mut iter)
351    }
352
353    // TODO: dedup with above emit_all
354    /// Emits all the elements of an Iterator to the encoder
355    pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
356    where
357        'e: 'r,
358        I: Iterator<Item = &'r &'e E>,
359        E: 'r + 'e + BinEncodable,
360    {
361        let mut iter = iter.cloned();
362        self.emit_iter(&mut iter)
363    }
364
365    /// emits all items in the iterator, return the number emitted
366    #[allow(clippy::needless_return)]
367    pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
368        &mut self,
369        iter: &mut I,
370    ) -> ProtoResult<usize> {
371        let mut count = 0;
372        for i in iter {
373            let rollback = self.set_rollback();
374            if let Err(e) = i.emit(self) {
375                return Err(match e.kind() {
376                    ProtoErrorKind::MaxBufferSizeExceeded(_) => {
377                        rollback.rollback(self);
378                        ProtoError::from(ProtoErrorKind::NotAllRecordsWritten { count })
379                    }
380                    _ => e,
381                });
382            }
383
384            count += 1;
385        }
386        Ok(count)
387    }
388
389    /// capture a location to write back to
390    pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
391        let index = self.offset;
392        let len = T::size_of();
393
394        // resize the buffer
395        self.buffer.reserve(self.offset, len)?;
396
397        // update the offset
398        self.offset += len;
399
400        Ok(Place {
401            start_index: index,
402            phantom: PhantomData,
403        })
404    }
405
406    /// calculates the length of data written since the place was creating
407    pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
408        (self.offset - place.start_index) - place.size_of()
409    }
410
411    /// write back to a previously captured location
412    pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
413        // preserve current index
414        let current_index = self.offset;
415
416        // reset the current index back to place before writing
417        //   this is an assert because it's programming error for it to be wrong.
418        assert!(place.start_index < current_index);
419        self.offset = place.start_index;
420
421        // emit the data to be written at this place
422        let emit_result = data.emit(self);
423
424        // double check that the current number of bytes were written
425        //   this is an assert because it's programming error for it to be wrong.
426        assert!((self.offset - place.start_index) == place.size_of());
427
428        // reset to original location
429        self.offset = current_index;
430
431        emit_result
432    }
433
434    fn set_rollback(&self) -> Rollback {
435        Rollback {
436            offset: self.offset(),
437            pointers: self.name_pointers.len(),
438        }
439    }
440}
441
442/// A trait to return the size of a type as it will be encoded in DNS
443///
444/// it does not necessarily equal `core::mem::size_of`, though it might, especially for primitives
445pub trait EncodedSize: BinEncodable {
446    /// Return the size in bytes of the
447    fn size_of() -> usize;
448}
449
450impl EncodedSize for u16 {
451    fn size_of() -> usize {
452        2
453    }
454}
455
456impl EncodedSize for Header {
457    fn size_of() -> usize {
458        Self::len()
459    }
460}
461
462#[derive(Debug)]
463#[must_use = "data must be written back to the place"]
464pub struct Place<T: EncodedSize> {
465    start_index: usize,
466    phantom: PhantomData<T>,
467}
468
469impl<T: EncodedSize> Place<T> {
470    pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
471        encoder.emit_at(self, data)
472    }
473
474    pub fn size_of(&self) -> usize {
475        T::size_of()
476    }
477}
478
479/// A type representing a rollback point in a stream
480pub(crate) struct Rollback {
481    offset: usize,
482    pointers: usize,
483}
484
485impl Rollback {
486    pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
487        let Self { offset, pointers } = self;
488        encoder.set_offset(offset);
489        encoder.name_pointers.truncate(pointers);
490    }
491}
492
493/// In the Verify mode there maybe some things which are encoded differently, e.g. SIG0 records
494///  should not be included in the additional count and not in the encoded data when in Verify
495#[derive(Copy, Clone, Eq, PartialEq)]
496pub enum EncodeMode {
497    /// In signing mode records are written in canonical form
498    Signing,
499    /// Write records in standard format
500    Normal,
501}
502
503#[cfg(test)]
504mod tests {
505    use core::str::FromStr;
506
507    use super::*;
508    use crate::{
509        op::{Message, Query},
510        rr::{
511            RData, Record, RecordType,
512            rdata::{CNAME, SRV},
513        },
514        serialize::binary::BinDecodable,
515    };
516    use crate::{rr::Name, serialize::binary::BinDecoder};
517
518    #[test]
519    fn test_label_compression_regression() {
520        // https://github.com/hickory-dns/hickory-dns/issues/339
521        /*
522        ;; QUESTION SECTION:
523        ;bluedot.is.autonavi.com.gds.alibabadns.com. IN AAAA
524
525        ;; AUTHORITY SECTION:
526        gds.alibabadns.com.     1799    IN      SOA     gdsns1.alibabadns.com. none. 2015080610 1800 600 3600 360
527        */
528        let data = vec![
529            154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
530            115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
531            97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
532            0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
533            110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
534            0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
535        ];
536
537        let msg = Message::from_vec(&data).unwrap();
538        msg.to_bytes().unwrap();
539    }
540
541    #[test]
542    fn test_size_of() {
543        assert_eq!(u16::size_of(), 2);
544    }
545
546    #[test]
547    fn test_place() {
548        let mut buf = vec![];
549        {
550            let mut encoder = BinEncoder::new(&mut buf);
551            let place = encoder.place::<u16>().unwrap();
552            assert_eq!(place.size_of(), 2);
553            assert_eq!(encoder.len_since_place(&place), 0);
554
555            encoder.emit(42_u8).expect("failed 0");
556            assert_eq!(encoder.len_since_place(&place), 1);
557
558            encoder.emit(48_u8).expect("failed 1");
559            assert_eq!(encoder.len_since_place(&place), 2);
560
561            place
562                .replace(&mut encoder, 4_u16)
563                .expect("failed to replace");
564            drop(encoder);
565        }
566
567        assert_eq!(buf.len(), 4);
568
569        let mut decoder = BinDecoder::new(&buf);
570        let written = decoder.read_u16().expect("cound not read u16").unverified();
571
572        assert_eq!(written, 4);
573    }
574
575    #[test]
576    fn test_max_size() {
577        let mut buf = vec![];
578        let mut encoder = BinEncoder::new(&mut buf);
579
580        encoder.set_max_size(5);
581        encoder.emit(0).expect("failed to write");
582        encoder.emit(1).expect("failed to write");
583        encoder.emit(2).expect("failed to write");
584        encoder.emit(3).expect("failed to write");
585        encoder.emit(4).expect("failed to write");
586        let error = encoder.emit(5).unwrap_err();
587
588        match error.kind() {
589            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
590            _ => panic!(),
591        }
592    }
593
594    #[test]
595    fn test_max_size_0() {
596        let mut buf = vec![];
597        let mut encoder = BinEncoder::new(&mut buf);
598
599        encoder.set_max_size(0);
600        let error = encoder.emit(0).unwrap_err();
601
602        match error.kind() {
603            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
604            _ => panic!(),
605        }
606    }
607
608    #[test]
609    fn test_max_size_place() {
610        let mut buf = vec![];
611        let mut encoder = BinEncoder::new(&mut buf);
612
613        encoder.set_max_size(2);
614        let place = encoder.place::<u16>().expect("place failed");
615        place.replace(&mut encoder, 16).expect("placeback failed");
616
617        let error = encoder.place::<u16>().unwrap_err();
618
619        match error.kind() {
620            ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
621            _ => panic!(),
622        }
623    }
624
625    #[test]
626    fn test_target_compression() {
627        let mut msg = Message::new();
628        msg.add_query(Query::query(
629            Name::from_str("www.google.com.").unwrap(),
630            RecordType::A,
631        ))
632        .add_answer(Record::from_rdata(
633            Name::from_str("www.google.com.").unwrap(),
634            0,
635            RData::SRV(SRV::new(
636                0,
637                0,
638                0,
639                Name::from_str("www.compressme.com.").unwrap(),
640            )),
641        ))
642        .add_additional(Record::from_rdata(
643            Name::from_str("www.google.com.").unwrap(),
644            0,
645            RData::SRV(SRV::new(
646                0,
647                0,
648                0,
649                Name::from_str("www.compressme.com.").unwrap(),
650            )),
651        ))
652        // name here should use compressed label from target in previous records
653        .add_answer(Record::from_rdata(
654            Name::from_str("www.compressme.com.").unwrap(),
655            0,
656            RData::CNAME(CNAME(Name::from_str("www.foo.com.").unwrap())),
657        ));
658
659        let bytes = msg.to_vec().unwrap();
660        // label is compressed pointing to target, would be 145 otherwise
661        assert_eq!(bytes.len(), 130);
662        // check re-serializing
663        assert!(Message::from_vec(&bytes).is_ok());
664    }
665
666    #[test]
667    fn test_fuzzed() {
668        const MESSAGE: &[u8] = include_bytes!("../../../tests/test-data/fuzz-long.rdata");
669        let msg = Message::from_bytes(MESSAGE).unwrap();
670        msg.to_bytes().unwrap();
671    }
672}