bitfield_access/
lib.rs

1#![doc = include_str!("../README.md")]
2#![no_std]
3
4use core::{
5    fmt::{Debug, UpperHex},
6    ops::{Bound, RangeBounds},
7};
8
9use num_traits::{CheckedShr, PrimInt, Unsigned};
10
11#[inline]
12fn bitmask<T: PrimInt + Unsigned>(bit_width: usize) -> T {
13    let max_width = core::mem::size_of::<T>() * 8;
14    assert!(bit_width <= max_width);
15    if bit_width == max_width {
16        T::max_value()
17    } else {
18        T::from((1_usize << bit_width) - 1).unwrap()
19    }
20}
21
22pub trait BitfieldAccess: AsRef<[u8]> {
23    /// Read a bitfield with the given bit indices from a buffer.
24    ///
25    /// # Examples
26    ///
27    /// ```
28    /// use bitfield_access::BitfieldAccess;
29    ///
30    /// let buffer = [0x12, 0x34, 0x56, 0x78];
31    /// assert_eq!(buffer.read_field::<u8>(4..8), 0x2);
32    /// assert_eq!(buffer.read_field::<u16>(12..24), 0x456);
33    /// assert_eq!(buffer.read_field::<u8>(25..=25), 0x1);
34    /// ```
35    ///
36    /// # Panics
37    ///
38    /// Panics if the range of bits is wider than the integer type `T`
39    /// or the bit indices are out of bounds.
40    #[inline]
41    fn read_field<T>(&self, bitrange: impl RangeBounds<usize>) -> T
42    where
43        T: PrimInt + Unsigned,
44    {
45        // There's a lot of logic here, but as an inline function the bit range is
46        // typically known at compile time, reducing this to just a small handful
47        // of shifts and bitwise instructions.
48        let data = self.as_ref();
49        let start = match bitrange.start_bound() {
50            core::ops::Bound::Included(idx) => *idx,
51            core::ops::Bound::Excluded(idx) => *idx + 1,
52            core::ops::Bound::Unbounded => 0,
53        };
54        let end = match bitrange.end_bound() {
55            core::ops::Bound::Included(idx) => *idx + 1,
56            core::ops::Bound::Excluded(idx) => *idx,
57            core::ops::Bound::Unbounded => data.len() * 8,
58        };
59
60        let storage_width = 8 * core::mem::size_of::<T>();
61        let bit_width = end - start;
62        assert!(
63            bit_width <= storage_width,
64            "field width {} exceeds storage width {}",
65            bit_width,
66            storage_width
67        );
68        let first_byte = start / 8;
69        let last_byte = (end - 1) / 8;
70        let num_bytes = last_byte - first_byte + 1;
71        let offset = 7 - (end - 1) % 8;
72        let mask = bitmask(bit_width);
73
74        // build the result from the last byte (LSB) to the first
75        let mut result = T::from(data[last_byte] >> offset).unwrap();
76        for i in 1..num_bytes {
77            result = result | T::from(data[last_byte - i]).unwrap() << (8 * i - offset);
78        }
79
80        result & mask
81    }
82
83    /// Write a bitfield with the given bit indices to a buffer.
84    ///
85    /// # Examples
86    ///
87    /// ```
88    /// use bitfield_access::BitfieldAccess;
89    ///
90    /// let mut buffer = [0x12, 0x34, 0x56, 0x78];
91    /// buffer.write_field(4..8, 0xA_u8);
92    /// assert_eq!(buffer, [0x1A, 0x34, 0x56, 0x78]);
93    /// buffer.write_field(20..=27, 0xBC_u8);
94    /// assert_eq!(buffer, [0x1A, 0x34, 0x5B, 0xC8]);
95    /// ```
96    ///
97    /// # Panics
98    ///
99    /// Panics if the bit indices are out of bounds or the value is too large.
100    #[inline]
101    fn write_field<T>(&mut self, bitrange: impl RangeBounds<usize>, mut value: T)
102    where
103        Self: AsMut<[u8]>,
104        T: PrimInt + Unsigned + TryInto<u8> + UpperHex + CheckedShr,
105        <T as TryInto<u8>>::Error: Debug,
106    {
107        // There's a lot of logic here, but as an inline function the bit range is
108        // typically known at compile time, reducing this to just a small handful
109        // of shifts and bitwise instructions.
110        let data = self.as_mut();
111        let start = match bitrange.start_bound() {
112            Bound::Included(idx) => *idx,
113            Bound::Excluded(idx) => *idx + 1,
114            Bound::Unbounded => 0,
115        };
116        let mut end = match bitrange.end_bound() {
117            Bound::Included(idx) => *idx + 1,
118            Bound::Excluded(idx) => *idx,
119            Bound::Unbounded => data.len() * 8,
120        };
121        let first_byte = start / 8;
122        let last_byte = (end - 1) / 8;
123        let max_value = bitmask(end - start);
124        assert!(
125            value <= max_value,
126            "value {:#X} exceeds maximum field value {:#X}",
127            value,
128            max_value
129        );
130
131        let byte_mask = T::from(0xFF).unwrap();
132        let zero = T::from(0x0).unwrap();
133
134        // write in one-byte chunks, from the last (LSB) to the first
135        for i in (first_byte..=last_byte).rev() {
136            let bit_offset = 7 - (end - 1) % 8;
137            let bit_width = core::cmp::min(8 - bit_offset, end - start);
138            let bit_mask = bitmask::<u8>(bit_width) << bit_offset;
139            let new_bits: u8 = (value & byte_mask).try_into().unwrap();
140            data[i] = (data[i] & !bit_mask) | ((new_bits << bit_offset) & bit_mask);
141            end -= bit_width;
142            value = value.checked_shr(bit_width as u32).unwrap_or(zero);
143        }
144    }
145}
146
147impl<T> BitfieldAccess for T where T: AsRef<[u8]> {}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    #[inline(never)]
155    fn test_read_field() {
156        let buffer = [0x12, 0x34, 0x56, 0x78];
157
158        // Test reading a single byte
159        assert_eq!(buffer.read_field::<u8>(4..8), 0x2);
160        assert_eq!(buffer.read_field::<u8>(8..16), 0x34);
161
162        // Test reading across byte boundaries
163        assert_eq!(buffer.read_field::<u16>(4..20), 0x2345);
164
165        // Test reading the entire buffer
166        assert_eq!(buffer.read_field::<u32>(..), 0x12345678);
167
168        // Test reading a single bit
169        assert_eq!(buffer.read_field::<u8>(7..8), 0x0);
170        assert_eq!(buffer.read_field::<u8>(17..=17), 0x1);
171    }
172
173    #[test]
174    fn test_write_field() {
175        const BUFFER: [u8; 4] = [0x12, 0x34, 0x56, 0x78];
176
177        // Test writing a single byte
178        let mut buffer = BUFFER;
179        buffer.write_field::<u8>(4..8, 0xA);
180        assert_eq!(buffer, [0x1A, 0x34, 0x56, 0x78]);
181        buffer.write_field::<u8>(0..8, 0xBC);
182        assert_eq!(buffer, [0xBC, 0x34, 0x56, 0x78]);
183
184        // Test writing across byte boundaries
185        let mut buffer = BUFFER;
186        buffer.write_field::<u8>(12..20, 0xBC);
187        assert_eq!(buffer, [0x12, 0x3B, 0xC6, 0x78]);
188
189        // Test writing the entire buffer
190        let mut buffer = BUFFER;
191        buffer.write_field::<u32>(.., 0x87654321u32);
192        assert_eq!(buffer, [0x87, 0x65, 0x43, 0x21]);
193
194        // Test writing a single bit
195        let mut buffer = BUFFER;
196        buffer.write_field::<u8>(7..8, 0x1);
197        assert_eq!(buffer, [0x13, 0x34, 0x56, 0x78]);
198        buffer.write_field::<u8>(8..=8, 0x1);
199        assert_eq!(buffer, [0x13, 0xB4, 0x56, 0x78]);
200        buffer.write_field::<u8>(30..31, 0x1);
201        assert_eq!(buffer, [0x13, 0xB4, 0x56, 0x7A]);
202    }
203}