1#![forbid(unsafe_code)]
14#![allow(dead_code)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_sign_loss)]
17#![allow(clippy::needless_range_loop)]
18#![allow(clippy::similar_names)]
19#![allow(clippy::unused_self)]
20#![allow(clippy::trivially_copy_pass_by_ref)]
21#![allow(clippy::match_same_arms)]
22
23use super::{BitDepth, BlockDimensions, IntraPredContext, IntraPredictor};
24
25#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
27pub enum DcMode {
28 #[default]
30 Both,
31 TopOnly,
33 LeftOnly,
35 NoNeighbors,
37 WithGradient,
39}
40
41#[derive(Clone, Copy, Debug, Default)]
43pub struct DcPredictor {
44 bit_depth: BitDepth,
46}
47
48impl DcPredictor {
49 #[must_use]
51 pub const fn new(bit_depth: BitDepth) -> Self {
52 Self { bit_depth }
53 }
54
55 fn dc_top_only(top: &[u16], width: usize) -> u16 {
57 if width == 0 {
58 return 128;
59 }
60
61 let sum: u32 = top.iter().take(width).map(|&s| u32::from(s)).sum();
62 let avg = (sum + (width as u32 / 2)) / width as u32;
63 avg as u16
64 }
65
66 fn dc_left_only(left: &[u16], height: usize) -> u16 {
68 if height == 0 {
69 return 128;
70 }
71
72 let sum: u32 = left.iter().take(height).map(|&s| u32::from(s)).sum();
73 let avg = (sum + (height as u32 / 2)) / height as u32;
74 avg as u16
75 }
76
77 fn dc_both(top: &[u16], left: &[u16], width: usize, height: usize) -> u16 {
79 if width == 0 && height == 0 {
80 return 128;
81 }
82
83 let top_sum: u32 = top.iter().take(width).map(|&s| u32::from(s)).sum();
84 let left_sum: u32 = left.iter().take(height).map(|&s| u32::from(s)).sum();
85
86 let total = width + height;
87 let sum = top_sum + left_sum;
88 let avg = (sum + (total as u32 / 2)) / total as u32;
89 avg as u16
90 }
91
92 pub fn predict_dc(
94 &self,
95 ctx: &IntraPredContext,
96 output: &mut [u16],
97 stride: usize,
98 dims: BlockDimensions,
99 ) {
100 let mode = self.determine_mode(ctx);
101 let dc_value = self.calculate_dc(ctx, dims, mode);
102
103 for y in 0..dims.height {
105 let row_start = y * stride;
106 for x in 0..dims.width {
107 output[row_start + x] = dc_value;
108 }
109 }
110 }
111
112 pub fn predict_dc_gradient(
114 &self,
115 ctx: &IntraPredContext,
116 output: &mut [u16],
117 stride: usize,
118 dims: BlockDimensions,
119 ) {
120 let base_dc = self.calculate_dc(ctx, dims, DcMode::Both);
121 let top = ctx.top_samples();
122 let left = ctx.left_samples();
123 let max_val = self.bit_depth.max_value();
124
125 let top_left = ctx.top_left_sample();
127
128 for y in 0..dims.height {
129 let row_start = y * stride;
130 let left_diff = i32::from(left[y]) - i32::from(top_left);
131
132 for x in 0..dims.width {
133 let top_diff = i32::from(top[x]) - i32::from(top_left);
134
135 let gradient = (top_diff + left_diff) / 2;
137 let pred = i32::from(base_dc) + gradient;
138
139 let clamped = pred.clamp(0, i32::from(max_val));
141 output[row_start + x] = clamped as u16;
142 }
143 }
144 }
145
146 fn determine_mode(&self, ctx: &IntraPredContext) -> DcMode {
148 let has_top = ctx.has_top();
149 let has_left = ctx.has_left();
150
151 match (has_top, has_left) {
152 (true, true) => DcMode::Both,
153 (true, false) => DcMode::TopOnly,
154 (false, true) => DcMode::LeftOnly,
155 (false, false) => DcMode::NoNeighbors,
156 }
157 }
158
159 fn calculate_dc(&self, ctx: &IntraPredContext, dims: BlockDimensions, mode: DcMode) -> u16 {
161 match mode {
162 DcMode::Both => Self::dc_both(
163 ctx.top_samples(),
164 ctx.left_samples(),
165 dims.width,
166 dims.height,
167 ),
168 DcMode::TopOnly => Self::dc_top_only(ctx.top_samples(), dims.width),
169 DcMode::LeftOnly => Self::dc_left_only(ctx.left_samples(), dims.height),
170 DcMode::NoNeighbors => self.bit_depth.midpoint(),
171 DcMode::WithGradient => Self::dc_both(
172 ctx.top_samples(),
173 ctx.left_samples(),
174 dims.width,
175 dims.height,
176 ),
177 }
178 }
179}
180
181impl IntraPredictor for DcPredictor {
182 fn predict(
183 &self,
184 ctx: &IntraPredContext,
185 output: &mut [u16],
186 stride: usize,
187 dims: BlockDimensions,
188 ) {
189 self.predict_dc(ctx, output, stride, dims);
190 }
191}
192
193#[derive(Clone, Copy, Debug, Default)]
195pub struct DcTopPredictor {
196 bit_depth: BitDepth,
197}
198
199impl DcTopPredictor {
200 #[must_use]
202 pub const fn new(bit_depth: BitDepth) -> Self {
203 Self { bit_depth }
204 }
205}
206
207impl IntraPredictor for DcTopPredictor {
208 fn predict(
209 &self,
210 ctx: &IntraPredContext,
211 output: &mut [u16],
212 stride: usize,
213 dims: BlockDimensions,
214 ) {
215 let dc_value = if ctx.has_top() {
216 DcPredictor::dc_top_only(ctx.top_samples(), dims.width)
217 } else {
218 self.bit_depth.midpoint()
219 };
220
221 for y in 0..dims.height {
222 let row_start = y * stride;
223 for x in 0..dims.width {
224 output[row_start + x] = dc_value;
225 }
226 }
227 }
228}
229
230#[derive(Clone, Copy, Debug, Default)]
232pub struct DcLeftPredictor {
233 bit_depth: BitDepth,
234}
235
236impl DcLeftPredictor {
237 #[must_use]
239 pub const fn new(bit_depth: BitDepth) -> Self {
240 Self { bit_depth }
241 }
242}
243
244impl IntraPredictor for DcLeftPredictor {
245 fn predict(
246 &self,
247 ctx: &IntraPredContext,
248 output: &mut [u16],
249 stride: usize,
250 dims: BlockDimensions,
251 ) {
252 let dc_value = if ctx.has_left() {
253 DcPredictor::dc_left_only(ctx.left_samples(), dims.height)
254 } else {
255 self.bit_depth.midpoint()
256 };
257
258 for y in 0..dims.height {
259 let row_start = y * stride;
260 for x in 0..dims.width {
261 output[row_start + x] = dc_value;
262 }
263 }
264 }
265}
266
267#[derive(Clone, Copy, Debug, Default)]
269pub struct Dc128Predictor {
270 bit_depth: BitDepth,
271}
272
273impl Dc128Predictor {
274 #[must_use]
276 pub const fn new(bit_depth: BitDepth) -> Self {
277 Self { bit_depth }
278 }
279}
280
281impl IntraPredictor for Dc128Predictor {
282 fn predict(
283 &self,
284 _ctx: &IntraPredContext,
285 output: &mut [u16],
286 stride: usize,
287 dims: BlockDimensions,
288 ) {
289 let dc_value = self.bit_depth.midpoint();
290
291 for y in 0..dims.height {
292 let row_start = y * stride;
293 for x in 0..dims.width {
294 output[row_start + x] = dc_value;
295 }
296 }
297 }
298}
299
300#[derive(Clone, Copy, Debug, Default)]
302pub struct DcGradientPredictor {
303 bit_depth: BitDepth,
304}
305
306impl DcGradientPredictor {
307 #[must_use]
309 pub const fn new(bit_depth: BitDepth) -> Self {
310 Self { bit_depth }
311 }
312}
313
314impl IntraPredictor for DcGradientPredictor {
315 fn predict(
316 &self,
317 ctx: &IntraPredContext,
318 output: &mut [u16],
319 stride: usize,
320 dims: BlockDimensions,
321 ) {
322 let predictor = DcPredictor::new(self.bit_depth);
323 predictor.predict_dc_gradient(ctx, output, stride, dims);
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::intra::context::IntraPredContext;
331
332 fn create_test_context() -> IntraPredContext {
333 let mut ctx = IntraPredContext::new(8, 8, BitDepth::Bits8);
334
335 for i in 0..8 {
337 ctx.set_top_sample(i, 100 + (i as u16 * 10));
338 }
339
340 for i in 0..8 {
342 ctx.set_left_sample(i, 80 + (i as u16 * 10));
343 }
344
345 ctx.set_top_left_sample(90);
346 ctx.set_availability(true, true);
347
348 ctx
349 }
350
351 #[test]
352 fn test_dc_top_only() {
353 let top = [100u16, 110, 120, 130];
354 let dc = DcPredictor::dc_top_only(&top, 4);
355 assert_eq!(dc, 115);
357 }
358
359 #[test]
360 fn test_dc_left_only() {
361 let left = [80u16, 90, 100, 110];
362 let dc = DcPredictor::dc_left_only(&left, 4);
363 assert_eq!(dc, 95);
365 }
366
367 #[test]
368 fn test_dc_both() {
369 let top = [100u16, 110, 120, 130];
370 let left = [80u16, 90, 100, 110];
371 let dc = DcPredictor::dc_both(&top, &left, 4, 4);
372 assert_eq!(dc, 105);
374 }
375
376 #[test]
377 fn test_dc_predictor_both() {
378 let ctx = create_test_context();
379 let predictor = DcPredictor::new(BitDepth::Bits8);
380 let dims = BlockDimensions::new(8, 8);
381 let mut output = vec![0u16; 64];
382
383 predictor.predict(&ctx, &mut output, 8, dims);
384
385 let dc_value = output[0];
387 assert!(output.iter().all(|&v| v == dc_value));
388
389 assert_eq!(dc_value, 125);
393 }
394
395 #[test]
396 fn test_dc_128_predictor() {
397 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
398 ctx.set_availability(false, false);
399
400 let predictor = Dc128Predictor::new(BitDepth::Bits8);
401 let dims = BlockDimensions::new(4, 4);
402 let mut output = vec![0u16; 16];
403
404 predictor.predict(&ctx, &mut output, 4, dims);
405
406 assert!(output.iter().all(|&v| v == 128));
408 }
409
410 #[test]
411 fn test_dc_top_predictor() {
412 let ctx = create_test_context();
413 let predictor = DcTopPredictor::new(BitDepth::Bits8);
414 let dims = BlockDimensions::new(8, 8);
415 let mut output = vec![0u16; 64];
416
417 predictor.predict(&ctx, &mut output, 8, dims);
418
419 assert!(output.iter().all(|&v| v == 135));
421 }
422
423 #[test]
424 fn test_dc_left_predictor() {
425 let ctx = create_test_context();
426 let predictor = DcLeftPredictor::new(BitDepth::Bits8);
427 let dims = BlockDimensions::new(8, 8);
428 let mut output = vec![0u16; 64];
429
430 predictor.predict(&ctx, &mut output, 8, dims);
431
432 assert!(output.iter().all(|&v| v == 115));
434 }
435
436 #[test]
437 fn test_dc_mode_determination() {
438 let predictor = DcPredictor::new(BitDepth::Bits8);
439
440 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
441
442 ctx.set_availability(true, true);
443 assert_eq!(predictor.determine_mode(&ctx), DcMode::Both);
444
445 ctx.set_availability(true, false);
446 assert_eq!(predictor.determine_mode(&ctx), DcMode::TopOnly);
447
448 ctx.set_availability(false, true);
449 assert_eq!(predictor.determine_mode(&ctx), DcMode::LeftOnly);
450
451 ctx.set_availability(false, false);
452 assert_eq!(predictor.determine_mode(&ctx), DcMode::NoNeighbors);
453 }
454
455 #[test]
456 fn test_bit_depth_10() {
457 let predictor = Dc128Predictor::new(BitDepth::Bits10);
458 let ctx = IntraPredContext::new(4, 4, BitDepth::Bits10);
459 let dims = BlockDimensions::new(4, 4);
460 let mut output = vec![0u16; 16];
461
462 predictor.predict(&ctx, &mut output, 4, dims);
463
464 assert!(output.iter().all(|&v| v == 512));
466 }
467}