Skip to main content

oximedia_codec/motion/
partition.rs

1//! Block partitioning decisions for motion estimation.
2//!
3//! This module provides:
4//! - Partition decision structures
5//! - Split decisions (16x16 -> 8x8 -> 4x4)
6//! - Skip detection for direct mode
7//! - Merge candidate generation
8//!
9//! Efficient partitioning is crucial for balancing compression
10//! efficiency with encoding complexity.
11
12#![forbid(unsafe_code)]
13#![allow(dead_code)]
14#![allow(clippy::too_many_arguments)]
15#![allow(clippy::must_use_candidate)]
16#![allow(clippy::cast_lossless)]
17#![allow(clippy::cast_precision_loss)]
18#![allow(clippy::cast_possible_truncation)]
19#![allow(clippy::cast_sign_loss)]
20#![allow(clippy::match_same_arms)]
21
22use super::types::{BlockMatch, BlockSize, MotionVector, MvCost};
23
24/// Threshold for skip mode SAD.
25pub const SKIP_THRESHOLD: u32 = 64;
26
27/// Threshold for considering a partition split.
28pub const SPLIT_THRESHOLD_RATIO: f32 = 0.8;
29
30/// Maximum merge candidates.
31pub const MAX_MERGE_CANDIDATES: usize = 5;
32
33/// Partition type for a block.
34#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
35#[repr(u8)]
36pub enum PartitionType {
37    /// No partition (use current block size).
38    #[default]
39    None = 0,
40    /// Horizontal split (top/bottom halves).
41    HorizontalSplit = 1,
42    /// Vertical split (left/right halves).
43    VerticalSplit = 2,
44    /// Quad split (4 equal quadrants).
45    Split = 3,
46    /// Horizontal split with top half smaller.
47    HorizontalA = 4,
48    /// Horizontal split with bottom half smaller.
49    HorizontalB = 5,
50    /// Vertical split with left half smaller.
51    VerticalA = 6,
52    /// Vertical split with right half smaller.
53    VerticalB = 7,
54    /// Horizontal 4-way split.
55    Horizontal4 = 8,
56    /// Vertical 4-way split.
57    Vertical4 = 9,
58}
59
60impl PartitionType {
61    /// Returns the number of sub-partitions.
62    #[must_use]
63    pub const fn num_parts(&self) -> usize {
64        match self {
65            Self::None => 1,
66            Self::HorizontalSplit | Self::VerticalSplit => 2,
67            Self::Split
68            | Self::HorizontalA
69            | Self::HorizontalB
70            | Self::VerticalA
71            | Self::VerticalB => 4,
72            Self::Horizontal4 | Self::Vertical4 => 4,
73        }
74    }
75
76    /// Returns true if this is a split partition.
77    #[must_use]
78    pub const fn is_split(&self) -> bool {
79        !matches!(self, Self::None)
80    }
81}
82
83/// Mode for inter prediction.
84#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
85#[repr(u8)]
86pub enum InterMode {
87    /// Skip mode (use predicted MV directly).
88    Skip = 0,
89    /// Merge mode (copy MV from neighbor).
90    Merge = 1,
91    /// New MV mode (search for best MV).
92    #[default]
93    NewMv = 2,
94    /// Nearest MV mode (use nearest neighbor MV).
95    NearestMv = 3,
96    /// Near MV mode (use second-nearest neighbor MV).
97    NearMv = 4,
98    /// Zero MV mode.
99    ZeroMv = 5,
100}
101
102impl InterMode {
103    /// Returns true if this mode requires MV search.
104    #[must_use]
105    pub const fn requires_search(&self) -> bool {
106        matches!(self, Self::NewMv)
107    }
108
109    /// Returns true if this mode uses a predictor MV.
110    #[must_use]
111    pub const fn uses_predictor(&self) -> bool {
112        matches!(self, Self::NearestMv | Self::NearMv | Self::ZeroMv)
113    }
114}
115
116/// Decision for a single partition.
117#[derive(Clone, Debug)]
118pub struct PartitionDecision {
119    /// Block size for this partition.
120    pub block_size: BlockSize,
121    /// Partition type.
122    pub partition_type: PartitionType,
123    /// Inter prediction mode.
124    pub mode: InterMode,
125    /// Motion vector.
126    pub mv: MotionVector,
127    /// Reference frame index.
128    pub ref_idx: i8,
129    /// Rate-distortion cost.
130    pub cost: u32,
131    /// Distortion (SAD/SATD).
132    pub distortion: u32,
133    /// Estimated bits for this partition.
134    pub bits: u32,
135    /// Is this a skip block?
136    pub is_skip: bool,
137    /// Merge candidate index (if merge mode).
138    pub merge_idx: u8,
139}
140
141impl Default for PartitionDecision {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl PartitionDecision {
148    /// Creates a new default decision.
149    #[must_use]
150    pub const fn new() -> Self {
151        Self {
152            block_size: BlockSize::Block8x8,
153            partition_type: PartitionType::None,
154            mode: InterMode::NewMv,
155            mv: MotionVector::zero(),
156            ref_idx: 0,
157            cost: u32::MAX,
158            distortion: u32::MAX,
159            bits: 0,
160            is_skip: false,
161            merge_idx: 0,
162        }
163    }
164
165    /// Creates a skip decision.
166    #[must_use]
167    pub const fn skip(block_size: BlockSize, mv: MotionVector, distortion: u32) -> Self {
168        Self {
169            block_size,
170            partition_type: PartitionType::None,
171            mode: InterMode::Skip,
172            mv,
173            ref_idx: 0,
174            cost: distortion,
175            distortion,
176            bits: 0,
177            is_skip: true,
178            merge_idx: 0,
179        }
180    }
181
182    /// Creates a decision from block match result.
183    #[must_use]
184    pub const fn from_match(block_size: BlockSize, block_match: &BlockMatch) -> Self {
185        Self {
186            block_size,
187            partition_type: PartitionType::None,
188            mode: InterMode::NewMv,
189            mv: block_match.mv,
190            ref_idx: 0,
191            cost: block_match.cost,
192            distortion: block_match.sad,
193            bits: 0,
194            is_skip: false,
195            merge_idx: 0,
196        }
197    }
198
199    /// Checks if this decision is better than another.
200    #[must_use]
201    pub const fn is_better_than(&self, other: &Self) -> bool {
202        self.cost < other.cost
203    }
204
205    /// Updates with a better decision.
206    pub fn update_if_better(&mut self, other: &Self) {
207        if other.is_better_than(self) {
208            *self = other.clone();
209        }
210    }
211}
212
213/// Split decision result for recursive partitioning.
214#[derive(Clone, Debug)]
215pub struct SplitDecision {
216    /// Should we split this block?
217    pub should_split: bool,
218    /// Cost of the unsplit block.
219    pub unsplit_cost: u32,
220    /// Cost of the split blocks (sum).
221    pub split_cost: u32,
222    /// Child decisions (if split).
223    pub children: Vec<PartitionDecision>,
224}
225
226impl Default for SplitDecision {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232impl SplitDecision {
233    /// Creates a new split decision.
234    #[must_use]
235    pub const fn new() -> Self {
236        Self {
237            should_split: false,
238            unsplit_cost: u32::MAX,
239            split_cost: u32::MAX,
240            children: Vec::new(),
241        }
242    }
243
244    /// Creates a decision to not split.
245    #[must_use]
246    pub const fn no_split(cost: u32) -> Self {
247        Self {
248            should_split: false,
249            unsplit_cost: cost,
250            split_cost: u32::MAX,
251            children: Vec::new(),
252        }
253    }
254
255    /// Creates a decision to split.
256    #[must_use]
257    pub fn split(unsplit_cost: u32, split_cost: u32, children: Vec<PartitionDecision>) -> Self {
258        Self {
259            should_split: split_cost < unsplit_cost,
260            unsplit_cost,
261            split_cost,
262            children,
263        }
264    }
265}
266
267/// Skip detection for inter prediction.
268#[derive(Clone, Debug, Default)]
269pub struct SkipDetector {
270    /// SAD threshold for skip.
271    threshold: u32,
272    /// MV cost weight.
273    mv_weight: f32,
274}
275
276impl SkipDetector {
277    /// Creates a new skip detector.
278    #[must_use]
279    pub const fn new(threshold: u32) -> Self {
280        Self {
281            threshold,
282            mv_weight: 1.0,
283        }
284    }
285
286    /// Sets the MV cost weight.
287    #[must_use]
288    pub const fn with_mv_weight(mut self, weight: f32) -> Self {
289        self.mv_weight = weight;
290        self
291    }
292
293    /// Checks if a block can be skipped.
294    #[must_use]
295    pub fn can_skip(&self, block_match: &BlockMatch, predicted_mv: &MotionVector) -> bool {
296        // Skip if SAD is low enough
297        if block_match.sad > self.threshold {
298            return false;
299        }
300
301        // Skip if MV matches predicted MV well
302        let mv_diff = block_match.mv - *predicted_mv;
303        let mv_dist = mv_diff.l1_norm();
304
305        // Allow skip if MV is close to prediction
306        mv_dist < 8 // Within 1 full pixel
307    }
308
309    /// Evaluates skip mode cost.
310    #[must_use]
311    pub fn evaluate_skip(
312        &self,
313        block_match: &BlockMatch,
314        predicted_mv: &MotionVector,
315        mv_cost: &MvCost,
316    ) -> Option<PartitionDecision> {
317        if !self.can_skip(block_match, predicted_mv) {
318            return None;
319        }
320
321        // Calculate skip mode cost (no MV bits needed)
322        let skip_cost = block_match.sad + 1; // +1 for skip flag
323
324        // Compare with regular mode cost
325        let regular_cost = mv_cost.rd_cost(&block_match.mv, block_match.sad);
326
327        if skip_cost < regular_cost {
328            Some(PartitionDecision::skip(
329                BlockSize::Block8x8,
330                *predicted_mv,
331                block_match.sad,
332            ))
333        } else {
334            None
335        }
336    }
337}
338
339/// Merge candidate for merge mode.
340#[derive(Clone, Copy, Debug)]
341pub struct MergeCandidate {
342    /// Motion vector.
343    pub mv: MotionVector,
344    /// Reference frame index.
345    pub ref_idx: i8,
346    /// Source of this candidate.
347    pub source: MergeSource,
348}
349
350impl MergeCandidate {
351    /// Creates a new merge candidate.
352    #[must_use]
353    pub const fn new(mv: MotionVector, ref_idx: i8, source: MergeSource) -> Self {
354        Self {
355            mv,
356            ref_idx,
357            source,
358        }
359    }
360
361    /// Creates a zero MV candidate.
362    #[must_use]
363    pub const fn zero() -> Self {
364        Self {
365            mv: MotionVector::zero(),
366            ref_idx: 0,
367            source: MergeSource::Zero,
368        }
369    }
370}
371
372/// Source of merge candidate.
373#[derive(Clone, Copy, Debug, PartialEq, Eq)]
374pub enum MergeSource {
375    /// Left neighbor.
376    Left,
377    /// Top neighbor.
378    Top,
379    /// Top-right neighbor.
380    TopRight,
381    /// Top-left neighbor.
382    TopLeft,
383    /// Co-located temporal.
384    CoLocated,
385    /// Zero MV.
386    Zero,
387}
388
389/// Merge candidate list.
390#[derive(Clone, Debug)]
391pub struct MergeCandidateList {
392    /// Candidates.
393    candidates: [MergeCandidate; MAX_MERGE_CANDIDATES],
394    /// Number of valid candidates.
395    count: usize,
396}
397
398impl Default for MergeCandidateList {
399    fn default() -> Self {
400        Self::new()
401    }
402}
403
404impl MergeCandidateList {
405    /// Creates a new empty list.
406    #[must_use]
407    pub fn new() -> Self {
408        Self {
409            candidates: [MergeCandidate::zero(); MAX_MERGE_CANDIDATES],
410            count: 0,
411        }
412    }
413
414    /// Adds a candidate.
415    pub fn add(&mut self, candidate: MergeCandidate) {
416        if self.count >= MAX_MERGE_CANDIDATES {
417            return;
418        }
419
420        // Check for duplicates
421        for i in 0..self.count {
422            if self.candidates[i].mv == candidate.mv
423                && self.candidates[i].ref_idx == candidate.ref_idx
424            {
425                return;
426            }
427        }
428
429        self.candidates[self.count] = candidate;
430        self.count += 1;
431    }
432
433    /// Returns the number of candidates.
434    #[must_use]
435    pub const fn len(&self) -> usize {
436        self.count
437    }
438
439    /// Returns true if empty.
440    #[must_use]
441    pub const fn is_empty(&self) -> bool {
442        self.count == 0
443    }
444
445    /// Gets a candidate by index.
446    #[must_use]
447    pub const fn get(&self, index: usize) -> Option<&MergeCandidate> {
448        if index < self.count {
449            Some(&self.candidates[index])
450        } else {
451            None
452        }
453    }
454
455    /// Returns candidates as slice.
456    #[must_use]
457    pub fn as_slice(&self) -> &[MergeCandidate] {
458        &self.candidates[..self.count]
459    }
460}
461
462/// Partition context for decision making.
463#[derive(Clone, Debug, Default)]
464pub struct PartitionContext {
465    /// Block position (x in pixels).
466    pub x: usize,
467    /// Block position (y in pixels).
468    pub y: usize,
469    /// Frame width.
470    pub frame_width: usize,
471    /// Frame height.
472    pub frame_height: usize,
473    /// Maximum block size allowed.
474    pub max_block_size: BlockSize,
475    /// Minimum block size allowed.
476    pub min_block_size: BlockSize,
477    /// Lambda for RD.
478    pub lambda: f32,
479}
480
481impl PartitionContext {
482    /// Creates a new partition context.
483    #[must_use]
484    pub const fn new(frame_width: usize, frame_height: usize) -> Self {
485        Self {
486            x: 0,
487            y: 0,
488            frame_width,
489            frame_height,
490            max_block_size: BlockSize::Block64x64,
491            min_block_size: BlockSize::Block4x4,
492            lambda: 1.0,
493        }
494    }
495
496    /// Sets the block position.
497    #[must_use]
498    pub const fn at(mut self, x: usize, y: usize) -> Self {
499        self.x = x;
500        self.y = y;
501        self
502    }
503
504    /// Sets the block size limits.
505    #[must_use]
506    pub const fn with_size_range(mut self, min: BlockSize, max: BlockSize) -> Self {
507        self.min_block_size = min;
508        self.max_block_size = max;
509        self
510    }
511
512    /// Checks if a block size fits within frame bounds.
513    #[must_use]
514    pub fn can_use_size(&self, size: BlockSize) -> bool {
515        self.x + size.width() <= self.frame_width && self.y + size.height() <= self.frame_height
516    }
517
518    /// Returns the child block size for quad split.
519    #[must_use]
520    pub const fn child_size(size: BlockSize) -> Option<BlockSize> {
521        match size {
522            BlockSize::Block128x128 => Some(BlockSize::Block64x64),
523            BlockSize::Block64x64 => Some(BlockSize::Block32x32),
524            BlockSize::Block32x32 => Some(BlockSize::Block16x16),
525            BlockSize::Block16x16 => Some(BlockSize::Block8x8),
526            BlockSize::Block8x8 => Some(BlockSize::Block4x4),
527            _ => None,
528        }
529    }
530}
531
532/// Partition decision maker.
533#[derive(Clone, Debug, Default)]
534pub struct PartitionDecider {
535    /// Skip detector.
536    skip_detector: SkipDetector,
537    /// Cost ratio threshold for splitting.
538    split_threshold: f32,
539}
540
541impl PartitionDecider {
542    /// Creates a new partition decider.
543    #[must_use]
544    pub fn new() -> Self {
545        Self {
546            skip_detector: SkipDetector::new(SKIP_THRESHOLD),
547            split_threshold: SPLIT_THRESHOLD_RATIO,
548        }
549    }
550
551    /// Sets the split threshold.
552    #[must_use]
553    pub const fn with_split_threshold(mut self, threshold: f32) -> Self {
554        self.split_threshold = threshold;
555        self
556    }
557
558    /// Decides whether to split a block.
559    pub fn decide_split(
560        &self,
561        unsplit_result: &PartitionDecision,
562        split_results: &[PartitionDecision],
563        ctx: &PartitionContext,
564    ) -> SplitDecision {
565        let unsplit_cost = unsplit_result.cost;
566
567        // Calculate split cost (sum of children + overhead)
568        let split_overhead = (ctx.lambda * 2.0) as u32; // Bits for split flag
569        let split_cost: u32 = split_results
570            .iter()
571            .map(|r| r.cost)
572            .fold(split_overhead, u32::saturating_add);
573
574        // Apply threshold
575        let effective_split_cost = (f64::from(split_cost) * f64::from(self.split_threshold)) as u32;
576
577        if effective_split_cost < unsplit_cost {
578            SplitDecision::split(unsplit_cost, split_cost, split_results.to_vec())
579        } else {
580            SplitDecision::no_split(unsplit_cost)
581        }
582    }
583
584    /// Checks for early termination (no need to try smaller partitions).
585    #[must_use]
586    pub fn can_early_terminate(&self, result: &PartitionDecision) -> bool {
587        // Skip blocks don't need further partitioning
588        if result.is_skip {
589            return true;
590        }
591
592        // Very low distortion blocks don't need splitting
593        result.distortion < SKIP_THRESHOLD / 2
594    }
595}
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    #[test]
602    fn test_partition_type_num_parts() {
603        assert_eq!(PartitionType::None.num_parts(), 1);
604        assert_eq!(PartitionType::HorizontalSplit.num_parts(), 2);
605        assert_eq!(PartitionType::VerticalSplit.num_parts(), 2);
606        assert_eq!(PartitionType::Split.num_parts(), 4);
607    }
608
609    #[test]
610    fn test_partition_type_is_split() {
611        assert!(!PartitionType::None.is_split());
612        assert!(PartitionType::Split.is_split());
613        assert!(PartitionType::HorizontalSplit.is_split());
614    }
615
616    #[test]
617    fn test_inter_mode() {
618        assert!(!InterMode::Skip.requires_search());
619        assert!(InterMode::NewMv.requires_search());
620        assert!(InterMode::NearestMv.uses_predictor());
621        assert!(!InterMode::NewMv.uses_predictor());
622    }
623
624    #[test]
625    fn test_partition_decision_default() {
626        let decision = PartitionDecision::new();
627        assert_eq!(decision.cost, u32::MAX);
628        assert!(!decision.is_skip);
629    }
630
631    #[test]
632    fn test_partition_decision_skip() {
633        let mv = MotionVector::new(8, 16);
634        let decision = PartitionDecision::skip(BlockSize::Block8x8, mv, 50);
635
636        assert!(decision.is_skip);
637        assert_eq!(decision.mode, InterMode::Skip);
638        assert_eq!(decision.mv.dx, 8);
639        assert_eq!(decision.distortion, 50);
640    }
641
642    #[test]
643    fn test_partition_decision_from_match() {
644        let block_match = BlockMatch::new(MotionVector::new(10, 20), 100, 150);
645        let decision = PartitionDecision::from_match(BlockSize::Block16x16, &block_match);
646
647        assert_eq!(decision.mv.dx, 10);
648        assert_eq!(decision.distortion, 100);
649        assert_eq!(decision.cost, 150);
650    }
651
652    #[test]
653    fn test_partition_decision_comparison() {
654        let better = PartitionDecision {
655            cost: 100,
656            ..PartitionDecision::new()
657        };
658        let worse = PartitionDecision {
659            cost: 200,
660            ..PartitionDecision::new()
661        };
662
663        assert!(better.is_better_than(&worse));
664        assert!(!worse.is_better_than(&better));
665    }
666
667    #[test]
668    fn test_split_decision_no_split() {
669        let decision = SplitDecision::no_split(100);
670        assert!(!decision.should_split);
671        assert_eq!(decision.unsplit_cost, 100);
672    }
673
674    #[test]
675    fn test_split_decision_split() {
676        let children = vec![
677            PartitionDecision {
678                cost: 30,
679                ..PartitionDecision::new()
680            },
681            PartitionDecision {
682                cost: 30,
683                ..PartitionDecision::new()
684            },
685        ];
686
687        let decision = SplitDecision::split(100, 60, children);
688        assert!(decision.should_split);
689    }
690
691    #[test]
692    fn test_skip_detector() {
693        let detector = SkipDetector::new(100);
694        let block_match = BlockMatch::new(MotionVector::new(4, 4), 50, 60);
695        let predicted_mv = MotionVector::new(4, 4);
696
697        assert!(detector.can_skip(&block_match, &predicted_mv));
698
699        let bad_match = BlockMatch::new(MotionVector::new(100, 100), 150, 200);
700        assert!(!detector.can_skip(&bad_match, &predicted_mv));
701    }
702
703    #[test]
704    fn test_merge_candidate() {
705        let candidate = MergeCandidate::new(MotionVector::new(10, 20), 0, MergeSource::Left);
706        assert_eq!(candidate.mv.dx, 10);
707        assert_eq!(candidate.source, MergeSource::Left);
708    }
709
710    #[test]
711    fn test_merge_candidate_list() {
712        let mut list = MergeCandidateList::new();
713
714        list.add(MergeCandidate::new(
715            MotionVector::new(10, 20),
716            0,
717            MergeSource::Left,
718        ));
719        list.add(MergeCandidate::new(
720            MotionVector::new(30, 40),
721            0,
722            MergeSource::Top,
723        ));
724
725        assert_eq!(list.len(), 2);
726        assert_eq!(list.get(0).expect("get should return value").mv.dx, 10);
727    }
728
729    #[test]
730    fn test_merge_candidate_list_dedup() {
731        let mut list = MergeCandidateList::new();
732
733        list.add(MergeCandidate::new(
734            MotionVector::new(10, 20),
735            0,
736            MergeSource::Left,
737        ));
738        list.add(MergeCandidate::new(
739            MotionVector::new(10, 20),
740            0,
741            MergeSource::Top,
742        ));
743
744        // Duplicate should not be added
745        assert_eq!(list.len(), 1);
746    }
747
748    #[test]
749    fn test_partition_context() {
750        let ctx = PartitionContext::new(1920, 1080)
751            .at(100, 200)
752            .with_size_range(BlockSize::Block4x4, BlockSize::Block64x64);
753
754        assert_eq!(ctx.x, 100);
755        assert_eq!(ctx.y, 200);
756        assert!(ctx.can_use_size(BlockSize::Block64x64));
757    }
758
759    #[test]
760    fn test_partition_context_child_size() {
761        assert_eq!(
762            PartitionContext::child_size(BlockSize::Block64x64),
763            Some(BlockSize::Block32x32)
764        );
765        assert_eq!(
766            PartitionContext::child_size(BlockSize::Block8x8),
767            Some(BlockSize::Block4x4)
768        );
769        assert_eq!(PartitionContext::child_size(BlockSize::Block4x4), None);
770    }
771
772    #[test]
773    fn test_partition_decider() {
774        let decider = PartitionDecider::new();
775
776        let unsplit = PartitionDecision {
777            cost: 200,
778            ..PartitionDecision::new()
779        };
780        let split_results = vec![
781            PartitionDecision {
782                cost: 40,
783                ..PartitionDecision::new()
784            },
785            PartitionDecision {
786                cost: 40,
787                ..PartitionDecision::new()
788            },
789            PartitionDecision {
790                cost: 40,
791                ..PartitionDecision::new()
792            },
793            PartitionDecision {
794                cost: 40,
795                ..PartitionDecision::new()
796            },
797        ];
798
799        let ctx = PartitionContext::new(1920, 1080);
800        let decision = decider.decide_split(&unsplit, &split_results, &ctx);
801
802        // Split should be better (4*40 = 160 + overhead < 200)
803        assert!(decision.should_split || decision.split_cost < 200);
804    }
805
806    #[test]
807    fn test_partition_decider_early_termination() {
808        let decider = PartitionDecider::new();
809
810        let skip_result = PartitionDecision {
811            is_skip: true,
812            distortion: 10,
813            ..PartitionDecision::new()
814        };
815
816        assert!(decider.can_early_terminate(&skip_result));
817
818        let low_distortion = PartitionDecision {
819            distortion: 10,
820            ..PartitionDecision::new()
821        };
822
823        assert!(decider.can_early_terminate(&low_distortion));
824    }
825}