#![allow(unsafe_code)]
use crate::error::{AlgorithmError, Result};
fn validate_unary(data: &[f32], out: &[f32]) -> Result<()> {
if data.len() != out.len() {
return Err(AlgorithmError::InvalidParameter {
parameter: "input",
message: format!(
"Slice length mismatch: data={}, out={}",
data.len(),
out.len()
),
});
}
Ok(())
}
#[cfg(target_arch = "aarch64")]
mod neon_impl {
use std::arch::aarch64::*;
#[target_feature(enable = "neon")]
pub(crate) unsafe fn sqrt_f32(data: &[f32], out: &mut [f32]) {
unsafe {
let len = data.len();
let chunks = len / 4;
let d_ptr = data.as_ptr();
let o_ptr = out.as_mut_ptr();
for i in 0..chunks {
let off = i * 4;
let vd = vld1q_f32(d_ptr.add(off));
let vr = vsqrtq_f32(vd);
vst1q_f32(o_ptr.add(off), vr);
}
let rem = chunks * 4;
for i in rem..len {
*o_ptr.add(i) = (*d_ptr.add(i)).sqrt();
}
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn abs_f32(data: &[f32], out: &mut [f32]) {
unsafe {
let len = data.len();
let chunks = len / 4;
let d_ptr = data.as_ptr();
let o_ptr = out.as_mut_ptr();
for i in 0..chunks {
let off = i * 4;
let vd = vld1q_f32(d_ptr.add(off));
let vr = vabsq_f32(vd);
vst1q_f32(o_ptr.add(off), vr);
}
let rem = chunks * 4;
for i in rem..len {
*o_ptr.add(i) = (*d_ptr.add(i)).abs();
}
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn floor_f32(data: &[f32], out: &mut [f32]) {
unsafe {
let len = data.len();
let chunks = len / 4;
let d_ptr = data.as_ptr();
let o_ptr = out.as_mut_ptr();
for i in 0..chunks {
let off = i * 4;
let vd = vld1q_f32(d_ptr.add(off));
let vr = vrndmq_f32(vd);
vst1q_f32(o_ptr.add(off), vr);
}
let rem = chunks * 4;
for i in rem..len {
*o_ptr.add(i) = (*d_ptr.add(i)).floor();
}
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn ceil_f32(data: &[f32], out: &mut [f32]) {
unsafe {
let len = data.len();
let chunks = len / 4;
let d_ptr = data.as_ptr();
let o_ptr = out.as_mut_ptr();
for i in 0..chunks {
let off = i * 4;
let vd = vld1q_f32(d_ptr.add(off));
let vr = vrndpq_f32(vd);
vst1q_f32(o_ptr.add(off), vr);
}
let rem = chunks * 4;
for i in rem..len {
*o_ptr.add(i) = (*d_ptr.add(i)).ceil();
}
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn round_f32(data: &[f32], out: &mut [f32]) {
unsafe {
let len = data.len();
let chunks = len / 4;
let d_ptr = data.as_ptr();
let o_ptr = out.as_mut_ptr();
for i in 0..chunks {
let off = i * 4;
let vd = vld1q_f32(d_ptr.add(off));
let vr = vrndaq_f32(vd);
vst1q_f32(o_ptr.add(off), vr);
}
let rem = chunks * 4;
for i in rem..len {
*o_ptr.add(i) = (*d_ptr.add(i)).round();
}
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn exp_f32(data: &[f32], out: &mut [f32]) {
for i in 0..data.len() {
out[i] = data[i].exp();
}
}
#[target_feature(enable = "neon")]
pub(crate) unsafe fn ln_f32(data: &[f32], out: &mut [f32]) {
for i in 0..data.len() {
out[i] = data[i].ln();
}
}
}
mod scalar_impl {
pub(crate) fn apply_unary(data: &[f32], out: &mut [f32], f: fn(f32) -> f32) {
const LANES: usize = 8;
let chunks = data.len() / LANES;
for i in 0..chunks {
let start = i * LANES;
let end = start + LANES;
for j in start..end {
out[j] = f(data[j]);
}
}
let remainder_start = chunks * LANES;
for i in remainder_start..data.len() {
out[i] = f(data[i]);
}
}
pub(crate) fn apply_binary(a: &[f32], b: &[f32], out: &mut [f32], f: fn(f32, f32) -> f32) {
const LANES: usize = 8;
let chunks = a.len() / LANES;
for i in 0..chunks {
let start = i * LANES;
let end = start + LANES;
for j in start..end {
out[j] = f(a[j], b[j]);
}
}
let remainder_start = chunks * LANES;
for i in remainder_start..a.len() {
out[i] = f(a[i], b[i]);
}
}
}
pub fn sqrt_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::sqrt_f32(data, out);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::sqrt);
}
Ok(())
}
pub fn ln_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::ln_f32(data, out);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::ln);
}
Ok(())
}
pub fn log10_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::ln_f32(data, out);
}
let log10e = std::f32::consts::LOG10_E;
for val in out.iter_mut() {
*val *= log10e;
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::log10);
}
Ok(())
}
pub fn log2_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::ln_f32(data, out);
}
let log2e = std::f32::consts::LOG2_E;
for val in out.iter_mut() {
*val *= log2e;
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::log2);
}
Ok(())
}
pub fn exp_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::exp_f32(data, out);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::exp);
}
Ok(())
}
pub fn exp2_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::exp2);
Ok(())
}
pub fn pow_f32(base: &[f32], exponent: &[f32], out: &mut [f32]) -> Result<()> {
if base.len() != exponent.len() || base.len() != out.len() {
return Err(AlgorithmError::InvalidParameter {
parameter: "input",
message: "Slice length mismatch".to_string(),
});
}
scalar_impl::apply_binary(base, exponent, out, f32::powf);
Ok(())
}
pub fn sin_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::sin);
Ok(())
}
pub fn cos_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::cos);
Ok(())
}
pub fn tan_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::tan);
Ok(())
}
pub fn asin_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::asin);
Ok(())
}
pub fn acos_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::acos);
Ok(())
}
pub fn atan_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::atan);
Ok(())
}
pub fn atan2_f32(y: &[f32], x: &[f32], out: &mut [f32]) -> Result<()> {
if y.len() != x.len() || y.len() != out.len() {
return Err(AlgorithmError::InvalidParameter {
parameter: "input",
message: "Slice length mismatch".to_string(),
});
}
scalar_impl::apply_binary(y, x, out, f32::atan2);
Ok(())
}
pub fn sinh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::sinh);
Ok(())
}
pub fn cosh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::cosh);
Ok(())
}
pub fn tanh_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
scalar_impl::apply_unary(data, out, f32::tanh);
Ok(())
}
pub fn abs_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::abs_f32(data, out);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::abs);
}
Ok(())
}
pub fn floor_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::floor_f32(data, out);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::floor);
}
Ok(())
}
pub fn ceil_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::ceil_f32(data, out);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::ceil);
}
Ok(())
}
pub fn round_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
#[cfg(target_arch = "aarch64")]
{
unsafe {
neon_impl::round_f32(data, out);
}
}
#[cfg(not(target_arch = "aarch64"))]
{
scalar_impl::apply_unary(data, out, f32::round);
}
Ok(())
}
pub fn fract_f32(data: &[f32], out: &mut [f32]) -> Result<()> {
validate_unary(data, out)?;
floor_f32(data, out)?;
for i in 0..data.len() {
out[i] = data[i] - out[i];
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f32::consts::PI;
#[test]
fn test_sqrt_f32() {
let data = vec![1.0, 4.0, 9.0, 16.0, 25.0];
let mut out = vec![0.0; 5];
sqrt_f32(&data, &mut out).expect("sqrt_f32 failed");
assert_relative_eq!(out[0], 1.0);
assert_relative_eq!(out[1], 2.0);
assert_relative_eq!(out[2], 3.0);
assert_relative_eq!(out[3], 4.0);
assert_relative_eq!(out[4], 5.0);
}
#[test]
fn test_sqrt_large() {
let data = vec![4.0; 1000];
let mut out = vec![0.0; 1000];
sqrt_f32(&data, &mut out).expect("sqrt_f32 large failed");
for &val in &out {
assert_relative_eq!(val, 2.0);
}
}
#[test]
fn test_exp_ln() {
let data = vec![0.0, 1.0, 2.0, 3.0];
let mut exp_out = vec![0.0; 4];
let mut ln_out = vec![0.0; 4];
exp_f32(&data, &mut exp_out).expect("exp_f32 failed");
ln_f32(&exp_out, &mut ln_out).expect("ln_f32 failed");
for i in 0..4 {
assert_relative_eq!(ln_out[i], data[i], epsilon = 1e-5);
}
}
#[test]
fn test_exp_large() {
let data: Vec<f32> = (0..100).map(|i| i as f32 * 0.1).collect();
let mut out = vec![0.0; 100];
exp_f32(&data, &mut out).expect("exp_f32 large failed");
for i in 0..100 {
assert_relative_eq!(out[i], data[i].exp(), epsilon = 1e-4);
}
}
#[test]
fn test_ln_large() {
let data: Vec<f32> = (1..=100).map(|i| i as f32).collect();
let mut out = vec![0.0; 100];
ln_f32(&data, &mut out).expect("ln_f32 large failed");
for i in 0..100 {
assert_relative_eq!(out[i], data[i].ln(), epsilon = 1e-4);
}
}
#[test]
fn test_log10() {
let data = vec![1.0, 10.0, 100.0, 1000.0];
let mut out = vec![0.0; 4];
log10_f32(&data, &mut out).expect("log10_f32 failed");
assert_relative_eq!(out[0], 0.0, epsilon = 1e-5);
assert_relative_eq!(out[1], 1.0, epsilon = 1e-5);
assert_relative_eq!(out[2], 2.0, epsilon = 1e-4);
assert_relative_eq!(out[3], 3.0, epsilon = 1e-4);
}
#[test]
fn test_log2() {
let data = vec![1.0, 2.0, 4.0, 8.0, 16.0];
let mut out = vec![0.0; 5];
log2_f32(&data, &mut out).expect("log2_f32 failed");
assert_relative_eq!(out[0], 0.0, epsilon = 1e-5);
assert_relative_eq!(out[1], 1.0, epsilon = 1e-4);
assert_relative_eq!(out[2], 2.0, epsilon = 1e-4);
assert_relative_eq!(out[3], 3.0, epsilon = 1e-4);
assert_relative_eq!(out[4], 4.0, epsilon = 1e-4);
}
#[test]
fn test_pow() {
let base = vec![2.0, 3.0, 4.0, 5.0];
let exp = vec![2.0, 2.0, 2.0, 2.0];
let mut out = vec![0.0; 4];
pow_f32(&base, &exp, &mut out).expect("pow_f32 failed");
assert_relative_eq!(out[0], 4.0);
assert_relative_eq!(out[1], 9.0);
assert_relative_eq!(out[2], 16.0);
assert_relative_eq!(out[3], 25.0);
}
#[test]
fn test_sin_cos() {
let data = vec![0.0, PI / 6.0, PI / 4.0, PI / 3.0, PI / 2.0];
let mut sin_out = vec![0.0; 5];
let mut cos_out = vec![0.0; 5];
sin_f32(&data, &mut sin_out).expect("sin_f32 failed");
cos_f32(&data, &mut cos_out).expect("cos_f32 failed");
assert_relative_eq!(sin_out[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(sin_out[4], 1.0, epsilon = 1e-6);
assert_relative_eq!(cos_out[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(cos_out[4], 0.0, epsilon = 1e-6);
for i in 0..5 {
let sum = sin_out[i] * sin_out[i] + cos_out[i] * cos_out[i];
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
}
}
#[test]
fn test_tan() {
let data = vec![0.0, PI / 4.0];
let mut out = vec![0.0; 2];
tan_f32(&data, &mut out).expect("tan_f32 failed");
assert_relative_eq!(out[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(out[1], 1.0, epsilon = 1e-6);
}
#[test]
fn test_asin_acos() {
let data = vec![0.0, 0.5, 1.0];
let mut asin_out = vec![0.0; 3];
let mut acos_out = vec![0.0; 3];
asin_f32(&data, &mut asin_out).expect("asin_f32 failed");
acos_f32(&data, &mut acos_out).expect("acos_f32 failed");
assert_relative_eq!(asin_out[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(asin_out[2], PI / 2.0, epsilon = 1e-6);
assert_relative_eq!(acos_out[0], PI / 2.0, epsilon = 1e-6);
assert_relative_eq!(acos_out[2], 0.0, epsilon = 1e-6);
}
#[test]
fn test_atan2() {
let y = vec![0.0, 1.0, 0.0, -1.0];
let x = vec![1.0, 0.0, -1.0, 0.0];
let mut out = vec![0.0; 4];
atan2_f32(&y, &x, &mut out).expect("atan2_f32 failed");
assert_relative_eq!(out[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(out[1], PI / 2.0, epsilon = 1e-6);
assert_relative_eq!(out[2], PI, epsilon = 1e-6);
assert_relative_eq!(out[3], -PI / 2.0, epsilon = 1e-6);
}
#[test]
fn test_hyperbolic() {
let data = vec![0.0, 1.0];
let mut sinh_out = vec![0.0; 2];
let mut cosh_out = vec![0.0; 2];
let mut tanh_out = vec![0.0; 2];
sinh_f32(&data, &mut sinh_out).expect("sinh_f32 failed");
cosh_f32(&data, &mut cosh_out).expect("cosh_f32 failed");
tanh_f32(&data, &mut tanh_out).expect("tanh_f32 failed");
assert_relative_eq!(sinh_out[0], 0.0, epsilon = 1e-6);
assert_relative_eq!(cosh_out[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(tanh_out[0], 0.0, epsilon = 1e-6);
}
#[test]
fn test_abs() {
let data = vec![-1.0, -2.0, 3.0, -4.0, 5.0];
let mut out = vec![0.0; 5];
abs_f32(&data, &mut out).expect("abs_f32 failed");
assert_relative_eq!(out[0], 1.0);
assert_relative_eq!(out[1], 2.0);
assert_relative_eq!(out[2], 3.0);
assert_relative_eq!(out[3], 4.0);
assert_relative_eq!(out[4], 5.0);
}
#[test]
fn test_abs_large() {
let data: Vec<f32> = (-500..500).map(|i| i as f32).collect();
let mut out = vec![0.0; 1000];
abs_f32(&data, &mut out).expect("abs_f32 large failed");
for i in 0..1000 {
assert_relative_eq!(out[i], (data[i]).abs());
}
}
#[test]
fn test_floor_ceil_round() {
let data = vec![1.2, 1.7, -1.2, -1.7];
let mut floor_out = vec![0.0; 4];
let mut ceil_out = vec![0.0; 4];
let mut round_out = vec![0.0; 4];
floor_f32(&data, &mut floor_out).expect("floor_f32 failed");
ceil_f32(&data, &mut ceil_out).expect("ceil_f32 failed");
round_f32(&data, &mut round_out).expect("round_f32 failed");
assert_relative_eq!(floor_out[0], 1.0);
assert_relative_eq!(floor_out[1], 1.0);
assert_relative_eq!(floor_out[2], -2.0);
assert_relative_eq!(floor_out[3], -2.0);
assert_relative_eq!(ceil_out[0], 2.0);
assert_relative_eq!(ceil_out[1], 2.0);
assert_relative_eq!(ceil_out[2], -1.0);
assert_relative_eq!(ceil_out[3], -1.0);
assert_relative_eq!(round_out[0], 1.0);
assert_relative_eq!(round_out[1], 2.0);
assert_relative_eq!(round_out[2], -1.0);
assert_relative_eq!(round_out[3], -2.0);
}
#[test]
fn test_fract() {
let data = vec![1.3, 2.7, -1.3, -2.7];
let mut out = vec![0.0; 4];
fract_f32(&data, &mut out).expect("fract_f32 failed");
assert_relative_eq!(out[0], 0.3, epsilon = 1e-6);
assert_relative_eq!(out[1], 0.7, epsilon = 1e-6);
assert_relative_eq!(out[2], 0.7, epsilon = 1e-6);
assert_relative_eq!(out[3], 0.3, epsilon = 1e-6);
}
#[test]
fn test_length_mismatch() {
let data = vec![1.0; 10];
let mut out = vec![0.0; 5];
assert!(sqrt_f32(&data, &mut out).is_err());
}
}