Skip to main content

oximedia_codec/av1/
encoder.rs

1//! AV1 encoder implementation.
2
3use super::obu::{encode_leb128, ObuHeader, ObuType};
4use super::tile_encoder::{ParallelTileEncoder, TileEncoderConfig, TileInfoBuilder};
5use crate::error::{CodecError, CodecResult};
6use crate::frame::VideoFrame;
7use crate::traits::{EncodedPacket, EncoderConfig, VideoEncoder};
8use oximedia_core::CodecId;
9
10/// AV1 encoder.
11#[derive(Debug)]
12pub struct Av1Encoder {
13    /// Encoder configuration.
14    config: EncoderConfig,
15    /// Frame counter.
16    frame_count: u64,
17    /// Pending output packets.
18    output_queue: Vec<EncodedPacket>,
19    /// Encoder is in flush mode.
20    flushing: bool,
21    /// Tile encoder for parallel encoding.
22    tile_encoder: Option<ParallelTileEncoder>,
23    /// Quality parameter (0-255).
24    quality: u8,
25}
26
27impl Av1Encoder {
28    /// Create a new AV1 encoder.
29    ///
30    /// # Errors
31    ///
32    /// Returns error if encoder initialization fails.
33    pub fn new(config: EncoderConfig) -> CodecResult<Self> {
34        if config.width == 0 || config.height == 0 {
35            return Err(CodecError::InvalidParameter(
36                "Invalid frame dimensions".to_string(),
37            ));
38        }
39
40        if config.codec != CodecId::Av1 {
41            return Err(CodecError::InvalidParameter(
42                "Expected AV1 codec".to_string(),
43            ));
44        }
45
46        // Initialize tile encoder if threads > 1
47        let tile_encoder = if config.threads > 1 {
48            let tile_config = TileEncoderConfig::auto(config.width, config.height, config.threads);
49            Some(ParallelTileEncoder::new(
50                tile_config,
51                config.width,
52                config.height,
53            )?)
54        } else {
55            None
56        };
57
58        // Compute quality from bitrate mode
59        let quality = Self::compute_quality(&config);
60
61        Ok(Self {
62            config,
63            frame_count: 0,
64            output_queue: Vec::new(),
65            flushing: false,
66            tile_encoder,
67            quality,
68        })
69    }
70
71    /// Compute quality parameter from encoder config.
72    fn compute_quality(config: &EncoderConfig) -> u8 {
73        use crate::traits::BitrateMode;
74
75        match config.bitrate {
76            BitrateMode::Crf(crf) => {
77                // Map CRF (0-51) to quality (255-0)
78                // Lower CRF = higher quality
79                let normalized = (crf / 51.0).clamp(0.0, 1.0);
80                (255.0 * (1.0 - normalized)) as u8
81            }
82            BitrateMode::Cbr(_) | BitrateMode::Vbr { .. } => {
83                // Default medium quality for bitrate modes
84                128
85            }
86            BitrateMode::Lossless => {
87                // Maximum quality for lossless
88                255
89            }
90        }
91    }
92
93    /// Encode a single frame.
94    fn encode_frame(&mut self, frame: &VideoFrame) {
95        let is_keyframe = self.frame_count % u64::from(self.config.keyint) == 0;
96        let mut data = Vec::new();
97
98        if is_keyframe {
99            self.write_sequence_header(&mut data);
100        }
101
102        // Use tile encoding if available, otherwise single-threaded
103        if let Some(ref tile_encoder) = self.tile_encoder {
104            self.encode_frame_with_tiles(frame, &mut data, is_keyframe);
105        } else {
106            Self::write_frame_obu(&mut data, is_keyframe);
107        }
108
109        #[allow(clippy::cast_possible_wrap)]
110        let pts = self.frame_count as i64;
111        let dts = pts;
112
113        self.output_queue.push(EncodedPacket {
114            data,
115            pts,
116            dts,
117            keyframe: is_keyframe,
118            duration: Some(1),
119        });
120
121        self.frame_count += 1;
122    }
123
124    /// Encode frame using parallel tile encoding.
125    fn encode_frame_with_tiles(&self, frame: &VideoFrame, data: &mut Vec<u8>, is_keyframe: bool) {
126        if let Some(ref tile_encoder) = self.tile_encoder {
127            // Encode tiles in parallel
128            match tile_encoder.encode_frame(frame, self.quality, is_keyframe) {
129                Ok(tiles) => {
130                    // Merge tiles into bitstream
131                    match tile_encoder.merge_tiles(&tiles) {
132                        Ok(merged) => {
133                            // Write frame header with tile info
134                            self.write_frame_header_with_tiles(data, is_keyframe, tile_encoder);
135                            // Append tile data
136                            data.extend_from_slice(&merged);
137                        }
138                        Err(_) => {
139                            // Fall back to simple encoding on error
140                            Self::write_frame_obu(data, is_keyframe);
141                        }
142                    }
143                }
144                Err(_) => {
145                    // Fall back to simple encoding on error
146                    Self::write_frame_obu(data, is_keyframe);
147                }
148            }
149        }
150    }
151
152    /// Write frame header with tile information.
153    fn write_frame_header_with_tiles(
154        &self,
155        data: &mut Vec<u8>,
156        is_keyframe: bool,
157        tile_encoder: &ParallelTileEncoder,
158    ) {
159        let header = ObuHeader {
160            obu_type: ObuType::Frame,
161            has_extension: false,
162            has_size: true,
163            temporal_id: 0,
164            spatial_id: 0,
165        };
166        data.extend(header.to_bytes());
167
168        // Build frame header with tile info
169        let mut frame_header = Vec::new();
170        let frame_type = u8::from(!is_keyframe);
171        frame_header.push((frame_type << 5) | 0x10);
172
173        // Write tile info
174        let tile_info = TileInfoBuilder::from_config(
175            tile_encoder.config(),
176            self.config.width,
177            self.config.height,
178        );
179
180        // Encode tile configuration (simplified)
181        if tile_info.tile_count() > 1 {
182            frame_header.push(tile_info.tile_cols_log2);
183            frame_header.push(tile_info.tile_rows_log2);
184        }
185
186        // Write size and header (placeholder for real implementation)
187        let size_bytes = encode_leb128(frame_header.len() as u64);
188        data.extend(size_bytes);
189        data.extend(frame_header);
190    }
191
192    /// Write sequence header OBU.
193    fn write_sequence_header(&self, data: &mut Vec<u8>) {
194        let header = ObuHeader {
195            obu_type: ObuType::SequenceHeader,
196            has_extension: false,
197            has_size: true,
198            temporal_id: 0,
199            spatial_id: 0,
200        };
201        data.extend(header.to_bytes());
202
203        let payload = self.build_sequence_header_payload();
204        let size_bytes = encode_leb128(payload.len() as u64);
205        data.extend(size_bytes);
206        data.extend(payload);
207    }
208
209    /// Build sequence header payload.
210    #[allow(clippy::cast_possible_truncation)]
211    fn build_sequence_header_payload(&self) -> Vec<u8> {
212        let mut payload = Vec::new();
213        payload.push(0x00);
214        payload.push(0x00);
215        payload.push(0x00);
216
217        let width_bits = 32 - self.config.width.leading_zeros();
218        let height_bits = 32 - self.config.height.leading_zeros();
219        payload.push(
220            ((width_bits.saturating_sub(1) as u8) << 4) | (height_bits.saturating_sub(1) as u8),
221        );
222
223        let width_minus_1 = self.config.width.saturating_sub(1);
224        let height_minus_1 = self.config.height.saturating_sub(1);
225        payload.extend(&width_minus_1.to_be_bytes()[2..]);
226        payload.extend(&height_minus_1.to_be_bytes()[2..]);
227
228        payload
229    }
230
231    /// Write frame OBU.
232    fn write_frame_obu(data: &mut Vec<u8>, is_keyframe: bool) {
233        let header = ObuHeader {
234            obu_type: ObuType::Frame,
235            has_extension: false,
236            has_size: true,
237            temporal_id: 0,
238            spatial_id: 0,
239        };
240        data.extend(header.to_bytes());
241
242        let frame_data = Self::build_frame_payload(is_keyframe);
243        let size_bytes = encode_leb128(frame_data.len() as u64);
244        data.extend(size_bytes);
245        data.extend(frame_data);
246    }
247
248    /// Build frame payload.
249    fn build_frame_payload(is_keyframe: bool) -> Vec<u8> {
250        let mut payload = Vec::new();
251        let frame_type = u8::from(!is_keyframe);
252        payload.push((frame_type << 5) | 0x10);
253        payload.extend(&[0x00; 16]);
254        payload
255    }
256}
257
258impl VideoEncoder for Av1Encoder {
259    fn codec(&self) -> CodecId {
260        CodecId::Av1
261    }
262
263    fn send_frame(&mut self, frame: &VideoFrame) -> CodecResult<()> {
264        if self.flushing {
265            return Err(CodecError::InvalidParameter(
266                "Cannot send frame while flushing".to_string(),
267            ));
268        }
269
270        if frame.width != self.config.width || frame.height != self.config.height {
271            return Err(CodecError::InvalidParameter(format!(
272                "Frame dimensions {}x{} don't match encoder config {}x{}",
273                frame.width, frame.height, self.config.width, self.config.height
274            )));
275        }
276
277        self.encode_frame(frame);
278        Ok(())
279    }
280
281    fn receive_packet(&mut self) -> CodecResult<Option<EncodedPacket>> {
282        if self.output_queue.is_empty() {
283            return Ok(None);
284        }
285        Ok(Some(self.output_queue.remove(0)))
286    }
287
288    fn flush(&mut self) -> CodecResult<()> {
289        self.flushing = true;
290        Ok(())
291    }
292
293    fn config(&self) -> &EncoderConfig {
294        &self.config
295    }
296}
297
298impl Av1Encoder {
299    /// Get tile encoder configuration if enabled.
300    #[must_use]
301    pub fn tile_config(&self) -> Option<&TileEncoderConfig> {
302        self.tile_encoder.as_ref().map(|e| e.config())
303    }
304
305    /// Check if parallel tile encoding is enabled.
306    #[must_use]
307    pub fn has_tile_encoding(&self) -> bool {
308        self.tile_encoder.is_some()
309    }
310
311    /// Get number of tiles being used.
312    #[must_use]
313    pub fn tile_count(&self) -> usize {
314        self.tile_encoder.as_ref().map_or(1, |e| e.tile_count())
315    }
316
317    /// Enable or reconfigure tile encoding.
318    ///
319    /// # Errors
320    ///
321    /// Returns error if tile configuration is invalid.
322    pub fn set_tile_config(&mut self, tile_config: TileEncoderConfig) -> CodecResult<()> {
323        tile_config.validate()?;
324
325        self.tile_encoder = Some(ParallelTileEncoder::new(
326            tile_config,
327            self.config.width,
328            self.config.height,
329        )?);
330
331        Ok(())
332    }
333
334    /// Disable tile encoding (single-threaded mode).
335    pub fn disable_tile_encoding(&mut self) {
336        self.tile_encoder = None;
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use oximedia_core::PixelFormat;
344
345    #[test]
346    fn test_encoder_creation() {
347        let config = EncoderConfig::av1(1920, 1080);
348        let encoder = Av1Encoder::new(config);
349        assert!(encoder.is_ok());
350    }
351
352    #[test]
353    fn test_encoder_invalid_dimensions() {
354        let config = EncoderConfig::av1(0, 0);
355        let encoder = Av1Encoder::new(config);
356        assert!(encoder.is_err());
357    }
358
359    #[test]
360    fn test_encoder_codec_id() {
361        let config = EncoderConfig::av1(1920, 1080);
362        let encoder = Av1Encoder::new(config).expect("should succeed");
363        assert_eq!(encoder.codec(), CodecId::Av1);
364    }
365
366    #[test]
367    fn test_encode_frame() {
368        let config = EncoderConfig::av1(320, 240);
369        let mut encoder = Av1Encoder::new(config).expect("should succeed");
370
371        let mut frame = VideoFrame::new(PixelFormat::Yuv420p, 320, 240);
372        frame.allocate();
373
374        assert!(encoder.send_frame(&frame).is_ok());
375
376        let packet = encoder.receive_packet().expect("should succeed");
377        assert!(packet.is_some());
378        let packet = packet.expect("should succeed");
379        assert!(packet.keyframe);
380        assert!(!packet.data.is_empty());
381    }
382
383    #[test]
384    fn test_frame_dimension_mismatch() {
385        let config = EncoderConfig::av1(320, 240);
386        let mut encoder = Av1Encoder::new(config).expect("should succeed");
387
388        let frame = VideoFrame::new(PixelFormat::Yuv420p, 640, 480);
389        let result = encoder.send_frame(&frame);
390        assert!(result.is_err());
391    }
392}