#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use super::super::super::math::avx512::{
exp_f32 as exp_vec_f32, exp_f64 as exp_vec_f64, log_f32 as log_vec_f32, log_f64 as log_vec_f64,
};
use super::super::super::streaming::{is_aligned_avx512, should_stream_f32, should_stream_f64};
use super::super::{binary_scalar_f32, binary_scalar_f64};
use crate::ops::BinaryOp;
const F32_LANES: usize = 16;
const F64_LANES: usize = 8;
macro_rules! impl_binary_f32_avx512 {
($name:ident, $vec_op:ident, $store:ident) => {
#[target_feature(enable = "avx512f")]
unsafe fn $name(a: *const f32, b: *const f32, out: *mut f32, chunks: usize) {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = _mm512_loadu_ps(a.add(offset));
let vb = _mm512_loadu_ps(b.add(offset));
let vr = $vec_op(va, vb);
$store(out.add(offset), vr);
}
}
};
}
macro_rules! impl_binary_f64_avx512 {
($name:ident, $vec_op:ident, $store:ident) => {
#[target_feature(enable = "avx512f")]
unsafe fn $name(a: *const f64, b: *const f64, out: *mut f64, chunks: usize) {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = _mm512_loadu_pd(a.add(offset));
let vb = _mm512_loadu_pd(b.add(offset));
let vr = $vec_op(va, vb);
$store(out.add(offset), vr);
}
}
};
}
impl_binary_f32_avx512!(binary_add_f32, _mm512_add_ps, _mm512_storeu_ps);
impl_binary_f32_avx512!(binary_sub_f32, _mm512_sub_ps, _mm512_storeu_ps);
impl_binary_f32_avx512!(binary_mul_f32, _mm512_mul_ps, _mm512_storeu_ps);
impl_binary_f32_avx512!(binary_div_f32, _mm512_div_ps, _mm512_storeu_ps);
impl_binary_f32_avx512!(binary_max_f32, _mm512_max_ps, _mm512_storeu_ps);
impl_binary_f32_avx512!(binary_min_f32, _mm512_min_ps, _mm512_storeu_ps);
impl_binary_f32_avx512!(binary_add_f32_stream, _mm512_add_ps, _mm512_stream_ps);
impl_binary_f32_avx512!(binary_sub_f32_stream, _mm512_sub_ps, _mm512_stream_ps);
impl_binary_f32_avx512!(binary_mul_f32_stream, _mm512_mul_ps, _mm512_stream_ps);
impl_binary_f32_avx512!(binary_div_f32_stream, _mm512_div_ps, _mm512_stream_ps);
impl_binary_f32_avx512!(binary_max_f32_stream, _mm512_max_ps, _mm512_stream_ps);
impl_binary_f32_avx512!(binary_min_f32_stream, _mm512_min_ps, _mm512_stream_ps);
impl_binary_f64_avx512!(binary_add_f64, _mm512_add_pd, _mm512_storeu_pd);
impl_binary_f64_avx512!(binary_sub_f64, _mm512_sub_pd, _mm512_storeu_pd);
impl_binary_f64_avx512!(binary_mul_f64, _mm512_mul_pd, _mm512_storeu_pd);
impl_binary_f64_avx512!(binary_div_f64, _mm512_div_pd, _mm512_storeu_pd);
impl_binary_f64_avx512!(binary_max_f64, _mm512_max_pd, _mm512_storeu_pd);
impl_binary_f64_avx512!(binary_min_f64, _mm512_min_pd, _mm512_storeu_pd);
impl_binary_f64_avx512!(binary_add_f64_stream, _mm512_add_pd, _mm512_stream_pd);
impl_binary_f64_avx512!(binary_sub_f64_stream, _mm512_sub_pd, _mm512_stream_pd);
impl_binary_f64_avx512!(binary_mul_f64_stream, _mm512_mul_pd, _mm512_stream_pd);
impl_binary_f64_avx512!(binary_div_f64_stream, _mm512_div_pd, _mm512_stream_pd);
impl_binary_f64_avx512!(binary_max_f64_stream, _mm512_max_pd, _mm512_stream_pd);
impl_binary_f64_avx512!(binary_min_f64_stream, _mm512_min_pd, _mm512_stream_pd);
macro_rules! impl_pow_f32_avx512 {
($name:ident, $store:ident) => {
#[target_feature(enable = "avx512f")]
unsafe fn $name(a: *const f32, b: *const f32, out: *mut f32, chunks: usize) {
for i in 0..chunks {
let offset = i * F32_LANES;
let va = _mm512_loadu_ps(a.add(offset));
let vb = _mm512_loadu_ps(b.add(offset));
let log_a = log_vec_f32(va);
let b_log_a = _mm512_mul_ps(vb, log_a);
let vr = exp_vec_f32(b_log_a);
$store(out.add(offset), vr);
}
}
};
}
macro_rules! impl_pow_f64_avx512 {
($name:ident, $store:ident) => {
#[target_feature(enable = "avx512f")]
unsafe fn $name(a: *const f64, b: *const f64, out: *mut f64, chunks: usize) {
for i in 0..chunks {
let offset = i * F64_LANES;
let va = _mm512_loadu_pd(a.add(offset));
let vb = _mm512_loadu_pd(b.add(offset));
let log_a = log_vec_f64(va);
let b_log_a = _mm512_mul_pd(vb, log_a);
let vr = exp_vec_f64(b_log_a);
$store(out.add(offset), vr);
}
}
};
}
impl_pow_f32_avx512!(binary_pow_f32, _mm512_storeu_ps);
impl_pow_f32_avx512!(binary_pow_f32_stream, _mm512_stream_ps);
impl_pow_f64_avx512!(binary_pow_f64, _mm512_storeu_pd);
impl_pow_f64_avx512!(binary_pow_f64_stream, _mm512_stream_pd);
#[target_feature(enable = "avx512f")]
pub unsafe fn binary_f32(op: BinaryOp, a: *const f32, b: *const f32, out: *mut f32, len: usize) {
let chunks = len / F32_LANES;
let remainder = len % F32_LANES;
let use_streaming = should_stream_f32(len) && is_aligned_avx512(out);
if op == BinaryOp::Atan2 {
binary_scalar_f32(op, a, b, out, len);
return;
}
if use_streaming {
match op {
BinaryOp::Add => binary_add_f32_stream(a, b, out, chunks),
BinaryOp::Sub => binary_sub_f32_stream(a, b, out, chunks),
BinaryOp::Mul => binary_mul_f32_stream(a, b, out, chunks),
BinaryOp::Div => binary_div_f32_stream(a, b, out, chunks),
BinaryOp::Max => binary_max_f32_stream(a, b, out, chunks),
BinaryOp::Min => binary_min_f32_stream(a, b, out, chunks),
BinaryOp::Pow => binary_pow_f32_stream(a, b, out, chunks),
BinaryOp::Atan2 => unreachable!(), }
_mm_sfence();
} else {
match op {
BinaryOp::Add => binary_add_f32(a, b, out, chunks),
BinaryOp::Sub => binary_sub_f32(a, b, out, chunks),
BinaryOp::Mul => binary_mul_f32(a, b, out, chunks),
BinaryOp::Div => binary_div_f32(a, b, out, chunks),
BinaryOp::Max => binary_max_f32(a, b, out, chunks),
BinaryOp::Min => binary_min_f32(a, b, out, chunks),
BinaryOp::Pow => binary_pow_f32(a, b, out, chunks),
BinaryOp::Atan2 => unreachable!(), }
}
if remainder > 0 {
let offset = chunks * F32_LANES;
binary_scalar_f32(op, a.add(offset), b.add(offset), out.add(offset), remainder);
}
}
#[target_feature(enable = "avx512f")]
pub unsafe fn binary_f64(op: BinaryOp, a: *const f64, b: *const f64, out: *mut f64, len: usize) {
let chunks = len / F64_LANES;
let remainder = len % F64_LANES;
let use_streaming = should_stream_f64(len) && is_aligned_avx512(out);
if op == BinaryOp::Atan2 {
binary_scalar_f64(op, a, b, out, len);
return;
}
if use_streaming {
match op {
BinaryOp::Add => binary_add_f64_stream(a, b, out, chunks),
BinaryOp::Sub => binary_sub_f64_stream(a, b, out, chunks),
BinaryOp::Mul => binary_mul_f64_stream(a, b, out, chunks),
BinaryOp::Div => binary_div_f64_stream(a, b, out, chunks),
BinaryOp::Max => binary_max_f64_stream(a, b, out, chunks),
BinaryOp::Min => binary_min_f64_stream(a, b, out, chunks),
BinaryOp::Pow => binary_pow_f64_stream(a, b, out, chunks),
BinaryOp::Atan2 => unreachable!(), }
_mm_sfence();
} else {
match op {
BinaryOp::Add => binary_add_f64(a, b, out, chunks),
BinaryOp::Sub => binary_sub_f64(a, b, out, chunks),
BinaryOp::Mul => binary_mul_f64(a, b, out, chunks),
BinaryOp::Div => binary_div_f64(a, b, out, chunks),
BinaryOp::Max => binary_max_f64(a, b, out, chunks),
BinaryOp::Min => binary_min_f64(a, b, out, chunks),
BinaryOp::Pow => binary_pow_f64(a, b, out, chunks),
BinaryOp::Atan2 => unreachable!(), }
}
if remainder > 0 {
let offset = chunks * F64_LANES;
binary_scalar_f64(op, a.add(offset), b.add(offset), out.add(offset), remainder);
}
}