1#![forbid(unsafe_code)]
7#![allow(dead_code)]
8#![allow(clippy::doc_markdown)]
9#![allow(clippy::unused_self)]
10#![allow(clippy::missing_errors_doc)]
11#![allow(clippy::match_same_arms)]
12
13use super::block::{BlockContextManager, BlockModeInfo, BlockSize};
14use super::cdef::CdefParams;
15use super::coeff_decode::CoeffDecoder;
16use super::frame_header::{FrameHeader, FrameType as Av1FrameType};
17use super::loop_filter::LoopFilterParams;
18use super::obu::{ObuIterator, ObuType};
19use super::prediction::PredictionEngine;
20use super::quantization::QuantizationParams;
21use super::sequence::SequenceHeader;
22use super::symbols::SymbolDecoder;
23use super::tile::TileInfo;
24use super::transform::{Transform2D, TxType};
25use crate::error::{CodecError, CodecResult};
26use crate::frame::{FrameType, VideoFrame};
27use crate::reconstruct::{DecoderPipeline, FrameContext, PipelineConfig, ReferenceFrameManager};
28use crate::traits::{DecoderConfig, VideoDecoder};
29use oximedia_core::{CodecId, PixelFormat, Rational, Timestamp};
30
31#[derive(Clone, Debug, Default)]
33#[allow(dead_code)]
34struct DecoderState {
35 frame_header: Option<FrameHeader>,
37 loop_filter: LoopFilterParams,
39 cdef: CdefParams,
41 quantization: QuantizationParams,
43 tile_info: Option<TileInfo>,
45 frame_is_intra: bool,
47}
48
49impl DecoderState {
50 fn new() -> Self {
52 Self::default()
53 }
54
55 fn reset(&mut self) {
57 self.frame_header = None;
58 self.tile_info = None;
59 }
60}
61
62#[derive(Debug)]
64pub struct Av1Decoder {
65 config: DecoderConfig,
67 sequence_header: Option<SequenceHeader>,
69 output_queue: Vec<VideoFrame>,
71 flushing: bool,
73 frame_count: u64,
75 state: DecoderState,
77 pipeline: Option<DecoderPipeline>,
79 ref_manager: ReferenceFrameManager,
81 prediction: Option<PredictionEngine>,
83 block_context: Option<BlockContextManager>,
85}
86
87impl Av1Decoder {
88 pub fn new(config: DecoderConfig) -> CodecResult<Self> {
94 let mut decoder = Self {
95 config,
96 sequence_header: None,
97 output_queue: Vec::new(),
98 flushing: false,
99 frame_count: 0,
100 state: DecoderState::new(),
101 pipeline: None,
102 ref_manager: ReferenceFrameManager::new(),
103 prediction: None,
104 block_context: None,
105 };
106
107 if let Some(extradata) = decoder.config.extradata.clone() {
108 decoder.parse_extradata(&extradata)?;
109 }
110
111 Ok(decoder)
112 }
113
114 fn parse_extradata(&mut self, data: &[u8]) -> CodecResult<()> {
116 for obu_result in ObuIterator::new(data) {
117 let (header, payload) = obu_result?;
118 if header.obu_type == ObuType::SequenceHeader {
119 self.sequence_header = Some(SequenceHeader::parse(payload)?);
120 break;
121 }
122 }
123 Ok(())
124 }
125
126 #[allow(clippy::too_many_lines)]
128 fn decode_temporal_unit(&mut self, data: &[u8], pts: i64) -> CodecResult<()> {
129 self.state.reset();
131
132 for obu_result in ObuIterator::new(data) {
133 let (header, payload) = obu_result?;
134
135 match header.obu_type {
136 ObuType::SequenceHeader => {
137 self.sequence_header = Some(SequenceHeader::parse(payload)?);
138 }
139 ObuType::FrameHeader | ObuType::Frame => {
140 if let Some(ref seq) = self.sequence_header {
141 let frame_header = FrameHeader::parse(payload, seq)?;
143
144 self.state.frame_is_intra = frame_header.frame_is_intra;
146 self.state.loop_filter = frame_header.loop_filter.clone();
147 self.state.cdef = frame_header.cdef.clone();
148 self.state.quantization = frame_header.quantization.clone();
149 self.state.tile_info = Some(frame_header.tile_info.clone());
150 self.state.frame_header = Some(frame_header.clone());
151
152 let format = Self::determine_pixel_format(seq);
154 let width = frame_header.frame_size.upscaled_width;
155 let height = frame_header.frame_size.frame_height;
156
157 let mut frame = VideoFrame::new(
158 format,
159 if width > 0 {
160 width
161 } else {
162 seq.max_frame_width()
163 },
164 if height > 0 {
165 height
166 } else {
167 seq.max_frame_height()
168 },
169 );
170 frame.allocate();
171 frame.timestamp = Timestamp::new(pts, Rational::new(1, 1000));
172
173 frame.frame_type = match frame_header.frame_type {
175 Av1FrameType::KeyFrame => FrameType::Key,
176 Av1FrameType::InterFrame => FrameType::Inter,
177 Av1FrameType::IntraOnlyFrame => FrameType::Key, Av1FrameType::SwitchFrame => FrameType::Inter, };
180
181 self.output_queue.push(frame);
182 self.frame_count += 1;
183 }
184 }
185 ObuType::TileGroup => {
186 }
189 _ => {}
190 }
191 }
192
193 Ok(())
194 }
195
196 fn determine_pixel_format(seq: &SequenceHeader) -> PixelFormat {
198 let cc = &seq.color_config;
199 if cc.mono_chrome {
200 return PixelFormat::Gray8;
201 }
202 match (cc.bit_depth, cc.subsampling_x, cc.subsampling_y) {
203 (8, true, false) => PixelFormat::Yuv422p,
204 (8, false, false) => PixelFormat::Yuv444p,
205 (10, true, true) => PixelFormat::Yuv420p10le,
206 (12, true, true) => PixelFormat::Yuv420p12le,
207 _ => PixelFormat::Yuv420p,
209 }
210 }
211
212 #[must_use]
214 #[allow(dead_code)]
215 pub fn current_frame_header(&self) -> Option<&FrameHeader> {
216 self.state.frame_header.as_ref()
217 }
218
219 #[must_use]
221 #[allow(dead_code)]
222 pub fn current_sequence_header(&self) -> Option<&SequenceHeader> {
223 self.sequence_header.as_ref()
224 }
225
226 #[must_use]
228 #[allow(dead_code)]
229 pub fn loop_filter_params(&self) -> &LoopFilterParams {
230 &self.state.loop_filter
231 }
232
233 #[must_use]
235 #[allow(dead_code)]
236 pub fn cdef_params(&self) -> &CdefParams {
237 &self.state.cdef
238 }
239
240 #[must_use]
242 #[allow(dead_code)]
243 pub fn quantization_params(&self) -> &QuantizationParams {
244 &self.state.quantization
245 }
246
247 #[must_use]
249 #[allow(dead_code)]
250 pub fn tile_info(&self) -> Option<&TileInfo> {
251 self.state.tile_info.as_ref()
252 }
253
254 #[must_use]
256 #[allow(dead_code)]
257 pub const fn frame_count(&self) -> u64 {
258 self.frame_count
259 }
260
261 fn initialize_pipeline(&mut self, seq: &SequenceHeader) -> CodecResult<()> {
263 let width = seq.max_frame_width();
264 let height = seq.max_frame_height();
265 let bit_depth = seq.color_config.bit_depth;
266
267 let pipeline_config = PipelineConfig::new(width, height)
269 .with_bit_depth(bit_depth)
270 .with_all_filters();
271
272 self.pipeline = Some(
274 DecoderPipeline::new(pipeline_config)
275 .map_err(|e| CodecError::Internal(format!("Pipeline creation failed: {e:?}")))?,
276 );
277
278 self.prediction = Some(PredictionEngine::new(width, height, bit_depth));
280
281 self.block_context = Some(BlockContextManager::new(
283 width / 4,
284 seq.color_config.subsampling_x,
285 seq.color_config.subsampling_y,
286 ));
287
288 Ok(())
289 }
290
291 fn decode_frame_with_pipeline(
293 &mut self,
294 frame_header: &FrameHeader,
295 tile_data: &[u8],
296 ) -> CodecResult<VideoFrame> {
297 if self.pipeline.is_none() {
299 if let Some(seq) = self.sequence_header.clone() {
300 self.initialize_pipeline(&seq)?;
301 } else {
302 return Err(CodecError::InvalidData("No sequence header".to_string()));
303 }
304 }
305
306 let mut frame_ctx = FrameContext::new(
308 frame_header.frame_size.upscaled_width,
309 frame_header.frame_size.frame_height,
310 );
311 frame_ctx.decode_order = self.frame_count;
312 frame_ctx.display_order = self.frame_count;
313 frame_ctx.is_keyframe = matches!(frame_header.frame_type, Av1FrameType::KeyFrame);
314 frame_ctx.show_frame = frame_header.show_frame;
315 frame_ctx.bit_depth = frame_header.quantization.base_q_idx as u8;
316
317 self.decode_tiles(tile_data, frame_header, &frame_ctx)?;
319
320 if let Some(ref mut pipeline) = self.pipeline {
322 let buffer = pipeline
323 .process_frame(tile_data, &frame_ctx)
324 .map_err(|e| CodecError::Internal(format!("Pipeline processing failed: {e:?}")))?;
325
326 let format = self
328 .sequence_header
329 .as_ref()
330 .map(Self::determine_pixel_format)
331 .unwrap_or(PixelFormat::Yuv420p);
332
333 let mut frame = VideoFrame::new(
334 format,
335 frame_header.frame_size.upscaled_width,
336 frame_header.frame_size.frame_height,
337 );
338 frame.allocate();
339
340 self.copy_buffer_to_frame(&buffer, &mut frame)?;
342
343 frame.frame_type = match frame_header.frame_type {
345 Av1FrameType::KeyFrame => FrameType::Key,
346 _ => FrameType::Inter,
347 };
348
349 Ok(frame)
350 } else {
351 Err(CodecError::Internal("Pipeline not initialized".to_string()))
352 }
353 }
354
355 fn decode_tiles(
357 &mut self,
358 tile_data: &[u8],
359 frame_header: &FrameHeader,
360 _frame_ctx: &FrameContext,
361 ) -> CodecResult<()> {
362 let frame_is_intra = frame_header.frame_is_intra;
363
364 let mut symbol_decoder = SymbolDecoder::new(tile_data.to_vec(), frame_is_intra);
366
367 if let Some(mut block_ctx) = self.block_context.take() {
369 self.decode_superblocks(&mut symbol_decoder, frame_header, &mut block_ctx)?;
370 self.block_context = Some(block_ctx);
372 }
373
374 Ok(())
375 }
376
377 fn decode_superblocks(
379 &mut self,
380 symbol_decoder: &mut SymbolDecoder,
381 frame_header: &FrameHeader,
382 block_ctx: &mut BlockContextManager,
383 ) -> CodecResult<()> {
384 let sb_size = BlockSize::Block64x64; let frame_width = frame_header.frame_size.upscaled_width;
386 let frame_height = frame_header.frame_size.frame_height;
387
388 let sb_cols = (frame_width + sb_size.width() - 1) / sb_size.width();
389 let sb_rows = (frame_height + sb_size.height() - 1) / sb_size.height();
390
391 for sb_row in 0..sb_rows {
392 block_ctx.reset_left_context();
393
394 for sb_col in 0..sb_cols {
395 let mi_row = sb_row * (sb_size.height() / 4);
396 let mi_col = sb_col * (sb_size.width() / 4);
397
398 block_ctx.set_position(mi_row, mi_col, sb_size);
399
400 self.decode_partition_tree(
402 symbol_decoder,
403 block_ctx,
404 mi_row,
405 mi_col,
406 sb_size,
407 &frame_header.quantization,
408 )?;
409 }
410 }
411
412 Ok(())
413 }
414
415 fn decode_partition_tree(
417 &mut self,
418 symbol_decoder: &mut SymbolDecoder,
419 block_ctx: &mut BlockContextManager,
420 mi_row: u32,
421 mi_col: u32,
422 bsize: BlockSize,
423 quant_params: &QuantizationParams,
424 ) -> CodecResult<()> {
425 let partition_ctx = block_ctx.get_partition_context(bsize);
427 let partition = symbol_decoder.read_partition(bsize, partition_ctx)?;
428
429 if partition.is_leaf() {
430 self.decode_block(
432 symbol_decoder,
433 block_ctx,
434 mi_row,
435 mi_col,
436 bsize,
437 quant_params,
438 )?;
439 } else {
440 self.decode_block(
443 symbol_decoder,
444 block_ctx,
445 mi_row,
446 mi_col,
447 bsize,
448 quant_params,
449 )?;
450 }
451
452 Ok(())
453 }
454
455 fn decode_block(
457 &mut self,
458 symbol_decoder: &mut SymbolDecoder,
459 block_ctx: &mut BlockContextManager,
460 mi_row: u32,
461 mi_col: u32,
462 bsize: BlockSize,
463 quant_params: &QuantizationParams,
464 ) -> CodecResult<()> {
465 let skip_ctx = 0; let mode_ctx = 0;
467
468 let mode_info = symbol_decoder.decode_block_mode(bsize, skip_ctx, mode_ctx)?;
470
471 block_ctx.mode_info = mode_info.clone();
473 block_ctx.update_context(bsize);
474
475 if !mode_info.skip && symbol_decoder.has_more_data() {
477 self.decode_block_coefficients(&mode_info, mi_row, mi_col, quant_params)?;
478 }
479
480 Ok(())
481 }
482
483 fn decode_block_coefficients(
485 &mut self,
486 mode_info: &BlockModeInfo,
487 _mi_row: u32,
488 _mi_col: u32,
489 quant_params: &QuantizationParams,
490 ) -> CodecResult<()> {
491 let tx_size = mode_info.tx_size;
492 let tx_type = TxType::DctDct;
494
495 for plane in 0..3 {
497 let coeff_data = vec![0u8; 128]; let mut coeff_decoder = CoeffDecoder::new(coeff_data, quant_params.clone(), 8);
500
501 let coeff_buffer =
503 coeff_decoder.decode_coefficients(tx_size, tx_type, plane, mode_info.skip)?;
504
505 let mut transform = Transform2D::new(tx_size, tx_type);
507 let mut residual = vec![0i32; tx_size.area() as usize];
508 transform.inverse(coeff_buffer.as_slice(), &mut residual);
509
510 }
512
513 Ok(())
514 }
515
516 fn copy_buffer_to_frame(
518 &self,
519 _buffer: &crate::reconstruct::FrameBuffer,
520 _frame: &mut VideoFrame,
521 ) -> CodecResult<()> {
522 Ok(())
524 }
525}
526
527impl VideoDecoder for Av1Decoder {
528 fn codec(&self) -> CodecId {
529 CodecId::Av1
530 }
531
532 fn send_packet(&mut self, data: &[u8], pts: i64) -> CodecResult<()> {
533 if self.flushing {
534 return Err(CodecError::InvalidParameter(
535 "Cannot send packet while flushing".to_string(),
536 ));
537 }
538 self.decode_temporal_unit(data, pts)
539 }
540
541 fn receive_frame(&mut self) -> CodecResult<Option<VideoFrame>> {
542 if self.output_queue.is_empty() {
543 if self.flushing {
544 return Err(CodecError::Eof);
545 }
546 return Ok(None);
547 }
548 Ok(Some(self.output_queue.remove(0)))
549 }
550
551 fn flush(&mut self) -> CodecResult<()> {
552 self.flushing = true;
553 Ok(())
554 }
555
556 fn reset(&mut self) {
557 self.output_queue.clear();
558 self.flushing = false;
559 self.frame_count = 0;
560 self.state.reset();
561 }
562
563 fn output_format(&self) -> Option<PixelFormat> {
564 self.sequence_header
565 .as_ref()
566 .map(Self::determine_pixel_format)
567 }
568
569 fn dimensions(&self) -> Option<(u32, u32)> {
570 self.sequence_header
571 .as_ref()
572 .map(|seq| (seq.max_frame_width(), seq.max_frame_height()))
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 #[test]
581 fn test_decoder_creation() {
582 let config = DecoderConfig::default();
583 let decoder = Av1Decoder::new(config);
584 assert!(decoder.is_ok());
585 }
586
587 #[test]
588 fn test_decoder_codec_id() {
589 let config = DecoderConfig::default();
590 let decoder = Av1Decoder::new(config).expect("should succeed");
591 assert_eq!(decoder.codec(), CodecId::Av1);
592 }
593
594 #[test]
595 fn test_decoder_flush() {
596 let config = DecoderConfig::default();
597 let mut decoder = Av1Decoder::new(config).expect("should succeed");
598 assert!(decoder.flush().is_ok());
599 }
600
601 #[test]
602 fn test_send_while_flushing() {
603 let config = DecoderConfig::default();
604 let mut decoder = Av1Decoder::new(config).expect("should succeed");
605 decoder.flush().expect("should succeed");
606 let result = decoder.send_packet(&[], 0);
607 assert!(result.is_err());
608 }
609
610 #[test]
611 fn test_decoder_reset() {
612 let config = DecoderConfig::default();
613 let mut decoder = Av1Decoder::new(config).expect("should succeed");
614 decoder.flush().expect("should succeed");
615 decoder.reset();
616 assert_eq!(decoder.frame_count(), 0);
617 assert!(decoder.send_packet(&[], 0).is_ok());
618 }
619
620 #[test]
621 fn test_initial_state() {
622 let config = DecoderConfig::default();
623 let decoder = Av1Decoder::new(config).expect("should succeed");
624 assert!(decoder.current_frame_header().is_none());
625 assert!(decoder.current_sequence_header().is_none());
626 assert!(decoder.tile_info().is_none());
627 }
628
629 #[test]
630 fn test_loop_filter_params() {
631 let config = DecoderConfig::default();
632 let decoder = Av1Decoder::new(config).expect("should succeed");
633 let lf = decoder.loop_filter_params();
634 assert!(!lf.is_enabled());
635 }
636
637 #[test]
638 fn test_cdef_params() {
639 let config = DecoderConfig::default();
640 let decoder = Av1Decoder::new(config).expect("should succeed");
641 let cdef = decoder.cdef_params();
642 assert!(!cdef.is_enabled());
643 }
644
645 #[test]
646 fn test_quantization_params() {
647 let config = DecoderConfig::default();
648 let decoder = Av1Decoder::new(config).expect("should succeed");
649 let qp = decoder.quantization_params();
650 assert_eq!(qp.base_q_idx, 0);
651 }
652}