Skip to main content

oximedia_bitstream/write/
counter.rs

1// Copyright 2017 Brian Langenberger
2// Copyright 2024-2026 COOLJAPAN OU (Team Kitasan)
3//
4// Licensed under the Apache License, Version 2.0 or the MIT license,
5// at your option. See the LICENSE-APACHE / LICENSE-MIT files for details.
6
7//! Bit-counting writers — `Overflowed`, the `Counter` trait, and the
8//! `BitsWritten` / deprecated `BitCounter` accumulators.
9
10use core::{convert::TryFrom, fmt};
11use std::io;
12
13use super::{
14    BitCount, BitWrite, Checkable, Endianness, Numeric, PhantomData, Primitive, SignedBitCount,
15    SignedInteger, UnsignedInteger,
16};
17
18/// An error returned if performing math operations would overflow
19#[derive(Copy, Clone, Debug)]
20pub struct Overflowed;
21
22impl fmt::Display for Overflowed {
23    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
24        "overflow occured in counter".fmt(f)
25    }
26}
27
28impl core::error::Error for Overflowed {}
29
30impl From<Overflowed> for io::Error {
31    fn from(Overflowed: Overflowed) -> Self {
32        io::Error::new(
33            #[cfg(feature = "std")]
34            {
35                io::ErrorKind::StorageFull
36            },
37            #[cfg(not(feature = "std"))]
38            {
39                io::ErrorKind::Other
40            },
41            "bitstream accumulator overflow",
42        )
43    }
44}
45
46/// A common trait for integer types for performing math operations
47/// which may check for overflow.
48pub trait Counter: Default + Sized + From<u8> + TryFrom<u32> + TryFrom<usize> {
49    /// add rhs to self, returning `Overflowed` if the result is too large
50    fn checked_add_assign(&mut self, rhs: Self) -> Result<(), Overflowed>;
51
52    /// multiply self by rhs, returning `Overflowed` if the result is too large
53    fn checked_mul(self, rhs: Self) -> Result<Self, Overflowed>;
54
55    /// returns `true` if the number if bits written is divisible by 8
56    fn byte_aligned(&self) -> bool;
57}
58
59macro_rules! define_counter {
60    ($t:ty) => {
61        impl Counter for $t {
62            fn checked_add_assign(&mut self, rhs: Self) -> Result<(), Overflowed> {
63                *self = <$t>::checked_add(*self, rhs).ok_or(Overflowed)?;
64                Ok(())
65            }
66
67            fn checked_mul(self, rhs: Self) -> Result<Self, Overflowed> {
68                <$t>::checked_mul(self, rhs).ok_or(Overflowed)
69            }
70
71            fn byte_aligned(&self) -> bool {
72                self % 8 == 0
73            }
74        }
75    };
76}
77
78define_counter!(u8);
79define_counter!(u16);
80define_counter!(u32);
81define_counter!(u64);
82define_counter!(u128);
83
84/// For counting the number of bits written but generating no output.
85///
86/// # Example
87/// ```
88/// use oximedia_bitstream::{BigEndian, BitWrite, BitsWritten};
89/// let mut writer: BitsWritten<u32> = BitsWritten::new();
90/// writer.write_var(1, 0b1u8).unwrap();
91/// writer.write_var(2, 0b01u8).unwrap();
92/// writer.write_var(5, 0b10111u8).unwrap();
93/// assert_eq!(writer.written(), 8);
94/// ```
95#[derive(Default)]
96pub struct BitsWritten<N> {
97    bits: N,
98}
99
100impl<N: Default> BitsWritten<N> {
101    /// Creates new empty BitsWritten value
102    #[inline]
103    pub fn new() -> Self {
104        Self { bits: N::default() }
105    }
106}
107
108impl<N: Copy> BitsWritten<N> {
109    /// Returns number of bits written
110    #[inline]
111    pub fn written(&self) -> N {
112        self.bits
113    }
114}
115
116impl<N> BitsWritten<N> {
117    /// Returns number of bits written
118    #[inline]
119    pub fn into_written(self) -> N {
120        self.bits
121    }
122}
123
124impl<N: Counter> BitWrite for BitsWritten<N> {
125    #[inline]
126    fn write_bit(&mut self, _bit: bool) -> io::Result<()> {
127        self.bits.checked_add_assign(1u8.into())?;
128        Ok(())
129    }
130
131    #[inline]
132    fn write_const<const BITS: u32, const VALUE: u32>(&mut self) -> io::Result<()> {
133        const {
134            assert!(
135                BITS == 0 || VALUE <= (u32::ALL >> (u32::BITS_SIZE - BITS)),
136                "excessive value for bits written"
137            );
138        }
139
140        self.bits
141            .checked_add_assign(BITS.try_into().map_err(|_| Overflowed)?)?;
142        Ok(())
143    }
144
145    #[inline]
146    fn write_unsigned<const BITS: u32, U>(&mut self, value: U) -> io::Result<()>
147    where
148        U: UnsignedInteger,
149    {
150        const {
151            assert!(BITS <= U::BITS_SIZE, "excessive bits for type written");
152        }
153
154        if BITS == 0 {
155            Ok(())
156        } else if value <= (U::ALL >> (U::BITS_SIZE - BITS)) {
157            self.bits
158                .checked_add_assign(BITS.try_into().map_err(|_| Overflowed)?)?;
159            Ok(())
160        } else {
161            Err(io::Error::new(
162                io::ErrorKind::InvalidInput,
163                "excessive value for bits written",
164            ))
165        }
166    }
167
168    #[inline]
169    fn write_signed<const BITS: u32, S>(&mut self, value: S) -> io::Result<()>
170    where
171        S: SignedInteger,
172    {
173        let SignedBitCount {
174            bits: BitCount { bits },
175            unsigned,
176        } = const {
177            assert!(BITS <= S::BITS_SIZE, "excessive bits for type written");
178            let count = BitCount::<BITS>::new::<BITS>().signed_count();
179            match count {
180                Some(c) => c,
181                None => panic!("signed writes need at least 1 bit for sign"),
182            }
183        };
184
185        // doesn't matter which side the sign is on
186        // so long as it's added to the bit count
187        self.bits.checked_add_assign(1u8.into())?;
188
189        self.write_unsigned_counted(
190            unsigned,
191            if value.is_negative() {
192                value.as_negative(bits)
193            } else {
194                value.as_non_negative()
195            },
196        )
197    }
198
199    #[inline]
200    fn write_unsigned_counted<const MAX: u32, U>(
201        &mut self,
202        BitCount { bits }: BitCount<MAX>,
203        value: U,
204    ) -> io::Result<()>
205    where
206        U: UnsignedInteger,
207    {
208        if MAX <= U::BITS_SIZE || bits <= U::BITS_SIZE {
209            if bits == 0 {
210                Ok(())
211            } else if value <= U::ALL >> (U::BITS_SIZE - bits) {
212                self.bits
213                    .checked_add_assign(bits.try_into().map_err(|_| Overflowed)?)?;
214                Ok(())
215            } else {
216                Err(io::Error::new(
217                    io::ErrorKind::InvalidInput,
218                    "excessive value for bits written",
219                ))
220            }
221        } else {
222            Err(io::Error::new(
223                io::ErrorKind::InvalidInput,
224                "excessive bits for type written",
225            ))
226        }
227    }
228
229    #[inline]
230    fn write_signed_counted<const MAX: u32, S>(
231        &mut self,
232        bits: impl TryInto<SignedBitCount<MAX>>,
233        value: S,
234    ) -> io::Result<()>
235    where
236        S: SignedInteger,
237    {
238        let SignedBitCount {
239            bits: BitCount { bits },
240            unsigned,
241        } = bits.try_into().map_err(|_| {
242            io::Error::new(
243                io::ErrorKind::InvalidInput,
244                "signed writes need at least 1 bit for sign",
245            )
246        })?;
247
248        if MAX <= S::BITS_SIZE || bits <= S::BITS_SIZE {
249            // doesn't matter which side the sign is on
250            // so long as it's added to the bit count
251            self.bits.checked_add_assign(1u8.into())?;
252
253            self.write_unsigned_counted(
254                unsigned,
255                if value.is_negative() {
256                    value.as_negative(bits)
257                } else {
258                    value.as_non_negative()
259                },
260            )
261        } else {
262            Err(io::Error::new(
263                io::ErrorKind::InvalidInput,
264                "excessive bits for type written",
265            ))
266        }
267    }
268
269    #[inline]
270    fn write_from<V>(&mut self, _: V) -> io::Result<()>
271    where
272        V: Primitive,
273    {
274        self.bits.checked_add_assign(
275            N::try_from(core::mem::size_of::<V>())
276                .map_err(|_| Overflowed)?
277                .checked_mul(8u8.into())?,
278        )?;
279        Ok(())
280    }
281
282    #[inline]
283    fn write_as_from<F, V>(&mut self, _: V) -> io::Result<()>
284    where
285        F: Endianness,
286        V: Primitive,
287    {
288        self.bits.checked_add_assign(
289            N::try_from(core::mem::size_of::<V>())
290                .map_err(|_| Overflowed)?
291                .checked_mul(8u8.into())?,
292        )?;
293        Ok(())
294    }
295
296    #[inline]
297    fn pad(&mut self, bits: u32) -> io::Result<()> {
298        self.bits
299            .checked_add_assign(bits.try_into().map_err(|_| Overflowed)?)?;
300        Ok(())
301    }
302
303    #[inline]
304    fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
305        self.bits.checked_add_assign(
306            N::try_from(buf.len())
307                .map_err(|_| Overflowed)?
308                .checked_mul(8u8.into())?,
309        )?;
310        Ok(())
311    }
312
313    fn write_unary<const STOP_BIT: u8>(&mut self, value: u32) -> io::Result<()> {
314        const {
315            assert!(matches!(STOP_BIT, 0 | 1), "stop bit must be 0 or 1");
316        }
317
318        self.bits
319            .checked_add_assign(value.try_into().map_err(|_| Overflowed)?)?;
320        self.bits.checked_add_assign(1u8.into())?;
321        Ok(())
322    }
323
324    fn write_checked<C: Checkable>(&mut self, value: C) -> io::Result<()> {
325        Ok(self
326            .bits
327            .checked_add_assign(value.written_bits().try_into().map_err(|_| Overflowed)?)?)
328    }
329
330    #[inline]
331    fn byte_aligned(&self) -> bool {
332        self.bits.byte_aligned()
333    }
334}
335
336/// For counting the number of bits written but generating no output.
337///
338/// # Example
339/// ```
340/// use oximedia_bitstream::{BigEndian, BitWrite, BitCounter};
341/// let mut writer: BitCounter<u32, BigEndian> = BitCounter::new();
342/// writer.write_var(1, 0b1u8).unwrap();
343/// writer.write_var(2, 0b01u8).unwrap();
344/// writer.write_var(5, 0b10111u8).unwrap();
345/// assert_eq!(writer.written(), 8);
346/// ```
347#[derive(Default)]
348#[deprecated(since = "4.0.0", note = "use of BitsWritten is preferred")]
349pub struct BitCounter<N, E: Endianness> {
350    bits: BitsWritten<N>,
351    phantom: PhantomData<E>,
352}
353
354#[allow(deprecated)]
355impl<N: Default, E: Endianness> BitCounter<N, E> {
356    /// Creates new counter
357    #[inline]
358    pub fn new() -> Self {
359        BitCounter {
360            bits: BitsWritten::new(),
361            phantom: PhantomData,
362        }
363    }
364}
365
366#[allow(deprecated)]
367impl<N: Copy, E: Endianness> BitCounter<N, E> {
368    /// Returns number of bits written
369    #[inline]
370    pub fn written(&self) -> N {
371        self.bits.written()
372    }
373}
374
375#[allow(deprecated)]
376impl<N, E: Endianness> BitCounter<N, E> {
377    /// Returns number of bits written
378    #[inline]
379    pub fn into_written(self) -> N {
380        self.bits.into_written()
381    }
382}
383
384#[allow(deprecated)]
385impl<N, E> BitWrite for BitCounter<N, E>
386where
387    E: Endianness,
388    N: Counter,
389{
390    #[inline]
391    fn write_bit(&mut self, bit: bool) -> io::Result<()> {
392        BitWrite::write_bit(&mut self.bits, bit)
393    }
394
395    #[inline]
396    fn write_const<const BITS: u32, const VALUE: u32>(&mut self) -> io::Result<()> {
397        BitWrite::write_const::<BITS, VALUE>(&mut self.bits)
398    }
399
400    #[inline]
401    fn write_unsigned<const BITS: u32, U>(&mut self, value: U) -> io::Result<()>
402    where
403        U: UnsignedInteger,
404    {
405        BitWrite::write_unsigned::<BITS, U>(&mut self.bits, value)
406    }
407
408    #[inline]
409    fn write_signed<const BITS: u32, S>(&mut self, value: S) -> io::Result<()>
410    where
411        S: SignedInteger,
412    {
413        BitWrite::write_signed::<BITS, S>(&mut self.bits, value)
414    }
415
416    #[inline]
417    fn write_unsigned_counted<const MAX: u32, U>(
418        &mut self,
419        count: BitCount<MAX>,
420        value: U,
421    ) -> io::Result<()>
422    where
423        U: UnsignedInteger,
424    {
425        BitWrite::write_unsigned_counted::<MAX, U>(&mut self.bits, count, value)
426    }
427
428    #[inline]
429    fn write_signed_counted<const MAX: u32, S>(
430        &mut self,
431        bits: impl TryInto<SignedBitCount<MAX>>,
432        value: S,
433    ) -> io::Result<()>
434    where
435        S: SignedInteger,
436    {
437        BitWrite::write_signed_counted::<MAX, S>(&mut self.bits, bits, value)
438    }
439
440    #[inline]
441    fn write_from<V>(&mut self, value: V) -> io::Result<()>
442    where
443        V: Primitive,
444    {
445        BitWrite::write_from(&mut self.bits, value)
446    }
447
448    #[inline]
449    fn write_as_from<F, V>(&mut self, value: V) -> io::Result<()>
450    where
451        F: Endianness,
452        V: Primitive,
453    {
454        BitWrite::write_as_from::<F, V>(&mut self.bits, value)
455    }
456
457    #[inline]
458    fn pad(&mut self, bits: u32) -> io::Result<()> {
459        BitWrite::pad(&mut self.bits, bits)
460    }
461
462    #[inline]
463    fn write_bytes(&mut self, buf: &[u8]) -> io::Result<()> {
464        BitWrite::write_bytes(&mut self.bits, buf)
465    }
466
467    fn write_unary<const STOP_BIT: u8>(&mut self, value: u32) -> io::Result<()> {
468        BitWrite::write_unary::<STOP_BIT>(&mut self.bits, value)
469    }
470
471    #[inline]
472    fn byte_aligned(&self) -> bool {
473        BitWrite::byte_aligned(&self.bits)
474    }
475}