cros_codecs/codec/av1/
writer.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::io::Write;
6
7use thiserror::Error;
8
9use crate::utils::BitWriter;
10use crate::utils::BitWriterError;
11
12#[derive(Error, Debug)]
13pub enum ObuWriterError {
14    #[error(transparent)]
15    BitWriterError(#[from] BitWriterError),
16
17    #[error("attemped to write leb128 on unaligned position")]
18    UnalignedLeb128,
19}
20
21pub type ObuWriterResult<T> = std::result::Result<T, ObuWriterError>;
22
23pub struct ObuWriter<W: Write>(BitWriter<W>);
24
25impl<W: Write> ObuWriter<W> {
26    pub fn new(writer: W) -> Self {
27        Self(BitWriter::new(writer))
28    }
29
30    /// Writes fixed bit size integer. Corresponds to `f(n)` in AV1 spec defined in 4.10.2.
31    pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> ObuWriterResult<usize> {
32        self.0
33            .write_f(bits, value)
34            .map_err(ObuWriterError::BitWriterError)
35    }
36
37    /// Writes variable length unsigned n-bit number. Corresponds to `uvlc()` in AV1 spec
38    /// defined in 4.10.3.
39    pub fn write_uvlc<T: Into<u32>>(&mut self, value: T) -> ObuWriterResult<usize> {
40        let value: u32 = value.into();
41        if value == u32::MAX {
42            return self.write_f(32, 0u32);
43        }
44
45        let value = value + 1;
46        let leading_zeros = (32 - value.leading_zeros()) as usize;
47
48        Ok(self.write_f(leading_zeros - 1, 0u32)? + self.write_f(leading_zeros, value)?)
49    }
50
51    /// Writes unsigned little-endian n-byte integer. Corresponds to `le(n)` in AV1 spec
52    /// defined in 4.10.4.
53    pub fn write_le<T: Into<u32>>(&mut self, n: usize, value: T) -> ObuWriterResult<usize> {
54        let value: u32 = value.into();
55        let mut value = value.to_le();
56
57        for _ in 0..n {
58            self.write_f(4, value & 0xff)?;
59            value >>= 8;
60        }
61
62        Ok(n)
63    }
64
65    /// Writes unsigned integer represented by a variable number of little-endian bytes.
66    /// Corresponds to `leb128()` in AV1 spec defined in 4.10.4.
67    ///
68    /// Note: Despite the name, the AV1 4.10.4 limits the value to [`u32::MAX`] = (1 << 32) - 1.
69    pub fn write_leb128<T: Into<u32>>(
70        &mut self,
71        value: T,
72        min_bytes: usize,
73    ) -> ObuWriterResult<usize> {
74        if !self.aligned() {
75            return Err(ObuWriterError::UnalignedLeb128);
76        }
77
78        let value: u32 = value.into();
79        let mut value: u32 = value.to_le();
80        let mut bytes = 0;
81
82        for _ in 0..8 {
83            bytes += 1;
84
85            if value >= 0x7f || bytes < min_bytes {
86                self.write_f(8, 0x80 | (value & 0x7f))?;
87                value >>= 7;
88            } else {
89                self.write_f(8, value & 0x7f)?;
90                break;
91            }
92        }
93
94        assert!(value < 0x7f);
95
96        Ok(bytes)
97    }
98
99    pub fn write_su<T: Into<i32>>(&mut self, bits: usize, value: T) -> ObuWriterResult<usize> {
100        let mut value: i32 = value.into();
101        if value < 0 {
102            value += 1 << bits;
103        }
104
105        assert!(value >= 0);
106        self.write_f(bits, value.unsigned_abs())
107    }
108
109    pub fn aligned(&self) -> bool {
110        !self.0.has_data_pending()
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::codec::av1::reader::Reader;
118
119    const TEST_VECTOR: &[u32] = &[
120        // some random test values
121        u32::MAX,
122        1,
123        2,
124        3,
125        4,
126        10,
127        20,
128        7312,
129        8832,
130        10123,
131        47457,
132        21390213,
133        u32::MIN,
134        u32::MAX - 1,
135    ];
136
137    #[test]
138    fn test_uvlc() {
139        for &value in TEST_VECTOR {
140            let mut buf = Vec::<u8>::new();
141
142            ObuWriter::new(&mut buf).write_uvlc(value).unwrap();
143
144            if value == u32::MAX {
145                // force stop uvlc
146                buf.push(0x80);
147            }
148
149            let read = Reader::new(&buf).read_uvlc().unwrap();
150
151            assert_eq!(read, value, "failed testing {}", value);
152        }
153    }
154
155    #[test]
156    fn test_leb128() {
157        for &value in TEST_VECTOR {
158            let mut buf = Vec::<u8>::new();
159
160            ObuWriter::new(&mut buf).write_leb128(value, 0).unwrap();
161            let read = Reader::new(&buf).read_leb128().unwrap();
162
163            assert_eq!(read, value, "failed testing {}", value);
164        }
165    }
166
167    #[test]
168    fn test_su() {
169        let vector = TEST_VECTOR
170            .iter()
171            .map(|e| *e as i32)
172            .chain(TEST_VECTOR.iter().map(|e| -(*e as i32)));
173
174        for value in vector {
175            let bits = 32 - value.abs().leading_zeros() as usize + 1; // For sign
176            if bits >= 32 {
177                // Skip too big nubmers
178                continue;
179            }
180
181            let mut buf = Vec::<u8>::new();
182
183            ObuWriter::new(&mut buf).write_su(bits, value).unwrap();
184
185            let read = Reader::new(&buf).read_su(bits as u8).unwrap();
186
187            assert_eq!(read, value, "failed testing {}", value);
188        }
189    }
190}