zoomvtools 2.0.0

Video motion vector analysis utilities in pure Rust
Documentation
#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]

#[cfg(test)]
use std::mem::size_of;
use std::{arch::x86_64::*, num::NonZeroUsize};

use crate::degrain::MAX_REFS_SIZE;
#[cfg(test)]
use crate::util::Pixel;
use cpudetect::target_family;
use semisafe::slice::get as semisafe_get;
use semisafe::slice::get_mut as semisafe_get_mut;

#[target_family("x86_64_v3")]
// NOTE: Ported from C implementation
pub(crate) unsafe fn degrain_u8<const RADIUS: usize, const WIDTH: usize, const HEIGHT: usize>(
    dest: *mut u8,
    dest_stride_bytes: NonZeroUsize,
    src: *const u8,
    src_stride_bytes: NonZeroUsize,
    refs: &[*const u8],
    refs_strides_bytes: &[NonZeroUsize],
    w_src: i32,
    w_refs: &[i32],
) {
    let mut refs_rows = [std::ptr::null(); MAX_REFS_SIZE];
    let mut refs_strides = [0_usize; MAX_REFS_SIZE];
    for r in 0..(RADIUS * 2) {
        *semisafe_get_mut(&mut refs_rows, r) = *semisafe_get(refs, r);
        *semisafe_get_mut(&mut refs_strides, r) = semisafe_get(refs_strides_bytes, r).get();
    }

    let mut src_row = src;
    let mut dest_row = dest;
    let bias = _mm256_set1_epi32(128);
    let w_src_vec = _mm256_set1_epi32(w_src);

    for _y in 0..HEIGHT {
        if WIDTH >= 8 {
            for x in (0..WIDTH).step_by(8) {
                let src_8 = _mm_loadl_epi64(src_row.add(x).cast());
                let src_i32 = _mm256_cvtepu8_epi32(src_8);
                let mut sum = _mm256_add_epi32(bias, _mm256_mullo_epi32(src_i32, w_src_vec));

                for r in 0..(RADIUS * 2) {
                    let ref_row = *semisafe_get(&refs_rows, r);
                    let ref_8 = _mm_loadl_epi64(ref_row.add(x).cast());
                    let ref_i32 = _mm256_cvtepu8_epi32(ref_8);
                    let w_ref_vec = _mm256_set1_epi32(*semisafe_get(w_refs, r));
                    sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(ref_i32, w_ref_vec));
                }

                let shifted = _mm256_srai_epi32(sum, 8);
                let packed16 = _mm256_packus_epi32(shifted, _mm256_setzero_si256());
                let packed16 = _mm256_permute4x64_epi64(packed16, 0xd8);
                let packed8 = _mm256_packus_epi16(packed16, _mm256_setzero_si256());
                _mm_storel_epi64(
                    dest_row.add(x) as *mut __m128i,
                    _mm256_castsi256_si128(packed8),
                );
            }
        } else {
            let src_4 = match WIDTH {
                4 => _mm_cvtsi32_si128((src_row as *const u32).read_unaligned() as i32),
                2 => _mm_cvtsi32_si128((src_row as *const u16).read_unaligned() as i32),
                _ => unreachable!(),
            };
            let src_i32 = _mm_cvtepu8_epi32(src_4);
            let mut sum = _mm_add_epi32(
                _mm_set1_epi32(128),
                _mm_mullo_epi32(src_i32, _mm_set1_epi32(w_src)),
            );

            for r in 0..(RADIUS * 2) {
                let ref_row = *semisafe_get(&refs_rows, r);
                let ref_4 = match WIDTH {
                    4 => _mm_cvtsi32_si128((ref_row as *const u32).read_unaligned() as i32),
                    2 => _mm_cvtsi32_si128((ref_row as *const u16).read_unaligned() as i32),
                    _ => unreachable!(),
                };
                let ref_i32 = _mm_cvtepu8_epi32(ref_4);
                let w_ref_vec = _mm_set1_epi32(*semisafe_get(w_refs, r));
                sum = _mm_add_epi32(sum, _mm_mullo_epi32(ref_i32, w_ref_vec));
            }

            let shifted = _mm_srai_epi32(sum, 8);
            let packed16 = _mm_packus_epi32(shifted, _mm_setzero_si128());
            let packed8 = _mm_packus_epi16(packed16, _mm_setzero_si128());
            let out = _mm_cvtsi128_si32(packed8) as u32;

            match WIDTH {
                4 => (dest_row as *mut u32).write_unaligned(out),
                2 => (dest_row as *mut u16).write_unaligned(out as u16),
                _ => unreachable!(),
            }
        }

        dest_row = dest_row.add(dest_stride_bytes.get());
        src_row = src_row.add(src_stride_bytes.get());
        for r in 0..(RADIUS * 2) {
            let ref_row = semisafe_get_mut(&mut refs_rows, r);
            *ref_row = ref_row.add(*semisafe_get(&refs_strides, r));
        }
    }
}

#[target_family("x86_64_v3")]
// NOTE: Custom implementation. No C implementation exists.
pub(crate) unsafe fn degrain_u16<const RADIUS: usize, const WIDTH: usize, const HEIGHT: usize>(
    dest: *mut u8,
    dest_stride_bytes: NonZeroUsize,
    src: *const u8,
    src_stride_bytes: NonZeroUsize,
    refs: &[*const u8],
    refs_strides_bytes: &[NonZeroUsize],
    w_src: i32,
    w_refs: &[i32],
) {
    let mut refs_rows = [std::ptr::null(); MAX_REFS_SIZE];
    let mut refs_strides = [0_usize; MAX_REFS_SIZE];
    for r in 0..(RADIUS * 2) {
        *semisafe_get_mut(&mut refs_rows, r) = *semisafe_get(refs, r);
        *semisafe_get_mut(&mut refs_strides, r) = semisafe_get(refs_strides_bytes, r).get();
    }

    let mut src_row = src;
    let mut dest_row = dest;
    let bias = _mm256_set1_epi32(128);
    let w_src_vec = _mm256_set1_epi32(w_src);

    for _y in 0..HEIGHT {
        if WIDTH >= 8 {
            for x in (0..WIDTH).step_by(8) {
                let offset = x * size_of::<u16>();
                let src_8 = _mm_loadu_si128(src_row.add(offset).cast());
                let src_i32 = _mm256_cvtepu16_epi32(src_8);
                let mut sum = _mm256_add_epi32(bias, _mm256_mullo_epi32(src_i32, w_src_vec));

                for r in 0..(RADIUS * 2) {
                    let ref_row = *semisafe_get(&refs_rows, r);
                    let ref_8 = _mm_loadu_si128(ref_row.add(offset).cast());
                    let ref_i32 = _mm256_cvtepu16_epi32(ref_8);
                    let w_ref_vec = _mm256_set1_epi32(*semisafe_get(w_refs, r));
                    sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(ref_i32, w_ref_vec));
                }

                let shifted = _mm256_srai_epi32(sum, 8);
                let packed = _mm256_packus_epi32(shifted, _mm256_setzero_si256());
                let packed = _mm256_permute4x64_epi64(packed, 0xd8);
                _mm_storeu_si128(
                    dest_row.add(offset) as *mut __m128i,
                    _mm256_castsi256_si128(packed),
                );
            }
        } else {
            let src_4 = match WIDTH {
                4 => _mm_loadl_epi64(src_row.cast()),
                2 => _mm_cvtsi32_si128((src_row as *const u32).read_unaligned() as i32),
                _ => unreachable!(),
            };
            let src_i32 = _mm_cvtepu16_epi32(src_4);
            let mut sum = _mm_add_epi32(
                _mm_set1_epi32(128),
                _mm_mullo_epi32(src_i32, _mm_set1_epi32(w_src)),
            );

            for r in 0..(RADIUS * 2) {
                let ref_row = *semisafe_get(&refs_rows, r);
                let ref_4 = match WIDTH {
                    4 => _mm_loadl_epi64(ref_row.cast()),
                    2 => _mm_cvtsi32_si128((ref_row as *const u32).read_unaligned() as i32),
                    _ => unreachable!(),
                };
                let ref_i32 = _mm_cvtepu16_epi32(ref_4);
                let w_ref_vec = _mm_set1_epi32(*semisafe_get(w_refs, r));
                sum = _mm_add_epi32(sum, _mm_mullo_epi32(ref_i32, w_ref_vec));
            }

            let shifted = _mm_srai_epi32(sum, 8);
            let packed = _mm_packus_epi32(shifted, _mm_setzero_si128());
            match WIDTH {
                4 => (dest_row as *mut u64).write_unaligned(_mm_cvtsi128_si64(packed) as u64),
                2 => (dest_row as *mut u32).write_unaligned(_mm_cvtsi128_si32(packed) as u32),
                _ => unreachable!(),
            }
        }

        dest_row = dest_row.add(dest_stride_bytes.get());
        src_row = src_row.add(src_stride_bytes.get());
        for r in 0..(RADIUS * 2) {
            let ref_row = semisafe_get_mut(&mut refs_rows, r);
            *ref_row = ref_row.add(*semisafe_get(&refs_strides, r));
        }
    }
}

#[cfg(test)]
#[target_family("x86_64_v3")]
pub(super) unsafe fn degrain_test<T: Pixel>(
    dest: &mut Vec<T>,
    width: NonZeroUsize,
    height: NonZeroUsize,
    src: &[T],
    src_stride_pixels: NonZeroUsize,
    refs: &[&[T]],
    refs_strides_pixels: &[NonZeroUsize],
    w_src: i32,
    w_refs: &[i32],
) {
    let radius = refs.len() / 2;

    // SAFETY: size_of::<T>() cannot be 0
    let stride_bytes =
        // SAFETY: size_of::<T>() cannot be 0
        unsafe { NonZeroUsize::new_unchecked(src_stride_pixels.get().saturating_mul(size_of::<T>())) };

    let refs_ptrs = refs
        .iter()
        .map(|ref_| ref_.as_ptr().cast())
        .collect::<Box<[_]>>();
    let refs_strides_bytes = refs_strides_pixels
        .iter()
        .map(|stride| {
            // SAFETY: size_of::<T>() cannot be 0
            // SAFETY: size_of::<T>() cannot be 0
            unsafe { NonZeroUsize::new_unchecked(stride.get().saturating_mul(size_of::<T>())) }
        })
        .collect::<Box<[_]>>();

    let func = match radius {
        1 => super::select_degrain_avx2::<T, 1>(width, height),
        2 => super::select_degrain_avx2::<T, 2>(width, height),
        3 => super::select_degrain_avx2::<T, 3>(width, height),
        4 => super::select_degrain_avx2::<T, 4>(width, height),
        5 => super::select_degrain_avx2::<T, 5>(width, height),
        6 => super::select_degrain_avx2::<T, 6>(width, height),
        _ => unreachable!("unsupported degrain radius"),
    };

    // SAFETY: pointers and strides are derived from owned slices, and function is selected for
    // the requested radius/width/height with AVX2 enabled.
    unsafe {
        func(
            dest.as_mut_ptr().cast(),
            stride_bytes,
            src.as_ptr().cast(),
            stride_bytes,
            &refs_ptrs,
            &refs_strides_bytes,
            w_src,
            w_refs,
        );
    }
}