Skip to main content

oximedia_codec/intra/
dc.rs

1//! DC prediction implementations.
2//!
3//! DC prediction uses the average of neighboring samples. Several variants
4//! exist depending on which neighbors are available:
5//!
6//! - **Both neighbors**: Average of top and left samples
7//! - **Top only**: Average of top samples
8//! - **Left only**: Average of left samples
9//! - **No neighbors**: Use midpoint value (128 for 8-bit)
10//!
11//! DC prediction is the simplest and most common intra mode.
12
13#![forbid(unsafe_code)]
14#![allow(dead_code)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_sign_loss)]
17#![allow(clippy::needless_range_loop)]
18#![allow(clippy::similar_names)]
19#![allow(clippy::unused_self)]
20#![allow(clippy::trivially_copy_pass_by_ref)]
21#![allow(clippy::match_same_arms)]
22
23use super::{BitDepth, BlockDimensions, IntraPredContext, IntraPredictor};
24
25/// DC prediction mode variant.
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
27pub enum DcMode {
28    /// Average of both top and left neighbors.
29    #[default]
30    Both,
31    /// Average of top neighbors only.
32    TopOnly,
33    /// Average of left neighbors only.
34    LeftOnly,
35    /// No neighbors available, use midpoint.
36    NoNeighbors,
37    /// DC with gradient (adds gradient based on position).
38    WithGradient,
39}
40
41/// DC predictor implementation.
42#[derive(Clone, Copy, Debug, Default)]
43pub struct DcPredictor {
44    /// Bit depth for sample values.
45    bit_depth: BitDepth,
46}
47
48impl DcPredictor {
49    /// Create a new DC predictor.
50    #[must_use]
51    pub const fn new(bit_depth: BitDepth) -> Self {
52        Self { bit_depth }
53    }
54
55    /// Calculate DC value from top samples only.
56    fn dc_top_only(top: &[u16], width: usize) -> u16 {
57        if width == 0 {
58            return 128;
59        }
60
61        let sum: u32 = top.iter().take(width).map(|&s| u32::from(s)).sum();
62        let avg = (sum + (width as u32 / 2)) / width as u32;
63        avg as u16
64    }
65
66    /// Calculate DC value from left samples only.
67    fn dc_left_only(left: &[u16], height: usize) -> u16 {
68        if height == 0 {
69            return 128;
70        }
71
72        let sum: u32 = left.iter().take(height).map(|&s| u32::from(s)).sum();
73        let avg = (sum + (height as u32 / 2)) / height as u32;
74        avg as u16
75    }
76
77    /// Calculate DC value from both neighbors.
78    fn dc_both(top: &[u16], left: &[u16], width: usize, height: usize) -> u16 {
79        if width == 0 && height == 0 {
80            return 128;
81        }
82
83        let top_sum: u32 = top.iter().take(width).map(|&s| u32::from(s)).sum();
84        let left_sum: u32 = left.iter().take(height).map(|&s| u32::from(s)).sum();
85
86        let total = width + height;
87        let sum = top_sum + left_sum;
88        let avg = (sum + (total as u32 / 2)) / total as u32;
89        avg as u16
90    }
91
92    /// Predict with DC mode.
93    pub fn predict_dc(
94        &self,
95        ctx: &IntraPredContext,
96        output: &mut [u16],
97        stride: usize,
98        dims: BlockDimensions,
99    ) {
100        let mode = self.determine_mode(ctx);
101        let dc_value = self.calculate_dc(ctx, dims, mode);
102
103        // Fill block with DC value
104        for y in 0..dims.height {
105            let row_start = y * stride;
106            for x in 0..dims.width {
107                output[row_start + x] = dc_value;
108            }
109        }
110    }
111
112    /// Predict with DC and gradient adjustment.
113    pub fn predict_dc_gradient(
114        &self,
115        ctx: &IntraPredContext,
116        output: &mut [u16],
117        stride: usize,
118        dims: BlockDimensions,
119    ) {
120        let base_dc = self.calculate_dc(ctx, dims, DcMode::Both);
121        let top = ctx.top_samples();
122        let left = ctx.left_samples();
123        let max_val = self.bit_depth.max_value();
124
125        // Calculate gradients
126        let top_left = ctx.top_left_sample();
127
128        for y in 0..dims.height {
129            let row_start = y * stride;
130            let left_diff = i32::from(left[y]) - i32::from(top_left);
131
132            for x in 0..dims.width {
133                let top_diff = i32::from(top[x]) - i32::from(top_left);
134
135                // Add gradient to base DC
136                let gradient = (top_diff + left_diff) / 2;
137                let pred = i32::from(base_dc) + gradient;
138
139                // Clamp to valid range
140                let clamped = pred.clamp(0, i32::from(max_val));
141                output[row_start + x] = clamped as u16;
142            }
143        }
144    }
145
146    /// Determine which DC mode to use based on neighbor availability.
147    fn determine_mode(&self, ctx: &IntraPredContext) -> DcMode {
148        let has_top = ctx.has_top();
149        let has_left = ctx.has_left();
150
151        match (has_top, has_left) {
152            (true, true) => DcMode::Both,
153            (true, false) => DcMode::TopOnly,
154            (false, true) => DcMode::LeftOnly,
155            (false, false) => DcMode::NoNeighbors,
156        }
157    }
158
159    /// Calculate DC value based on mode.
160    fn calculate_dc(&self, ctx: &IntraPredContext, dims: BlockDimensions, mode: DcMode) -> u16 {
161        match mode {
162            DcMode::Both => Self::dc_both(
163                ctx.top_samples(),
164                ctx.left_samples(),
165                dims.width,
166                dims.height,
167            ),
168            DcMode::TopOnly => Self::dc_top_only(ctx.top_samples(), dims.width),
169            DcMode::LeftOnly => Self::dc_left_only(ctx.left_samples(), dims.height),
170            DcMode::NoNeighbors => self.bit_depth.midpoint(),
171            DcMode::WithGradient => Self::dc_both(
172                ctx.top_samples(),
173                ctx.left_samples(),
174                dims.width,
175                dims.height,
176            ),
177        }
178    }
179}
180
181impl IntraPredictor for DcPredictor {
182    fn predict(
183        &self,
184        ctx: &IntraPredContext,
185        output: &mut [u16],
186        stride: usize,
187        dims: BlockDimensions,
188    ) {
189        self.predict_dc(ctx, output, stride, dims);
190    }
191}
192
193/// Top-only DC predictor.
194#[derive(Clone, Copy, Debug, Default)]
195pub struct DcTopPredictor {
196    bit_depth: BitDepth,
197}
198
199impl DcTopPredictor {
200    /// Create a new top-only DC predictor.
201    #[must_use]
202    pub const fn new(bit_depth: BitDepth) -> Self {
203        Self { bit_depth }
204    }
205}
206
207impl IntraPredictor for DcTopPredictor {
208    fn predict(
209        &self,
210        ctx: &IntraPredContext,
211        output: &mut [u16],
212        stride: usize,
213        dims: BlockDimensions,
214    ) {
215        let dc_value = if ctx.has_top() {
216            DcPredictor::dc_top_only(ctx.top_samples(), dims.width)
217        } else {
218            self.bit_depth.midpoint()
219        };
220
221        for y in 0..dims.height {
222            let row_start = y * stride;
223            for x in 0..dims.width {
224                output[row_start + x] = dc_value;
225            }
226        }
227    }
228}
229
230/// Left-only DC predictor.
231#[derive(Clone, Copy, Debug, Default)]
232pub struct DcLeftPredictor {
233    bit_depth: BitDepth,
234}
235
236impl DcLeftPredictor {
237    /// Create a new left-only DC predictor.
238    #[must_use]
239    pub const fn new(bit_depth: BitDepth) -> Self {
240        Self { bit_depth }
241    }
242}
243
244impl IntraPredictor for DcLeftPredictor {
245    fn predict(
246        &self,
247        ctx: &IntraPredContext,
248        output: &mut [u16],
249        stride: usize,
250        dims: BlockDimensions,
251    ) {
252        let dc_value = if ctx.has_left() {
253            DcPredictor::dc_left_only(ctx.left_samples(), dims.height)
254        } else {
255            self.bit_depth.midpoint()
256        };
257
258        for y in 0..dims.height {
259            let row_start = y * stride;
260            for x in 0..dims.width {
261                output[row_start + x] = dc_value;
262            }
263        }
264    }
265}
266
267/// No-neighbors DC predictor (uses midpoint value).
268#[derive(Clone, Copy, Debug, Default)]
269pub struct Dc128Predictor {
270    bit_depth: BitDepth,
271}
272
273impl Dc128Predictor {
274    /// Create a new 128 (midpoint) DC predictor.
275    #[must_use]
276    pub const fn new(bit_depth: BitDepth) -> Self {
277        Self { bit_depth }
278    }
279}
280
281impl IntraPredictor for Dc128Predictor {
282    fn predict(
283        &self,
284        _ctx: &IntraPredContext,
285        output: &mut [u16],
286        stride: usize,
287        dims: BlockDimensions,
288    ) {
289        let dc_value = self.bit_depth.midpoint();
290
291        for y in 0..dims.height {
292            let row_start = y * stride;
293            for x in 0..dims.width {
294                output[row_start + x] = dc_value;
295            }
296        }
297    }
298}
299
300/// DC predictor with gradient adjustment.
301#[derive(Clone, Copy, Debug, Default)]
302pub struct DcGradientPredictor {
303    bit_depth: BitDepth,
304}
305
306impl DcGradientPredictor {
307    /// Create a new gradient DC predictor.
308    #[must_use]
309    pub const fn new(bit_depth: BitDepth) -> Self {
310        Self { bit_depth }
311    }
312}
313
314impl IntraPredictor for DcGradientPredictor {
315    fn predict(
316        &self,
317        ctx: &IntraPredContext,
318        output: &mut [u16],
319        stride: usize,
320        dims: BlockDimensions,
321    ) {
322        let predictor = DcPredictor::new(self.bit_depth);
323        predictor.predict_dc_gradient(ctx, output, stride, dims);
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::intra::context::IntraPredContext;
331
332    fn create_test_context() -> IntraPredContext {
333        let mut ctx = IntraPredContext::new(8, 8, BitDepth::Bits8);
334
335        // Set top samples: [100, 110, 120, 130, 140, 150, 160, 170]
336        for i in 0..8 {
337            ctx.set_top_sample(i, 100 + (i as u16 * 10));
338        }
339
340        // Set left samples: [80, 90, 100, 110, 120, 130, 140, 150]
341        for i in 0..8 {
342            ctx.set_left_sample(i, 80 + (i as u16 * 10));
343        }
344
345        ctx.set_top_left_sample(90);
346        ctx.set_availability(true, true);
347
348        ctx
349    }
350
351    #[test]
352    fn test_dc_top_only() {
353        let top = [100u16, 110, 120, 130];
354        let dc = DcPredictor::dc_top_only(&top, 4);
355        // Average: (100 + 110 + 120 + 130) / 4 = 460 / 4 = 115
356        assert_eq!(dc, 115);
357    }
358
359    #[test]
360    fn test_dc_left_only() {
361        let left = [80u16, 90, 100, 110];
362        let dc = DcPredictor::dc_left_only(&left, 4);
363        // Average: (80 + 90 + 100 + 110) / 4 = 380 / 4 = 95
364        assert_eq!(dc, 95);
365    }
366
367    #[test]
368    fn test_dc_both() {
369        let top = [100u16, 110, 120, 130];
370        let left = [80u16, 90, 100, 110];
371        let dc = DcPredictor::dc_both(&top, &left, 4, 4);
372        // Average: (460 + 380) / 8 = 840 / 8 = 105
373        assert_eq!(dc, 105);
374    }
375
376    #[test]
377    fn test_dc_predictor_both() {
378        let ctx = create_test_context();
379        let predictor = DcPredictor::new(BitDepth::Bits8);
380        let dims = BlockDimensions::new(8, 8);
381        let mut output = vec![0u16; 64];
382
383        predictor.predict(&ctx, &mut output, 8, dims);
384
385        // All values should be the same DC value
386        let dc_value = output[0];
387        assert!(output.iter().all(|&v| v == dc_value));
388
389        // Top sum: 100+110+120+130+140+150+160+170 = 1080
390        // Left sum: 80+90+100+110+120+130+140+150 = 920
391        // Total: (1080 + 920) / 16 = 125
392        assert_eq!(dc_value, 125);
393    }
394
395    #[test]
396    fn test_dc_128_predictor() {
397        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
398        ctx.set_availability(false, false);
399
400        let predictor = Dc128Predictor::new(BitDepth::Bits8);
401        let dims = BlockDimensions::new(4, 4);
402        let mut output = vec![0u16; 16];
403
404        predictor.predict(&ctx, &mut output, 4, dims);
405
406        // All values should be 128
407        assert!(output.iter().all(|&v| v == 128));
408    }
409
410    #[test]
411    fn test_dc_top_predictor() {
412        let ctx = create_test_context();
413        let predictor = DcTopPredictor::new(BitDepth::Bits8);
414        let dims = BlockDimensions::new(8, 8);
415        let mut output = vec![0u16; 64];
416
417        predictor.predict(&ctx, &mut output, 8, dims);
418
419        // Top sum: 1080 / 8 = 135
420        assert!(output.iter().all(|&v| v == 135));
421    }
422
423    #[test]
424    fn test_dc_left_predictor() {
425        let ctx = create_test_context();
426        let predictor = DcLeftPredictor::new(BitDepth::Bits8);
427        let dims = BlockDimensions::new(8, 8);
428        let mut output = vec![0u16; 64];
429
430        predictor.predict(&ctx, &mut output, 8, dims);
431
432        // Left sum: 920 / 8 = 115
433        assert!(output.iter().all(|&v| v == 115));
434    }
435
436    #[test]
437    fn test_dc_mode_determination() {
438        let predictor = DcPredictor::new(BitDepth::Bits8);
439
440        let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
441
442        ctx.set_availability(true, true);
443        assert_eq!(predictor.determine_mode(&ctx), DcMode::Both);
444
445        ctx.set_availability(true, false);
446        assert_eq!(predictor.determine_mode(&ctx), DcMode::TopOnly);
447
448        ctx.set_availability(false, true);
449        assert_eq!(predictor.determine_mode(&ctx), DcMode::LeftOnly);
450
451        ctx.set_availability(false, false);
452        assert_eq!(predictor.determine_mode(&ctx), DcMode::NoNeighbors);
453    }
454
455    #[test]
456    fn test_bit_depth_10() {
457        let predictor = Dc128Predictor::new(BitDepth::Bits10);
458        let ctx = IntraPredContext::new(4, 4, BitDepth::Bits10);
459        let dims = BlockDimensions::new(4, 4);
460        let mut output = vec![0u16; 16];
461
462        predictor.predict(&ctx, &mut output, 4, dims);
463
464        // Midpoint for 10-bit is 512
465        assert!(output.iter().all(|&v| v == 512));
466    }
467}