use crate::EdgeMode;
use crate::edge_mode::clamp_edge;
use crate::neon::{load_f32_f16, store_f32_f16};
use crate::unsafe_slice::UnsafeSlice;
use core::f16;
use std::arch::aarch64::*;
pub(crate) fn fg_vertical_pass_neon_f16<const CN: usize>(
bytes: &UnsafeSlice<f16>,
stride: u32,
width: u32,
height: u32,
radius: u32,
start: u32,
end: u32,
edge_mode: EdgeMode,
) {
unsafe {
let mut buffer = [[0f32; 4]; 1024];
let height_wide = height as i64;
let radius_64 = radius as i64;
let weight = 1.0f32 / ((radius as f32) * (radius as f32));
let f_weight = vdupq_n_f32(weight);
for x in start..std::cmp::min(width, end) {
let mut diffs = vdupq_n_f32(0f32);
let mut summs = vdupq_n_f32(0f32);
let start_y = 0 - 2 * radius as i64;
for y in start_y..height_wide {
let current_y = (y * (stride as i64)) as usize;
if y >= 0 {
let current_px = x as usize * CN;
let prepared_px = vmulq_f32(summs, f_weight);
let dst_ptr = bytes.get_ptr(current_y + current_px);
store_f32_f16::<CN>(dst_ptr, prepared_px);
let arr_index = ((y - radius_64) & 1023) as usize;
let d_arr_index = (y & 1023) as usize;
let d_buf_ptr = buffer.get_unchecked(d_arr_index).as_ptr();
let mut d_stored = vld1q_f32(d_buf_ptr);
d_stored = vmulq_n_f32(d_stored, 2f32);
let buf_ptr = buffer.get_unchecked(arr_index).as_ptr();
let a_stored = vld1q_f32(buf_ptr);
diffs = vaddq_f32(diffs, vsubq_f32(a_stored, d_stored));
} else if y + radius_64 >= 0 {
let arr_index = (y & 1023) as usize;
let buf_ptr = buffer.get_unchecked(arr_index).as_ptr();
let mut stored = vld1q_f32(buf_ptr);
stored = vmulq_n_f32(stored, 2f32);
diffs = vsubq_f32(diffs, stored);
}
let next_row_y =
clamp_edge!(edge_mode, y + radius_64, 0, height_wide) * (stride as usize);
let next_row_x = x as usize * CN;
let s_ptr = bytes.get_ptr(next_row_y + next_row_x);
let pixel_color = load_f32_f16::<CN>(s_ptr);
let arr_index = ((y + radius_64) & 1023) as usize;
let buf_ptr = buffer.get_unchecked_mut(arr_index).as_mut_ptr();
diffs = vaddq_f32(diffs, pixel_color);
summs = vaddq_f32(summs, diffs);
vst1q_f32(buf_ptr, pixel_color);
}
}
}
}
pub(crate) fn fg_horizontal_pass_neon_f16<const CN: usize>(
bytes: &UnsafeSlice<f16>,
stride: u32,
width: u32,
height: u32,
radius: u32,
start: u32,
end: u32,
edge_mode: EdgeMode,
) {
unsafe {
let mut buffer: [[f32; 4]; 1024] = [[0f32; 4]; 1024];
let radius_64 = radius as i64;
let width_wide = width as i64;
let weight = 1.0f32 / ((radius as f32) * (radius as f32));
let f_weight = vdupq_n_f32(weight);
for y in start..std::cmp::min(height, end) {
let mut diffs: float32x4_t = vdupq_n_f32(0f32);
let mut summs: float32x4_t = vdupq_n_f32(0f32);
let current_y = ((y as i64) * (stride as i64)) as usize;
let start_x = 0 - 2 * radius_64;
for x in start_x..(width as i64) {
if x >= 0 {
let current_px = (std::cmp::max(x, 0) as u32) as usize * CN;
let prepared_px = vmulq_f32(summs, f_weight);
let dst_ptr = bytes.get_ptr(current_y + current_px);
store_f32_f16::<CN>(dst_ptr, prepared_px);
let arr_index = ((x - radius_64) & 1023) as usize;
let d_arr_index = (x & 1023) as usize;
let d_buf_ptr = buffer.get_unchecked(d_arr_index).as_ptr();
let mut d_stored = vld1q_f32(d_buf_ptr);
d_stored = vmulq_n_f32(d_stored, 2f32);
let buf_ptr = buffer.get_unchecked(arr_index).as_ptr();
let a_stored = vld1q_f32(buf_ptr);
diffs = vaddq_f32(diffs, vsubq_f32(a_stored, d_stored));
} else if x + radius_64 >= 0 {
let arr_index = (x & 1023) as usize;
let buf_ptr = buffer.get_unchecked(arr_index).as_ptr();
let mut stored = vld1q_f32(buf_ptr);
stored = vmulq_n_f32(stored, 2f32);
diffs = vsubq_f32(diffs, stored);
}
let next_row_y = (y as usize) * (stride as usize);
let next_row_x = clamp_edge!(edge_mode, x + radius_64, 0, width_wide);
let next_row_px = next_row_x * CN;
let s_ptr = bytes.get_ptr(next_row_y + next_row_px);
let pixel_color = load_f32_f16::<CN>(s_ptr);
let arr_index = ((x + radius_64) & 1023) as usize;
let buf_ptr = buffer.get_unchecked_mut(arr_index).as_mut_ptr();
diffs = vaddq_f32(diffs, pixel_color);
summs = vaddq_f32(summs, diffs);
vst1q_f32(buf_ptr, pixel_color);
}
}
}
}