zoomvtools 2.0.0

Video motion vector analysis utilities in pure Rust
Documentation
#[cfg(test)]
mod tests;

#[cfg(all(target_arch = "x86_64", feature = "simd"))]
mod avx2;
mod rust;

use std::num::{NonZeroU8, NonZeroUsize};
use std::sync::Arc;

use anyhow::Result;
use rustdct::{DctPlanner, TransformType2And3};
use semisafe::slice::get as semisafe_get;
use semisafe::slice::get_mut as semisafe_get_mut;

use crate::util::Pixel;

pub struct DctHelper {
    size_x: NonZeroUsize,
    size_y: NonZeroUsize,
    bits_per_sample: NonZeroU8,
    dct_shift: usize,
    dct_shift0: usize,

    row_dct: Arc<dyn TransformType2And3<f32>>,
    col_dct: Arc<dyn TransformType2And3<f32>>,
    row_scratch: Box<[f32]>,
    col_scratch: Box<[f32]>,
    column: Box<[f32]>,
    src: Box<[f32]>,
    src_dct: Box<[f32]>,
}

impl DctHelper {
    #[inline]
    pub fn new(
        size_x: NonZeroUsize,
        size_y: NonZeroUsize,
        bits_per_sample: NonZeroU8,
    ) -> Result<Self> {
        let size_2d = size_y.saturating_mul(size_x);
        let mut cur_size = 1usize;
        let mut dct_shift = 0usize;
        while cur_size < size_2d.get() {
            dct_shift += 1;
            cur_size <<= 1;
        }
        let dct_shift0 = dct_shift + 2;

        let mut planner = DctPlanner::new();
        let row_dct = planner.plan_dct2(size_x.get());
        let col_dct = planner.plan_dct2(size_y.get());
        let src = vec![0.0; size_2d.get()].into_boxed_slice();
        let src_dct = vec![0.0; size_2d.get()].into_boxed_slice();
        let this = DctHelper {
            size_x,
            size_y,
            bits_per_sample,
            dct_shift,
            dct_shift0,
            row_scratch: vec![0.0; row_dct.get_scratch_len()].into_boxed_slice(),
            col_scratch: vec![0.0; col_dct.get_scratch_len()].into_boxed_slice(),
            column: vec![0.0; size_y.get()].into_boxed_slice(),
            row_dct,
            col_dct,
            src,
            src_dct,
        };
        Ok(this)
    }

    #[cfg_attr(
        feature = "tracing",
        tracing::instrument(skip_all, name = "dct::bytes_2d")
    )]
    #[inline]
    pub fn bytes_2d<T: Pixel>(
        &mut self,
        src_plane: &[T],
        src_pitch: NonZeroUsize,
        dct_plane: &mut [T],
        dct_pitch: NonZeroUsize,
    ) -> Result<()> {
        self.pixels_to_float_src(src_plane, src_pitch);
        self.transform_src_to_dct();
        self.float_src_to_pixels(dct_plane, dct_pitch);

        Ok(())
    }

    #[inline]
    fn transform_src_to_dct(&mut self) {
        let size_x = self.size_x.get();

        for (src_row, dst_row) in self
            .src
            .chunks_exact(size_x)
            .zip(self.src_dct.chunks_exact_mut(size_x))
        {
            dst_row.copy_from_slice(src_row);
            self.row_dct
                .process_dct2_with_scratch(dst_row, &mut self.row_scratch);
        }

        for x in 0..size_x {
            for (value, row) in self
                .column
                .iter_mut()
                .zip(self.src_dct.chunks_exact(size_x))
            {
                *value = *semisafe_get(row, x);
            }

            self.col_dct
                .process_dct2_with_scratch(&mut self.column, &mut self.col_scratch);

            for (&value, row) in self
                .column
                .iter()
                .zip(self.src_dct.chunks_exact_mut(size_x))
            {
                *semisafe_get_mut(row, x) = value;
            }
        }

        // COMPAT: C MVTools uses FFTW's 2D REDFT10 entry point, which produces coefficients at
        // 4x the amplitude of rustdct's raw separable DCT-II output for the same input block.
        // Match the historical coefficient amplitude so existing pixel conversion and SAD code
        // keep the same behavior.
        for value in &mut self.src_dct {
            *value *= 4.0;
        }
    }

    #[inline]
    pub fn pixels_to_float_src<T: Pixel>(&mut self, src_plane: &[T], src_pitch: NonZeroUsize) {
        for j in 0..(self.size_y.get()) {
            let f_src = semisafe_get_mut(
                semisafe_get_mut(&mut self.src, j * self.size_x.get()..),
                ..self.size_x.get(),
            );
            let p_src = semisafe_get(
                semisafe_get(src_plane, j * src_pitch.get()..),
                ..self.size_x.get(),
            );
            for (f, p) in f_src.iter_mut().zip(p_src.iter()) {
                *f = p.to_f32().expect("fits in f32");
            }
        }
    }

    #[inline]
    pub fn float_src_to_pixels<T: Pixel>(&self, dst: &mut [T], dst_pitch: NonZeroUsize) {
        #[cfg(all(target_arch = "x86_64", feature = "simd"))]
        if cpudetect::x86_64::is_x86_64_v3_compatible() {
            // PERF: 95% faster than scalar on width >= 8
            // SAFETY: The cpudetect-backed runtime gate above guarantees AVX2 support.
            unsafe {
                avx2::float_src_to_pixels(
                    dst,
                    dst_pitch,
                    &self.src_dct,
                    self.size_x,
                    self.size_y,
                    self.bits_per_sample,
                    self.dct_shift,
                    self.dct_shift0,
                );
            }
            return;
        }

        rust::float_src_to_pixels(
            dst,
            dst_pitch,
            &self.src_dct,
            self.size_x,
            self.size_y,
            self.bits_per_sample,
            self.dct_shift,
            self.dct_shift0,
        );
    }

    #[cfg(feature = "bench")]
    #[inline]
    pub fn src_dct_mut(&mut self) -> &mut [f32] {
        &mut self.src_dct
    }
}