cros_codecs/codec/av1/
writer.rs

1// Copyright 2024 The ChromiumOS Authors
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5use std::fmt;
6use std::io::Write;
7
8use crate::bitstream_utils::BitWriter;
9use crate::bitstream_utils::BitWriterError;
10
11#[derive(Debug)]
12pub enum ObuWriterError {
13    BitWriterError(BitWriterError),
14    UnalignedLeb128,
15}
16
17impl fmt::Display for ObuWriterError {
18    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
19        match self {
20            ObuWriterError::BitWriterError(x) => write!(f, "{}", x.to_string()),
21            ObuWriterError::UnalignedLeb128 => {
22                write!(f, "attempted to write leb128 on unaligned position")
23            }
24        }
25    }
26}
27
28impl From<BitWriterError> for ObuWriterError {
29    fn from(err: BitWriterError) -> Self {
30        ObuWriterError::BitWriterError(err)
31    }
32}
33
34pub type ObuWriterResult<T> = std::result::Result<T, ObuWriterError>;
35
36pub struct ObuWriter<W: Write>(BitWriter<W>);
37
38impl<W: Write> ObuWriter<W> {
39    pub fn new(writer: W) -> Self {
40        Self(BitWriter::new(writer))
41    }
42
43    /// Writes fixed bit size integer. Corresponds to `f(n)` in AV1 spec defined in 4.10.2.
44    pub fn write_f<T: Into<u32>>(&mut self, bits: usize, value: T) -> ObuWriterResult<usize> {
45        self.0.write_f(bits, value).map_err(ObuWriterError::BitWriterError)
46    }
47
48    /// Writes variable length unsigned n-bit number. Corresponds to `uvlc()` in AV1 spec
49    /// defined in 4.10.3.
50    pub fn write_uvlc<T: Into<u32>>(&mut self, value: T) -> ObuWriterResult<usize> {
51        let value: u32 = value.into();
52        if value == u32::MAX {
53            return self.write_f(32, 0u32);
54        }
55
56        let value = value + 1;
57        let leading_zeros = (32 - value.leading_zeros()) as usize;
58
59        Ok(self.write_f(leading_zeros - 1, 0u32)? + self.write_f(leading_zeros, value)?)
60    }
61
62    /// Writes unsigned little-endian n-byte integer. Corresponds to `le(n)` in AV1 spec
63    /// defined in 4.10.4.
64    pub fn write_le<T: Into<u32>>(&mut self, n: usize, value: T) -> ObuWriterResult<usize> {
65        let value: u32 = value.into();
66        let mut value = value.to_le();
67
68        for _ in 0..n {
69            self.write_f(4, value & 0xff)?;
70            value >>= 8;
71        }
72
73        Ok(n)
74    }
75
76    /// Writes unsigned integer represented by a variable number of little-endian bytes.
77    /// Corresponds to `leb128()` in AV1 spec defined in 4.10.4.
78    ///
79    /// Note: Despite the name, the AV1 4.10.4 limits the value to [`u32::MAX`] = (1 << 32) - 1.
80    pub fn write_leb128<T: Into<u32>>(
81        &mut self,
82        value: T,
83        min_bytes: usize,
84    ) -> ObuWriterResult<usize> {
85        if !self.aligned() {
86            return Err(ObuWriterError::UnalignedLeb128);
87        }
88
89        let value: u32 = value.into();
90        let mut value: u32 = value.to_le();
91        let mut bytes = 0;
92
93        for _ in 0..8 {
94            bytes += 1;
95
96            if value >= 0x7f || bytes < min_bytes {
97                self.write_f(8, 0x80 | (value & 0x7f))?;
98                value >>= 7;
99            } else {
100                self.write_f(8, value & 0x7f)?;
101                break;
102            }
103        }
104
105        assert!(value < 0x7f);
106
107        Ok(bytes)
108    }
109
110    pub fn write_su<T: Into<i32>>(&mut self, bits: usize, value: T) -> ObuWriterResult<usize> {
111        let mut value: i32 = value.into();
112        if value < 0 {
113            value += 1 << bits;
114        }
115
116        assert!(value >= 0);
117        self.write_f(bits, value.unsigned_abs())
118    }
119
120    pub fn aligned(&self) -> bool {
121        !self.0.has_data_pending()
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use crate::codec::av1::reader::Reader;
129
130    const TEST_VECTOR: &[u32] = &[
131        // some random test values
132        u32::MAX,
133        1,
134        2,
135        3,
136        4,
137        10,
138        20,
139        7312,
140        8832,
141        10123,
142        47457,
143        21390213,
144        u32::MIN,
145        u32::MAX - 1,
146    ];
147
148    #[test]
149    fn test_uvlc() {
150        for &value in TEST_VECTOR {
151            let mut buf = Vec::<u8>::new();
152
153            ObuWriter::new(&mut buf).write_uvlc(value).unwrap();
154
155            if value == u32::MAX {
156                // force stop uvlc
157                buf.push(0x80);
158            }
159
160            let read = Reader::new(&buf).read_uvlc().unwrap();
161
162            assert_eq!(read, value, "failed testing {}", value);
163        }
164    }
165
166    #[test]
167    fn test_leb128() {
168        for &value in TEST_VECTOR {
169            let mut buf = Vec::<u8>::new();
170
171            ObuWriter::new(&mut buf).write_leb128(value, 0).unwrap();
172            let read = Reader::new(&buf).read_leb128().unwrap();
173
174            assert_eq!(read, value, "failed testing {}", value);
175        }
176    }
177
178    #[test]
179    fn test_su() {
180        let vector =
181            TEST_VECTOR.iter().map(|e| *e as i32).chain(TEST_VECTOR.iter().map(|e| -(*e as i32)));
182
183        for value in vector {
184            let bits = 32 - value.abs().leading_zeros() as usize + 1; // For sign
185            if bits >= 32 {
186                // Skip too big nubmers
187                continue;
188            }
189
190            let mut buf = Vec::<u8>::new();
191
192            ObuWriter::new(&mut buf).write_su(bits, value).unwrap();
193
194            let read = Reader::new(&buf).read_su(bits as usize).unwrap();
195
196            assert_eq!(read, value, "failed testing {}", value);
197        }
198    }
199}