awdl_frame_parser/tlvs/
mod.rs

1/// TLVs regarding the data path.
2pub mod data_path;
3/// TLVs containing data about dns services.
4pub mod dns_sd;
5/// TLVs about the synchronization and election state of the peer.
6pub mod sync_elect;
7pub mod version;
8use core::{fmt::Debug, iter::repeat, marker::PhantomData};
9
10use mac_parser::MACAddress;
11use macro_bits::serializable_enum;
12use scroll::{
13    ctx::{MeasureWith, TryFromCtx, TryIntoCtx},
14    Endian, Pread, Pwrite,
15};
16use tlv_rs::{raw_tlv::RawTLV, TLV};
17
18use crate::common::{AWDLStr, ReadLabelIterator};
19
20use self::{
21    data_path::{DataPathStateTLV, HTCapabilitiesTLV, IEEE80211ContainerTLV},
22    dns_sd::{ArpaTLV, ServiceResponseTLV},
23    sync_elect::{
24        ChannelSequenceTLV, ElectionParametersTLV, ElectionParametersV2TLV, ReadMACIterator,
25        SyncTreeTLV, SynchronizationParametersTLV,
26    },
27    version::VersionTLV,
28};
29
30serializable_enum! {
31    #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
32    /// The type of the TLV.
33    pub enum AWDLTLVType: u8 {
34        #[default]
35        /// Required for `tlv-rs`.
36        Null => 0x00,
37
38        /// The service parameters.
39        ServiceResponse => 0x02,
40
41        /// The synchronization parameters.
42        SynchronizationParameters => 0x04,
43
44        /// The election parameters.
45        ElectionParameters => 0x05,
46
47        // The service parameters.
48        //ServiceParameters => 0x06,
49
50        /// The HT capabilities.
51        HTCapabilities => 0x07,
52
53        /// The data path state.
54        DataPathState => 0x0C,
55
56        /// The hostname of the peer.
57        Arpa => 0x10,
58
59        /// The VHT capabilities.
60        IEEE80211Container => 0x11,
61
62        /// The channel sequence.
63        ChannelSequence => 0x12,
64
65        /// The synchronization tree.
66        SynchronizationTree => 0x14,
67
68        /// The actual version of the AWDL protocol, that's being used.
69        Version => 0x15,
70
71        /// The V2 Election Parameters.
72        ElectionParametersV2 => 0x18
73    }
74}
75/// A trait implemented by all AWDL TLVs.
76pub trait AwdlTlv {
77    const TLV_TYPE: AWDLTLVType;
78}
79
80pub type RawAWDLTLV<'a> = RawTLV<'a, u8, u16>;
81pub type TypedAWDLTLV<'a, Payload> = TLV<u8, u16, AWDLTLVType, Payload>;
82
83#[derive(Clone)]
84pub enum AWDLTLV<'a, MACIterator, LabelIterator> {
85    ServiceResponse(ServiceResponseTLV<'a, LabelIterator>),
86    SynchronizationParameters(SynchronizationParametersTLV),
87    ElectionParameters(ElectionParametersTLV),
88    HTCapabilities(HTCapabilitiesTLV),
89    DataPathState(DataPathStateTLV),
90    Arpa(ArpaTLV<LabelIterator>),
91    IEEE80211Container(IEEE80211ContainerTLV<'a>),
92    ChannelSequence(ChannelSequenceTLV),
93    SynchronizationTree(SyncTreeTLV<MACIterator>),
94    Version(VersionTLV),
95    ElectionParametersV2(ElectionParametersV2TLV),
96    Unknown(RawAWDLTLV<'a>),
97}
98macro_rules! comparisons {
99    ($self:expr, $other:expr, $($path:ident),*) => {
100        match ($self, $other) {
101            $(
102                (Self::$path(lhs), AWDLTLV::<'a, RhsMACIterator, RhsLabelIterator>::$path(rhs)) => lhs == rhs,
103            )*
104            _ => false,
105        }
106    };
107}
108impl<'a, LhsMACIterator, RhsMACIterator, LhsLabelIterator, RhsLabelIterator>
109    PartialEq<AWDLTLV<'a, RhsMACIterator, RhsLabelIterator>>
110    for AWDLTLV<'a, LhsMACIterator, LhsLabelIterator>
111where
112    LhsMACIterator: IntoIterator<Item = MACAddress> + Clone,
113    RhsMACIterator: IntoIterator<Item = MACAddress> + Clone,
114    LhsLabelIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
115    RhsLabelIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
116{
117    fn eq(&self, other: &AWDLTLV<'a, RhsMACIterator, RhsLabelIterator>) -> bool {
118        comparisons!(
119            self,
120            other,
121            ServiceResponse,
122            SynchronizationParameters,
123            ElectionParameters,
124            HTCapabilities,
125            DataPathState,
126            Arpa,
127            IEEE80211Container,
128            ChannelSequence,
129            SynchronizationTree,
130            Version,
131            ElectionParametersV2,
132            Unknown
133        )
134    }
135}
136impl<'a, MACIterator, LabelIterator> Eq for AWDLTLV<'a, MACIterator, LabelIterator>
137where
138    MACIterator: IntoIterator<Item = MACAddress> + Clone,
139    LabelIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
140{
141}
142macro_rules! debug_impls {
143    ($self:expr, $f:expr, $($path:ident),*) => {
144        match $self {
145            $(
146                Self::$path(inner) => inner.fmt($f),
147            )*
148        }
149    };
150}
151impl<'a, MACIterator, LabelIterator> Debug for AWDLTLV<'a, MACIterator, LabelIterator>
152where
153    MACIterator: IntoIterator<Item = MACAddress> + Clone + Debug,
154    LabelIterator: IntoIterator<Item = AWDLStr<'a>> + Clone + Debug,
155{
156    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
157        debug_impls!(
158            self,
159            f,
160            ServiceResponse,
161            SynchronizationParameters,
162            ElectionParameters,
163            HTCapabilities,
164            DataPathState,
165            Arpa,
166            IEEE80211Container,
167            ChannelSequence,
168            SynchronizationTree,
169            Version,
170            ElectionParametersV2,
171            Unknown
172        )
173    }
174}
175impl<'a, MACIterator, LabelIterator> AWDLTLV<'a, MACIterator, LabelIterator>
176where
177    LabelIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
178    <LabelIterator as IntoIterator>::IntoIter: Clone,
179    MACIterator: IntoIterator<Item = MACAddress> + Clone,
180{
181    pub const fn get_type(&self) -> AWDLTLVType {
182        match self {
183            AWDLTLV::Arpa(_) => AWDLTLVType::Arpa,
184            AWDLTLV::ChannelSequence(_) => AWDLTLVType::ChannelSequence,
185            AWDLTLV::DataPathState(_) => AWDLTLVType::DataPathState,
186            AWDLTLV::ElectionParameters(_) => AWDLTLVType::ElectionParameters,
187            AWDLTLV::ElectionParametersV2(_) => AWDLTLVType::ElectionParametersV2,
188            AWDLTLV::HTCapabilities(_) => AWDLTLVType::HTCapabilities,
189            AWDLTLV::IEEE80211Container(_) => AWDLTLVType::IEEE80211Container,
190            AWDLTLV::ServiceResponse(_) => AWDLTLVType::ServiceResponse,
191            AWDLTLV::SynchronizationParameters(_) => AWDLTLVType::SynchronizationParameters,
192            AWDLTLV::SynchronizationTree(_) => AWDLTLVType::SynchronizationTree,
193            AWDLTLV::Version(_) => AWDLTLVType::Version,
194            AWDLTLV::Unknown(raw_tlv) => AWDLTLVType::Unknown(raw_tlv.tlv_type),
195        }
196    }
197}
198macro_rules! measure_with_impls {
199    ($self:expr, $ctx:expr, $($path:ident),*) => {
200        match $self {
201            $(
202                Self::$path(inner) => inner.measure_with($ctx),
203            )*
204            Self::Unknown(raw_tlv) => raw_tlv.slice.len()
205        }
206    };
207}
208impl<'a, MACIterator, LabelIterator> MeasureWith<()> for AWDLTLV<'a, MACIterator, LabelIterator>
209where
210    MACIterator: ExactSizeIterator,
211    LabelIterator: IntoIterator<Item = AWDLStr<'a>> + Clone + Debug,
212{
213    fn measure_with(&self, ctx: &()) -> usize {
214        3 + measure_with_impls!(
215            self,
216            ctx,
217            ServiceResponse,
218            SynchronizationParameters,
219            ElectionParameters,
220            HTCapabilities,
221            DataPathState,
222            Arpa,
223            IEEE80211Container,
224            ChannelSequence,
225            SynchronizationTree,
226            Version,
227            ElectionParametersV2
228        )
229    }
230}
231macro_rules! read_impls {
232    ($self:expr, $raw_tlv:expr, $($path:ident),*) => {
233        match AWDLTLVType::from_bits($raw_tlv.tlv_type) {
234            $(
235                AWDLTLVType::$path => Self::$path($raw_tlv.slice.pread(0)?),
236            )*
237            AWDLTLVType::Unknown(tlv_type) => Self::Unknown(RawTLV {
238                tlv_type,
239                slice: $raw_tlv.slice,
240                _phantom: PhantomData,
241            }),
242            AWDLTLVType::Null => Self::Unknown(RawTLV {
243                tlv_type: 0,
244                slice: $raw_tlv.slice,
245                _phantom: PhantomData,
246            }),
247        }
248    };
249}
250impl<'a> TryFromCtx<'a> for AWDLTLV<'a, ReadMACIterator<'a>, ReadLabelIterator<'a>> {
251    type Error = scroll::Error;
252    fn try_from_ctx(from: &'a [u8], _ctx: ()) -> Result<(Self, usize), Self::Error> {
253        let (raw_tlv, len) =
254            <RawAWDLTLV<'a> as TryFromCtx<'a, Endian>>::try_from_ctx(from, Endian::Little)?;
255        Ok((
256            read_impls!(
257                self,
258                raw_tlv,
259                ServiceResponse,
260                SynchronizationParameters,
261                ElectionParameters,
262                HTCapabilities,
263                DataPathState,
264                Arpa,
265                IEEE80211Container,
266                ChannelSequence,
267                SynchronizationTree,
268                Version,
269                ElectionParametersV2
270            ),
271            len,
272        ))
273    }
274}
275macro_rules! write_impls {
276    ($self:expr, $buf:expr, $tlv_type:expr, $($path:ident),*) => {
277        match $self {
278            $(
279                Self::$path(payload) => $buf.pwrite_with(
280                    TypedAWDLTLV {
281                        tlv_type: $tlv_type,
282                        payload,
283                        _phantom: PhantomData,
284                    },
285                    0,
286                    Endian::Little,
287                ),
288            )*
289            Self::Unknown(tlv) => $buf.pwrite(tlv, 0)
290        }
291    };
292}
293impl<'a, MACIterator, LabelIterator> TryIntoCtx for AWDLTLV<'a, MACIterator, LabelIterator>
294where
295    LabelIterator: IntoIterator<Item = AWDLStr<'a>> + Clone,
296    <LabelIterator as IntoIterator>::IntoIter: Clone,
297    MACIterator: IntoIterator<Item = MACAddress> + ExactSizeIterator + Clone,
298{
299    type Error = scroll::Error;
300    fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
301        let tlv_type = self.get_type();
302        write_impls!(
303            self,
304            buf,
305            tlv_type,
306            ServiceResponse,
307            SynchronizationParameters,
308            ElectionParameters,
309            HTCapabilities,
310            DataPathState,
311            Arpa,
312            IEEE80211Container,
313            ChannelSequence,
314            SynchronizationTree,
315            Version,
316            ElectionParametersV2
317        )
318    }
319}
320
321/// Default [AWDLTLV] returned by reading.
322pub type DefaultAWDLTLV<'a> = AWDLTLV<'a, ReadMACIterator<'a>, ReadLabelIterator<'a>>;
323
324#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
325/// A container for the TLVs in an action frame.
326pub struct ReadTLVs<'a> {
327    bytes: &'a [u8],
328}
329impl<'a> ReadTLVs<'a> {
330    pub const fn new(bytes: &'a [u8]) -> Self {
331        Self { bytes }
332    }
333    /// Get an iterator over [RawAWDLTLV]'s.
334    pub fn raw_tlv_iter(&self) -> impl Iterator<Item = RawAWDLTLV<'a>> + '_ {
335        repeat(()).scan(0usize, |offset, _| {
336            self.bytes.gread::<RawAWDLTLV>(offset).ok()
337        })
338    }
339    /// Check if the TLV type matches and try to parse the TLV.
340    fn match_and_parse_tlv<Tlv: AwdlTlv + TryFromCtx<'a, Error = scroll::Error>>(
341        &self,
342        raw_tlv: RawAWDLTLV<'a>,
343    ) -> Option<Tlv> {
344        if raw_tlv.tlv_type == Tlv::TLV_TYPE.into_bits() {
345            raw_tlv.slice.pread::<Tlv>(0).ok()
346        } else {
347            None
348        }
349    }
350    /// Get an iterator over matching TLVs.
351    pub fn get_tlvs<Tlv: AwdlTlv + TryFromCtx<'a, Error = scroll::Error>>(
352        &self,
353    ) -> impl Iterator<Item = Tlv> + use<'_, 'a, Tlv> {
354        self.raw_tlv_iter()
355            .filter_map(|raw_tlv| self.match_and_parse_tlv(raw_tlv))
356    }
357    /// Get the first matching TLV.
358    pub fn get_first_tlv<Tlv: AwdlTlv + TryFromCtx<'a, Error = scroll::Error>>(
359        &self,
360    ) -> Option<Tlv> {
361        self.raw_tlv_iter()
362            .find_map(|raw_tlv| self.match_and_parse_tlv(raw_tlv))
363    }
364}
365impl MeasureWith<()> for ReadTLVs<'_> {
366    fn measure_with(&self, _ctx: &()) -> usize {
367        self.bytes.len()
368    }
369}
370impl TryIntoCtx<()> for ReadTLVs<'_> {
371    type Error = scroll::Error;
372    fn try_into_ctx(self, buf: &mut [u8], _ctx: ()) -> Result<usize, Self::Error> {
373        buf.pwrite(self.bytes, 0)
374    }
375}