Skip to main content

irox_bits/
bitstream.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2025 IROX Contributors
3//
4
5use crate::{Bits, BitsError, BitsErrorKind, BitsWrapper, MutBits};
6use core::cmp::Ordering;
7
8pub struct BitStreamEncoder<'a, T: MutBits> {
9    delegate: BitsWrapper<'a, T>,
10    buf: u32,
11    remaining: u8,
12}
13impl<T: MutBits> Drop for BitStreamEncoder<'_, T> {
14    fn drop(&mut self) {
15        let [a, b, c, d] = self.buf.to_be_bytes();
16        if self.remaining < 8 {
17            // write 4
18            let _ = self.delegate.write_all_bytes(&[a, b, c, d]);
19        } else if self.remaining < 16 {
20            // write 3
21            let _ = self.delegate.write_all_bytes(&[a, b, c]);
22        } else if self.remaining < 24 {
23            // write 2
24            let _ = self.delegate.write_all_bytes(&[a, b]);
25        } else if self.remaining < 32 {
26            // write 1
27            let _ = self.delegate.write_u8(a);
28        }
29    }
30}
31impl<'a, T: MutBits> BitStreamEncoder<'a, T> {
32    pub fn new(delegate: BitsWrapper<'a, T>) -> Self {
33        Self {
34            delegate,
35            buf: 0u32,
36            remaining: 32u8,
37        }
38    }
39    pub fn write_u8_bits(&mut self, val: u8, num_bits: u8) -> Result<(), BitsError> {
40        self.write_u32_bits(val as u32, num_bits)
41    }
42    pub fn write_u16_bits(&mut self, val: u16, num_bits: u8) -> Result<(), BitsError> {
43        self.write_u32_bits(val as u32, num_bits)
44    }
45    pub fn write_u32_bits(&mut self, val: u32, mut num_bits: u8) -> Result<(), BitsError> {
46        if num_bits > 32 {
47            return Err(BitsErrorKind::InvalidInput.into());
48        }
49        while num_bits > 0 {
50            match num_bits.cmp(&self.remaining) {
51                Ordering::Less => {
52                    let shift = self.remaining - num_bits;
53                    let mask = (1u32 << num_bits) - 1;
54                    self.buf |= (val & mask) << shift;
55                    self.remaining -= num_bits;
56                    num_bits = 0;
57                }
58                Ordering::Equal => {
59                    let mask = (1u32 << num_bits) - 1;
60                    self.buf |= val & mask;
61                    num_bits = 0;
62                    self.delegate.write_be_u32(self.buf)?;
63                    self.remaining = 32;
64                    self.buf = 0;
65                }
66                Ordering::Greater => {
67                    let touse = self.remaining;
68                    let shift = num_bits - self.remaining;
69                    let mask = (1u32 << touse) - 1;
70                    self.buf |= (val >> shift) & mask;
71                    self.delegate.write_be_u32(self.buf)?;
72                    self.remaining = 32;
73                    self.buf = 0;
74                    num_bits -= touse;
75                }
76            }
77        }
78        Ok(())
79    }
80}
81
82pub struct BitStreamDecoder<'a, T: Bits> {
83    delegate: BitsWrapper<'a, T>,
84    buf: u32,
85    used: u8,
86}
87impl<'a, T: Bits> BitStreamDecoder<'a, T> {
88    pub fn new(delegate: BitsWrapper<'a, T>) -> Self {
89        Self {
90            delegate,
91            buf: 0,
92            used: 0,
93        }
94    }
95    pub fn read_u32_bits(&mut self, num_bits: u8) -> Result<u32, BitsError> {
96        if num_bits > 32 {
97            return Err(BitsErrorKind::InvalidInput.into());
98        }
99        loop {
100            match self.used.cmp(&num_bits) {
101                Ordering::Less => {
102                    // used < numbits - add more.
103                    let v = self.delegate.read_u8()?;
104                    self.buf = (self.buf << 8) | v as u32;
105                    self.used += 8;
106                }
107                Ordering::Equal => {
108                    let mask = (1u32 << num_bits) - 1;
109                    self.used = 0;
110                    let b = self.buf & mask;
111                    self.buf = 0;
112                    return Ok(b);
113                }
114                Ordering::Greater => {
115                    let rem = self.used - num_bits;
116                    let mask = (1u32 << num_bits) - 1;
117                    let b = (self.buf >> rem) & mask;
118                    self.used -= num_bits;
119                    return Ok(b);
120                }
121            }
122        }
123    }
124    pub fn read_le_u32_bits(&mut self, num_bits: u8) -> Result<u32, BitsError> {
125        if num_bits > 32 {
126            return Err(BitsErrorKind::InvalidInput.into());
127        }
128        loop {
129            match self.used.cmp(&num_bits) {
130                Ordering::Less => {
131                    // used < numbits - add more.
132                    let v = self.delegate.read_u8()?;
133                    self.buf |= (v as u32) << self.used;
134                    self.used += 8;
135                }
136                Ordering::Equal => {
137                    let mask = (1u32 << num_bits) - 1;
138                    self.used = 0;
139                    let b = self.buf & mask;
140                    self.buf = 0;
141                    return Ok(b);
142                }
143                Ordering::Greater => {
144                    let mask = (1u32 << num_bits) - 1;
145                    let b = self.buf & mask;
146                    self.buf >>= num_bits;
147                    self.used -= num_bits;
148                    return Ok(b);
149                }
150            }
151        }
152    }
153    pub fn peek_le_u32_bits(&mut self, num_bits: u8) -> Result<u32, BitsError> {
154        if num_bits > 32 {
155            return Err(BitsErrorKind::InvalidInput.into());
156        }
157        loop {
158            match self.used.cmp(&num_bits) {
159                Ordering::Less => {
160                    // used < numbits - add more.
161                    let v = self.delegate.read_u8()?;
162                    self.buf |= (v as u32) << self.used;
163                    self.used += 8;
164                }
165                _ => {
166                    let mask = (1u32 << num_bits) - 1;
167                    let b = self.buf & mask;
168                    return Ok(b);
169                }
170            }
171        }
172    }
173    pub fn delegate(&mut self) -> &mut BitsWrapper<'a, T> {
174        self.buf = 0;
175        self.used = 0;
176        &mut self.delegate
177    }
178}
179#[cfg(all(test, feature = "std"))]
180mod test {
181    use crate::{BitStreamDecoder, BitStreamEncoder, BitsError, BitsWrapper};
182
183    #[test]
184    pub fn test_dec() -> Result<(), BitsError> {
185        let buf = vec![0xAB, 0xCD, 0xAB, 0xCD];
186        let mut dec = BitStreamDecoder::new(BitsWrapper::Owned(buf));
187        assert_eq!(0xA, dec.read_u32_bits(4)?);
188        assert_eq!(0xB, dec.read_u32_bits(4)?);
189        assert_eq!(0xC, dec.read_u32_bits(4)?);
190        assert_eq!(0xD, dec.read_u32_bits(4)?);
191        assert_eq!(0xABCD, dec.read_u32_bits(16)?);
192        Ok(())
193    }
194
195    #[test]
196    pub fn test_dec2() -> Result<(), BitsError> {
197        let buf = vec![0x03, 0xC0, 0x81, 0x00, 0x88, 0x10, 0x1A, 0x02];
198        let mut dec = BitStreamDecoder::new(BitsWrapper::Owned(buf));
199        assert_eq!(7, dec.read_u32_bits(9)?);
200        assert_eq!(258, dec.read_u32_bits(9)?);
201        assert_eq!(8, dec.read_u32_bits(9)?);
202        assert_eq!(8, dec.read_u32_bits(9)?);
203        assert_eq!(258, dec.read_u32_bits(9)?);
204        assert_eq!(6, dec.read_u32_bits(9)?);
205        assert_eq!(257, dec.read_u32_bits(9)?);
206
207        Ok(())
208    }
209
210    #[test]
211    pub fn test_enc() -> Result<(), BitsError> {
212        let mut buf = Vec::<u8>::new();
213        {
214            let wrap = BitsWrapper::Borrowed(&mut buf);
215            let mut enc = BitStreamEncoder::new(wrap);
216
217            enc.write_u16_bits(0xAAAA, 4)?;
218            enc.write_u16_bits(0xBBBB, 4)?;
219            enc.write_u16_bits(0xCCCC, 4)?;
220            enc.write_u16_bits(0xDDDD, 4)?;
221            enc.write_u16_bits(0xABCD, 16)?;
222        }
223        // println!("{:?}", buf);
224        assert_eq!(buf, [0xAB, 0xCD, 0xAB, 0xCD]);
225        Ok(())
226    }
227
228    #[test]
229    pub fn test_enc2() -> Result<(), BitsError> {
230        let mut buf = Vec::<u8>::new();
231        {
232            let wrap = BitsWrapper::Borrowed(&mut buf);
233            let mut enc = BitStreamEncoder::new(wrap);
234
235            enc.write_u16_bits(7, 9)?;
236            enc.write_u16_bits(258, 9)?;
237            enc.write_u16_bits(8, 9)?;
238            enc.write_u16_bits(8, 9)?;
239            enc.write_u16_bits(258, 9)?;
240            enc.write_u16_bits(6, 9)?;
241            enc.write_u16_bits(257, 9)?;
242        }
243        assert_eq!(buf, [0x03, 0xC0, 0x81, 0x00, 0x88, 0x10, 0x1A, 0x02]);
244        // 0x007  0b000000111
245        // 0x102             100000010
246        // 0x008                      000001000
247        // 0x008                               000001000
248        // 0x102                                        100000010
249        // 0x006                                                 000000011
250        // 0x03 = 0b00000011
251        // 0xC0 =           11000000
252        // 0x81 =                   10000001
253        // 0x00 =                           00000000
254        // 0x88 =                                   10001000
255        // 0x10 =                                           00010000
256
257        Ok(())
258    }
259}