#![forbid(unsafe_code)]
#![allow(dead_code)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::unused_self)]
#![allow(clippy::missing_errors_doc)]
#![allow(clippy::match_same_arms)]
use super::block::{BlockContextManager, BlockModeInfo, BlockSize};
use super::cdef::CdefParams;
use super::coeff_decode::CoeffDecoder;
use super::frame_header::{FrameHeader, FrameType as Av1FrameType};
use super::loop_filter::LoopFilterParams;
use super::obu::{ObuIterator, ObuType};
use super::prediction::PredictionEngine;
use super::quantization::QuantizationParams;
use super::sequence::SequenceHeader;
use super::symbols::SymbolDecoder;
use super::tile::TileInfo;
use super::transform::{Transform2D, TxType};
use crate::error::{CodecError, CodecResult};
use crate::frame::{FrameType, VideoFrame};
use crate::reconstruct::{DecoderPipeline, FrameContext, PipelineConfig, ReferenceFrameManager};
use crate::traits::{DecoderConfig, VideoDecoder};
use oximedia_core::{CodecId, PixelFormat, Rational, Timestamp};
#[derive(Clone, Debug, Default)]
#[allow(dead_code)]
struct DecoderState {
frame_header: Option<FrameHeader>,
loop_filter: LoopFilterParams,
cdef: CdefParams,
quantization: QuantizationParams,
tile_info: Option<TileInfo>,
frame_is_intra: bool,
}
impl DecoderState {
fn new() -> Self {
Self::default()
}
fn reset(&mut self) {
self.frame_header = None;
self.tile_info = None;
}
}
#[derive(Debug)]
pub struct Av1Decoder {
config: DecoderConfig,
sequence_header: Option<SequenceHeader>,
output_queue: Vec<VideoFrame>,
flushing: bool,
frame_count: u64,
state: DecoderState,
pipeline: Option<DecoderPipeline>,
ref_manager: ReferenceFrameManager,
prediction: Option<PredictionEngine>,
block_context: Option<BlockContextManager>,
}
impl Av1Decoder {
pub fn new(config: DecoderConfig) -> CodecResult<Self> {
let mut decoder = Self {
config,
sequence_header: None,
output_queue: Vec::new(),
flushing: false,
frame_count: 0,
state: DecoderState::new(),
pipeline: None,
ref_manager: ReferenceFrameManager::new(),
prediction: None,
block_context: None,
};
if let Some(extradata) = decoder.config.extradata.clone() {
decoder.parse_extradata(&extradata)?;
}
Ok(decoder)
}
fn parse_extradata(&mut self, data: &[u8]) -> CodecResult<()> {
for obu_result in ObuIterator::new(data) {
let (header, payload) = obu_result?;
if header.obu_type == ObuType::SequenceHeader {
self.sequence_header = Some(SequenceHeader::parse(payload)?);
break;
}
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn decode_temporal_unit(&mut self, data: &[u8], pts: i64) -> CodecResult<()> {
self.state.reset();
for obu_result in ObuIterator::new(data) {
let (header, payload) = obu_result?;
match header.obu_type {
ObuType::SequenceHeader => {
self.sequence_header = Some(SequenceHeader::parse(payload)?);
}
ObuType::FrameHeader | ObuType::Frame => {
if let Some(ref seq) = self.sequence_header {
let frame_header = FrameHeader::parse(payload, seq)?;
self.state.frame_is_intra = frame_header.frame_is_intra;
self.state.loop_filter = frame_header.loop_filter.clone();
self.state.cdef = frame_header.cdef.clone();
self.state.quantization = frame_header.quantization.clone();
self.state.tile_info = Some(frame_header.tile_info.clone());
self.state.frame_header = Some(frame_header.clone());
let format = Self::determine_pixel_format(seq);
let width = frame_header.frame_size.upscaled_width;
let height = frame_header.frame_size.frame_height;
let mut frame = VideoFrame::new(
format,
if width > 0 {
width
} else {
seq.max_frame_width()
},
if height > 0 {
height
} else {
seq.max_frame_height()
},
);
frame.allocate();
frame.timestamp = Timestamp::new(pts, Rational::new(1, 1000));
frame.frame_type = match frame_header.frame_type {
Av1FrameType::KeyFrame => FrameType::Key,
Av1FrameType::InterFrame => FrameType::Inter,
Av1FrameType::IntraOnlyFrame => FrameType::Key, Av1FrameType::SwitchFrame => FrameType::Inter, };
self.output_queue.push(frame);
self.frame_count += 1;
}
}
ObuType::TileGroup => {
}
_ => {}
}
}
Ok(())
}
fn determine_pixel_format(seq: &SequenceHeader) -> PixelFormat {
let cc = &seq.color_config;
if cc.mono_chrome {
return PixelFormat::Gray8;
}
match (cc.bit_depth, cc.subsampling_x, cc.subsampling_y) {
(8, true, false) => PixelFormat::Yuv422p,
(8, false, false) => PixelFormat::Yuv444p,
(10, true, true) => PixelFormat::Yuv420p10le,
(12, true, true) => PixelFormat::Yuv420p12le,
_ => PixelFormat::Yuv420p,
}
}
#[must_use]
#[allow(dead_code)]
pub fn current_frame_header(&self) -> Option<&FrameHeader> {
self.state.frame_header.as_ref()
}
#[must_use]
#[allow(dead_code)]
pub fn current_sequence_header(&self) -> Option<&SequenceHeader> {
self.sequence_header.as_ref()
}
#[must_use]
#[allow(dead_code)]
pub fn loop_filter_params(&self) -> &LoopFilterParams {
&self.state.loop_filter
}
#[must_use]
#[allow(dead_code)]
pub fn cdef_params(&self) -> &CdefParams {
&self.state.cdef
}
#[must_use]
#[allow(dead_code)]
pub fn quantization_params(&self) -> &QuantizationParams {
&self.state.quantization
}
#[must_use]
#[allow(dead_code)]
pub fn tile_info(&self) -> Option<&TileInfo> {
self.state.tile_info.as_ref()
}
#[must_use]
#[allow(dead_code)]
pub const fn frame_count(&self) -> u64 {
self.frame_count
}
fn initialize_pipeline(&mut self, seq: &SequenceHeader) -> CodecResult<()> {
let width = seq.max_frame_width();
let height = seq.max_frame_height();
let bit_depth = seq.color_config.bit_depth;
let pipeline_config = PipelineConfig::new(width, height)
.with_bit_depth(bit_depth)
.with_all_filters();
self.pipeline = Some(
DecoderPipeline::new(pipeline_config)
.map_err(|e| CodecError::Internal(format!("Pipeline creation failed: {e:?}")))?,
);
self.prediction = Some(PredictionEngine::new(width, height, bit_depth));
self.block_context = Some(BlockContextManager::new(
width / 4,
seq.color_config.subsampling_x,
seq.color_config.subsampling_y,
));
Ok(())
}
fn decode_frame_with_pipeline(
&mut self,
frame_header: &FrameHeader,
tile_data: &[u8],
) -> CodecResult<VideoFrame> {
if self.pipeline.is_none() {
if let Some(seq) = self.sequence_header.clone() {
self.initialize_pipeline(&seq)?;
} else {
return Err(CodecError::InvalidData("No sequence header".to_string()));
}
}
let mut frame_ctx = FrameContext::new(
frame_header.frame_size.upscaled_width,
frame_header.frame_size.frame_height,
);
frame_ctx.decode_order = self.frame_count;
frame_ctx.display_order = self.frame_count;
frame_ctx.is_keyframe = matches!(frame_header.frame_type, Av1FrameType::KeyFrame);
frame_ctx.show_frame = frame_header.show_frame;
frame_ctx.bit_depth = frame_header.quantization.base_q_idx as u8;
self.decode_tiles(tile_data, frame_header, &frame_ctx)?;
if let Some(ref mut pipeline) = self.pipeline {
let buffer = pipeline
.process_frame(tile_data, &frame_ctx)
.map_err(|e| CodecError::Internal(format!("Pipeline processing failed: {e:?}")))?;
let format = self
.sequence_header
.as_ref()
.map(Self::determine_pixel_format)
.unwrap_or(PixelFormat::Yuv420p);
let mut frame = VideoFrame::new(
format,
frame_header.frame_size.upscaled_width,
frame_header.frame_size.frame_height,
);
frame.allocate();
self.copy_buffer_to_frame(&buffer, &mut frame)?;
frame.frame_type = match frame_header.frame_type {
Av1FrameType::KeyFrame => FrameType::Key,
_ => FrameType::Inter,
};
Ok(frame)
} else {
Err(CodecError::Internal("Pipeline not initialized".to_string()))
}
}
fn decode_tiles(
&mut self,
tile_data: &[u8],
frame_header: &FrameHeader,
_frame_ctx: &FrameContext,
) -> CodecResult<()> {
let frame_is_intra = frame_header.frame_is_intra;
let mut symbol_decoder = SymbolDecoder::new(tile_data.to_vec(), frame_is_intra);
if let Some(mut block_ctx) = self.block_context.take() {
self.decode_superblocks(&mut symbol_decoder, frame_header, &mut block_ctx)?;
self.block_context = Some(block_ctx);
}
Ok(())
}
fn decode_superblocks(
&mut self,
symbol_decoder: &mut SymbolDecoder,
frame_header: &FrameHeader,
block_ctx: &mut BlockContextManager,
) -> CodecResult<()> {
let sb_size = BlockSize::Block64x64; let frame_width = frame_header.frame_size.upscaled_width;
let frame_height = frame_header.frame_size.frame_height;
let sb_cols = (frame_width + sb_size.width() - 1) / sb_size.width();
let sb_rows = (frame_height + sb_size.height() - 1) / sb_size.height();
for sb_row in 0..sb_rows {
block_ctx.reset_left_context();
for sb_col in 0..sb_cols {
let mi_row = sb_row * (sb_size.height() / 4);
let mi_col = sb_col * (sb_size.width() / 4);
block_ctx.set_position(mi_row, mi_col, sb_size);
self.decode_partition_tree(
symbol_decoder,
block_ctx,
mi_row,
mi_col,
sb_size,
&frame_header.quantization,
)?;
}
}
Ok(())
}
fn decode_partition_tree(
&mut self,
symbol_decoder: &mut SymbolDecoder,
block_ctx: &mut BlockContextManager,
mi_row: u32,
mi_col: u32,
bsize: BlockSize,
quant_params: &QuantizationParams,
) -> CodecResult<()> {
let partition_ctx = block_ctx.get_partition_context(bsize);
let partition = symbol_decoder.read_partition(bsize, partition_ctx)?;
if partition.is_leaf() {
self.decode_block(
symbol_decoder,
block_ctx,
mi_row,
mi_col,
bsize,
quant_params,
)?;
} else {
self.decode_block(
symbol_decoder,
block_ctx,
mi_row,
mi_col,
bsize,
quant_params,
)?;
}
Ok(())
}
fn decode_block(
&mut self,
symbol_decoder: &mut SymbolDecoder,
block_ctx: &mut BlockContextManager,
mi_row: u32,
mi_col: u32,
bsize: BlockSize,
quant_params: &QuantizationParams,
) -> CodecResult<()> {
let skip_ctx = 0; let mode_ctx = 0;
let mode_info = symbol_decoder.decode_block_mode(bsize, skip_ctx, mode_ctx)?;
block_ctx.mode_info = mode_info.clone();
block_ctx.update_context(bsize);
if !mode_info.skip && symbol_decoder.has_more_data() {
self.decode_block_coefficients(&mode_info, mi_row, mi_col, quant_params)?;
}
Ok(())
}
fn decode_block_coefficients(
&mut self,
mode_info: &BlockModeInfo,
_mi_row: u32,
_mi_col: u32,
quant_params: &QuantizationParams,
) -> CodecResult<()> {
let tx_size = mode_info.tx_size;
let tx_type = TxType::DctDct;
for plane in 0..3 {
let coeff_data = vec![0u8; 128]; let mut coeff_decoder = CoeffDecoder::new(coeff_data, quant_params.clone(), 8);
let coeff_buffer =
coeff_decoder.decode_coefficients(tx_size, tx_type, plane, mode_info.skip)?;
let mut transform = Transform2D::new(tx_size, tx_type);
let mut residual = vec![0i32; tx_size.area() as usize];
transform.inverse(coeff_buffer.as_slice(), &mut residual);
}
Ok(())
}
fn copy_buffer_to_frame(
&self,
_buffer: &crate::reconstruct::FrameBuffer,
_frame: &mut VideoFrame,
) -> CodecResult<()> {
Ok(())
}
}
impl VideoDecoder for Av1Decoder {
fn codec(&self) -> CodecId {
CodecId::Av1
}
fn send_packet(&mut self, data: &[u8], pts: i64) -> CodecResult<()> {
if self.flushing {
return Err(CodecError::InvalidParameter(
"Cannot send packet while flushing".to_string(),
));
}
self.decode_temporal_unit(data, pts)
}
fn receive_frame(&mut self) -> CodecResult<Option<VideoFrame>> {
if self.output_queue.is_empty() {
if self.flushing {
return Err(CodecError::Eof);
}
return Ok(None);
}
Ok(Some(self.output_queue.remove(0)))
}
fn flush(&mut self) -> CodecResult<()> {
self.flushing = true;
Ok(())
}
fn reset(&mut self) {
self.output_queue.clear();
self.flushing = false;
self.frame_count = 0;
self.state.reset();
}
fn output_format(&self) -> Option<PixelFormat> {
self.sequence_header
.as_ref()
.map(Self::determine_pixel_format)
}
fn dimensions(&self) -> Option<(u32, u32)> {
self.sequence_header
.as_ref()
.map(|seq| (seq.max_frame_width(), seq.max_frame_height()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_creation() {
let config = DecoderConfig::default();
let decoder = Av1Decoder::new(config);
assert!(decoder.is_ok());
}
#[test]
fn test_decoder_codec_id() {
let config = DecoderConfig::default();
let decoder = Av1Decoder::new(config).expect("should succeed");
assert_eq!(decoder.codec(), CodecId::Av1);
}
#[test]
fn test_decoder_flush() {
let config = DecoderConfig::default();
let mut decoder = Av1Decoder::new(config).expect("should succeed");
assert!(decoder.flush().is_ok());
}
#[test]
fn test_send_while_flushing() {
let config = DecoderConfig::default();
let mut decoder = Av1Decoder::new(config).expect("should succeed");
decoder.flush().expect("should succeed");
let result = decoder.send_packet(&[], 0);
assert!(result.is_err());
}
#[test]
fn test_decoder_reset() {
let config = DecoderConfig::default();
let mut decoder = Av1Decoder::new(config).expect("should succeed");
decoder.flush().expect("should succeed");
decoder.reset();
assert_eq!(decoder.frame_count(), 0);
assert!(decoder.send_packet(&[], 0).is_ok());
}
#[test]
fn test_initial_state() {
let config = DecoderConfig::default();
let decoder = Av1Decoder::new(config).expect("should succeed");
assert!(decoder.current_frame_header().is_none());
assert!(decoder.current_sequence_header().is_none());
assert!(decoder.tile_info().is_none());
}
#[test]
fn test_loop_filter_params() {
let config = DecoderConfig::default();
let decoder = Av1Decoder::new(config).expect("should succeed");
let lf = decoder.loop_filter_params();
assert!(!lf.is_enabled());
}
#[test]
fn test_cdef_params() {
let config = DecoderConfig::default();
let decoder = Av1Decoder::new(config).expect("should succeed");
let cdef = decoder.cdef_params();
assert!(!cdef.is_enabled());
}
#[test]
fn test_quantization_params() {
let config = DecoderConfig::default();
let decoder = Av1Decoder::new(config).expect("should succeed");
let qp = decoder.quantization_params();
assert_eq!(qp.base_q_idx, 0);
}
}