Skip to main content

oximedia_codec/av1/
parallel_tile_decoder.rs

1//! Parallel tile decoding for AV1 frames using rayon.
2//!
3//! AV1 divides frames into independent rectangular tiles that can be decoded
4//! concurrently.  This module provides [`ParallelTileDecoder`] which splits
5//! a raw frame buffer into [`TileJob`]s, decodes them in parallel via rayon,
6//! and re-assembles the results into a planar YUV 4:2:0 output buffer.
7//!
8//! # Structural Implementation
9//!
10//! A full AV1 tile decode requires entropy decoding, prediction, inverse
11//! transforms, and loop filtering — all of which depend on codec state that
12//! lives in the outer decoder.  This module provides the *structural*
13//! scaffolding: correct splitting, parallel dispatch, and frame assembly.
14//! Each tile's "decode" step currently copies the raw tile bytes as a
15//! stand-in for the real decode pass.
16//!
17//! # Example
18//!
19//! ```rust
20//! use oximedia_codec::av1::{ParallelTileDecoder, TileJob};
21//!
22//! let decoder = ParallelTileDecoder::new(1920, 1080, 4, 2);
23//! let frame_data = vec![0u8; 1920 * 1080]; // synthetic luma
24//! let tiles = decoder.split_into_tiles(&frame_data);
25//! assert_eq!(tiles.len(), 8);
26//! let output = decoder.decode_tiles_parallel(tiles).expect("decode");
27//! assert_eq!(output.len(), 1920 * 1080 * 3 / 2);
28//! ```
29
30#![forbid(unsafe_code)]
31#![allow(clippy::cast_possible_truncation)]
32#![allow(clippy::missing_errors_doc)]
33
34use rayon::prelude::*;
35
36use crate::error::{CodecError, CodecResult};
37
38// =============================================================================
39// Public types
40// =============================================================================
41
42/// A single tile extracted from a frame, ready for independent decoding.
43#[derive(Clone, Debug)]
44pub struct TileJob {
45    /// Tile row index (0-based, top to bottom).
46    pub tile_row: u32,
47    /// Tile column index (0-based, left to right).
48    pub tile_col: u32,
49    /// Raw tile data (luma bytes for this tile region).
50    pub tile_data: Vec<u8>,
51    /// Pixel offset of this tile's top-left corner in the frame: `(x, y)`.
52    pub tile_offset: (u32, u32),
53    /// Pixel dimensions of this tile: `(width, height)`.
54    pub tile_size: (u32, u32),
55}
56
57/// Parallel tile decoder for AV1 frames.
58///
59/// Divides a frame into a grid of `tile_cols × tile_rows` tiles, decodes them
60/// concurrently with rayon, and assembles the results into a planar YUV 4:2:0
61/// buffer.
62#[derive(Clone, Debug)]
63pub struct ParallelTileDecoder {
64    /// Frame width in pixels.
65    pub frame_width: u32,
66    /// Frame height in pixels.
67    pub frame_height: u32,
68    /// Number of tile columns.
69    pub tile_cols: u32,
70    /// Number of tile rows.
71    pub tile_rows: u32,
72}
73
74// =============================================================================
75// Implementation
76// =============================================================================
77
78impl ParallelTileDecoder {
79    /// Create a new `ParallelTileDecoder`.
80    ///
81    /// # Panics
82    ///
83    /// Does not panic; use [`Self::decode_tiles_parallel`] error return for
84    /// invalid configurations.
85    pub fn new(frame_width: u32, frame_height: u32, tile_cols: u32, tile_rows: u32) -> Self {
86        Self {
87            frame_width,
88            frame_height,
89            tile_cols,
90            tile_rows,
91        }
92    }
93
94    /// Split `frame_data` (luma plane bytes, row-major) into [`TileJob`]s.
95    ///
96    /// The frame is divided into a `tile_cols × tile_rows` grid.  The last
97    /// tile column and row absorb any remainder pixels.
98    ///
99    /// `frame_data` is interpreted as a contiguous luma plane of
100    /// `frame_width × frame_height` bytes.  If `frame_data` is shorter than
101    /// the expected luma plane size the available bytes are distributed
102    /// proportionally across tiles.
103    pub fn split_into_tiles(&self, frame_data: &[u8]) -> Vec<TileJob> {
104        if self.tile_cols == 0 || self.tile_rows == 0 {
105            return Vec::new();
106        }
107
108        let base_tile_w = self.frame_width / self.tile_cols;
109        let base_tile_h = self.frame_height / self.tile_rows;
110        let rem_w = self.frame_width % self.tile_cols;
111        let rem_h = self.frame_height % self.tile_rows;
112
113        let total_tiles = (self.tile_rows * self.tile_cols) as usize;
114        let mut jobs = Vec::with_capacity(total_tiles);
115
116        for row in 0..self.tile_rows {
117            let tile_h = if row == self.tile_rows - 1 {
118                base_tile_h + rem_h
119            } else {
120                base_tile_h
121            };
122            let y_offset = row * base_tile_h;
123
124            for col in 0..self.tile_cols {
125                let tile_w = if col == self.tile_cols - 1 {
126                    base_tile_w + rem_w
127                } else {
128                    base_tile_w
129                };
130                let x_offset = col * base_tile_w;
131
132                // Extract the luma bytes belonging to this tile region.
133                let tile_bytes = Self::extract_tile_bytes(
134                    frame_data,
135                    self.frame_width,
136                    x_offset,
137                    y_offset,
138                    tile_w,
139                    tile_h,
140                );
141
142                jobs.push(TileJob {
143                    tile_row: row,
144                    tile_col: col,
145                    tile_data: tile_bytes,
146                    tile_offset: (x_offset, y_offset),
147                    tile_size: (tile_w, tile_h),
148                });
149            }
150        }
151
152        jobs
153    }
154
155    /// Decode all tiles in parallel using rayon and assemble the output frame.
156    ///
157    /// Returns a planar YUV 4:2:0 buffer of length
158    /// `frame_width × frame_height × 3 / 2`.
159    ///
160    /// # Errors
161    ///
162    /// Returns `CodecError::InvalidParameter` when the tile grid dimensions
163    /// are zero, or `CodecError::InvalidBitstream` when an individual tile
164    /// fails to decode.
165    pub fn decode_tiles_parallel(&self, tiles: Vec<TileJob>) -> CodecResult<Vec<u8>> {
166        if self.tile_cols == 0 || self.tile_rows == 0 {
167            return Err(CodecError::InvalidParameter(
168                "tile_cols and tile_rows must be non-zero".to_string(),
169            ));
170        }
171
172        // Decode tiles in parallel; collect (job, decoded_bytes) pairs.
173        let results: Result<Vec<(TileJob, Vec<u8>)>, CodecError> = tiles
174            .into_par_iter()
175            .map(|job| {
176                let decoded = Self::decode_single_tile(&job)?;
177                Ok((job, decoded))
178            })
179            .collect();
180
181        let tile_outputs = results?;
182        Ok(self.assemble_frame(&tile_outputs))
183    }
184
185    /// Assemble decoded tile outputs into a full planar YUV 4:2:0 frame.
186    ///
187    /// The luma (`Y`) plane is filled from each tile's decoded bytes.
188    /// The chroma (`Cb`, `Cr`) planes are zeroed (neutral grey), which is
189    /// appropriate for a structural pass that does not yet decode chroma.
190    ///
191    /// Returns a buffer of `frame_width × frame_height × 3 / 2` bytes:
192    /// - bytes `[0 .. W*H)` — luma
193    /// - bytes `[W*H .. W*H + W*H/4)` — Cb (zeroed)
194    /// - bytes `[W*H + W*H/4 .. W*H*3/2)` — Cr (zeroed)
195    pub fn assemble_frame(&self, tile_outputs: &[(TileJob, Vec<u8>)]) -> Vec<u8> {
196        let luma_size = (self.frame_width * self.frame_height) as usize;
197        let chroma_size = luma_size / 4;
198        let total_size = luma_size + 2 * chroma_size;
199
200        let mut frame = vec![0u8; total_size];
201        let luma_plane = &mut frame[..luma_size];
202
203        for (job, decoded) in tile_outputs {
204            let (x_off, y_off) = job.tile_offset;
205            let (tile_w, tile_h) = job.tile_size;
206
207            // Copy decoded luma bytes row by row into the correct frame region.
208            for row in 0..tile_h {
209                let src_row_start = (row * tile_w) as usize;
210                let src_row_end = src_row_start + tile_w as usize;
211
212                let dst_row_start = ((y_off + row) * self.frame_width + x_off) as usize;
213                let dst_row_end = dst_row_start + tile_w as usize;
214
215                // Guard against decoded buffer being shorter than expected
216                // (e.g. structural stub returning fewer bytes than tile area).
217                let src_available = decoded.len().saturating_sub(src_row_start);
218                if src_available == 0 {
219                    continue;
220                }
221                let copy_len = (src_row_end - src_row_start).min(src_available);
222
223                if dst_row_end <= luma_plane.len() {
224                    luma_plane[dst_row_start..dst_row_start + copy_len]
225                        .copy_from_slice(&decoded[src_row_start..src_row_start + copy_len]);
226                }
227            }
228        }
229
230        // Chroma planes remain zero-initialised (neutral 4:2:0 chroma).
231        frame
232    }
233
234    // ------------------------------------------------------------------
235    // Private helpers
236    // ------------------------------------------------------------------
237
238    /// Extract the luma bytes for a tile region from a row-major frame buffer.
239    fn extract_tile_bytes(
240        frame_data: &[u8],
241        frame_width: u32,
242        x_offset: u32,
243        y_offset: u32,
244        tile_w: u32,
245        tile_h: u32,
246    ) -> Vec<u8> {
247        let mut bytes = Vec::with_capacity((tile_w * tile_h) as usize);
248
249        for row in 0..tile_h {
250            let src_start = ((y_offset + row) * frame_width + x_offset) as usize;
251            let src_end = src_start + tile_w as usize;
252
253            if src_start >= frame_data.len() {
254                // Pad remaining rows with zeros when input is shorter.
255                bytes.extend(std::iter::repeat(0u8).take(tile_w as usize));
256            } else {
257                let available_end = src_end.min(frame_data.len());
258                bytes.extend_from_slice(&frame_data[src_start..available_end]);
259                if available_end < src_end {
260                    bytes.extend(std::iter::repeat(0u8).take(src_end - available_end));
261                }
262            }
263        }
264
265        bytes
266    }
267
268    /// Structural single-tile decode.
269    ///
270    /// For now this copies the input tile bytes as the "decoded" output.  A
271    /// real implementation would invoke the AV1 entropy / prediction /
272    /// transform pipeline here.
273    fn decode_single_tile(job: &TileJob) -> CodecResult<Vec<u8>> {
274        // Structural pass: validate minimum size and return a copy.
275        let (tile_w, tile_h) = job.tile_size;
276        if tile_w == 0 || tile_h == 0 {
277            return Err(CodecError::InvalidBitstream(format!(
278                "Tile ({}, {}) has zero dimension: {}×{}",
279                job.tile_col, job.tile_row, tile_w, tile_h
280            )));
281        }
282        // Return the raw tile bytes as decoded output.
283        Ok(job.tile_data.clone())
284    }
285}
286
287// =============================================================================
288// Tests
289// =============================================================================
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294
295    fn make_decoder_4x2() -> ParallelTileDecoder {
296        ParallelTileDecoder::new(1920, 1080, 4, 2)
297    }
298
299    // -----------------------------------------------------------------------
300    // Constructor
301    // -----------------------------------------------------------------------
302
303    #[test]
304    fn test_new_stores_dimensions() {
305        let dec = ParallelTileDecoder::new(3840, 2160, 8, 4);
306        assert_eq!(dec.frame_width, 3840);
307        assert_eq!(dec.frame_height, 2160);
308        assert_eq!(dec.tile_cols, 8);
309        assert_eq!(dec.tile_rows, 4);
310    }
311
312    // -----------------------------------------------------------------------
313    // split_into_tiles
314    // -----------------------------------------------------------------------
315
316    #[test]
317    fn test_split_tile_count() {
318        let dec = make_decoder_4x2();
319        let frame = vec![0u8; 1920 * 1080];
320        let tiles = dec.split_into_tiles(&frame);
321        assert_eq!(tiles.len(), 8, "4 cols × 2 rows = 8 tiles");
322    }
323
324    #[test]
325    fn test_split_tile_offsets() {
326        let dec = make_decoder_4x2();
327        let frame = vec![0u8; 1920 * 1080];
328        let tiles = dec.split_into_tiles(&frame);
329
330        // First tile is always at (0, 0)
331        let first = tiles
332            .iter()
333            .find(|t| t.tile_row == 0 && t.tile_col == 0)
334            .expect("tile (0,0)");
335        assert_eq!(first.tile_offset, (0, 0));
336
337        // Second column of first row
338        let t01 = tiles
339            .iter()
340            .find(|t| t.tile_row == 0 && t.tile_col == 1)
341            .expect("tile (0,1)");
342        assert_eq!(t01.tile_offset.1, 0, "y offset of row 0 should be 0");
343        assert!(t01.tile_offset.0 > 0, "x offset of col 1 should be > 0");
344
345        // Second row
346        let t10 = tiles
347            .iter()
348            .find(|t| t.tile_row == 1 && t.tile_col == 0)
349            .expect("tile (1,0)");
350        assert!(t10.tile_offset.1 > 0, "y offset of row 1 should be > 0");
351    }
352
353    #[test]
354    fn test_split_tile_sizes_sum_to_frame() {
355        // Use a size where width and height are exact multiples.
356        let dec = ParallelTileDecoder::new(800, 600, 4, 3);
357        let frame = vec![0u8; 800 * 600];
358        let tiles = dec.split_into_tiles(&frame);
359        assert_eq!(tiles.len(), 12);
360
361        // All tiles in col 0 should have the same width.
362        let col0_widths: Vec<u32> = tiles
363            .iter()
364            .filter(|t| t.tile_col == 0)
365            .map(|t| t.tile_size.0)
366            .collect();
367        assert!(col0_widths.iter().all(|&w| w == col0_widths[0]));
368
369        // Widths of all tiles in row 0 should sum to frame_width.
370        let row0_width_sum: u32 = tiles
371            .iter()
372            .filter(|t| t.tile_row == 0)
373            .map(|t| t.tile_size.0)
374            .sum();
375        assert_eq!(row0_width_sum, 800);
376
377        // Heights of all tiles in col 0 should sum to frame_height.
378        let col0_height_sum: u32 = tiles
379            .iter()
380            .filter(|t| t.tile_col == 0)
381            .map(|t| t.tile_size.1)
382            .sum();
383        assert_eq!(col0_height_sum, 600);
384    }
385
386    #[test]
387    fn test_split_handles_non_divisible_dimensions() {
388        // 1000 / 3 = 333 rem 1; 700 / 2 = 350 rem 0
389        let dec = ParallelTileDecoder::new(1000, 700, 3, 2);
390        let frame = vec![0u8; 1000 * 700];
391        let tiles = dec.split_into_tiles(&frame);
392        assert_eq!(tiles.len(), 6);
393
394        let row0_width_sum: u32 = tiles
395            .iter()
396            .filter(|t| t.tile_row == 0)
397            .map(|t| t.tile_size.0)
398            .sum();
399        assert_eq!(row0_width_sum, 1000, "widths must cover full frame width");
400
401        let col0_height_sum: u32 = tiles
402            .iter()
403            .filter(|t| t.tile_col == 0)
404            .map(|t| t.tile_size.1)
405            .sum();
406        assert_eq!(col0_height_sum, 700, "heights must cover full frame height");
407    }
408
409    #[test]
410    fn test_split_zero_cols_returns_empty() {
411        let dec = ParallelTileDecoder::new(1920, 1080, 0, 2);
412        let frame = vec![0u8; 100];
413        let tiles = dec.split_into_tiles(&frame);
414        assert!(tiles.is_empty());
415    }
416
417    #[test]
418    fn test_split_tile_data_length_matches_tile_area() {
419        let dec = ParallelTileDecoder::new(400, 300, 2, 2);
420        let frame = vec![0xAAu8; 400 * 300];
421        let tiles = dec.split_into_tiles(&frame);
422        for tile in &tiles {
423            let expected_len = (tile.tile_size.0 * tile.tile_size.1) as usize;
424            assert_eq!(
425                tile.tile_data.len(),
426                expected_len,
427                "tile ({},{}) data length mismatch",
428                tile.tile_row,
429                tile.tile_col
430            );
431        }
432    }
433
434    #[test]
435    fn test_split_preserves_pixel_values() {
436        // Mark each pixel with a unique value to verify correct extraction.
437        let width = 4u32;
438        let height = 4u32;
439        let frame: Vec<u8> = (0..(width * height) as u8).collect();
440
441        let dec = ParallelTileDecoder::new(width, height, 2, 2);
442        let tiles = dec.split_into_tiles(&frame);
443
444        // Top-left tile (col=0, row=0) covers pixels [0,1] × [0,1]
445        let tl = tiles
446            .iter()
447            .find(|t| t.tile_row == 0 && t.tile_col == 0)
448            .expect("tl");
449        assert_eq!(tl.tile_data[0], frame[0]); // (0,0)
450        assert_eq!(tl.tile_data[1], frame[1]); // (0,1)
451        assert_eq!(tl.tile_data[2], frame[width as usize]); // (1,0)
452    }
453
454    // -----------------------------------------------------------------------
455    // decode_tiles_parallel
456    // -----------------------------------------------------------------------
457
458    #[test]
459    fn test_decode_tiles_parallel_output_size() {
460        let dec = make_decoder_4x2();
461        let frame = vec![0u8; 1920 * 1080];
462        let tiles = dec.split_into_tiles(&frame);
463        let output = dec.decode_tiles_parallel(tiles).expect("decode");
464        assert_eq!(output.len(), 1920 * 1080 * 3 / 2, "YUV 4:2:0 output size");
465    }
466
467    #[test]
468    fn test_decode_tiles_parallel_single_tile() {
469        let dec = ParallelTileDecoder::new(320, 240, 1, 1);
470        let frame = vec![0x7Fu8; 320 * 240];
471        let tiles = dec.split_into_tiles(&frame);
472        let output = dec.decode_tiles_parallel(tiles).expect("decode");
473        assert_eq!(output.len(), 320 * 240 * 3 / 2);
474        // Luma plane should match the input (structural pass copies bytes).
475        assert!(output[..320 * 240].iter().all(|&b| b == 0x7F));
476    }
477
478    #[test]
479    fn test_decode_tiles_parallel_zero_cols_errors() {
480        let dec = ParallelTileDecoder::new(640, 480, 0, 2);
481        let result = dec.decode_tiles_parallel(vec![]);
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn test_decode_tiles_parallel_preserves_content() {
487        // Encode row index into luma bytes; verify round-trip through decode+assemble.
488        let width = 64u32;
489        let height = 32u32;
490        let mut frame = vec![0u8; (width * height) as usize];
491        for y in 0..height {
492            for x in 0..width {
493                frame[(y * width + x) as usize] = (y % 256) as u8;
494            }
495        }
496
497        let dec = ParallelTileDecoder::new(width, height, 4, 2);
498        let tiles = dec.split_into_tiles(&frame);
499        let output = dec.decode_tiles_parallel(tiles).expect("decode");
500
501        // Verify luma plane content matches original.
502        for y in 0..height {
503            for x in 0..width {
504                let idx = (y * width + x) as usize;
505                assert_eq!(output[idx], frame[idx], "luma mismatch at ({x},{y})");
506            }
507        }
508    }
509
510    // -----------------------------------------------------------------------
511    // assemble_frame
512    // -----------------------------------------------------------------------
513
514    #[test]
515    fn test_assemble_frame_output_size() {
516        let dec = ParallelTileDecoder::new(640, 480, 2, 2);
517        let tile_outputs: Vec<(TileJob, Vec<u8>)> = Vec::new();
518        let frame = dec.assemble_frame(&tile_outputs);
519        assert_eq!(frame.len(), 640 * 480 * 3 / 2);
520    }
521
522    #[test]
523    fn test_assemble_frame_chroma_zeroed() {
524        let dec = ParallelTileDecoder::new(64, 32, 1, 1);
525        let luma_size = 64 * 32;
526        let job = TileJob {
527            tile_row: 0,
528            tile_col: 0,
529            tile_data: vec![0xFFu8; luma_size],
530            tile_offset: (0, 0),
531            tile_size: (64, 32),
532        };
533        let tile_outputs = vec![(job, vec![0xFFu8; luma_size])];
534        let frame = dec.assemble_frame(&tile_outputs);
535
536        // Luma should be all 0xFF.
537        assert!(frame[..luma_size].iter().all(|&b| b == 0xFF));
538        // Chroma should be all zeros.
539        assert!(frame[luma_size..].iter().all(|&b| b == 0));
540    }
541
542    #[test]
543    fn test_assemble_frame_single_tile_full_coverage() {
544        let dec = ParallelTileDecoder::new(8, 4, 1, 1);
545        let tile_bytes: Vec<u8> = (0..32u8).collect();
546        let job = TileJob {
547            tile_row: 0,
548            tile_col: 0,
549            tile_data: tile_bytes.clone(),
550            tile_offset: (0, 0),
551            tile_size: (8, 4),
552        };
553        let tile_outputs = vec![(job, tile_bytes.clone())];
554        let frame = dec.assemble_frame(&tile_outputs);
555
556        for (i, &expected) in tile_bytes.iter().enumerate() {
557            assert_eq!(frame[i], expected, "luma byte {i} mismatch");
558        }
559    }
560}