use std::iter::Sum;
use std::sync::Arc;
use arrow_array::Float32Array;
use half::{bf16, f16};
use num_traits::real::Real;
#[inline]
pub fn dot<T: Real + Sum>(from: &[T], to: &[T]) -> T {
from.iter().zip(to.iter()).map(|(x, y)| x.mul(*y)).sum()
}
pub trait Dot {
type Output;
fn dot(&self, other: &Self) -> Self::Output;
}
impl Dot for [bf16] {
type Output = bf16;
fn dot(&self, other: &[bf16]) -> bf16 {
dot(self, other)
}
}
impl Dot for [f16] {
type Output = f16;
fn dot(&self, other: &[f16]) -> f16 {
dot(self, other)
}
}
impl Dot for [f32] {
type Output = f32;
fn dot(&self, other: &[f32]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("fma") {
return x86_64::avx::dot_f32(self, other);
}
}
dot(self, other)
}
}
impl Dot for [f64] {
type Output = f64;
fn dot(&self, other: &[f64]) -> f64 {
dot(self, other)
}
}
pub fn dot_distance_batch(from: &[f32], to: &[f32], dimension: usize) -> Arc<Float32Array> {
debug_assert_eq!(from.len(), dimension);
debug_assert_eq!(to.len() % dimension, 0);
let dists = unsafe {
Float32Array::from_trusted_len_iter(
to.chunks_exact(dimension)
.map(|v| Some(dot_distance(from, v))),
)
};
Arc::new(dists)
}
#[inline]
pub fn dot_distance(from: &[f32], to: &[f32]) -> f32 {
-from.dot(to)
}
#[cfg(target_arch = "x86_64")]
mod x86_64 {
pub mod avx {
use crate::linalg::x86_64::avx::*;
use std::arch::x86_64::*;
#[inline]
pub fn dot_f32(x: &[f32], y: &[f32]) -> f32 {
let len = x.len() / 8 * 8;
let mut sum = unsafe {
let mut sums = _mm256_setzero_ps();
x.chunks_exact(8).zip(y.chunks_exact(8)).for_each(|(a, b)| {
let x = _mm256_loadu_ps(a.as_ptr());
let y = _mm256_loadu_ps(b.as_ptr());
sums = _mm256_fmadd_ps(x, y, sums);
});
add_f32_register(sums)
};
sum += x[len..]
.iter()
.zip(y[len..].iter())
.map(|(a, b)| a * b)
.sum::<f32>();
sum
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use num_traits::FromPrimitive;
#[test]
fn test_dot() {
let x: Vec<f32> = (0..20).map(|v| v as f32).collect();
let y: Vec<f32> = (100..120).map(|v| v as f32).collect();
assert_eq!(x.dot(&y), dot(&x, &y));
let x: Vec<f16> = (0..20).map(|v| f16::from_i32(v).unwrap()).collect();
let y: Vec<f16> = (100..120).map(|v| f16::from_i32(v).unwrap()).collect();
assert_eq!(x.dot(&y), dot(&x, &y));
let x: Vec<f64> = (20..40).map(|v| f64::from_i32(v).unwrap()).collect();
let y: Vec<f64> = (120..140).map(|v| f64::from_i32(v).unwrap()).collect();
assert_eq!(x.dot(&y), dot(&x, &y));
}
}