Skip to main content

oximedia_codec/av1/
decoder.rs

1//! AV1 decoder implementation.
2//!
3//! This module provides a complete AV1 decoder that uses the frame header,
4//! loop filter, CDEF, quantization, and tile parsing infrastructure.
5
6#![forbid(unsafe_code)]
7#![allow(dead_code)]
8#![allow(clippy::doc_markdown)]
9#![allow(clippy::unused_self)]
10#![allow(clippy::missing_errors_doc)]
11#![allow(clippy::match_same_arms)]
12
13use super::block::{BlockContextManager, BlockModeInfo, BlockSize};
14use super::cdef::CdefParams;
15use super::coeff_decode::CoeffDecoder;
16use super::frame_header::{FrameHeader, FrameType as Av1FrameType};
17use super::loop_filter::LoopFilterParams;
18use super::obu::{ObuIterator, ObuType};
19use super::prediction::PredictionEngine;
20use super::quantization::QuantizationParams;
21use super::sequence::SequenceHeader;
22use super::symbols::SymbolDecoder;
23use super::tile::TileInfo;
24use super::transform::{Transform2D, TxType};
25use crate::error::{CodecError, CodecResult};
26use crate::frame::{FrameType, VideoFrame};
27use crate::reconstruct::{DecoderPipeline, FrameContext, PipelineConfig, ReferenceFrameManager};
28use crate::traits::{DecoderConfig, VideoDecoder};
29use oximedia_core::{CodecId, PixelFormat, Rational, Timestamp};
30
31/// AV1 decoder state.
32#[derive(Clone, Debug, Default)]
33#[allow(dead_code)]
34struct DecoderState {
35    /// Current frame header (if parsed).
36    frame_header: Option<FrameHeader>,
37    /// Current loop filter parameters.
38    loop_filter: LoopFilterParams,
39    /// Current CDEF parameters.
40    cdef: CdefParams,
41    /// Current quantization parameters.
42    quantization: QuantizationParams,
43    /// Current tile info.
44    tile_info: Option<TileInfo>,
45    /// Frame is intra-only.
46    frame_is_intra: bool,
47}
48
49impl DecoderState {
50    /// Create a new decoder state.
51    fn new() -> Self {
52        Self::default()
53    }
54
55    /// Reset state for a new frame.
56    fn reset(&mut self) {
57        self.frame_header = None;
58        self.tile_info = None;
59    }
60}
61
62/// AV1 decoder.
63#[derive(Debug)]
64pub struct Av1Decoder {
65    /// Decoder configuration.
66    config: DecoderConfig,
67    /// Current sequence header.
68    sequence_header: Option<SequenceHeader>,
69    /// Decoded frame output queue.
70    output_queue: Vec<VideoFrame>,
71    /// Decoder is in flush mode.
72    flushing: bool,
73    /// Frame counter.
74    frame_count: u64,
75    /// Decoder state.
76    state: DecoderState,
77    /// Reconstruction pipeline.
78    pipeline: Option<DecoderPipeline>,
79    /// Reference frame manager.
80    ref_manager: ReferenceFrameManager,
81    /// Prediction engine.
82    prediction: Option<PredictionEngine>,
83    /// Block context manager.
84    block_context: Option<BlockContextManager>,
85}
86
87impl Av1Decoder {
88    /// Create a new AV1 decoder.
89    ///
90    /// # Errors
91    ///
92    /// Returns error if decoder initialization fails.
93    pub fn new(config: DecoderConfig) -> CodecResult<Self> {
94        let mut decoder = Self {
95            config,
96            sequence_header: None,
97            output_queue: Vec::new(),
98            flushing: false,
99            frame_count: 0,
100            state: DecoderState::new(),
101            pipeline: None,
102            ref_manager: ReferenceFrameManager::new(),
103            prediction: None,
104            block_context: None,
105        };
106
107        if let Some(extradata) = decoder.config.extradata.clone() {
108            decoder.parse_extradata(&extradata)?;
109        }
110
111        Ok(decoder)
112    }
113
114    /// Parse codec extradata.
115    fn parse_extradata(&mut self, data: &[u8]) -> CodecResult<()> {
116        for obu_result in ObuIterator::new(data) {
117            let (header, payload) = obu_result?;
118            if header.obu_type == ObuType::SequenceHeader {
119                self.sequence_header = Some(SequenceHeader::parse(payload)?);
120                break;
121            }
122        }
123        Ok(())
124    }
125
126    /// Decode a temporal unit.
127    #[allow(clippy::too_many_lines)]
128    fn decode_temporal_unit(&mut self, data: &[u8], pts: i64) -> CodecResult<()> {
129        // Reset state for new frame
130        self.state.reset();
131
132        for obu_result in ObuIterator::new(data) {
133            let (header, payload) = obu_result?;
134
135            match header.obu_type {
136                ObuType::SequenceHeader => {
137                    self.sequence_header = Some(SequenceHeader::parse(payload)?);
138                }
139                ObuType::FrameHeader | ObuType::Frame => {
140                    if let Some(ref seq) = self.sequence_header {
141                        // Parse frame header using the new infrastructure
142                        let frame_header = FrameHeader::parse(payload, seq)?;
143
144                        // Store parsed state
145                        self.state.frame_is_intra = frame_header.frame_is_intra;
146                        self.state.loop_filter = frame_header.loop_filter.clone();
147                        self.state.cdef = frame_header.cdef.clone();
148                        self.state.quantization = frame_header.quantization.clone();
149                        self.state.tile_info = Some(frame_header.tile_info.clone());
150                        self.state.frame_header = Some(frame_header.clone());
151
152                        // Create output frame
153                        let format = Self::determine_pixel_format(seq);
154                        let width = frame_header.frame_size.upscaled_width;
155                        let height = frame_header.frame_size.frame_height;
156
157                        let mut frame = VideoFrame::new(
158                            format,
159                            if width > 0 {
160                                width
161                            } else {
162                                seq.max_frame_width()
163                            },
164                            if height > 0 {
165                                height
166                            } else {
167                                seq.max_frame_height()
168                            },
169                        );
170                        frame.allocate();
171                        frame.timestamp = Timestamp::new(pts, Rational::new(1, 1000));
172
173                        // Determine frame type from AV1 frame type
174                        frame.frame_type = match frame_header.frame_type {
175                            Av1FrameType::KeyFrame => FrameType::Key,
176                            Av1FrameType::InterFrame => FrameType::Inter,
177                            Av1FrameType::IntraOnlyFrame => FrameType::Key, // Treat as key for display
178                            Av1FrameType::SwitchFrame => FrameType::Inter,  // Switch frame is inter
179                        };
180
181                        self.output_queue.push(frame);
182                        self.frame_count += 1;
183                    }
184                }
185                ObuType::TileGroup => {
186                    // Tile group data would be processed here
187                    // For now, we've already created the frame in FrameHeader handling
188                }
189                _ => {}
190            }
191        }
192
193        Ok(())
194    }
195
196    /// Determine pixel format from sequence header.
197    fn determine_pixel_format(seq: &SequenceHeader) -> PixelFormat {
198        let cc = &seq.color_config;
199        if cc.mono_chrome {
200            return PixelFormat::Gray8;
201        }
202        match (cc.bit_depth, cc.subsampling_x, cc.subsampling_y) {
203            (8, true, false) => PixelFormat::Yuv422p,
204            (8, false, false) => PixelFormat::Yuv444p,
205            (10, true, true) => PixelFormat::Yuv420p10le,
206            (12, true, true) => PixelFormat::Yuv420p12le,
207            // Default to YUV420p for 8-bit 4:2:0 and any other unhandled cases
208            _ => PixelFormat::Yuv420p,
209        }
210    }
211
212    /// Get the current frame header if available.
213    #[must_use]
214    #[allow(dead_code)]
215    pub fn current_frame_header(&self) -> Option<&FrameHeader> {
216        self.state.frame_header.as_ref()
217    }
218
219    /// Get the current sequence header if available.
220    #[must_use]
221    #[allow(dead_code)]
222    pub fn current_sequence_header(&self) -> Option<&SequenceHeader> {
223        self.sequence_header.as_ref()
224    }
225
226    /// Get the current loop filter parameters.
227    #[must_use]
228    #[allow(dead_code)]
229    pub fn loop_filter_params(&self) -> &LoopFilterParams {
230        &self.state.loop_filter
231    }
232
233    /// Get the current CDEF parameters.
234    #[must_use]
235    #[allow(dead_code)]
236    pub fn cdef_params(&self) -> &CdefParams {
237        &self.state.cdef
238    }
239
240    /// Get the current quantization parameters.
241    #[must_use]
242    #[allow(dead_code)]
243    pub fn quantization_params(&self) -> &QuantizationParams {
244        &self.state.quantization
245    }
246
247    /// Get the current tile info if available.
248    #[must_use]
249    #[allow(dead_code)]
250    pub fn tile_info(&self) -> Option<&TileInfo> {
251        self.state.tile_info.as_ref()
252    }
253
254    /// Get decoded frame count.
255    #[must_use]
256    #[allow(dead_code)]
257    pub const fn frame_count(&self) -> u64 {
258        self.frame_count
259    }
260
261    /// Initialize pipeline from sequence header.
262    fn initialize_pipeline(&mut self, seq: &SequenceHeader) -> CodecResult<()> {
263        let width = seq.max_frame_width();
264        let height = seq.max_frame_height();
265        let bit_depth = seq.color_config.bit_depth;
266
267        // Create pipeline config
268        let pipeline_config = PipelineConfig::new(width, height)
269            .with_bit_depth(bit_depth)
270            .with_all_filters();
271
272        // Create pipeline
273        self.pipeline = Some(
274            DecoderPipeline::new(pipeline_config)
275                .map_err(|e| CodecError::Internal(format!("Pipeline creation failed: {e:?}")))?,
276        );
277
278        // Create prediction engine
279        self.prediction = Some(PredictionEngine::new(width, height, bit_depth));
280
281        // Create block context manager
282        self.block_context = Some(BlockContextManager::new(
283            width / 4,
284            seq.color_config.subsampling_x,
285            seq.color_config.subsampling_y,
286        ));
287
288        Ok(())
289    }
290
291    /// Decode a frame with full pipeline.
292    fn decode_frame_with_pipeline(
293        &mut self,
294        frame_header: &FrameHeader,
295        tile_data: &[u8],
296    ) -> CodecResult<VideoFrame> {
297        // Ensure pipeline is initialized
298        if self.pipeline.is_none() {
299            if let Some(seq) = self.sequence_header.clone() {
300                self.initialize_pipeline(&seq)?;
301            } else {
302                return Err(CodecError::InvalidData("No sequence header".to_string()));
303            }
304        }
305
306        // Create frame context
307        let mut frame_ctx = FrameContext::new(
308            frame_header.frame_size.upscaled_width,
309            frame_header.frame_size.frame_height,
310        );
311        frame_ctx.decode_order = self.frame_count;
312        frame_ctx.display_order = self.frame_count;
313        frame_ctx.is_keyframe = matches!(frame_header.frame_type, Av1FrameType::KeyFrame);
314        frame_ctx.show_frame = frame_header.show_frame;
315        frame_ctx.bit_depth = frame_header.quantization.base_q_idx as u8;
316
317        // Decode tiles and reconstruct
318        self.decode_tiles(tile_data, frame_header, &frame_ctx)?;
319
320        // Get output from pipeline
321        if let Some(ref mut pipeline) = self.pipeline {
322            let buffer = pipeline
323                .process_frame(tile_data, &frame_ctx)
324                .map_err(|e| CodecError::Internal(format!("Pipeline processing failed: {e:?}")))?;
325
326            // Convert buffer to VideoFrame
327            let format = self
328                .sequence_header
329                .as_ref()
330                .map(Self::determine_pixel_format)
331                .unwrap_or(PixelFormat::Yuv420p);
332
333            let mut frame = VideoFrame::new(
334                format,
335                frame_header.frame_size.upscaled_width,
336                frame_header.frame_size.frame_height,
337            );
338            frame.allocate();
339
340            // Copy buffer data to frame (simplified)
341            self.copy_buffer_to_frame(&buffer, &mut frame)?;
342
343            // Set frame metadata
344            frame.frame_type = match frame_header.frame_type {
345                Av1FrameType::KeyFrame => FrameType::Key,
346                _ => FrameType::Inter,
347            };
348
349            Ok(frame)
350        } else {
351            Err(CodecError::Internal("Pipeline not initialized".to_string()))
352        }
353    }
354
355    /// Decode tiles using symbol decoder.
356    fn decode_tiles(
357        &mut self,
358        tile_data: &[u8],
359        frame_header: &FrameHeader,
360        _frame_ctx: &FrameContext,
361    ) -> CodecResult<()> {
362        let frame_is_intra = frame_header.frame_is_intra;
363
364        // Create symbol decoder
365        let mut symbol_decoder = SymbolDecoder::new(tile_data.to_vec(), frame_is_intra);
366
367        // Take ownership temporarily to avoid borrow issues
368        if let Some(mut block_ctx) = self.block_context.take() {
369            self.decode_superblocks(&mut symbol_decoder, frame_header, &mut block_ctx)?;
370            // Put it back
371            self.block_context = Some(block_ctx);
372        }
373
374        Ok(())
375    }
376
377    /// Decode superblocks.
378    fn decode_superblocks(
379        &mut self,
380        symbol_decoder: &mut SymbolDecoder,
381        frame_header: &FrameHeader,
382        block_ctx: &mut BlockContextManager,
383    ) -> CodecResult<()> {
384        let sb_size = BlockSize::Block64x64; // or 128x128 based on sequence header
385        let frame_width = frame_header.frame_size.upscaled_width;
386        let frame_height = frame_header.frame_size.frame_height;
387
388        let sb_cols = (frame_width + sb_size.width() - 1) / sb_size.width();
389        let sb_rows = (frame_height + sb_size.height() - 1) / sb_size.height();
390
391        for sb_row in 0..sb_rows {
392            block_ctx.reset_left_context();
393
394            for sb_col in 0..sb_cols {
395                let mi_row = sb_row * (sb_size.height() / 4);
396                let mi_col = sb_col * (sb_size.width() / 4);
397
398                block_ctx.set_position(mi_row, mi_col, sb_size);
399
400                // Decode partition tree
401                self.decode_partition_tree(
402                    symbol_decoder,
403                    block_ctx,
404                    mi_row,
405                    mi_col,
406                    sb_size,
407                    &frame_header.quantization,
408                )?;
409            }
410        }
411
412        Ok(())
413    }
414
415    /// Decode partition tree recursively.
416    fn decode_partition_tree(
417        &mut self,
418        symbol_decoder: &mut SymbolDecoder,
419        block_ctx: &mut BlockContextManager,
420        mi_row: u32,
421        mi_col: u32,
422        bsize: BlockSize,
423        quant_params: &QuantizationParams,
424    ) -> CodecResult<()> {
425        // Read partition
426        let partition_ctx = block_ctx.get_partition_context(bsize);
427        let partition = symbol_decoder.read_partition(bsize, partition_ctx)?;
428
429        if partition.is_leaf() {
430            // Leaf block: decode mode and coefficients
431            self.decode_block(
432                symbol_decoder,
433                block_ctx,
434                mi_row,
435                mi_col,
436                bsize,
437                quant_params,
438            )?;
439        } else {
440            // Recursive partition: decode sub-blocks
441            // (Simplified: just decode as leaf for now)
442            self.decode_block(
443                symbol_decoder,
444                block_ctx,
445                mi_row,
446                mi_col,
447                bsize,
448                quant_params,
449            )?;
450        }
451
452        Ok(())
453    }
454
455    /// Decode a single block.
456    fn decode_block(
457        &mut self,
458        symbol_decoder: &mut SymbolDecoder,
459        block_ctx: &mut BlockContextManager,
460        mi_row: u32,
461        mi_col: u32,
462        bsize: BlockSize,
463        quant_params: &QuantizationParams,
464    ) -> CodecResult<()> {
465        let skip_ctx = 0; // Would compute from neighbors
466        let mode_ctx = 0;
467
468        // Decode block mode
469        let mode_info = symbol_decoder.decode_block_mode(bsize, skip_ctx, mode_ctx)?;
470
471        // Store mode info in context
472        block_ctx.mode_info = mode_info.clone();
473        block_ctx.update_context(bsize);
474
475        // Decode coefficients if not skipped
476        if !mode_info.skip && symbol_decoder.has_more_data() {
477            self.decode_block_coefficients(&mode_info, mi_row, mi_col, quant_params)?;
478        }
479
480        Ok(())
481    }
482
483    /// Decode coefficients for a block.
484    fn decode_block_coefficients(
485        &mut self,
486        mode_info: &BlockModeInfo,
487        _mi_row: u32,
488        _mi_col: u32,
489        quant_params: &QuantizationParams,
490    ) -> CodecResult<()> {
491        let tx_size = mode_info.tx_size;
492        // Default to DCT_DCT for transform type
493        let tx_type = TxType::DctDct;
494
495        // Decode and transform coefficients for each plane
496        for plane in 0..3 {
497            // Create coefficient decoder (would use actual tile data)
498            let coeff_data = vec![0u8; 128]; // Placeholder
499            let mut coeff_decoder = CoeffDecoder::new(coeff_data, quant_params.clone(), 8);
500
501            // Decode coefficients
502            let coeff_buffer =
503                coeff_decoder.decode_coefficients(tx_size, tx_type, plane, mode_info.skip)?;
504
505            // Apply inverse transform
506            let mut transform = Transform2D::new(tx_size, tx_type);
507            let mut residual = vec![0i32; tx_size.area() as usize];
508            transform.inverse(coeff_buffer.as_slice(), &mut residual);
509
510            // Add to prediction (would integrate with prediction engine)
511        }
512
513        Ok(())
514    }
515
516    /// Copy buffer to frame.
517    fn copy_buffer_to_frame(
518        &self,
519        _buffer: &crate::reconstruct::FrameBuffer,
520        _frame: &mut VideoFrame,
521    ) -> CodecResult<()> {
522        // Simplified: buffer copying would happen here
523        Ok(())
524    }
525}
526
527impl VideoDecoder for Av1Decoder {
528    fn codec(&self) -> CodecId {
529        CodecId::Av1
530    }
531
532    fn send_packet(&mut self, data: &[u8], pts: i64) -> CodecResult<()> {
533        if self.flushing {
534            return Err(CodecError::InvalidParameter(
535                "Cannot send packet while flushing".to_string(),
536            ));
537        }
538        self.decode_temporal_unit(data, pts)
539    }
540
541    fn receive_frame(&mut self) -> CodecResult<Option<VideoFrame>> {
542        if self.output_queue.is_empty() {
543            if self.flushing {
544                return Err(CodecError::Eof);
545            }
546            return Ok(None);
547        }
548        Ok(Some(self.output_queue.remove(0)))
549    }
550
551    fn flush(&mut self) -> CodecResult<()> {
552        self.flushing = true;
553        Ok(())
554    }
555
556    fn reset(&mut self) {
557        self.output_queue.clear();
558        self.flushing = false;
559        self.frame_count = 0;
560        self.state.reset();
561    }
562
563    fn output_format(&self) -> Option<PixelFormat> {
564        self.sequence_header
565            .as_ref()
566            .map(Self::determine_pixel_format)
567    }
568
569    fn dimensions(&self) -> Option<(u32, u32)> {
570        self.sequence_header
571            .as_ref()
572            .map(|seq| (seq.max_frame_width(), seq.max_frame_height()))
573    }
574}
575
576#[cfg(test)]
577mod tests {
578    use super::*;
579
580    #[test]
581    fn test_decoder_creation() {
582        let config = DecoderConfig::default();
583        let decoder = Av1Decoder::new(config);
584        assert!(decoder.is_ok());
585    }
586
587    #[test]
588    fn test_decoder_codec_id() {
589        let config = DecoderConfig::default();
590        let decoder = Av1Decoder::new(config).expect("should succeed");
591        assert_eq!(decoder.codec(), CodecId::Av1);
592    }
593
594    #[test]
595    fn test_decoder_flush() {
596        let config = DecoderConfig::default();
597        let mut decoder = Av1Decoder::new(config).expect("should succeed");
598        assert!(decoder.flush().is_ok());
599    }
600
601    #[test]
602    fn test_send_while_flushing() {
603        let config = DecoderConfig::default();
604        let mut decoder = Av1Decoder::new(config).expect("should succeed");
605        decoder.flush().expect("should succeed");
606        let result = decoder.send_packet(&[], 0);
607        assert!(result.is_err());
608    }
609
610    #[test]
611    fn test_decoder_reset() {
612        let config = DecoderConfig::default();
613        let mut decoder = Av1Decoder::new(config).expect("should succeed");
614        decoder.flush().expect("should succeed");
615        decoder.reset();
616        assert_eq!(decoder.frame_count(), 0);
617        assert!(decoder.send_packet(&[], 0).is_ok());
618    }
619
620    #[test]
621    fn test_initial_state() {
622        let config = DecoderConfig::default();
623        let decoder = Av1Decoder::new(config).expect("should succeed");
624        assert!(decoder.current_frame_header().is_none());
625        assert!(decoder.current_sequence_header().is_none());
626        assert!(decoder.tile_info().is_none());
627    }
628
629    #[test]
630    fn test_loop_filter_params() {
631        let config = DecoderConfig::default();
632        let decoder = Av1Decoder::new(config).expect("should succeed");
633        let lf = decoder.loop_filter_params();
634        assert!(!lf.is_enabled());
635    }
636
637    #[test]
638    fn test_cdef_params() {
639        let config = DecoderConfig::default();
640        let decoder = Av1Decoder::new(config).expect("should succeed");
641        let cdef = decoder.cdef_params();
642        assert!(!cdef.is_enabled());
643    }
644
645    #[test]
646    fn test_quantization_params() {
647        let config = DecoderConfig::default();
648        let decoder = Av1Decoder::new(config).expect("should succeed");
649        let qp = decoder.quantization_params();
650        assert_eq!(qp.base_q_idx, 0);
651    }
652}