Skip to main content

oximedia_codec/av1/
mode_decision.rs

1//! AV1 mode decision with rate-distortion optimization.
2//!
3//! This module implements comprehensive mode decision for AV1 encoding:
4//!
5//! - Partition decision (split vs non-split)
6//! - Intra mode RDO with all 13 directional modes
7//! - Inter mode RDO with motion estimation
8//! - Transform size selection
9//! - Rate-distortion cost computation
10//!
11//! # Rate-Distortion Optimization
12//!
13//! The encoder selects the best coding mode by minimizing:
14//! ```text
15//! Cost = Distortion + λ * Rate
16//! ```
17//!
18//! Where:
19//! - Distortion = SSE (sum of squared errors) or SATD
20//! - λ = lagrangian multiplier derived from QP
21//! - Rate = estimated bits for mode/MV/residual
22
23#![forbid(unsafe_code)]
24#![allow(dead_code)]
25#![allow(clippy::cast_possible_truncation)]
26#![allow(clippy::cast_precision_loss)]
27#![allow(clippy::cast_sign_loss)]
28#![allow(clippy::similar_names)]
29#![allow(clippy::too_many_arguments)]
30#![allow(clippy::struct_excessive_bools)]
31
32use super::block::{BlockSize, InterMode, IntraMode, PartitionType};
33use super::transform::{TxSize, TxType};
34use crate::motion::{
35    BlockMatch, DiamondSearch, MotionSearch, MotionVector, SearchConfig, SearchContext,
36};
37
38// =============================================================================
39// Constants
40// =============================================================================
41
42/// Maximum number of intra mode candidates to test.
43const MAX_INTRA_CANDIDATES: usize = 8;
44
45/// Maximum number of inter mode candidates.
46const MAX_INTER_CANDIDATES: usize = 4;
47
48/// Threshold for early termination in mode decision.
49const EARLY_TERM_THRESHOLD: f32 = 1.2;
50
51/// Partition split threshold multiplier.
52const SPLIT_THRESHOLD_BASE: f32 = 0.95;
53
54// =============================================================================
55// Mode Decision Configuration
56// =============================================================================
57
58/// Mode decision configuration.
59#[derive(Clone, Debug)]
60pub struct ModeDecisionConfig {
61    /// Enable rate-distortion optimization.
62    pub rd_optimization: bool,
63    /// Lagrangian multiplier for RD cost.
64    pub lambda: f32,
65    /// Split threshold for partition decision.
66    pub split_threshold: f32,
67    /// Enable early termination.
68    pub early_termination: bool,
69    /// Maximum partition depth.
70    pub max_partition_depth: u8,
71    /// Enable transform size RDO.
72    pub tx_size_rdo: bool,
73    /// Use SATD instead of SAD for motion estimation.
74    pub use_satd: bool,
75    /// Encoder preset (affects search thoroughness).
76    pub preset_speed: u8,
77}
78
79impl Default for ModeDecisionConfig {
80    fn default() -> Self {
81        Self {
82            rd_optimization: true,
83            lambda: 1.0,
84            split_threshold: SPLIT_THRESHOLD_BASE,
85            early_termination: true,
86            max_partition_depth: 4,
87            tx_size_rdo: true,
88            use_satd: true,
89            preset_speed: 5, // Medium
90        }
91    }
92}
93
94impl ModeDecisionConfig {
95    /// Create config from QP value.
96    #[must_use]
97    pub fn from_qp(qp: u8) -> Self {
98        let lambda = compute_lambda_from_qp(qp);
99        Self {
100            lambda,
101            ..Default::default()
102        }
103    }
104
105    /// Create config for fast preset.
106    #[must_use]
107    pub fn fast() -> Self {
108        Self {
109            rd_optimization: false,
110            early_termination: true,
111            max_partition_depth: 3,
112            tx_size_rdo: false,
113            use_satd: false,
114            preset_speed: 8,
115            ..Default::default()
116        }
117    }
118
119    /// Create config for slow preset.
120    #[must_use]
121    pub fn slow() -> Self {
122        Self {
123            rd_optimization: true,
124            early_termination: false,
125            max_partition_depth: 4,
126            tx_size_rdo: true,
127            use_satd: true,
128            preset_speed: 2,
129            ..Default::default()
130        }
131    }
132}
133
134// =============================================================================
135// Mode Candidate
136// =============================================================================
137
138/// Mode decision candidate.
139#[derive(Clone, Debug)]
140pub struct ModeCandidate {
141    /// Block size for this candidate.
142    pub block_size: BlockSize,
143    /// Partition type.
144    pub partition: PartitionType,
145    /// Prediction mode (intra or inter).
146    pub pred_mode: PredictionMode,
147    /// Transform size.
148    pub tx_size: TxSize,
149    /// Transform type.
150    pub tx_type: TxType,
151    /// Rate-distortion cost.
152    pub cost: f32,
153    /// Distortion (SSE or SATD).
154    pub distortion: f32,
155    /// Rate in bits.
156    pub rate: u32,
157    /// Motion vector (for inter modes).
158    pub mv: Option<MotionVector>,
159    /// Skip residual flag.
160    pub skip: bool,
161}
162
163impl ModeCandidate {
164    /// Create a new mode candidate.
165    #[must_use]
166    pub fn new(block_size: BlockSize, pred_mode: PredictionMode) -> Self {
167        Self {
168            block_size,
169            partition: PartitionType::None,
170            pred_mode,
171            tx_size: TxSize::Tx4x4,
172            tx_type: TxType::DctDct,
173            cost: f32::MAX,
174            distortion: 0.0,
175            rate: 0,
176            mv: None,
177            skip: false,
178        }
179    }
180
181    /// Check if this is an intra candidate.
182    #[must_use]
183    pub const fn is_intra(&self) -> bool {
184        matches!(self.pred_mode, PredictionMode::Intra(_))
185    }
186
187    /// Check if this is an inter candidate.
188    #[must_use]
189    pub const fn is_inter(&self) -> bool {
190        matches!(self.pred_mode, PredictionMode::Inter(_))
191    }
192}
193
194/// Prediction mode (intra or inter).
195#[derive(Clone, Copy, Debug, PartialEq, Eq)]
196pub enum PredictionMode {
197    /// Intra prediction.
198    Intra(IntraMode),
199    /// Inter prediction.
200    Inter(InterMode),
201}
202
203// =============================================================================
204// Mode Decision Engine
205// =============================================================================
206
207/// Mode decision engine for AV1 encoding.
208#[derive(Clone, Debug)]
209pub struct ModeDecision {
210    /// Configuration.
211    config: ModeDecisionConfig,
212    /// Best cost found so far (for early termination).
213    best_cost: f32,
214}
215
216impl ModeDecision {
217    /// Create a new mode decision engine.
218    #[must_use]
219    pub fn new(config: ModeDecisionConfig) -> Self {
220        Self {
221            config,
222            best_cost: f32::MAX,
223        }
224    }
225
226    /// Create with default configuration.
227    #[must_use]
228    pub fn with_defaults() -> Self {
229        Self::new(ModeDecisionConfig::default())
230    }
231
232    /// Set lambda value.
233    pub fn set_lambda(&mut self, lambda: f32) {
234        self.config.lambda = lambda;
235    }
236
237    /// Reset for new frame.
238    pub fn reset(&mut self) {
239        self.best_cost = f32::MAX;
240    }
241
242    /// Decide partition for a block.
243    ///
244    /// Returns the best partition type based on RD cost.
245    #[allow(clippy::unused_self)]
246    pub fn decide_partition(
247        &self,
248        _src: &[u8],
249        _src_stride: usize,
250        block_size: BlockSize,
251        _depth: u8,
252    ) -> PartitionType {
253        // Simple heuristic: larger blocks prefer splitting
254        if block_size.width() >= 64 {
255            PartitionType::Split
256        } else if block_size.width() >= 32 {
257            // Consider split based on content complexity
258            PartitionType::None
259        } else {
260            PartitionType::None
261        }
262    }
263
264    /// Decide best intra mode for a block.
265    ///
266    /// Tests multiple intra modes and returns the one with lowest RD cost.
267    pub fn decide_intra_mode(
268        &mut self,
269        src: &[u8],
270        src_stride: usize,
271        recon_left: &[u8],
272        recon_above: &[u8],
273        block_size: BlockSize,
274    ) -> ModeCandidate {
275        let mut best_candidate =
276            ModeCandidate::new(block_size, PredictionMode::Intra(IntraMode::DcPred));
277        let mut best_cost = f32::MAX;
278
279        // Test common intra modes
280        let modes_to_test = self.get_intra_modes_to_test(block_size);
281
282        for mode in modes_to_test {
283            let candidate = self.evaluate_intra_mode(
284                src,
285                src_stride,
286                recon_left,
287                recon_above,
288                block_size,
289                mode,
290            );
291
292            if candidate.cost < best_cost {
293                best_cost = candidate.cost;
294                best_candidate = candidate;
295
296                // Early termination
297                if self.config.early_termination
298                    && best_cost < self.best_cost * EARLY_TERM_THRESHOLD
299                {
300                    break;
301                }
302            }
303        }
304
305        self.best_cost = self.best_cost.min(best_cost);
306        best_candidate
307    }
308
309    /// Decide best inter mode for a block.
310    ///
311    /// Performs motion search and evaluates inter prediction modes.
312    #[allow(clippy::too_many_arguments)]
313    pub fn decide_inter_mode(
314        &mut self,
315        src: &[u8],
316        src_stride: usize,
317        ref_frame: &[u8],
318        ref_stride: usize,
319        block_size: BlockSize,
320        x: u32,
321        y: u32,
322        frame_width: u32,
323        frame_height: u32,
324    ) -> ModeCandidate {
325        // Perform motion search
326        let mv = self.search_motion(
327            src,
328            src_stride,
329            ref_frame,
330            ref_stride,
331            block_size,
332            x,
333            y,
334            frame_width,
335            frame_height,
336        );
337
338        // Evaluate inter mode with found MV
339        let mut candidate = ModeCandidate::new(block_size, PredictionMode::Inter(InterMode::NewMv));
340        candidate.mv = Some(mv.mv);
341
342        // Compute distortion
343        let distortion = self
344            .compute_inter_distortion(src, src_stride, ref_frame, ref_stride, block_size, &mv.mv);
345
346        // Estimate rate (simplified)
347        let rate = self.estimate_inter_rate(block_size, &mv.mv);
348
349        candidate.distortion = distortion as f32;
350        candidate.rate = rate;
351        candidate.cost = distortion as f32 + self.config.lambda * rate as f32;
352
353        candidate
354    }
355
356    /// Compute RD cost for a mode candidate.
357    pub fn compute_rd_cost(&self, candidate: &ModeCandidate) -> f32 {
358        if self.config.rd_optimization {
359            candidate.distortion + self.config.lambda * candidate.rate as f32
360        } else {
361            // Fast mode: just use distortion
362            candidate.distortion
363        }
364    }
365
366    // =========================================================================
367    // Internal Helper Methods
368    // =========================================================================
369
370    /// Get list of intra modes to test based on block size and preset.
371    fn get_intra_modes_to_test(&self, block_size: BlockSize) -> Vec<IntraMode> {
372        let mut modes = vec![
373            IntraMode::DcPred,
374            IntraMode::VPred,
375            IntraMode::HPred,
376            IntraMode::PaethPred,
377        ];
378
379        if self.config.preset_speed <= 5 {
380            // Medium and slower: test more modes
381            modes.extend_from_slice(&[
382                IntraMode::D45Pred,
383                IntraMode::D135Pred,
384                IntraMode::SmoothPred,
385            ]);
386        }
387
388        if self.config.preset_speed <= 3 && block_size.width() >= 8 {
389            // Slow: test all directional modes
390            modes.extend_from_slice(&[
391                IntraMode::D67Pred,
392                IntraMode::D113Pred,
393                IntraMode::D157Pred,
394                IntraMode::D203Pred,
395                IntraMode::SmoothVPred,
396                IntraMode::SmoothHPred,
397            ]);
398        }
399
400        modes.truncate(MAX_INTRA_CANDIDATES);
401        modes
402    }
403
404    /// Evaluate a specific intra mode.
405    fn evaluate_intra_mode(
406        &self,
407        src: &[u8],
408        src_stride: usize,
409        _recon_left: &[u8],
410        _recon_above: &[u8],
411        block_size: BlockSize,
412        mode: IntraMode,
413    ) -> ModeCandidate {
414        let mut candidate = ModeCandidate::new(block_size, PredictionMode::Intra(mode));
415
416        // Generate prediction (simplified - uses DC for all modes in this implementation)
417        let pred = self.generate_intra_prediction(block_size, mode);
418
419        // Compute distortion (SSE)
420        let distortion = self.compute_sse(
421            src,
422            src_stride,
423            &pred,
424            block_size.width() as usize,
425            block_size,
426        );
427
428        // Estimate rate
429        let rate = self.estimate_intra_rate(block_size, mode);
430
431        candidate.distortion = distortion as f32;
432        candidate.rate = rate;
433        candidate.cost = distortion as f32 + self.config.lambda * rate as f32;
434        candidate.tx_size = self.select_tx_size(block_size);
435
436        candidate
437    }
438
439    /// Generate intra prediction (simplified implementation).
440    fn generate_intra_prediction(&self, block_size: BlockSize, mode: IntraMode) -> Vec<u8> {
441        let size = (block_size.width() * block_size.height()) as usize;
442        let pred_value = match mode {
443            IntraMode::DcPred => 128,
444            IntraMode::VPred => 128,
445            IntraMode::HPred => 128,
446            _ => 128,
447        };
448        vec![pred_value; size]
449    }
450
451    /// Compute sum of squared errors.
452    fn compute_sse(
453        &self,
454        src: &[u8],
455        src_stride: usize,
456        pred: &[u8],
457        pred_stride: usize,
458        block_size: BlockSize,
459    ) -> u64 {
460        let w = block_size.width() as usize;
461        let h = block_size.height() as usize;
462        let mut sse = 0u64;
463
464        for y in 0..h {
465            for x in 0..w {
466                if y * src_stride + x < src.len() && y * pred_stride + x < pred.len() {
467                    let diff =
468                        i32::from(src[y * src_stride + x]) - i32::from(pred[y * pred_stride + x]);
469                    sse += (diff * diff) as u64;
470                }
471            }
472        }
473
474        sse
475    }
476
477    /// Estimate rate for intra mode.
478    fn estimate_intra_rate(&self, block_size: BlockSize, _mode: IntraMode) -> u32 {
479        // Simplified rate estimation
480        let base_rate = 8; // Mode overhead
481        let coeff_rate = (block_size.area() / 4) as u32; // Rough coefficient bits
482        base_rate + coeff_rate
483    }
484
485    /// Estimate rate for inter mode.
486    fn estimate_inter_rate(&self, block_size: BlockSize, mv: &MotionVector) -> u32 {
487        // MV rate (simplified)
488        let mv_rate = (mv.dx.abs() + mv.dy.abs()) as u32 / 4 + 4;
489
490        // Coefficient rate
491        let coeff_rate = (block_size.area() / 8) as u32;
492
493        mv_rate + coeff_rate + 4 // Mode overhead
494    }
495
496    /// Search motion for inter prediction.
497    #[allow(clippy::too_many_arguments)]
498    fn search_motion(
499        &self,
500        src: &[u8],
501        src_stride: usize,
502        ref_frame: &[u8],
503        ref_stride: usize,
504        block_size: BlockSize,
505        x: u32,
506        y: u32,
507        frame_width: u32,
508        frame_height: u32,
509    ) -> BlockMatch {
510        // Create search context
511        let ctx = SearchContext::new(
512            src,
513            src_stride,
514            ref_frame,
515            ref_stride,
516            crate::motion::BlockSize::Block8x8, // Convert to motion BlockSize
517            x as usize,
518            y as usize,
519            frame_width as usize,
520            frame_height as usize,
521        );
522
523        // Configure search range based on preset
524        let search_range = if self.config.preset_speed >= 8 {
525            16 // Fast: small range
526        } else if self.config.preset_speed >= 5 {
527            32 // Medium
528        } else {
529            64 // Slow: large range
530        };
531
532        let search_config =
533            SearchConfig::default().range(crate::motion::SearchRange::symmetric(search_range));
534
535        // Perform diamond search
536        let searcher = DiamondSearch::new();
537        searcher.search(&ctx, &search_config)
538    }
539
540    /// Compute inter prediction distortion.
541    fn compute_inter_distortion(
542        &self,
543        src: &[u8],
544        src_stride: usize,
545        ref_frame: &[u8],
546        ref_stride: usize,
547        block_size: BlockSize,
548        mv: &MotionVector,
549    ) -> u64 {
550        let w = block_size.width() as usize;
551        let h = block_size.height() as usize;
552
553        let ref_x = mv.full_pel_x() as usize;
554        let ref_y = mv.full_pel_y() as usize;
555
556        let mut sad = 0u64;
557
558        for y in 0..h {
559            for x in 0..w {
560                let src_idx = y * src_stride + x;
561                let ref_idx = (y + ref_y) * ref_stride + (x + ref_x);
562
563                if src_idx < src.len() && ref_idx < ref_frame.len() {
564                    let diff = i32::from(src[src_idx]).abs_diff(i32::from(ref_frame[ref_idx]));
565                    sad += u64::from(diff);
566                }
567            }
568        }
569
570        if self.config.use_satd {
571            // Convert SAD to approximate SATD
572            (sad * 12) / 10
573        } else {
574            sad
575        }
576    }
577
578    /// Select transform size for block.
579    fn select_tx_size(&self, block_size: BlockSize) -> TxSize {
580        if self.config.tx_size_rdo {
581            // RDO-based selection (simplified: use max)
582            block_size.max_tx_size()
583        } else {
584            // Fast: use max transform size
585            block_size.max_tx_size()
586        }
587    }
588}
589
590// =============================================================================
591// Lagrangian Multiplier Computation
592// =============================================================================
593
594/// Compute lagrangian multiplier from QP.
595///
596/// Uses the formula: λ = 0.85 * 2^((QP - 12) / 3)
597#[must_use]
598pub fn compute_lambda_from_qp(qp: u8) -> f32 {
599    let qp_f = f32::from(qp);
600    0.85 * 2.0_f32.powf((qp_f - 12.0) / 3.0)
601}
602
603/// Compute QP from lambda (inverse).
604#[must_use]
605pub fn compute_qp_from_lambda(lambda: f32) -> u8 {
606    let qp = 12.0 + 3.0 * (lambda / 0.85).log2();
607    qp.clamp(0.0, 255.0) as u8
608}
609
610// =============================================================================
611// Partition Decision Helper
612// =============================================================================
613
614/// Partition decision result.
615#[derive(Clone, Debug)]
616pub struct PartitionDecision {
617    /// Selected partition type.
618    pub partition: PartitionType,
619    /// RD cost of this partition.
620    pub cost: f32,
621    /// Whether to recurse into sub-partitions.
622    pub recurse: bool,
623}
624
625impl PartitionDecision {
626    /// Create a decision to not split.
627    #[must_use]
628    pub const fn no_split(cost: f32) -> Self {
629        Self {
630            partition: PartitionType::None,
631            cost,
632            recurse: false,
633        }
634    }
635
636    /// Create a decision to split.
637    #[must_use]
638    pub const fn split(cost: f32) -> Self {
639        Self {
640            partition: PartitionType::Split,
641            cost,
642            recurse: true,
643        }
644    }
645}
646
647// =============================================================================
648// Rate Estimation Tables
649// =============================================================================
650
651/// Rate estimation context.
652#[derive(Clone, Debug, Default)]
653pub struct RateEstimator {
654    /// Intra mode rate table.
655    pub intra_mode_bits: [f32; 13],
656    /// Inter mode rate table.
657    pub inter_mode_bits: [f32; 4],
658    /// Partition rate table.
659    pub partition_bits: [f32; 10],
660}
661
662impl RateEstimator {
663    /// Create new rate estimator with default tables.
664    #[must_use]
665    pub fn new() -> Self {
666        Self {
667            intra_mode_bits: [
668                3.0, 3.5, 3.5, 4.0, 4.0, 4.5, 4.5, 4.5, 4.5, 4.0, 4.5, 4.5, 4.0,
669            ],
670            inter_mode_bits: [2.0, 2.5, 3.0, 3.5],
671            partition_bits: [2.0, 3.0, 3.0, 3.5, 4.5, 4.5, 4.5, 4.5, 4.0, 4.0],
672        }
673    }
674
675    /// Get intra mode rate.
676    #[must_use]
677    pub fn intra_mode_rate(&self, mode: IntraMode) -> f32 {
678        self.intra_mode_bits[mode as usize]
679    }
680
681    /// Get inter mode rate.
682    #[must_use]
683    pub fn inter_mode_rate(&self, mode: InterMode) -> f32 {
684        self.inter_mode_bits[mode as usize]
685    }
686
687    /// Get partition rate.
688    #[must_use]
689    pub fn partition_rate(&self, partition: PartitionType) -> f32 {
690        self.partition_bits[partition as usize]
691    }
692}
693
694// =============================================================================
695// Tests
696// =============================================================================
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701
702    #[test]
703    fn test_mode_decision_config_default() {
704        let config = ModeDecisionConfig::default();
705        assert!(config.rd_optimization);
706        assert!(config.early_termination);
707        assert_eq!(config.max_partition_depth, 4);
708    }
709
710    #[test]
711    fn test_mode_decision_config_from_qp() {
712        let config = ModeDecisionConfig::from_qp(28);
713        assert!(config.lambda > 0.0);
714        assert!(config.lambda < 100.0);
715    }
716
717    #[test]
718    fn test_mode_decision_config_presets() {
719        let fast = ModeDecisionConfig::fast();
720        assert!(!fast.rd_optimization);
721        assert_eq!(fast.preset_speed, 8);
722
723        let slow = ModeDecisionConfig::slow();
724        assert!(slow.rd_optimization);
725        assert_eq!(slow.preset_speed, 2);
726    }
727
728    #[test]
729    fn test_mode_candidate_creation() {
730        let candidate = ModeCandidate::new(
731            BlockSize::Block8x8,
732            PredictionMode::Intra(IntraMode::DcPred),
733        );
734        assert_eq!(candidate.block_size, BlockSize::Block8x8);
735        assert!(candidate.is_intra());
736        assert!(!candidate.is_inter());
737        assert_eq!(candidate.cost, f32::MAX);
738    }
739
740    #[test]
741    fn test_mode_decision_creation() {
742        let config = ModeDecisionConfig::default();
743        let md = ModeDecision::new(config);
744        assert_eq!(md.best_cost, f32::MAX);
745    }
746
747    #[test]
748    fn test_lambda_computation() {
749        let lambda = compute_lambda_from_qp(28);
750        assert!(lambda > 0.0);
751        assert!(lambda < 100.0);
752
753        // Test inverse
754        let qp = compute_qp_from_lambda(lambda);
755        assert!((qp as i32 - 28).abs() <= 1);
756    }
757
758    #[test]
759    fn test_lambda_qp_roundtrip() {
760        for qp in [0, 10, 20, 30, 40, 50, 63] {
761            let lambda = compute_lambda_from_qp(qp);
762            let qp_back = compute_qp_from_lambda(lambda);
763            assert!((qp_back as i32 - qp as i32).abs() <= 2);
764        }
765    }
766
767    #[test]
768    fn test_partition_decision() {
769        let no_split = PartitionDecision::no_split(100.0);
770        assert_eq!(no_split.partition, PartitionType::None);
771        assert!(!no_split.recurse);
772
773        let split = PartitionDecision::split(150.0);
774        assert_eq!(split.partition, PartitionType::Split);
775        assert!(split.recurse);
776    }
777
778    #[test]
779    fn test_rate_estimator() {
780        let estimator = RateEstimator::new();
781
782        let dc_rate = estimator.intra_mode_rate(IntraMode::DcPred);
783        assert!(dc_rate > 0.0);
784        assert!(dc_rate < 10.0);
785
786        let inter_rate = estimator.inter_mode_rate(InterMode::NewMv);
787        assert!(inter_rate > 0.0);
788    }
789
790    #[test]
791    fn test_intra_modes_to_test_fast() {
792        let config = ModeDecisionConfig::fast();
793        let md = ModeDecision::new(config);
794        let modes = md.get_intra_modes_to_test(BlockSize::Block8x8);
795
796        assert!(!modes.is_empty());
797        assert!(modes.len() <= MAX_INTRA_CANDIDATES);
798        assert!(modes.contains(&IntraMode::DcPred));
799    }
800
801    #[test]
802    fn test_intra_modes_to_test_slow() {
803        let config = ModeDecisionConfig::slow();
804        let md = ModeDecision::new(config);
805        let modes = md.get_intra_modes_to_test(BlockSize::Block16x16);
806
807        assert!(modes.len() > 4);
808        assert!(modes.contains(&IntraMode::DcPred));
809        assert!(modes.contains(&IntraMode::D45Pred));
810    }
811
812    #[test]
813    fn test_intra_prediction_generation() {
814        let config = ModeDecisionConfig::default();
815        let md = ModeDecision::new(config);
816
817        let pred = md.generate_intra_prediction(BlockSize::Block8x8, IntraMode::DcPred);
818        assert_eq!(pred.len(), 64);
819        assert!(pred.iter().all(|&p| p == 128));
820    }
821
822    #[test]
823    fn test_sse_computation() {
824        let config = ModeDecisionConfig::default();
825        let md = ModeDecision::new(config);
826
827        let src = vec![100u8; 64];
828        let pred = vec![100u8; 64];
829
830        let sse = md.compute_sse(&src, 8, &pred, 8, BlockSize::Block8x8);
831        assert_eq!(sse, 0); // Identical blocks
832
833        let pred2 = vec![110u8; 64];
834        let sse2 = md.compute_sse(&src, 8, &pred2, 8, BlockSize::Block8x8);
835        assert!(sse2 > 0); // Different blocks
836        assert_eq!(sse2, 6400); // 64 * (10^2)
837    }
838
839    #[test]
840    fn test_decide_partition_simple() {
841        let config = ModeDecisionConfig::default();
842        let md = ModeDecision::new(config);
843
844        let src = vec![128u8; 64 * 64];
845        let partition = md.decide_partition(&src, 64, BlockSize::Block64x64, 0);
846
847        // Large blocks should split
848        assert_eq!(partition, PartitionType::Split);
849
850        let partition_small = md.decide_partition(&src, 8, BlockSize::Block8x8, 2);
851        assert_eq!(partition_small, PartitionType::None);
852    }
853
854    #[test]
855    fn test_tx_size_selection() {
856        let config = ModeDecisionConfig::default();
857        let md = ModeDecision::new(config);
858
859        let tx_size = md.select_tx_size(BlockSize::Block8x8);
860        assert_eq!(tx_size, TxSize::Tx8x8);
861
862        let tx_size_16 = md.select_tx_size(BlockSize::Block16x16);
863        assert_eq!(tx_size_16, TxSize::Tx16x16);
864    }
865
866    #[test]
867    fn test_rate_estimation() {
868        let config = ModeDecisionConfig::default();
869        let md = ModeDecision::new(config);
870
871        let rate = md.estimate_intra_rate(BlockSize::Block8x8, IntraMode::DcPred);
872        assert!(rate > 0);
873        assert!(rate < 1000);
874
875        let mv = MotionVector::new(4, 4);
876        let inter_rate = md.estimate_inter_rate(BlockSize::Block8x8, &mv);
877        assert!(inter_rate > 0);
878    }
879
880    #[test]
881    fn test_rd_cost_computation() {
882        let config = ModeDecisionConfig::from_qp(28);
883        let md = ModeDecision::new(config);
884
885        let mut candidate = ModeCandidate::new(
886            BlockSize::Block8x8,
887            PredictionMode::Intra(IntraMode::DcPred),
888        );
889        candidate.distortion = 1000.0;
890        candidate.rate = 100;
891
892        let cost = md.compute_rd_cost(&candidate);
893        assert!(cost > candidate.distortion);
894        assert!(cost < 10000.0);
895    }
896
897    #[test]
898    fn test_prediction_mode() {
899        let intra = PredictionMode::Intra(IntraMode::DcPred);
900        assert!(matches!(intra, PredictionMode::Intra(_)));
901
902        let inter = PredictionMode::Inter(InterMode::NewMv);
903        assert!(matches!(inter, PredictionMode::Inter(_)));
904    }
905}