1#![forbid(unsafe_code)]
20#![allow(dead_code)]
21#![allow(clippy::cast_possible_truncation)]
22#![allow(clippy::trivially_copy_pass_by_ref)]
23#![allow(clippy::manual_div_ceil)]
24#![allow(clippy::manual_rem_euclid)]
25
26use super::{BitDepth, BlockDimensions, IntraPredContext, MAX_NEIGHBOR_SAMPLES};
27
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
30pub enum FilterStrength {
31 #[default]
33 None,
34 Weak,
36 Strong,
38}
39
40impl FilterStrength {
41 #[must_use]
43 pub fn from_angle_and_size(angle: i16, width: usize, height: usize) -> Self {
44 let is_steep = is_steep_angle(angle);
46
47 let min_dim = width.min(height);
49
50 if !is_steep {
51 Self::None
52 } else if min_dim >= 16 {
53 Self::Strong
54 } else if min_dim >= 8 {
55 Self::Weak
56 } else {
57 Self::None
58 }
59 }
60}
61
62#[must_use]
64fn is_steep_angle(angle: i16) -> bool {
65 let angle = ((angle % 360) + 360) % 360;
67
68 let diagonals = [45, 135, 225, 315];
70 diagonals.iter().any(|&d| (angle - d).abs() < 23)
71}
72
73#[derive(Clone, Copy, Debug, Default)]
75pub struct IntraEdgeFilter {
76 strength: FilterStrength,
78 bit_depth: BitDepth,
80}
81
82impl IntraEdgeFilter {
83 #[must_use]
85 pub const fn new(strength: FilterStrength, bit_depth: BitDepth) -> Self {
86 Self {
87 strength,
88 bit_depth,
89 }
90 }
91
92 #[must_use]
94 pub fn auto(angle: i16, dims: BlockDimensions, bit_depth: BitDepth) -> Self {
95 let strength = FilterStrength::from_angle_and_size(angle, dims.width, dims.height);
96 Self {
97 strength,
98 bit_depth,
99 }
100 }
101
102 #[must_use]
104 pub const fn strength(&self) -> FilterStrength {
105 self.strength
106 }
107
108 pub fn filter_top(&self, samples: &mut [u16], count: usize) {
110 match self.strength {
111 FilterStrength::None => {}
112 FilterStrength::Weak => self.apply_weak_filter(samples, count),
113 FilterStrength::Strong => self.apply_strong_filter(samples, count),
114 }
115 }
116
117 pub fn filter_left(&self, samples: &mut [u16], count: usize) {
119 match self.strength {
121 FilterStrength::None => {}
122 FilterStrength::Weak => self.apply_weak_filter(samples, count),
123 FilterStrength::Strong => self.apply_strong_filter(samples, count),
124 }
125 }
126
127 fn apply_weak_filter(&self, samples: &mut [u16], count: usize) {
129 if count < 3 {
130 return;
131 }
132
133 let max_val = self.bit_depth.max_value();
134 let mut filtered = [0u16; MAX_NEIGHBOR_SAMPLES];
135
136 filtered[0] = samples[0];
138
139 for i in 1..count.saturating_sub(1) {
141 let sum =
142 u32::from(samples[i - 1]) + 2 * u32::from(samples[i]) + u32::from(samples[i + 1]);
143 let val = (sum + 2) / 4;
144 filtered[i] = val.min(u32::from(max_val)) as u16;
145 }
146
147 if count > 1 {
149 filtered[count - 1] = samples[count - 1];
150 }
151
152 samples[..count].copy_from_slice(&filtered[..count]);
154 }
155
156 fn apply_strong_filter(&self, samples: &mut [u16], count: usize) {
158 if count < 5 {
159 self.apply_weak_filter(samples, count);
161 return;
162 }
163
164 let max_val = self.bit_depth.max_value();
165 let mut filtered = [0u16; MAX_NEIGHBOR_SAMPLES];
166
167 filtered[0] = samples[0];
169 if count > 1 {
170 let sum = u32::from(samples[0]) + 2 * u32::from(samples[1]) + u32::from(samples[2]);
171 filtered[1] = ((sum + 2) / 4).min(u32::from(max_val)) as u16;
172 }
173
174 for i in 2..count.saturating_sub(2) {
176 let sum = u32::from(samples[i - 2])
177 + 2 * u32::from(samples[i - 1])
178 + 2 * u32::from(samples[i])
179 + 2 * u32::from(samples[i + 1])
180 + u32::from(samples[i + 2]);
181 let val = (sum + 4) / 8;
182 filtered[i] = val.min(u32::from(max_val)) as u16;
183 }
184
185 if count > 2 {
187 let i = count - 2;
188 let sum =
189 u32::from(samples[i - 1]) + 2 * u32::from(samples[i]) + u32::from(samples[i + 1]);
190 filtered[i] = ((sum + 2) / 4).min(u32::from(max_val)) as u16;
191 }
192 if count > 1 {
193 filtered[count - 1] = samples[count - 1];
194 }
195
196 samples[..count].copy_from_slice(&filtered[..count]);
198 }
199}
200
201pub fn apply_intra_filter(ctx: &mut IntraPredContext, angle: i16, dims: BlockDimensions) {
203 let filter = IntraEdgeFilter::auto(angle, dims, ctx.bit_depth());
204
205 if filter.strength() == FilterStrength::None {
206 return;
207 }
208
209 let top_count = dims.width + dims.height;
211 let left_count = dims.height + dims.width;
212
213 ctx.filter_top_samples(|samples| {
214 filter.filter_top(samples, top_count.min(samples.len()));
215 });
216
217 ctx.filter_left_samples(|samples| {
218 filter.filter_left(samples, left_count.min(samples.len()));
219 });
220}
221
222pub struct RecursiveIntraHelper {
227 bit_depth: BitDepth,
228}
229
230impl RecursiveIntraHelper {
231 #[must_use]
233 pub const fn new(bit_depth: BitDepth) -> Self {
234 Self { bit_depth }
235 }
236
237 pub fn apply_recursive_filter(
242 &self,
243 output: &mut [u16],
244 stride: usize,
245 dims: BlockDimensions,
246 filter_type: RecursiveFilterType,
247 ) {
248 match filter_type {
249 RecursiveFilterType::None => {}
250 RecursiveFilterType::Horizontal => {
251 self.filter_horizontal(output, stride, dims);
252 }
253 RecursiveFilterType::Vertical => {
254 self.filter_vertical(output, stride, dims);
255 }
256 RecursiveFilterType::Both => {
257 self.filter_horizontal(output, stride, dims);
258 self.filter_vertical(output, stride, dims);
259 }
260 }
261 }
262
263 fn filter_horizontal(&self, output: &mut [u16], stride: usize, dims: BlockDimensions) {
265 let max_val = self.bit_depth.max_value();
266
267 for y in 0..dims.height {
268 let row_start = y * stride;
269 for x in 1..dims.width {
270 let prev = u32::from(output[row_start + x - 1]);
271 let curr = u32::from(output[row_start + x]);
272 let filtered = (prev + curr + 1) / 2;
273 output[row_start + x] = filtered.min(u32::from(max_val)) as u16;
274 }
275 }
276 }
277
278 fn filter_vertical(&self, output: &mut [u16], stride: usize, dims: BlockDimensions) {
280 let max_val = self.bit_depth.max_value();
281
282 for x in 0..dims.width {
283 for y in 1..dims.height {
284 let prev = u32::from(output[(y - 1) * stride + x]);
285 let curr = u32::from(output[y * stride + x]);
286 let filtered = (prev + curr + 1) / 2;
287 output[y * stride + x] = filtered.min(u32::from(max_val)) as u16;
288 }
289 }
290 }
291}
292
293#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
295pub enum RecursiveFilterType {
296 #[default]
298 None,
299 Horizontal,
301 Vertical,
303 Both,
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_filter_strength_selection() {
313 let strength = FilterStrength::from_angle_and_size(45, 16, 16);
315 assert_eq!(strength, FilterStrength::Strong);
316
317 let strength = FilterStrength::from_angle_and_size(45, 8, 8);
319 assert_eq!(strength, FilterStrength::Weak);
320
321 let strength = FilterStrength::from_angle_and_size(45, 4, 4);
323 assert_eq!(strength, FilterStrength::None);
324
325 let strength = FilterStrength::from_angle_and_size(90, 16, 16);
327 assert_eq!(strength, FilterStrength::None);
328 }
329
330 #[test]
331 fn test_is_steep_angle() {
332 assert!(is_steep_angle(45));
333 assert!(is_steep_angle(50));
334 assert!(is_steep_angle(135));
335 assert!(is_steep_angle(315));
336
337 assert!(!is_steep_angle(0));
338 assert!(!is_steep_angle(90));
339 assert!(!is_steep_angle(180));
340 assert!(!is_steep_angle(270));
341 }
342
343 #[test]
344 fn test_weak_filter() {
345 let filter = IntraEdgeFilter::new(FilterStrength::Weak, BitDepth::Bits8);
346 let mut samples = [100u16, 150, 200, 150, 100];
347
348 filter.apply_weak_filter(&mut samples, 5);
349
350 assert_eq!(samples[0], 100);
352 assert_eq!(samples[4], 100);
353
354 assert!(samples[1] >= 140 && samples[1] <= 160);
359 assert!(samples[2] >= 170 && samples[2] <= 180);
360 assert!(samples[3] >= 140 && samples[3] <= 160);
361 }
362
363 #[test]
364 fn test_strong_filter() {
365 let filter = IntraEdgeFilter::new(FilterStrength::Strong, BitDepth::Bits8);
366 let mut samples = [100u16, 110, 200, 190, 100, 110, 100];
367
368 filter.apply_strong_filter(&mut samples, 7);
369
370 assert_eq!(samples[0], 100);
372 assert_eq!(samples[6], 100);
374
375 for sample in &samples {
378 assert!(*sample >= 100 && *sample <= 200);
379 }
380 }
381
382 #[test]
383 fn test_filter_clipping() {
384 let filter = IntraEdgeFilter::new(FilterStrength::Weak, BitDepth::Bits8);
385 let mut samples = [250u16, 255, 255, 255, 250];
386
387 filter.apply_weak_filter(&mut samples, 5);
388
389 for sample in &samples {
391 assert!(*sample <= 255);
392 }
393 }
394
395 #[test]
396 fn test_recursive_helper_horizontal() {
397 let helper = RecursiveIntraHelper::new(BitDepth::Bits8);
398 let mut output = vec![100u16, 200, 100, 200];
399 let dims = BlockDimensions::new(4, 1);
400
401 helper.filter_horizontal(&mut output, 4, dims);
402
403 assert_eq!(output[0], 100);
406 assert!(output[1] > 100 && output[1] < 200);
407 }
408
409 #[test]
410 fn test_recursive_helper_vertical() {
411 let helper = RecursiveIntraHelper::new(BitDepth::Bits8);
412 let mut output = vec![100u16, 100, 200, 200, 100, 100, 200, 200];
413 let dims = BlockDimensions::new(2, 4);
414
415 helper.filter_vertical(&mut output, 2, dims);
416
417 assert_eq!(output[0], 100);
419 assert_eq!(output[1], 100);
420
421 assert!(output[2] > 100 && output[2] < 200);
423 }
424
425 #[test]
426 fn test_auto_filter_creation() {
427 let filter = IntraEdgeFilter::auto(45, BlockDimensions::new(16, 16), BitDepth::Bits8);
428 assert_eq!(filter.strength(), FilterStrength::Strong);
429
430 let filter = IntraEdgeFilter::auto(90, BlockDimensions::new(16, 16), BitDepth::Bits8);
431 assert_eq!(filter.strength(), FilterStrength::None);
432 }
433}