awdl_frame_parser/common/
awdl_dns_name.rs1use core::{
2 fmt::{Display, Write},
3 iter::repeat,
4};
5use scroll::{
6 ctx::{MeasureWith, TryFromCtx, TryIntoCtx},
7 Pread, Pwrite, NETWORK,
8};
9
10use crate::tlvs::RawAWDLTLV;
11
12use super::{awdl_dns_compression::AWDLDnsCompression, awdl_str::AWDLStr};
13
14#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
15pub struct ReadLabelIterator<'a> {
16 bytes: &'a [u8],
17 offset: usize,
18}
19impl<'a> ReadLabelIterator<'a> {
20 pub const fn new(bytes: &'a [u8]) -> Self {
21 Self { bytes, offset: 0 }
22 }
23}
24impl MeasureWith<()> for ReadLabelIterator<'_> {
25 fn measure_with(&self, _ctx: &()) -> usize {
26 self.bytes.len()
27 }
28}
29impl<'a> Iterator for ReadLabelIterator<'a> {
30 type Item = AWDLStr<'a>;
31 fn next(&mut self) -> Option<Self::Item> {
32 self.bytes.gread(&mut self.offset).ok()
33 }
34}
35impl ExactSizeIterator for ReadLabelIterator<'_> {
36 fn len(&self) -> usize {
37 repeat(())
38 .scan(0usize, |offset, _| {
39 self.bytes.gread::<RawAWDLTLV>(offset).ok()
40 })
41 .count()
42 }
43}
44
45#[derive(Clone, Copy, Debug, Default, Hash)]
46pub struct AWDLDnsName<I> {
48 pub labels: I,
50
51 pub domain: AWDLDnsCompression,
53}
54impl<'a, I: IntoIterator<Item = AWDLStr<'a>> + Clone> Eq for AWDLDnsName<I> {}
55impl<'a, LhsIterator, RhsIterator> PartialEq<AWDLDnsName<RhsIterator>> for AWDLDnsName<LhsIterator>
56where
57 LhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
58 RhsIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
59{
60 fn eq(&self, other: &AWDLDnsName<RhsIterator>) -> bool {
61 self.labels.clone().into_iter().eq(other.labels.clone())
62 }
63}
64
65impl<'a, I> MeasureWith<()> for AWDLDnsName<I>
66where
67 I: IntoIterator<Item = AWDLStr<'a>> + Clone,
68{
69 fn measure_with(&self, ctx: &()) -> usize {
70 self.labels
71 .clone()
72 .into_iter()
73 .map(|label| label.measure_with(ctx))
74 .sum::<usize>()
75 + 2
76 }
77}
78impl<'a> TryFromCtx<'a> for AWDLDnsName<ReadLabelIterator<'a>> {
79 type Error = scroll::Error;
80 fn try_from_ctx(from: &'a [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> {
81 let mut offset = 0;
82 let label_bytes = from.gread_with(&mut offset, from.len() - 2)?;
83 let domain =
84 AWDLDnsCompression::from_bits(from.gread_with(&mut offset, NETWORK)?);
85 Ok((
86 Self {
87 labels: ReadLabelIterator::new(label_bytes),
88 domain,
89 },
90 offset,
91 ))
92 }
93}
94impl<'a, I> TryIntoCtx for AWDLDnsName<I>
95where
96 I: IntoIterator<Item = AWDLStr<'a>>,
97{
98 type Error = scroll::Error;
99 fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
100 let mut offset = 0;
101 for x in self.labels {
103 buf.gwrite(x, &mut offset)?;
104 }
105 buf.gwrite_with(self.domain.into_bits(), &mut offset, NETWORK)?;
106 Ok(offset)
107 }
108}
109impl<'a, I> Display for AWDLDnsName<I>
110where
111 I: IntoIterator<Item = AWDLStr<'a>> + Clone,
112{
113 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
114 for label in self.labels.clone() {
115 f.write_str(&label)?;
116 f.write_char('.')?;
117 }
118 f.write_str(self.domain.to_static_string())
119 }
120}
121pub type DefaultAWDLDnsName<'a> = AWDLDnsName<ReadLabelIterator<'a>>;
123#[cfg(test)]
124#[test]
125fn test_dns_name() {
126 use alloc::vec;
127 let bytes = [
128 0x04, b'a', b'w', b'd', b'l', 0x04, b'a', b'w', b'd', b'l', 0xc0, 0x0c,
129 ]
130 .as_slice();
131 let dns_name = bytes.pread::<DefaultAWDLDnsName>(0).unwrap();
132 assert_eq!(
133 dns_name,
134 AWDLDnsName {
135 labels: ["awdl".into(), "awdl".into()],
136 domain: AWDLDnsCompression::Local
137 }
138 );
139 let mut buf = vec![0x00; dns_name.measure_with(&())];
140 buf.pwrite(dns_name, 0).unwrap();
141 assert_eq!(bytes, buf);
142}