1#![forbid(unsafe_code)]
26#![allow(dead_code)]
27#![allow(clippy::struct_excessive_bools)]
28#![allow(clippy::cast_sign_loss)]
29#![allow(clippy::cast_possible_truncation)]
30
31use super::{BitDepth, MAX_NEIGHBOR_SAMPLES};
32
33pub type TopSamples = [u16; MAX_NEIGHBOR_SAMPLES];
35
36pub type LeftSamples = [u16; MAX_NEIGHBOR_SAMPLES];
38
39#[derive(Clone, Copy, Debug, Default)]
41pub struct NeighborAvailability {
42 pub top: bool,
44 pub left: bool,
46 pub top_left: bool,
48 pub top_right: bool,
50 pub bottom_left: bool,
52}
53
54impl NeighborAvailability {
55 pub const ALL: Self = Self {
57 top: true,
58 left: true,
59 top_left: true,
60 top_right: true,
61 bottom_left: true,
62 };
63
64 pub const NONE: Self = Self {
66 top: false,
67 left: false,
68 top_left: false,
69 top_right: false,
70 bottom_left: false,
71 };
72
73 #[must_use]
75 pub const fn any(&self) -> bool {
76 self.top || self.left || self.top_left || self.top_right || self.bottom_left
77 }
78
79 #[must_use]
81 pub const fn has_top(&self) -> bool {
82 self.top
83 }
84
85 #[must_use]
87 pub const fn has_left(&self) -> bool {
88 self.left
89 }
90}
91
92#[derive(Clone, Debug)]
96pub struct IntraPredContext {
97 top: TopSamples,
99 left: LeftSamples,
101 top_left: u16,
103 width: usize,
105 height: usize,
107 bit_depth: BitDepth,
109 availability: NeighborAvailability,
111}
112
113impl IntraPredContext {
114 #[must_use]
116 pub fn new(width: usize, height: usize, bit_depth: BitDepth) -> Self {
117 let midpoint = bit_depth.midpoint();
118 Self {
119 top: [midpoint; MAX_NEIGHBOR_SAMPLES],
120 left: [midpoint; MAX_NEIGHBOR_SAMPLES],
121 top_left: midpoint,
122 width,
123 height,
124 bit_depth,
125 availability: NeighborAvailability::NONE,
126 }
127 }
128
129 #[must_use]
131 pub fn with_availability(
132 width: usize,
133 height: usize,
134 bit_depth: BitDepth,
135 availability: NeighborAvailability,
136 ) -> Self {
137 let mut ctx = Self::new(width, height, bit_depth);
138 ctx.availability = availability;
139 ctx
140 }
141
142 #[must_use]
144 pub const fn bit_depth(&self) -> BitDepth {
145 self.bit_depth
146 }
147
148 #[must_use]
150 pub const fn width(&self) -> usize {
151 self.width
152 }
153
154 #[must_use]
156 pub const fn height(&self) -> usize {
157 self.height
158 }
159
160 #[must_use]
162 pub const fn has_top(&self) -> bool {
163 self.availability.top
164 }
165
166 #[must_use]
168 pub const fn has_left(&self) -> bool {
169 self.availability.left
170 }
171
172 #[must_use]
174 pub const fn has_top_left(&self) -> bool {
175 self.availability.top_left
176 }
177
178 #[must_use]
180 pub const fn availability(&self) -> NeighborAvailability {
181 self.availability
182 }
183
184 pub fn set_availability(&mut self, has_top: bool, has_left: bool) {
186 self.availability.top = has_top;
187 self.availability.left = has_left;
188 self.availability.top_left = has_top && has_left;
189 }
190
191 pub fn set_full_availability(&mut self, availability: NeighborAvailability) {
193 self.availability = availability;
194 }
195
196 #[must_use]
198 pub fn top_samples(&self) -> &[u16] {
199 &self.top[..self.width.min(MAX_NEIGHBOR_SAMPLES)]
200 }
201
202 #[must_use]
204 pub fn left_samples(&self) -> &[u16] {
205 &self.left[..self.height.min(MAX_NEIGHBOR_SAMPLES)]
206 }
207
208 #[must_use]
210 pub fn extended_top_samples(&self) -> &[u16] {
211 let count = (self.width * 2).min(MAX_NEIGHBOR_SAMPLES);
212 &self.top[..count]
213 }
214
215 #[must_use]
217 pub fn extended_left_samples(&self) -> &[u16] {
218 let count = (self.height * 2).min(MAX_NEIGHBOR_SAMPLES);
219 &self.left[..count]
220 }
221
222 #[must_use]
224 pub const fn top_left_sample(&self) -> u16 {
225 self.top_left
226 }
227
228 pub fn set_top_sample(&mut self, idx: usize, value: u16) {
230 if idx < MAX_NEIGHBOR_SAMPLES {
231 self.top[idx] = value;
232 }
233 }
234
235 pub fn set_left_sample(&mut self, idx: usize, value: u16) {
237 if idx < MAX_NEIGHBOR_SAMPLES {
238 self.left[idx] = value;
239 }
240 }
241
242 pub fn set_top_left_sample(&mut self, value: u16) {
244 self.top_left = value;
245 }
246
247 pub fn set_top_samples(&mut self, samples: &[u16]) {
249 let count = samples.len().min(MAX_NEIGHBOR_SAMPLES);
250 self.top[..count].copy_from_slice(&samples[..count]);
251 }
252
253 pub fn set_left_samples(&mut self, samples: &[u16]) {
255 let count = samples.len().min(MAX_NEIGHBOR_SAMPLES);
256 self.left[..count].copy_from_slice(&samples[..count]);
257 }
258
259 pub fn filter_top_samples<F>(&mut self, filter: F)
261 where
262 F: FnOnce(&mut [u16]),
263 {
264 filter(&mut self.top);
265 }
266
267 pub fn filter_left_samples<F>(&mut self, filter: F)
269 where
270 F: FnOnce(&mut [u16]),
271 {
272 filter(&mut self.left);
273 }
274
275 #[allow(clippy::too_many_arguments)]
285 pub fn reconstruct_neighbors(
286 &mut self,
287 frame: &[u16],
288 frame_stride: usize,
289 block_x: usize,
290 block_y: usize,
291 frame_width: usize,
292 frame_height: usize,
293 ) {
294 let has_top = block_y > 0;
296 let has_left = block_x > 0;
297 let has_top_right = has_top && (block_x + self.width * 2 <= frame_width);
298 let has_bottom_left = has_left && (block_y + self.height * 2 <= frame_height);
299
300 self.availability = NeighborAvailability {
301 top: has_top,
302 left: has_left,
303 top_left: has_top && has_left,
304 top_right: has_top_right,
305 bottom_left: has_bottom_left,
306 };
307
308 if has_top {
310 let top_y = block_y - 1;
311 let top_row_start = top_y * frame_stride;
312
313 for x in 0..self.width {
315 let frame_x = block_x + x;
316 if frame_x < frame_width {
317 self.top[x] = frame[top_row_start + frame_x];
318 }
319 }
320
321 if has_top_right {
323 for x in self.width..(self.width * 2) {
324 let frame_x = block_x + x;
325 if frame_x < frame_width {
326 self.top[x] = frame[top_row_start + frame_x];
327 }
328 }
329 } else {
330 let last = self.top[self.width.saturating_sub(1)];
332 for x in self.width..(self.width * 2) {
333 self.top[x] = last;
334 }
335 }
336 }
337
338 if has_left {
340 let left_x = block_x - 1;
341
342 for y in 0..self.height {
344 let frame_y = block_y + y;
345 if frame_y < frame_height {
346 self.left[y] = frame[frame_y * frame_stride + left_x];
347 }
348 }
349
350 if has_bottom_left {
352 for y in self.height..(self.height * 2) {
353 let frame_y = block_y + y;
354 if frame_y < frame_height {
355 self.left[y] = frame[frame_y * frame_stride + left_x];
356 }
357 }
358 } else {
359 let last = self.left[self.height.saturating_sub(1)];
361 for y in self.height..(self.height * 2) {
362 self.left[y] = last;
363 }
364 }
365 }
366
367 if has_top && has_left {
369 self.top_left = frame[(block_y - 1) * frame_stride + (block_x - 1)];
370 } else if has_top {
371 self.top_left = self.top[0];
372 } else if has_left {
373 self.top_left = self.left[0];
374 }
375 }
376
377 #[must_use]
379 pub const fn is_at_frame_edge(&self) -> bool {
380 !self.availability.top || !self.availability.left
381 }
382
383 pub fn fill_unavailable(&mut self) {
385 let midpoint = self.bit_depth.midpoint();
386
387 if !self.availability.top {
388 self.top.fill(midpoint);
389 }
390
391 if !self.availability.left {
392 self.left.fill(midpoint);
393 }
394
395 if !self.availability.top_left {
396 self.top_left = midpoint;
397 }
398 }
399
400 #[must_use]
402 pub fn get_extended_sample(&self, x: i32, y: i32) -> u16 {
403 if x < 0 && y < 0 {
404 self.top_left
406 } else if y < 0 {
407 let idx = x as usize;
409 if idx < self.top.len() {
410 self.top[idx]
411 } else {
412 self.top[self.top.len() - 1]
413 }
414 } else if x < 0 {
415 let idx = y as usize;
417 if idx < self.left.len() {
418 self.left[idx]
419 } else {
420 self.left[self.left.len() - 1]
421 }
422 } else {
423 self.bit_depth.midpoint()
425 }
426 }
427}
428
429impl Default for IntraPredContext {
430 fn default() -> Self {
431 Self::new(4, 4, BitDepth::Bits8)
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_context_creation() {
441 let ctx = IntraPredContext::new(8, 8, BitDepth::Bits8);
442 assert_eq!(ctx.width(), 8);
443 assert_eq!(ctx.height(), 8);
444 assert_eq!(ctx.bit_depth(), BitDepth::Bits8);
445
446 assert_eq!(ctx.top_left_sample(), 128);
448 assert!(ctx.top_samples().iter().all(|&s| s == 128));
449 assert!(ctx.left_samples().iter().all(|&s| s == 128));
450 }
451
452 #[test]
453 fn test_availability() {
454 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
455
456 assert!(!ctx.has_top());
457 assert!(!ctx.has_left());
458
459 ctx.set_availability(true, true);
460 assert!(ctx.has_top());
461 assert!(ctx.has_left());
462 assert!(ctx.has_top_left());
463 }
464
465 #[test]
466 fn test_sample_setting() {
467 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
468
469 ctx.set_top_sample(0, 100);
470 ctx.set_top_sample(1, 110);
471 ctx.set_left_sample(0, 90);
472 ctx.set_top_left_sample(95);
473
474 assert_eq!(ctx.top_samples()[0], 100);
475 assert_eq!(ctx.top_samples()[1], 110);
476 assert_eq!(ctx.left_samples()[0], 90);
477 assert_eq!(ctx.top_left_sample(), 95);
478 }
479
480 #[test]
481 fn test_bulk_sample_setting() {
482 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
483
484 let top = [100u16, 110, 120, 130];
485 let left = [90u16, 100, 110, 120];
486
487 ctx.set_top_samples(&top);
488 ctx.set_left_samples(&left);
489
490 assert_eq!(ctx.top_samples()[..4], [100, 110, 120, 130]);
491 assert_eq!(ctx.left_samples()[..4], [90, 100, 110, 120]);
492 }
493
494 #[test]
495 fn test_reconstruct_neighbors() {
496 let frame_width = 16;
498 let frame_height = 16;
499 let mut frame = vec![0u16; frame_width * frame_height];
500
501 for y in 0..frame_height {
503 for x in 0..frame_width {
504 frame[y * frame_width + x] = ((x + y) * 10) as u16;
505 }
506 }
507
508 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
509
510 ctx.reconstruct_neighbors(&frame, frame_width, 4, 4, frame_width, frame_height);
512
513 assert!(ctx.has_top());
514 assert!(ctx.has_left());
515 assert!(ctx.has_top_left());
516
517 assert_eq!(ctx.top_samples()[0], 70);
520 assert_eq!(ctx.top_samples()[1], 80);
521
522 assert_eq!(ctx.left_samples()[0], 70);
525 assert_eq!(ctx.left_samples()[1], 80);
526
527 assert_eq!(ctx.top_left_sample(), 60);
529 }
530
531 #[test]
532 fn test_reconstruct_at_edge() {
533 let frame_width = 16;
534 let frame_height = 16;
535 let frame = vec![100u16; frame_width * frame_height];
536
537 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
538
539 ctx.reconstruct_neighbors(&frame, frame_width, 0, 0, frame_width, frame_height);
541
542 assert!(!ctx.has_top());
543 assert!(!ctx.has_left());
544 assert!(!ctx.has_top_left());
545 }
546
547 #[test]
548 fn test_extended_sample_access() {
549 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
550
551 ctx.set_top_samples(&[10, 20, 30, 40]);
552 ctx.set_left_samples(&[15, 25, 35, 45]);
553 ctx.set_top_left_sample(5);
554
555 assert_eq!(ctx.get_extended_sample(-1, -1), 5);
557
558 assert_eq!(ctx.get_extended_sample(0, -1), 10);
560 assert_eq!(ctx.get_extended_sample(1, -1), 20);
561
562 assert_eq!(ctx.get_extended_sample(-1, 0), 15);
564 assert_eq!(ctx.get_extended_sample(-1, 1), 25);
565 }
566
567 #[test]
568 fn test_neighbor_availability_constants() {
569 let all = NeighborAvailability::ALL;
570 assert!(all.top);
571 assert!(all.left);
572 assert!(all.top_left);
573 assert!(all.any());
574
575 let none = NeighborAvailability::NONE;
576 assert!(!none.top);
577 assert!(!none.left);
578 assert!(!none.any());
579 }
580
581 #[test]
582 fn test_fill_unavailable() {
583 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
584 ctx.set_top_samples(&[200, 200, 200, 200]);
585 ctx.availability.top = false;
586
587 ctx.fill_unavailable();
588
589 assert!(ctx.top_samples().iter().all(|&s| s == 128));
591 }
592
593 #[test]
594 fn test_bit_depth_10() {
595 let ctx = IntraPredContext::new(4, 4, BitDepth::Bits10);
596 assert_eq!(ctx.bit_depth(), BitDepth::Bits10);
597 assert_eq!(ctx.top_left_sample(), 512); }
599
600 #[test]
601 fn test_extended_samples() {
602 let mut ctx = IntraPredContext::new(4, 4, BitDepth::Bits8);
603
604 for i in 0..8 {
606 ctx.set_top_sample(i, (i * 10) as u16);
607 }
608
609 let extended = ctx.extended_top_samples();
610 assert_eq!(extended.len(), 8);
611 assert_eq!(extended[0], 0);
612 assert_eq!(extended[7], 70);
613 }
614}