use pulp::{Arch, Simd, WithSimd};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ExecutionBackend {
Scalar,
#[default]
Simd,
#[cfg(feature = "gpu")]
Wgpu,
}
impl ExecutionBackend {
pub fn name(self) -> &'static str {
match self {
Self::Scalar => "scalar",
Self::Simd => "simd",
#[cfg(feature = "gpu")]
Self::Wgpu => "wgpu",
}
}
pub fn is_gpu(self) -> bool {
#[cfg(feature = "gpu")]
if matches!(self, Self::Wgpu) {
return true;
}
false
}
}
pub fn dot(backend: ExecutionBackend, lhs: &[f64], rhs: &[f64]) -> f64 {
assert_eq!(
lhs.len(),
rhs.len(),
"dot: length mismatch ({} vs {})",
lhs.len(),
rhs.len()
);
match backend {
ExecutionBackend::Scalar => dot_scalar(lhs, rhs),
ExecutionBackend::Simd => dot_simd(lhs, rhs),
#[cfg(feature = "gpu")]
ExecutionBackend::Wgpu => dot_simd(lhs, rhs),
}
}
pub fn squared_l2_norm(backend: ExecutionBackend, values: &[f64]) -> f64 {
match backend {
ExecutionBackend::Scalar => values.iter().map(|value| value * value).sum(),
ExecutionBackend::Simd => dot_simd(values, values),
#[cfg(feature = "gpu")]
ExecutionBackend::Wgpu => dot_simd(values, values),
}
}
pub fn sum_squared_error(backend: ExecutionBackend, lhs: &[f64], rhs: &[f64]) -> f64 {
assert_eq!(
lhs.len(),
rhs.len(),
"sum_squared_error: length mismatch ({} vs {})",
lhs.len(),
rhs.len()
);
match backend {
ExecutionBackend::Scalar => lhs
.iter()
.zip(rhs.iter())
.map(|(left, right)| {
let delta = left - right;
delta * delta
})
.sum(),
ExecutionBackend::Simd => sum_squared_error_simd(lhs, rhs),
#[cfg(feature = "gpu")]
ExecutionBackend::Wgpu => sum_squared_error_simd(lhs, rhs),
}
}
pub fn weighted_sum_in_place(
backend: ExecutionBackend,
output: &mut [f64],
weight: f64,
values: &[f64],
) {
assert_eq!(
output.len(),
values.len(),
"weighted_sum_in_place: length mismatch ({} vs {})",
output.len(),
values.len()
);
match backend {
ExecutionBackend::Scalar => {
for (slot, value) in output.iter_mut().zip(values.iter()) {
*slot += weight * *value;
}
}
ExecutionBackend::Simd => weighted_sum_in_place_simd(output, weight, values),
#[cfg(feature = "gpu")]
ExecutionBackend::Wgpu => weighted_sum_in_place_simd(output, weight, values),
}
}
fn dot_scalar(lhs: &[f64], rhs: &[f64]) -> f64 {
lhs.iter()
.zip(rhs.iter())
.map(|(left, right)| left * right)
.sum()
}
fn dot_simd(lhs: &[f64], rhs: &[f64]) -> f64 {
struct Dot<'a> {
lhs: &'a [f64],
rhs: &'a [f64],
}
impl WithSimd for Dot<'_> {
type Output = f64;
#[inline(always)]
fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
let Self { lhs, rhs } = self;
let (lhs_head, lhs_tail) = S::as_simd_f64s(lhs);
let (rhs_head, rhs_tail) = S::as_simd_f64s(rhs);
let mut acc0 = simd.splat_f64s(0.0);
let mut acc1 = simd.splat_f64s(0.0);
let mut acc2 = simd.splat_f64s(0.0);
let mut acc3 = simd.splat_f64s(0.0);
let (lhs_4, lhs_1) = pulp::as_arrays::<4, _>(lhs_head);
let (rhs_4, rhs_1) = pulp::as_arrays::<4, _>(rhs_head);
for ([lhs0, lhs1, lhs2, lhs3], [rhs0, rhs1, rhs2, rhs3]) in
lhs_4.iter().zip(rhs_4.iter())
{
acc0 = simd.mul_add_f64s(*lhs0, *rhs0, acc0);
acc1 = simd.mul_add_f64s(*lhs1, *rhs1, acc1);
acc2 = simd.mul_add_f64s(*lhs2, *rhs2, acc2);
acc3 = simd.mul_add_f64s(*lhs3, *rhs3, acc3);
}
for (lhs0, rhs0) in lhs_1.iter().zip(rhs_1.iter()) {
acc0 = simd.mul_add_f64s(*lhs0, *rhs0, acc0);
}
acc0 = simd.add_f64s(acc0, acc1);
acc2 = simd.add_f64s(acc2, acc3);
acc0 = simd.add_f64s(acc0, acc2);
let mut acc = simd.reduce_sum_f64s(acc0);
for (left, right) in lhs_tail.iter().zip(rhs_tail.iter()) {
acc += left * right;
}
acc
}
}
Arch::new().dispatch(Dot { lhs, rhs })
}
fn sum_squared_error_simd(lhs: &[f64], rhs: &[f64]) -> f64 {
struct SquaredError<'a> {
lhs: &'a [f64],
rhs: &'a [f64],
}
impl WithSimd for SquaredError<'_> {
type Output = f64;
#[inline(always)]
fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
let Self { lhs, rhs } = self;
let (lhs_head, lhs_tail) = S::as_simd_f64s(lhs);
let (rhs_head, rhs_tail) = S::as_simd_f64s(rhs);
let mut acc0 = simd.splat_f64s(0.0);
let mut acc1 = simd.splat_f64s(0.0);
let mut acc2 = simd.splat_f64s(0.0);
let mut acc3 = simd.splat_f64s(0.0);
let (lhs_4, lhs_1) = pulp::as_arrays::<4, _>(lhs_head);
let (rhs_4, rhs_1) = pulp::as_arrays::<4, _>(rhs_head);
for ([lhs0, lhs1, lhs2, lhs3], [rhs0, rhs1, rhs2, rhs3]) in
lhs_4.iter().zip(rhs_4.iter())
{
let diff0 = simd.sub_f64s(*lhs0, *rhs0);
let diff1 = simd.sub_f64s(*lhs1, *rhs1);
let diff2 = simd.sub_f64s(*lhs2, *rhs2);
let diff3 = simd.sub_f64s(*lhs3, *rhs3);
acc0 = simd.mul_add_f64s(diff0, diff0, acc0);
acc1 = simd.mul_add_f64s(diff1, diff1, acc1);
acc2 = simd.mul_add_f64s(diff2, diff2, acc2);
acc3 = simd.mul_add_f64s(diff3, diff3, acc3);
}
for (lhs0, rhs0) in lhs_1.iter().zip(rhs_1.iter()) {
let diff = simd.sub_f64s(*lhs0, *rhs0);
acc0 = simd.mul_add_f64s(diff, diff, acc0);
}
acc0 = simd.add_f64s(acc0, acc1);
acc2 = simd.add_f64s(acc2, acc3);
acc0 = simd.add_f64s(acc0, acc2);
let mut acc = simd.reduce_sum_f64s(acc0);
for (left, right) in lhs_tail.iter().zip(rhs_tail.iter()) {
let diff = left - right;
acc += diff * diff;
}
acc
}
}
Arch::new().dispatch(SquaredError { lhs, rhs })
}
fn weighted_sum_in_place_simd(output: &mut [f64], weight: f64, values: &[f64]) {
struct WeightedSum<'a> {
output: &'a mut [f64],
weight: f64,
values: &'a [f64],
}
impl WithSimd for WeightedSum<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
let Self {
output,
weight,
values,
} = self;
let (output_head, output_tail) = S::as_mut_simd_f64s(output);
let (values_head, values_tail) = S::as_simd_f64s(values);
let scale = simd.splat_f64s(weight);
for (dst, src) in output_head.iter_mut().zip(values_head.iter()) {
*dst = simd.mul_add_f64s(*src, scale, *dst);
}
for (dst, src) in output_tail.iter_mut().zip(values_tail.iter()) {
*dst += weight * *src;
}
}
}
Arch::new().dispatch(WeightedSum {
output,
weight,
values,
});
}
#[cfg(test)]
mod tests {
use super::{dot, squared_l2_norm, sum_squared_error, weighted_sum_in_place, ExecutionBackend};
use approx::assert_abs_diff_eq;
fn sample_values() -> (Vec<f64>, Vec<f64>) {
let lhs: Vec<f64> = (0..257).map(|i| ((i as f64) * 0.03125).sin()).collect();
let rhs: Vec<f64> = (0..257).map(|i| ((i as f64) * 0.015625).cos()).collect();
(lhs, rhs)
}
#[test]
fn simd_dot_matches_scalar() {
let (lhs, rhs) = sample_values();
let scalar = dot(ExecutionBackend::Scalar, &lhs, &rhs);
let simd = dot(ExecutionBackend::Simd, &lhs, &rhs);
assert_abs_diff_eq!(scalar, simd, epsilon = 1e-12);
}
#[test]
fn simd_squared_norm_matches_scalar() {
let (lhs, _) = sample_values();
let scalar = squared_l2_norm(ExecutionBackend::Scalar, &lhs);
let simd = squared_l2_norm(ExecutionBackend::Simd, &lhs);
assert_abs_diff_eq!(scalar, simd, epsilon = 1e-12);
}
#[test]
fn simd_squared_error_matches_scalar() {
let (lhs, rhs) = sample_values();
let scalar = sum_squared_error(ExecutionBackend::Scalar, &lhs, &rhs);
let simd = sum_squared_error(ExecutionBackend::Simd, &lhs, &rhs);
assert_abs_diff_eq!(scalar, simd, epsilon = 1e-12);
}
#[test]
fn simd_weighted_sum_matches_scalar() {
let (lhs, rhs) = sample_values();
let mut scalar = lhs.clone();
let mut simd = lhs;
weighted_sum_in_place(ExecutionBackend::Scalar, &mut scalar, 0.37, &rhs);
weighted_sum_in_place(ExecutionBackend::Simd, &mut simd, 0.37, &rhs);
for (left, right) in scalar.iter().zip(simd.iter()) {
assert_abs_diff_eq!(left, right, epsilon = 1e-12);
}
}
}