1#![forbid(unsafe_code)]
24#![allow(dead_code)]
25#![allow(clippy::cast_possible_truncation)]
26#![allow(clippy::cast_precision_loss)]
27#![allow(clippy::cast_sign_loss)]
28#![allow(clippy::similar_names)]
29#![allow(clippy::too_many_arguments)]
30#![allow(clippy::struct_excessive_bools)]
31
32use super::block::{BlockSize, InterMode, IntraMode, PartitionType};
33use super::transform::{TxSize, TxType};
34use crate::motion::{
35 BlockMatch, DiamondSearch, MotionSearch, MotionVector, SearchConfig, SearchContext,
36};
37
38const MAX_INTRA_CANDIDATES: usize = 8;
44
45const MAX_INTER_CANDIDATES: usize = 4;
47
48const EARLY_TERM_THRESHOLD: f32 = 1.2;
50
51const SPLIT_THRESHOLD_BASE: f32 = 0.95;
53
54#[derive(Clone, Debug)]
60pub struct ModeDecisionConfig {
61 pub rd_optimization: bool,
63 pub lambda: f32,
65 pub split_threshold: f32,
67 pub early_termination: bool,
69 pub max_partition_depth: u8,
71 pub tx_size_rdo: bool,
73 pub use_satd: bool,
75 pub preset_speed: u8,
77}
78
79impl Default for ModeDecisionConfig {
80 fn default() -> Self {
81 Self {
82 rd_optimization: true,
83 lambda: 1.0,
84 split_threshold: SPLIT_THRESHOLD_BASE,
85 early_termination: true,
86 max_partition_depth: 4,
87 tx_size_rdo: true,
88 use_satd: true,
89 preset_speed: 5, }
91 }
92}
93
94impl ModeDecisionConfig {
95 #[must_use]
97 pub fn from_qp(qp: u8) -> Self {
98 let lambda = compute_lambda_from_qp(qp);
99 Self {
100 lambda,
101 ..Default::default()
102 }
103 }
104
105 #[must_use]
107 pub fn fast() -> Self {
108 Self {
109 rd_optimization: false,
110 early_termination: true,
111 max_partition_depth: 3,
112 tx_size_rdo: false,
113 use_satd: false,
114 preset_speed: 8,
115 ..Default::default()
116 }
117 }
118
119 #[must_use]
121 pub fn slow() -> Self {
122 Self {
123 rd_optimization: true,
124 early_termination: false,
125 max_partition_depth: 4,
126 tx_size_rdo: true,
127 use_satd: true,
128 preset_speed: 2,
129 ..Default::default()
130 }
131 }
132}
133
134#[derive(Clone, Debug)]
140pub struct ModeCandidate {
141 pub block_size: BlockSize,
143 pub partition: PartitionType,
145 pub pred_mode: PredictionMode,
147 pub tx_size: TxSize,
149 pub tx_type: TxType,
151 pub cost: f32,
153 pub distortion: f32,
155 pub rate: u32,
157 pub mv: Option<MotionVector>,
159 pub skip: bool,
161}
162
163impl ModeCandidate {
164 #[must_use]
166 pub fn new(block_size: BlockSize, pred_mode: PredictionMode) -> Self {
167 Self {
168 block_size,
169 partition: PartitionType::None,
170 pred_mode,
171 tx_size: TxSize::Tx4x4,
172 tx_type: TxType::DctDct,
173 cost: f32::MAX,
174 distortion: 0.0,
175 rate: 0,
176 mv: None,
177 skip: false,
178 }
179 }
180
181 #[must_use]
183 pub const fn is_intra(&self) -> bool {
184 matches!(self.pred_mode, PredictionMode::Intra(_))
185 }
186
187 #[must_use]
189 pub const fn is_inter(&self) -> bool {
190 matches!(self.pred_mode, PredictionMode::Inter(_))
191 }
192}
193
194#[derive(Clone, Copy, Debug, PartialEq, Eq)]
196pub enum PredictionMode {
197 Intra(IntraMode),
199 Inter(InterMode),
201}
202
203#[derive(Clone, Debug)]
209pub struct ModeDecision {
210 config: ModeDecisionConfig,
212 best_cost: f32,
214}
215
216impl ModeDecision {
217 #[must_use]
219 pub fn new(config: ModeDecisionConfig) -> Self {
220 Self {
221 config,
222 best_cost: f32::MAX,
223 }
224 }
225
226 #[must_use]
228 pub fn with_defaults() -> Self {
229 Self::new(ModeDecisionConfig::default())
230 }
231
232 pub fn set_lambda(&mut self, lambda: f32) {
234 self.config.lambda = lambda;
235 }
236
237 pub fn reset(&mut self) {
239 self.best_cost = f32::MAX;
240 }
241
242 #[allow(clippy::unused_self)]
246 pub fn decide_partition(
247 &self,
248 _src: &[u8],
249 _src_stride: usize,
250 block_size: BlockSize,
251 _depth: u8,
252 ) -> PartitionType {
253 if block_size.width() >= 64 {
255 PartitionType::Split
256 } else if block_size.width() >= 32 {
257 PartitionType::None
259 } else {
260 PartitionType::None
261 }
262 }
263
264 pub fn decide_intra_mode(
268 &mut self,
269 src: &[u8],
270 src_stride: usize,
271 recon_left: &[u8],
272 recon_above: &[u8],
273 block_size: BlockSize,
274 ) -> ModeCandidate {
275 let mut best_candidate =
276 ModeCandidate::new(block_size, PredictionMode::Intra(IntraMode::DcPred));
277 let mut best_cost = f32::MAX;
278
279 let modes_to_test = self.get_intra_modes_to_test(block_size);
281
282 for mode in modes_to_test {
283 let candidate = self.evaluate_intra_mode(
284 src,
285 src_stride,
286 recon_left,
287 recon_above,
288 block_size,
289 mode,
290 );
291
292 if candidate.cost < best_cost {
293 best_cost = candidate.cost;
294 best_candidate = candidate;
295
296 if self.config.early_termination
298 && best_cost < self.best_cost * EARLY_TERM_THRESHOLD
299 {
300 break;
301 }
302 }
303 }
304
305 self.best_cost = self.best_cost.min(best_cost);
306 best_candidate
307 }
308
309 #[allow(clippy::too_many_arguments)]
313 pub fn decide_inter_mode(
314 &mut self,
315 src: &[u8],
316 src_stride: usize,
317 ref_frame: &[u8],
318 ref_stride: usize,
319 block_size: BlockSize,
320 x: u32,
321 y: u32,
322 frame_width: u32,
323 frame_height: u32,
324 ) -> ModeCandidate {
325 let mv = self.search_motion(
327 src,
328 src_stride,
329 ref_frame,
330 ref_stride,
331 block_size,
332 x,
333 y,
334 frame_width,
335 frame_height,
336 );
337
338 let mut candidate = ModeCandidate::new(block_size, PredictionMode::Inter(InterMode::NewMv));
340 candidate.mv = Some(mv.mv);
341
342 let distortion = self
344 .compute_inter_distortion(src, src_stride, ref_frame, ref_stride, block_size, &mv.mv);
345
346 let rate = self.estimate_inter_rate(block_size, &mv.mv);
348
349 candidate.distortion = distortion as f32;
350 candidate.rate = rate;
351 candidate.cost = distortion as f32 + self.config.lambda * rate as f32;
352
353 candidate
354 }
355
356 pub fn compute_rd_cost(&self, candidate: &ModeCandidate) -> f32 {
358 if self.config.rd_optimization {
359 candidate.distortion + self.config.lambda * candidate.rate as f32
360 } else {
361 candidate.distortion
363 }
364 }
365
366 fn get_intra_modes_to_test(&self, block_size: BlockSize) -> Vec<IntraMode> {
372 let mut modes = vec![
373 IntraMode::DcPred,
374 IntraMode::VPred,
375 IntraMode::HPred,
376 IntraMode::PaethPred,
377 ];
378
379 if self.config.preset_speed <= 5 {
380 modes.extend_from_slice(&[
382 IntraMode::D45Pred,
383 IntraMode::D135Pred,
384 IntraMode::SmoothPred,
385 ]);
386 }
387
388 if self.config.preset_speed <= 3 && block_size.width() >= 8 {
389 modes.extend_from_slice(&[
391 IntraMode::D67Pred,
392 IntraMode::D113Pred,
393 IntraMode::D157Pred,
394 IntraMode::D203Pred,
395 IntraMode::SmoothVPred,
396 IntraMode::SmoothHPred,
397 ]);
398 }
399
400 modes.truncate(MAX_INTRA_CANDIDATES);
401 modes
402 }
403
404 fn evaluate_intra_mode(
406 &self,
407 src: &[u8],
408 src_stride: usize,
409 _recon_left: &[u8],
410 _recon_above: &[u8],
411 block_size: BlockSize,
412 mode: IntraMode,
413 ) -> ModeCandidate {
414 let mut candidate = ModeCandidate::new(block_size, PredictionMode::Intra(mode));
415
416 let pred = self.generate_intra_prediction(block_size, mode);
418
419 let distortion = self.compute_sse(
421 src,
422 src_stride,
423 &pred,
424 block_size.width() as usize,
425 block_size,
426 );
427
428 let rate = self.estimate_intra_rate(block_size, mode);
430
431 candidate.distortion = distortion as f32;
432 candidate.rate = rate;
433 candidate.cost = distortion as f32 + self.config.lambda * rate as f32;
434 candidate.tx_size = self.select_tx_size(block_size);
435
436 candidate
437 }
438
439 fn generate_intra_prediction(&self, block_size: BlockSize, mode: IntraMode) -> Vec<u8> {
441 let size = (block_size.width() * block_size.height()) as usize;
442 let pred_value = match mode {
443 IntraMode::DcPred => 128,
444 IntraMode::VPred => 128,
445 IntraMode::HPred => 128,
446 _ => 128,
447 };
448 vec![pred_value; size]
449 }
450
451 fn compute_sse(
453 &self,
454 src: &[u8],
455 src_stride: usize,
456 pred: &[u8],
457 pred_stride: usize,
458 block_size: BlockSize,
459 ) -> u64 {
460 let w = block_size.width() as usize;
461 let h = block_size.height() as usize;
462 let mut sse = 0u64;
463
464 for y in 0..h {
465 for x in 0..w {
466 if y * src_stride + x < src.len() && y * pred_stride + x < pred.len() {
467 let diff =
468 i32::from(src[y * src_stride + x]) - i32::from(pred[y * pred_stride + x]);
469 sse += (diff * diff) as u64;
470 }
471 }
472 }
473
474 sse
475 }
476
477 fn estimate_intra_rate(&self, block_size: BlockSize, _mode: IntraMode) -> u32 {
479 let base_rate = 8; let coeff_rate = (block_size.area() / 4) as u32; base_rate + coeff_rate
483 }
484
485 fn estimate_inter_rate(&self, block_size: BlockSize, mv: &MotionVector) -> u32 {
487 let mv_rate = (mv.dx.abs() + mv.dy.abs()) as u32 / 4 + 4;
489
490 let coeff_rate = (block_size.area() / 8) as u32;
492
493 mv_rate + coeff_rate + 4 }
495
496 #[allow(clippy::too_many_arguments)]
498 fn search_motion(
499 &self,
500 src: &[u8],
501 src_stride: usize,
502 ref_frame: &[u8],
503 ref_stride: usize,
504 block_size: BlockSize,
505 x: u32,
506 y: u32,
507 frame_width: u32,
508 frame_height: u32,
509 ) -> BlockMatch {
510 let ctx = SearchContext::new(
512 src,
513 src_stride,
514 ref_frame,
515 ref_stride,
516 crate::motion::BlockSize::Block8x8, x as usize,
518 y as usize,
519 frame_width as usize,
520 frame_height as usize,
521 );
522
523 let search_range = if self.config.preset_speed >= 8 {
525 16 } else if self.config.preset_speed >= 5 {
527 32 } else {
529 64 };
531
532 let search_config =
533 SearchConfig::default().range(crate::motion::SearchRange::symmetric(search_range));
534
535 let searcher = DiamondSearch::new();
537 searcher.search(&ctx, &search_config)
538 }
539
540 fn compute_inter_distortion(
542 &self,
543 src: &[u8],
544 src_stride: usize,
545 ref_frame: &[u8],
546 ref_stride: usize,
547 block_size: BlockSize,
548 mv: &MotionVector,
549 ) -> u64 {
550 let w = block_size.width() as usize;
551 let h = block_size.height() as usize;
552
553 let ref_x = mv.full_pel_x() as usize;
554 let ref_y = mv.full_pel_y() as usize;
555
556 let mut sad = 0u64;
557
558 for y in 0..h {
559 for x in 0..w {
560 let src_idx = y * src_stride + x;
561 let ref_idx = (y + ref_y) * ref_stride + (x + ref_x);
562
563 if src_idx < src.len() && ref_idx < ref_frame.len() {
564 let diff = i32::from(src[src_idx]).abs_diff(i32::from(ref_frame[ref_idx]));
565 sad += u64::from(diff);
566 }
567 }
568 }
569
570 if self.config.use_satd {
571 (sad * 12) / 10
573 } else {
574 sad
575 }
576 }
577
578 fn select_tx_size(&self, block_size: BlockSize) -> TxSize {
580 if self.config.tx_size_rdo {
581 block_size.max_tx_size()
583 } else {
584 block_size.max_tx_size()
586 }
587 }
588}
589
590#[must_use]
598pub fn compute_lambda_from_qp(qp: u8) -> f32 {
599 let qp_f = f32::from(qp);
600 0.85 * 2.0_f32.powf((qp_f - 12.0) / 3.0)
601}
602
603#[must_use]
605pub fn compute_qp_from_lambda(lambda: f32) -> u8 {
606 let qp = 12.0 + 3.0 * (lambda / 0.85).log2();
607 qp.clamp(0.0, 255.0) as u8
608}
609
610#[derive(Clone, Debug)]
616pub struct PartitionDecision {
617 pub partition: PartitionType,
619 pub cost: f32,
621 pub recurse: bool,
623}
624
625impl PartitionDecision {
626 #[must_use]
628 pub const fn no_split(cost: f32) -> Self {
629 Self {
630 partition: PartitionType::None,
631 cost,
632 recurse: false,
633 }
634 }
635
636 #[must_use]
638 pub const fn split(cost: f32) -> Self {
639 Self {
640 partition: PartitionType::Split,
641 cost,
642 recurse: true,
643 }
644 }
645}
646
647#[derive(Clone, Debug, Default)]
653pub struct RateEstimator {
654 pub intra_mode_bits: [f32; 13],
656 pub inter_mode_bits: [f32; 4],
658 pub partition_bits: [f32; 10],
660}
661
662impl RateEstimator {
663 #[must_use]
665 pub fn new() -> Self {
666 Self {
667 intra_mode_bits: [
668 3.0, 3.5, 3.5, 4.0, 4.0, 4.5, 4.5, 4.5, 4.5, 4.0, 4.5, 4.5, 4.0,
669 ],
670 inter_mode_bits: [2.0, 2.5, 3.0, 3.5],
671 partition_bits: [2.0, 3.0, 3.0, 3.5, 4.5, 4.5, 4.5, 4.5, 4.0, 4.0],
672 }
673 }
674
675 #[must_use]
677 pub fn intra_mode_rate(&self, mode: IntraMode) -> f32 {
678 self.intra_mode_bits[mode as usize]
679 }
680
681 #[must_use]
683 pub fn inter_mode_rate(&self, mode: InterMode) -> f32 {
684 self.inter_mode_bits[mode as usize]
685 }
686
687 #[must_use]
689 pub fn partition_rate(&self, partition: PartitionType) -> f32 {
690 self.partition_bits[partition as usize]
691 }
692}
693
694#[cfg(test)]
699mod tests {
700 use super::*;
701
702 #[test]
703 fn test_mode_decision_config_default() {
704 let config = ModeDecisionConfig::default();
705 assert!(config.rd_optimization);
706 assert!(config.early_termination);
707 assert_eq!(config.max_partition_depth, 4);
708 }
709
710 #[test]
711 fn test_mode_decision_config_from_qp() {
712 let config = ModeDecisionConfig::from_qp(28);
713 assert!(config.lambda > 0.0);
714 assert!(config.lambda < 100.0);
715 }
716
717 #[test]
718 fn test_mode_decision_config_presets() {
719 let fast = ModeDecisionConfig::fast();
720 assert!(!fast.rd_optimization);
721 assert_eq!(fast.preset_speed, 8);
722
723 let slow = ModeDecisionConfig::slow();
724 assert!(slow.rd_optimization);
725 assert_eq!(slow.preset_speed, 2);
726 }
727
728 #[test]
729 fn test_mode_candidate_creation() {
730 let candidate = ModeCandidate::new(
731 BlockSize::Block8x8,
732 PredictionMode::Intra(IntraMode::DcPred),
733 );
734 assert_eq!(candidate.block_size, BlockSize::Block8x8);
735 assert!(candidate.is_intra());
736 assert!(!candidate.is_inter());
737 assert_eq!(candidate.cost, f32::MAX);
738 }
739
740 #[test]
741 fn test_mode_decision_creation() {
742 let config = ModeDecisionConfig::default();
743 let md = ModeDecision::new(config);
744 assert_eq!(md.best_cost, f32::MAX);
745 }
746
747 #[test]
748 fn test_lambda_computation() {
749 let lambda = compute_lambda_from_qp(28);
750 assert!(lambda > 0.0);
751 assert!(lambda < 100.0);
752
753 let qp = compute_qp_from_lambda(lambda);
755 assert!((qp as i32 - 28).abs() <= 1);
756 }
757
758 #[test]
759 fn test_lambda_qp_roundtrip() {
760 for qp in [0, 10, 20, 30, 40, 50, 63] {
761 let lambda = compute_lambda_from_qp(qp);
762 let qp_back = compute_qp_from_lambda(lambda);
763 assert!((qp_back as i32 - qp as i32).abs() <= 2);
764 }
765 }
766
767 #[test]
768 fn test_partition_decision() {
769 let no_split = PartitionDecision::no_split(100.0);
770 assert_eq!(no_split.partition, PartitionType::None);
771 assert!(!no_split.recurse);
772
773 let split = PartitionDecision::split(150.0);
774 assert_eq!(split.partition, PartitionType::Split);
775 assert!(split.recurse);
776 }
777
778 #[test]
779 fn test_rate_estimator() {
780 let estimator = RateEstimator::new();
781
782 let dc_rate = estimator.intra_mode_rate(IntraMode::DcPred);
783 assert!(dc_rate > 0.0);
784 assert!(dc_rate < 10.0);
785
786 let inter_rate = estimator.inter_mode_rate(InterMode::NewMv);
787 assert!(inter_rate > 0.0);
788 }
789
790 #[test]
791 fn test_intra_modes_to_test_fast() {
792 let config = ModeDecisionConfig::fast();
793 let md = ModeDecision::new(config);
794 let modes = md.get_intra_modes_to_test(BlockSize::Block8x8);
795
796 assert!(!modes.is_empty());
797 assert!(modes.len() <= MAX_INTRA_CANDIDATES);
798 assert!(modes.contains(&IntraMode::DcPred));
799 }
800
801 #[test]
802 fn test_intra_modes_to_test_slow() {
803 let config = ModeDecisionConfig::slow();
804 let md = ModeDecision::new(config);
805 let modes = md.get_intra_modes_to_test(BlockSize::Block16x16);
806
807 assert!(modes.len() > 4);
808 assert!(modes.contains(&IntraMode::DcPred));
809 assert!(modes.contains(&IntraMode::D45Pred));
810 }
811
812 #[test]
813 fn test_intra_prediction_generation() {
814 let config = ModeDecisionConfig::default();
815 let md = ModeDecision::new(config);
816
817 let pred = md.generate_intra_prediction(BlockSize::Block8x8, IntraMode::DcPred);
818 assert_eq!(pred.len(), 64);
819 assert!(pred.iter().all(|&p| p == 128));
820 }
821
822 #[test]
823 fn test_sse_computation() {
824 let config = ModeDecisionConfig::default();
825 let md = ModeDecision::new(config);
826
827 let src = vec![100u8; 64];
828 let pred = vec![100u8; 64];
829
830 let sse = md.compute_sse(&src, 8, &pred, 8, BlockSize::Block8x8);
831 assert_eq!(sse, 0); let pred2 = vec![110u8; 64];
834 let sse2 = md.compute_sse(&src, 8, &pred2, 8, BlockSize::Block8x8);
835 assert!(sse2 > 0); assert_eq!(sse2, 6400); }
838
839 #[test]
840 fn test_decide_partition_simple() {
841 let config = ModeDecisionConfig::default();
842 let md = ModeDecision::new(config);
843
844 let src = vec![128u8; 64 * 64];
845 let partition = md.decide_partition(&src, 64, BlockSize::Block64x64, 0);
846
847 assert_eq!(partition, PartitionType::Split);
849
850 let partition_small = md.decide_partition(&src, 8, BlockSize::Block8x8, 2);
851 assert_eq!(partition_small, PartitionType::None);
852 }
853
854 #[test]
855 fn test_tx_size_selection() {
856 let config = ModeDecisionConfig::default();
857 let md = ModeDecision::new(config);
858
859 let tx_size = md.select_tx_size(BlockSize::Block8x8);
860 assert_eq!(tx_size, TxSize::Tx8x8);
861
862 let tx_size_16 = md.select_tx_size(BlockSize::Block16x16);
863 assert_eq!(tx_size_16, TxSize::Tx16x16);
864 }
865
866 #[test]
867 fn test_rate_estimation() {
868 let config = ModeDecisionConfig::default();
869 let md = ModeDecision::new(config);
870
871 let rate = md.estimate_intra_rate(BlockSize::Block8x8, IntraMode::DcPred);
872 assert!(rate > 0);
873 assert!(rate < 1000);
874
875 let mv = MotionVector::new(4, 4);
876 let inter_rate = md.estimate_inter_rate(BlockSize::Block8x8, &mv);
877 assert!(inter_rate > 0);
878 }
879
880 #[test]
881 fn test_rd_cost_computation() {
882 let config = ModeDecisionConfig::from_qp(28);
883 let md = ModeDecision::new(config);
884
885 let mut candidate = ModeCandidate::new(
886 BlockSize::Block8x8,
887 PredictionMode::Intra(IntraMode::DcPred),
888 );
889 candidate.distortion = 1000.0;
890 candidate.rate = 100;
891
892 let cost = md.compute_rd_cost(&candidate);
893 assert!(cost > candidate.distortion);
894 assert!(cost < 10000.0);
895 }
896
897 #[test]
898 fn test_prediction_mode() {
899 let intra = PredictionMode::Intra(IntraMode::DcPred);
900 assert!(matches!(intra, PredictionMode::Intra(_)));
901
902 let inter = PredictionMode::Inter(InterMode::NewMv);
903 assert!(matches!(inter, PredictionMode::Inter(_)));
904 }
905}