#![allow(deprecated)] #![cfg_attr(not(feature = "unchecked"), forbid(unsafe_code))]
#![cfg_attr(feature = "unchecked", deny(unsafe_code))]
#![allow(unused_imports)]
#![allow(dead_code)]
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
#[cfg(target_arch = "x86_64")]
use crate::src::cpu::summon_avx2;
use archmage::{Desktop64, Server64, SimdToken, arcane, rite};
use std::cmp;
use std::ffi::c_int;
use std::ffi::c_uint;
use std::slice;
#[cfg(target_arch = "x86_64")]
use crate::src::safe_simd::partial_simd;
#[cfg(target_arch = "x86_64")]
use crate::src::safe_simd::pixel_access::{
Flex, loadi64, loadu_128, loadu_256, loadu_512, storeu_128, storeu_256, storeu_512,
};
use crate::include::common::bitdepth::AsPrimitive;
use crate::include::common::bitdepth::BitDepth;
use crate::include::common::bitdepth::BitDepth8;
use crate::include::common::bitdepth::BitDepth16;
use crate::include::common::bitdepth::DynPixel;
use crate::include::common::bitdepth::LeftPixelRow;
use crate::include::common::intops::iclip;
use crate::include::dav1d::picture::PicOffset;
use crate::src::align::AlignedVec64;
use crate::src::disjoint_mut::DisjointMut;
use crate::src::ffi_safe::FFISafe;
use crate::src::looprestoration::{LooprestorationParams, LrEdgeFlags, padding};
use crate::src::tables::dav1d_sgr_x_by_x;
#[allow(non_camel_case_types)]
type ptrdiff_t = isize;
const REST_UNIT_STRIDE: usize = 256 * 3 / 2 + 3 + 3;
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter7_8bpc_avx2_inner(
_token: Desktop64,
p: PicOffset,
left: &[LeftPixelRow<u8>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
) {
let mut tmp = [0u8; (64 + 3 + 3) * REST_UNIT_STRIDE];
padding::<BitDepth8>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let mut hor = [0u16; (64 + 3 + 3) * REST_UNIT_STRIDE];
let filter = ¶ms.filter;
let round_bits_h = 3i32;
let rounding_off_h = 1i32 << (round_bits_h - 1);
let clip_limit = 1i32 << (8 + 1 + 7 - round_bits_h);
let hf0 = _mm256_set1_epi32(filter[0][0] as i32);
let hf1 = _mm256_set1_epi32(filter[0][1] as i32);
let hf2 = _mm256_set1_epi32(filter[0][2] as i32);
let hf3 = _mm256_set1_epi32(filter[0][3] as i32 + 128); let hf4 = _mm256_set1_epi32(filter[0][4] as i32);
let hf5 = _mm256_set1_epi32(filter[0][5] as i32);
let hf6 = _mm256_set1_epi32(filter[0][6] as i32);
let dc_offset = _mm256_set1_epi32(1i32 << 14);
let h_rounding = _mm256_set1_epi32(rounding_off_h);
let h_clip_max = _mm256_set1_epi32(clip_limit - 1);
let h_zero = _mm256_setzero_si256();
for row in 0..(h + 6) {
let tmp_row = &tmp[row * REST_UNIT_STRIDE..];
let hor_row = &mut hor[row * REST_UNIT_STRIDE..row * REST_UNIT_STRIDE + w];
let mut x = 0;
while x + 8 <= w {
let p0 = _mm256_cvtepu8_epi32(loadu_128!(&tmp_row[x..x + 16], [u8; 16]));
let p1 = _mm256_cvtepu8_epi32(loadu_128!(&tmp_row[x + 1..x + 17], [u8; 16]));
let p2 = _mm256_cvtepu8_epi32(loadu_128!(&tmp_row[x + 2..x + 18], [u8; 16]));
let p3 = _mm256_cvtepu8_epi32(loadu_128!(&tmp_row[x + 3..x + 19], [u8; 16]));
let p4 = _mm256_cvtepu8_epi32(loadu_128!(&tmp_row[x + 4..x + 20], [u8; 16]));
let p5 = _mm256_cvtepu8_epi32(loadu_128!(&tmp_row[x + 5..x + 21], [u8; 16]));
let p6 = _mm256_cvtepu8_epi32(loadu_128!(&tmp_row[x + 6..x + 22], [u8; 16]));
let mut sum = dc_offset;
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p0, hf0));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p1, hf1));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p2, hf2));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p3, hf3));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p4, hf4));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p5, hf5));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p6, hf6));
sum = _mm256_add_epi32(sum, h_rounding);
sum = _mm256_srai_epi32::<3>(sum); sum = _mm256_max_epi32(sum, h_zero);
sum = _mm256_min_epi32(sum, h_clip_max);
let packed = _mm256_packus_epi32(sum, sum); let lo = _mm256_castsi256_si128(packed);
let hi = _mm256_extracti128_si256(packed, 1);
let combined = _mm_unpacklo_epi64(lo, hi);
storeu_128!(&mut hor_row[x..x + 8], [u16; 8], combined);
x += 8;
}
while x < w {
let mut sum = 1i32 << 14;
sum += tmp_row[x + 3] as i32 * 128;
for k in 0..7 {
sum += tmp_row[x + k] as i32 * filter[0][k] as i32;
}
hor_row[x] = iclip((sum + rounding_off_h) >> round_bits_h, 0, clip_limit - 1) as u16;
x += 1;
}
}
let round_bits_v = 11i32;
let rounding_off_v = 1i32 << (round_bits_v - 1);
let round_offset = 1i32 << (8 + round_bits_v - 1); let vf0 = _mm256_set1_epi32(filter[1][0] as i32);
let vf1 = _mm256_set1_epi32(filter[1][1] as i32);
let vf2 = _mm256_set1_epi32(filter[1][2] as i32);
let vf3 = _mm256_set1_epi32(filter[1][3] as i32);
let vf4 = _mm256_set1_epi32(filter[1][4] as i32);
let vf5 = _mm256_set1_epi32(filter[1][5] as i32);
let vf6 = _mm256_set1_epi32(filter[1][6] as i32);
let v_round_offset = _mm256_set1_epi32(-round_offset);
let v_rounding = _mm256_set1_epi32(rounding_off_v);
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth8, _>(
&p,
w,
h,
|bytes, offset, stride| {
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize;
let mut i = 0usize;
while i + 8 <= w {
let row0 = &hor[(j + 0) * REST_UNIT_STRIDE + i..];
let row1 = &hor[(j + 1) * REST_UNIT_STRIDE + i..];
let row2 = &hor[(j + 2) * REST_UNIT_STRIDE + i..];
let row3 = &hor[(j + 3) * REST_UNIT_STRIDE + i..];
let row4 = &hor[(j + 4) * REST_UNIT_STRIDE + i..];
let row5 = &hor[(j + 5) * REST_UNIT_STRIDE + i..];
let row6 = &hor[(j + 6) * REST_UNIT_STRIDE + i..];
let r0 = loadu_128!(&row0[..8], [u16; 8]);
let r1 = loadu_128!(&row1[..8], [u16; 8]);
let r2 = loadu_128!(&row2[..8], [u16; 8]);
let r3 = loadu_128!(&row3[..8], [u16; 8]);
let r4 = loadu_128!(&row4[..8], [u16; 8]);
let r5 = loadu_128!(&row5[..8], [u16; 8]);
let r6 = loadu_128!(&row6[..8], [u16; 8]);
let r0_lo = _mm256_cvtepu16_epi32(r0);
let r1_lo = _mm256_cvtepu16_epi32(r1);
let r2_lo = _mm256_cvtepu16_epi32(r2);
let r3_lo = _mm256_cvtepu16_epi32(r3);
let r4_lo = _mm256_cvtepu16_epi32(r4);
let r5_lo = _mm256_cvtepu16_epi32(r5);
let r6_lo = _mm256_cvtepu16_epi32(r6);
let mut sum = v_round_offset;
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r0_lo, vf0));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r1_lo, vf1));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r2_lo, vf2));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r3_lo, vf3));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r4_lo, vf4));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r5_lo, vf5));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r6_lo, vf6));
sum = _mm256_add_epi32(sum, v_rounding);
sum = _mm256_srai_epi32::<11>(sum);
let sum16 = _mm256_packus_epi32(sum, sum);
let sum16_lo = _mm256_castsi256_si128(sum16);
let sum16_hi = _mm256_extracti128_si256(sum16, 1);
let sum16_combined = _mm_unpacklo_epi64(sum16_lo, sum16_hi);
let sum8 = _mm_packus_epi16(sum16_combined, sum16_combined);
let dst_arr: &mut [u8; 8] = (&mut bytes[row_off + i..row_off + i + 8])
.try_into()
.unwrap();
partial_simd::mm_storel_epi64(dst_arr, sum8);
i += 8;
}
while i < w {
let mut sum = -round_offset;
for k in 0..7 {
sum += hor[(j + k) * REST_UNIT_STRIDE + i] as i32 * filter[1][k] as i32;
}
bytes[row_off + i] =
iclip((sum + rounding_off_v) >> round_bits_v, 0, 255) as u8;
i += 1;
}
}
},
); }
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter5_8bpc_avx2_inner(
_token: Desktop64,
p: PicOffset,
left: &[LeftPixelRow<u8>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
) {
wiener_filter7_8bpc_avx2_inner(_token, p, left, lpf, lpf_off, w, h, params, edges);
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter7_8bpc_avx512_inner(
_token: Server64,
p: PicOffset,
left: &[LeftPixelRow<u8>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
) {
let mut tmp = [0u8; (64 + 3 + 3) * REST_UNIT_STRIDE];
padding::<BitDepth8>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let mut hor = [0u16; (64 + 3 + 3) * REST_UNIT_STRIDE];
let filter = ¶ms.filter;
let round_bits_h = 3i32;
let rounding_off_h = 1i32 << (round_bits_h - 1);
let clip_limit = 1i32 << (8 + 1 + 7 - round_bits_h);
let hf0 = _mm512_set1_epi32(filter[0][0] as i32);
let hf1 = _mm512_set1_epi32(filter[0][1] as i32);
let hf2 = _mm512_set1_epi32(filter[0][2] as i32);
let hf3 = _mm512_set1_epi32(filter[0][3] as i32 + 128); let hf4 = _mm512_set1_epi32(filter[0][4] as i32);
let hf5 = _mm512_set1_epi32(filter[0][5] as i32);
let hf6 = _mm512_set1_epi32(filter[0][6] as i32);
let dc_offset = _mm512_set1_epi32(1i32 << 14);
let h_rounding = _mm512_set1_epi32(rounding_off_h);
let h_clip_max = _mm512_set1_epi32(clip_limit - 1);
let h_zero = _mm512_setzero_si512();
for row in 0..(h + 6) {
let tmp_row = &tmp[row * REST_UNIT_STRIDE..];
let hor_row = &mut hor[row * REST_UNIT_STRIDE..row * REST_UNIT_STRIDE + w];
let mut x = 0;
while x + 16 <= w {
let p0 = _mm512_cvtepu8_epi32(loadu_128!(&tmp_row[x..x + 16], [u8; 16]));
let p1 = _mm512_cvtepu8_epi32(loadu_128!(&tmp_row[x + 1..x + 17], [u8; 16]));
let p2 = _mm512_cvtepu8_epi32(loadu_128!(&tmp_row[x + 2..x + 18], [u8; 16]));
let p3 = _mm512_cvtepu8_epi32(loadu_128!(&tmp_row[x + 3..x + 19], [u8; 16]));
let p4 = _mm512_cvtepu8_epi32(loadu_128!(&tmp_row[x + 4..x + 20], [u8; 16]));
let p5 = _mm512_cvtepu8_epi32(loadu_128!(&tmp_row[x + 5..x + 21], [u8; 16]));
let p6 = _mm512_cvtepu8_epi32(loadu_128!(&tmp_row[x + 6..x + 22], [u8; 16]));
let mut sum = dc_offset;
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p0, hf0));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p1, hf1));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p2, hf2));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p3, hf3));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p4, hf4));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p5, hf5));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p6, hf6));
sum = _mm512_add_epi32(sum, h_rounding);
sum = _mm512_srai_epi32::<3>(sum); sum = _mm512_max_epi32(sum, h_zero);
sum = _mm512_min_epi32(sum, h_clip_max);
let result_u16: __m256i = _mm512_cvtusepi32_epi16(sum);
storeu_256!(&mut hor_row[x..x + 16], [u16; 16], result_u16);
x += 16;
}
while x < w {
let mut sum = 1i32 << 14;
sum += tmp_row[x + 3] as i32 * 128;
for k in 0..7 {
sum += tmp_row[x + k] as i32 * filter[0][k] as i32;
}
hor_row[x] = iclip((sum + rounding_off_h) >> round_bits_h, 0, clip_limit - 1) as u16;
x += 1;
}
}
let round_bits_v = 11i32;
let rounding_off_v = 1i32 << (round_bits_v - 1);
let round_offset = 1i32 << (8 + round_bits_v - 1);
let vf0 = _mm512_set1_epi32(filter[1][0] as i32);
let vf1 = _mm512_set1_epi32(filter[1][1] as i32);
let vf2 = _mm512_set1_epi32(filter[1][2] as i32);
let vf3 = _mm512_set1_epi32(filter[1][3] as i32);
let vf4 = _mm512_set1_epi32(filter[1][4] as i32);
let vf5 = _mm512_set1_epi32(filter[1][5] as i32);
let vf6 = _mm512_set1_epi32(filter[1][6] as i32);
let v_round_offset = _mm512_set1_epi32(-round_offset);
let v_rounding = _mm512_set1_epi32(rounding_off_v);
let zero_512 = _mm512_setzero_si512();
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth8, _>(
&p,
w,
h,
|bytes, offset, stride| {
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize;
let mut i = 0usize;
while i + 16 <= w {
let r0 = _mm512_cvtepu16_epi32(loadu_256!(
&hor[(j + 0) * REST_UNIT_STRIDE + i..(j + 0) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
));
let r1 = _mm512_cvtepu16_epi32(loadu_256!(
&hor[(j + 1) * REST_UNIT_STRIDE + i..(j + 1) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
));
let r2 = _mm512_cvtepu16_epi32(loadu_256!(
&hor[(j + 2) * REST_UNIT_STRIDE + i..(j + 2) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
));
let r3 = _mm512_cvtepu16_epi32(loadu_256!(
&hor[(j + 3) * REST_UNIT_STRIDE + i..(j + 3) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
));
let r4 = _mm512_cvtepu16_epi32(loadu_256!(
&hor[(j + 4) * REST_UNIT_STRIDE + i..(j + 4) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
));
let r5 = _mm512_cvtepu16_epi32(loadu_256!(
&hor[(j + 5) * REST_UNIT_STRIDE + i..(j + 5) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
));
let r6 = _mm512_cvtepu16_epi32(loadu_256!(
&hor[(j + 6) * REST_UNIT_STRIDE + i..(j + 6) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
));
let mut sum = v_round_offset;
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r0, vf0));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r1, vf1));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r2, vf2));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r3, vf3));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r4, vf4));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r5, vf5));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r6, vf6));
sum = _mm512_add_epi32(sum, v_rounding);
sum = _mm512_srai_epi32::<11>(sum);
let clamped = _mm512_max_epi32(sum, zero_512);
let result_u8: __m128i = _mm512_cvtusepi32_epi8(clamped);
storeu_128!(
&mut bytes[row_off + i..row_off + i + 16],
[u8; 16],
result_u8
);
i += 16;
}
while i + 8 <= w {
let r0 = loadu_128!(
&hor[(j + 0) * REST_UNIT_STRIDE + i..(j + 0) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
);
let r1 = loadu_128!(
&hor[(j + 1) * REST_UNIT_STRIDE + i..(j + 1) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
);
let r2 = loadu_128!(
&hor[(j + 2) * REST_UNIT_STRIDE + i..(j + 2) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
);
let r3 = loadu_128!(
&hor[(j + 3) * REST_UNIT_STRIDE + i..(j + 3) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
);
let r4 = loadu_128!(
&hor[(j + 4) * REST_UNIT_STRIDE + i..(j + 4) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
);
let r5 = loadu_128!(
&hor[(j + 5) * REST_UNIT_STRIDE + i..(j + 5) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
);
let r6 = loadu_128!(
&hor[(j + 6) * REST_UNIT_STRIDE + i..(j + 6) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
);
let vf0_256 = _mm256_set1_epi32(filter[1][0] as i32);
let vf1_256 = _mm256_set1_epi32(filter[1][1] as i32);
let vf2_256 = _mm256_set1_epi32(filter[1][2] as i32);
let vf3_256 = _mm256_set1_epi32(filter[1][3] as i32);
let vf4_256 = _mm256_set1_epi32(filter[1][4] as i32);
let vf5_256 = _mm256_set1_epi32(filter[1][5] as i32);
let vf6_256 = _mm256_set1_epi32(filter[1][6] as i32);
let r0_lo = _mm256_cvtepu16_epi32(r0);
let r1_lo = _mm256_cvtepu16_epi32(r1);
let r2_lo = _mm256_cvtepu16_epi32(r2);
let r3_lo = _mm256_cvtepu16_epi32(r3);
let r4_lo = _mm256_cvtepu16_epi32(r4);
let r5_lo = _mm256_cvtepu16_epi32(r5);
let r6_lo = _mm256_cvtepu16_epi32(r6);
let mut sum = _mm256_set1_epi32(-round_offset);
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r0_lo, vf0_256));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r1_lo, vf1_256));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r2_lo, vf2_256));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r3_lo, vf3_256));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r4_lo, vf4_256));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r5_lo, vf5_256));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r6_lo, vf6_256));
sum = _mm256_add_epi32(sum, _mm256_set1_epi32(rounding_off_v));
sum = _mm256_srai_epi32::<11>(sum);
let sum16 = _mm256_packus_epi32(sum, sum);
let sum16_lo = _mm256_castsi256_si128(sum16);
let sum16_hi = _mm256_extracti128_si256(sum16, 1);
let sum16_combined = _mm_unpacklo_epi64(sum16_lo, sum16_hi);
let sum8 = _mm_packus_epi16(sum16_combined, sum16_combined);
let dst_arr: &mut [u8; 8] = (&mut bytes[row_off + i..row_off + i + 8])
.try_into()
.unwrap();
partial_simd::mm_storel_epi64(dst_arr, sum8);
i += 8;
}
while i < w {
let mut sum = -round_offset;
for k in 0..7 {
sum += hor[(j + k) * REST_UNIT_STRIDE + i] as i32 * filter[1][k] as i32;
}
bytes[row_off + i] =
iclip((sum + rounding_off_v) >> round_bits_v, 0, 255) as u8;
i += 1;
}
}
},
); }
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter5_8bpc_avx512_inner(
_token: Server64,
p: PicOffset,
left: &[LeftPixelRow<u8>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
) {
wiener_filter7_8bpc_avx512_inner(_token, p, left, lpf, lpf_off, w, h, params, edges);
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter7_16bpc_avx512_inner(
_token: Server64,
p: PicOffset,
left: &[LeftPixelRow<u16>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
) {
let bitdepth = if bitdepth_max == 1023 { 10 } else { 12 };
let mut tmp = [0u16; (64 + 3 + 3) * REST_UNIT_STRIDE];
padding::<BitDepth16>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let mut hor = [0i32; (64 + 3 + 3) * REST_UNIT_STRIDE];
let filter = ¶ms.filter;
let round_bits_h = if bitdepth == 12 { 5 } else { 3 };
let rounding_off_h = 1i32 << (round_bits_h - 1);
let clip_limit = 1i32 << (bitdepth + 1 + 7 - round_bits_h);
let hf0 = _mm512_set1_epi32(filter[0][0] as i32);
let hf1 = _mm512_set1_epi32(filter[0][1] as i32);
let hf2 = _mm512_set1_epi32(filter[0][2] as i32);
let hf3 = _mm512_set1_epi32(filter[0][3] as i32);
let hf4 = _mm512_set1_epi32(filter[0][4] as i32);
let hf5 = _mm512_set1_epi32(filter[0][5] as i32);
let hf6 = _mm512_set1_epi32(filter[0][6] as i32);
let dc_offset_h = _mm512_set1_epi32(1i32 << (bitdepth + 6));
let h_rounding = _mm512_set1_epi32(rounding_off_h);
let h_clip_max = _mm512_set1_epi32(clip_limit - 1);
let h_zero = _mm512_setzero_si512();
for row in 0..(h + 6) {
let tmp_row = &tmp[row * REST_UNIT_STRIDE..];
let hor_row = &mut hor[row * REST_UNIT_STRIDE..row * REST_UNIT_STRIDE + w];
let mut x = 0;
while x + 16 <= w {
let p0 = _mm512_cvtepu16_epi32(loadu_256!(&tmp_row[x..x + 16], [u16; 16]));
let p1 = _mm512_cvtepu16_epi32(loadu_256!(&tmp_row[x + 1..x + 17], [u16; 16]));
let p2 = _mm512_cvtepu16_epi32(loadu_256!(&tmp_row[x + 2..x + 18], [u16; 16]));
let p3 = _mm512_cvtepu16_epi32(loadu_256!(&tmp_row[x + 3..x + 19], [u16; 16]));
let p4 = _mm512_cvtepu16_epi32(loadu_256!(&tmp_row[x + 4..x + 20], [u16; 16]));
let p5 = _mm512_cvtepu16_epi32(loadu_256!(&tmp_row[x + 5..x + 21], [u16; 16]));
let p6 = _mm512_cvtepu16_epi32(loadu_256!(&tmp_row[x + 6..x + 22], [u16; 16]));
let mut sum = dc_offset_h;
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p0, hf0));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p1, hf1));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p2, hf2));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p3, hf3));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p4, hf4));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p5, hf5));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(p6, hf6));
sum = _mm512_add_epi32(sum, h_rounding);
let shifted = if bitdepth == 12 {
_mm512_srai_epi32::<5>(sum)
} else {
_mm512_srai_epi32::<3>(sum)
};
let clamped = _mm512_max_epi32(shifted, h_zero);
let clamped = _mm512_min_epi32(clamped, h_clip_max);
storeu_512!(&mut hor_row[x..x + 16], [i32; 16], clamped);
x += 16;
}
while x < w {
let mut sum = 1i32 << (bitdepth + 6);
for k in 0..7 {
sum += tmp_row[x + k] as i32 * filter[0][k] as i32;
}
hor_row[x] = iclip((sum + rounding_off_h) >> round_bits_h, 0, clip_limit - 1);
x += 1;
}
}
let round_bits_v = if bitdepth == 12 { 9 } else { 11 };
let rounding_off_v = 1i32 << (round_bits_v - 1);
let round_offset = 1i32 << (bitdepth + round_bits_v - 1);
let vf0 = _mm512_set1_epi32(filter[1][0] as i32);
let vf1 = _mm512_set1_epi32(filter[1][1] as i32);
let vf2 = _mm512_set1_epi32(filter[1][2] as i32);
let vf3 = _mm512_set1_epi32(filter[1][3] as i32);
let vf4 = _mm512_set1_epi32(filter[1][4] as i32);
let vf5 = _mm512_set1_epi32(filter[1][5] as i32);
let vf6 = _mm512_set1_epi32(filter[1][6] as i32);
let v_round_offset = _mm512_set1_epi32(-round_offset);
let v_rounding = _mm512_set1_epi32(rounding_off_v);
let zero_512 = _mm512_setzero_si512();
let max_512 = _mm512_set1_epi32(bitdepth_max);
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth16, _>(
&p,
w,
h,
|bytes, offset, stride| {
let p_u16: &mut [u16] = zerocopy::FromBytes::mut_from_bytes(&mut bytes[..])
.expect("bytes alignment/size mismatch for u16 reinterpretation");
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize / 2;
let mut i = 0usize;
while i + 16 <= w {
let r0 = loadu_512!(
&hor[(j + 0) * REST_UNIT_STRIDE + i..(j + 0) * REST_UNIT_STRIDE + i + 16],
[i32; 16]
);
let r1 = loadu_512!(
&hor[(j + 1) * REST_UNIT_STRIDE + i..(j + 1) * REST_UNIT_STRIDE + i + 16],
[i32; 16]
);
let r2 = loadu_512!(
&hor[(j + 2) * REST_UNIT_STRIDE + i..(j + 2) * REST_UNIT_STRIDE + i + 16],
[i32; 16]
);
let r3 = loadu_512!(
&hor[(j + 3) * REST_UNIT_STRIDE + i..(j + 3) * REST_UNIT_STRIDE + i + 16],
[i32; 16]
);
let r4 = loadu_512!(
&hor[(j + 4) * REST_UNIT_STRIDE + i..(j + 4) * REST_UNIT_STRIDE + i + 16],
[i32; 16]
);
let r5 = loadu_512!(
&hor[(j + 5) * REST_UNIT_STRIDE + i..(j + 5) * REST_UNIT_STRIDE + i + 16],
[i32; 16]
);
let r6 = loadu_512!(
&hor[(j + 6) * REST_UNIT_STRIDE + i..(j + 6) * REST_UNIT_STRIDE + i + 16],
[i32; 16]
);
let mut sum = v_round_offset;
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r0, vf0));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r1, vf1));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r2, vf2));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r3, vf3));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r4, vf4));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r5, vf5));
sum = _mm512_add_epi32(sum, _mm512_mullo_epi32(r6, vf6));
sum = _mm512_add_epi32(sum, v_rounding);
let shifted = if bitdepth == 12 {
_mm512_srai_epi32::<9>(sum)
} else {
_mm512_srai_epi32::<11>(sum)
};
let clamped = _mm512_min_epi32(_mm512_max_epi32(shifted, zero_512), max_512);
let result_u16: __m256i = _mm512_cvtusepi32_epi16(clamped);
storeu_256!(
&mut p_u16[row_off + i..row_off + i + 16],
[u16; 16],
result_u16
);
i += 16;
}
while i < w {
let mut sum = -round_offset;
for k in 0..7 {
sum += hor[(j + k) * REST_UNIT_STRIDE + i] * filter[1][k] as i32;
}
p_u16[row_off + i] =
iclip((sum + rounding_off_v) >> round_bits_v, 0, bitdepth_max) as u16;
i += 1;
}
}
},
); }
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter5_16bpc_avx512_inner(
_token: Server64,
p: PicOffset,
left: &[LeftPixelRow<u16>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
) {
wiener_filter7_16bpc_avx512_inner(
_token,
p,
left,
lpf,
lpf_off,
w,
h,
params,
edges,
bitdepth_max,
);
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter7_16bpc_avx2_inner(
_token: Desktop64,
p: PicOffset,
left: &[LeftPixelRow<u16>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
) {
let bitdepth = if bitdepth_max == 1023 { 10 } else { 12 };
let mut tmp = [0u16; (64 + 3 + 3) * REST_UNIT_STRIDE];
padding::<BitDepth16>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let mut hor = [0i32; (64 + 3 + 3) * REST_UNIT_STRIDE];
let filter = ¶ms.filter;
let round_bits_h = if bitdepth == 12 { 5 } else { 3 };
let rounding_off_h = 1i32 << (round_bits_h - 1);
let clip_limit = 1i32 << (bitdepth + 1 + 7 - round_bits_h);
let hf0 = _mm256_set1_epi32(filter[0][0] as i32);
let hf1 = _mm256_set1_epi32(filter[0][1] as i32);
let hf2 = _mm256_set1_epi32(filter[0][2] as i32);
let hf3 = _mm256_set1_epi32(filter[0][3] as i32);
let hf4 = _mm256_set1_epi32(filter[0][4] as i32);
let hf5 = _mm256_set1_epi32(filter[0][5] as i32);
let hf6 = _mm256_set1_epi32(filter[0][6] as i32);
let dc_offset_h = _mm256_set1_epi32(1i32 << (bitdepth + 6));
let h_rounding = _mm256_set1_epi32(rounding_off_h);
let h_clip_max = _mm256_set1_epi32(clip_limit - 1);
let h_zero = _mm256_setzero_si256();
for row in 0..(h + 6) {
let tmp_row = &tmp[row * REST_UNIT_STRIDE..];
let hor_row = &mut hor[row * REST_UNIT_STRIDE..row * REST_UNIT_STRIDE + w];
let mut x = 0;
while x + 8 <= w {
let p0 = _mm256_cvtepu16_epi32(loadu_128!(&tmp_row[x..x + 8], [u16; 8]));
let p1 = _mm256_cvtepu16_epi32(loadu_128!(&tmp_row[x + 1..x + 9], [u16; 8]));
let p2 = _mm256_cvtepu16_epi32(loadu_128!(&tmp_row[x + 2..x + 10], [u16; 8]));
let p3 = _mm256_cvtepu16_epi32(loadu_128!(&tmp_row[x + 3..x + 11], [u16; 8]));
let p4 = _mm256_cvtepu16_epi32(loadu_128!(&tmp_row[x + 4..x + 12], [u16; 8]));
let p5 = _mm256_cvtepu16_epi32(loadu_128!(&tmp_row[x + 5..x + 13], [u16; 8]));
let p6 = _mm256_cvtepu16_epi32(loadu_128!(&tmp_row[x + 6..x + 14], [u16; 8]));
let mut sum = dc_offset_h;
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p0, hf0));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p1, hf1));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p2, hf2));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p3, hf3));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p4, hf4));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p5, hf5));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(p6, hf6));
sum = _mm256_add_epi32(sum, h_rounding);
let shifted = if bitdepth == 12 {
_mm256_srai_epi32::<5>(sum)
} else {
_mm256_srai_epi32::<3>(sum)
};
let clamped = _mm256_max_epi32(shifted, h_zero);
let clamped = _mm256_min_epi32(clamped, h_clip_max);
storeu_256!(&mut hor_row[x..x + 8], [i32; 8], clamped);
x += 8;
}
while x < w {
let mut sum = 1i32 << (bitdepth + 6);
for k in 0..7 {
sum += tmp_row[x + k] as i32 * filter[0][k] as i32;
}
hor_row[x] = iclip((sum + rounding_off_h) >> round_bits_h, 0, clip_limit - 1);
x += 1;
}
}
let round_bits_v = if bitdepth == 12 { 9 } else { 11 };
let rounding_off_v = 1i32 << (round_bits_v - 1);
let round_offset = 1i32 << (bitdepth + round_bits_v - 1);
let vf0 = _mm256_set1_epi32(filter[1][0] as i32);
let vf1 = _mm256_set1_epi32(filter[1][1] as i32);
let vf2 = _mm256_set1_epi32(filter[1][2] as i32);
let vf3 = _mm256_set1_epi32(filter[1][3] as i32);
let vf4 = _mm256_set1_epi32(filter[1][4] as i32);
let vf5 = _mm256_set1_epi32(filter[1][5] as i32);
let vf6 = _mm256_set1_epi32(filter[1][6] as i32);
let v_round_offset = _mm256_set1_epi32(-round_offset);
let v_rounding = _mm256_set1_epi32(rounding_off_v);
let v_max = _mm256_set1_epi32(bitdepth_max);
let v_zero = _mm256_setzero_si256();
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth16, _>(
&p,
w,
h,
|bytes, offset, stride| {
let p_u16: &mut [u16] = zerocopy::FromBytes::mut_from_bytes(&mut bytes[..])
.expect("bytes alignment/size mismatch for u16 reinterpretation");
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize / 2;
let mut i = 0usize;
while i + 8 <= w {
let r0 = loadu_256!(
&hor[(j + 0) * REST_UNIT_STRIDE + i..(j + 0) * REST_UNIT_STRIDE + i + 8],
[i32; 8]
);
let r1 = loadu_256!(
&hor[(j + 1) * REST_UNIT_STRIDE + i..(j + 1) * REST_UNIT_STRIDE + i + 8],
[i32; 8]
);
let r2 = loadu_256!(
&hor[(j + 2) * REST_UNIT_STRIDE + i..(j + 2) * REST_UNIT_STRIDE + i + 8],
[i32; 8]
);
let r3 = loadu_256!(
&hor[(j + 3) * REST_UNIT_STRIDE + i..(j + 3) * REST_UNIT_STRIDE + i + 8],
[i32; 8]
);
let r4 = loadu_256!(
&hor[(j + 4) * REST_UNIT_STRIDE + i..(j + 4) * REST_UNIT_STRIDE + i + 8],
[i32; 8]
);
let r5 = loadu_256!(
&hor[(j + 5) * REST_UNIT_STRIDE + i..(j + 5) * REST_UNIT_STRIDE + i + 8],
[i32; 8]
);
let r6 = loadu_256!(
&hor[(j + 6) * REST_UNIT_STRIDE + i..(j + 6) * REST_UNIT_STRIDE + i + 8],
[i32; 8]
);
let mut sum = v_round_offset;
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r0, vf0));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r1, vf1));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r2, vf2));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r3, vf3));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r4, vf4));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r5, vf5));
sum = _mm256_add_epi32(sum, _mm256_mullo_epi32(r6, vf6));
sum = _mm256_add_epi32(sum, v_rounding);
let shifted = if bitdepth == 12 {
_mm256_srai_epi32::<9>(sum)
} else {
_mm256_srai_epi32::<11>(sum)
};
let clamped = _mm256_min_epi32(_mm256_max_epi32(shifted, v_zero), v_max);
let packed = _mm256_packus_epi32(clamped, clamped);
let lo = _mm256_castsi256_si128(packed);
let hi = _mm256_extracti128_si256(packed, 1);
let combined = _mm_unpacklo_epi64(lo, hi);
storeu_128!(&mut p_u16[row_off + i..row_off + i + 8], [u16; 8], combined);
i += 8;
}
while i < w {
let mut sum = -round_offset;
for k in 0..7 {
sum += hor[(j + k) * REST_UNIT_STRIDE + i] * filter[1][k] as i32;
}
p_u16[row_off + i] =
iclip((sum + rounding_off_v) >> round_bits_v, 0, bitdepth_max) as u16;
i += 1;
}
}
},
); }
#[cfg(target_arch = "x86_64")]
#[arcane]
fn wiener_filter5_16bpc_avx2_inner(
_token: Desktop64,
p: PicOffset,
left: &[LeftPixelRow<u16>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
) {
wiener_filter7_16bpc_avx2_inner(
_token,
p,
left,
lpf,
lpf_off,
w,
h,
params,
edges,
bitdepth_max,
);
}
fn reconstruct_lpf_offset(lpf: &DisjointMut<AlignedVec64<u8>>, ptr: *const u8) -> isize {
let base = lpf.as_mut_ptr();
ptr as isize - base as isize
}
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn wiener_filter7_8bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
_bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u8>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u8>();
let lpf_off = reconstruct_lpf_offset(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
let token = unsafe { Desktop64::forge_token_dangerously() };
wiener_filter7_8bpc_avx2_inner(token, p, left, lpf, lpf_off, w, h, params, edges);
}
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn wiener_filter5_8bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
_bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u8>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u8>();
let lpf_off = reconstruct_lpf_offset(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
let token = unsafe { Desktop64::forge_token_dangerously() };
wiener_filter5_8bpc_avx2_inner(token, p, left, lpf, lpf_off, w, h, params, edges);
}
fn reconstruct_lpf_offset_16bpc(lpf: &DisjointMut<AlignedVec64<u8>>, ptr: *const u16) -> isize {
let base = lpf.as_mut_ptr().cast::<u16>();
ptr as isize - base as isize / 2 }
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn wiener_filter7_16bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u16>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u16>();
let lpf_off = reconstruct_lpf_offset_16bpc(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
let token = unsafe { Desktop64::forge_token_dangerously() };
wiener_filter7_16bpc_avx2_inner(
token,
p,
left,
lpf,
lpf_off,
w,
h,
params,
edges,
bitdepth_max,
);
}
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn wiener_filter5_16bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u16>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u16>();
let lpf_off = reconstruct_lpf_offset_16bpc(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
let token = unsafe { Desktop64::forge_token_dangerously() };
wiener_filter5_16bpc_avx2_inner(
token,
p,
left,
lpf,
lpf_off,
w,
h,
params,
edges,
bitdepth_max,
);
}
const MAX_RESTORATION_WIDTH: usize = 256 * 3 / 2;
#[inline(always)]
fn boxsum5_8bpc(
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
for x in 0..w {
let mut sum_v = x;
let mut sumsq_v = x;
let mut a = src[x] as i32;
let mut a2 = a * a;
let mut b = src[1 * REST_UNIT_STRIDE + x] as i32;
let mut b2 = b * b;
let mut c = src[2 * REST_UNIT_STRIDE + x] as i32;
let mut c2 = c * c;
let mut d = src[3 * REST_UNIT_STRIDE + x] as i32;
let mut d2 = d * d;
let mut s_idx = 3 * REST_UNIT_STRIDE + x;
for _ in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let e = src[s_idx] as i32;
let e2 = e * e;
sum_v += REST_UNIT_STRIDE;
sumsq_v += REST_UNIT_STRIDE;
sum[sum_v] = (a + b + c + d + e) as i16;
sumsq[sumsq_v] = a2 + b2 + c2 + d2 + e2;
a = b;
a2 = b2;
b = c;
b2 = c2;
c = d;
c2 = d2;
d = e;
d2 = e2;
}
}
let mut sum_idx = REST_UNIT_STRIDE;
let mut sumsq_idx = REST_UNIT_STRIDE;
for _ in 2..h - 2 {
let mut a = sum[sum_idx];
let mut a2 = sumsq[sumsq_idx];
let mut b = sum[sum_idx + 1];
let mut b2 = sumsq[sumsq_idx + 1];
let mut c = sum[sum_idx + 2];
let mut c2 = sumsq[sumsq_idx + 2];
let mut d = sum[sum_idx + 3];
let mut d2 = sumsq[sumsq_idx + 3];
for x in 2..w - 2 {
let e = sum[sum_idx + x + 2];
let e2 = sumsq[sumsq_idx + x + 2];
sum[sum_idx + x] = a + b + c + d + e;
sumsq[sumsq_idx + x] = a2 + b2 + c2 + d2 + e2;
a = b;
b = c;
c = d;
d = e;
a2 = b2;
b2 = c2;
c2 = d2;
d2 = e2;
}
sum_idx += REST_UNIT_STRIDE;
sumsq_idx += REST_UNIT_STRIDE;
}
}
#[inline(always)]
fn boxsum3_8bpc(
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let src = &src[REST_UNIT_STRIDE..];
for x in 1..w - 1 {
let mut sum_v = x;
let mut sumsq_v = x;
let mut a = src[x] as i32;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i32;
let mut b2 = b * b;
let mut s_idx = REST_UNIT_STRIDE + x;
for _ in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let c = src[s_idx] as i32;
let c2 = c * c;
sum_v += REST_UNIT_STRIDE;
sumsq_v += REST_UNIT_STRIDE;
sum[sum_v] = (a + b + c) as i16;
sumsq[sumsq_v] = a2 + b2 + c2;
a = b;
a2 = b2;
b = c;
b2 = c2;
}
}
let mut sum_idx = REST_UNIT_STRIDE;
let mut sumsq_idx = REST_UNIT_STRIDE;
for _ in 2..h - 2 {
let mut a = sum[sum_idx + 1];
let mut a2 = sumsq[sumsq_idx + 1];
let mut b = sum[sum_idx + 2];
let mut b2 = sumsq[sumsq_idx + 2];
for x in 2..w - 2 {
let c = sum[sum_idx + x + 1];
let c2 = sumsq[sumsq_idx + x + 1];
sum[sum_idx + x] = a + b + c;
sumsq[sumsq_idx + x] = a2 + b2 + c2;
a = b;
b = c;
a2 = b2;
b2 = c2;
}
sum_idx += REST_UNIT_STRIDE;
sumsq_idx += REST_UNIT_STRIDE;
}
}
#[inline(never)]
fn selfguided_filter_8bpc(
dst: &mut [i16; 64 * MAX_RESTORATION_WIDTH],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
n: i32,
s: u32,
) {
let sgr_one_by_x: u32 = if n == 25 { 164 } else { 455 };
let mut sumsq = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut sum = [0i16; (64 + 2 + 2) * REST_UNIT_STRIDE];
let step = if n == 25 { 2 } else { 1 };
if n == 25 {
boxsum5_8bpc(&mut sumsq, &mut sum, src, w + 6, h + 6);
} else {
boxsum3_8bpc(&mut sumsq, &mut sum, src, w + 6, h + 6);
}
{
let mut sq = sumsq.as_mut_slice().flex_mut();
let mut sm = sum.as_mut_slice().flex_mut();
for row_offset in (0..(h + 2)).step_by(step) {
let aa_base = (row_offset + 1) * REST_UNIT_STRIDE + 2;
for i in 0..(w + 2) {
let idx = aa_base + i;
let a_val = sq[idx];
let b_val = sm[idx] as i32;
let p = cmp::max(a_val * n - b_val * b_val, 0) as u32;
let z = (p * s + (1 << 19)) >> 20;
let x = dav1d_sgr_x_by_x[cmp::min(z, 255) as usize] as u32;
sq[idx] = ((x * (b_val as u32) * sgr_one_by_x + (1 << 11)) >> 12) as i32;
sm[idx] = x as i16;
}
}
}
let base = 2 * REST_UNIT_STRIDE + 3; let src_base = 3 * REST_UNIT_STRIDE + 3;
let bb = sum.as_slice().flex();
let aa = sumsq.as_slice().flex();
let src = src.as_slice().flex();
let mut dst = dst.as_mut_slice().flex_mut();
if n == 25 {
let mut j = 0usize;
while j < h.saturating_sub(1) {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb[idx - REST_UNIT_STRIDE] as i32;
let below = bb[idx + REST_UNIT_STRIDE] as i32;
let above_left = bb[idx - REST_UNIT_STRIDE - 1] as i32;
let above_right = bb[idx - REST_UNIT_STRIDE + 1] as i32;
let below_left = bb[idx + REST_UNIT_STRIDE - 1] as i32;
let below_right = bb[idx + REST_UNIT_STRIDE + 1] as i32;
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let a_six = {
let above = aa[idx - REST_UNIT_STRIDE];
let below = aa[idx + REST_UNIT_STRIDE];
let above_left = aa[idx - REST_UNIT_STRIDE - 1];
let above_right = aa[idx - REST_UNIT_STRIDE + 1];
let below_left = aa[idx + REST_UNIT_STRIDE - 1];
let below_right = aa[idx + REST_UNIT_STRIDE + 1];
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i16;
}
if j + 1 < h {
for i in 0..w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let b_horiz = {
let center = bb[idx] as i32;
let left = bb[idx - 1] as i32;
let right = bb[idx + 1] as i32;
center * 6 + (left + right) * 5
};
let a_horiz = {
let center = aa[idx];
let left = aa[idx - 1];
let right = aa[idx + 1];
center * 6 + (left + right) * 5
};
let src_val = src[src_base + (j + 1) * REST_UNIT_STRIDE + i] as i32;
dst[(j + 1) * MAX_RESTORATION_WIDTH + i] =
((a_horiz - b_horiz * src_val + (1 << 7)) >> 8) as i16;
}
}
j += 2;
}
if j < h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb[idx - REST_UNIT_STRIDE] as i32;
let below = bb[idx + REST_UNIT_STRIDE] as i32;
let above_left = bb[idx - REST_UNIT_STRIDE - 1] as i32;
let above_right = bb[idx - REST_UNIT_STRIDE + 1] as i32;
let below_left = bb[idx + REST_UNIT_STRIDE - 1] as i32;
let below_right = bb[idx + REST_UNIT_STRIDE + 1] as i32;
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let a_six = {
let above = aa[idx - REST_UNIT_STRIDE];
let below = aa[idx + REST_UNIT_STRIDE];
let above_left = aa[idx - REST_UNIT_STRIDE - 1];
let above_right = aa[idx - REST_UNIT_STRIDE + 1];
let below_left = aa[idx + REST_UNIT_STRIDE - 1];
let below_right = aa[idx + REST_UNIT_STRIDE + 1];
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i16;
}
}
} else {
for j in 0..h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_eight = {
let center = bb[idx] as i32;
let left = bb[idx - 1] as i32;
let right = bb[idx + 1] as i32;
let above = bb[idx - REST_UNIT_STRIDE] as i32;
let below = bb[idx + REST_UNIT_STRIDE] as i32;
let above_left = bb[idx - REST_UNIT_STRIDE - 1] as i32;
let above_right = bb[idx - REST_UNIT_STRIDE + 1] as i32;
let below_left = bb[idx + REST_UNIT_STRIDE - 1] as i32;
let below_right = bb[idx + REST_UNIT_STRIDE + 1] as i32;
(center + left + right + above + below) * 4
+ (above_left + above_right + below_left + below_right) * 3
};
let a_eight = {
let center = aa[idx];
let left = aa[idx - 1];
let right = aa[idx + 1];
let above = aa[idx - REST_UNIT_STRIDE];
let below = aa[idx + REST_UNIT_STRIDE];
let above_left = aa[idx - REST_UNIT_STRIDE - 1];
let above_right = aa[idx - REST_UNIT_STRIDE + 1];
let below_left = aa[idx + REST_UNIT_STRIDE - 1];
let below_right = aa[idx + REST_UNIT_STRIDE + 1];
(center + left + right + above + below) * 4
+ (above_left + above_right + below_left + below_right) * 3
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_eight - b_eight * src_val + (1 << 8)) >> 9) as i16;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum5_v_avx2(
_token: Desktop64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut x = 0usize;
while x + 16 <= w {
for out_row in 2..h - 2 {
let base_row = out_row - 2;
let r0 = loadu_128!(
&src[base_row * REST_UNIT_STRIDE + x..base_row * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let r1 = loadu_128!(
&src[(base_row + 1) * REST_UNIT_STRIDE + x
..(base_row + 1) * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let r2 = loadu_128!(
&src[(base_row + 2) * REST_UNIT_STRIDE + x
..(base_row + 2) * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let r3 = loadu_128!(
&src[(base_row + 3) * REST_UNIT_STRIDE + x
..(base_row + 3) * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let r4 = loadu_128!(
&src[(base_row + 4) * REST_UNIT_STRIDE + x
..(base_row + 4) * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let w0 = _mm256_cvtepu8_epi16(r0);
let w1 = _mm256_cvtepu8_epi16(r1);
let w2 = _mm256_cvtepu8_epi16(r2);
let w3 = _mm256_cvtepu8_epi16(r3);
let w4 = _mm256_cvtepu8_epi16(r4);
let sum_v = _mm256_add_epi16(
_mm256_add_epi16(_mm256_add_epi16(w0, w1), _mm256_add_epi16(w2, w3)),
w4,
);
let sum_offset = (out_row - 1) * REST_UNIT_STRIDE + x;
storeu_256!(&mut sum[sum_offset..sum_offset + 16], [i16; 16], sum_v);
let lo_0 = _mm256_cvtepu8_epi32(r0);
let lo_1 = _mm256_cvtepu8_epi32(r1);
let lo_2 = _mm256_cvtepu8_epi32(r2);
let lo_3 = _mm256_cvtepu8_epi32(r3);
let lo_4 = _mm256_cvtepu8_epi32(r4);
let sq0 = _mm256_mullo_epi32(lo_0, lo_0);
let sq1 = _mm256_mullo_epi32(lo_1, lo_1);
let sq2 = _mm256_mullo_epi32(lo_2, lo_2);
let sq3 = _mm256_mullo_epi32(lo_3, lo_3);
let sq4 = _mm256_mullo_epi32(lo_4, lo_4);
let sumsq_lo = _mm256_add_epi32(
_mm256_add_epi32(_mm256_add_epi32(sq0, sq1), _mm256_add_epi32(sq2, sq3)),
sq4,
);
storeu_256!(&mut sumsq[sum_offset..sum_offset + 8], [i32; 8], sumsq_lo);
let hi_0 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r0));
let hi_1 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r1));
let hi_2 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r2));
let hi_3 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r3));
let hi_4 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r4));
let sq0h = _mm256_mullo_epi32(hi_0, hi_0);
let sq1h = _mm256_mullo_epi32(hi_1, hi_1);
let sq2h = _mm256_mullo_epi32(hi_2, hi_2);
let sq3h = _mm256_mullo_epi32(hi_3, hi_3);
let sq4h = _mm256_mullo_epi32(hi_4, hi_4);
let sumsq_hi = _mm256_add_epi32(
_mm256_add_epi32(_mm256_add_epi32(sq0h, sq1h), _mm256_add_epi32(sq2h, sq3h)),
sq4h,
);
storeu_256!(
&mut sumsq[sum_offset + 8..sum_offset + 16],
[i32; 8],
sumsq_hi
);
}
x += 16;
}
for x in x..w {
let mut a = src[x] as i32;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i32;
let mut b2 = b * b;
let mut c = src[2 * REST_UNIT_STRIDE + x] as i32;
let mut c2 = c * c;
let mut d = src[3 * REST_UNIT_STRIDE + x] as i32;
let mut d2 = d * d;
let mut s_idx = 3 * REST_UNIT_STRIDE + x;
for out_row in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let e = src[s_idx] as i32;
let e2 = e * e;
let sum_v = (out_row - 1) * REST_UNIT_STRIDE + x;
let sumsq_v = (out_row - 1) * REST_UNIT_STRIDE + x;
sum[sum_v] = (a + b + c + d + e) as i16;
sumsq[sumsq_v] = a2 + b2 + c2 + d2 + e2;
a = b;
a2 = b2;
b = c;
b2 = c2;
c = d;
c2 = d2;
d = e;
d2 = e2;
}
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum5_h_avx2(
_token: Desktop64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut sum_tmp = [0i16; REST_UNIT_STRIDE];
let mut sumsq_tmp = [0i32; REST_UNIT_STRIDE];
for row in 1..h - 3 {
let row_off = row * REST_UNIT_STRIDE;
let mut x = 2usize;
while x + 16 <= w - 2 {
let s0 = loadu_256!(&sum[row_off + x - 2..row_off + x - 2 + 16], [i16; 16]);
let s1 = loadu_256!(&sum[row_off + x - 1..row_off + x - 1 + 16], [i16; 16]);
let s2 = loadu_256!(&sum[row_off + x..row_off + x + 16], [i16; 16]);
let s3 = loadu_256!(&sum[row_off + x + 1..row_off + x + 1 + 16], [i16; 16]);
let s4 = loadu_256!(&sum[row_off + x + 2..row_off + x + 2 + 16], [i16; 16]);
let hsum = _mm256_add_epi16(
_mm256_add_epi16(_mm256_add_epi16(s0, s1), _mm256_add_epi16(s2, s3)),
s4,
);
storeu_256!(&mut sum_tmp[x..x + 16], [i16; 16], hsum);
for off in [0usize, 8] {
let q0 = loadu_256!(
&sumsq[row_off + x + off - 2..row_off + x + off - 2 + 8],
[i32; 8]
);
let q1 = loadu_256!(
&sumsq[row_off + x + off - 1..row_off + x + off - 1 + 8],
[i32; 8]
);
let q2 = loadu_256!(&sumsq[row_off + x + off..row_off + x + off + 8], [i32; 8]);
let q3 = loadu_256!(
&sumsq[row_off + x + off + 1..row_off + x + off + 1 + 8],
[i32; 8]
);
let q4 = loadu_256!(
&sumsq[row_off + x + off + 2..row_off + x + off + 2 + 8],
[i32; 8]
);
let hsumsq = _mm256_add_epi32(
_mm256_add_epi32(_mm256_add_epi32(q0, q1), _mm256_add_epi32(q2, q3)),
q4,
);
storeu_256!(&mut sumsq_tmp[x + off..x + off + 8], [i32; 8], hsumsq);
}
x += 16;
}
while x < w - 2 {
let a = sum[row_off + x - 2];
let b = sum[row_off + x - 1];
let c = sum[row_off + x];
let d = sum[row_off + x + 1];
let e = sum[row_off + x + 2];
sum_tmp[x] = a + b + c + d + e;
let a2 = sumsq[row_off + x - 2];
let b2 = sumsq[row_off + x - 1];
let c2 = sumsq[row_off + x];
let d2 = sumsq[row_off + x + 1];
let e2 = sumsq[row_off + x + 2];
sumsq_tmp[x] = a2 + b2 + c2 + d2 + e2;
x += 1;
}
sum[row_off + 2..row_off + w - 2].copy_from_slice(&sum_tmp[2..w - 2]);
sumsq[row_off + 2..row_off + w - 2].copy_from_slice(&sumsq_tmp[2..w - 2]);
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum3_v_avx2(
_token: Desktop64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let src = &src[REST_UNIT_STRIDE..];
let mut x = 1usize;
while x + 16 < w {
for out_row in 2..h - 2 {
let base_row = out_row - 2; let r0 = loadu_128!(
&src[base_row * REST_UNIT_STRIDE + x..base_row * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let r1 = loadu_128!(
&src[(base_row + 1) * REST_UNIT_STRIDE + x
..(base_row + 1) * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let r2 = loadu_128!(
&src[(base_row + 2) * REST_UNIT_STRIDE + x
..(base_row + 2) * REST_UNIT_STRIDE + x + 16],
[u8; 16]
);
let w0 = _mm256_cvtepu8_epi16(r0);
let w1 = _mm256_cvtepu8_epi16(r1);
let w2 = _mm256_cvtepu8_epi16(r2);
let sum_v = _mm256_add_epi16(_mm256_add_epi16(w0, w1), w2);
let sum_offset = (out_row - 1) * REST_UNIT_STRIDE + x;
storeu_256!(&mut sum[sum_offset..sum_offset + 16], [i16; 16], sum_v);
let lo_0 = _mm256_cvtepu8_epi32(r0);
let lo_1 = _mm256_cvtepu8_epi32(r1);
let lo_2 = _mm256_cvtepu8_epi32(r2);
let sq_lo = _mm256_add_epi32(
_mm256_add_epi32(
_mm256_mullo_epi32(lo_0, lo_0),
_mm256_mullo_epi32(lo_1, lo_1),
),
_mm256_mullo_epi32(lo_2, lo_2),
);
storeu_256!(&mut sumsq[sum_offset..sum_offset + 8], [i32; 8], sq_lo);
let hi_0 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r0));
let hi_1 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r1));
let hi_2 = _mm256_cvtepu8_epi32(_mm_srli_si128::<8>(r2));
let sq_hi = _mm256_add_epi32(
_mm256_add_epi32(
_mm256_mullo_epi32(hi_0, hi_0),
_mm256_mullo_epi32(hi_1, hi_1),
),
_mm256_mullo_epi32(hi_2, hi_2),
);
storeu_256!(&mut sumsq[sum_offset + 8..sum_offset + 16], [i32; 8], sq_hi);
}
x += 16;
}
for x in x..w - 1 {
let mut a = src[x] as i32;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i32;
let mut b2 = b * b;
let mut s_idx = REST_UNIT_STRIDE + x;
for out_row in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let c = src[s_idx] as i32;
let c2 = c * c;
let sum_v = (out_row - 1) * REST_UNIT_STRIDE + x;
sum[sum_v] = (a + b + c) as i16;
sumsq[sum_v] = a2 + b2 + c2;
a = b;
a2 = b2;
b = c;
b2 = c2;
}
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum3_h_avx2(
_token: Desktop64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut sum_tmp = [0i16; REST_UNIT_STRIDE];
let mut sumsq_tmp = [0i32; REST_UNIT_STRIDE];
for row in 1..h - 3 {
let row_off = row * REST_UNIT_STRIDE;
let mut x = 2usize;
while x + 16 <= w - 2 {
let s0 = loadu_256!(&sum[row_off + x - 1..row_off + x - 1 + 16], [i16; 16]);
let s1 = loadu_256!(&sum[row_off + x..row_off + x + 16], [i16; 16]);
let s2 = loadu_256!(&sum[row_off + x + 1..row_off + x + 1 + 16], [i16; 16]);
let hsum = _mm256_add_epi16(_mm256_add_epi16(s0, s1), s2);
storeu_256!(&mut sum_tmp[x..x + 16], [i16; 16], hsum);
for off in [0usize, 8] {
let q0 = loadu_256!(
&sumsq[row_off + x + off - 1..row_off + x + off - 1 + 8],
[i32; 8]
);
let q1 = loadu_256!(&sumsq[row_off + x + off..row_off + x + off + 8], [i32; 8]);
let q2 = loadu_256!(
&sumsq[row_off + x + off + 1..row_off + x + off + 1 + 8],
[i32; 8]
);
let hsumsq = _mm256_add_epi32(_mm256_add_epi32(q0, q1), q2);
storeu_256!(&mut sumsq_tmp[x + off..x + off + 8], [i32; 8], hsumsq);
}
x += 16;
}
while x < w - 2 {
sum_tmp[x] = sum[row_off + x - 1] + sum[row_off + x] + sum[row_off + x + 1];
sumsq_tmp[x] = sumsq[row_off + x - 1] + sumsq[row_off + x] + sumsq[row_off + x + 1];
x += 1;
}
sum[row_off + 2..row_off + w - 2].copy_from_slice(&sum_tmp[2..w - 2]);
sumsq[row_off + 2..row_off + w - 2].copy_from_slice(&sumsq_tmp[2..w - 2]);
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum5_v_avx512(
_token: Server64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut x = 0usize;
while x + 32 <= w {
for out_row in 2..h - 2 {
let base_row = out_row - 2;
let r0 = loadu_256!(
&src[base_row * REST_UNIT_STRIDE + x..base_row * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let r1 = loadu_256!(
&src[(base_row + 1) * REST_UNIT_STRIDE + x
..(base_row + 1) * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let r2 = loadu_256!(
&src[(base_row + 2) * REST_UNIT_STRIDE + x
..(base_row + 2) * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let r3 = loadu_256!(
&src[(base_row + 3) * REST_UNIT_STRIDE + x
..(base_row + 3) * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let r4 = loadu_256!(
&src[(base_row + 4) * REST_UNIT_STRIDE + x
..(base_row + 4) * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let w0 = _mm512_cvtepu8_epi16(r0);
let w1 = _mm512_cvtepu8_epi16(r1);
let w2 = _mm512_cvtepu8_epi16(r2);
let w3 = _mm512_cvtepu8_epi16(r3);
let w4 = _mm512_cvtepu8_epi16(r4);
let sum_v = _mm512_add_epi16(
_mm512_add_epi16(_mm512_add_epi16(w0, w1), _mm512_add_epi16(w2, w3)),
w4,
);
let sum_offset = (out_row - 1) * REST_UNIT_STRIDE + x;
storeu_512!(&mut sum[sum_offset..sum_offset + 32], [i16; 32], sum_v);
let r0_lo = _mm256_castsi256_si128(r0);
let r1_lo = _mm256_castsi256_si128(r1);
let r2_lo = _mm256_castsi256_si128(r2);
let r3_lo = _mm256_castsi256_si128(r3);
let r4_lo = _mm256_castsi256_si128(r4);
let lo_0 = _mm512_cvtepu8_epi32(r0_lo);
let lo_1 = _mm512_cvtepu8_epi32(r1_lo);
let lo_2 = _mm512_cvtepu8_epi32(r2_lo);
let lo_3 = _mm512_cvtepu8_epi32(r3_lo);
let lo_4 = _mm512_cvtepu8_epi32(r4_lo);
let sumsq_lo = _mm512_add_epi32(
_mm512_add_epi32(
_mm512_add_epi32(
_mm512_mullo_epi32(lo_0, lo_0),
_mm512_mullo_epi32(lo_1, lo_1),
),
_mm512_add_epi32(
_mm512_mullo_epi32(lo_2, lo_2),
_mm512_mullo_epi32(lo_3, lo_3),
),
),
_mm512_mullo_epi32(lo_4, lo_4),
);
storeu_512!(&mut sumsq[sum_offset..sum_offset + 16], [i32; 16], sumsq_lo);
let r0_hi = _mm256_extracti128_si256::<1>(r0);
let r1_hi = _mm256_extracti128_si256::<1>(r1);
let r2_hi = _mm256_extracti128_si256::<1>(r2);
let r3_hi = _mm256_extracti128_si256::<1>(r3);
let r4_hi = _mm256_extracti128_si256::<1>(r4);
let hi_0 = _mm512_cvtepu8_epi32(r0_hi);
let hi_1 = _mm512_cvtepu8_epi32(r1_hi);
let hi_2 = _mm512_cvtepu8_epi32(r2_hi);
let hi_3 = _mm512_cvtepu8_epi32(r3_hi);
let hi_4 = _mm512_cvtepu8_epi32(r4_hi);
let sumsq_hi = _mm512_add_epi32(
_mm512_add_epi32(
_mm512_add_epi32(
_mm512_mullo_epi32(hi_0, hi_0),
_mm512_mullo_epi32(hi_1, hi_1),
),
_mm512_add_epi32(
_mm512_mullo_epi32(hi_2, hi_2),
_mm512_mullo_epi32(hi_3, hi_3),
),
),
_mm512_mullo_epi32(hi_4, hi_4),
);
storeu_512!(
&mut sumsq[sum_offset + 16..sum_offset + 32],
[i32; 16],
sumsq_hi
);
}
x += 32;
}
for x in x..w {
let mut a = src[x] as i32;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i32;
let mut b2 = b * b;
let mut c = src[2 * REST_UNIT_STRIDE + x] as i32;
let mut c2 = c * c;
let mut d = src[3 * REST_UNIT_STRIDE + x] as i32;
let mut d2 = d * d;
let mut s_idx = 3 * REST_UNIT_STRIDE + x;
for out_row in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let e = src[s_idx] as i32;
let e2 = e * e;
let sum_v = (out_row - 1) * REST_UNIT_STRIDE + x;
sum[sum_v] = (a + b + c + d + e) as i16;
sumsq[sum_v] = a2 + b2 + c2 + d2 + e2;
a = b;
a2 = b2;
b = c;
b2 = c2;
c = d;
c2 = d2;
d = e;
d2 = e2;
}
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum5_h_avx512(
_token: Server64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut sum_tmp = [0i16; REST_UNIT_STRIDE];
let mut sumsq_tmp = [0i32; REST_UNIT_STRIDE];
for row in 1..h - 3 {
let row_off = row * REST_UNIT_STRIDE;
let mut x = 2usize;
while x + 32 <= w - 2 {
let s0 = loadu_512!(&sum[row_off + x - 2..row_off + x - 2 + 32], [i16; 32]);
let s1 = loadu_512!(&sum[row_off + x - 1..row_off + x - 1 + 32], [i16; 32]);
let s2 = loadu_512!(&sum[row_off + x..row_off + x + 32], [i16; 32]);
let s3 = loadu_512!(&sum[row_off + x + 1..row_off + x + 1 + 32], [i16; 32]);
let s4 = loadu_512!(&sum[row_off + x + 2..row_off + x + 2 + 32], [i16; 32]);
let hsum = _mm512_add_epi16(
_mm512_add_epi16(_mm512_add_epi16(s0, s1), _mm512_add_epi16(s2, s3)),
s4,
);
storeu_512!(&mut sum_tmp[x..x + 32], [i16; 32], hsum);
for off in [0usize, 16] {
let q0 = loadu_512!(
&sumsq[row_off + x + off - 2..row_off + x + off - 2 + 16],
[i32; 16]
);
let q1 = loadu_512!(
&sumsq[row_off + x + off - 1..row_off + x + off - 1 + 16],
[i32; 16]
);
let q2 = loadu_512!(&sumsq[row_off + x + off..row_off + x + off + 16], [i32; 16]);
let q3 = loadu_512!(
&sumsq[row_off + x + off + 1..row_off + x + off + 1 + 16],
[i32; 16]
);
let q4 = loadu_512!(
&sumsq[row_off + x + off + 2..row_off + x + off + 2 + 16],
[i32; 16]
);
let hsumsq = _mm512_add_epi32(
_mm512_add_epi32(_mm512_add_epi32(q0, q1), _mm512_add_epi32(q2, q3)),
q4,
);
storeu_512!(&mut sumsq_tmp[x + off..x + off + 16], [i32; 16], hsumsq);
}
x += 32;
}
while x < w - 2 {
let a = sum[row_off + x - 2];
let b = sum[row_off + x - 1];
let c = sum[row_off + x];
let d = sum[row_off + x + 1];
let e = sum[row_off + x + 2];
sum_tmp[x] = a + b + c + d + e;
let a2 = sumsq[row_off + x - 2];
let b2 = sumsq[row_off + x - 1];
let c2 = sumsq[row_off + x];
let d2 = sumsq[row_off + x + 1];
let e2 = sumsq[row_off + x + 2];
sumsq_tmp[x] = a2 + b2 + c2 + d2 + e2;
x += 1;
}
sum[row_off + 2..row_off + w - 2].copy_from_slice(&sum_tmp[2..w - 2]);
sumsq[row_off + 2..row_off + w - 2].copy_from_slice(&sumsq_tmp[2..w - 2]);
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum3_v_avx512(
_token: Server64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let src = &src[REST_UNIT_STRIDE..];
let mut x = 1usize;
while x + 32 < w {
for out_row in 2..h - 2 {
let base_row = out_row - 2;
let r0 = loadu_256!(
&src[base_row * REST_UNIT_STRIDE + x..base_row * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let r1 = loadu_256!(
&src[(base_row + 1) * REST_UNIT_STRIDE + x
..(base_row + 1) * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let r2 = loadu_256!(
&src[(base_row + 2) * REST_UNIT_STRIDE + x
..(base_row + 2) * REST_UNIT_STRIDE + x + 32],
[u8; 32]
);
let w0 = _mm512_cvtepu8_epi16(r0);
let w1 = _mm512_cvtepu8_epi16(r1);
let w2 = _mm512_cvtepu8_epi16(r2);
let sum_v = _mm512_add_epi16(_mm512_add_epi16(w0, w1), w2);
let sum_offset = (out_row - 1) * REST_UNIT_STRIDE + x;
storeu_512!(&mut sum[sum_offset..sum_offset + 32], [i16; 32], sum_v);
let r0_lo = _mm256_castsi256_si128(r0);
let r1_lo = _mm256_castsi256_si128(r1);
let r2_lo = _mm256_castsi256_si128(r2);
let lo_0 = _mm512_cvtepu8_epi32(r0_lo);
let lo_1 = _mm512_cvtepu8_epi32(r1_lo);
let lo_2 = _mm512_cvtepu8_epi32(r2_lo);
let sq_lo = _mm512_add_epi32(
_mm512_add_epi32(
_mm512_mullo_epi32(lo_0, lo_0),
_mm512_mullo_epi32(lo_1, lo_1),
),
_mm512_mullo_epi32(lo_2, lo_2),
);
storeu_512!(&mut sumsq[sum_offset..sum_offset + 16], [i32; 16], sq_lo);
let r0_hi = _mm256_extracti128_si256::<1>(r0);
let r1_hi = _mm256_extracti128_si256::<1>(r1);
let r2_hi = _mm256_extracti128_si256::<1>(r2);
let hi_0 = _mm512_cvtepu8_epi32(r0_hi);
let hi_1 = _mm512_cvtepu8_epi32(r1_hi);
let hi_2 = _mm512_cvtepu8_epi32(r2_hi);
let sq_hi = _mm512_add_epi32(
_mm512_add_epi32(
_mm512_mullo_epi32(hi_0, hi_0),
_mm512_mullo_epi32(hi_1, hi_1),
),
_mm512_mullo_epi32(hi_2, hi_2),
);
storeu_512!(
&mut sumsq[sum_offset + 16..sum_offset + 32],
[i32; 16],
sq_hi
);
}
x += 32;
}
for x in x..w - 1 {
let mut a = src[x] as i32;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i32;
let mut b2 = b * b;
let mut s_idx = REST_UNIT_STRIDE + x;
for out_row in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let c = src[s_idx] as i32;
let c2 = c * c;
let sum_v = (out_row - 1) * REST_UNIT_STRIDE + x;
sum[sum_v] = (a + b + c) as i16;
sumsq[sum_v] = a2 + b2 + c2;
a = b;
a2 = b2;
b = c;
b2 = c2;
}
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum3_h_avx512(
_token: Server64,
sumsq: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i16; (64 + 2 + 2) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut sum_tmp = [0i16; REST_UNIT_STRIDE];
let mut sumsq_tmp = [0i32; REST_UNIT_STRIDE];
for row in 1..h - 3 {
let row_off = row * REST_UNIT_STRIDE;
let mut x = 2usize;
while x + 32 <= w - 2 {
let s0 = loadu_512!(&sum[row_off + x - 1..row_off + x - 1 + 32], [i16; 32]);
let s1 = loadu_512!(&sum[row_off + x..row_off + x + 32], [i16; 32]);
let s2 = loadu_512!(&sum[row_off + x + 1..row_off + x + 1 + 32], [i16; 32]);
let hsum = _mm512_add_epi16(_mm512_add_epi16(s0, s1), s2);
storeu_512!(&mut sum_tmp[x..x + 32], [i16; 32], hsum);
for off in [0usize, 16] {
let q0 = loadu_512!(
&sumsq[row_off + x + off - 1..row_off + x + off - 1 + 16],
[i32; 16]
);
let q1 = loadu_512!(&sumsq[row_off + x + off..row_off + x + off + 16], [i32; 16]);
let q2 = loadu_512!(
&sumsq[row_off + x + off + 1..row_off + x + off + 1 + 16],
[i32; 16]
);
let hsumsq = _mm512_add_epi32(_mm512_add_epi32(q0, q1), q2);
storeu_512!(&mut sumsq_tmp[x + off..x + off + 16], [i32; 16], hsumsq);
}
x += 32;
}
while x < w - 2 {
sum_tmp[x] = sum[row_off + x - 1] + sum[row_off + x] + sum[row_off + x + 1];
sumsq_tmp[x] = sumsq[row_off + x - 1] + sumsq[row_off + x] + sumsq[row_off + x + 1];
x += 1;
}
sum[row_off + 2..row_off + w - 2].copy_from_slice(&sum_tmp[2..w - 2]);
sumsq[row_off + 2..row_off + w - 2].copy_from_slice(&sumsq_tmp[2..w - 2]);
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn selfguided_filter_8bpc_avx2(
_token: Desktop64,
dst: &mut [i16; 64 * MAX_RESTORATION_WIDTH],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
n: i32,
s: u32,
) {
let sgr_one_by_x: u32 = if n == 25 { 164 } else { 455 };
let mut sumsq = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut sum = [0i16; (64 + 2 + 2) * REST_UNIT_STRIDE];
let step = if n == 25 { 2 } else { 1 };
if n == 25 {
boxsum5_v_avx2(_token, &mut sumsq, &mut sum, src, w + 6, h + 6);
boxsum5_h_avx2(_token, &mut sumsq, &mut sum, w + 6, h + 6);
} else {
boxsum3_v_avx2(_token, &mut sumsq, &mut sum, src, w + 6, h + 6);
boxsum3_h_avx2(_token, &mut sumsq, &mut sum, w + 6, h + 6);
}
for row_offset in (0..(h + 2)).step_by(step) {
let aa_base = (row_offset + 1) * REST_UNIT_STRIDE + 2;
for i in 0..(w + 2) {
let idx = aa_base + i;
let a_val = sumsq[idx];
let b_val = sum[idx] as i32;
let p = cmp::max(a_val * n - b_val * b_val, 0) as u32;
let z = (p * s + (1 << 19)) >> 20;
let x = dav1d_sgr_x_by_x[cmp::min(z, 255) as usize] as u32;
sumsq[idx] = ((x * (b_val as u32) * sgr_one_by_x + (1 << 11)) >> 12) as i32;
sum[idx] = x as i16;
}
}
let base = 2 * REST_UNIT_STRIDE + 3;
let src_base = 3 * REST_UNIT_STRIDE + 3;
let rounding_9 = _mm256_set1_epi32(1 << 8);
let rounding_8 = _mm256_set1_epi32(1 << 7);
let six = _mm256_set1_epi32(6);
let five = _mm256_set1_epi32(5);
let four = _mm256_set1_epi32(4);
let three = _mm256_set1_epi32(3);
if n == 25 {
let mut j = 0usize;
while j < h.saturating_sub(1) {
let mut i = 0usize;
while i + 8 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let sum_above = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i16; 8]
));
let sum_below = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i16; 8]
));
let sum_al = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i16; 8]
));
let sum_ar = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i16; 8]
));
let sum_bl = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i16; 8]
));
let sum_br = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i16; 8]
));
let b_six = _mm256_add_epi32(
_mm256_mullo_epi32(_mm256_add_epi32(sum_above, sum_below), six),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(sum_al, sum_ar),
_mm256_add_epi32(sum_bl, sum_br),
),
five,
),
);
let sq_above = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i32; 8]
);
let sq_below = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i32; 8]
);
let sq_al = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let sq_ar = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let sq_bl = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let sq_br = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let a_six = _mm256_add_epi32(
_mm256_mullo_epi32(_mm256_add_epi32(sq_above, sq_below), six),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(sq_al, sq_ar),
_mm256_add_epi32(sq_bl, sq_br),
),
five,
),
);
let src_val = _mm256_cvtepu8_epi32(loadi64!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 8]
));
let result = _mm256_srai_epi32::<9>(_mm256_add_epi32(
_mm256_sub_epi32(a_six, _mm256_mullo_epi32(b_six, src_val)),
rounding_9,
));
let result_16 = _mm256_packs_epi32(result, _mm256_setzero_si256());
let result_16 = _mm256_permute4x64_epi64::<0xD8>(result_16);
storeu_128!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 8],
[i16; 8],
_mm256_castsi256_si128(result_16)
);
i += 8;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = sum[idx - REST_UNIT_STRIDE] as i32;
let below = sum[idx + REST_UNIT_STRIDE] as i32;
let al = sum[idx - REST_UNIT_STRIDE - 1] as i32;
let ar = sum[idx - REST_UNIT_STRIDE + 1] as i32;
let bl = sum[idx + REST_UNIT_STRIDE - 1] as i32;
let br = sum[idx + REST_UNIT_STRIDE + 1] as i32;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = sumsq[idx - REST_UNIT_STRIDE];
let below = sumsq[idx + REST_UNIT_STRIDE];
let al = sumsq[idx - REST_UNIT_STRIDE - 1];
let ar = sumsq[idx - REST_UNIT_STRIDE + 1];
let bl = sumsq[idx + REST_UNIT_STRIDE - 1];
let br = sumsq[idx + REST_UNIT_STRIDE + 1];
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i16;
i += 1;
}
if j + 1 < h {
let mut i = 0usize;
while i + 8 <= w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let sum_center =
_mm256_cvtepi16_epi32(loadu_128!(&sum[idx..idx + 8], [i16; 8]));
let sum_left =
_mm256_cvtepi16_epi32(loadu_128!(&sum[idx - 1..idx - 1 + 8], [i16; 8]));
let sum_right =
_mm256_cvtepi16_epi32(loadu_128!(&sum[idx + 1..idx + 1 + 8], [i16; 8]));
let b_horiz = _mm256_add_epi32(
_mm256_mullo_epi32(sum_center, six),
_mm256_mullo_epi32(_mm256_add_epi32(sum_left, sum_right), five),
);
let sq_center = loadu_256!(&sumsq[idx..idx + 8], [i32; 8]);
let sq_left = loadu_256!(&sumsq[idx - 1..idx - 1 + 8], [i32; 8]);
let sq_right = loadu_256!(&sumsq[idx + 1..idx + 1 + 8], [i32; 8]);
let a_horiz = _mm256_add_epi32(
_mm256_mullo_epi32(sq_center, six),
_mm256_mullo_epi32(_mm256_add_epi32(sq_left, sq_right), five),
);
let src_val = _mm256_cvtepu8_epi32(loadi64!(
&src[src_base + (j + 1) * REST_UNIT_STRIDE + i
..src_base + (j + 1) * REST_UNIT_STRIDE + i + 8]
));
let result = _mm256_srai_epi32::<8>(_mm256_add_epi32(
_mm256_sub_epi32(a_horiz, _mm256_mullo_epi32(b_horiz, src_val)),
rounding_8,
));
let result_16 = _mm256_packs_epi32(result, _mm256_setzero_si256());
let result_16 = _mm256_permute4x64_epi64::<0xD8>(result_16);
storeu_128!(
&mut dst[(j + 1) * MAX_RESTORATION_WIDTH + i
..(j + 1) * MAX_RESTORATION_WIDTH + i + 8],
[i16; 8],
_mm256_castsi256_si128(result_16)
);
i += 8;
}
while i < w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let b_horiz = {
let center = sum[idx] as i32;
let left = sum[idx - 1] as i32;
let right = sum[idx + 1] as i32;
center * 6 + (left + right) * 5
};
let a_horiz = {
let center = sumsq[idx];
let left = sumsq[idx - 1];
let right = sumsq[idx + 1];
center * 6 + (left + right) * 5
};
let src_val = src[src_base + (j + 1) * REST_UNIT_STRIDE + i] as i32;
dst[(j + 1) * MAX_RESTORATION_WIDTH + i] =
((a_horiz - b_horiz * src_val + (1 << 7)) >> 8) as i16;
i += 1;
}
}
j += 2;
}
if j < h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = sum[idx - REST_UNIT_STRIDE] as i32;
let below = sum[idx + REST_UNIT_STRIDE] as i32;
let al = sum[idx - REST_UNIT_STRIDE - 1] as i32;
let ar = sum[idx - REST_UNIT_STRIDE + 1] as i32;
let bl = sum[idx + REST_UNIT_STRIDE - 1] as i32;
let br = sum[idx + REST_UNIT_STRIDE + 1] as i32;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = sumsq[idx - REST_UNIT_STRIDE];
let below = sumsq[idx + REST_UNIT_STRIDE];
let al = sumsq[idx - REST_UNIT_STRIDE - 1];
let ar = sumsq[idx - REST_UNIT_STRIDE + 1];
let bl = sumsq[idx + REST_UNIT_STRIDE - 1];
let br = sumsq[idx + REST_UNIT_STRIDE + 1];
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i16;
}
}
} else {
for j in 0..h {
let mut i = 0usize;
while i + 8 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let s_c = _mm256_cvtepi16_epi32(loadu_128!(&sum[idx..idx + 8], [i16; 8]));
let s_l = _mm256_cvtepi16_epi32(loadu_128!(&sum[idx - 1..idx - 1 + 8], [i16; 8]));
let s_r = _mm256_cvtepi16_epi32(loadu_128!(&sum[idx + 1..idx + 1 + 8], [i16; 8]));
let s_a = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i16; 8]
));
let s_b = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i16; 8]
));
let s_al = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i16; 8]
));
let s_ar = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i16; 8]
));
let s_bl = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i16; 8]
));
let s_br = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i16; 8]
));
let b_eight = _mm256_add_epi32(
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(s_c, _mm256_add_epi32(s_l, s_r)),
_mm256_add_epi32(s_a, s_b),
),
four,
),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(s_al, s_ar),
_mm256_add_epi32(s_bl, s_br),
),
three,
),
);
let q_c = loadu_256!(&sumsq[idx..idx + 8], [i32; 8]);
let q_l = loadu_256!(&sumsq[idx - 1..idx - 1 + 8], [i32; 8]);
let q_r = loadu_256!(&sumsq[idx + 1..idx + 1 + 8], [i32; 8]);
let q_a = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i32; 8]
);
let q_b = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i32; 8]
);
let q_al = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let q_ar = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let q_bl = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let q_br = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let a_eight = _mm256_add_epi32(
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(q_c, _mm256_add_epi32(q_l, q_r)),
_mm256_add_epi32(q_a, q_b),
),
four,
),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(q_al, q_ar),
_mm256_add_epi32(q_bl, q_br),
),
three,
),
);
let src_val = _mm256_cvtepu8_epi32(loadi64!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 8]
));
let result = _mm256_srai_epi32::<9>(_mm256_add_epi32(
_mm256_sub_epi32(a_eight, _mm256_mullo_epi32(b_eight, src_val)),
rounding_9,
));
let result_16 = _mm256_packs_epi32(result, _mm256_setzero_si256());
let result_16 = _mm256_permute4x64_epi64::<0xD8>(result_16);
storeu_128!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 8],
[i16; 8],
_mm256_castsi256_si128(result_16)
);
i += 8;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_eight = {
let center = sum[idx] as i32;
let left = sum[idx - 1] as i32;
let right = sum[idx + 1] as i32;
let above = sum[idx - REST_UNIT_STRIDE] as i32;
let below = sum[idx + REST_UNIT_STRIDE] as i32;
let al = sum[idx - REST_UNIT_STRIDE - 1] as i32;
let ar = sum[idx - REST_UNIT_STRIDE + 1] as i32;
let bl = sum[idx + REST_UNIT_STRIDE - 1] as i32;
let br = sum[idx + REST_UNIT_STRIDE + 1] as i32;
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let a_eight = {
let center = sumsq[idx];
let left = sumsq[idx - 1];
let right = sumsq[idx + 1];
let above = sumsq[idx - REST_UNIT_STRIDE];
let below = sumsq[idx + REST_UNIT_STRIDE];
let al = sumsq[idx - REST_UNIT_STRIDE - 1];
let ar = sumsq[idx - REST_UNIT_STRIDE + 1];
let bl = sumsq[idx + REST_UNIT_STRIDE - 1];
let br = sumsq[idx + REST_UNIT_STRIDE + 1];
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_eight - b_eight * src_val + (1 << 8)) >> 9) as i16;
i += 1;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn selfguided_filter_8bpc_avx512(
_token: Server64,
dst: &mut [i16; 64 * MAX_RESTORATION_WIDTH],
src: &[u8; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
n: i32,
s: u32,
) {
let sgr_one_by_x: u32 = if n == 25 { 164 } else { 455 };
let mut sumsq = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut sum = [0i16; (64 + 2 + 2) * REST_UNIT_STRIDE];
let step = if n == 25 { 2 } else { 1 };
if n == 25 {
boxsum5_v_avx512(_token, &mut sumsq, &mut sum, src, w + 6, h + 6);
boxsum5_h_avx512(_token, &mut sumsq, &mut sum, w + 6, h + 6);
} else {
boxsum3_v_avx512(_token, &mut sumsq, &mut sum, src, w + 6, h + 6);
boxsum3_h_avx512(_token, &mut sumsq, &mut sum, w + 6, h + 6);
}
for row_offset in (0..(h + 2)).step_by(step) {
let aa_base = (row_offset + 1) * REST_UNIT_STRIDE + 2;
for i in 0..(w + 2) {
let idx = aa_base + i;
let a_val = sumsq[idx];
let b_val = sum[idx] as i32;
let p = cmp::max(a_val * n - b_val * b_val, 0) as u32;
let z = (p * s + (1 << 19)) >> 20;
let x = dav1d_sgr_x_by_x[cmp::min(z, 255) as usize] as u32;
sumsq[idx] = ((x * (b_val as u32) * sgr_one_by_x + (1 << 11)) >> 12) as i32;
sum[idx] = x as i16;
}
}
let base = 2 * REST_UNIT_STRIDE + 3;
let src_base = 3 * REST_UNIT_STRIDE + 3;
let rounding_9 = _mm512_set1_epi32(1 << 8);
let rounding_8 = _mm512_set1_epi32(1 << 7);
let six = _mm512_set1_epi32(6);
let five = _mm512_set1_epi32(5);
let four = _mm512_set1_epi32(4);
let three = _mm512_set1_epi32(3);
if n == 25 {
let mut j = 0usize;
while j < h.saturating_sub(1) {
let mut i = 0usize;
while i + 16 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let sum_above = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i16; 16]
));
let sum_below = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i16; 16]
));
let sum_al = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i16; 16]
));
let sum_ar = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i16; 16]
));
let sum_bl = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i16; 16]
));
let sum_br = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i16; 16]
));
let b_six = _mm512_add_epi32(
_mm512_mullo_epi32(_mm512_add_epi32(sum_above, sum_below), six),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(sum_al, sum_ar),
_mm512_add_epi32(sum_bl, sum_br),
),
five,
),
);
let sq_above = loadu_512!(
&sumsq[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i32; 16]
);
let sq_below = loadu_512!(
&sumsq[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i32; 16]
);
let sq_al = loadu_512!(
&sumsq[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let sq_ar = loadu_512!(
&sumsq[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let sq_bl = loadu_512!(
&sumsq[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let sq_br = loadu_512!(
&sumsq[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let a_six = _mm512_add_epi32(
_mm512_mullo_epi32(_mm512_add_epi32(sq_above, sq_below), six),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(sq_al, sq_ar),
_mm512_add_epi32(sq_bl, sq_br),
),
five,
),
);
let src_bytes = loadu_128!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 16],
[u8; 16]
);
let src_val = _mm512_cvtepu8_epi32(src_bytes);
let result = _mm512_srai_epi32::<9>(_mm512_add_epi32(
_mm512_sub_epi32(a_six, _mm512_mullo_epi32(b_six, src_val)),
rounding_9,
));
let result_16 = _mm512_cvtsepi32_epi16(result);
storeu_256!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 16],
[i16; 16],
result_16
);
i += 16;
}
while i + 8 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let sum_above = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i16; 8]
));
let sum_below = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i16; 8]
));
let sum_al = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i16; 8]
));
let sum_ar = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i16; 8]
));
let sum_bl = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i16; 8]
));
let sum_br = _mm256_cvtepi16_epi32(loadu_128!(
&sum[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i16; 8]
));
let six_256 = _mm256_set1_epi32(6);
let five_256 = _mm256_set1_epi32(5);
let b_six = _mm256_add_epi32(
_mm256_mullo_epi32(_mm256_add_epi32(sum_above, sum_below), six_256),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(sum_al, sum_ar),
_mm256_add_epi32(sum_bl, sum_br),
),
five_256,
),
);
let sq_above = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i32; 8]
);
let sq_below = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i32; 8]
);
let sq_al = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let sq_ar = loadu_256!(
&sumsq[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let sq_bl = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let sq_br = loadu_256!(
&sumsq[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let a_six = _mm256_add_epi32(
_mm256_mullo_epi32(_mm256_add_epi32(sq_above, sq_below), six_256),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(sq_al, sq_ar),
_mm256_add_epi32(sq_bl, sq_br),
),
five_256,
),
);
let src_val = _mm256_cvtepu8_epi32(loadi64!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 8]
));
let result = _mm256_srai_epi32::<9>(_mm256_add_epi32(
_mm256_sub_epi32(a_six, _mm256_mullo_epi32(b_six, src_val)),
_mm256_set1_epi32(1 << 8),
));
let result_16 = _mm256_packs_epi32(result, _mm256_setzero_si256());
let result_16 = _mm256_permute4x64_epi64::<0xD8>(result_16);
storeu_128!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 8],
[i16; 8],
_mm256_castsi256_si128(result_16)
);
i += 8;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = sum[idx - REST_UNIT_STRIDE] as i32;
let below = sum[idx + REST_UNIT_STRIDE] as i32;
let al = sum[idx - REST_UNIT_STRIDE - 1] as i32;
let ar = sum[idx - REST_UNIT_STRIDE + 1] as i32;
let bl = sum[idx + REST_UNIT_STRIDE - 1] as i32;
let br = sum[idx + REST_UNIT_STRIDE + 1] as i32;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = sumsq[idx - REST_UNIT_STRIDE];
let below = sumsq[idx + REST_UNIT_STRIDE];
let al = sumsq[idx - REST_UNIT_STRIDE - 1];
let ar = sumsq[idx - REST_UNIT_STRIDE + 1];
let bl = sumsq[idx + REST_UNIT_STRIDE - 1];
let br = sumsq[idx + REST_UNIT_STRIDE + 1];
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i16;
i += 1;
}
if j + 1 < h {
let mut i = 0usize;
while i + 16 <= w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let sum_center =
_mm512_cvtepi16_epi32(loadu_256!(&sum[idx..idx + 16], [i16; 16]));
let sum_left =
_mm512_cvtepi16_epi32(loadu_256!(&sum[idx - 1..idx - 1 + 16], [i16; 16]));
let sum_right =
_mm512_cvtepi16_epi32(loadu_256!(&sum[idx + 1..idx + 1 + 16], [i16; 16]));
let b_horiz = _mm512_add_epi32(
_mm512_mullo_epi32(sum_center, six),
_mm512_mullo_epi32(_mm512_add_epi32(sum_left, sum_right), five),
);
let sq_center = loadu_512!(&sumsq[idx..idx + 16], [i32; 16]);
let sq_left = loadu_512!(&sumsq[idx - 1..idx - 1 + 16], [i32; 16]);
let sq_right = loadu_512!(&sumsq[idx + 1..idx + 1 + 16], [i32; 16]);
let a_horiz = _mm512_add_epi32(
_mm512_mullo_epi32(sq_center, six),
_mm512_mullo_epi32(_mm512_add_epi32(sq_left, sq_right), five),
);
let src_bytes = loadu_128!(
&src[src_base + (j + 1) * REST_UNIT_STRIDE + i
..src_base + (j + 1) * REST_UNIT_STRIDE + i + 16],
[u8; 16]
);
let src_val = _mm512_cvtepu8_epi32(src_bytes);
let result = _mm512_srai_epi32::<8>(_mm512_add_epi32(
_mm512_sub_epi32(a_horiz, _mm512_mullo_epi32(b_horiz, src_val)),
rounding_8,
));
let result_16 = _mm512_cvtsepi32_epi16(result);
storeu_256!(
&mut dst[(j + 1) * MAX_RESTORATION_WIDTH + i
..(j + 1) * MAX_RESTORATION_WIDTH + i + 16],
[i16; 16],
result_16
);
i += 16;
}
while i < w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let b_horiz = {
let center = sum[idx] as i32;
let left = sum[idx - 1] as i32;
let right = sum[idx + 1] as i32;
center * 6 + (left + right) * 5
};
let a_horiz = {
let center = sumsq[idx];
let left = sumsq[idx - 1];
let right = sumsq[idx + 1];
center * 6 + (left + right) * 5
};
let src_val = src[src_base + (j + 1) * REST_UNIT_STRIDE + i] as i32;
dst[(j + 1) * MAX_RESTORATION_WIDTH + i] =
((a_horiz - b_horiz * src_val + (1 << 7)) >> 8) as i16;
i += 1;
}
}
j += 2;
}
if j < h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = sum[idx - REST_UNIT_STRIDE] as i32;
let below = sum[idx + REST_UNIT_STRIDE] as i32;
let al = sum[idx - REST_UNIT_STRIDE - 1] as i32;
let ar = sum[idx - REST_UNIT_STRIDE + 1] as i32;
let bl = sum[idx + REST_UNIT_STRIDE - 1] as i32;
let br = sum[idx + REST_UNIT_STRIDE + 1] as i32;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = sumsq[idx - REST_UNIT_STRIDE];
let below = sumsq[idx + REST_UNIT_STRIDE];
let al = sumsq[idx - REST_UNIT_STRIDE - 1];
let ar = sumsq[idx - REST_UNIT_STRIDE + 1];
let bl = sumsq[idx + REST_UNIT_STRIDE - 1];
let br = sumsq[idx + REST_UNIT_STRIDE + 1];
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i16;
}
}
} else {
for j in 0..h {
let mut i = 0usize;
while i + 16 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let s_c = _mm512_cvtepi16_epi32(loadu_256!(&sum[idx..idx + 16], [i16; 16]));
let s_l = _mm512_cvtepi16_epi32(loadu_256!(&sum[idx - 1..idx - 1 + 16], [i16; 16]));
let s_r = _mm512_cvtepi16_epi32(loadu_256!(&sum[idx + 1..idx + 1 + 16], [i16; 16]));
let s_a = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i16; 16]
));
let s_b = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i16; 16]
));
let s_al = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i16; 16]
));
let s_ar = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i16; 16]
));
let s_bl = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i16; 16]
));
let s_br = _mm512_cvtepi16_epi32(loadu_256!(
&sum[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i16; 16]
));
let b_eight = _mm512_add_epi32(
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(s_c, _mm512_add_epi32(s_l, s_r)),
_mm512_add_epi32(s_a, s_b),
),
four,
),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(s_al, s_ar),
_mm512_add_epi32(s_bl, s_br),
),
three,
),
);
let q_c = loadu_512!(&sumsq[idx..idx + 16], [i32; 16]);
let q_l = loadu_512!(&sumsq[idx - 1..idx - 1 + 16], [i32; 16]);
let q_r = loadu_512!(&sumsq[idx + 1..idx + 1 + 16], [i32; 16]);
let q_a = loadu_512!(
&sumsq[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i32; 16]
);
let q_b = loadu_512!(
&sumsq[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i32; 16]
);
let q_al = loadu_512!(
&sumsq[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let q_ar = loadu_512!(
&sumsq[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let q_bl = loadu_512!(
&sumsq[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let q_br = loadu_512!(
&sumsq[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let a_eight = _mm512_add_epi32(
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(q_c, _mm512_add_epi32(q_l, q_r)),
_mm512_add_epi32(q_a, q_b),
),
four,
),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(q_al, q_ar),
_mm512_add_epi32(q_bl, q_br),
),
three,
),
);
let src_bytes = loadu_128!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 16],
[u8; 16]
);
let src_val = _mm512_cvtepu8_epi32(src_bytes);
let result = _mm512_srai_epi32::<9>(_mm512_add_epi32(
_mm512_sub_epi32(a_eight, _mm512_mullo_epi32(b_eight, src_val)),
rounding_9,
));
let result_16 = _mm512_cvtsepi32_epi16(result);
storeu_256!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 16],
[i16; 16],
result_16
);
i += 16;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_eight = {
let center = sum[idx] as i32;
let left = sum[idx - 1] as i32;
let right = sum[idx + 1] as i32;
let above = sum[idx - REST_UNIT_STRIDE] as i32;
let below = sum[idx + REST_UNIT_STRIDE] as i32;
let al = sum[idx - REST_UNIT_STRIDE - 1] as i32;
let ar = sum[idx - REST_UNIT_STRIDE + 1] as i32;
let bl = sum[idx + REST_UNIT_STRIDE - 1] as i32;
let br = sum[idx + REST_UNIT_STRIDE + 1] as i32;
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let a_eight = {
let center = sumsq[idx];
let left = sumsq[idx - 1];
let right = sumsq[idx + 1];
let above = sumsq[idx - REST_UNIT_STRIDE];
let below = sumsq[idx + REST_UNIT_STRIDE];
let al = sumsq[idx - REST_UNIT_STRIDE - 1];
let ar = sumsq[idx - REST_UNIT_STRIDE + 1];
let bl = sumsq[idx + REST_UNIT_STRIDE - 1];
let br = sumsq[idx + REST_UNIT_STRIDE + 1];
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i32;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_eight - b_eight * src_val + (1 << 8)) >> 9) as i16;
i += 1;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn sgr_apply_8bpc(
_t: Desktop64,
p_guard: &mut [u8],
p_base: usize,
stride: isize,
dst: &[i16],
w: usize,
h: usize,
w_k: i32,
) {
use super::pixel_access::{loadu_128, storeu_128};
let w_k_v = _mm256_set1_epi32(w_k);
let rounding_v = _mm256_set1_epi32(1 << 10);
let zero_256 = _mm256_setzero_si256();
let max_v = _mm256_set1_epi32(255);
let dst = dst.flex();
let mut p_guard = p_guard.flex_mut();
for j in 0..h {
let row_off = p_base.wrapping_add_signed(j as isize * stride);
let dst_row = j * MAX_RESTORATION_WIDTH;
let mut i = 0;
while i + 8 <= w {
let dst_v = loadu_128!(&dst[dst_row + i..dst_row + i + 8], [i16; 8]);
let dst_i32 = _mm256_cvtepi16_epi32(dst_v);
let product = _mm256_mullo_epi32(dst_i32, w_k_v);
let rounded = _mm256_add_epi32(product, rounding_v);
let delta = _mm256_srai_epi32::<11>(rounded);
let mut px_bytes = [0u8; 16];
px_bytes[0..8].copy_from_slice(&p_guard[row_off + i..row_off + i + 8]);
let pixel_sse = loadu_128!(&px_bytes);
let pixels_i32 = _mm256_cvtepu8_epi32(pixel_sse);
let result = _mm256_add_epi32(pixels_i32, delta);
let clamped = _mm256_max_epi32(_mm256_min_epi32(result, max_v), zero_256);
let lo = _mm256_castsi256_si128(clamped);
let hi = _mm256_extracti128_si256::<1>(clamped);
let i16_packed = _mm_packs_epi32(lo, hi);
let u8_packed = _mm_packus_epi16(i16_packed, _mm_setzero_si128());
let mut out = [0u8; 16];
storeu_128!(&mut out, u8_packed);
p_guard[row_off + i..row_off + i + 8].copy_from_slice(&out[0..8]);
i += 8;
}
while i < w {
let v = w_k * dst[dst_row + i] as i32;
p_guard[row_off + i] = iclip(
p_guard[row_off + i] as i32 + ((v + (1 << 10)) >> 11),
0,
255,
) as u8;
i += 1;
}
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn sgr_apply_mix_8bpc(
_t: Desktop64,
p_guard: &mut [u8],
p_base: usize,
stride: isize,
dst0: &[i16],
dst1: &[i16],
w: usize,
h: usize,
w0: i32,
w1: i32,
) {
use super::pixel_access::{loadu_128, storeu_128};
let w0_v = _mm256_set1_epi32(w0);
let w1_v = _mm256_set1_epi32(w1);
let rounding_v = _mm256_set1_epi32(1 << 10);
let zero_256 = _mm256_setzero_si256();
let max_v = _mm256_set1_epi32(255);
let dst0 = dst0.flex();
let dst1 = dst1.flex();
let mut p_guard = p_guard.flex_mut();
for j in 0..h {
let row_off = p_base.wrapping_add_signed(j as isize * stride);
let dst_row = j * MAX_RESTORATION_WIDTH;
let mut i = 0;
while i + 8 <= w {
let d0_v = loadu_128!(&dst0[dst_row + i..dst_row + i + 8], [i16; 8]);
let d1_v = loadu_128!(&dst1[dst_row + i..dst_row + i + 8], [i16; 8]);
let d0_i32 = _mm256_cvtepi16_epi32(d0_v);
let d1_i32 = _mm256_cvtepi16_epi32(d1_v);
let v = _mm256_add_epi32(
_mm256_mullo_epi32(d0_i32, w0_v),
_mm256_mullo_epi32(d1_i32, w1_v),
);
let rounded = _mm256_add_epi32(v, rounding_v);
let delta = _mm256_srai_epi32::<11>(rounded);
let mut px_bytes = [0u8; 16];
px_bytes[0..8].copy_from_slice(&p_guard[row_off + i..row_off + i + 8]);
let pixel_sse = loadu_128!(&px_bytes);
let pixels_i32 = _mm256_cvtepu8_epi32(pixel_sse);
let result = _mm256_add_epi32(pixels_i32, delta);
let clamped = _mm256_max_epi32(_mm256_min_epi32(result, max_v), zero_256);
let lo = _mm256_castsi256_si128(clamped);
let hi = _mm256_extracti128_si256::<1>(clamped);
let i16_packed = _mm_packs_epi32(lo, hi);
let u8_packed = _mm_packus_epi16(i16_packed, _mm_setzero_si128());
let mut out = [0u8; 16];
storeu_128!(&mut out, u8_packed);
p_guard[row_off + i..row_off + i + 8].copy_from_slice(&out[0..8]);
i += 8;
}
while i < w {
let v = w0 * dst0[dst_row + i] as i32 + w1 * dst1[dst_row + i] as i32;
p_guard[row_off + i] = iclip(
p_guard[row_off + i] as i32 + ((v + (1 << 10)) >> 11),
0,
255,
) as u8;
i += 1;
}
}
}
#[cfg(target_arch = "x86_64")]
fn sgr_5x5_8bpc_avx2_inner(
p: PicOffset,
left: &[LeftPixelRow<u8>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
) {
let mut tmp = [0u8; (64 + 3 + 3) * REST_UNIT_STRIDE];
let mut dst = [0i16; 64 * MAX_RESTORATION_WIDTH];
padding::<BitDepth8>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let sgr = params.sgr();
#[cfg(target_arch = "x86_64")]
if let Some(token) = crate::src::cpu::summon_avx512() {
selfguided_filter_8bpc_avx512(token, &mut dst, &tmp, w, h, 25, sgr.s0);
} else if let Some(token) = summon_avx2() {
selfguided_filter_8bpc_avx2(token, &mut dst, &tmp, w, h, 25, sgr.s0);
} else {
selfguided_filter_8bpc(&mut dst, &tmp, w, h, 25, sgr.s0);
}
#[cfg(not(target_arch = "x86_64"))]
selfguided_filter_8bpc(&mut dst, &tmp, w, h, 25, sgr.s0);
let w0 = sgr.w0 as i32;
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth8, _>(
&p,
w,
h,
|bytes, offset, stride| {
if let Some(token) = summon_avx2() {
sgr_apply_8bpc(token, bytes, offset, stride, &dst, w, h, w0);
} else {
let dst = dst.as_slice().flex();
let mut cp = bytes.flex_mut();
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize;
for i in 0..w {
let v = w0 * dst[j * MAX_RESTORATION_WIDTH + i] as i32;
cp[row_off + i] =
iclip(cp[row_off + i] as i32 + ((v + (1 << 10)) >> 11), 0, 255) as u8;
}
}
}
},
); }
#[cfg(target_arch = "x86_64")]
fn sgr_3x3_8bpc_avx2_inner(
p: PicOffset,
left: &[LeftPixelRow<u8>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
) {
let mut tmp = [0u8; (64 + 3 + 3) * REST_UNIT_STRIDE];
let mut dst = [0i16; 64 * MAX_RESTORATION_WIDTH];
padding::<BitDepth8>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let sgr = params.sgr();
#[cfg(target_arch = "x86_64")]
if let Some(token) = crate::src::cpu::summon_avx512() {
selfguided_filter_8bpc_avx512(token, &mut dst, &tmp, w, h, 9, sgr.s1);
} else if let Some(token) = summon_avx2() {
selfguided_filter_8bpc_avx2(token, &mut dst, &tmp, w, h, 9, sgr.s1);
} else {
selfguided_filter_8bpc(&mut dst, &tmp, w, h, 9, sgr.s1);
}
#[cfg(not(target_arch = "x86_64"))]
selfguided_filter_8bpc(&mut dst, &tmp, w, h, 9, sgr.s1);
let w1 = sgr.w1 as i32;
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth8, _>(
&p,
w,
h,
|bytes, offset, stride| {
if let Some(token) = summon_avx2() {
sgr_apply_8bpc(token, bytes, offset, stride, &dst, w, h, w1);
} else {
let dst = dst.as_slice().flex();
let mut cp = bytes.flex_mut();
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize;
for i in 0..w {
let v = w1 * dst[j * MAX_RESTORATION_WIDTH + i] as i32;
cp[row_off + i] =
iclip(cp[row_off + i] as i32 + ((v + (1 << 10)) >> 11), 0, 255) as u8;
}
}
}
},
); }
#[cfg(target_arch = "x86_64")]
fn sgr_mix_8bpc_avx2_inner(
p: PicOffset,
left: &[LeftPixelRow<u8>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
) {
let mut tmp = [0u8; (64 + 3 + 3) * REST_UNIT_STRIDE];
let mut dst0 = [0i16; 64 * MAX_RESTORATION_WIDTH];
let mut dst1 = [0i16; 64 * MAX_RESTORATION_WIDTH];
padding::<BitDepth8>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let sgr = params.sgr();
#[cfg(target_arch = "x86_64")]
if let Some(token) = crate::src::cpu::summon_avx512() {
selfguided_filter_8bpc_avx512(token, &mut dst0, &tmp, w, h, 25, sgr.s0);
selfguided_filter_8bpc_avx512(token, &mut dst1, &tmp, w, h, 9, sgr.s1);
} else if let Some(token) = summon_avx2() {
selfguided_filter_8bpc_avx2(token, &mut dst0, &tmp, w, h, 25, sgr.s0);
selfguided_filter_8bpc_avx2(token, &mut dst1, &tmp, w, h, 9, sgr.s1);
} else {
selfguided_filter_8bpc(&mut dst0, &tmp, w, h, 25, sgr.s0);
selfguided_filter_8bpc(&mut dst1, &tmp, w, h, 9, sgr.s1);
}
#[cfg(not(target_arch = "x86_64"))]
{
selfguided_filter_8bpc(&mut dst0, &tmp, w, h, 25, sgr.s0);
selfguided_filter_8bpc(&mut dst1, &tmp, w, h, 9, sgr.s1);
}
let w0 = sgr.w0 as i32;
let w1 = sgr.w1 as i32;
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth8, _>(
&p,
w,
h,
|bytes, offset, stride| {
if let Some(token) = summon_avx2() {
sgr_apply_mix_8bpc(token, bytes, offset, stride, &dst0, &dst1, w, h, w0, w1);
} else {
let d0 = dst0.as_slice().flex();
let d1 = dst1.as_slice().flex();
let mut cp = bytes.flex_mut();
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize;
for i in 0..w {
let v = w0 * d0[j * MAX_RESTORATION_WIDTH + i] as i32
+ w1 * d1[j * MAX_RESTORATION_WIDTH + i] as i32;
cp[row_off + i] =
iclip(cp[row_off + i] as i32 + ((v + (1 << 10)) >> 11), 0, 255) as u8;
}
}
}
},
); }
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn sgr_filter_5x5_8bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
_bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u8>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u8>();
let lpf_off = reconstruct_lpf_offset(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
sgr_5x5_8bpc_avx2_inner(p, left, lpf, lpf_off, w, h, params, edges);
}
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn sgr_filter_3x3_8bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
_bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u8>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u8>();
let lpf_off = reconstruct_lpf_offset(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
sgr_3x3_8bpc_avx2_inner(p, left, lpf, lpf_off, w, h, params, edges);
}
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn sgr_filter_mix_8bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
_bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u8>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u8>();
let lpf_off = reconstruct_lpf_offset(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
sgr_mix_8bpc_avx2_inner(p, left, lpf, lpf_off, w, h, params, edges);
}
#[inline(always)]
fn boxsum5_16bpc(
sumsq: &mut [i64; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u16; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
for x in 0..w {
let mut sum_v = x;
let mut sumsq_v = x;
let mut a = src[x] as i64;
let mut a2 = a * a;
let mut b = src[1 * REST_UNIT_STRIDE + x] as i64;
let mut b2 = b * b;
let mut c = src[2 * REST_UNIT_STRIDE + x] as i64;
let mut c2 = c * c;
let mut d = src[3 * REST_UNIT_STRIDE + x] as i64;
let mut d2 = d * d;
let mut s_idx = 3 * REST_UNIT_STRIDE + x;
for _ in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let e = src[s_idx] as i64;
let e2 = e * e;
sum_v += REST_UNIT_STRIDE;
sumsq_v += REST_UNIT_STRIDE;
sum[sum_v] = (a + b + c + d + e) as i32;
sumsq[sumsq_v] = a2 + b2 + c2 + d2 + e2;
a = b;
a2 = b2;
b = c;
b2 = c2;
c = d;
c2 = d2;
d = e;
d2 = e2;
}
}
let mut sum_idx = REST_UNIT_STRIDE;
let mut sumsq_idx = REST_UNIT_STRIDE;
for _ in 2..h - 2 {
let mut a = sum[sum_idx] as i64;
let mut a2 = sumsq[sumsq_idx];
let mut b = sum[sum_idx + 1] as i64;
let mut b2 = sumsq[sumsq_idx + 1];
let mut c = sum[sum_idx + 2] as i64;
let mut c2 = sumsq[sumsq_idx + 2];
let mut d = sum[sum_idx + 3] as i64;
let mut d2 = sumsq[sumsq_idx + 3];
for x in 2..w - 2 {
let e = sum[sum_idx + x + 2] as i64;
let e2 = sumsq[sumsq_idx + x + 2];
sum[sum_idx + x] = (a + b + c + d + e) as i32;
sumsq[sumsq_idx + x] = a2 + b2 + c2 + d2 + e2;
a = b;
b = c;
c = d;
d = e;
a2 = b2;
b2 = c2;
c2 = d2;
d2 = e2;
}
sum_idx += REST_UNIT_STRIDE;
sumsq_idx += REST_UNIT_STRIDE;
}
}
#[inline(always)]
fn boxsum3_16bpc(
sumsq: &mut [i64; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u16; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let src = &src[REST_UNIT_STRIDE..];
for x in 1..w - 1 {
let mut sum_v = x;
let mut sumsq_v = x;
let mut a = src[x] as i64;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i64;
let mut b2 = b * b;
let mut s_idx = REST_UNIT_STRIDE + x;
for _ in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let c = src[s_idx] as i64;
let c2 = c * c;
sum_v += REST_UNIT_STRIDE;
sumsq_v += REST_UNIT_STRIDE;
sum[sum_v] = (a + b + c) as i32;
sumsq[sumsq_v] = a2 + b2 + c2;
a = b;
a2 = b2;
b = c;
b2 = c2;
}
}
let mut sum_idx = REST_UNIT_STRIDE;
let mut sumsq_idx = REST_UNIT_STRIDE;
for _ in 2..h - 2 {
let mut a = sum[sum_idx + 1] as i64;
let mut a2 = sumsq[sumsq_idx + 1];
let mut b = sum[sum_idx + 2] as i64;
let mut b2 = sumsq[sumsq_idx + 2];
for x in 2..w - 2 {
let c = sum[sum_idx + x + 1] as i64;
let c2 = sumsq[sumsq_idx + x + 1];
sum[sum_idx + x] = (a + b + c) as i32;
sumsq[sumsq_idx + x] = a2 + b2 + c2;
a = b;
b = c;
a2 = b2;
b2 = c2;
}
sum_idx += REST_UNIT_STRIDE;
sumsq_idx += REST_UNIT_STRIDE;
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum5_v_16bpc_avx512(
_token: Server64,
sumsq: &mut [i64; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u16; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut x = 0usize;
while x + 16 <= w {
for out_row in 2..h - 2 {
let base_row = out_row - 2;
let r0 = loadu_256!(
&src[base_row * REST_UNIT_STRIDE + x..base_row * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let r1 = loadu_256!(
&src[(base_row + 1) * REST_UNIT_STRIDE + x
..(base_row + 1) * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let r2 = loadu_256!(
&src[(base_row + 2) * REST_UNIT_STRIDE + x
..(base_row + 2) * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let r3 = loadu_256!(
&src[(base_row + 3) * REST_UNIT_STRIDE + x
..(base_row + 3) * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let r4 = loadu_256!(
&src[(base_row + 4) * REST_UNIT_STRIDE + x
..(base_row + 4) * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let w0 = _mm512_cvtepu16_epi32(r0);
let w1 = _mm512_cvtepu16_epi32(r1);
let w2 = _mm512_cvtepu16_epi32(r2);
let w3 = _mm512_cvtepu16_epi32(r3);
let w4 = _mm512_cvtepu16_epi32(r4);
let sum_v = _mm512_add_epi32(
_mm512_add_epi32(_mm512_add_epi32(w0, w1), _mm512_add_epi32(w2, w3)),
w4,
);
let sum_offset = (out_row - 1) * REST_UNIT_STRIDE + x;
storeu_512!(&mut sum[sum_offset..sum_offset + 16], [i32; 16], sum_v);
let sumsq_i32 = _mm512_add_epi32(
_mm512_add_epi32(
_mm512_add_epi32(_mm512_mullo_epi32(w0, w0), _mm512_mullo_epi32(w1, w1)),
_mm512_add_epi32(_mm512_mullo_epi32(w2, w2), _mm512_mullo_epi32(w3, w3)),
),
_mm512_mullo_epi32(w4, w4),
);
let lo_256 = _mm512_castsi512_si256(sumsq_i32);
let lo_i64 = _mm512_cvtepu32_epi64(lo_256);
storeu_512!(&mut sumsq[sum_offset..sum_offset + 8], [i64; 8], lo_i64);
let hi_256 = _mm512_extracti64x4_epi64::<1>(sumsq_i32);
let hi_i64 = _mm512_cvtepu32_epi64(hi_256);
storeu_512!(
&mut sumsq[sum_offset + 8..sum_offset + 16],
[i64; 8],
hi_i64
);
}
x += 16;
}
for x in x..w {
let mut a = src[x] as i64;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i64;
let mut b2 = b * b;
let mut c = src[2 * REST_UNIT_STRIDE + x] as i64;
let mut c2 = c * c;
let mut d = src[3 * REST_UNIT_STRIDE + x] as i64;
let mut d2 = d * d;
let mut s_idx = 3 * REST_UNIT_STRIDE + x;
for out_row in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let e = src[s_idx] as i64;
let e2 = e * e;
let sum_v = (out_row - 1) * REST_UNIT_STRIDE + x;
sum[sum_v] = (a + b + c + d + e) as i32;
sumsq[sum_v] = a2 + b2 + c2 + d2 + e2;
a = b;
a2 = b2;
b = c;
b2 = c2;
c = d;
c2 = d2;
d = e;
d2 = e2;
}
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum5_h_16bpc_avx512(
_token: Server64,
sumsq: &mut [i64; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut sum_tmp = [0i32; REST_UNIT_STRIDE];
let mut sumsq_tmp = [0i64; REST_UNIT_STRIDE];
for row in 1..h - 3 {
let row_off = row * REST_UNIT_STRIDE;
let mut x = 2usize;
while x + 16 <= w - 2 {
let s0 = loadu_512!(&sum[row_off + x - 2..row_off + x - 2 + 16], [i32; 16]);
let s1 = loadu_512!(&sum[row_off + x - 1..row_off + x - 1 + 16], [i32; 16]);
let s2 = loadu_512!(&sum[row_off + x..row_off + x + 16], [i32; 16]);
let s3 = loadu_512!(&sum[row_off + x + 1..row_off + x + 1 + 16], [i32; 16]);
let s4 = loadu_512!(&sum[row_off + x + 2..row_off + x + 2 + 16], [i32; 16]);
let hsum = _mm512_add_epi32(
_mm512_add_epi32(_mm512_add_epi32(s0, s1), _mm512_add_epi32(s2, s3)),
s4,
);
storeu_512!(&mut sum_tmp[x..x + 16], [i32; 16], hsum);
for off in [0usize, 8] {
let q0 = loadu_512!(
&sumsq[row_off + x + off - 2..row_off + x + off - 2 + 8],
[i64; 8]
);
let q1 = loadu_512!(
&sumsq[row_off + x + off - 1..row_off + x + off - 1 + 8],
[i64; 8]
);
let q2 = loadu_512!(&sumsq[row_off + x + off..row_off + x + off + 8], [i64; 8]);
let q3 = loadu_512!(
&sumsq[row_off + x + off + 1..row_off + x + off + 1 + 8],
[i64; 8]
);
let q4 = loadu_512!(
&sumsq[row_off + x + off + 2..row_off + x + off + 2 + 8],
[i64; 8]
);
let hsumsq = _mm512_add_epi64(
_mm512_add_epi64(_mm512_add_epi64(q0, q1), _mm512_add_epi64(q2, q3)),
q4,
);
storeu_512!(&mut sumsq_tmp[x + off..x + off + 8], [i64; 8], hsumsq);
}
x += 16;
}
while x < w - 2 {
sum_tmp[x] = sum[row_off + x - 2]
+ sum[row_off + x - 1]
+ sum[row_off + x]
+ sum[row_off + x + 1]
+ sum[row_off + x + 2];
sumsq_tmp[x] = sumsq[row_off + x - 2]
+ sumsq[row_off + x - 1]
+ sumsq[row_off + x]
+ sumsq[row_off + x + 1]
+ sumsq[row_off + x + 2];
x += 1;
}
sum[row_off + 2..row_off + w - 2].copy_from_slice(&sum_tmp[2..w - 2]);
sumsq[row_off + 2..row_off + w - 2].copy_from_slice(&sumsq_tmp[2..w - 2]);
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum3_v_16bpc_avx512(
_token: Server64,
sumsq: &mut [i64; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
src: &[u16; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let src = &src[REST_UNIT_STRIDE..];
let mut x = 1usize;
while x + 16 < w {
for out_row in 2..h - 2 {
let base_row = out_row - 2;
let r0 = loadu_256!(
&src[base_row * REST_UNIT_STRIDE + x..base_row * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let r1 = loadu_256!(
&src[(base_row + 1) * REST_UNIT_STRIDE + x
..(base_row + 1) * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let r2 = loadu_256!(
&src[(base_row + 2) * REST_UNIT_STRIDE + x
..(base_row + 2) * REST_UNIT_STRIDE + x + 16],
[u16; 16]
);
let w0 = _mm512_cvtepu16_epi32(r0);
let w1 = _mm512_cvtepu16_epi32(r1);
let w2 = _mm512_cvtepu16_epi32(r2);
let sum_v = _mm512_add_epi32(_mm512_add_epi32(w0, w1), w2);
let sum_offset = (out_row - 1) * REST_UNIT_STRIDE + x;
storeu_512!(&mut sum[sum_offset..sum_offset + 16], [i32; 16], sum_v);
let sumsq_i32 = _mm512_add_epi32(
_mm512_add_epi32(_mm512_mullo_epi32(w0, w0), _mm512_mullo_epi32(w1, w1)),
_mm512_mullo_epi32(w2, w2),
);
let lo_256 = _mm512_castsi512_si256(sumsq_i32);
let lo_i64 = _mm512_cvtepu32_epi64(lo_256);
storeu_512!(&mut sumsq[sum_offset..sum_offset + 8], [i64; 8], lo_i64);
let hi_256 = _mm512_extracti64x4_epi64::<1>(sumsq_i32);
let hi_i64 = _mm512_cvtepu32_epi64(hi_256);
storeu_512!(
&mut sumsq[sum_offset + 8..sum_offset + 16],
[i64; 8],
hi_i64
);
}
x += 16;
}
for x in x..w - 1 {
let mut a = src[x] as i64;
let mut a2 = a * a;
let mut b = src[REST_UNIT_STRIDE + x] as i64;
let mut b2 = b * b;
let mut s_idx = REST_UNIT_STRIDE + x;
for out_row in 2..h - 2 {
s_idx += REST_UNIT_STRIDE;
let c = src[s_idx] as i64;
let c2 = c * c;
let sum_v = (out_row - 1) * REST_UNIT_STRIDE + x;
sum[sum_v] = (a + b + c) as i32;
sumsq[sum_v] = a2 + b2 + c2;
a = b;
a2 = b2;
b = c;
b2 = c2;
}
}
}
#[cfg(target_arch = "x86_64")]
#[rite]
fn boxsum3_h_16bpc_avx512(
_token: Server64,
sumsq: &mut [i64; (64 + 2 + 2) * REST_UNIT_STRIDE],
sum: &mut [i32; (64 + 2 + 2) * REST_UNIT_STRIDE],
w: usize,
h: usize,
) {
let mut sum_tmp = [0i32; REST_UNIT_STRIDE];
let mut sumsq_tmp = [0i64; REST_UNIT_STRIDE];
for row in 1..h - 3 {
let row_off = row * REST_UNIT_STRIDE;
let mut x = 2usize;
while x + 16 <= w - 2 {
let s0 = loadu_512!(&sum[row_off + x - 1..row_off + x - 1 + 16], [i32; 16]);
let s1 = loadu_512!(&sum[row_off + x..row_off + x + 16], [i32; 16]);
let s2 = loadu_512!(&sum[row_off + x + 1..row_off + x + 1 + 16], [i32; 16]);
let hsum = _mm512_add_epi32(_mm512_add_epi32(s0, s1), s2);
storeu_512!(&mut sum_tmp[x..x + 16], [i32; 16], hsum);
for off in [0usize, 8] {
let q0 = loadu_512!(
&sumsq[row_off + x + off - 1..row_off + x + off - 1 + 8],
[i64; 8]
);
let q1 = loadu_512!(&sumsq[row_off + x + off..row_off + x + off + 8], [i64; 8]);
let q2 = loadu_512!(
&sumsq[row_off + x + off + 1..row_off + x + off + 1 + 8],
[i64; 8]
);
let hsumsq = _mm512_add_epi64(_mm512_add_epi64(q0, q1), q2);
storeu_512!(&mut sumsq_tmp[x + off..x + off + 8], [i64; 8], hsumsq);
}
x += 16;
}
while x < w - 2 {
sum_tmp[x] = sum[row_off + x - 1] + sum[row_off + x] + sum[row_off + x + 1];
sumsq_tmp[x] = sumsq[row_off + x - 1] + sumsq[row_off + x] + sumsq[row_off + x + 1];
x += 1;
}
sum[row_off + 2..row_off + w - 2].copy_from_slice(&sum_tmp[2..w - 2]);
sumsq[row_off + 2..row_off + w - 2].copy_from_slice(&sumsq_tmp[2..w - 2]);
}
}
#[inline(never)]
fn selfguided_filter_16bpc(
dst: &mut [i32; 64 * MAX_RESTORATION_WIDTH],
src: &[u16; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
n: i32,
s: u32,
bitdepth_max: i32,
) {
let sgr_one_by_x: u32 = if n == 25 { 164 } else { 455 };
let bitdepth = if bitdepth_max == 1023 { 10 } else { 12 };
let bitdepth_min_8 = bitdepth - 8;
let mut sumsq = [0i64; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut sum = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut aa = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut bb = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let step = if n == 25 { 2 } else { 1 };
if n == 25 {
boxsum5_16bpc(&mut sumsq, &mut sum, src, w + 6, h + 6);
} else {
boxsum3_16bpc(&mut sumsq, &mut sum, src, w + 6, h + 6);
}
{
let sq = sumsq.as_slice().flex();
let sm = sum.as_slice().flex();
let mut aa_m = aa.as_mut_slice().flex_mut();
let mut bb_m = bb.as_mut_slice().flex_mut();
for row_offset in (0..(h + 2)).step_by(step) {
let aa_base = (row_offset + 1) * REST_UNIT_STRIDE + 2;
for i in 0..(w + 2) {
let idx = aa_base + i;
let a_val = sq[idx];
let b_val = sm[idx] as i64;
let a_scaled =
((a_val + (1 << (2 * bitdepth_min_8 - 1))) >> (2 * bitdepth_min_8)) as i32;
let b_scaled = ((b_val + (1 << (bitdepth_min_8 - 1))) >> bitdepth_min_8) as i32;
let p = cmp::max(a_scaled * n - b_scaled * b_scaled, 0) as u32;
let z = (p * s + (1 << 19)) >> 20;
let x = dav1d_sgr_x_by_x[cmp::min(z, 255) as usize] as u32;
aa_m[idx] = ((x * (b_val as u32) * sgr_one_by_x + (1 << 11)) >> 12) as i32;
bb_m[idx] = x as i32;
}
}
}
let base = 2 * REST_UNIT_STRIDE + 3; let src_base = 3 * REST_UNIT_STRIDE + 3;
let bb_f = bb.as_slice().flex();
let aa_f = aa.as_slice().flex();
let src = src.as_slice().flex();
let mut dst = dst.as_mut_slice().flex_mut();
if n == 25 {
let mut j = 0usize;
while j < h.saturating_sub(1) {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb_f[idx - REST_UNIT_STRIDE] as i64;
let below = bb_f[idx + REST_UNIT_STRIDE] as i64;
let above_left = bb_f[idx - REST_UNIT_STRIDE - 1] as i64;
let above_right = bb_f[idx - REST_UNIT_STRIDE + 1] as i64;
let below_left = bb_f[idx + REST_UNIT_STRIDE - 1] as i64;
let below_right = bb_f[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let a_six = {
let above = aa_f[idx - REST_UNIT_STRIDE] as i64;
let below = aa_f[idx + REST_UNIT_STRIDE] as i64;
let above_left = aa_f[idx - REST_UNIT_STRIDE - 1] as i64;
let above_right = aa_f[idx - REST_UNIT_STRIDE + 1] as i64;
let below_left = aa_f[idx + REST_UNIT_STRIDE - 1] as i64;
let below_right = aa_f[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i32;
}
if j + 1 < h {
for i in 0..w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let b_horiz = {
let center = bb_f[idx] as i64;
let left = bb_f[idx - 1] as i64;
let right = bb_f[idx + 1] as i64;
center * 6 + (left + right) * 5
};
let a_horiz = {
let center = aa_f[idx] as i64;
let left = aa_f[idx - 1] as i64;
let right = aa_f[idx + 1] as i64;
center * 6 + (left + right) * 5
};
let src_val = src[src_base + (j + 1) * REST_UNIT_STRIDE + i] as i64;
dst[(j + 1) * MAX_RESTORATION_WIDTH + i] =
((a_horiz - b_horiz * src_val + (1 << 7)) >> 8) as i32;
}
}
j += 2;
}
if j < h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb_f[idx - REST_UNIT_STRIDE] as i64;
let below = bb_f[idx + REST_UNIT_STRIDE] as i64;
let above_left = bb_f[idx - REST_UNIT_STRIDE - 1] as i64;
let above_right = bb_f[idx - REST_UNIT_STRIDE + 1] as i64;
let below_left = bb_f[idx + REST_UNIT_STRIDE - 1] as i64;
let below_right = bb_f[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let a_six = {
let above = aa_f[idx - REST_UNIT_STRIDE] as i64;
let below = aa_f[idx + REST_UNIT_STRIDE] as i64;
let above_left = aa_f[idx - REST_UNIT_STRIDE - 1] as i64;
let above_right = aa_f[idx - REST_UNIT_STRIDE + 1] as i64;
let below_left = aa_f[idx + REST_UNIT_STRIDE - 1] as i64;
let below_right = aa_f[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (above_left + above_right + below_left + below_right) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i32;
}
}
} else {
for j in 0..h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_eight = {
let center = bb_f[idx] as i64;
let left = bb_f[idx - 1] as i64;
let right = bb_f[idx + 1] as i64;
let above = bb_f[idx - REST_UNIT_STRIDE] as i64;
let below = bb_f[idx + REST_UNIT_STRIDE] as i64;
let above_left = bb_f[idx - REST_UNIT_STRIDE - 1] as i64;
let above_right = bb_f[idx - REST_UNIT_STRIDE + 1] as i64;
let below_left = bb_f[idx + REST_UNIT_STRIDE - 1] as i64;
let below_right = bb_f[idx + REST_UNIT_STRIDE + 1] as i64;
(center + left + right + above + below) * 4
+ (above_left + above_right + below_left + below_right) * 3
};
let a_eight = {
let center = aa_f[idx] as i64;
let left = aa_f[idx - 1] as i64;
let right = aa_f[idx + 1] as i64;
let above = aa_f[idx - REST_UNIT_STRIDE] as i64;
let below = aa_f[idx + REST_UNIT_STRIDE] as i64;
let above_left = aa_f[idx - REST_UNIT_STRIDE - 1] as i64;
let above_right = aa_f[idx - REST_UNIT_STRIDE + 1] as i64;
let below_left = aa_f[idx + REST_UNIT_STRIDE - 1] as i64;
let below_right = aa_f[idx + REST_UNIT_STRIDE + 1] as i64;
(center + left + right + above + below) * 4
+ (above_left + above_right + below_left + below_right) * 3
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_eight - b_eight * src_val + (1 << 8)) >> 9) as i32;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn selfguided_filter_16bpc_avx2(
_token: Desktop64,
dst: &mut [i32; 64 * MAX_RESTORATION_WIDTH],
src: &[u16; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
n: i32,
s: u32,
bitdepth_max: i32,
) {
let sgr_one_by_x: u32 = if n == 25 { 164 } else { 455 };
let bitdepth = if bitdepth_max == 1023 { 10 } else { 12 };
let bitdepth_min_8 = bitdepth - 8;
let mut sumsq = [0i64; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut sum = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut aa = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut bb = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let step = if n == 25 { 2 } else { 1 };
if n == 25 {
boxsum5_16bpc(&mut sumsq, &mut sum, src, w + 6, h + 6);
} else {
boxsum3_16bpc(&mut sumsq, &mut sum, src, w + 6, h + 6);
}
for row_offset in (0..(h + 2)).step_by(step) {
let aa_base = (row_offset + 1) * REST_UNIT_STRIDE + 2;
for i in 0..(w + 2) {
let idx = aa_base + i;
let a_val = sumsq[idx];
let b_val = sum[idx] as i64;
let a_scaled =
((a_val + (1 << (2 * bitdepth_min_8 - 1))) >> (2 * bitdepth_min_8)) as i32;
let b_scaled = ((b_val + (1 << (bitdepth_min_8 - 1))) >> bitdepth_min_8) as i32;
let p = cmp::max(a_scaled * n - b_scaled * b_scaled, 0) as u32;
let z = (p * s + (1 << 19)) >> 20;
let x = dav1d_sgr_x_by_x[cmp::min(z, 255) as usize] as u32;
aa[idx] = ((x * (b_val as u32) * sgr_one_by_x + (1 << 11)) >> 12) as i32;
bb[idx] = x as i32;
}
}
let base = 2 * REST_UNIT_STRIDE + 3;
let src_base = 3 * REST_UNIT_STRIDE + 3;
let rounding_9 = _mm256_set1_epi32(1 << 8);
let rounding_8 = _mm256_set1_epi32(1 << 7);
let six = _mm256_set1_epi32(6);
let five = _mm256_set1_epi32(5);
let four = _mm256_set1_epi32(4);
let three = _mm256_set1_epi32(3);
if n == 25 {
let mut j = 0usize;
while j < h.saturating_sub(1) {
let mut i = 0usize;
while i + 8 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let bb_above = loadu_256!(
&bb[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i32; 8]
);
let bb_below = loadu_256!(
&bb[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i32; 8]
);
let bb_al = loadu_256!(
&bb[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let bb_ar = loadu_256!(
&bb[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let bb_bl = loadu_256!(
&bb[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let bb_br = loadu_256!(
&bb[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let b_six = _mm256_add_epi32(
_mm256_mullo_epi32(_mm256_add_epi32(bb_above, bb_below), six),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(bb_al, bb_ar),
_mm256_add_epi32(bb_bl, bb_br),
),
five,
),
);
let aa_above = loadu_256!(
&aa[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i32; 8]
);
let aa_below = loadu_256!(
&aa[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i32; 8]
);
let aa_al = loadu_256!(
&aa[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let aa_ar = loadu_256!(
&aa[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let aa_bl = loadu_256!(
&aa[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let aa_br = loadu_256!(
&aa[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let a_six = _mm256_add_epi32(
_mm256_mullo_epi32(_mm256_add_epi32(aa_above, aa_below), six),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(aa_al, aa_ar),
_mm256_add_epi32(aa_bl, aa_br),
),
five,
),
);
let src_val = _mm256_cvtepu16_epi32(loadu_128!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 8],
[u16; 8]
));
let result = _mm256_srai_epi32::<9>(_mm256_add_epi32(
_mm256_sub_epi32(a_six, _mm256_mullo_epi32(b_six, src_val)),
rounding_9,
));
storeu_256!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 8],
[i32; 8],
result
);
i += 8;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb[idx - REST_UNIT_STRIDE] as i64;
let below = bb[idx + REST_UNIT_STRIDE] as i64;
let al = bb[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = bb[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = bb[idx + REST_UNIT_STRIDE - 1] as i64;
let br = bb[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = aa[idx - REST_UNIT_STRIDE] as i64;
let below = aa[idx + REST_UNIT_STRIDE] as i64;
let al = aa[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = aa[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = aa[idx + REST_UNIT_STRIDE - 1] as i64;
let br = aa[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i32;
i += 1;
}
if j + 1 < h {
let mut i = 0usize;
while i + 8 <= w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let bb_center = loadu_256!(&bb[idx..idx + 8], [i32; 8]);
let bb_left = loadu_256!(&bb[idx - 1..idx - 1 + 8], [i32; 8]);
let bb_right = loadu_256!(&bb[idx + 1..idx + 1 + 8], [i32; 8]);
let b_horiz = _mm256_add_epi32(
_mm256_mullo_epi32(bb_center, six),
_mm256_mullo_epi32(_mm256_add_epi32(bb_left, bb_right), five),
);
let aa_center = loadu_256!(&aa[idx..idx + 8], [i32; 8]);
let aa_left = loadu_256!(&aa[idx - 1..idx - 1 + 8], [i32; 8]);
let aa_right = loadu_256!(&aa[idx + 1..idx + 1 + 8], [i32; 8]);
let a_horiz = _mm256_add_epi32(
_mm256_mullo_epi32(aa_center, six),
_mm256_mullo_epi32(_mm256_add_epi32(aa_left, aa_right), five),
);
let src_val = _mm256_cvtepu16_epi32(loadu_128!(
&src[src_base + (j + 1) * REST_UNIT_STRIDE + i
..src_base + (j + 1) * REST_UNIT_STRIDE + i + 8],
[u16; 8]
));
let result = _mm256_srai_epi32::<8>(_mm256_add_epi32(
_mm256_sub_epi32(a_horiz, _mm256_mullo_epi32(b_horiz, src_val)),
rounding_8,
));
storeu_256!(
&mut dst[(j + 1) * MAX_RESTORATION_WIDTH + i
..(j + 1) * MAX_RESTORATION_WIDTH + i + 8],
[i32; 8],
result
);
i += 8;
}
while i < w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let b_horiz = {
let center = bb[idx] as i64;
let left = bb[idx - 1] as i64;
let right = bb[idx + 1] as i64;
center * 6 + (left + right) * 5
};
let a_horiz = {
let center = aa[idx] as i64;
let left = aa[idx - 1] as i64;
let right = aa[idx + 1] as i64;
center * 6 + (left + right) * 5
};
let src_val = src[src_base + (j + 1) * REST_UNIT_STRIDE + i] as i64;
dst[(j + 1) * MAX_RESTORATION_WIDTH + i] =
((a_horiz - b_horiz * src_val + (1 << 7)) >> 8) as i32;
i += 1;
}
}
j += 2;
}
if j < h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb[idx - REST_UNIT_STRIDE] as i64;
let below = bb[idx + REST_UNIT_STRIDE] as i64;
let al = bb[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = bb[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = bb[idx + REST_UNIT_STRIDE - 1] as i64;
let br = bb[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = aa[idx - REST_UNIT_STRIDE] as i64;
let below = aa[idx + REST_UNIT_STRIDE] as i64;
let al = aa[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = aa[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = aa[idx + REST_UNIT_STRIDE - 1] as i64;
let br = aa[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i32;
}
}
} else {
for j in 0..h {
let mut i = 0usize;
while i + 8 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_c = loadu_256!(&bb[idx..idx + 8], [i32; 8]);
let b_l = loadu_256!(&bb[idx - 1..idx - 1 + 8], [i32; 8]);
let b_r = loadu_256!(&bb[idx + 1..idx + 1 + 8], [i32; 8]);
let b_a = loadu_256!(
&bb[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i32; 8]
);
let b_b = loadu_256!(
&bb[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i32; 8]
);
let b_al = loadu_256!(
&bb[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let b_ar = loadu_256!(
&bb[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let b_bl = loadu_256!(
&bb[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let b_br = loadu_256!(
&bb[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let b_eight = _mm256_add_epi32(
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(b_c, _mm256_add_epi32(b_l, b_r)),
_mm256_add_epi32(b_a, b_b),
),
four,
),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(b_al, b_ar),
_mm256_add_epi32(b_bl, b_br),
),
three,
),
);
let a_c = loadu_256!(&aa[idx..idx + 8], [i32; 8]);
let a_l = loadu_256!(&aa[idx - 1..idx - 1 + 8], [i32; 8]);
let a_r = loadu_256!(&aa[idx + 1..idx + 1 + 8], [i32; 8]);
let a_a = loadu_256!(
&aa[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 8],
[i32; 8]
);
let a_b = loadu_256!(
&aa[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 8],
[i32; 8]
);
let a_al = loadu_256!(
&aa[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let a_ar = loadu_256!(
&aa[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let a_bl = loadu_256!(
&aa[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 8],
[i32; 8]
);
let a_br = loadu_256!(
&aa[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 8],
[i32; 8]
);
let a_eight = _mm256_add_epi32(
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(a_c, _mm256_add_epi32(a_l, a_r)),
_mm256_add_epi32(a_a, a_b),
),
four,
),
_mm256_mullo_epi32(
_mm256_add_epi32(
_mm256_add_epi32(a_al, a_ar),
_mm256_add_epi32(a_bl, a_br),
),
three,
),
);
let src_val = _mm256_cvtepu16_epi32(loadu_128!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 8],
[u16; 8]
));
let result = _mm256_srai_epi32::<9>(_mm256_add_epi32(
_mm256_sub_epi32(a_eight, _mm256_mullo_epi32(b_eight, src_val)),
rounding_9,
));
storeu_256!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 8],
[i32; 8],
result
);
i += 8;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_eight = {
let center = bb[idx] as i64;
let left = bb[idx - 1] as i64;
let right = bb[idx + 1] as i64;
let above = bb[idx - REST_UNIT_STRIDE] as i64;
let below = bb[idx + REST_UNIT_STRIDE] as i64;
let al = bb[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = bb[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = bb[idx + REST_UNIT_STRIDE - 1] as i64;
let br = bb[idx + REST_UNIT_STRIDE + 1] as i64;
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let a_eight = {
let center = aa[idx] as i64;
let left = aa[idx - 1] as i64;
let right = aa[idx + 1] as i64;
let above = aa[idx - REST_UNIT_STRIDE] as i64;
let below = aa[idx + REST_UNIT_STRIDE] as i64;
let al = aa[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = aa[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = aa[idx + REST_UNIT_STRIDE - 1] as i64;
let br = aa[idx + REST_UNIT_STRIDE + 1] as i64;
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_eight - b_eight * src_val + (1 << 8)) >> 9) as i32;
i += 1;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn selfguided_filter_16bpc_avx512(
_token: Server64,
dst: &mut [i32; 64 * MAX_RESTORATION_WIDTH],
src: &[u16; (64 + 3 + 3) * REST_UNIT_STRIDE],
w: usize,
h: usize,
n: i32,
s: u32,
bitdepth_max: i32,
) {
let sgr_one_by_x: u32 = if n == 25 { 164 } else { 455 };
let bitdepth = if bitdepth_max == 1023 { 10 } else { 12 };
let bitdepth_min_8 = bitdepth - 8;
let mut sumsq = [0i64; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut sum = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut aa = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let mut bb = [0i32; (64 + 2 + 2) * REST_UNIT_STRIDE];
let step = if n == 25 { 2 } else { 1 };
if n == 25 {
boxsum5_v_16bpc_avx512(_token, &mut sumsq, &mut sum, src, w + 6, h + 6);
boxsum5_h_16bpc_avx512(_token, &mut sumsq, &mut sum, w + 6, h + 6);
} else {
boxsum3_v_16bpc_avx512(_token, &mut sumsq, &mut sum, src, w + 6, h + 6);
boxsum3_h_16bpc_avx512(_token, &mut sumsq, &mut sum, w + 6, h + 6);
}
for row_offset in (0..(h + 2)).step_by(step) {
let aa_base = (row_offset + 1) * REST_UNIT_STRIDE + 2;
for i in 0..(w + 2) {
let idx = aa_base + i;
let a_val = sumsq[idx];
let b_val = sum[idx] as i64;
let a_scaled =
((a_val + (1 << (2 * bitdepth_min_8 - 1))) >> (2 * bitdepth_min_8)) as i32;
let b_scaled = ((b_val + (1 << (bitdepth_min_8 - 1))) >> bitdepth_min_8) as i32;
let p = cmp::max(a_scaled * n - b_scaled * b_scaled, 0) as u32;
let z = (p * s + (1 << 19)) >> 20;
let x = dav1d_sgr_x_by_x[cmp::min(z, 255) as usize] as u32;
aa[idx] = ((x * (b_val as u32) * sgr_one_by_x + (1 << 11)) >> 12) as i32;
bb[idx] = x as i32;
}
}
let base = 2 * REST_UNIT_STRIDE + 3;
let src_base = 3 * REST_UNIT_STRIDE + 3;
let rounding_9 = _mm512_set1_epi32(1 << 8);
let rounding_8 = _mm512_set1_epi32(1 << 7);
let six = _mm512_set1_epi32(6);
let five = _mm512_set1_epi32(5);
let four = _mm512_set1_epi32(4);
let three = _mm512_set1_epi32(3);
if n == 25 {
let mut j = 0usize;
while j < h.saturating_sub(1) {
let mut i = 0usize;
while i + 16 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let bb_above = loadu_512!(
&bb[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i32; 16]
);
let bb_below = loadu_512!(
&bb[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i32; 16]
);
let bb_al = loadu_512!(
&bb[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let bb_ar = loadu_512!(
&bb[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let bb_bl = loadu_512!(
&bb[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let bb_br = loadu_512!(
&bb[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let b_six = _mm512_add_epi32(
_mm512_mullo_epi32(_mm512_add_epi32(bb_above, bb_below), six),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(bb_al, bb_ar),
_mm512_add_epi32(bb_bl, bb_br),
),
five,
),
);
let aa_above = loadu_512!(
&aa[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i32; 16]
);
let aa_below = loadu_512!(
&aa[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i32; 16]
);
let aa_al = loadu_512!(
&aa[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let aa_ar = loadu_512!(
&aa[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let aa_bl = loadu_512!(
&aa[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let aa_br = loadu_512!(
&aa[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let a_six = _mm512_add_epi32(
_mm512_mullo_epi32(_mm512_add_epi32(aa_above, aa_below), six),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(aa_al, aa_ar),
_mm512_add_epi32(aa_bl, aa_br),
),
five,
),
);
let src_words = loadu_256!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 16],
[u16; 16]
);
let src_val = _mm512_cvtepu16_epi32(src_words);
let result = _mm512_srai_epi32::<9>(_mm512_add_epi32(
_mm512_sub_epi32(a_six, _mm512_mullo_epi32(b_six, src_val)),
rounding_9,
));
storeu_512!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 16],
[i32; 16],
result
);
i += 16;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb[idx - REST_UNIT_STRIDE] as i64;
let below = bb[idx + REST_UNIT_STRIDE] as i64;
let al = bb[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = bb[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = bb[idx + REST_UNIT_STRIDE - 1] as i64;
let br = bb[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = aa[idx - REST_UNIT_STRIDE] as i64;
let below = aa[idx + REST_UNIT_STRIDE] as i64;
let al = aa[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = aa[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = aa[idx + REST_UNIT_STRIDE - 1] as i64;
let br = aa[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i32;
i += 1;
}
if j + 1 < h {
let mut i = 0usize;
while i + 16 <= w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let bb_center = loadu_512!(&bb[idx..idx + 16], [i32; 16]);
let bb_left = loadu_512!(&bb[idx - 1..idx - 1 + 16], [i32; 16]);
let bb_right = loadu_512!(&bb[idx + 1..idx + 1 + 16], [i32; 16]);
let b_horiz = _mm512_add_epi32(
_mm512_mullo_epi32(bb_center, six),
_mm512_mullo_epi32(_mm512_add_epi32(bb_left, bb_right), five),
);
let aa_center = loadu_512!(&aa[idx..idx + 16], [i32; 16]);
let aa_left = loadu_512!(&aa[idx - 1..idx - 1 + 16], [i32; 16]);
let aa_right = loadu_512!(&aa[idx + 1..idx + 1 + 16], [i32; 16]);
let a_horiz = _mm512_add_epi32(
_mm512_mullo_epi32(aa_center, six),
_mm512_mullo_epi32(_mm512_add_epi32(aa_left, aa_right), five),
);
let src_words = loadu_256!(
&src[src_base + (j + 1) * REST_UNIT_STRIDE + i
..src_base + (j + 1) * REST_UNIT_STRIDE + i + 16],
[u16; 16]
);
let src_val = _mm512_cvtepu16_epi32(src_words);
let result = _mm512_srai_epi32::<8>(_mm512_add_epi32(
_mm512_sub_epi32(a_horiz, _mm512_mullo_epi32(b_horiz, src_val)),
rounding_8,
));
storeu_512!(
&mut dst[(j + 1) * MAX_RESTORATION_WIDTH + i
..(j + 1) * MAX_RESTORATION_WIDTH + i + 16],
[i32; 16],
result
);
i += 16;
}
while i < w {
let idx = base + (j + 1) * REST_UNIT_STRIDE + i;
let b_horiz = {
let center = bb[idx] as i64;
let left = bb[idx - 1] as i64;
let right = bb[idx + 1] as i64;
center * 6 + (left + right) * 5
};
let a_horiz = {
let center = aa[idx] as i64;
let left = aa[idx - 1] as i64;
let right = aa[idx + 1] as i64;
center * 6 + (left + right) * 5
};
let src_val = src[src_base + (j + 1) * REST_UNIT_STRIDE + i] as i64;
dst[(j + 1) * MAX_RESTORATION_WIDTH + i] =
((a_horiz - b_horiz * src_val + (1 << 7)) >> 8) as i32;
i += 1;
}
}
j += 2;
}
if j < h {
for i in 0..w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_six = {
let above = bb[idx - REST_UNIT_STRIDE] as i64;
let below = bb[idx + REST_UNIT_STRIDE] as i64;
let al = bb[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = bb[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = bb[idx + REST_UNIT_STRIDE - 1] as i64;
let br = bb[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let a_six = {
let above = aa[idx - REST_UNIT_STRIDE] as i64;
let below = aa[idx + REST_UNIT_STRIDE] as i64;
let al = aa[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = aa[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = aa[idx + REST_UNIT_STRIDE - 1] as i64;
let br = aa[idx + REST_UNIT_STRIDE + 1] as i64;
(above + below) * 6 + (al + ar + bl + br) * 5
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_six - b_six * src_val + (1 << 8)) >> 9) as i32;
}
}
} else {
for j in 0..h {
let mut i = 0usize;
while i + 16 <= w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_c = loadu_512!(&bb[idx..idx + 16], [i32; 16]);
let b_l = loadu_512!(&bb[idx - 1..idx - 1 + 16], [i32; 16]);
let b_r = loadu_512!(&bb[idx + 1..idx + 1 + 16], [i32; 16]);
let b_a = loadu_512!(
&bb[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i32; 16]
);
let b_b = loadu_512!(
&bb[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i32; 16]
);
let b_al = loadu_512!(
&bb[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let b_ar = loadu_512!(
&bb[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let b_bl = loadu_512!(
&bb[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let b_br = loadu_512!(
&bb[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let b_eight = _mm512_add_epi32(
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(b_c, _mm512_add_epi32(b_l, b_r)),
_mm512_add_epi32(b_a, b_b),
),
four,
),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(b_al, b_ar),
_mm512_add_epi32(b_bl, b_br),
),
three,
),
);
let a_c = loadu_512!(&aa[idx..idx + 16], [i32; 16]);
let a_l = loadu_512!(&aa[idx - 1..idx - 1 + 16], [i32; 16]);
let a_r = loadu_512!(&aa[idx + 1..idx + 1 + 16], [i32; 16]);
let a_a = loadu_512!(
&aa[idx - REST_UNIT_STRIDE..idx - REST_UNIT_STRIDE + 16],
[i32; 16]
);
let a_b = loadu_512!(
&aa[idx + REST_UNIT_STRIDE..idx + REST_UNIT_STRIDE + 16],
[i32; 16]
);
let a_al = loadu_512!(
&aa[idx - REST_UNIT_STRIDE - 1..idx - REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let a_ar = loadu_512!(
&aa[idx - REST_UNIT_STRIDE + 1..idx - REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let a_bl = loadu_512!(
&aa[idx + REST_UNIT_STRIDE - 1..idx + REST_UNIT_STRIDE - 1 + 16],
[i32; 16]
);
let a_br = loadu_512!(
&aa[idx + REST_UNIT_STRIDE + 1..idx + REST_UNIT_STRIDE + 1 + 16],
[i32; 16]
);
let a_eight = _mm512_add_epi32(
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(a_c, _mm512_add_epi32(a_l, a_r)),
_mm512_add_epi32(a_a, a_b),
),
four,
),
_mm512_mullo_epi32(
_mm512_add_epi32(
_mm512_add_epi32(a_al, a_ar),
_mm512_add_epi32(a_bl, a_br),
),
three,
),
);
let src_words = loadu_256!(
&src[src_base + j * REST_UNIT_STRIDE + i
..src_base + j * REST_UNIT_STRIDE + i + 16],
[u16; 16]
);
let src_val = _mm512_cvtepu16_epi32(src_words);
let result = _mm512_srai_epi32::<9>(_mm512_add_epi32(
_mm512_sub_epi32(a_eight, _mm512_mullo_epi32(b_eight, src_val)),
rounding_9,
));
storeu_512!(
&mut dst[j * MAX_RESTORATION_WIDTH + i..j * MAX_RESTORATION_WIDTH + i + 16],
[i32; 16],
result
);
i += 16;
}
while i < w {
let idx = base + j * REST_UNIT_STRIDE + i;
let b_eight = {
let center = bb[idx] as i64;
let left = bb[idx - 1] as i64;
let right = bb[idx + 1] as i64;
let above = bb[idx - REST_UNIT_STRIDE] as i64;
let below = bb[idx + REST_UNIT_STRIDE] as i64;
let al = bb[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = bb[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = bb[idx + REST_UNIT_STRIDE - 1] as i64;
let br = bb[idx + REST_UNIT_STRIDE + 1] as i64;
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let a_eight = {
let center = aa[idx] as i64;
let left = aa[idx - 1] as i64;
let right = aa[idx + 1] as i64;
let above = aa[idx - REST_UNIT_STRIDE] as i64;
let below = aa[idx + REST_UNIT_STRIDE] as i64;
let al = aa[idx - REST_UNIT_STRIDE - 1] as i64;
let ar = aa[idx - REST_UNIT_STRIDE + 1] as i64;
let bl = aa[idx + REST_UNIT_STRIDE - 1] as i64;
let br = aa[idx + REST_UNIT_STRIDE + 1] as i64;
(center + left + right + above + below) * 4 + (al + ar + bl + br) * 3
};
let src_val = src[src_base + j * REST_UNIT_STRIDE + i] as i64;
dst[j * MAX_RESTORATION_WIDTH + i] =
((a_eight - b_eight * src_val + (1 << 8)) >> 9) as i32;
i += 1;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn sgr_apply_16bpc(
_t: Desktop64,
p_guard: &mut [u16],
p_base: usize,
stride: isize,
dst: &[i32],
w: usize,
h: usize,
w_k: i32,
bitdepth_max: c_int,
) {
use super::pixel_access::{loadu_128, storeu_128};
let w_k_v = _mm256_set1_epi32(w_k);
let rounding_v = _mm256_set1_epi32(1 << 10);
let zero_256 = _mm256_setzero_si256();
let max_v = _mm256_set1_epi32(bitdepth_max);
let dst = dst.flex();
let mut p_guard = p_guard.flex_mut();
for j in 0..h {
let row_off = p_base.wrapping_add_signed(j as isize * stride);
let dst_row = j * MAX_RESTORATION_WIDTH;
let mut i = 0;
while i + 8 <= w {
let dst_v = loadu_256!(&dst[dst_row + i..dst_row + i + 8], [i32; 8]);
let product = _mm256_mullo_epi32(dst_v, w_k_v);
let rounded = _mm256_add_epi32(product, rounding_v);
let delta = _mm256_srai_epi32::<11>(rounded);
let pixel_sse = loadu_128!(&p_guard[row_off + i..row_off + i + 8], [u16; 8]);
let pixels_i32 = _mm256_cvtepu16_epi32(pixel_sse);
let result = _mm256_add_epi32(pixels_i32, delta);
let clamped = _mm256_max_epi32(_mm256_min_epi32(result, max_v), zero_256);
let lo = _mm256_castsi256_si128(clamped);
let hi = _mm256_extracti128_si256::<1>(clamped);
let u16_packed = _mm_packus_epi32(lo, hi);
storeu_128!(
<&mut [u16; 8]>::try_from(&mut p_guard[row_off + i..row_off + i + 8]).unwrap(),
u16_packed
);
i += 8;
}
while i < w {
let v = w_k * dst[dst_row + i];
p_guard[row_off + i] = iclip(
p_guard[row_off + i] as i32 + ((v + (1 << 10)) >> 11),
0,
bitdepth_max,
) as u16;
i += 1;
}
}
}
#[cfg(target_arch = "x86_64")]
#[arcane]
fn sgr_apply_mix_16bpc(
_t: Desktop64,
p_guard: &mut [u16],
p_base: usize,
stride: isize,
dst0: &[i32],
dst1: &[i32],
w: usize,
h: usize,
w0: i32,
w1: i32,
bitdepth_max: c_int,
) {
use super::pixel_access::{loadu_128, storeu_128};
let w0_v = _mm256_set1_epi32(w0);
let w1_v = _mm256_set1_epi32(w1);
let rounding_v = _mm256_set1_epi32(1 << 10);
let zero_256 = _mm256_setzero_si256();
let max_v = _mm256_set1_epi32(bitdepth_max);
let dst0 = dst0.flex();
let dst1 = dst1.flex();
let mut p_guard = p_guard.flex_mut();
for j in 0..h {
let row_off = p_base.wrapping_add_signed(j as isize * stride);
let dst_row = j * MAX_RESTORATION_WIDTH;
let mut i = 0;
while i + 8 <= w {
let d0_v = loadu_256!(&dst0[dst_row + i..dst_row + i + 8], [i32; 8]);
let d1_v = loadu_256!(&dst1[dst_row + i..dst_row + i + 8], [i32; 8]);
let v = _mm256_add_epi32(
_mm256_mullo_epi32(d0_v, w0_v),
_mm256_mullo_epi32(d1_v, w1_v),
);
let rounded = _mm256_add_epi32(v, rounding_v);
let delta = _mm256_srai_epi32::<11>(rounded);
let pixel_sse = loadu_128!(&p_guard[row_off + i..row_off + i + 8], [u16; 8]);
let pixels_i32 = _mm256_cvtepu16_epi32(pixel_sse);
let result = _mm256_add_epi32(pixels_i32, delta);
let clamped = _mm256_max_epi32(_mm256_min_epi32(result, max_v), zero_256);
let lo = _mm256_castsi256_si128(clamped);
let hi = _mm256_extracti128_si256::<1>(clamped);
let u16_packed = _mm_packus_epi32(lo, hi);
storeu_128!(
<&mut [u16; 8]>::try_from(&mut p_guard[row_off + i..row_off + i + 8]).unwrap(),
u16_packed
);
i += 8;
}
while i < w {
let v = w0 * dst0[dst_row + i] + w1 * dst1[dst_row + i];
p_guard[row_off + i] = iclip(
p_guard[row_off + i] as i32 + ((v + (1 << 10)) >> 11),
0,
bitdepth_max,
) as u16;
i += 1;
}
}
}
#[cfg(target_arch = "x86_64")]
fn sgr_5x5_16bpc_avx2_inner(
p: PicOffset,
left: &[LeftPixelRow<u16>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
) {
let mut tmp = [0u16; (64 + 3 + 3) * REST_UNIT_STRIDE];
let mut dst = [0i32; 64 * MAX_RESTORATION_WIDTH];
padding::<BitDepth16>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let sgr = params.sgr();
#[cfg(target_arch = "x86_64")]
if let Some(token) = crate::src::cpu::summon_avx512() {
selfguided_filter_16bpc_avx512(token, &mut dst, &tmp, w, h, 25, sgr.s0, bitdepth_max);
} else if let Some(token) = summon_avx2() {
selfguided_filter_16bpc_avx2(token, &mut dst, &tmp, w, h, 25, sgr.s0, bitdepth_max);
} else {
selfguided_filter_16bpc(&mut dst, &tmp, w, h, 25, sgr.s0, bitdepth_max);
}
#[cfg(not(target_arch = "x86_64"))]
selfguided_filter_16bpc(&mut dst, &tmp, w, h, 25, sgr.s0, bitdepth_max);
let w0 = sgr.w0 as i32;
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth16, _>(
&p,
w,
h,
|bytes, offset, stride| {
let p_u16: &mut [u16] = zerocopy::FromBytes::mut_from_bytes(&mut bytes[..])
.expect("bytes alignment/size mismatch for u16 reinterpretation");
if let Some(token) = summon_avx2() {
sgr_apply_16bpc(
token,
p_u16,
offset / 2,
stride / 2,
&dst,
w,
h,
w0,
bitdepth_max,
);
} else {
let dst = dst.as_slice().flex();
let mut cp = p_u16.flex_mut();
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize / 2;
for i in 0..w {
let v = w0 * dst[j * MAX_RESTORATION_WIDTH + i];
cp[row_off + i] = iclip(
cp[row_off + i] as i32 + ((v + (1 << 10)) >> 11),
0,
bitdepth_max,
) as u16;
}
}
}
},
); }
#[cfg(target_arch = "x86_64")]
fn sgr_3x3_16bpc_avx2_inner(
p: PicOffset,
left: &[LeftPixelRow<u16>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
) {
let mut tmp = [0u16; (64 + 3 + 3) * REST_UNIT_STRIDE];
let mut dst = [0i32; 64 * MAX_RESTORATION_WIDTH];
padding::<BitDepth16>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let sgr = params.sgr();
#[cfg(target_arch = "x86_64")]
if let Some(token) = crate::src::cpu::summon_avx512() {
selfguided_filter_16bpc_avx512(token, &mut dst, &tmp, w, h, 9, sgr.s1, bitdepth_max);
} else if let Some(token) = summon_avx2() {
selfguided_filter_16bpc_avx2(token, &mut dst, &tmp, w, h, 9, sgr.s1, bitdepth_max);
} else {
selfguided_filter_16bpc(&mut dst, &tmp, w, h, 9, sgr.s1, bitdepth_max);
}
#[cfg(not(target_arch = "x86_64"))]
selfguided_filter_16bpc(&mut dst, &tmp, w, h, 9, sgr.s1, bitdepth_max);
let w1 = sgr.w1 as i32;
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth16, _>(
&p,
w,
h,
|bytes, offset, stride| {
let p_u16: &mut [u16] = zerocopy::FromBytes::mut_from_bytes(&mut bytes[..])
.expect("bytes alignment/size mismatch for u16 reinterpretation");
if let Some(token) = summon_avx2() {
sgr_apply_16bpc(
token,
p_u16,
offset / 2,
stride / 2,
&dst,
w,
h,
w1,
bitdepth_max,
);
} else {
let dst = dst.as_slice().flex();
let mut cp = p_u16.flex_mut();
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize / 2;
for i in 0..w {
let v = w1 * dst[j * MAX_RESTORATION_WIDTH + i];
cp[row_off + i] = iclip(
cp[row_off + i] as i32 + ((v + (1 << 10)) >> 11),
0,
bitdepth_max,
) as u16;
}
}
}
},
); }
#[cfg(target_arch = "x86_64")]
fn sgr_mix_16bpc_avx2_inner(
p: PicOffset,
left: &[LeftPixelRow<u16>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: usize,
h: usize,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
) {
let mut tmp = [0u16; (64 + 3 + 3) * REST_UNIT_STRIDE];
let mut dst0 = [0i32; 64 * MAX_RESTORATION_WIDTH];
let mut dst1 = [0i32; 64 * MAX_RESTORATION_WIDTH];
padding::<BitDepth16>(&mut tmp, p, left, lpf, lpf_off, w, h, edges);
let sgr = params.sgr();
#[cfg(target_arch = "x86_64")]
if let Some(token) = crate::src::cpu::summon_avx512() {
selfguided_filter_16bpc_avx512(token, &mut dst0, &tmp, w, h, 25, sgr.s0, bitdepth_max);
selfguided_filter_16bpc_avx512(token, &mut dst1, &tmp, w, h, 9, sgr.s1, bitdepth_max);
} else if let Some(token) = summon_avx2() {
selfguided_filter_16bpc_avx2(token, &mut dst0, &tmp, w, h, 25, sgr.s0, bitdepth_max);
selfguided_filter_16bpc_avx2(token, &mut dst1, &tmp, w, h, 9, sgr.s1, bitdepth_max);
} else {
selfguided_filter_16bpc(&mut dst0, &tmp, w, h, 25, sgr.s0, bitdepth_max);
selfguided_filter_16bpc(&mut dst1, &tmp, w, h, 9, sgr.s1, bitdepth_max);
}
#[cfg(not(target_arch = "x86_64"))]
{
selfguided_filter_16bpc(&mut dst0, &tmp, w, h, 25, sgr.s0, bitdepth_max);
selfguided_filter_16bpc(&mut dst1, &tmp, w, h, 9, sgr.s1, bitdepth_max);
}
let w0 = sgr.w0 as i32;
let w1 = sgr.w1 as i32;
crate::include::dav1d::picture::with_pixel_guard_mut::<BitDepth16, _>(
&p,
w,
h,
|bytes, offset, stride| {
let p_u16: &mut [u16] = zerocopy::FromBytes::mut_from_bytes(&mut bytes[..])
.expect("bytes alignment/size mismatch for u16 reinterpretation");
if let Some(token) = summon_avx2() {
sgr_apply_mix_16bpc(
token,
p_u16,
offset / 2,
stride / 2,
&dst0,
&dst1,
w,
h,
w0,
w1,
bitdepth_max,
);
} else {
let d0 = dst0.as_slice().flex();
let d1 = dst1.as_slice().flex();
let mut cp = p_u16.flex_mut();
for j in 0..h {
let row_off = (offset as isize + j as isize * stride) as usize / 2;
for i in 0..w {
let v = w0 * d0[j * MAX_RESTORATION_WIDTH + i]
+ w1 * d1[j * MAX_RESTORATION_WIDTH + i];
cp[row_off + i] = iclip(
cp[row_off + i] as i32 + ((v + (1 << 10)) >> 11),
0,
bitdepth_max,
) as u16;
}
}
}
},
); }
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn sgr_filter_5x5_16bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u16>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u16>();
let lpf_off = reconstruct_lpf_offset_16bpc(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
sgr_5x5_16bpc_avx2_inner(p, left, lpf, lpf_off, w, h, params, edges, bitdepth_max);
}
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn sgr_filter_3x3_16bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u16>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u16>();
let lpf_off = reconstruct_lpf_offset_16bpc(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
sgr_3x3_16bpc_avx2_inner(p, left, lpf, lpf_off, w, h, params, edges, bitdepth_max);
}
#[cfg(all(feature = "asm", target_arch = "x86_64"))]
#[target_feature(enable = "avx2")]
pub unsafe extern "C" fn sgr_filter_mix_16bpc_avx2(
_p_ptr: *mut DynPixel,
_stride: ptrdiff_t,
left: *const LeftPixelRow<DynPixel>,
lpf_ptr: *const DynPixel,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bitdepth_max: c_int,
p: *const FFISafe<PicOffset>,
lpf: *const FFISafe<DisjointMut<AlignedVec64<u8>>>,
) {
let p = unsafe { *FFISafe::get(p) };
let left = left.cast::<LeftPixelRow<u16>>();
let lpf = unsafe { FFISafe::get(lpf) };
let lpf_ptr = lpf_ptr.cast::<u16>();
let lpf_off = reconstruct_lpf_offset_16bpc(lpf, lpf_ptr);
let w = w as usize;
let h = h as usize;
let left = unsafe { slice::from_raw_parts(left, h) };
sgr_mix_16bpc_avx2_inner(p, left, lpf, lpf_off, w, h, params, edges, bitdepth_max);
}
#[cfg(target_arch = "x86_64")]
pub fn lr_filter_dispatch<BD: BitDepth>(
variant: usize,
dst: PicOffset,
left: &[LeftPixelRow<BD::Pixel>],
lpf: &DisjointMut<AlignedVec64<u8>>,
lpf_off: isize,
w: c_int,
h: c_int,
params: &LooprestorationParams,
edges: LrEdgeFlags,
bd: BD,
) -> bool {
use crate::include::common::bitdepth::BPC;
let avx512_token = crate::src::cpu::summon_avx512();
let Some(token) = crate::src::cpu::summon_avx2() else {
return false;
};
let w = w as usize;
let h = h as usize;
use crate::src::safe_simd::pixel_access::reinterpret_slice;
let left_8 =
|| -> &[LeftPixelRow<u8>] { reinterpret_slice(left).expect("BD::Pixel layout matches u8") };
let left_16 = || -> &[LeftPixelRow<u16>] {
reinterpret_slice(left).expect("BD::Pixel layout matches u16")
};
match (BD::BPC, variant) {
(BPC::BPC8, 0) => {
if let Some(t512) = avx512_token {
wiener_filter7_8bpc_avx512_inner(
t512,
dst,
left_8(),
lpf,
lpf_off,
w,
h,
params,
edges,
)
} else {
wiener_filter7_8bpc_avx2_inner(
token,
dst,
left_8(),
lpf,
lpf_off,
w,
h,
params,
edges,
)
}
}
(BPC::BPC8, 1) => {
if let Some(t512) = avx512_token {
wiener_filter5_8bpc_avx512_inner(
t512,
dst,
left_8(),
lpf,
lpf_off,
w,
h,
params,
edges,
)
} else {
wiener_filter5_8bpc_avx2_inner(
token,
dst,
left_8(),
lpf,
lpf_off,
w,
h,
params,
edges,
)
}
}
(BPC::BPC8, 2) => sgr_5x5_8bpc_avx2_inner(dst, left_8(), lpf, lpf_off, w, h, params, edges),
(BPC::BPC8, 3) => sgr_3x3_8bpc_avx2_inner(dst, left_8(), lpf, lpf_off, w, h, params, edges),
(BPC::BPC8, _) => sgr_mix_8bpc_avx2_inner(dst, left_8(), lpf, lpf_off, w, h, params, edges),
(BPC::BPC16, 0) => {
if let Some(t512) = avx512_token {
wiener_filter7_16bpc_avx512_inner(
t512,
dst,
left_16(),
lpf,
lpf_off,
w,
h,
params,
edges,
bd.into_c(),
)
} else {
wiener_filter7_16bpc_avx2_inner(
token,
dst,
left_16(),
lpf,
lpf_off,
w,
h,
params,
edges,
bd.into_c(),
)
}
}
(BPC::BPC16, 1) => {
if let Some(t512) = avx512_token {
wiener_filter5_16bpc_avx512_inner(
t512,
dst,
left_16(),
lpf,
lpf_off,
w,
h,
params,
edges,
bd.into_c(),
)
} else {
wiener_filter5_16bpc_avx2_inner(
token,
dst,
left_16(),
lpf,
lpf_off,
w,
h,
params,
edges,
bd.into_c(),
)
}
}
(BPC::BPC16, 2) => sgr_5x5_16bpc_avx2_inner(
dst,
left_16(),
lpf,
lpf_off,
w,
h,
params,
edges,
bd.into_c(),
),
(BPC::BPC16, 3) => sgr_3x3_16bpc_avx2_inner(
dst,
left_16(),
lpf,
lpf_off,
w,
h,
params,
edges,
bd.into_c(),
),
(BPC::BPC16, _) => sgr_mix_16bpc_avx2_inner(
dst,
left_16(),
lpf,
lpf_off,
w,
h,
params,
edges,
bd.into_c(),
),
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rest_unit_stride() {
assert_eq!(REST_UNIT_STRIDE, 256 * 3 / 2 + 3 + 3);
assert_eq!(REST_UNIT_STRIDE, 390);
}
#[test]
fn test_max_restoration_width() {
assert_eq!(MAX_RESTORATION_WIDTH, 384);
}
}