turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use pulp::{Arch, Simd, WithSimd};
use serde::{Deserialize, Serialize};

/// Execution backend used by batch helpers and benchmark code.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum ExecutionBackend {
    /// Portable scalar implementation.
    Scalar,
    /// Runtime-dispatched SIMD implementation using `pulp`.
    #[default]
    Simd,
    /// Optional Burn/WGPU batch path.
    #[cfg(feature = "gpu")]
    Wgpu,
}

impl ExecutionBackend {
    pub fn name(self) -> &'static str {
        match self {
            Self::Scalar => "scalar",
            Self::Simd => "simd",
            #[cfg(feature = "gpu")]
            Self::Wgpu => "wgpu",
        }
    }

    pub fn is_gpu(self) -> bool {
        #[cfg(feature = "gpu")]
        if matches!(self, Self::Wgpu) {
            return true;
        }

        false
    }
}

pub fn dot(backend: ExecutionBackend, lhs: &[f64], rhs: &[f64]) -> f64 {
    assert_eq!(
        lhs.len(),
        rhs.len(),
        "dot: length mismatch ({} vs {})",
        lhs.len(),
        rhs.len()
    );

    match backend {
        ExecutionBackend::Scalar => dot_scalar(lhs, rhs),
        ExecutionBackend::Simd => dot_simd(lhs, rhs),
        #[cfg(feature = "gpu")]
        ExecutionBackend::Wgpu => dot_simd(lhs, rhs),
    }
}

pub fn squared_l2_norm(backend: ExecutionBackend, values: &[f64]) -> f64 {
    match backend {
        ExecutionBackend::Scalar => values.iter().map(|value| value * value).sum(),
        ExecutionBackend::Simd => dot_simd(values, values),
        #[cfg(feature = "gpu")]
        ExecutionBackend::Wgpu => dot_simd(values, values),
    }
}

pub fn sum_squared_error(backend: ExecutionBackend, lhs: &[f64], rhs: &[f64]) -> f64 {
    assert_eq!(
        lhs.len(),
        rhs.len(),
        "sum_squared_error: length mismatch ({} vs {})",
        lhs.len(),
        rhs.len()
    );

    match backend {
        ExecutionBackend::Scalar => lhs
            .iter()
            .zip(rhs.iter())
            .map(|(left, right)| {
                let delta = left - right;
                delta * delta
            })
            .sum(),
        ExecutionBackend::Simd => sum_squared_error_simd(lhs, rhs),
        #[cfg(feature = "gpu")]
        ExecutionBackend::Wgpu => sum_squared_error_simd(lhs, rhs),
    }
}

pub fn weighted_sum_in_place(
    backend: ExecutionBackend,
    output: &mut [f64],
    weight: f64,
    values: &[f64],
) {
    assert_eq!(
        output.len(),
        values.len(),
        "weighted_sum_in_place: length mismatch ({} vs {})",
        output.len(),
        values.len()
    );

    match backend {
        ExecutionBackend::Scalar => {
            for (slot, value) in output.iter_mut().zip(values.iter()) {
                *slot += weight * *value;
            }
        }
        ExecutionBackend::Simd => weighted_sum_in_place_simd(output, weight, values),
        #[cfg(feature = "gpu")]
        ExecutionBackend::Wgpu => weighted_sum_in_place_simd(output, weight, values),
    }
}

fn dot_scalar(lhs: &[f64], rhs: &[f64]) -> f64 {
    lhs.iter()
        .zip(rhs.iter())
        .map(|(left, right)| left * right)
        .sum()
}

fn dot_simd(lhs: &[f64], rhs: &[f64]) -> f64 {
    struct Dot<'a> {
        lhs: &'a [f64],
        rhs: &'a [f64],
    }

    impl WithSimd for Dot<'_> {
        type Output = f64;

        #[inline(always)]
        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
            let Self { lhs, rhs } = self;
            let (lhs_head, lhs_tail) = S::as_simd_f64s(lhs);
            let (rhs_head, rhs_tail) = S::as_simd_f64s(rhs);

            let mut acc0 = simd.splat_f64s(0.0);
            let mut acc1 = simd.splat_f64s(0.0);
            let mut acc2 = simd.splat_f64s(0.0);
            let mut acc3 = simd.splat_f64s(0.0);

            let (lhs_4, lhs_1) = pulp::as_arrays::<4, _>(lhs_head);
            let (rhs_4, rhs_1) = pulp::as_arrays::<4, _>(rhs_head);

            for ([lhs0, lhs1, lhs2, lhs3], [rhs0, rhs1, rhs2, rhs3]) in
                lhs_4.iter().zip(rhs_4.iter())
            {
                acc0 = simd.mul_add_f64s(*lhs0, *rhs0, acc0);
                acc1 = simd.mul_add_f64s(*lhs1, *rhs1, acc1);
                acc2 = simd.mul_add_f64s(*lhs2, *rhs2, acc2);
                acc3 = simd.mul_add_f64s(*lhs3, *rhs3, acc3);
            }

            for (lhs0, rhs0) in lhs_1.iter().zip(rhs_1.iter()) {
                acc0 = simd.mul_add_f64s(*lhs0, *rhs0, acc0);
            }

            acc0 = simd.add_f64s(acc0, acc1);
            acc2 = simd.add_f64s(acc2, acc3);
            acc0 = simd.add_f64s(acc0, acc2);

            let mut acc = simd.reduce_sum_f64s(acc0);
            for (left, right) in lhs_tail.iter().zip(rhs_tail.iter()) {
                acc += left * right;
            }

            acc
        }
    }

    Arch::new().dispatch(Dot { lhs, rhs })
}

fn sum_squared_error_simd(lhs: &[f64], rhs: &[f64]) -> f64 {
    struct SquaredError<'a> {
        lhs: &'a [f64],
        rhs: &'a [f64],
    }

    impl WithSimd for SquaredError<'_> {
        type Output = f64;

        #[inline(always)]
        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
            let Self { lhs, rhs } = self;
            let (lhs_head, lhs_tail) = S::as_simd_f64s(lhs);
            let (rhs_head, rhs_tail) = S::as_simd_f64s(rhs);

            let mut acc0 = simd.splat_f64s(0.0);
            let mut acc1 = simd.splat_f64s(0.0);
            let mut acc2 = simd.splat_f64s(0.0);
            let mut acc3 = simd.splat_f64s(0.0);

            let (lhs_4, lhs_1) = pulp::as_arrays::<4, _>(lhs_head);
            let (rhs_4, rhs_1) = pulp::as_arrays::<4, _>(rhs_head);

            for ([lhs0, lhs1, lhs2, lhs3], [rhs0, rhs1, rhs2, rhs3]) in
                lhs_4.iter().zip(rhs_4.iter())
            {
                let diff0 = simd.sub_f64s(*lhs0, *rhs0);
                let diff1 = simd.sub_f64s(*lhs1, *rhs1);
                let diff2 = simd.sub_f64s(*lhs2, *rhs2);
                let diff3 = simd.sub_f64s(*lhs3, *rhs3);

                acc0 = simd.mul_add_f64s(diff0, diff0, acc0);
                acc1 = simd.mul_add_f64s(diff1, diff1, acc1);
                acc2 = simd.mul_add_f64s(diff2, diff2, acc2);
                acc3 = simd.mul_add_f64s(diff3, diff3, acc3);
            }

            for (lhs0, rhs0) in lhs_1.iter().zip(rhs_1.iter()) {
                let diff = simd.sub_f64s(*lhs0, *rhs0);
                acc0 = simd.mul_add_f64s(diff, diff, acc0);
            }

            acc0 = simd.add_f64s(acc0, acc1);
            acc2 = simd.add_f64s(acc2, acc3);
            acc0 = simd.add_f64s(acc0, acc2);

            let mut acc = simd.reduce_sum_f64s(acc0);
            for (left, right) in lhs_tail.iter().zip(rhs_tail.iter()) {
                let diff = left - right;
                acc += diff * diff;
            }

            acc
        }
    }

    Arch::new().dispatch(SquaredError { lhs, rhs })
}

fn weighted_sum_in_place_simd(output: &mut [f64], weight: f64, values: &[f64]) {
    struct WeightedSum<'a> {
        output: &'a mut [f64],
        weight: f64,
        values: &'a [f64],
    }

    impl WithSimd for WeightedSum<'_> {
        type Output = ();

        #[inline(always)]
        fn with_simd<S: Simd>(self, simd: S) -> Self::Output {
            let Self {
                output,
                weight,
                values,
            } = self;
            let (output_head, output_tail) = S::as_mut_simd_f64s(output);
            let (values_head, values_tail) = S::as_simd_f64s(values);
            let scale = simd.splat_f64s(weight);

            for (dst, src) in output_head.iter_mut().zip(values_head.iter()) {
                *dst = simd.mul_add_f64s(*src, scale, *dst);
            }

            for (dst, src) in output_tail.iter_mut().zip(values_tail.iter()) {
                *dst += weight * *src;
            }
        }
    }

    Arch::new().dispatch(WeightedSum {
        output,
        weight,
        values,
    });
}

#[cfg(test)]
mod tests {
    use super::{dot, squared_l2_norm, sum_squared_error, weighted_sum_in_place, ExecutionBackend};
    use approx::assert_abs_diff_eq;

    fn sample_values() -> (Vec<f64>, Vec<f64>) {
        let lhs: Vec<f64> = (0..257).map(|i| ((i as f64) * 0.03125).sin()).collect();
        let rhs: Vec<f64> = (0..257).map(|i| ((i as f64) * 0.015625).cos()).collect();
        (lhs, rhs)
    }

    #[test]
    fn simd_dot_matches_scalar() {
        let (lhs, rhs) = sample_values();
        let scalar = dot(ExecutionBackend::Scalar, &lhs, &rhs);
        let simd = dot(ExecutionBackend::Simd, &lhs, &rhs);
        assert_abs_diff_eq!(scalar, simd, epsilon = 1e-12);
    }

    #[test]
    fn simd_squared_norm_matches_scalar() {
        let (lhs, _) = sample_values();
        let scalar = squared_l2_norm(ExecutionBackend::Scalar, &lhs);
        let simd = squared_l2_norm(ExecutionBackend::Simd, &lhs);
        assert_abs_diff_eq!(scalar, simd, epsilon = 1e-12);
    }

    #[test]
    fn simd_squared_error_matches_scalar() {
        let (lhs, rhs) = sample_values();
        let scalar = sum_squared_error(ExecutionBackend::Scalar, &lhs, &rhs);
        let simd = sum_squared_error(ExecutionBackend::Simd, &lhs, &rhs);
        assert_abs_diff_eq!(scalar, simd, epsilon = 1e-12);
    }

    #[test]
    fn simd_weighted_sum_matches_scalar() {
        let (lhs, rhs) = sample_values();
        let mut scalar = lhs.clone();
        let mut simd = lhs;
        weighted_sum_in_place(ExecutionBackend::Scalar, &mut scalar, 0.37, &rhs);
        weighted_sum_in_place(ExecutionBackend::Simd, &mut simd, 0.37, &rhs);
        for (left, right) in scalar.iter().zip(simd.iter()) {
            assert_abs_diff_eq!(left, right, epsilon = 1e-12);
        }
    }
}