1#![forbid(unsafe_code)]
8#![allow(dead_code)]
9#![allow(clippy::similar_names)]
10#![allow(clippy::cast_precision_loss)]
11#![allow(clippy::cast_sign_loss)]
12#![allow(clippy::cast_possible_truncation)]
13#![allow(clippy::trivially_copy_pass_by_ref)]
14
15use std::ops::{Add, Neg, Sub};
16
17pub const MV_MAX: i32 = 16383 * 8; pub const MV_MIN: i32 = -16384 * 8;
22
23pub const DEFAULT_SEARCH_RANGE: i32 = 64;
25
26#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
28#[repr(u8)]
29pub enum MvPrecision {
30 FullPel = 0,
32 HalfPel = 1,
34 #[default]
36 QuarterPel = 2,
37 EighthPel = 3,
39}
40
41impl MvPrecision {
42 #[must_use]
44 pub const fn fractional_bits(self) -> u8 {
45 match self {
46 Self::FullPel => 0,
47 Self::HalfPel => 1,
48 Self::QuarterPel => 2,
49 Self::EighthPel => 3,
50 }
51 }
52
53 #[must_use]
55 pub const fn scale(self) -> i32 {
56 1 << self.fractional_bits()
57 }
58
59 #[must_use]
61 pub const fn frac_mask(self) -> i32 {
62 self.scale() - 1
63 }
64
65 #[must_use]
67 pub const fn convert(self, value: i32, target: Self) -> i32 {
68 let src_bits = self.fractional_bits() as i32;
69 let dst_bits = target.fractional_bits() as i32;
70 let shift = dst_bits - src_bits;
71 if shift > 0 {
72 value << shift
73 } else {
74 value >> (-shift)
75 }
76 }
77}
78
79#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
84pub struct MotionVector {
85 pub dx: i32,
87 pub dy: i32,
89}
90
91impl MotionVector {
92 #[must_use]
94 pub const fn zero() -> Self {
95 Self { dx: 0, dy: 0 }
96 }
97
98 #[must_use]
100 pub const fn new(dx: i32, dy: i32) -> Self {
101 Self { dx, dy }
102 }
103
104 #[must_use]
106 pub const fn from_full_pel(dx: i32, dy: i32) -> Self {
107 Self {
108 dx: dx << 3,
109 dy: dy << 3,
110 }
111 }
112
113 #[must_use]
115 pub const fn from_precision(dx: i32, dy: i32, precision: MvPrecision) -> Self {
116 let shift = 3 - precision.fractional_bits() as i32;
117 Self {
118 dx: dx << shift,
119 dy: dy << shift,
120 }
121 }
122
123 #[must_use]
125 pub const fn is_zero(&self) -> bool {
126 self.dx == 0 && self.dy == 0
127 }
128
129 #[must_use]
131 pub const fn full_pel_x(&self) -> i32 {
132 self.dx >> 3
133 }
134
135 #[must_use]
137 pub const fn full_pel_y(&self) -> i32 {
138 self.dy >> 3
139 }
140
141 #[must_use]
143 pub const fn frac_x(&self) -> i32 {
144 self.dx & 7
145 }
146
147 #[must_use]
149 pub const fn frac_y(&self) -> i32 {
150 self.dy & 7
151 }
152
153 #[must_use]
155 pub const fn half_pel_x(&self) -> i32 {
156 (self.dx >> 2) & 1
157 }
158
159 #[must_use]
161 pub const fn half_pel_y(&self) -> i32 {
162 (self.dy >> 2) & 1
163 }
164
165 #[must_use]
167 pub const fn quarter_pel_x(&self) -> i32 {
168 (self.dx >> 1) & 3
169 }
170
171 #[must_use]
173 pub const fn quarter_pel_y(&self) -> i32 {
174 (self.dy >> 1) & 3
175 }
176
177 #[must_use]
179 pub const fn to_precision(&self, precision: MvPrecision) -> Self {
180 let shift = 3 - precision.fractional_bits() as i32;
181 let mask = !((1 << shift) - 1);
182 Self {
183 dx: self.dx & mask,
184 dy: self.dy & mask,
185 }
186 }
187
188 #[must_use]
190 pub const fn round_to_precision(&self, precision: MvPrecision) -> Self {
191 let shift = 3 - precision.fractional_bits() as i32;
192 let round = 1 << (shift - 1);
193 if shift > 0 {
194 Self {
195 dx: ((self.dx + round) >> shift) << shift,
196 dy: ((self.dy + round) >> shift) << shift,
197 }
198 } else {
199 *self
200 }
201 }
202
203 #[must_use]
205 pub fn clamp(&self) -> Self {
206 Self {
207 dx: self.dx.clamp(MV_MIN, MV_MAX),
208 dy: self.dy.clamp(MV_MIN, MV_MAX),
209 }
210 }
211
212 #[must_use]
214 pub fn clamp_to_range(&self, range: &SearchRange) -> Self {
215 Self {
216 dx: self.dx.clamp(-range.horizontal << 3, range.horizontal << 3),
217 dy: self.dy.clamp(-range.vertical << 3, range.vertical << 3),
218 }
219 }
220
221 #[must_use]
223 pub const fn magnitude_squared(&self) -> i64 {
224 (self.dx as i64) * (self.dx as i64) + (self.dy as i64) * (self.dy as i64)
225 }
226
227 #[must_use]
229 pub const fn l1_norm(&self) -> i32 {
230 self.dx.abs() + self.dy.abs()
231 }
232
233 #[must_use]
235 pub fn linf_norm(&self) -> i32 {
236 self.dx.abs().max(self.dy.abs())
237 }
238
239 #[must_use]
241 #[allow(clippy::cast_possible_truncation)]
242 pub fn scale(&self, num: i32, den: i32) -> Self {
243 if den == 0 {
244 return *self;
245 }
246 Self {
247 dx: ((i64::from(self.dx) * i64::from(num)) / i64::from(den)) as i32,
248 dy: ((i64::from(self.dy) * i64::from(num)) / i64::from(den)) as i32,
249 }
250 }
251}
252
253impl Add for MotionVector {
254 type Output = Self;
255
256 fn add(self, other: Self) -> Self {
257 Self {
258 dx: self.dx.saturating_add(other.dx),
259 dy: self.dy.saturating_add(other.dy),
260 }
261 }
262}
263
264impl Sub for MotionVector {
265 type Output = Self;
266
267 fn sub(self, other: Self) -> Self {
268 Self {
269 dx: self.dx.saturating_sub(other.dx),
270 dy: self.dy.saturating_sub(other.dy),
271 }
272 }
273}
274
275impl Neg for MotionVector {
276 type Output = Self;
277
278 fn neg(self) -> Self {
279 Self {
280 dx: self.dx.saturating_neg(),
281 dy: self.dy.saturating_neg(),
282 }
283 }
284}
285
286#[derive(Clone, Copy, Debug, PartialEq, Eq)]
288pub struct SearchRange {
289 pub horizontal: i32,
291 pub vertical: i32,
293}
294
295impl Default for SearchRange {
296 fn default() -> Self {
297 Self::new(DEFAULT_SEARCH_RANGE, DEFAULT_SEARCH_RANGE)
298 }
299}
300
301impl SearchRange {
302 #[must_use]
304 pub const fn new(horizontal: i32, vertical: i32) -> Self {
305 Self {
306 horizontal,
307 vertical,
308 }
309 }
310
311 #[must_use]
313 pub const fn symmetric(range: i32) -> Self {
314 Self::new(range, range)
315 }
316
317 #[must_use]
319 pub const fn num_positions(&self) -> u64 {
320 let w = (2 * self.horizontal + 1) as u64;
321 let h = (2 * self.vertical + 1) as u64;
322 w * h
323 }
324
325 #[must_use]
327 pub const fn contains(&self, dx: i32, dy: i32) -> bool {
328 dx >= -self.horizontal
329 && dx <= self.horizontal
330 && dy >= -self.vertical
331 && dy <= self.vertical
332 }
333
334 #[must_use]
336 pub const fn scale(&self, factor: i32) -> Self {
337 Self {
338 horizontal: self.horizontal * factor,
339 vertical: self.vertical * factor,
340 }
341 }
342
343 #[must_use]
345 pub const fn reduce(&self, factor: i32) -> Self {
346 if factor == 0 {
347 *self
348 } else {
349 Self {
350 horizontal: self.horizontal / factor,
351 vertical: self.vertical / factor,
352 }
353 }
354 }
355}
356
357#[derive(Clone, Copy, Debug, PartialEq, Eq)]
359pub struct BlockMatch {
360 pub mv: MotionVector,
362 pub sad: u32,
364 pub cost: u32,
366}
367
368impl Default for BlockMatch {
369 fn default() -> Self {
370 Self::worst()
371 }
372}
373
374impl BlockMatch {
375 #[must_use]
377 pub const fn new(mv: MotionVector, sad: u32, cost: u32) -> Self {
378 Self { mv, sad, cost }
379 }
380
381 #[must_use]
383 pub const fn zero_mv(sad: u32) -> Self {
384 Self {
385 mv: MotionVector::zero(),
386 sad,
387 cost: sad,
388 }
389 }
390
391 #[must_use]
393 pub const fn worst() -> Self {
394 Self {
395 mv: MotionVector::zero(),
396 sad: u32::MAX,
397 cost: u32::MAX,
398 }
399 }
400
401 #[must_use]
403 pub const fn is_better_than(&self, other: &Self) -> bool {
404 self.cost < other.cost
405 }
406
407 pub fn update_if_better(&mut self, other: &Self) {
409 if other.is_better_than(self) {
410 *self = *other;
411 }
412 }
413}
414
415#[derive(Clone, Copy, Debug)]
417pub struct MvCost {
418 pub lambda: f32,
420 pub mv_weight: f32,
422 pub ref_mv: MotionVector,
424}
425
426impl Default for MvCost {
427 fn default() -> Self {
428 Self::new(1.0)
429 }
430}
431
432impl MvCost {
433 #[must_use]
435 pub const fn new(lambda: f32) -> Self {
436 Self {
437 lambda,
438 mv_weight: 1.0,
439 ref_mv: MotionVector::zero(),
440 }
441 }
442
443 #[must_use]
445 pub const fn with_ref_mv(lambda: f32, ref_mv: MotionVector) -> Self {
446 Self {
447 lambda,
448 mv_weight: 1.0,
449 ref_mv,
450 }
451 }
452
453 #[must_use]
455 pub fn estimate_bits(&self, mv: &MotionVector) -> f32 {
456 let diff = *mv - self.ref_mv;
457 let dx_bits = Self::component_bits(diff.dx);
458 let dy_bits = Self::component_bits(diff.dy);
459 (dx_bits + dy_bits) * self.mv_weight
460 }
461
462 #[must_use]
464 fn component_bits(value: i32) -> f32 {
465 if value == 0 {
466 return 1.0;
467 }
468 let abs_val = value.unsigned_abs();
469 let log2_approx = 32 - abs_val.leading_zeros();
471 (2 * log2_approx + 2) as f32
472 }
473
474 #[must_use]
476 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
477 pub fn rd_cost(&self, mv: &MotionVector, sad: u32) -> u32 {
478 let bits = self.estimate_bits(mv);
479 let rate_cost = (bits * self.lambda) as u32;
480 sad.saturating_add(rate_cost)
481 }
482
483 pub fn set_ref_mv(&mut self, ref_mv: MotionVector) {
485 self.ref_mv = ref_mv;
486 }
487}
488
489#[derive(Clone, Copy, Debug, PartialEq, Eq, Default, Hash)]
491#[repr(u8)]
492pub enum BlockSize {
493 Block4x4 = 0,
495 Block4x8 = 1,
497 Block8x4 = 2,
499 #[default]
501 Block8x8 = 3,
502 Block8x16 = 4,
504 Block16x8 = 5,
506 Block16x16 = 6,
508 Block16x32 = 7,
510 Block32x16 = 8,
512 Block32x32 = 9,
514 Block32x64 = 10,
516 Block64x32 = 11,
518 Block64x64 = 12,
520 Block64x128 = 13,
522 Block128x64 = 14,
524 Block128x128 = 15,
526}
527
528impl BlockSize {
529 #[must_use]
531 pub const fn width(&self) -> usize {
532 match self {
533 Self::Block4x4 | Self::Block4x8 => 4,
534 Self::Block8x4 | Self::Block8x8 | Self::Block8x16 => 8,
535 Self::Block16x8 | Self::Block16x16 | Self::Block16x32 => 16,
536 Self::Block32x16 | Self::Block32x32 | Self::Block32x64 => 32,
537 Self::Block64x32 | Self::Block64x64 | Self::Block64x128 => 64,
538 Self::Block128x64 | Self::Block128x128 => 128,
539 }
540 }
541
542 #[must_use]
544 pub const fn height(&self) -> usize {
545 match self {
546 Self::Block4x4 | Self::Block8x4 => 4,
547 Self::Block4x8 | Self::Block8x8 | Self::Block16x8 => 8,
548 Self::Block8x16 | Self::Block16x16 | Self::Block32x16 => 16,
549 Self::Block16x32 | Self::Block32x32 | Self::Block64x32 => 32,
550 Self::Block32x64 | Self::Block64x64 | Self::Block128x64 => 64,
551 Self::Block64x128 | Self::Block128x128 => 128,
552 }
553 }
554
555 #[must_use]
557 pub const fn num_pixels(&self) -> usize {
558 self.width() * self.height()
559 }
560
561 #[must_use]
563 pub const fn is_square(&self) -> bool {
564 matches!(
565 self,
566 Self::Block4x4
567 | Self::Block8x8
568 | Self::Block16x16
569 | Self::Block32x32
570 | Self::Block64x64
571 | Self::Block128x128
572 )
573 }
574
575 #[must_use]
577 pub const fn width_log2(&self) -> u8 {
578 match self.width() {
579 4 => 2,
580 8 => 3,
581 16 => 4,
582 32 => 5,
583 64 => 6,
584 128 => 7,
585 _ => 0,
586 }
587 }
588
589 #[must_use]
591 pub const fn height_log2(&self) -> u8 {
592 match self.height() {
593 4 => 2,
594 8 => 3,
595 16 => 4,
596 32 => 5,
597 64 => 6,
598 128 => 7,
599 _ => 0,
600 }
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607
608 #[test]
609 fn test_mv_precision() {
610 assert_eq!(MvPrecision::FullPel.fractional_bits(), 0);
611 assert_eq!(MvPrecision::HalfPel.fractional_bits(), 1);
612 assert_eq!(MvPrecision::QuarterPel.fractional_bits(), 2);
613 assert_eq!(MvPrecision::EighthPel.fractional_bits(), 3);
614
615 assert_eq!(MvPrecision::FullPel.scale(), 1);
616 assert_eq!(MvPrecision::QuarterPel.scale(), 4);
617 assert_eq!(MvPrecision::EighthPel.scale(), 8);
618 }
619
620 #[test]
621 fn test_mv_precision_convert() {
622 assert_eq!(MvPrecision::FullPel.convert(2, MvPrecision::QuarterPel), 8);
624 assert_eq!(MvPrecision::QuarterPel.convert(8, MvPrecision::FullPel), 2);
626 }
627
628 #[test]
629 fn test_motion_vector_creation() {
630 let mv = MotionVector::new(16, -24);
631 assert_eq!(mv.dx, 16);
632 assert_eq!(mv.dy, -24);
633
634 let mv_fp = MotionVector::from_full_pel(2, -3);
635 assert_eq!(mv_fp.dx, 16);
636 assert_eq!(mv_fp.dy, -24);
637 }
638
639 #[test]
640 fn test_motion_vector_components() {
641 let mv = MotionVector::new(27, -19); assert_eq!(mv.full_pel_x(), 3);
644 assert_eq!(mv.full_pel_y(), -3); assert_eq!(mv.frac_x(), 3);
646 assert_eq!(mv.frac_y(), -19 & 7);
647 }
648
649 #[test]
650 fn test_motion_vector_zero() {
651 let mv = MotionVector::zero();
652 assert!(mv.is_zero());
653 assert_eq!(mv.magnitude_squared(), 0);
654 }
655
656 #[test]
657 fn test_motion_vector_arithmetic() {
658 let mv1 = MotionVector::new(10, 20);
659 let mv2 = MotionVector::new(5, -10);
660
661 let sum = mv1 + mv2;
662 assert_eq!(sum.dx, 15);
663 assert_eq!(sum.dy, 10);
664
665 let diff = mv1 - mv2;
666 assert_eq!(diff.dx, 5);
667 assert_eq!(diff.dy, 30);
668
669 let neg = -mv1;
670 assert_eq!(neg.dx, -10);
671 assert_eq!(neg.dy, -20);
672 }
673
674 #[test]
675 fn test_motion_vector_magnitude() {
676 let mv = MotionVector::new(3, 4);
677 assert_eq!(mv.magnitude_squared(), 25);
678 assert_eq!(mv.l1_norm(), 7);
679 assert_eq!(mv.linf_norm(), 4);
680 }
681
682 #[test]
683 fn test_motion_vector_precision_conversion() {
684 let mv = MotionVector::new(27, 19); let qpel = mv.to_precision(MvPrecision::QuarterPel);
687 assert_eq!(qpel.dx & 1, 0); assert_eq!(qpel.dy & 1, 0);
689
690 let fpel = mv.to_precision(MvPrecision::FullPel);
691 assert_eq!(fpel.dx & 7, 0); assert_eq!(fpel.dy & 7, 0);
693 }
694
695 #[test]
696 fn test_search_range() {
697 let range = SearchRange::symmetric(32);
698 assert_eq!(range.horizontal, 32);
699 assert_eq!(range.vertical, 32);
700
701 assert!(range.contains(0, 0));
702 assert!(range.contains(32, 32));
703 assert!(range.contains(-32, -32));
704 assert!(!range.contains(33, 0));
705 }
706
707 #[test]
708 fn test_search_range_positions() {
709 let range = SearchRange::symmetric(2);
710 assert_eq!(range.num_positions(), 25);
712 }
713
714 #[test]
715 fn test_block_match() {
716 let best = BlockMatch::new(MotionVector::new(8, 16), 100, 120);
717 let worst = BlockMatch::worst();
718
719 assert!(best.is_better_than(&worst));
720 assert!(!worst.is_better_than(&best));
721 }
722
723 #[test]
724 fn test_block_match_update() {
725 let mut current = BlockMatch::worst();
726 let better = BlockMatch::new(MotionVector::new(8, 16), 100, 120);
727
728 current.update_if_better(&better);
729 assert_eq!(current.sad, 100);
730 }
731
732 #[test]
733 fn test_mv_cost() {
734 let cost = MvCost::new(1.0);
735 let mv = MotionVector::new(16, 16);
736
737 let bits = cost.estimate_bits(&mv);
738 assert!(bits > 0.0);
739
740 let rd = cost.rd_cost(&mv, 100);
741 assert!(rd >= 100);
742 }
743
744 #[test]
745 fn test_mv_cost_with_ref() {
746 let ref_mv = MotionVector::new(16, 16);
747 let cost = MvCost::with_ref_mv(1.0, ref_mv);
748
749 let same_bits = cost.estimate_bits(&ref_mv);
751
752 let diff_mv = MotionVector::new(32, 32);
754 let diff_bits = cost.estimate_bits(&diff_mv);
755
756 assert!(same_bits < diff_bits);
757 }
758
759 #[test]
760 fn test_block_size() {
761 assert_eq!(BlockSize::Block8x8.width(), 8);
762 assert_eq!(BlockSize::Block8x8.height(), 8);
763 assert_eq!(BlockSize::Block8x8.num_pixels(), 64);
764 assert!(BlockSize::Block8x8.is_square());
765
766 assert_eq!(BlockSize::Block16x8.width(), 16);
767 assert_eq!(BlockSize::Block16x8.height(), 8);
768 assert!(!BlockSize::Block16x8.is_square());
769 }
770
771 #[test]
772 fn test_block_size_log2() {
773 assert_eq!(BlockSize::Block4x4.width_log2(), 2);
774 assert_eq!(BlockSize::Block8x8.width_log2(), 3);
775 assert_eq!(BlockSize::Block16x16.width_log2(), 4);
776 assert_eq!(BlockSize::Block64x64.width_log2(), 6);
777 }
778}