#[cfg(feature = "no-std")]
use alloc::vec;
#[cfg(feature = "no-std")]
use alloc::vec::Vec;
#[cfg(not(feature = "no-std"))]
use std::vec::Vec;
use crate::simd_types::*;
use crate::simd_utils::*;
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("avx512f") {
return unsafe { dot_product_avx512(a, b) };
} else if crate::simd_feature_detected!("avx2") {
return unsafe { dot_product_avx2(a, b) };
} else if crate::simd_feature_detected!("sse2") {
return unsafe { dot_product_sse2(a, b) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { dot_product_neon(a, b) };
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86", target_arch = "x86_64")))]
{
dot_product_scalar(a, b)
}
}
pub fn norm_l2(x: &[f32]) -> f32 {
dot_product(x, x).sqrt()
}
pub fn norm(vector: &[f32]) -> f32 {
norm_l2(vector)
}
pub fn scale(vector: &mut [f32], scalar: f32) {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("avx512f") {
return unsafe { scale_avx512(vector, scalar) };
} else if crate::simd_feature_detected!("avx2") {
return unsafe { scale_avx2(vector, scalar) };
} else if crate::simd_feature_detected!("sse2") {
return unsafe { scale_sse2(vector, scalar) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { scale_neon(vector, scalar) };
}
for element in vector.iter_mut() {
*element *= scalar;
}
}
pub fn fma(a: &mut [f32], b: &[f32], c: &[f32]) {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
assert_eq!(a.len(), c.len(), "Vectors must have the same length");
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("fma") {
return unsafe { fma_avx2_fma(a, b, c) };
} else if crate::simd_feature_detected!("avx2") {
return unsafe { fma_avx2(a, b, c) };
} else if crate::simd_feature_detected!("sse2") {
return unsafe { fma_sse2(a, b, c) };
}
}
#[cfg(target_arch = "aarch64")]
{
return unsafe { fma_neon(a, b, c) };
}
for i in 0..a.len() {
a[i] = a[i] * b[i] + c[i];
}
}
pub fn add_vectors(a: &[f32], b: &[f32]) -> Vec<f32> {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
let mut result = vec![0.0; a.len()];
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("avx2") {
unsafe { add_avx2(a, b, &mut result) };
return result;
} else if crate::simd_feature_detected!("sse2") {
unsafe { add_sse2(a, b, &mut result) };
return result;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { add_neon(a, b, &mut result) };
return result;
}
for i in 0..a.len() {
result[i] = a[i] + b[i];
}
result
}
pub fn fused_multiply_add(a: &[f32], b: &[f32], c: &[f32]) -> Vec<f32> {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
assert_eq!(a.len(), c.len(), "Vectors must have the same length");
let mut result = a.to_vec();
fma(&mut result, b, c);
result
}
pub fn subtract_vectors(a: &[f32], b: &[f32]) -> Vec<f32> {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
let mut result = vec![0.0; a.len()];
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("avx2") {
unsafe { subtract_avx2(a, b, &mut result) };
return result;
} else if crate::simd_feature_detected!("sse2") {
unsafe { subtract_sse2(a, b, &mut result) };
return result;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { subtract_neon(a, b, &mut result) };
return result;
}
for i in 0..a.len() {
result[i] = a[i] - b[i];
}
result
}
pub fn multiply_vectors(a: &[f32], b: &[f32]) -> Vec<f32> {
assert_eq!(a.len(), b.len(), "Vectors must have the same length");
let mut result = vec![0.0; a.len()];
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
if crate::simd_feature_detected!("avx2") {
unsafe { multiply_avx2(a, b, &mut result) };
return result;
} else if crate::simd_feature_detected!("sse2") {
unsafe { multiply_sse2(a, b, &mut result) };
return result;
}
}
#[cfg(target_arch = "aarch64")]
{
unsafe { multiply_neon(a, b, &mut result) };
return result;
}
for i in 0..a.len() {
result[i] = a[i] * b[i];
}
result
}
pub fn scale_vector(vector: &[f32], scalar: f32) -> Vec<f32> {
let mut result = vector.to_vec();
scale(&mut result, scalar);
result
}
fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[allow(non_snake_case)]
#[cfg(all(test, not(feature = "no-std")))]
mod tests {
use super::*;
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0, 8.0];
let result = dot_product(&a, &b);
assert!((result - 70.0).abs() < 1e-6);
}
#[test]
fn test_norm_l2() {
let x = vec![3.0, 4.0];
let result = norm_l2(&x);
assert!((result - 5.0).abs() < 1e-6);
}
#[test]
fn test_scale() {
let mut x = vec![1.0, 2.0, 3.0, 4.0];
scale(&mut x, 2.0);
assert_eq!(x, vec![2.0, 4.0, 6.0, 8.0]);
}
#[test]
fn test_fma() {
let mut a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let c = vec![7.0, 8.0, 9.0];
fma(&mut a, &b, &c);
assert_eq!(a, vec![11.0, 18.0, 27.0]);
}
#[test]
fn test_add_vectors() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = add_vectors(&a, &b);
assert_eq!(result, vec![5.0, 7.0, 9.0]);
}
#[test]
fn test_subtract_vectors() {
let a = vec![5.0, 7.0, 9.0];
let b = vec![1.0, 2.0, 3.0];
let result = subtract_vectors(&a, &b);
assert_eq!(result, vec![4.0, 5.0, 6.0]);
}
#[test]
fn test_multiply_vectors() {
let a = vec![2.0, 3.0, 4.0];
let b = vec![5.0, 6.0, 7.0];
let result = multiply_vectors(&a, &b);
assert_eq!(result, vec![10.0, 18.0, 28.0]);
}
#[test]
fn test_fused_multiply_add() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let c = vec![7.0, 8.0, 9.0];
let result = fused_multiply_add(&a, &b, &c);
assert_eq!(result, vec![11.0, 18.0, 27.0]);
}
#[test]
fn test_scale_vector() {
let x = vec![1.0, 2.0, 3.0];
let result = scale_vector(&x, 3.0);
assert_eq!(result, vec![3.0, 6.0, 9.0]);
}
#[test]
#[should_panic(expected = "Vectors must have the same length")]
fn test_dot_product_mismatched_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
dot_product(&a, &b);
}
}