Skip to main content

oximedia_codec/motion/
cache.rs

1//! Motion vector caching for video encoding.
2//!
3//! This module provides:
4//! - MV cache for storing search results
5//! - Reference frame MV storage
6//! - Co-located MV lookup for temporal prediction
7//!
8//! Caching motion vectors improves encoding speed by:
9//! - Providing better starting points for subsequent searches
10//! - Enabling fast MV predictor derivation
11//! - Supporting temporal prediction
12
13#![forbid(unsafe_code)]
14#![allow(dead_code)]
15#![allow(clippy::manual_div_ceil)]
16#![allow(clippy::must_use_candidate)]
17#![allow(clippy::cast_possible_truncation)]
18#![allow(clippy::cast_sign_loss)]
19#![allow(clippy::cast_possible_wrap)]
20
21use super::types::{BlockSize, MotionVector};
22
23/// Cache entry for a motion vector.
24#[derive(Clone, Copy, Debug, Default)]
25pub struct MvCacheEntry {
26    /// Motion vector.
27    pub mv: MotionVector,
28    /// Reference frame index.
29    pub ref_idx: i8,
30    /// SAD value.
31    pub sad: u32,
32    /// Is this entry valid?
33    pub valid: bool,
34    /// Is this an inter block?
35    pub is_inter: bool,
36}
37
38impl MvCacheEntry {
39    /// Creates an invalid entry.
40    #[must_use]
41    pub const fn invalid() -> Self {
42        Self {
43            mv: MotionVector::zero(),
44            ref_idx: -1,
45            sad: u32::MAX,
46            valid: false,
47            is_inter: false,
48        }
49    }
50
51    /// Creates a valid inter entry.
52    #[must_use]
53    pub const fn inter(mv: MotionVector, ref_idx: i8, sad: u32) -> Self {
54        Self {
55            mv,
56            ref_idx,
57            sad,
58            valid: true,
59            is_inter: true,
60        }
61    }
62
63    /// Creates an intra entry (no MV).
64    #[must_use]
65    pub const fn intra() -> Self {
66        Self {
67            mv: MotionVector::zero(),
68            ref_idx: -1,
69            sad: 0,
70            valid: true,
71            is_inter: false,
72        }
73    }
74}
75
76/// Motion vector cache for a frame.
77///
78/// Stores motion vectors in a grid aligned to 4x4 blocks (MI units).
79#[derive(Clone, Debug)]
80pub struct MvCache {
81    /// Cache data.
82    data: Vec<MvCacheEntry>,
83    /// Width in MI units (4x4 blocks).
84    mi_cols: usize,
85    /// Height in MI units (4x4 blocks).
86    mi_rows: usize,
87    /// Number of reference frames supported.
88    num_refs: usize,
89}
90
91impl Default for MvCache {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97impl MvCache {
98    /// Creates a new empty cache.
99    #[must_use]
100    pub const fn new() -> Self {
101        Self {
102            data: Vec::new(),
103            mi_cols: 0,
104            mi_rows: 0,
105            num_refs: 3,
106        }
107    }
108
109    /// Allocates cache for a frame.
110    pub fn allocate(&mut self, width: usize, height: usize, num_refs: usize) {
111        self.mi_cols = (width + 3) / 4;
112        self.mi_rows = (height + 3) / 4;
113        self.num_refs = num_refs;
114
115        let size = self.mi_cols * self.mi_rows * num_refs;
116        self.data = vec![MvCacheEntry::invalid(); size];
117    }
118
119    /// Clears all entries.
120    pub fn clear(&mut self) {
121        self.data.fill(MvCacheEntry::invalid());
122    }
123
124    /// Returns the width in MI units.
125    #[must_use]
126    pub const fn mi_cols(&self) -> usize {
127        self.mi_cols
128    }
129
130    /// Returns the height in MI units.
131    #[must_use]
132    pub const fn mi_rows(&self) -> usize {
133        self.mi_rows
134    }
135
136    /// Calculates the index for a position.
137    fn index(&self, mi_row: usize, mi_col: usize, ref_idx: usize) -> Option<usize> {
138        if mi_row >= self.mi_rows || mi_col >= self.mi_cols || ref_idx >= self.num_refs {
139            return None;
140        }
141        Some((mi_row * self.mi_cols + mi_col) * self.num_refs + ref_idx)
142    }
143
144    /// Gets an entry.
145    #[must_use]
146    pub fn get(&self, mi_row: usize, mi_col: usize, ref_idx: usize) -> Option<&MvCacheEntry> {
147        let idx = self.index(mi_row, mi_col, ref_idx)?;
148        self.data.get(idx)
149    }
150
151    /// Gets a mutable entry.
152    pub fn get_mut(
153        &mut self,
154        mi_row: usize,
155        mi_col: usize,
156        ref_idx: usize,
157    ) -> Option<&mut MvCacheEntry> {
158        let idx = self.index(mi_row, mi_col, ref_idx)?;
159        self.data.get_mut(idx)
160    }
161
162    /// Sets an entry.
163    pub fn set(&mut self, mi_row: usize, mi_col: usize, ref_idx: usize, entry: MvCacheEntry) {
164        if let Some(idx) = self.index(mi_row, mi_col, ref_idx) {
165            if idx < self.data.len() {
166                self.data[idx] = entry;
167            }
168        }
169    }
170
171    /// Fills a block region with an entry.
172    pub fn fill_block(
173        &mut self,
174        mi_row: usize,
175        mi_col: usize,
176        block_size: BlockSize,
177        ref_idx: usize,
178        entry: MvCacheEntry,
179    ) {
180        let mi_width = block_size.width() / 4;
181        let mi_height = block_size.height() / 4;
182
183        for row in mi_row..mi_row + mi_height {
184            for col in mi_col..mi_col + mi_width {
185                self.set(row, col, ref_idx, entry);
186            }
187        }
188    }
189
190    /// Gets the left neighbor entry.
191    #[must_use]
192    pub fn get_left(&self, mi_row: usize, mi_col: usize, ref_idx: usize) -> Option<&MvCacheEntry> {
193        if mi_col == 0 {
194            return None;
195        }
196        self.get(mi_row, mi_col - 1, ref_idx)
197    }
198
199    /// Gets the top neighbor entry.
200    #[must_use]
201    pub fn get_top(&self, mi_row: usize, mi_col: usize, ref_idx: usize) -> Option<&MvCacheEntry> {
202        if mi_row == 0 {
203            return None;
204        }
205        self.get(mi_row - 1, mi_col, ref_idx)
206    }
207
208    /// Gets the top-right neighbor entry.
209    #[must_use]
210    pub fn get_top_right(
211        &self,
212        mi_row: usize,
213        mi_col: usize,
214        block_size: BlockSize,
215        ref_idx: usize,
216    ) -> Option<&MvCacheEntry> {
217        if mi_row == 0 {
218            return None;
219        }
220        let mi_width = block_size.width() / 4;
221        let tr_col = mi_col + mi_width;
222        if tr_col >= self.mi_cols {
223            return None;
224        }
225        self.get(mi_row - 1, tr_col, ref_idx)
226    }
227
228    /// Gets the top-left neighbor entry.
229    #[must_use]
230    pub fn get_top_left(
231        &self,
232        mi_row: usize,
233        mi_col: usize,
234        ref_idx: usize,
235    ) -> Option<&MvCacheEntry> {
236        if mi_row == 0 || mi_col == 0 {
237            return None;
238        }
239        self.get(mi_row - 1, mi_col - 1, ref_idx)
240    }
241}
242
243/// Reference frame motion vector storage.
244///
245/// Stores MVs for decoded reference frames to enable temporal prediction.
246#[derive(Clone, Debug)]
247pub struct RefFrameMvs {
248    /// MV data for each reference frame.
249    frames: Vec<MvCache>,
250    /// Maximum number of reference frames.
251    max_refs: usize,
252}
253
254impl Default for RefFrameMvs {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl RefFrameMvs {
261    /// Creates new reference frame MV storage.
262    #[must_use]
263    pub const fn new() -> Self {
264        Self {
265            frames: Vec::new(),
266            max_refs: 8,
267        }
268    }
269
270    /// Sets the maximum number of references.
271    #[must_use]
272    pub fn with_max_refs(mut self, max: usize) -> Self {
273        self.max_refs = max;
274        self
275    }
276
277    /// Allocates storage for reference frames.
278    pub fn allocate(&mut self, width: usize, height: usize, num_refs: usize) {
279        self.frames.clear();
280        for _ in 0..num_refs.min(self.max_refs) {
281            let mut cache = MvCache::new();
282            cache.allocate(width, height, 1);
283            self.frames.push(cache);
284        }
285    }
286
287    /// Gets the MV cache for a reference frame.
288    #[must_use]
289    pub fn get_frame(&self, frame_idx: usize) -> Option<&MvCache> {
290        self.frames.get(frame_idx)
291    }
292
293    /// Gets mutable MV cache for a reference frame.
294    pub fn get_frame_mut(&mut self, frame_idx: usize) -> Option<&mut MvCache> {
295        self.frames.get_mut(frame_idx)
296    }
297
298    /// Stores MVs from current frame as reference.
299    pub fn store_frame(&mut self, frame_idx: usize, source: &MvCache) {
300        if frame_idx < self.frames.len() {
301            self.frames[frame_idx] = source.clone();
302        }
303    }
304
305    /// Gets co-located MV from reference frame.
306    #[must_use]
307    pub fn get_co_located(
308        &self,
309        frame_idx: usize,
310        mi_row: usize,
311        mi_col: usize,
312    ) -> Option<MvCacheEntry> {
313        let frame = self.frames.get(frame_idx)?;
314        frame.get(mi_row, mi_col, 0).copied()
315    }
316
317    /// Shifts reference frames (for new frame insertion).
318    pub fn shift_frames(&mut self) {
319        if self.frames.len() > 1 {
320            self.frames.rotate_right(1);
321        }
322    }
323}
324
325/// Co-located MV lookup helper.
326#[derive(Clone, Debug, Default)]
327pub struct CoLocatedMvLookup {
328    /// Reference frame MVs.
329    ref_mvs: RefFrameMvs,
330    /// Temporal distance to each reference.
331    temporal_distances: Vec<i32>,
332}
333
334impl CoLocatedMvLookup {
335    /// Creates a new lookup helper.
336    #[must_use]
337    pub fn new() -> Self {
338        Self {
339            ref_mvs: RefFrameMvs::new(),
340            temporal_distances: Vec::new(),
341        }
342    }
343
344    /// Allocates storage.
345    pub fn allocate(&mut self, width: usize, height: usize, num_refs: usize) {
346        self.ref_mvs.allocate(width, height, num_refs);
347        self.temporal_distances = vec![1; num_refs];
348    }
349
350    /// Sets temporal distances.
351    pub fn set_temporal_distances(&mut self, distances: &[i32]) {
352        self.temporal_distances = distances.to_vec();
353    }
354
355    /// Gets co-located MV with temporal scaling.
356    #[must_use]
357    #[allow(clippy::cast_possible_truncation)]
358    pub fn get_scaled_co_located(
359        &self,
360        frame_idx: usize,
361        mi_row: usize,
362        mi_col: usize,
363        target_dist: i32,
364    ) -> Option<MotionVector> {
365        let entry = self.ref_mvs.get_co_located(frame_idx, mi_row, mi_col)?;
366
367        if !entry.valid || !entry.is_inter {
368            return None;
369        }
370
371        let src_dist = self.temporal_distances.get(frame_idx).copied().unwrap_or(1);
372
373        if src_dist == target_dist || src_dist == 0 {
374            return Some(entry.mv);
375        }
376
377        // Scale MV for different temporal distance
378        let scale_x = (i64::from(entry.mv.dx) * i64::from(target_dist)) / i64::from(src_dist);
379        let scale_y = (i64::from(entry.mv.dy) * i64::from(target_dist)) / i64::from(src_dist);
380
381        Some(MotionVector::new(scale_x as i32, scale_y as i32))
382    }
383
384    /// Gets underlying reference MVs.
385    #[must_use]
386    pub fn ref_mvs(&self) -> &RefFrameMvs {
387        &self.ref_mvs
388    }
389
390    /// Gets mutable reference MVs.
391    pub fn ref_mvs_mut(&mut self) -> &mut RefFrameMvs {
392        &mut self.ref_mvs
393    }
394}
395
396/// Search result cache for avoiding redundant searches.
397#[derive(Clone, Debug)]
398pub struct SearchResultCache {
399    /// Cached results.
400    results: Vec<Option<(MotionVector, u32)>>,
401    /// Width in blocks.
402    width: usize,
403    /// Height in blocks.
404    height: usize,
405    /// Block size for this cache.
406    block_size: BlockSize,
407}
408
409impl Default for SearchResultCache {
410    fn default() -> Self {
411        Self::new()
412    }
413}
414
415impl SearchResultCache {
416    /// Creates a new cache.
417    #[must_use]
418    pub const fn new() -> Self {
419        Self {
420            results: Vec::new(),
421            width: 0,
422            height: 0,
423            block_size: BlockSize::Block8x8,
424        }
425    }
426
427    /// Allocates cache for a frame.
428    pub fn allocate(&mut self, frame_width: usize, frame_height: usize, block_size: BlockSize) {
429        self.block_size = block_size;
430        self.width = (frame_width + block_size.width() - 1) / block_size.width();
431        self.height = (frame_height + block_size.height() - 1) / block_size.height();
432        self.results = vec![None; self.width * self.height];
433    }
434
435    /// Clears the cache.
436    pub fn clear(&mut self) {
437        self.results.fill(None);
438    }
439
440    /// Gets a cached result.
441    #[must_use]
442    pub fn get(&self, block_x: usize, block_y: usize) -> Option<(MotionVector, u32)> {
443        let bx = block_x / self.block_size.width();
444        let by = block_y / self.block_size.height();
445
446        if bx >= self.width || by >= self.height {
447            return None;
448        }
449
450        self.results[by * self.width + bx]
451    }
452
453    /// Stores a result.
454    pub fn store(&mut self, block_x: usize, block_y: usize, mv: MotionVector, sad: u32) {
455        let bx = block_x / self.block_size.width();
456        let by = block_y / self.block_size.height();
457
458        if bx < self.width && by < self.height {
459            self.results[by * self.width + bx] = Some((mv, sad));
460        }
461    }
462
463    /// Checks if a result is cached.
464    #[must_use]
465    pub fn has(&self, block_x: usize, block_y: usize) -> bool {
466        self.get(block_x, block_y).is_some()
467    }
468}
469
470/// Combined cache manager.
471#[derive(Clone, Debug, Default)]
472pub struct CacheManager {
473    /// Current frame MV cache.
474    pub current_frame: MvCache,
475    /// Reference frame MVs.
476    pub ref_frames: RefFrameMvs,
477    /// Search result cache.
478    pub search_cache: SearchResultCache,
479    /// Co-located lookup.
480    pub co_located: CoLocatedMvLookup,
481}
482
483impl CacheManager {
484    /// Creates a new cache manager.
485    #[must_use]
486    pub fn new() -> Self {
487        Self {
488            current_frame: MvCache::new(),
489            ref_frames: RefFrameMvs::new(),
490            search_cache: SearchResultCache::new(),
491            co_located: CoLocatedMvLookup::new(),
492        }
493    }
494
495    /// Allocates all caches.
496    pub fn allocate(&mut self, width: usize, height: usize, num_refs: usize) {
497        self.current_frame.allocate(width, height, num_refs);
498        self.ref_frames.allocate(width, height, num_refs);
499        self.search_cache
500            .allocate(width, height, BlockSize::Block8x8);
501        self.co_located.allocate(width, height, num_refs);
502    }
503
504    /// Clears all caches for a new frame.
505    pub fn new_frame(&mut self) {
506        self.current_frame.clear();
507        self.search_cache.clear();
508    }
509
510    /// Stores current frame as reference.
511    pub fn finalize_frame(&mut self, frame_idx: usize) {
512        self.ref_frames.store_frame(frame_idx, &self.current_frame);
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519
520    #[test]
521    fn test_mv_cache_entry_invalid() {
522        let entry = MvCacheEntry::invalid();
523        assert!(!entry.valid);
524        assert!(!entry.is_inter);
525    }
526
527    #[test]
528    fn test_mv_cache_entry_inter() {
529        let mv = MotionVector::new(10, 20);
530        let entry = MvCacheEntry::inter(mv, 0, 100);
531
532        assert!(entry.valid);
533        assert!(entry.is_inter);
534        assert_eq!(entry.mv.dx, 10);
535        assert_eq!(entry.sad, 100);
536    }
537
538    #[test]
539    fn test_mv_cache_allocate() {
540        let mut cache = MvCache::new();
541        cache.allocate(64, 64, 3);
542
543        assert_eq!(cache.mi_cols(), 16); // 64/4
544        assert_eq!(cache.mi_rows(), 16);
545    }
546
547    #[test]
548    fn test_mv_cache_get_set() {
549        let mut cache = MvCache::new();
550        cache.allocate(64, 64, 3);
551
552        let entry = MvCacheEntry::inter(MotionVector::new(10, 20), 0, 100);
553        cache.set(5, 5, 0, entry);
554
555        let retrieved = cache.get(5, 5, 0).expect("get should return value");
556        assert!(retrieved.valid);
557        assert_eq!(retrieved.mv.dx, 10);
558    }
559
560    #[test]
561    fn test_mv_cache_fill_block() {
562        let mut cache = MvCache::new();
563        cache.allocate(64, 64, 3);
564
565        let entry = MvCacheEntry::inter(MotionVector::new(10, 20), 0, 100);
566        cache.fill_block(0, 0, BlockSize::Block8x8, 0, entry);
567
568        // 8x8 block = 2x2 MI units
569        assert!(cache.get(0, 0, 0).expect("get should return value").valid);
570        assert!(cache.get(0, 1, 0).expect("get should return value").valid);
571        assert!(cache.get(1, 0, 0).expect("get should return value").valid);
572        assert!(cache.get(1, 1, 0).expect("get should return value").valid);
573    }
574
575    #[test]
576    fn test_mv_cache_neighbors() {
577        let mut cache = MvCache::new();
578        cache.allocate(64, 64, 3);
579
580        // Set some entries
581        let left = MvCacheEntry::inter(MotionVector::new(1, 1), 0, 10);
582        let top = MvCacheEntry::inter(MotionVector::new(2, 2), 0, 20);
583
584        cache.set(5, 4, 0, left);
585        cache.set(4, 5, 0, top);
586
587        let got_left = cache.get_left(5, 5, 0).expect("should succeed");
588        assert_eq!(got_left.mv.dx, 1);
589
590        let got_top = cache.get_top(5, 5, 0).expect("should succeed");
591        assert_eq!(got_top.mv.dx, 2);
592    }
593
594    #[test]
595    fn test_ref_frame_mvs() {
596        let mut ref_mvs = RefFrameMvs::new();
597        ref_mvs.allocate(64, 64, 3);
598
599        // Store frame
600        let mut cache = MvCache::new();
601        cache.allocate(64, 64, 1);
602        let entry = MvCacheEntry::inter(MotionVector::new(10, 20), 0, 100);
603        cache.set(5, 5, 0, entry);
604
605        ref_mvs.store_frame(0, &cache);
606
607        // Retrieve co-located
608        let co_loc = ref_mvs.get_co_located(0, 5, 5).expect("should succeed");
609        assert!(co_loc.valid);
610        assert_eq!(co_loc.mv.dx, 10);
611    }
612
613    #[test]
614    fn test_co_located_lookup_scaling() {
615        let mut lookup = CoLocatedMvLookup::new();
616        lookup.allocate(64, 64, 3);
617        lookup.set_temporal_distances(&[1, 2, 4]);
618
619        // Store a co-located MV
620        if let Some(frame) = lookup.ref_mvs_mut().get_frame_mut(0) {
621            let entry = MvCacheEntry::inter(MotionVector::new(100, 200), 0, 50);
622            frame.set(5, 5, 0, entry);
623        }
624
625        // Get scaled for different target distance
626        let scaled = lookup.get_scaled_co_located(0, 5, 5, 2);
627        assert!(scaled.is_some());
628        let mv = scaled.expect("should succeed");
629        assert_eq!(mv.dx, 200); // Scaled by 2
630        assert_eq!(mv.dy, 400);
631    }
632
633    #[test]
634    fn test_search_result_cache() {
635        let mut cache = SearchResultCache::new();
636        cache.allocate(64, 64, BlockSize::Block8x8);
637
638        // Store result
639        let mv = MotionVector::new(10, 20);
640        cache.store(0, 0, mv, 100);
641
642        // Retrieve
643        let result = cache.get(0, 0);
644        assert!(result.is_some());
645        let (cached_mv, sad) = result.expect("should succeed");
646        assert_eq!(cached_mv.dx, 10);
647        assert_eq!(sad, 100);
648    }
649
650    #[test]
651    fn test_search_result_cache_clear() {
652        let mut cache = SearchResultCache::new();
653        cache.allocate(64, 64, BlockSize::Block8x8);
654
655        cache.store(0, 0, MotionVector::new(10, 20), 100);
656        assert!(cache.has(0, 0));
657
658        cache.clear();
659        assert!(!cache.has(0, 0));
660    }
661
662    #[test]
663    fn test_cache_manager() {
664        let mut manager = CacheManager::new();
665        manager.allocate(64, 64, 3);
666
667        // Store in current frame
668        let entry = MvCacheEntry::inter(MotionVector::new(10, 20), 0, 100);
669        manager.current_frame.set(5, 5, 0, entry);
670
671        // Store search result
672        manager
673            .search_cache
674            .store(0, 0, MotionVector::new(5, 10), 50);
675
676        assert!(
677            manager
678                .current_frame
679                .get(5, 5, 0)
680                .expect("get should return value")
681                .valid
682        );
683        assert!(manager.search_cache.has(0, 0));
684    }
685
686    #[test]
687    fn test_cache_manager_new_frame() {
688        let mut manager = CacheManager::new();
689        manager.allocate(64, 64, 3);
690
691        // Store data
692        let entry = MvCacheEntry::inter(MotionVector::new(10, 20), 0, 100);
693        manager.current_frame.set(5, 5, 0, entry);
694        manager
695            .search_cache
696            .store(0, 0, MotionVector::new(5, 10), 50);
697
698        // New frame clears caches
699        manager.new_frame();
700
701        assert!(
702            !manager
703                .current_frame
704                .get(5, 5, 0)
705                .expect("get should return value")
706                .valid
707        );
708        assert!(!manager.search_cache.has(0, 0));
709    }
710}