Skip to main content

fray/
bitfields.rs

1use crate::{BitContainer, BitContainerFor};
2
3mod fields;
4pub use fields::{Field, FieldType};
5pub mod bitorder;
6use bitorder::BitOrder;
7mod private;
8
9/// Provides all the convenience methods for interacting with a **bitfield structure**.
10///
11/// This trait is:
12/// - **sealed** (cannot be implemented manually),
13/// - intended to be obtained via the `bitfield` macro,
14/// - the main API end-users will work with.
15///
16/// # Core methods
17///
18/// - [`new`](Self::new):
19///   Construct a new instance with an empty [`BitContainer`].
20///
21/// - [`into_inner`](Self::into_inner):
22///   Consume `self` and return the underlying raw **inner value**
23///   ([`BitFieldImpl::Container::Inner`](BitContainer::Inner)).
24///
25/// - [`get`](Self::get) / [`try_get`](Self::try_get):
26///   Read the value of a field.
27///
28/// - [`set`](Self::set) / [`try_set`](Self::try_set):
29///   Write the value of a field.
30///
31/// - [`with`](Self::with) / [`try_with`](Self::try_with):
32///   Write the value of a field and return `&mut Self` for chaining.
33///
34/// # How it works
35///
36/// Internally, the macro implements [`BitFieldImpl`] for your type, and
37/// [`BitField`] is automatically implemented via a generic blanket impl
38/// over all types that implement [`BitFieldImpl`].
39pub trait BitField: private::Sealed + BitFieldImpl {
40    ///   Consume `self` and return the underlying raw **inner value**
41    ///   ([`BitField::Container::Inner`](BitContainer::Inner)).
42    #[inline]
43    fn into_inner(self) -> <Self::Container as BitContainer>::Inner {
44        self.into().into_inner()
45    }
46
47    /// Construct a new instance with an empty [`BitContainer`].
48    #[inline]
49    fn new() -> Self {
50        Self::Container::empty().into()
51    }
52
53    #[doc(hidden)]
54    #[inline]
55    fn _set<F>(&mut self, value: F::BitsType, _: private::Token)
56    where
57        F: Field<Self>,
58        Self::Container: BitContainerFor<F::BitsType>,
59    {
60        debug_assert!(F::OFFSET + F::SIZE <= Self::Container::SIZE);
61        self.as_mut().store(value, F::OFFSET, F::SIZE);
62    }
63
64    #[doc(hidden)]
65    #[inline]
66    fn _get<F>(&self, _: private::Token) -> F::BitsType
67    where
68        F: Field<Self>,
69        Self::Container: BitContainerFor<F::BitsType>,
70    {
71        debug_assert!(F::OFFSET + F::SIZE <= Self::Container::SIZE);
72        self.as_ref().retrieve(F::OFFSET, F::SIZE)
73    }
74
75    /// Read the value of field `F`. Returns `F::Type` directly.
76    ///
77    /// Requires an infallible conversion from `F::BitsType` to `F::Type`.
78    #[inline]
79    fn get<F>(&self) -> F::Type
80    where
81        F: Field<Self>,
82        Self::Container: BitContainerFor<F::BitsType>,
83        F::BitsType: Into<F::Type>,
84    {
85        self._get::<F>(private::Token).into()
86    }
87
88    /// Read the value of field `F`.
89    ///
90    /// Returns a `Result` containing:
91    /// - `F::Type` if the conversion succeeds,
92    /// - an error (`TryInto<F::Type>::Error`) if it fails.
93    #[inline]
94    fn try_get<F>(&self) -> Result<F::Type, <F::BitsType as TryInto<F::Type>>::Error>
95    where
96        F: Field<Self>,
97        Self::Container: BitContainerFor<F::BitsType>,
98        F::BitsType: TryInto<F::Type>,
99    {
100        self._get::<F>(private::Token).try_into()
101    }
102
103    /// Write the value of field `F`.
104    ///
105    /// Requires an infallible conversion from `F::Type` to `F::BitsType`.
106    #[inline]
107    fn set<F>(&mut self, value: F::Type)
108    where
109        F: Field<Self>,
110        Self::Container: BitContainerFor<F::BitsType>,
111        F::Type: Into<F::BitsType>,
112    {
113        self._set::<F>(value.into(), private::Token);
114    }
115
116    /// Write the value of field `F`.
117    ///
118    /// Returns a `Result` containing:
119    /// - `()` if the conversion succeeds,
120    /// - an error (`TryInto<F::BitsType>::Error`) if it fails.
121    #[inline]
122    fn try_set<F>(&mut self, value: F::Type) -> Result<(), <F::Type as TryInto<F::BitsType>>::Error>
123    where
124        F: Field<Self>,
125        Self::Container: BitContainerFor<F::BitsType>,
126        F::Type: TryInto<F::BitsType>,
127    {
128        self._set::<F>(value.try_into()?, private::Token);
129        Ok(())
130    }
131
132    /// Write the value of field `F` and return `&mut Self` for chaining.
133    ///
134    /// Requires an infallible conversion from `F::Type` to `F::BitsType`.
135    #[inline]
136    fn with<F>(&mut self, value: F::Type) -> &mut Self
137    where
138        F: Field<Self>,
139        Self::Container: BitContainerFor<F::BitsType>,
140        F::Type: Into<F::BitsType>,
141    {
142        self.set::<F>(value);
143        self
144    }
145
146    /// Write the value of field `F`.
147    ///
148    /// Returns a `Result` containing:
149    /// - `&mut Self` if the conversion succeeds allowing chaning,
150    /// - an error (`TryInto<F::BitsType>::Error`) if it fails.
151    #[inline]
152    fn try_with<F>(
153        &mut self,
154        value: F::Type,
155    ) -> Result<&mut Self, <F::Type as TryInto<F::BitsType>>::Error>
156    where
157        F: Field<Self>,
158        Self::Container: BitContainerFor<F::BitsType>,
159        F::Type: TryInto<F::BitsType>,
160    {
161        self.try_set::<F>(value)?;
162        Ok(self)
163    }
164}
165
166impl<T: BitFieldImpl + private::Sealed> BitField for T {}
167
168/// Internal trait required by [`BitField`].
169///
170/// `BitFieldImpl` is **not meant to be implemented manually**.
171/// It is automatically implemented by the `bitfield` macro.
172///
173/// # Usage
174///
175/// End-users generally do **not** interact with `BitFieldImpl` directly.
176/// Instead, they use the convenience methods provided by [`BitField`].
177pub trait BitFieldImpl
178where
179    Self::Container: Into<Self>,
180    Self: Into<Self::Container>,
181    Self: AsRef<Self::Container>,
182    Self: AsMut<Self::Container>,
183{
184    /// The underlying storage type for the bitfield.
185    ///
186    /// This type defines where the raw bits are stored.
187    /// It must implement [`BitContainer`].
188    type Container: BitContainer;
189
190    /// Indicates the bit numbering order used by the bitfield.
191    ///
192    /// This associated type has no functional impact; it exists purely for
193    /// **semantic clarity**.
194    ///
195    /// Set automatically by the [`bitorder`](crate::bitfield) attribute
196    /// (`LSB0` by default).
197    type BitOrder: BitOrder;
198}
199
200impl<T: BitFieldImpl> private::Sealed for T {}
201
202#[cfg(test)]
203mod tests {
204    use crate::iterable::BitIterableContainer;
205
206    use super::*;
207
208    #[allow(dead_code)]
209    #[allow(clippy::upper_case_acronyms)]
210    #[allow(missing_debug_implementations)]
211    mod dns_flags {
212        use core::fmt::Debug;
213
214        use super::*;
215        use crate::{bitorder::LSB0, iterable::BitIterableContainer};
216
217        // https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1
218        // #[bitfield(repr(u16), bitorder(msb0))]
219        // pub struct DNSFlags {
220        //     QR: bool,
221        //     OPCODE: OpCode,
222        //     AA: bool,
223        //     TC: bool,
224        //     RD: bool,
225        //     RA: bool,
226        //     #[bits(3)]
227        //     _z: (),
228        //     RCODE: Rcode,
229        // }
230        pub mod fields {
231            macro_rules! create_field {
232                ($name:ident, $bf:ty, $t: ty, $bit_t:ty, $offset:expr) => {
233                    pub enum $name {}
234
235                    impl crate::Field<$bf> for $name {
236                        type Type = $t;
237                        type BitsType = $bit_t;
238                        const OFFSET: usize = $offset;
239                    }
240                };
241            }
242            create_field!(QR, super::DNSFlags, bool, bool, 0);
243            create_field!(OPCODE, super::DNSFlags, super::OpCode, u8, 1);
244            create_field!(AA, super::DNSFlags, bool, bool, 5);
245            create_field!(TC, super::DNSFlags, bool, bool, 6);
246            create_field!(RD, super::DNSFlags, bool, bool, 7);
247            create_field!(RA, super::DNSFlags, bool, bool, 8);
248            // offset = 8 + 1 (RA size) + 3 (zeroes)
249            create_field!(RCODE, super::DNSFlags, super::Rcode, u8, 12);
250        }
251
252        #[repr(u8)]
253        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
254        pub enum OpCode {
255            Query = 0,
256            IQuery = 1,
257            Status = 2,
258        }
259        impl FieldType for OpCode {
260            const SIZE: usize = 4;
261
262            type BitsType = u8;
263        }
264        impl TryFrom<u8> for OpCode {
265            type Error = ();
266
267            fn try_from(value: u8) -> Result<Self, Self::Error> {
268                Ok(match value {
269                    0 => Self::Query,
270                    1 => Self::IQuery,
271                    2 => Self::Status,
272                    _ => return Err(()),
273                })
274            }
275        }
276
277        impl From<OpCode> for u8 {
278            fn from(value: OpCode) -> Self {
279                value as u8
280            }
281        }
282
283        #[repr(u8)]
284        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
285        pub enum Rcode {
286            NoError = 0,
287            FormatError = 1,
288            ServerFailure = 2,
289            NameError = 3,
290            NotImplemented = 4,
291            Refused = 5,
292        }
293        impl FieldType for Rcode {
294            const SIZE: usize = 4;
295
296            type BitsType = u8;
297        }
298
299        impl TryFrom<u8> for Rcode {
300            type Error = ();
301
302            fn try_from(value: u8) -> Result<Self, Self::Error> {
303                Ok(match value {
304                    0 => Self::NoError,
305                    1 => Self::FormatError,
306                    2 => Self::ServerFailure,
307                    3 => Self::NameError,
308                    4 => Self::NotImplemented,
309                    5 => Self::Refused,
310                    _ => return Err(()),
311                })
312            }
313        }
314
315        #[derive(Clone, Copy)]
316        pub struct DNSFlags(BitIterableContainer<u16>);
317
318        impl From<DNSFlags> for <<DNSFlags as BitFieldImpl>::Container as BitContainer>::Inner {
319            fn from(value: DNSFlags) -> Self {
320                value.into_inner()
321            }
322        }
323
324        impl From<BitIterableContainer<u16>> for DNSFlags {
325            fn from(value: BitIterableContainer<u16>) -> Self {
326                Self(value)
327            }
328        }
329
330        impl From<DNSFlags> for BitIterableContainer<u16> {
331            fn from(value: DNSFlags) -> Self {
332                value.0
333            }
334        }
335
336        impl AsRef<BitIterableContainer<u16>> for DNSFlags {
337            fn as_ref(&self) -> &BitIterableContainer<u16> {
338                &self.0
339            }
340        }
341
342        impl AsMut<BitIterableContainer<u16>> for DNSFlags {
343            fn as_mut(&mut self) -> &mut BitIterableContainer<u16> {
344                &mut self.0
345            }
346        }
347
348        impl BitFieldImpl for DNSFlags {
349            type Container = BitIterableContainer<u16>;
350            type BitOrder = LSB0;
351        }
352
353        impl Debug for DNSFlags {
354            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
355                use crate::debug::PrettyResult;
356                f.debug_struct("DNSFlags")
357                    .field("QR", &PrettyResult::from(&self.try_get::<QR>()))
358                    .field("OPCODE", &PrettyResult::from(&self.try_get::<OPCODE>()))
359                    .field("AA", &PrettyResult::from(&self.try_get::<AA>()))
360                    .field("TC", &PrettyResult::from(&self.try_get::<TC>()))
361                    .field("RD", &PrettyResult::from(&self.try_get::<RD>()))
362                    .field("RA", &PrettyResult::from(&self.try_get::<RA>()))
363                    .field("RCODE", &PrettyResult::from(&self.try_get::<RCODE>()))
364                    .finish()
365            }
366        }
367    }
368    use dns_flags::{DNSFlags, OpCode, fields::*};
369
370    #[test]
371    fn set_bool_basic() {
372        let mut dns_flags = DNSFlags::new();
373
374        dns_flags.set::<QR>(true);
375        assert_eq!(dns_flags.into_inner(), 0b1u16);
376
377        dns_flags.set::<QR>(true);
378        assert_eq!(dns_flags.into_inner(), 0b1u16);
379
380        dns_flags.set::<QR>(false);
381        assert_eq!(dns_flags.into_inner(), 0b0u16);
382
383        dns_flags.set::<QR>(false);
384        assert_eq!(dns_flags.into_inner(), 0b0u16);
385
386        dns_flags.set::<RA>(true);
387        assert_eq!(dns_flags.into_inner(), 0b1_0000_0000u16);
388
389        dns_flags.set::<RA>(true);
390        assert_eq!(dns_flags.into_inner(), 0b1_0000_0000u16);
391
392        dns_flags.set::<RA>(false);
393        assert_eq!(dns_flags.into_inner(), 0b0u16);
394
395        dns_flags.set::<RA>(false);
396        assert_eq!(dns_flags.into_inner(), 0b0u16);
397    }
398
399    #[test]
400    fn set_bool_dont_overlap() {
401        let mut dns_flags = DNSFlags::new();
402        dns_flags.set::<TC>(true);
403        dns_flags.set::<RD>(true);
404        dns_flags.set::<RA>(true);
405        assert_eq!(dns_flags.into_inner(), 0b1_1100_0000u16);
406
407        dns_flags.set::<RD>(false);
408        assert_eq!(dns_flags.into_inner(), 0b1_0100_0000u16);
409
410        dns_flags.set::<RD>(true);
411        assert_eq!(dns_flags.into_inner(), 0b1_1100_0000u16);
412    }
413
414    #[test]
415    fn get_bool() {
416        let dns_flags = DNSFlags::from(BitIterableContainer::from(0b1u16));
417        assert!(dns_flags.get::<QR>());
418
419        let dns_flags = DNSFlags::new();
420        assert!(!dns_flags.get::<RA>());
421        assert!(!dns_flags.get::<QR>());
422
423        let dns_flags = DNSFlags::from(BitIterableContainer::from(0b1_0000_0000u16));
424        assert!(dns_flags.get::<RA>());
425    }
426
427    #[test]
428    fn set_custom_type() {
429        let mut dns_flags = DNSFlags::new();
430
431        dns_flags.set::<OPCODE>(OpCode::IQuery);
432        assert_eq!(dns_flags.into_inner(), 0b10u16);
433
434        dns_flags.set::<OPCODE>(OpCode::Status);
435        assert_eq!(dns_flags.into_inner(), 0b100u16);
436
437        dns_flags.set::<OPCODE>(OpCode::Query);
438        assert_eq!(dns_flags.into_inner(), 0b0u16);
439    }
440
441    #[test]
442    fn get_custom_type() {
443        let dns_flags = DNSFlags::from(BitIterableContainer::from(0b10u16));
444        assert_eq!(dns_flags.try_get::<OPCODE>(), Ok(OpCode::IQuery));
445
446        let dns_flags = DNSFlags::from(BitIterableContainer::from(0b100u16));
447        assert_eq!(dns_flags.try_get::<OPCODE>(), Ok(OpCode::Status));
448
449        let dns_flags = DNSFlags::new();
450        assert_eq!(dns_flags.try_get::<OPCODE>(), Ok(OpCode::Query));
451    }
452}