zoomvtools 1.1.0

Video motion vector analysis utilities in pure Rust
Documentation
#![allow(clippy::undocumented_unsafe_blocks)]
#![allow(unsafe_op_in_unsafe_fn)]

use std::{
    arch::x86_64::*,
    mem::size_of,
    num::{NonZeroU8, NonZeroUsize},
};

use num_traits::clamp;

use crate::util::{Pixel, round_ties_to_even};
use semisafe::option::unwrap as semisafe_opt_unwrap;
use semisafe::slice::get as semisafe_get;
use semisafe::slice::get_mut as semisafe_get_mut;

#[target_feature(enable = "avx2")]
pub(super) unsafe fn float_src_to_pixels<T: Pixel>(
    dst: &mut [T],
    dst_pitch: NonZeroUsize,
    src_dct: &[f32],
    size_x: NonZeroUsize,
    size_y: NonZeroUsize,
    bits_per_sample: NonZeroU8,
    dct_shift: usize,
    dct_shift0: usize,
) {
    match size_of::<T>() {
        1 => float_src_to_pixels_u8(
            dst.as_mut_ptr().cast(),
            dst_pitch,
            src_dct,
            size_x,
            size_y,
            bits_per_sample,
            dct_shift,
        ),
        2 => float_src_to_pixels_u16(
            dst.as_mut_ptr().cast(),
            dst_pitch,
            src_dct,
            size_x,
            size_y,
            bits_per_sample,
            dct_shift,
        ),
        _ => unreachable!(),
    }

    let pixel_max = (1 << bits_per_sample.get() as usize) - 1;
    let pixel_half = 1 << (bits_per_sample.get() as usize - 1);
    let f = *semisafe_get(src_dct, 0) * 0.5;
    let integ = round_ties_to_even(f) as i32;
    *semisafe_get_mut(dst, 0) = semisafe_opt_unwrap(T::from(clamp(
        (integ >> dct_shift0) + pixel_half,
        0,
        pixel_max,
    )));
}

#[target_feature(enable = "avx2")]
unsafe fn float_src_to_pixels_u8(
    dst: *mut u8,
    dst_pitch: NonZeroUsize,
    src_dct: &[f32],
    size_x: NonZeroUsize,
    size_y: NonZeroUsize,
    _bits_per_sample: NonZeroU8,
    dct_shift: usize,
) {
    // PERF: u8 only supports 8-bit, so make it constant
    let bits_per_sample = 8usize;

    let size_x = size_x.get();
    let size_y = size_y.get();
    let dst_pitch = dst_pitch.get();

    let sqrt_2_div_2 = _mm256_set1_ps((2f32).sqrt() / 2.0);
    let pixel_half = 1i32 << (bits_per_sample - 1);
    let pixel_max = (1i32 << bits_per_sample) - 1;
    let v_pixel_half = _mm256_set1_epi32(pixel_half);
    let v_pixel_max = _mm256_set1_epi32(pixel_max);
    let v_zero = _mm256_setzero_si256();
    let shift = _mm_cvtsi32_si128(dct_shift as i32);
    let sqrt_2_div_2_scalar = (2f32).sqrt() / 2.0;

    for y in 0..size_y {
        let src_row = semisafe_get(src_dct, y * size_x..y * size_x + size_x);
        let dst_row = dst.add(y * dst_pitch);

        let mut x = 0;
        while x + 8 <= size_x {
            let f = _mm256_loadu_ps(src_row.as_ptr().add(x));
            let scaled = _mm256_mul_ps(f, sqrt_2_div_2);
            let integ = _mm256_cvtps_epi32(scaled);
            let shifted = _mm256_sra_epi32(integ, shift);
            let biased = _mm256_add_epi32(shifted, v_pixel_half);
            let clamped_hi = _mm256_min_epi32(biased, v_pixel_max);
            let clamped = _mm256_max_epi32(clamped_hi, v_zero);
            let packed_u16 = _mm256_packus_epi32(clamped, v_zero);
            let packed_u16 = _mm256_permute4x64_epi64(packed_u16, 0b11_01_10_00);
            let packed_u8 = _mm256_packus_epi16(packed_u16, v_zero);
            let lo = _mm256_castsi256_si128(packed_u8);
            (dst_row.add(x) as *mut u64).write_unaligned(_mm_cvtsi128_si64(lo) as u64);
            x += 8;
        }

        while x < size_x {
            let f = semisafe_get(src_row, x) * sqrt_2_div_2_scalar;
            let integ = round_ties_to_even(f) as i32;
            let clamped = ((integ >> dct_shift) + pixel_half).clamp(0, pixel_max);
            *dst_row.add(x) = clamped as u8;
            x += 1;
        }
    }
}

#[target_feature(enable = "avx2")]
unsafe fn float_src_to_pixels_u16(
    dst: *mut u16,
    dst_pitch: NonZeroUsize,
    src_dct: &[f32],
    size_x: NonZeroUsize,
    size_y: NonZeroUsize,
    bits_per_sample: NonZeroU8,
    dct_shift: usize,
) {
    let size_x = size_x.get();
    let size_y = size_y.get();
    let dst_pitch = dst_pitch.get();

    let sqrt_2_div_2 = _mm256_set1_ps((2f32).sqrt() / 2.0);
    let pixel_half = 1i32 << (bits_per_sample.get() as usize - 1);
    let pixel_max = (1i32 << bits_per_sample.get() as usize) - 1;
    let v_pixel_half = _mm256_set1_epi32(pixel_half);
    let v_pixel_max = _mm256_set1_epi32(pixel_max);
    let v_zero = _mm256_setzero_si256();
    let shift = _mm_cvtsi32_si128(dct_shift as i32);
    let sqrt_2_div_2_scalar = (2f32).sqrt() / 2.0;

    for y in 0..size_y {
        let src_row = semisafe_get(src_dct, y * size_x..y * size_x + size_x);
        let dst_row = dst.add(y * dst_pitch);

        let mut x = 0;
        while x + 8 <= size_x {
            let f = _mm256_loadu_ps(src_row.as_ptr().add(x));
            let scaled = _mm256_mul_ps(f, sqrt_2_div_2);
            let integ = _mm256_cvtps_epi32(scaled);
            let shifted = _mm256_sra_epi32(integ, shift);
            let biased = _mm256_add_epi32(shifted, v_pixel_half);
            let clamped_hi = _mm256_min_epi32(biased, v_pixel_max);
            let clamped = _mm256_max_epi32(clamped_hi, v_zero);
            let packed = _mm256_packus_epi32(clamped, v_zero);
            let packed = _mm256_permute4x64_epi64(packed, 0b11_01_10_00);
            _mm_storeu_si128(
                dst_row.add(x) as *mut __m128i,
                _mm256_castsi256_si128(packed),
            );
            x += 8;
        }

        while x < size_x {
            let f = semisafe_get(src_row, x) * sqrt_2_div_2_scalar;
            let integ = round_ties_to_even(f) as i32;
            let clamped = ((integ >> dct_shift) + pixel_half).clamp(0, pixel_max);
            *dst_row.add(x) = clamped as u16;
            x += 1;
        }
    }
}