zoomvtools 2.0.0

Video motion vector analysis utilities in pure Rust
Documentation
use std::{arch::x86_64::*, mem::MaybeUninit, num::NonZeroUsize};

use crate::resize::{
    SIMPLE_RESIZE_WEIGHT_HALF, SIMPLE_RESIZE_WEIGHT_MAX, SIMPLE_RESIZE_WEIGHT_SHIFT, SimpleResize,
};
use cpudetect::target_family;
use semisafe::slice::get as semisafe_get;

#[target_family("x86_64_v3")]
// NOTE: Ported from C implementation
pub(crate) unsafe fn simple_resize_u8(
    resizer: &SimpleResize,
    mut dest: *mut u8,
    dest_stride_bytes: NonZeroUsize,
    src: *const u8,
    src_stride_bytes: NonZeroUsize,
    _horizontal_vectors: bool,
) {
    if resizer.src_width.get() < 4 || resizer.dest_width.get() < 8 {
        // SAFETY: same preconditions as this AVX2 entry point; scalar fallback uses identical API.
        return unsafe {
            super::rust::simple_resize::<u8>(
                resizer,
                dest,
                dest_stride_bytes,
                src,
                src_stride_bytes,
                false,
            );
        };
    }

    // Two additional bytes because of vpgatherdd.
    let mut workp: Vec<MaybeUninit<i16>> = Vec::with_capacity(resizer.src_width.get() + 2);
    // SAFETY: we write to the work buffer before reading from it
    unsafe { workp.set_len(resizer.src_width.get() + 2) };

    #[rustfmt::skip]
    let shuffle_mask = _mm256_set_epi8(
        -0x80, 13, -0x80, 12, -0x80, 9, -0x80, 8, -0x80, 5, -0x80, 4, -0x80, 1, -0x80, 0,
        -0x80, 13, -0x80, 12, -0x80, 9, -0x80, 8, -0x80, 5, -0x80, 4, -0x80, 1, -0x80, 0,
    );

    for y in 0..resizer.dest_height.get() {
        let weight_bottom = *semisafe_get(&resizer.vertical_weights, y);
        let weight_top = SIMPLE_RESIZE_WEIGHT_MAX - weight_bottom;

        let srcp1 = src.add(*semisafe_get(&resizer.vertical_offsets, y) * src_stride_bytes.get());
        let srcp2 = srcp1.add(src_stride_bytes.get());

        let dwords_weights_v = _mm_set1_epi32((weight_top << 16) | weight_bottom);

        let pixels_per_iteration = 4;
        let src_width_avx2 = resizer.src_width.get() & !(pixels_per_iteration - 1);

        // vertical
        for x in (0..src_width_avx2).step_by(pixels_per_iteration) {
            simple_resize_u8_vertical_4px(
                workp.as_mut_ptr().cast(),
                srcp1,
                srcp2,
                x,
                dwords_weights_v,
            );
        }

        if src_width_avx2 < resizer.src_width.get() {
            simple_resize_u8_vertical_4px(
                workp.as_mut_ptr().cast(),
                srcp1,
                srcp2,
                resizer.src_width.get() - pixels_per_iteration,
                dwords_weights_v,
            );
        }

        let pixels_per_iteration = 8;
        let dest_width_avx2 = resizer.dest_width.get() & !(pixels_per_iteration - 1);

        // horizontal
        for x in (0..dest_width_avx2).step_by(pixels_per_iteration) {
            simple_resize_u8_horizontal_8px_avx2(
                resizer,
                dest,
                workp.as_ptr().cast(),
                x,
                shuffle_mask,
            );
        }

        if dest_width_avx2 < resizer.dest_width.get() {
            simple_resize_u8_horizontal_8px_avx2(
                resizer,
                dest,
                workp.as_ptr().cast(),
                resizer.dest_width.get() - pixels_per_iteration,
                shuffle_mask,
            );
        }

        dest = dest.add(dest_stride_bytes.get());
    }
}

#[target_family("x86_64_v3")]
// NOTE: Ported from C implementation
unsafe fn simple_resize_u8_vertical_4px(
    workp: *mut u8,
    srcp1: *const u8,
    srcp2: *const u8,
    x: usize,
    dwords_weights: __m128i,
) {
    let zeroes = _mm_setzero_si128();
    let top = _mm_cvtsi32_si128((srcp1.add(x) as *const i32).read_unaligned());
    let bottom = _mm_cvtsi32_si128((srcp2.add(x) as *const i32).read_unaligned());
    let pixels = _mm_unpacklo_epi8(_mm_unpacklo_epi8(bottom, top), zeroes);

    let mut dst = _mm_madd_epi16(pixels, dwords_weights);

    dst = _mm_add_epi32(dst, _mm_set1_epi32(SIMPLE_RESIZE_WEIGHT_HALF));
    dst = _mm_srli_epi32(dst, SIMPLE_RESIZE_WEIGHT_SHIFT);
    dst = _mm_packs_epi32(dst, dst);
    dst = _mm_packus_epi16(dst, dst);
    (workp.add(x) as *mut i32).write_unaligned(_mm_cvtsi128_si32(dst));
}

#[target_family("x86_64_v3")]
// NOTE: Ported from C implementation
unsafe fn simple_resize_u8_horizontal_8px_avx2(
    resizer: &SimpleResize,
    dest: *mut u8,
    workp: *const u8,
    x: usize,
    shuffle_mask: __m256i,
) {
    let dwords_weights_h = _mm256_loadu_si256(
        semisafe_get(&resizer.horizontal_weights_avx2, x) as *const _ as *const _
    );
    let dwords_offsets = _mm256_loadu_si256(
        semisafe_get(&resizer.horizontal_offsets_avx2, x) as *const _ as *const _
    );
    // scale = size_of::<u8>()
    let mut pixels = _mm256_i32gather_epi32::<1>(workp.cast(), dwords_offsets);

    pixels = _mm256_shuffle_epi8(pixels, shuffle_mask);

    pixels = _mm256_madd_epi16(pixels, dwords_weights_h);
    pixels = _mm256_add_epi32(pixels, _mm256_set1_epi32(SIMPLE_RESIZE_WEIGHT_HALF));
    pixels = _mm256_srai_epi32(pixels, SIMPLE_RESIZE_WEIGHT_SHIFT);
    pixels = _mm256_packs_epi32(pixels, pixels);
    pixels = _mm256_permute4x64_epi64(pixels, 0b11101000);
    pixels = _mm256_packus_epi16(pixels, pixels);

    _mm_storel_epi64(dest.add(x).cast(), _mm256_castsi256_si128(pixels));
}

#[target_family("x86_64_v3")]
// NOTE: Ported from C implementation
pub(crate) unsafe fn simple_resize_i16(
    resizer: &SimpleResize,
    dest: *mut u8,
    dest_stride_bytes: NonZeroUsize,
    src: *const u8,
    src_stride_bytes: NonZeroUsize,
    horizontal_vectors: bool,
) {
    if resizer.src_width.get() < 8 || resizer.dest_width.get() < 8 {
        // SAFETY: same preconditions as this AVX2 entry point; scalar fallback uses identical API.
        return unsafe {
            super::rust::simple_resize::<i16>(
                resizer,
                dest,
                dest_stride_bytes,
                src,
                src_stride_bytes,
                horizontal_vectors,
            );
        };
    }

    let src_width = resizer.src_width.get();
    let dest_width = resizer.dest_width.get();

    let src_stride = src_stride_bytes.get() / size_of::<i16>();
    let dest_stride = dest_stride_bytes.get() / size_of::<i16>();

    let src: *const i16 = src.cast();
    let mut dest: *mut i16 = dest.cast();

    let mut workp: Vec<MaybeUninit<i16>> = Vec::with_capacity(src_width);
    // SAFETY: we write to the work buffer before reading from it
    unsafe { workp.set_len(src_width) };

    let pixels_per_iteration = 8usize;

    let pel = resizer.pel as i32;
    let limit_height = resizer.limit_height.get() as i32;
    let limit_width = resizer.limit_width.get() as i32;

    let mut minimum = _mm256_setzero_si256();
    let mut maximum = _mm256_set1_epi32(limit_height * pel - 1);
    let horizontal_step = _mm256_set1_epi32(if horizontal_vectors {
        pel * pixels_per_iteration as i32
    } else {
        0
    });
    let vertical_step = _mm256_set1_epi32(if horizontal_vectors { 0 } else { pel });

    let initial_horizontal_minimum = _mm256_set_epi32(
        -7 * pel,
        -6 * pel,
        -5 * pel,
        -4 * pel,
        -3 * pel,
        -2 * pel,
        -pel,
        0,
    );
    let initial_horizontal_maximum = _mm256_set_epi32(
        (limit_width - 7) * pel - 1,
        (limit_width - 6) * pel - 1,
        (limit_width - 5) * pel - 1,
        (limit_width - 4) * pel - 1,
        (limit_width - 3) * pel - 1,
        (limit_width - 2) * pel - 1,
        (limit_width - 1) * pel - 1,
        limit_width * pel - 1,
    );

    for y in 0..resizer.dest_height.get() {
        let weight_bottom = *semisafe_get(&resizer.vertical_weights, y);
        let weight_top = SIMPLE_RESIZE_WEIGHT_MAX - weight_bottom;

        let srcp1 = src.add(*semisafe_get(&resizer.vertical_offsets, y) * src_stride);
        let srcp2 = srcp1.add(src_stride);

        let dwords_weights_v = _mm_set1_epi32((weight_top << 16) | weight_bottom);

        let src_width_avx2 = src_width & !(pixels_per_iteration - 1);

        // vertical
        for x in (0..src_width_avx2).step_by(pixels_per_iteration) {
            simple_resize_i16_vertical_8px(
                workp.as_mut_ptr().cast(),
                srcp1,
                srcp2,
                x,
                dwords_weights_v,
            );
        }

        if src_width_avx2 < src_width {
            simple_resize_i16_vertical_8px(
                workp.as_mut_ptr().cast(),
                srcp1,
                srcp2,
                src_width - pixels_per_iteration,
                dwords_weights_v,
            );
        }

        if horizontal_vectors {
            minimum = initial_horizontal_minimum;
            maximum = initial_horizontal_maximum;
        }

        let dst_width_avx2 = dest_width & !(pixels_per_iteration - 1);

        // horizontal
        for x in (0..dst_width_avx2).step_by(pixels_per_iteration) {
            simple_resize_i16_horizontal_8px(
                resizer,
                dest,
                workp.as_ptr().cast(),
                x,
                &mut minimum,
                &mut maximum,
                horizontal_step,
            );
        }

        if dst_width_avx2 < dest_width {
            if horizontal_vectors {
                let step_back = _mm256_set1_epi32(
                    (pixels_per_iteration as i32 - (dest_width as i32 - dst_width_avx2 as i32))
                        * pel,
                );
                minimum = _mm256_add_epi32(minimum, step_back);
                maximum = _mm256_add_epi32(maximum, step_back);
            }

            simple_resize_i16_horizontal_8px(
                resizer,
                dest,
                workp.as_ptr().cast(),
                dest_width - pixels_per_iteration,
                &mut minimum,
                &mut maximum,
                horizontal_step,
            );
        }

        dest = dest.add(dest_stride);

        minimum = _mm256_sub_epi32(minimum, vertical_step);
        maximum = _mm256_sub_epi32(maximum, vertical_step);
    }
}

#[target_family("x86_64_v3")]
// NOTE: Ported from C implementation
unsafe fn simple_resize_i16_vertical_8px(
    workp: *mut i16,
    srcp1: *const i16,
    srcp2: *const i16,
    x: usize,
    dwords_weights: __m128i,
) {
    let top = _mm_loadu_si128(srcp1.add(x).cast());
    let bottom = _mm_loadu_si128(srcp2.add(x).cast());
    let pixels_lo = _mm_unpacklo_epi16(bottom, top);
    let pixels_hi = _mm_unpackhi_epi16(bottom, top);

    let mut dst_lo = _mm_madd_epi16(pixels_lo, dwords_weights);
    let mut dst_hi = _mm_madd_epi16(pixels_hi, dwords_weights);
    dst_lo = _mm_add_epi32(dst_lo, _mm_set1_epi32(SIMPLE_RESIZE_WEIGHT_HALF));
    dst_hi = _mm_add_epi32(dst_hi, _mm_set1_epi32(SIMPLE_RESIZE_WEIGHT_HALF));
    dst_lo = _mm_srai_epi32(dst_lo, SIMPLE_RESIZE_WEIGHT_SHIFT);
    dst_hi = _mm_srai_epi32(dst_hi, SIMPLE_RESIZE_WEIGHT_SHIFT);
    let dst = _mm_packs_epi32(dst_lo, dst_hi);
    _mm_storeu_si128(workp.add(x).cast(), dst);
}

#[target_family("x86_64_v3")]
// NOTE: Ported from C implementation
unsafe fn simple_resize_i16_horizontal_8px(
    resizer: &SimpleResize,
    dest: *mut i16,
    workp: *const i16,
    x: usize,
    minimum: &mut __m256i,
    maximum: &mut __m256i,
    horizontal_step: __m256i,
) {
    let dwords_weights_h = _mm256_loadu_si256(
        semisafe_get(&resizer.horizontal_weights_avx2, x) as *const _ as *const _
    );
    let dwords_offsets = _mm256_loadu_si256(
        semisafe_get(&resizer.horizontal_offsets_avx2, x) as *const _ as *const _
    );
    // scale = size_of::<i16>()
    let mut pixels = _mm256_i32gather_epi32::<2>(workp.cast(), dwords_offsets);
    pixels = _mm256_madd_epi16(pixels, dwords_weights_h);
    pixels = _mm256_add_epi32(pixels, _mm256_set1_epi32(SIMPLE_RESIZE_WEIGHT_HALF));
    pixels = _mm256_srai_epi32(pixels, SIMPLE_RESIZE_WEIGHT_SHIFT);

    pixels = _mm256_max_epi32(*minimum, _mm256_min_epi32(pixels, *maximum));

    pixels = _mm256_packs_epi32(pixels, pixels);

    *minimum = _mm256_sub_epi32(*minimum, horizontal_step);
    *maximum = _mm256_sub_epi32(*maximum, horizontal_step);

    _mm_storeu_si128(
        dest.add(x).cast(),
        _mm256_castsi256_si128(_mm256_permute4x64_epi64(pixels, 0b11101000)),
    );
}