Skip to main content

oximedia_codec/intra/
context.rs

1//! Intra prediction context.
2//!
3//! The prediction context manages neighbor samples and availability
4//! information needed for intra prediction. It provides a unified
5//! interface for accessing top, left, and top-left samples.
6//!
7//! # Sample Layout
8//!
9//! For a block at position (x, y):
10//! ```text
11//!     TL | T0 T1 T2 T3 ...
12//!     ---+----------------
13//!     L0 | P00 P01 P02 P03
14//!     L1 | P10 P11 P12 P13
15//!     L2 | P20 P21 P22 P23
16//!     L3 | P30 P31 P32 P33
17//! ```
18//!
19//! Where:
20//! - TL = Top-Left sample
21//! - T0..Tn = Top samples
22//! - L0..Ln = Left samples
23//! - Pxy = Predicted samples
24
25#![forbid(unsafe_code)]
26#![allow(dead_code)]
27#![allow(clippy::struct_excessive_bools)]
28#![allow(clippy::cast_sign_loss)]
29#![allow(clippy::cast_possible_truncation)]
30
31use super::{BitDepth, MAX_NEIGHBOR_SAMPLES};
32
33/// Top samples array type.
34pub type TopSamples = [u16; MAX_NEIGHBOR_SAMPLES];
35
36/// Left samples array type.
37pub type LeftSamples = [u16; MAX_NEIGHBOR_SAMPLES];
38
39/// Neighbor availability flags.
40#[derive(Clone, Copy, Debug, Default)]
41pub struct NeighborAvailability {
42    /// Top neighbor is available.
43    pub top: bool,
44    /// Left neighbor is available.
45    pub left: bool,
46    /// Top-left neighbor is available.
47    pub top_left: bool,
48    /// Top-right neighbor is available.
49    pub top_right: bool,
50    /// Bottom-left neighbor is available.
51    pub bottom_left: bool,
52}
53
54impl NeighborAvailability {
55    /// All neighbors available.
56    pub const ALL: Self = Self {
57        top: true,
58        left: true,
59        top_left: true,
60        top_right: true,
61        bottom_left: true,
62    };
63
64    /// No neighbors available.
65    pub const NONE: Self = Self {
66        top: false,
67        left: false,
68        top_left: false,
69        top_right: false,
70        bottom_left: false,
71    };
72
73    /// Check if any neighbor is available.
74    #[must_use]
75    pub const fn any(&self) -> bool {
76        self.top || self.left || self.top_left || self.top_right || self.bottom_left
77    }
78
79    /// Check if top row is available (for vertical prediction).
80    #[must_use]
81    pub const fn has_top(&self) -> bool {
82        self.top
83    }
84
85    /// Check if left column is available (for horizontal prediction).
86    #[must_use]
87    pub const fn has_left(&self) -> bool {
88        self.left
89    }
90}
91
92/// Intra prediction context.
93///
94/// Holds neighbor samples and availability information for intra prediction.
95#[derive(Clone, Debug)]
96pub struct IntraPredContext {
97    /// Top row samples (above the block).
98    top: TopSamples,
99    /// Left column samples (to the left of the block).
100    left: LeftSamples,
101    /// Top-left corner sample.
102    top_left: u16,
103    /// Block width.
104    width: usize,
105    /// Block height.
106    height: usize,
107    /// Bit depth.
108    bit_depth: BitDepth,
109    /// Neighbor availability.
110    availability: NeighborAvailability,
111}
112
113impl IntraPredContext {
114    /// Create a new prediction context.
115    #[must_use]
116    pub fn new(width: usize, height: usize, bit_depth: BitDepth) -> Self {
117        let midpoint = bit_depth.midpoint();
118        Self {
119            top: [midpoint; MAX_NEIGHBOR_SAMPLES],
120            left: [midpoint; MAX_NEIGHBOR_SAMPLES],
121            top_left: midpoint,
122            width,
123            height,
124            bit_depth,
125            availability: NeighborAvailability::NONE,
126        }
127    }
128
129    /// Create with specific neighbor availability.
130    #[must_use]
131    pub fn with_availability(
132        width: usize,
133        height: usize,
134        bit_depth: BitDepth,
135        availability: NeighborAvailability,
136    ) -> Self {
137        let mut ctx = Self::new(width, height, bit_depth);
138        ctx.availability = availability;
139        ctx
140    }
141
142    /// Get the bit depth.
143    #[must_use]
144    pub const fn bit_depth(&self) -> BitDepth {
145        self.bit_depth
146    }
147
148    /// Get block width.
149    #[must_use]
150    pub const fn width(&self) -> usize {
151        self.width
152    }
153
154    /// Get block height.
155    #[must_use]
156    pub const fn height(&self) -> usize {
157        self.height
158    }
159
160    /// Check if top neighbor is available.
161    #[must_use]
162    pub const fn has_top(&self) -> bool {
163        self.availability.top
164    }
165
166    /// Check if left neighbor is available.
167    #[must_use]
168    pub const fn has_left(&self) -> bool {
169        self.availability.left
170    }
171
172    /// Check if top-left neighbor is available.
173    #[must_use]
174    pub const fn has_top_left(&self) -> bool {
175        self.availability.top_left
176    }
177
178    /// Get neighbor availability.
179    #[must_use]
180    pub const fn availability(&self) -> NeighborAvailability {
181        self.availability
182    }
183
184    /// Set neighbor availability.
185    pub fn set_availability(&mut self, has_top: bool, has_left: bool) {
186        self.availability.top = has_top;
187        self.availability.left = has_left;
188        self.availability.top_left = has_top && has_left;
189    }
190
191    /// Set full neighbor availability.
192    pub fn set_full_availability(&mut self, availability: NeighborAvailability) {
193        self.availability = availability;
194    }
195
196    /// Get top samples slice.
197    #[must_use]
198    pub fn top_samples(&self) -> &[u16] {
199        &self.top[..self.width.min(MAX_NEIGHBOR_SAMPLES)]
200    }
201
202    /// Get left samples slice.
203    #[must_use]
204    pub fn left_samples(&self) -> &[u16] {
205        &self.left[..self.height.min(MAX_NEIGHBOR_SAMPLES)]
206    }
207
208    /// Get extended top samples (including top-right).
209    #[must_use]
210    pub fn extended_top_samples(&self) -> &[u16] {
211        let count = (self.width * 2).min(MAX_NEIGHBOR_SAMPLES);
212        &self.top[..count]
213    }
214
215    /// Get extended left samples (including bottom-left).
216    #[must_use]
217    pub fn extended_left_samples(&self) -> &[u16] {
218        let count = (self.height * 2).min(MAX_NEIGHBOR_SAMPLES);
219        &self.left[..count]
220    }
221
222    /// Get top-left sample.
223    #[must_use]
224    pub const fn top_left_sample(&self) -> u16 {
225        self.top_left
226    }
227
228    /// Set a top sample.
229    pub fn set_top_sample(&mut self, idx: usize, value: u16) {
230        if idx < MAX_NEIGHBOR_SAMPLES {
231            self.top[idx] = value;
232        }
233    }
234
235    /// Set a left sample.
236    pub fn set_left_sample(&mut self, idx: usize, value: u16) {
237        if idx < MAX_NEIGHBOR_SAMPLES {
238            self.left[idx] = value;
239        }
240    }
241
242    /// Set top-left sample.
243    pub fn set_top_left_sample(&mut self, value: u16) {
244        self.top_left = value;
245    }
246
247    /// Set all top samples from a slice.
248    pub fn set_top_samples(&mut self, samples: &[u16]) {
249        let count = samples.len().min(MAX_NEIGHBOR_SAMPLES);
250        self.top[..count].copy_from_slice(&samples[..count]);
251    }
252
253    /// Set all left samples from a slice.
254    pub fn set_left_samples(&mut self, samples: &[u16]) {
255        let count = samples.len().min(MAX_NEIGHBOR_SAMPLES);
256        self.left[..count].copy_from_slice(&samples[..count]);
257    }
258
259    /// Apply a filter function to top samples.
260    pub fn filter_top_samples<F>(&mut self, filter: F)
261    where
262        F: FnOnce(&mut [u16]),
263    {
264        filter(&mut self.top);
265    }
266
267    /// Apply a filter function to left samples.
268    pub fn filter_left_samples<F>(&mut self, filter: F)
269    where
270        F: FnOnce(&mut [u16]),
271    {
272        filter(&mut self.left);
273    }
274
275    /// Reconstruct neighbors from a frame buffer.
276    ///
277    /// # Arguments
278    /// * `frame` - Frame sample buffer
279    /// * `frame_stride` - Frame row stride
280    /// * `block_x` - Block X position in samples
281    /// * `block_y` - Block Y position in samples
282    /// * `frame_width` - Frame width in samples
283    /// * `frame_height` - Frame height in samples
284    #[allow(clippy::too_many_arguments)]
285    pub fn reconstruct_neighbors(
286        &mut self,
287        frame: &[u16],
288        frame_stride: usize,
289        block_x: usize,
290        block_y: usize,
291        frame_width: usize,
292        frame_height: usize,
293    ) {
294        // Determine availability
295        let has_top = block_y > 0;
296        let has_left = block_x > 0;
297        let has_top_right = has_top && (block_x + self.width * 2 <= frame_width);
298        let has_bottom_left = has_left && (block_y + self.height * 2 <= frame_height);
299
300        self.availability = NeighborAvailability {
301            top: has_top,
302            left: has_left,
303            top_left: has_top && has_left,
304            top_right: has_top_right,
305            bottom_left: has_bottom_left,
306        };
307
308        // Copy top samples
309        if has_top {
310            let top_y = block_y - 1;
311            let top_row_start = top_y * frame_stride;
312
313            // Copy regular top samples
314            for x in 0..self.width {
315                let frame_x = block_x + x;
316                if frame_x < frame_width {
317                    self.top[x] = frame[top_row_start + frame_x];
318                }
319            }
320
321            // Copy top-right samples if available
322            if has_top_right {
323                for x in self.width..(self.width * 2) {
324                    let frame_x = block_x + x;
325                    if frame_x < frame_width {
326                        self.top[x] = frame[top_row_start + frame_x];
327                    }
328                }
329            } else {
330                // Replicate last top sample
331                let last = self.top[self.width.saturating_sub(1)];
332                for x in self.width..(self.width * 2) {
333                    self.top[x] = last;
334                }
335            }
336        }
337
338        // Copy left samples
339        if has_left {
340            let left_x = block_x - 1;
341
342            // Copy regular left samples
343            for y in 0..self.height {
344                let frame_y = block_y + y;
345                if frame_y < frame_height {
346                    self.left[y] = frame[frame_y * frame_stride + left_x];
347                }
348            }
349
350            // Copy bottom-left samples if available
351            if has_bottom_left {
352                for y in self.height..(self.height * 2) {
353                    let frame_y = block_y + y;
354                    if frame_y < frame_height {
355                        self.left[y] = frame[frame_y * frame_stride + left_x];
356                    }
357                }
358            } else {
359                // Replicate last left sample
360                let last = self.left[self.height.saturating_sub(1)];
361                for y in self.height..(self.height * 2) {
362                    self.left[y] = last;
363                }
364            }
365        }
366
367        // Copy top-left sample
368        if has_top && has_left {
369            self.top_left = frame[(block_y - 1) * frame_stride + (block_x - 1)];
370        } else if has_top {
371            self.top_left = self.top[0];
372        } else if has_left {
373            self.top_left = self.left[0];
374        }
375    }
376
377    /// Check if the block is at the frame edge.
378    #[must_use]
379    pub const fn is_at_frame_edge(&self) -> bool {
380        !self.availability.top || !self.availability.left
381    }
382
383    /// Fill unavailable samples with the midpoint value.
384    pub fn fill_unavailable(&mut self) {
385        let midpoint = self.bit_depth.midpoint();
386
387        if !self.availability.top {
388            self.top.fill(midpoint);
389        }
390
391        if !self.availability.left {
392            self.left.fill(midpoint);
393        }
394
395        if !self.availability.top_left {
396            self.top_left = midpoint;
397        }
398    }
399
400    /// Get a sample at an extended position (can be negative for top-left region).
401    #[must_use]
402    pub fn get_extended_sample(&self, x: i32, y: i32) -> u16 {
403        if x < 0 && y < 0 {
404            // Top-left region
405            self.top_left
406        } else if y < 0 {
407            // Top row
408            let idx = x as usize;
409            if idx < self.top.len() {
410                self.top[idx]
411            } else {
412                self.top[self.top.len() - 1]
413            }
414        } else if x < 0 {
415            // Left column
416            let idx = y as usize;
417            if idx < self.left.len() {
418                self.left[idx]
419            } else {
420                self.left[self.left.len() - 1]
421            }
422        } else {
423            // Should not happen for neighbor access
424            self.bit_depth.midpoint()
425        }
426    }
427}
428
429impl Default for IntraPredContext {
430    fn default() -> Self {
431        Self::new(4, 4, BitDepth::Bits8)
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_context_creation() {
441        let ctx = IntraPredContext::new(8, 8, BitDepth::Bits8);
442        assert_eq!(ctx.width(), 8);
443        assert_eq!(ctx.height(), 8);
444        assert_eq!(ctx.bit_depth(), BitDepth::Bits8);
445
446        // Should be initialized with midpoint
447        assert_eq!(ctx.top_left_sample(), 128);
448        assert!(ctx.top_samples().iter().all(|&s| s == 128));
449        assert!(ctx.left_samples().iter().all(|&s| s == 128));
450    }
451
452    #[test]
453    fn test_availability() {
454        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
455
456        assert!(!ctx.has_top());
457        assert!(!ctx.has_left());
458
459        ctx.set_availability(true, true);
460        assert!(ctx.has_top());
461        assert!(ctx.has_left());
462        assert!(ctx.has_top_left());
463    }
464
465    #[test]
466    fn test_sample_setting() {
467        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
468
469        ctx.set_top_sample(0, 100);
470        ctx.set_top_sample(1, 110);
471        ctx.set_left_sample(0, 90);
472        ctx.set_top_left_sample(95);
473
474        assert_eq!(ctx.top_samples()[0], 100);
475        assert_eq!(ctx.top_samples()[1], 110);
476        assert_eq!(ctx.left_samples()[0], 90);
477        assert_eq!(ctx.top_left_sample(), 95);
478    }
479
480    #[test]
481    fn test_bulk_sample_setting() {
482        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
483
484        let top = [100u16, 110, 120, 130];
485        let left = [90u16, 100, 110, 120];
486
487        ctx.set_top_samples(&top);
488        ctx.set_left_samples(&left);
489
490        assert_eq!(ctx.top_samples()[..4], [100, 110, 120, 130]);
491        assert_eq!(ctx.left_samples()[..4], [90, 100, 110, 120]);
492    }
493
494    #[test]
495    fn test_reconstruct_neighbors() {
496        // Create a simple 16x16 frame
497        let frame_width = 16;
498        let frame_height = 16;
499        let mut frame = vec![0u16; frame_width * frame_height];
500
501        // Fill with gradient
502        for y in 0..frame_height {
503            for x in 0..frame_width {
504                frame[y * frame_width + x] = ((x + y) * 10) as u16;
505            }
506        }
507
508        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
509
510        // Reconstruct at position (4, 4)
511        ctx.reconstruct_neighbors(&frame, frame_width, 4, 4, frame_width, frame_height);
512
513        assert!(ctx.has_top());
514        assert!(ctx.has_left());
515        assert!(ctx.has_top_left());
516
517        // Top row should be from y=3, x=4..8
518        // Values: (4+3)*10=70, (5+3)*10=80, etc.
519        assert_eq!(ctx.top_samples()[0], 70);
520        assert_eq!(ctx.top_samples()[1], 80);
521
522        // Left column should be from x=3, y=4..8
523        // Values: (3+4)*10=70, (3+5)*10=80, etc.
524        assert_eq!(ctx.left_samples()[0], 70);
525        assert_eq!(ctx.left_samples()[1], 80);
526
527        // Top-left should be from (3, 3)
528        assert_eq!(ctx.top_left_sample(), 60);
529    }
530
531    #[test]
532    fn test_reconstruct_at_edge() {
533        let frame_width = 16;
534        let frame_height = 16;
535        let frame = vec![100u16; frame_width * frame_height];
536
537        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
538
539        // Reconstruct at position (0, 0) - top-left corner
540        ctx.reconstruct_neighbors(&frame, frame_width, 0, 0, frame_width, frame_height);
541
542        assert!(!ctx.has_top());
543        assert!(!ctx.has_left());
544        assert!(!ctx.has_top_left());
545    }
546
547    #[test]
548    fn test_extended_sample_access() {
549        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
550
551        ctx.set_top_samples(&[10, 20, 30, 40]);
552        ctx.set_left_samples(&[15, 25, 35, 45]);
553        ctx.set_top_left_sample(5);
554
555        // Top-left region
556        assert_eq!(ctx.get_extended_sample(-1, -1), 5);
557
558        // Top row
559        assert_eq!(ctx.get_extended_sample(0, -1), 10);
560        assert_eq!(ctx.get_extended_sample(1, -1), 20);
561
562        // Left column
563        assert_eq!(ctx.get_extended_sample(-1, 0), 15);
564        assert_eq!(ctx.get_extended_sample(-1, 1), 25);
565    }
566
567    #[test]
568    fn test_neighbor_availability_constants() {
569        let all = NeighborAvailability::ALL;
570        assert!(all.top);
571        assert!(all.left);
572        assert!(all.top_left);
573        assert!(all.any());
574
575        let none = NeighborAvailability::NONE;
576        assert!(!none.top);
577        assert!(!none.left);
578        assert!(!none.any());
579    }
580
581    #[test]
582    fn test_fill_unavailable() {
583        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
584        ctx.set_top_samples(&[200, 200, 200, 200]);
585        ctx.availability.top = false;
586
587        ctx.fill_unavailable();
588
589        // Top should be filled with midpoint (128)
590        assert!(ctx.top_samples().iter().all(|&s| s == 128));
591    }
592
593    #[test]
594    fn test_bit_depth_10() {
595        let ctx = IntraPredContext::new(4, 4, BitDepth::Bits10);
596        assert_eq!(ctx.bit_depth(), BitDepth::Bits10);
597        assert_eq!(ctx.top_left_sample(), 512); // 10-bit midpoint
598    }
599
600    #[test]
601    fn test_extended_samples() {
602        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
603
604        // Set extended top samples (for top-right)
605        for i in 0..8 {
606            ctx.set_top_sample(i, (i * 10) as u16);
607        }
608
609        let extended = ctx.extended_top_samples();
610        assert_eq!(extended.len(), 8);
611        assert_eq!(extended[0], 0);
612        assert_eq!(extended[7], 70);
613    }
614}