pub(crate) mod scalar;
#[cfg(target_arch = "aarch64")]
pub(crate) mod neon;
#[cfg(target_arch = "x86_64")]
pub(crate) mod x86;
#[cfg_attr(not(tarpaulin), inline(always))]
pub(crate) fn dot_768(a: &[f32; 768], b: &[f32; 768]) -> f32 {
dot_768_dispatch(a, b)
}
#[cfg(target_arch = "aarch64")]
#[cfg_attr(not(tarpaulin), inline(always))]
fn dot_768_dispatch(a: &[f32; 768], b: &[f32; 768]) -> f32 {
if cfg!(miri) {
return scalar::dot_768(a, b);
}
unsafe { neon::dot_768(a, b) }
}
#[cfg(target_arch = "x86_64")]
#[cfg_attr(not(tarpaulin), inline(always))]
fn dot_768_dispatch(a: &[f32; 768], b: &[f32; 768]) -> f32 {
if cfg!(miri) {
return scalar::dot_768(a, b);
}
if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
unsafe { x86::dot_768_avx2_fma(a, b) }
} else {
scalar::dot_768(a, b)
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
#[cfg_attr(not(tarpaulin), inline(always))]
fn dot_768_dispatch(a: &[f32; 768], b: &[f32; 768]) -> f32 {
scalar::dot_768(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
fn fixture() -> (Box<[f32; 768]>, Box<[f32; 768]>) {
let a: Box<[f32; 768]> = (0..768)
.map(|i| ((i as f32) * 0.013).sin())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap();
let b: Box<[f32; 768]> = (0..768)
.map(|i| ((i as f32) * 0.017).cos())
.collect::<Vec<_>>()
.into_boxed_slice()
.try_into()
.unwrap();
(a, b)
}
#[test]
fn dispatch_agrees_with_scalar_within_tolerance() {
let (a, b) = fixture();
let s = scalar::dot_768(&a, &b);
let d = dot_768(&a, &b);
assert!(
(s - d).abs() < 1e-3,
"dispatch dot ({d}) disagrees with scalar ({s})"
);
}
#[test]
fn dispatch_zero_for_orthogonal_axes() {
let mut a = Box::new([0.0f32; 768]);
let mut b = Box::new([0.0f32; 768]);
a[0] = 1.0;
b[1] = 1.0;
assert_eq!(dot_768(&a, &b), 0.0);
}
#[test]
fn short_slice_cannot_be_converted_to_768_array() {
let v = vec![0.0f32; 100];
let arr: Result<&[f32; 768], _> = v.as_slice().try_into();
assert!(
arr.is_err(),
"100-element slice must not convert to [f32; 768]"
);
}
}