Skip to main content

oximedia_codec/intra/
smooth.rs

1//! Smooth prediction implementations (AV1).
2//!
3//! Smooth prediction uses weighted interpolation between neighboring samples
4//! to create gradual transitions. Three variants are provided:
5//!
6//! - **SMOOTH**: Bilinear interpolation between all neighbors
7//! - **SMOOTH_V**: Vertical interpolation (top to bottom-left)
8//! - **SMOOTH_H**: Horizontal interpolation (left to top-right)
9//!
10//! Weight tables define the blending factors at each position.
11
12#![forbid(unsafe_code)]
13#![allow(dead_code)]
14#![allow(clippy::similar_names)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::doc_markdown)]
17#![allow(clippy::needless_range_loop)]
18
19use super::{BlockDimensions, IntraPredContext, IntraPredictor};
20
21/// Smooth weight tables for different block sizes.
22/// Weights are in 1/256ths and decrease from edge to center.
23pub mod weights {
24    /// Weights for 4-sample blocks.
25    pub const SMOOTH_WEIGHTS_4: [u16; 4] = [255, 149, 85, 64];
26
27    /// Weights for 8-sample blocks.
28    pub const SMOOTH_WEIGHTS_8: [u16; 8] = [255, 197, 146, 105, 73, 50, 37, 32];
29
30    /// Weights for 16-sample blocks.
31    pub const SMOOTH_WEIGHTS_16: [u16; 16] = [
32        255, 225, 196, 170, 145, 123, 102, 84, 68, 54, 43, 33, 26, 20, 17, 16,
33    ];
34
35    /// Weights for 32-sample blocks.
36    pub const SMOOTH_WEIGHTS_32: [u16; 32] = [
37        255, 240, 225, 210, 196, 182, 169, 157, 145, 133, 122, 111, 101, 92, 83, 74, 66, 59, 52,
38        45, 39, 34, 29, 25, 21, 17, 14, 12, 10, 9, 8, 8,
39    ];
40
41    /// Weights for 64-sample blocks.
42    pub const SMOOTH_WEIGHTS_64: [u16; 64] = [
43        255, 248, 240, 233, 225, 218, 210, 203, 196, 189, 182, 176, 169, 163, 156, 150, 144, 138,
44        133, 127, 121, 116, 111, 106, 101, 96, 91, 86, 82, 77, 73, 69, 65, 61, 57, 54, 50, 47, 44,
45        41, 38, 35, 32, 29, 27, 25, 22, 20, 18, 16, 15, 13, 12, 10, 9, 8, 7, 6, 6, 5, 5, 4, 4, 4,
46    ];
47
48    /// Get weight table for a given size.
49    #[must_use]
50    pub fn get_weights(size: usize) -> &'static [u16] {
51        match size {
52            4 => &SMOOTH_WEIGHTS_4,
53            8 => &SMOOTH_WEIGHTS_8,
54            16 => &SMOOTH_WEIGHTS_16,
55            32 => &SMOOTH_WEIGHTS_32,
56            64 => &SMOOTH_WEIGHTS_64,
57            _ => {
58                // For sizes > 64, use the 64 table
59                if size > 64 {
60                    &SMOOTH_WEIGHTS_64
61                } else {
62                    // Fallback to nearest smaller
63                    if size > 32 {
64                        &SMOOTH_WEIGHTS_32
65                    } else if size > 16 {
66                        &SMOOTH_WEIGHTS_16
67                    } else if size > 8 {
68                        &SMOOTH_WEIGHTS_8
69                    } else {
70                        &SMOOTH_WEIGHTS_4
71                    }
72                }
73            }
74        }
75    }
76
77    /// Interpolate weight for sizes not in the table.
78    #[must_use]
79    pub fn interpolate_weight(size: usize, idx: usize) -> u16 {
80        let weights = get_weights(size);
81        let table_size = weights.len();
82
83        if size == table_size {
84            return weights[idx];
85        }
86
87        // Scale index to table size
88        let scaled_idx = (idx * table_size) / size;
89        let frac = ((idx * table_size) % size) * 256 / size;
90
91        let w0 = weights[scaled_idx.min(table_size - 1)];
92        let w1 = weights[(scaled_idx + 1).min(table_size - 1)];
93
94        // Linear interpolation
95        let w0_32 = u32::from(w0);
96        let w1_32 = u32::from(w1);
97        let result = (w0_32 * (256 - frac as u32) + w1_32 * frac as u32 + 128) / 256;
98        result as u16
99    }
100}
101
102/// Smooth predictor (bilinear interpolation).
103#[derive(Clone, Copy, Debug, Default)]
104pub struct SmoothPredictor;
105
106impl SmoothPredictor {
107    /// Create a new smooth predictor.
108    #[must_use]
109    pub const fn new() -> Self {
110        Self
111    }
112
113    /// Perform smooth prediction.
114    pub fn predict_smooth(
115        ctx: &IntraPredContext,
116        output: &mut [u16],
117        stride: usize,
118        dims: BlockDimensions,
119    ) {
120        let top = ctx.top_samples();
121        let left = ctx.left_samples();
122
123        // Get bottom-left and top-right samples for interpolation
124        let bottom_left = left[dims.height.saturating_sub(1)];
125        let top_right = top[dims.width.saturating_sub(1)];
126
127        let weights_x = weights::get_weights(dims.width);
128        let weights_y = weights::get_weights(dims.height);
129
130        for y in 0..dims.height {
131            let row_start = y * stride;
132            let weight_y = if y < weights_y.len() {
133                weights_y[y]
134            } else {
135                weights::interpolate_weight(dims.height, y)
136            };
137
138            for x in 0..dims.width {
139                let weight_x = if x < weights_x.len() {
140                    weights_x[x]
141                } else {
142                    weights::interpolate_weight(dims.width, x)
143                };
144
145                // Bilinear interpolation
146                // pred = (weight_y * top[x] + (256 - weight_y) * bottom_left
147                //       + weight_x * left[y] + (256 - weight_x) * top_right + 256) / 512
148                let top_sample = u32::from(top[x]);
149                let left_sample = u32::from(left[y]);
150                let bl = u32::from(bottom_left);
151                let tr = u32::from(top_right);
152
153                let wy = u32::from(weight_y);
154                let wx = u32::from(weight_x);
155
156                let vertical = wy * top_sample + (256 - wy) * bl;
157                let horizontal = wx * left_sample + (256 - wx) * tr;
158
159                let pred = (vertical + horizontal + 256) / 512;
160                output[row_start + x] = pred as u16;
161            }
162        }
163    }
164}
165
166impl IntraPredictor for SmoothPredictor {
167    fn predict(
168        &self,
169        ctx: &IntraPredContext,
170        output: &mut [u16],
171        stride: usize,
172        dims: BlockDimensions,
173    ) {
174        Self::predict_smooth(ctx, output, stride, dims);
175    }
176}
177
178/// Smooth-V predictor (vertical smooth).
179#[derive(Clone, Copy, Debug, Default)]
180pub struct SmoothVPredictor;
181
182impl SmoothVPredictor {
183    /// Create a new smooth-V predictor.
184    #[must_use]
185    pub const fn new() -> Self {
186        Self
187    }
188
189    /// Perform smooth-V prediction.
190    pub fn predict_smooth_v(
191        ctx: &IntraPredContext,
192        output: &mut [u16],
193        stride: usize,
194        dims: BlockDimensions,
195    ) {
196        let top = ctx.top_samples();
197        let left = ctx.left_samples();
198
199        // Bottom-left sample for vertical interpolation
200        let bottom_left = left[dims.height.saturating_sub(1)];
201
202        let weights_y = weights::get_weights(dims.height);
203
204        for y in 0..dims.height {
205            let row_start = y * stride;
206            let weight_y = if y < weights_y.len() {
207                weights_y[y]
208            } else {
209                weights::interpolate_weight(dims.height, y)
210            };
211
212            for x in 0..dims.width {
213                // Vertical interpolation only
214                // pred = (weight_y * top[x] + (256 - weight_y) * bottom_left + 128) / 256
215                let top_sample = u32::from(top[x]);
216                let bl = u32::from(bottom_left);
217                let wy = u32::from(weight_y);
218
219                let pred = (wy * top_sample + (256 - wy) * bl + 128) / 256;
220                output[row_start + x] = pred as u16;
221            }
222        }
223    }
224}
225
226impl IntraPredictor for SmoothVPredictor {
227    fn predict(
228        &self,
229        ctx: &IntraPredContext,
230        output: &mut [u16],
231        stride: usize,
232        dims: BlockDimensions,
233    ) {
234        Self::predict_smooth_v(ctx, output, stride, dims);
235    }
236}
237
238/// Smooth-H predictor (horizontal smooth).
239#[derive(Clone, Copy, Debug, Default)]
240pub struct SmoothHPredictor;
241
242impl SmoothHPredictor {
243    /// Create a new smooth-H predictor.
244    #[must_use]
245    pub const fn new() -> Self {
246        Self
247    }
248
249    /// Perform smooth-H prediction.
250    pub fn predict_smooth_h(
251        ctx: &IntraPredContext,
252        output: &mut [u16],
253        stride: usize,
254        dims: BlockDimensions,
255    ) {
256        let top = ctx.top_samples();
257        let left = ctx.left_samples();
258
259        // Top-right sample for horizontal interpolation
260        let top_right = top[dims.width.saturating_sub(1)];
261
262        let weights_x = weights::get_weights(dims.width);
263
264        for y in 0..dims.height {
265            let row_start = y * stride;
266            let left_sample = u32::from(left[y]);
267
268            for x in 0..dims.width {
269                let weight_x = if x < weights_x.len() {
270                    weights_x[x]
271                } else {
272                    weights::interpolate_weight(dims.width, x)
273                };
274
275                // Horizontal interpolation only
276                // pred = (weight_x * left[y] + (256 - weight_x) * top_right + 128) / 256
277                let tr = u32::from(top_right);
278                let wx = u32::from(weight_x);
279
280                let pred = (wx * left_sample + (256 - wx) * tr + 128) / 256;
281                output[row_start + x] = pred as u16;
282            }
283        }
284    }
285}
286
287impl IntraPredictor for SmoothHPredictor {
288    fn predict(
289        &self,
290        ctx: &IntraPredContext,
291        output: &mut [u16],
292        stride: usize,
293        dims: BlockDimensions,
294    ) {
295        Self::predict_smooth_h(ctx, output, stride, dims);
296    }
297}
298
299/// Bilinear interpolation helper for smooth modes.
300#[inline]
301pub fn bilinear_interpolate(
302    top: u16,
303    left: u16,
304    bottom_left: u16,
305    top_right: u16,
306    weight_x: u16,
307    weight_y: u16,
308) -> u16 {
309    let t = u32::from(top);
310    let l = u32::from(left);
311    let bl = u32::from(bottom_left);
312    let tr = u32::from(top_right);
313    let wx = u32::from(weight_x);
314    let wy = u32::from(weight_y);
315
316    // Bilinear blend
317    let vertical = wy * t + (256 - wy) * bl;
318    let horizontal = wx * l + (256 - wx) * tr;
319
320    let result = (vertical + horizontal + 256) / 512;
321    result as u16
322}
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use crate::intra::context::IntraPredContext;
328    use crate::intra::BitDepth;
329
330    fn create_test_context() -> IntraPredContext {
331        let mut ctx = IntraPredContext::new(8, 8, BitDepth::Bits8);
332
333        // Set uniform top samples: all 200
334        for i in 0..16 {
335            ctx.set_top_sample(i, 200);
336        }
337
338        // Set uniform left samples: all 100
339        for i in 0..16 {
340            ctx.set_left_sample(i, 100);
341        }
342
343        ctx.set_top_left_sample(150);
344        ctx.set_availability(true, true);
345
346        ctx
347    }
348
349    #[test]
350    fn test_smooth_weights() {
351        // Check weight tables exist and are decreasing
352        let w4 = weights::get_weights(4);
353        assert_eq!(w4.len(), 4);
354        assert!(w4[0] > w4[3]);
355
356        let w8 = weights::get_weights(8);
357        assert_eq!(w8.len(), 8);
358        assert!(w8[0] > w8[7]);
359
360        let w16 = weights::get_weights(16);
361        assert_eq!(w16.len(), 16);
362        assert_eq!(w16[0], 255);
363    }
364
365    #[test]
366    fn test_smooth_prediction() {
367        let ctx = create_test_context();
368        let predictor = SmoothPredictor::new();
369        let dims = BlockDimensions::new(4, 4);
370        let mut output = vec![0u16; 16];
371
372        predictor.predict(&ctx, &mut output, 4, dims);
373
374        // All outputs should be between 100 and 200 (the left and top values)
375        for &val in &output {
376            assert!(val >= 100 && val <= 200, "Value {} out of range", val);
377        }
378
379        // Top-left corner should be closer to average
380        // Bottom-right corner should blend more
381        assert!(output[0] >= output[15] - 50);
382    }
383
384    #[test]
385    fn test_smooth_v_prediction() {
386        let ctx = create_test_context();
387        let predictor = SmoothVPredictor::new();
388        let dims = BlockDimensions::new(4, 4);
389        let mut output = vec![0u16; 16];
390
391        predictor.predict(&ctx, &mut output, 4, dims);
392
393        // Each row should have the same value (vertical interpolation)
394        for y in 0..4 {
395            let row_start = y * 4;
396            let first = output[row_start];
397            for x in 1..4 {
398                assert_eq!(output[row_start + x], first, "Row {} not uniform", y);
399            }
400        }
401
402        // Values should decrease from top to bottom (top=200, bottom_left=100)
403        assert!(output[0] > output[12]);
404    }
405
406    #[test]
407    fn test_smooth_h_prediction() {
408        let ctx = create_test_context();
409        let predictor = SmoothHPredictor::new();
410        let dims = BlockDimensions::new(4, 4);
411        let mut output = vec![0u16; 16];
412
413        predictor.predict(&ctx, &mut output, 4, dims);
414
415        // Each column should have the same value (horizontal interpolation)
416        for x in 0..4 {
417            let first = output[x];
418            for y in 1..4 {
419                assert_eq!(output[y * 4 + x], first, "Column {} not uniform", x);
420            }
421        }
422
423        // Values should increase from left to right (left=100, top_right=200)
424        assert!(output[0] < output[3]);
425    }
426
427    #[test]
428    fn test_bilinear_interpolate() {
429        // Equal weights should give average
430        let result = bilinear_interpolate(100, 100, 100, 100, 128, 128);
431        assert_eq!(result, 100);
432
433        // Different samples
434        let result = bilinear_interpolate(200, 100, 100, 200, 128, 128);
435        assert!(result >= 140 && result <= 160);
436    }
437
438    #[test]
439    fn test_weight_interpolation() {
440        // Test interpolation for non-standard sizes
441        let w = weights::interpolate_weight(6, 0);
442        assert!(w > 200); // Should be high at edge
443
444        let w = weights::interpolate_weight(6, 5);
445        assert!(w < 100); // Should be lower at center
446    }
447
448    #[test]
449    fn test_smooth_rectangular_block() {
450        let ctx = create_test_context();
451        let predictor = SmoothPredictor::new();
452        let dims = BlockDimensions::new(8, 4);
453        let mut output = vec![0u16; 32];
454
455        predictor.predict(&ctx, &mut output, 8, dims);
456
457        // All values should be in valid range
458        for &val in &output {
459            assert!(val >= 100 && val <= 200);
460        }
461    }
462}