zoomvtools 2.0.0

Video motion vector analysis utilities in pure Rust
Documentation
#![allow(clippy::undocumented_unsafe_blocks)]

use std::{arch::x86_64::*, num::NonZeroUsize};

use cpudetect::target_family;

use crate::util::Pixel;

#[target_family("x86_64_v3")]
pub(super) unsafe fn reduce_average<T: Pixel>(
    dest: &mut [T],
    src: &[T],
    dest_pitch: NonZeroUsize,
    src_pitch: NonZeroUsize,
    dest_width: NonZeroUsize,
    dest_height: NonZeroUsize,
) {
    // Check the array bounds once at the start of the loop.
    debug_assert!(src.len() >= src_pitch.get() * dest_height.get() * 2);
    debug_assert!(dest.len() >= dest_pitch.get() * dest_height.get());

    match size_of::<T>() {
        1 => unsafe {
            reduce_average_u8(
                dest.as_mut_ptr() as *mut u8,
                src.as_ptr() as *const u8,
                dest_pitch,
                src_pitch,
                dest_width,
                dest_height,
            );
        },
        2 => unsafe {
            reduce_average_u16(
                dest.as_mut_ptr() as *mut u16,
                src.as_ptr() as *const u16,
                dest_pitch,
                src_pitch,
                dest_width,
                dest_height,
            );
        },
        _ => unreachable!(),
    }
}

#[target_family("x86_64_v3")]
// NOTE: Custom implementation. No C implementation exists.
unsafe fn reduce_average_u8(
    dest: *mut u8,
    src: *const u8,
    dest_pitch: NonZeroUsize,
    src_pitch: NonZeroUsize,
    dest_width: NonZeroUsize,
    dest_height: NonZeroUsize,
) {
    let dest_width = dest_width.get();
    let dest_height = dest_height.get();
    let dest_pitch = dest_pitch.get();
    let src_pitch = src_pitch.get();

    // Process 32 destination pixels at a time (requires 64 source pixels per row).
    const SIMD_WIDTH: usize = 32;
    const FALLBACK_SIMD_WIDTH: usize = 16;
    let ones = _mm256_set1_epi8(1);
    let rounding = _mm256_set1_epi16(2);

    for y in 0..dest_height {
        let dest_row = dest.add(y * dest_pitch);
        let src_row1 = src.add(y * 2 * src_pitch);
        let src_row2 = src.add((y * 2 + 1) * src_pitch);

        let mut x = 0;

        // Process 32-pixel SIMD chunks.
        while x + SIMD_WIDTH <= dest_width {
            let src1_lo = _mm256_loadu_si256(src_row1.add(x * 2) as *const __m256i);
            let src1_hi = _mm256_loadu_si256(src_row1.add(x * 2 + 32) as *const __m256i);
            let src2_lo = _mm256_loadu_si256(src_row2.add(x * 2) as *const __m256i);
            let src2_hi = _mm256_loadu_si256(src_row2.add(x * 2 + 32) as *const __m256i);

            let pairs_lo = _mm256_add_epi16(
                _mm256_maddubs_epi16(src1_lo, ones),
                _mm256_maddubs_epi16(src2_lo, ones),
            );
            let pairs_hi = _mm256_add_epi16(
                _mm256_maddubs_epi16(src1_hi, ones),
                _mm256_maddubs_epi16(src2_hi, ones),
            );

            let result_lo = _mm256_srli_epi16(_mm256_add_epi16(pairs_lo, rounding), 2);
            let result_hi = _mm256_srli_epi16(_mm256_add_epi16(pairs_hi, rounding), 2);
            let packed = _mm256_packus_epi16(result_lo, result_hi);
            let final_result = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);

            _mm256_storeu_si256(dest_row.add(x) as *mut __m256i, final_result);

            x += SIMD_WIDTH;
        }

        if x + FALLBACK_SIMD_WIDTH <= dest_width {
            let src1 = _mm256_loadu_si256(src_row1.add(x * 2) as *const __m256i);
            let src2 = _mm256_loadu_si256(src_row2.add(x * 2) as *const __m256i);

            let pairs = _mm256_add_epi16(
                _mm256_maddubs_epi16(src1, ones),
                _mm256_maddubs_epi16(src2, ones),
            );
            let result = _mm256_srli_epi16(_mm256_add_epi16(pairs, rounding), 2);
            let packed = _mm256_packus_epi16(result, result);
            let final_result = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);

            _mm_storeu_si128(
                dest_row.add(x) as *mut __m128i,
                _mm256_castsi256_si128(final_result),
            );

            x += FALLBACK_SIMD_WIDTH;
        }

        // Handle remaining pixels with scalar code
        while x < dest_width {
            let src_x = x * 2;
            let a = *src_row1.add(src_x) as u16;
            let b = *src_row1.add(src_x + 1) as u16;
            let c = *src_row2.add(src_x) as u16;
            let d = *src_row2.add(src_x + 1) as u16;

            let avg = ((a + b + c + d + 2) / 4) as u8;
            *dest_row.add(x) = avg;
            x += 1;
        }
    }
}

#[target_family("x86_64_v3")]
// NOTE: Custom implementation. No C implementation exists.
unsafe fn reduce_average_u16(
    dest: *mut u16,
    src: *const u16,
    dest_pitch: NonZeroUsize,
    src_pitch: NonZeroUsize,
    dest_width: NonZeroUsize,
    dest_height: NonZeroUsize,
) {
    let dest_width = dest_width.get();
    let dest_height = dest_height.get();
    let dest_pitch = dest_pitch.get();
    let src_pitch = src_pitch.get();

    // Process 16 destination pixels at a time (requires 32 source pixels per row).
    const SIMD_WIDTH: usize = 16;
    const FALLBACK_SIMD_WIDTH: usize = 8;
    let signed_bias = _mm256_set1_epi16(i16::MIN);
    let ones = _mm256_set1_epi16(1);
    let bias_correction = _mm256_set1_epi32(131_074);

    for y in 0..dest_height {
        let dest_row = dest.add(y * dest_pitch);
        let src_row1 = src.add(y * 2 * src_pitch);
        let src_row2 = src.add((y * 2 + 1) * src_pitch);

        let mut x = 0;

        // Process 16-pixel SIMD chunks.
        while x + SIMD_WIDTH <= dest_width {
            let src1_lo = _mm256_xor_si256(
                _mm256_loadu_si256(src_row1.add(x * 2) as *const __m256i),
                signed_bias,
            );
            let src1_hi = _mm256_xor_si256(
                _mm256_loadu_si256(src_row1.add(x * 2 + 16) as *const __m256i),
                signed_bias,
            );
            let src2_lo = _mm256_xor_si256(
                _mm256_loadu_si256(src_row2.add(x * 2) as *const __m256i),
                signed_bias,
            );
            let src2_hi = _mm256_xor_si256(
                _mm256_loadu_si256(src_row2.add(x * 2 + 16) as *const __m256i),
                signed_bias,
            );

            let sum_lo = _mm256_add_epi32(
                _mm256_add_epi32(
                    _mm256_madd_epi16(src1_lo, ones),
                    _mm256_madd_epi16(src2_lo, ones),
                ),
                bias_correction,
            );
            let sum_hi = _mm256_add_epi32(
                _mm256_add_epi32(
                    _mm256_madd_epi16(src1_hi, ones),
                    _mm256_madd_epi16(src2_hi, ones),
                ),
                bias_correction,
            );

            let result_lo = _mm256_srli_epi32(sum_lo, 2);
            let result_hi = _mm256_srli_epi32(sum_hi, 2);
            let packed = _mm256_packus_epi32(result_lo, result_hi);
            let final_result = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);

            _mm256_storeu_si256(dest_row.add(x) as *mut __m256i, final_result);

            x += SIMD_WIDTH;
        }

        if x + FALLBACK_SIMD_WIDTH <= dest_width {
            let src1 = _mm256_xor_si256(
                _mm256_loadu_si256(src_row1.add(x * 2) as *const __m256i),
                signed_bias,
            );
            let src2 = _mm256_xor_si256(
                _mm256_loadu_si256(src_row2.add(x * 2) as *const __m256i),
                signed_bias,
            );

            let sum = _mm256_add_epi32(
                _mm256_add_epi32(_mm256_madd_epi16(src1, ones), _mm256_madd_epi16(src2, ones)),
                bias_correction,
            );
            let result = _mm256_srli_epi32(sum, 2);
            let packed = _mm256_packus_epi32(result, result);
            let final_result = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);

            _mm_storeu_si128(
                dest_row.add(x) as *mut __m128i,
                _mm256_castsi256_si128(final_result),
            );

            x += FALLBACK_SIMD_WIDTH;
        }

        // Handle remaining pixels with scalar code
        while x < dest_width {
            let src_x = x * 2;
            let a = *src_row1.add(src_x) as u32;
            let b = *src_row1.add(src_x + 1) as u32;
            let c = *src_row2.add(src_x) as u32;
            let d = *src_row2.add(src_x + 1) as u32;

            let avg = ((a + b + c + d + 2) / 4) as u16;
            *dest_row.add(x) = avg;
            x += 1;
        }
    }
}