1use super::obu::{encode_leb128, ObuHeader, ObuType};
4use super::tile_encoder::{ParallelTileEncoder, TileEncoderConfig, TileInfoBuilder};
5use crate::error::{CodecError, CodecResult};
6use crate::frame::VideoFrame;
7use crate::traits::{EncodedPacket, EncoderConfig, VideoEncoder};
8use oximedia_core::CodecId;
9
10#[derive(Debug)]
12pub struct Av1Encoder {
13 config: EncoderConfig,
15 frame_count: u64,
17 output_queue: Vec<EncodedPacket>,
19 flushing: bool,
21 tile_encoder: Option<ParallelTileEncoder>,
23 quality: u8,
25}
26
27impl Av1Encoder {
28 pub fn new(config: EncoderConfig) -> CodecResult<Self> {
34 if config.width == 0 || config.height == 0 {
35 return Err(CodecError::InvalidParameter(
36 "Invalid frame dimensions".to_string(),
37 ));
38 }
39
40 if config.codec != CodecId::Av1 {
41 return Err(CodecError::InvalidParameter(
42 "Expected AV1 codec".to_string(),
43 ));
44 }
45
46 let tile_encoder = if config.threads > 1 {
48 let tile_config = TileEncoderConfig::auto(config.width, config.height, config.threads);
49 Some(ParallelTileEncoder::new(
50 tile_config,
51 config.width,
52 config.height,
53 )?)
54 } else {
55 None
56 };
57
58 let quality = Self::compute_quality(&config);
60
61 Ok(Self {
62 config,
63 frame_count: 0,
64 output_queue: Vec::new(),
65 flushing: false,
66 tile_encoder,
67 quality,
68 })
69 }
70
71 fn compute_quality(config: &EncoderConfig) -> u8 {
73 use crate::traits::BitrateMode;
74
75 match config.bitrate {
76 BitrateMode::Crf(crf) => {
77 let normalized = (crf / 51.0).clamp(0.0, 1.0);
80 (255.0 * (1.0 - normalized)) as u8
81 }
82 BitrateMode::Cbr(_) | BitrateMode::Vbr { .. } => {
83 128
85 }
86 BitrateMode::Lossless => {
87 255
89 }
90 }
91 }
92
93 fn encode_frame(&mut self, frame: &VideoFrame) {
95 let is_keyframe = self.frame_count % u64::from(self.config.keyint) == 0;
96 let mut data = Vec::new();
97
98 if is_keyframe {
99 self.write_sequence_header(&mut data);
100 }
101
102 if let Some(ref tile_encoder) = self.tile_encoder {
104 self.encode_frame_with_tiles(frame, &mut data, is_keyframe);
105 } else {
106 Self::write_frame_obu(&mut data, is_keyframe);
107 }
108
109 #[allow(clippy::cast_possible_wrap)]
110 let pts = self.frame_count as i64;
111 let dts = pts;
112
113 self.output_queue.push(EncodedPacket {
114 data,
115 pts,
116 dts,
117 keyframe: is_keyframe,
118 duration: Some(1),
119 });
120
121 self.frame_count += 1;
122 }
123
124 fn encode_frame_with_tiles(&self, frame: &VideoFrame, data: &mut Vec<u8>, is_keyframe: bool) {
126 if let Some(ref tile_encoder) = self.tile_encoder {
127 match tile_encoder.encode_frame(frame, self.quality, is_keyframe) {
129 Ok(tiles) => {
130 match tile_encoder.merge_tiles(&tiles) {
132 Ok(merged) => {
133 self.write_frame_header_with_tiles(data, is_keyframe, tile_encoder);
135 data.extend_from_slice(&merged);
137 }
138 Err(_) => {
139 Self::write_frame_obu(data, is_keyframe);
141 }
142 }
143 }
144 Err(_) => {
145 Self::write_frame_obu(data, is_keyframe);
147 }
148 }
149 }
150 }
151
152 fn write_frame_header_with_tiles(
154 &self,
155 data: &mut Vec<u8>,
156 is_keyframe: bool,
157 tile_encoder: &ParallelTileEncoder,
158 ) {
159 let header = ObuHeader {
160 obu_type: ObuType::Frame,
161 has_extension: false,
162 has_size: true,
163 temporal_id: 0,
164 spatial_id: 0,
165 };
166 data.extend(header.to_bytes());
167
168 let mut frame_header = Vec::new();
170 let frame_type = u8::from(!is_keyframe);
171 frame_header.push((frame_type << 5) | 0x10);
172
173 let tile_info = TileInfoBuilder::from_config(
175 tile_encoder.config(),
176 self.config.width,
177 self.config.height,
178 );
179
180 if tile_info.tile_count() > 1 {
182 frame_header.push(tile_info.tile_cols_log2);
183 frame_header.push(tile_info.tile_rows_log2);
184 }
185
186 let size_bytes = encode_leb128(frame_header.len() as u64);
188 data.extend(size_bytes);
189 data.extend(frame_header);
190 }
191
192 fn write_sequence_header(&self, data: &mut Vec<u8>) {
194 let header = ObuHeader {
195 obu_type: ObuType::SequenceHeader,
196 has_extension: false,
197 has_size: true,
198 temporal_id: 0,
199 spatial_id: 0,
200 };
201 data.extend(header.to_bytes());
202
203 let payload = self.build_sequence_header_payload();
204 let size_bytes = encode_leb128(payload.len() as u64);
205 data.extend(size_bytes);
206 data.extend(payload);
207 }
208
209 #[allow(clippy::cast_possible_truncation)]
211 fn build_sequence_header_payload(&self) -> Vec<u8> {
212 let mut payload = Vec::new();
213 payload.push(0x00);
214 payload.push(0x00);
215 payload.push(0x00);
216
217 let width_bits = 32 - self.config.width.leading_zeros();
218 let height_bits = 32 - self.config.height.leading_zeros();
219 payload.push(
220 ((width_bits.saturating_sub(1) as u8) << 4) | (height_bits.saturating_sub(1) as u8),
221 );
222
223 let width_minus_1 = self.config.width.saturating_sub(1);
224 let height_minus_1 = self.config.height.saturating_sub(1);
225 payload.extend(&width_minus_1.to_be_bytes()[2..]);
226 payload.extend(&height_minus_1.to_be_bytes()[2..]);
227
228 payload
229 }
230
231 fn write_frame_obu(data: &mut Vec<u8>, is_keyframe: bool) {
233 let header = ObuHeader {
234 obu_type: ObuType::Frame,
235 has_extension: false,
236 has_size: true,
237 temporal_id: 0,
238 spatial_id: 0,
239 };
240 data.extend(header.to_bytes());
241
242 let frame_data = Self::build_frame_payload(is_keyframe);
243 let size_bytes = encode_leb128(frame_data.len() as u64);
244 data.extend(size_bytes);
245 data.extend(frame_data);
246 }
247
248 fn build_frame_payload(is_keyframe: bool) -> Vec<u8> {
250 let mut payload = Vec::new();
251 let frame_type = u8::from(!is_keyframe);
252 payload.push((frame_type << 5) | 0x10);
253 payload.extend(&[0x00; 16]);
254 payload
255 }
256}
257
258impl VideoEncoder for Av1Encoder {
259 fn codec(&self) -> CodecId {
260 CodecId::Av1
261 }
262
263 fn send_frame(&mut self, frame: &VideoFrame) -> CodecResult<()> {
264 if self.flushing {
265 return Err(CodecError::InvalidParameter(
266 "Cannot send frame while flushing".to_string(),
267 ));
268 }
269
270 if frame.width != self.config.width || frame.height != self.config.height {
271 return Err(CodecError::InvalidParameter(format!(
272 "Frame dimensions {}x{} don't match encoder config {}x{}",
273 frame.width, frame.height, self.config.width, self.config.height
274 )));
275 }
276
277 self.encode_frame(frame);
278 Ok(())
279 }
280
281 fn receive_packet(&mut self) -> CodecResult<Option<EncodedPacket>> {
282 if self.output_queue.is_empty() {
283 return Ok(None);
284 }
285 Ok(Some(self.output_queue.remove(0)))
286 }
287
288 fn flush(&mut self) -> CodecResult<()> {
289 self.flushing = true;
290 Ok(())
291 }
292
293 fn config(&self) -> &EncoderConfig {
294 &self.config
295 }
296}
297
298impl Av1Encoder {
299 #[must_use]
301 pub fn tile_config(&self) -> Option<&TileEncoderConfig> {
302 self.tile_encoder.as_ref().map(|e| e.config())
303 }
304
305 #[must_use]
307 pub fn has_tile_encoding(&self) -> bool {
308 self.tile_encoder.is_some()
309 }
310
311 #[must_use]
313 pub fn tile_count(&self) -> usize {
314 self.tile_encoder.as_ref().map_or(1, |e| e.tile_count())
315 }
316
317 pub fn set_tile_config(&mut self, tile_config: TileEncoderConfig) -> CodecResult<()> {
323 tile_config.validate()?;
324
325 self.tile_encoder = Some(ParallelTileEncoder::new(
326 tile_config,
327 self.config.width,
328 self.config.height,
329 )?);
330
331 Ok(())
332 }
333
334 pub fn disable_tile_encoding(&mut self) {
336 self.tile_encoder = None;
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use oximedia_core::PixelFormat;
344
345 #[test]
346 fn test_encoder_creation() {
347 let config = EncoderConfig::av1(1920, 1080);
348 let encoder = Av1Encoder::new(config);
349 assert!(encoder.is_ok());
350 }
351
352 #[test]
353 fn test_encoder_invalid_dimensions() {
354 let config = EncoderConfig::av1(0, 0);
355 let encoder = Av1Encoder::new(config);
356 assert!(encoder.is_err());
357 }
358
359 #[test]
360 fn test_encoder_codec_id() {
361 let config = EncoderConfig::av1(1920, 1080);
362 let encoder = Av1Encoder::new(config).expect("should succeed");
363 assert_eq!(encoder.codec(), CodecId::Av1);
364 }
365
366 #[test]
367 fn test_encode_frame() {
368 let config = EncoderConfig::av1(320, 240);
369 let mut encoder = Av1Encoder::new(config).expect("should succeed");
370
371 let mut frame = VideoFrame::new(PixelFormat::Yuv420p, 320, 240);
372 frame.allocate();
373
374 assert!(encoder.send_frame(&frame).is_ok());
375
376 let packet = encoder.receive_packet().expect("should succeed");
377 assert!(packet.is_some());
378 let packet = packet.expect("should succeed");
379 assert!(packet.keyframe);
380 assert!(!packet.data.is_empty());
381 }
382
383 #[test]
384 fn test_frame_dimension_mismatch() {
385 let config = EncoderConfig::av1(320, 240);
386 let mut encoder = Av1Encoder::new(config).expect("should succeed");
387
388 let frame = VideoFrame::new(PixelFormat::Yuv420p, 640, 480);
389 let result = encoder.send_frame(&frame);
390 assert!(result.is_err());
391 }
392}