1use core::marker::PhantomData;
9
10use alloc::vec::Vec;
11
12use crate::{
13 ProtoError,
14 error::{ProtoErrorKind, ProtoResult},
15 op::Header,
16};
17
18use super::BinEncodable;
19
20mod private {
22 use alloc::vec::Vec;
23
24 use crate::error::{ProtoErrorKind, ProtoResult};
25
26 pub(super) struct MaximalBuf<'a> {
28 max_size: usize,
29 buffer: &'a mut Vec<u8>,
30 }
31
32 impl<'a> MaximalBuf<'a> {
33 pub(super) fn new(max_size: u16, buffer: &'a mut Vec<u8>) -> Self {
34 MaximalBuf {
35 max_size: max_size as usize,
36 buffer,
37 }
38 }
39
40 pub(super) fn set_max_size(&mut self, max: u16) {
42 self.max_size = max as usize;
43 }
44
45 pub(super) fn write(&mut self, offset: usize, data: &[u8]) -> ProtoResult<()> {
46 debug_assert!(offset <= self.buffer.len());
47 if offset + data.len() > self.max_size {
48 return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
49 }
50
51 if offset == self.buffer.len() {
52 self.buffer.extend(data);
53 return Ok(());
54 }
55
56 let end = offset + data.len();
57 if end > self.buffer.len() {
58 self.buffer.resize(end, 0);
59 }
60
61 self.buffer[offset..end].copy_from_slice(data);
62 Ok(())
63 }
64
65 pub(super) fn reserve(&mut self, offset: usize, len: usize) -> ProtoResult<()> {
66 let end = offset + len;
67 if end > self.max_size {
68 return Err(ProtoErrorKind::MaxBufferSizeExceeded(self.max_size).into());
69 }
70
71 self.buffer.resize(end, 0);
72 Ok(())
73 }
74
75 pub(super) fn truncate(&mut self, len: usize) {
77 self.buffer.truncate(len)
78 }
79
80 pub(super) fn len(&self) -> usize {
82 self.buffer.len()
83 }
84
85 pub(super) fn buffer(&'a self) -> &'a [u8] {
87 self.buffer as &'a [u8]
88 }
89
90 pub(super) fn into_bytes(self) -> &'a Vec<u8> {
92 self.buffer
93 }
94 }
95}
96
97pub struct BinEncoder<'a> {
99 offset: usize,
100 buffer: private::MaximalBuf<'a>,
101 name_pointers: Vec<(usize, Vec<u8>)>,
103 mode: EncodeMode,
104 canonical_names: bool,
105}
106
107impl<'a> BinEncoder<'a> {
108 pub fn new(buf: &'a mut Vec<u8>) -> Self {
110 Self::with_offset(buf, 0, EncodeMode::Normal)
111 }
112
113 pub fn with_mode(buf: &'a mut Vec<u8>, mode: EncodeMode) -> Self {
119 Self::with_offset(buf, 0, mode)
120 }
121
122 pub fn with_offset(buf: &'a mut Vec<u8>, offset: u32, mode: EncodeMode) -> Self {
132 if buf.capacity() < 512 {
133 let reserve = 512 - buf.capacity();
134 buf.reserve(reserve);
135 }
136
137 BinEncoder {
138 offset: offset as usize,
139 buffer: private::MaximalBuf::new(u16::MAX, buf),
141 name_pointers: Vec::new(),
142 mode,
143 canonical_names: false,
144 }
145 }
146
147 pub fn set_max_size(&mut self, max: u16) {
154 self.buffer.set_max_size(max);
155 }
156
157 pub fn into_bytes(self) -> &'a Vec<u8> {
159 self.buffer.into_bytes()
160 }
161
162 pub fn len(&self) -> usize {
164 self.buffer.len()
165 }
166
167 pub fn is_empty(&self) -> bool {
169 self.buffer.buffer().is_empty()
170 }
171
172 pub fn offset(&self) -> usize {
174 self.offset
175 }
176
177 pub fn set_offset(&mut self, offset: usize) {
179 self.offset = offset;
180 }
181
182 pub fn mode(&self) -> EncodeMode {
184 self.mode
185 }
186
187 pub fn set_canonical_names(&mut self, canonical_names: bool) {
189 self.canonical_names = canonical_names;
190 }
191
192 pub fn is_canonical_names(&self) -> bool {
194 self.canonical_names
195 }
196
197 pub fn with_canonical_names<F: FnOnce(&mut Self) -> ProtoResult<()>>(
199 &mut self,
200 f: F,
201 ) -> ProtoResult<()> {
202 let was_canonical = self.is_canonical_names();
203 self.set_canonical_names(true);
204
205 let res = f(self);
206 self.set_canonical_names(was_canonical);
207
208 res
209 }
210
211 pub fn reserve(&mut self, _additional: usize) -> ProtoResult<()> {
214 Ok(())
215 }
216
217 pub fn trim(&mut self) {
219 let offset = self.offset;
220 self.buffer.truncate(offset);
221 self.name_pointers.retain(|&(start, _)| start < offset);
222 }
223
224 pub fn slice_of(&self, start: usize, end: usize) -> &[u8] {
238 assert!(start < self.offset);
239 assert!(end <= self.buffer.len());
240 &self.buffer.buffer()[start..end]
241 }
242
243 pub fn store_label_pointer(&mut self, start: usize, end: usize) {
248 assert!(start <= (u16::MAX as usize));
249 assert!(end <= (u16::MAX as usize));
250 assert!(start <= end);
251 if self.offset < 0x3FFF_usize {
252 self.name_pointers
253 .push((start, self.slice_of(start, end).to_vec())); }
255 }
256
257 pub fn get_label_pointer(&self, start: usize, end: usize) -> Option<u16> {
259 let search = self.slice_of(start, end);
260
261 for (match_start, matcher) in &self.name_pointers {
262 if matcher.as_slice() == search {
263 assert!(match_start <= &(u16::MAX as usize));
264 return Some(*match_start as u16);
265 }
266 }
267
268 None
269 }
270
271 pub fn emit(&mut self, b: u8) -> ProtoResult<()> {
273 self.buffer.write(self.offset, &[b])?;
274 self.offset += 1;
275 Ok(())
276 }
277
278 pub fn emit_character_data<S: AsRef<[u8]>>(&mut self, char_data: S) -> ProtoResult<()> {
291 let char_bytes = char_data.as_ref();
292 if char_bytes.len() > 255 {
293 return Err(ProtoErrorKind::CharacterDataTooLong {
294 max: 255,
295 len: char_bytes.len(),
296 }
297 .into());
298 }
299
300 self.emit_character_data_unrestricted(char_data)
301 }
302
303 pub fn emit_character_data_unrestricted<S: AsRef<[u8]>>(&mut self, data: S) -> ProtoResult<()> {
308 let data = data.as_ref();
310 self.emit(data.len() as u8)?;
311 self.write_slice(data)
312 }
313
314 pub fn emit_u8(&mut self, data: u8) -> ProtoResult<()> {
316 self.emit(data)
317 }
318
319 pub fn emit_u16(&mut self, data: u16) -> ProtoResult<()> {
321 self.write_slice(&data.to_be_bytes())
322 }
323
324 pub fn emit_i32(&mut self, data: i32) -> ProtoResult<()> {
326 self.write_slice(&data.to_be_bytes())
327 }
328
329 pub fn emit_u32(&mut self, data: u32) -> ProtoResult<()> {
331 self.write_slice(&data.to_be_bytes())
332 }
333
334 fn write_slice(&mut self, data: &[u8]) -> ProtoResult<()> {
335 self.buffer.write(self.offset, data)?;
336 self.offset += data.len();
337 Ok(())
338 }
339
340 pub fn emit_vec(&mut self, data: &[u8]) -> ProtoResult<()> {
342 self.write_slice(data)
343 }
344
345 pub fn emit_all<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
347 &mut self,
348 mut iter: I,
349 ) -> ProtoResult<usize> {
350 self.emit_iter(&mut iter)
351 }
352
353 pub fn emit_all_refs<'r, 'e, I, E>(&mut self, iter: I) -> ProtoResult<usize>
356 where
357 'e: 'r,
358 I: Iterator<Item = &'r &'e E>,
359 E: 'r + 'e + BinEncodable,
360 {
361 let mut iter = iter.cloned();
362 self.emit_iter(&mut iter)
363 }
364
365 #[allow(clippy::needless_return)]
367 pub fn emit_iter<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable>(
368 &mut self,
369 iter: &mut I,
370 ) -> ProtoResult<usize> {
371 let mut count = 0;
372 for i in iter {
373 let rollback = self.set_rollback();
374 if let Err(e) = i.emit(self) {
375 return Err(match e.kind() {
376 ProtoErrorKind::MaxBufferSizeExceeded(_) => {
377 rollback.rollback(self);
378 ProtoError::from(ProtoErrorKind::NotAllRecordsWritten { count })
379 }
380 _ => e,
381 });
382 }
383
384 count += 1;
385 }
386 Ok(count)
387 }
388
389 pub fn place<T: EncodedSize>(&mut self) -> ProtoResult<Place<T>> {
391 let index = self.offset;
392 let len = T::size_of();
393
394 self.buffer.reserve(self.offset, len)?;
396
397 self.offset += len;
399
400 Ok(Place {
401 start_index: index,
402 phantom: PhantomData,
403 })
404 }
405
406 pub fn len_since_place<T: EncodedSize>(&self, place: &Place<T>) -> usize {
408 (self.offset - place.start_index) - place.size_of()
409 }
410
411 pub fn emit_at<T: EncodedSize>(&mut self, place: Place<T>, data: T) -> ProtoResult<()> {
413 let current_index = self.offset;
415
416 assert!(place.start_index < current_index);
419 self.offset = place.start_index;
420
421 let emit_result = data.emit(self);
423
424 assert!((self.offset - place.start_index) == place.size_of());
427
428 self.offset = current_index;
430
431 emit_result
432 }
433
434 fn set_rollback(&self) -> Rollback {
435 Rollback {
436 offset: self.offset(),
437 pointers: self.name_pointers.len(),
438 }
439 }
440}
441
442pub trait EncodedSize: BinEncodable {
446 fn size_of() -> usize;
448}
449
450impl EncodedSize for u16 {
451 fn size_of() -> usize {
452 2
453 }
454}
455
456impl EncodedSize for Header {
457 fn size_of() -> usize {
458 Self::len()
459 }
460}
461
462#[derive(Debug)]
463#[must_use = "data must be written back to the place"]
464pub struct Place<T: EncodedSize> {
465 start_index: usize,
466 phantom: PhantomData<T>,
467}
468
469impl<T: EncodedSize> Place<T> {
470 pub fn replace(self, encoder: &mut BinEncoder<'_>, data: T) -> ProtoResult<()> {
471 encoder.emit_at(self, data)
472 }
473
474 pub fn size_of(&self) -> usize {
475 T::size_of()
476 }
477}
478
479pub(crate) struct Rollback {
481 offset: usize,
482 pointers: usize,
483}
484
485impl Rollback {
486 pub(crate) fn rollback(self, encoder: &mut BinEncoder<'_>) {
487 let Self { offset, pointers } = self;
488 encoder.set_offset(offset);
489 encoder.name_pointers.truncate(pointers);
490 }
491}
492
493#[derive(Copy, Clone, Eq, PartialEq)]
496pub enum EncodeMode {
497 Signing,
499 Normal,
501}
502
503#[cfg(test)]
504mod tests {
505 use core::str::FromStr;
506
507 use super::*;
508 use crate::{
509 op::{Message, Query},
510 rr::{
511 RData, Record, RecordType,
512 rdata::{CNAME, SRV},
513 },
514 serialize::binary::BinDecodable,
515 };
516 use crate::{rr::Name, serialize::binary::BinDecoder};
517
518 #[test]
519 fn test_label_compression_regression() {
520 let data = vec![
529 154, 50, 129, 128, 0, 1, 0, 0, 0, 1, 0, 1, 7, 98, 108, 117, 101, 100, 111, 116, 2, 105,
530 115, 8, 97, 117, 116, 111, 110, 97, 118, 105, 3, 99, 111, 109, 3, 103, 100, 115, 10,
531 97, 108, 105, 98, 97, 98, 97, 100, 110, 115, 3, 99, 111, 109, 0, 0, 28, 0, 1, 192, 36,
532 0, 6, 0, 1, 0, 0, 7, 7, 0, 35, 6, 103, 100, 115, 110, 115, 49, 192, 40, 4, 110, 111,
533 110, 101, 0, 120, 27, 176, 162, 0, 0, 7, 8, 0, 0, 2, 88, 0, 0, 14, 16, 0, 0, 1, 104, 0,
534 0, 41, 2, 0, 0, 0, 0, 0, 0, 0,
535 ];
536
537 let msg = Message::from_vec(&data).unwrap();
538 msg.to_bytes().unwrap();
539 }
540
541 #[test]
542 fn test_size_of() {
543 assert_eq!(u16::size_of(), 2);
544 }
545
546 #[test]
547 fn test_place() {
548 let mut buf = vec![];
549 {
550 let mut encoder = BinEncoder::new(&mut buf);
551 let place = encoder.place::<u16>().unwrap();
552 assert_eq!(place.size_of(), 2);
553 assert_eq!(encoder.len_since_place(&place), 0);
554
555 encoder.emit(42_u8).expect("failed 0");
556 assert_eq!(encoder.len_since_place(&place), 1);
557
558 encoder.emit(48_u8).expect("failed 1");
559 assert_eq!(encoder.len_since_place(&place), 2);
560
561 place
562 .replace(&mut encoder, 4_u16)
563 .expect("failed to replace");
564 drop(encoder);
565 }
566
567 assert_eq!(buf.len(), 4);
568
569 let mut decoder = BinDecoder::new(&buf);
570 let written = decoder.read_u16().expect("cound not read u16").unverified();
571
572 assert_eq!(written, 4);
573 }
574
575 #[test]
576 fn test_max_size() {
577 let mut buf = vec![];
578 let mut encoder = BinEncoder::new(&mut buf);
579
580 encoder.set_max_size(5);
581 encoder.emit(0).expect("failed to write");
582 encoder.emit(1).expect("failed to write");
583 encoder.emit(2).expect("failed to write");
584 encoder.emit(3).expect("failed to write");
585 encoder.emit(4).expect("failed to write");
586 let error = encoder.emit(5).unwrap_err();
587
588 match error.kind() {
589 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
590 _ => panic!(),
591 }
592 }
593
594 #[test]
595 fn test_max_size_0() {
596 let mut buf = vec![];
597 let mut encoder = BinEncoder::new(&mut buf);
598
599 encoder.set_max_size(0);
600 let error = encoder.emit(0).unwrap_err();
601
602 match error.kind() {
603 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
604 _ => panic!(),
605 }
606 }
607
608 #[test]
609 fn test_max_size_place() {
610 let mut buf = vec![];
611 let mut encoder = BinEncoder::new(&mut buf);
612
613 encoder.set_max_size(2);
614 let place = encoder.place::<u16>().expect("place failed");
615 place.replace(&mut encoder, 16).expect("placeback failed");
616
617 let error = encoder.place::<u16>().unwrap_err();
618
619 match error.kind() {
620 ProtoErrorKind::MaxBufferSizeExceeded(_) => (),
621 _ => panic!(),
622 }
623 }
624
625 #[test]
626 fn test_target_compression() {
627 let mut msg = Message::new();
628 msg.add_query(Query::query(
629 Name::from_str("www.google.com.").unwrap(),
630 RecordType::A,
631 ))
632 .add_answer(Record::from_rdata(
633 Name::from_str("www.google.com.").unwrap(),
634 0,
635 RData::SRV(SRV::new(
636 0,
637 0,
638 0,
639 Name::from_str("www.compressme.com.").unwrap(),
640 )),
641 ))
642 .add_additional(Record::from_rdata(
643 Name::from_str("www.google.com.").unwrap(),
644 0,
645 RData::SRV(SRV::new(
646 0,
647 0,
648 0,
649 Name::from_str("www.compressme.com.").unwrap(),
650 )),
651 ))
652 .add_answer(Record::from_rdata(
654 Name::from_str("www.compressme.com.").unwrap(),
655 0,
656 RData::CNAME(CNAME(Name::from_str("www.foo.com.").unwrap())),
657 ));
658
659 let bytes = msg.to_vec().unwrap();
660 assert_eq!(bytes.len(), 130);
662 assert!(Message::from_vec(&bytes).is_ok());
664 }
665
666 #[test]
667 fn test_fuzzed() {
668 const MESSAGE: &[u8] = include_bytes!("../../../tests/test-data/fuzz-long.rdata");
669 let msg = Message::from_bytes(MESSAGE).unwrap();
670 msg.to_bytes().unwrap();
671 }
672}