Skip to main content

oximedia_codec/intra/
filter.rs

1//! Intra edge filter implementations.
2//!
3//! Intra edge filtering smooths the neighbor samples before prediction
4//! to reduce blocking artifacts. The filter strength depends on the
5//! prediction angle and block size.
6//!
7//! # Filter Types
8//!
9//! - **Weak filter**: 3-tap [1, 2, 1] / 4
10//! - **Strong filter**: 5-tap [1, 2, 2, 2, 1] / 8
11//! - **Adaptive filter**: Selects based on edge strength
12//!
13//! # Application
14//!
15//! Edge filtering is typically applied to:
16//! - Directional modes at steep angles
17//! - Larger block sizes where artifacts are more visible
18
19#![forbid(unsafe_code)]
20#![allow(dead_code)]
21#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::trivially_copy_pass_by_ref)]
23#![allow(clippy::manual_div_ceil)]
24#![allow(clippy::manual_rem_euclid)]
25
26use super::{BitDepth, BlockDimensions, IntraPredContext, MAX_NEIGHBOR_SAMPLES};
27
28/// Filter strength levels.
29#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
30pub enum FilterStrength {
31    /// No filtering.
32    #[default]
33    None,
34    /// Weak 3-tap filter.
35    Weak,
36    /// Strong 5-tap filter.
37    Strong,
38}
39
40impl FilterStrength {
41    /// Determine filter strength based on angle and block size.
42    #[must_use]
43    pub fn from_angle_and_size(angle: i16, width: usize, height: usize) -> Self {
44        // Angles close to 45, 135, 225, or 315 degrees benefit from filtering
45        let is_steep = is_steep_angle(angle);
46
47        // Larger blocks need stronger filtering
48        let min_dim = width.min(height);
49
50        if !is_steep {
51            Self::None
52        } else if min_dim >= 16 {
53            Self::Strong
54        } else if min_dim >= 8 {
55            Self::Weak
56        } else {
57            Self::None
58        }
59    }
60}
61
62/// Check if an angle is steep (close to diagonal).
63#[must_use]
64fn is_steep_angle(angle: i16) -> bool {
65    // Normalize to 0-360
66    let angle = ((angle % 360) + 360) % 360;
67
68    // Steep angles are within 22.5 degrees of diagonals (45, 135, 225, 315)
69    let diagonals = [45, 135, 225, 315];
70    diagonals.iter().any(|&d| (angle - d).abs() < 23)
71}
72
73/// Intra edge filter.
74#[derive(Clone, Copy, Debug, Default)]
75pub struct IntraEdgeFilter {
76    /// Filter strength.
77    strength: FilterStrength,
78    /// Bit depth for clamping.
79    bit_depth: BitDepth,
80}
81
82impl IntraEdgeFilter {
83    /// Create a new intra edge filter.
84    #[must_use]
85    pub const fn new(strength: FilterStrength, bit_depth: BitDepth) -> Self {
86        Self {
87            strength,
88            bit_depth,
89        }
90    }
91
92    /// Create with automatic strength selection.
93    #[must_use]
94    pub fn auto(angle: i16, dims: BlockDimensions, bit_depth: BitDepth) -> Self {
95        let strength = FilterStrength::from_angle_and_size(angle, dims.width, dims.height);
96        Self {
97            strength,
98            bit_depth,
99        }
100    }
101
102    /// Get the filter strength.
103    #[must_use]
104    pub const fn strength(&self) -> FilterStrength {
105        self.strength
106    }
107
108    /// Apply filter to top samples.
109    pub fn filter_top(&self, samples: &mut [u16], count: usize) {
110        match self.strength {
111            FilterStrength::None => {}
112            FilterStrength::Weak => self.apply_weak_filter(samples, count),
113            FilterStrength::Strong => self.apply_strong_filter(samples, count),
114        }
115    }
116
117    /// Apply filter to left samples.
118    pub fn filter_left(&self, samples: &mut [u16], count: usize) {
119        // Same filter, different orientation
120        match self.strength {
121            FilterStrength::None => {}
122            FilterStrength::Weak => self.apply_weak_filter(samples, count),
123            FilterStrength::Strong => self.apply_strong_filter(samples, count),
124        }
125    }
126
127    /// Apply weak 3-tap filter [1, 2, 1] / 4.
128    fn apply_weak_filter(&self, samples: &mut [u16], count: usize) {
129        if count < 3 {
130            return;
131        }
132
133        let max_val = self.bit_depth.max_value();
134        let mut filtered = [0u16; MAX_NEIGHBOR_SAMPLES];
135
136        // First sample unchanged
137        filtered[0] = samples[0];
138
139        // Apply 3-tap filter to middle samples
140        for i in 1..count.saturating_sub(1) {
141            let sum =
142                u32::from(samples[i - 1]) + 2 * u32::from(samples[i]) + u32::from(samples[i + 1]);
143            let val = (sum + 2) / 4;
144            filtered[i] = val.min(u32::from(max_val)) as u16;
145        }
146
147        // Last sample unchanged
148        if count > 1 {
149            filtered[count - 1] = samples[count - 1];
150        }
151
152        // Copy back
153        samples[..count].copy_from_slice(&filtered[..count]);
154    }
155
156    /// Apply strong 5-tap filter [1, 2, 2, 2, 1] / 8.
157    fn apply_strong_filter(&self, samples: &mut [u16], count: usize) {
158        if count < 5 {
159            // Fall back to weak filter for small arrays
160            self.apply_weak_filter(samples, count);
161            return;
162        }
163
164        let max_val = self.bit_depth.max_value();
165        let mut filtered = [0u16; MAX_NEIGHBOR_SAMPLES];
166
167        // First two samples get special treatment
168        filtered[0] = samples[0];
169        if count > 1 {
170            let sum = u32::from(samples[0]) + 2 * u32::from(samples[1]) + u32::from(samples[2]);
171            filtered[1] = ((sum + 2) / 4).min(u32::from(max_val)) as u16;
172        }
173
174        // Apply 5-tap filter to middle samples
175        for i in 2..count.saturating_sub(2) {
176            let sum = u32::from(samples[i - 2])
177                + 2 * u32::from(samples[i - 1])
178                + 2 * u32::from(samples[i])
179                + 2 * u32::from(samples[i + 1])
180                + u32::from(samples[i + 2]);
181            let val = (sum + 4) / 8;
182            filtered[i] = val.min(u32::from(max_val)) as u16;
183        }
184
185        // Last two samples get special treatment
186        if count > 2 {
187            let i = count - 2;
188            let sum =
189                u32::from(samples[i - 1]) + 2 * u32::from(samples[i]) + u32::from(samples[i + 1]);
190            filtered[i] = ((sum + 2) / 4).min(u32::from(max_val)) as u16;
191        }
192        if count > 1 {
193            filtered[count - 1] = samples[count - 1];
194        }
195
196        // Copy back
197        samples[..count].copy_from_slice(&filtered[..count]);
198    }
199}
200
201/// Apply intra filter to prediction context.
202pub fn apply_intra_filter(ctx: &mut IntraPredContext, angle: i16, dims: BlockDimensions) {
203    let filter = IntraEdgeFilter::auto(angle, dims, ctx.bit_depth());
204
205    if filter.strength() == FilterStrength::None {
206        return;
207    }
208
209    // Get mutable references to samples and filter them
210    let top_count = dims.width + dims.height;
211    let left_count = dims.height + dims.width;
212
213    ctx.filter_top_samples(|samples| {
214        filter.filter_top(samples, top_count.min(samples.len()));
215    });
216
217    ctx.filter_left_samples(|samples| {
218        filter.filter_left(samples, left_count.min(samples.len()));
219    });
220}
221
222/// Recursive intra prediction helper.
223///
224/// Applies intra prediction using a recursive filter approach
225/// that considers previously predicted samples.
226pub struct RecursiveIntraHelper {
227    bit_depth: BitDepth,
228}
229
230impl RecursiveIntraHelper {
231    /// Create a new recursive intra helper.
232    #[must_use]
233    pub const fn new(bit_depth: BitDepth) -> Self {
234        Self { bit_depth }
235    }
236
237    /// Apply recursive filtering to predicted samples.
238    ///
239    /// This smooths the prediction by considering previously predicted
240    /// samples in the current block.
241    pub fn apply_recursive_filter(
242        &self,
243        output: &mut [u16],
244        stride: usize,
245        dims: BlockDimensions,
246        filter_type: RecursiveFilterType,
247    ) {
248        match filter_type {
249            RecursiveFilterType::None => {}
250            RecursiveFilterType::Horizontal => {
251                self.filter_horizontal(output, stride, dims);
252            }
253            RecursiveFilterType::Vertical => {
254                self.filter_vertical(output, stride, dims);
255            }
256            RecursiveFilterType::Both => {
257                self.filter_horizontal(output, stride, dims);
258                self.filter_vertical(output, stride, dims);
259            }
260        }
261    }
262
263    /// Apply horizontal recursive filter.
264    fn filter_horizontal(&self, output: &mut [u16], stride: usize, dims: BlockDimensions) {
265        let max_val = self.bit_depth.max_value();
266
267        for y in 0..dims.height {
268            let row_start = y * stride;
269            for x in 1..dims.width {
270                let prev = u32::from(output[row_start + x - 1]);
271                let curr = u32::from(output[row_start + x]);
272                let filtered = (prev + curr + 1) / 2;
273                output[row_start + x] = filtered.min(u32::from(max_val)) as u16;
274            }
275        }
276    }
277
278    /// Apply vertical recursive filter.
279    fn filter_vertical(&self, output: &mut [u16], stride: usize, dims: BlockDimensions) {
280        let max_val = self.bit_depth.max_value();
281
282        for x in 0..dims.width {
283            for y in 1..dims.height {
284                let prev = u32::from(output[(y - 1) * stride + x]);
285                let curr = u32::from(output[y * stride + x]);
286                let filtered = (prev + curr + 1) / 2;
287                output[y * stride + x] = filtered.min(u32::from(max_val)) as u16;
288            }
289        }
290    }
291}
292
293/// Recursive filter type.
294#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
295pub enum RecursiveFilterType {
296    /// No recursive filtering.
297    #[default]
298    None,
299    /// Horizontal recursive filter.
300    Horizontal,
301    /// Vertical recursive filter.
302    Vertical,
303    /// Both horizontal and vertical.
304    Both,
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    fn test_filter_strength_selection() {
313        // Diagonal angle, large block -> strong
314        let strength = FilterStrength::from_angle_and_size(45, 16, 16);
315        assert_eq!(strength, FilterStrength::Strong);
316
317        // Diagonal angle, medium block -> weak
318        let strength = FilterStrength::from_angle_and_size(45, 8, 8);
319        assert_eq!(strength, FilterStrength::Weak);
320
321        // Diagonal angle, small block -> none
322        let strength = FilterStrength::from_angle_and_size(45, 4, 4);
323        assert_eq!(strength, FilterStrength::None);
324
325        // Non-diagonal angle -> none
326        let strength = FilterStrength::from_angle_and_size(90, 16, 16);
327        assert_eq!(strength, FilterStrength::None);
328    }
329
330    #[test]
331    fn test_is_steep_angle() {
332        assert!(is_steep_angle(45));
333        assert!(is_steep_angle(50));
334        assert!(is_steep_angle(135));
335        assert!(is_steep_angle(315));
336
337        assert!(!is_steep_angle(0));
338        assert!(!is_steep_angle(90));
339        assert!(!is_steep_angle(180));
340        assert!(!is_steep_angle(270));
341    }
342
343    #[test]
344    fn test_weak_filter() {
345        let filter = IntraEdgeFilter::new(FilterStrength::Weak, BitDepth::Bits8);
346        let mut samples = [100u16, 150, 200, 150, 100];
347
348        filter.apply_weak_filter(&mut samples, 5);
349
350        // First and last unchanged
351        assert_eq!(samples[0], 100);
352        assert_eq!(samples[4], 100);
353
354        // Middle samples smoothed
355        // samples[1] = (100 + 2*150 + 200 + 2) / 4 = 150
356        // samples[2] = (150 + 2*200 + 150 + 2) / 4 = 175
357        // samples[3] = (200 + 2*150 + 100 + 2) / 4 = 150
358        assert!(samples[1] >= 140 && samples[1] <= 160);
359        assert!(samples[2] >= 170 && samples[2] <= 180);
360        assert!(samples[3] >= 140 && samples[3] <= 160);
361    }
362
363    #[test]
364    fn test_strong_filter() {
365        let filter = IntraEdgeFilter::new(FilterStrength::Strong, BitDepth::Bits8);
366        let mut samples = [100u16, 110, 200, 190, 100, 110, 100];
367
368        filter.apply_strong_filter(&mut samples, 7);
369
370        // First unchanged
371        assert_eq!(samples[0], 100);
372        // Last unchanged
373        assert_eq!(samples[6], 100);
374
375        // Middle samples should be smoothed more than weak filter
376        // All values should be reasonable (between 100 and 200)
377        for sample in &samples {
378            assert!(*sample >= 100 && *sample <= 200);
379        }
380    }
381
382    #[test]
383    fn test_filter_clipping() {
384        let filter = IntraEdgeFilter::new(FilterStrength::Weak, BitDepth::Bits8);
385        let mut samples = [250u16, 255, 255, 255, 250];
386
387        filter.apply_weak_filter(&mut samples, 5);
388
389        // All values should be <= 255
390        for sample in &samples {
391            assert!(*sample <= 255);
392        }
393    }
394
395    #[test]
396    fn test_recursive_helper_horizontal() {
397        let helper = RecursiveIntraHelper::new(BitDepth::Bits8);
398        let mut output = vec![100u16, 200, 100, 200];
399        let dims = BlockDimensions::new(4, 1);
400
401        helper.filter_horizontal(&mut output, 4, dims);
402
403        // Each sample averaged with previous
404        // [100, 150, 125, 162] approximately
405        assert_eq!(output[0], 100);
406        assert!(output[1] > 100 && output[1] < 200);
407    }
408
409    #[test]
410    fn test_recursive_helper_vertical() {
411        let helper = RecursiveIntraHelper::new(BitDepth::Bits8);
412        let mut output = vec![100u16, 100, 200, 200, 100, 100, 200, 200];
413        let dims = BlockDimensions::new(2, 4);
414
415        helper.filter_vertical(&mut output, 2, dims);
416
417        // First row unchanged
418        assert_eq!(output[0], 100);
419        assert_eq!(output[1], 100);
420
421        // Subsequent rows averaged with previous
422        assert!(output[2] > 100 && output[2] < 200);
423    }
424
425    #[test]
426    fn test_auto_filter_creation() {
427        let filter = IntraEdgeFilter::auto(45, BlockDimensions::new(16, 16), BitDepth::Bits8);
428        assert_eq!(filter.strength(), FilterStrength::Strong);
429
430        let filter = IntraEdgeFilter::auto(90, BlockDimensions::new(16, 16), BitDepth::Bits8);
431        assert_eq!(filter.strength(), FilterStrength::None);
432    }
433}