awdl_frame_parser/common/
awdl_dns_name.rs

1use core::{
2    fmt::{Display, Write},
3    iter::repeat,
4};
5use scroll::{
6    ctx::{MeasureWith, TryFromCtx, TryIntoCtx},
7    Pread, Pwrite, NETWORK,
8};
9
10use crate::tlvs::RawAWDLTLV;
11
12use super::{awdl_dns_compression::AWDLDnsCompression, awdl_str::AWDLStr};
13
14#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
15pub struct ReadLabelIterator<'a> {
16    bytes: &'a [u8],
17    offset: usize,
18}
19impl<'a> ReadLabelIterator<'a> {
20    pub const fn new(bytes: &'a [u8]) -> Self {
21        Self { bytes, offset: 0 }
22    }
23}
24impl MeasureWith<()> for ReadLabelIterator<'_> {
25    fn measure_with(&self, _ctx: &()) -> usize {
26        self.bytes.len()
27    }
28}
29impl<'a> Iterator for ReadLabelIterator<'a> {
30    type Item = AWDLStr<'a>;
31    fn next(&mut self) -> Option<Self::Item> {
32        self.bytes.gread(&mut self.offset).ok()
33    }
34}
35impl ExactSizeIterator for ReadLabelIterator<'_> {
36    fn len(&self) -> usize {
37        repeat(())
38            .scan(0usize, |offset, _| {
39                self.bytes.gread::<RawAWDLTLV>(offset).ok()
40            })
41            .count()
42    }
43}
44
45#[derive(Clone, Copy, Debug, Default, Hash)]
46/// A hostname combined with the [domain](AWDLDnsCompression).
47pub struct AWDLDnsName<I> {
48    /// The labels of the peer.
49    pub labels: I,
50
51    /// The domain in [compressed form](AWDLDnsCompression).
52    pub domain: AWDLDnsCompression,
53}
54impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Clone> Eq for AWDLDnsName<I> {}
55impl<'a, LhsIterator, RhsIterator> PartialEq<AWDLDnsName<RhsIterator>> for AWDLDnsName<LhsIterator>
56where
57    LhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
58    RhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
59{
60    fn eq(&self, other: &AWDLDnsName<RhsIterator>) -> bool {
61        self.labels.clone().into_iter().eq(other.labels.clone())
62    }
63}
64
65impl<'a, I> MeasureWith<()> for AWDLDnsName<I>
66where
67    I: IntoIterator<Item = AWDLStr<'a>> + Clone,
68{
69    fn measure_with(&self, ctx: &()) -> usize {
70        self.labels
71            .clone()
72            .into_iter()
73            .map(|label| label.measure_with(ctx))
74            .sum::<usize>()
75            + 2
76    }
77}
78impl<'a> TryFromCtx<'a> for AWDLDnsName<ReadLabelIterator<'a>> {
79    type Error = scroll::Error;
80    fn try_from_ctx(from: &'a [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> {
81        let mut offset = 0;
82        let label_bytes = from.gread_with(&mut offset, from.len() - 2)?;
83        let domain =
84            AWDLDnsCompression::from_bits(from.gread_with(&mut offset, NETWORK)?);
85        Ok((
86            Self {
87                labels: ReadLabelIterator::new(label_bytes),
88                domain,
89            },
90            offset,
91        ))
92    }
93}
94impl<'a, I> TryIntoCtx for AWDLDnsName<I>
95where
96    I: IntoIterator<Item = AWDLStr<'a>>,
97{
98    type Error = scroll::Error;
99    fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
100        let mut offset = 0;
101        // Using for loop, because of ? operator.
102        for x in self.labels {
103            buf.gwrite(x, &mut offset)?;
104        }
105        buf.gwrite_with(self.domain.into_bits(), &mut offset, NETWORK)?;
106        Ok(offset)
107    }
108}
109impl<'a, I> Display for AWDLDnsName<I>
110where
111    I: IntoIterator<Item = AWDLStr<'a>> + Clone,
112{
113    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114        for label in self.labels.clone() {
115            f.write_str(&label)?;
116            f.write_char('.')?;
117        }
118        f.write_str(self.domain.to_static_string())
119    }
120}
121/// The default awdl dns name returned by reading.
122pub type DefaultAWDLDnsName<'a> = AWDLDnsName<ReadLabelIterator<'a>>;
123#[cfg(test)]
124#[test]
125fn test_dns_name() {
126    use alloc::vec;
127    let bytes = [
128        0x04, b'a', b'w', b'd', b'l', 0x04, b'a', b'w', b'd', b'l', 0xc0, 0x0c,
129    ]
130    .as_slice();
131    let dns_name = bytes.pread::<DefaultAWDLDnsName>(0).unwrap();
132    assert_eq!(
133        dns_name,
134        AWDLDnsName {
135            labels: ["awdl".into(), "awdl".into()],
136            domain: AWDLDnsCompression::Local
137        }
138    );
139    let mut buf = vec![0x00; dns_name.measure_with(&())];
140    buf.pwrite(dns_name, 0).unwrap();
141    assert_eq!(bytes, buf);
142}