use std::os::raw::c_int;
#[repr(i32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HsdStatus {
Success = 0,
ErrNullPtr = -1,
ErrInvalidInput = -3,
ErrCpuNotSupported = -4,
Failure = -99,
}
impl HsdStatus {
#[inline]
pub fn is_success(self) -> bool {
self == HsdStatus::Success
}
}
impl From<c_int> for HsdStatus {
fn from(value: c_int) -> Self {
match value {
0 => HsdStatus::Success,
-1 => HsdStatus::ErrNullPtr,
-3 => HsdStatus::ErrInvalidInput,
-4 => HsdStatus::ErrCpuNotSupported,
_ => HsdStatus::Failure,
}
}
}
unsafe extern "C" {
pub fn hsd_dist_sqeuclidean_f32(
a: *const f32,
b: *const f32,
n: usize,
result: *mut f32,
) -> c_int;
pub fn hsd_dist_manhattan_f32(
a: *const f32,
b: *const f32,
n: usize,
result: *mut f32,
) -> c_int;
pub fn hsd_sim_cosine_f32(a: *const f32, b: *const f32, n: usize, result: *mut f32) -> c_int;
pub fn hsd_sim_dot_f32(a: *const f32, b: *const f32, n: usize, result: *mut f32) -> c_int;
pub fn hsd_get_backend() -> *const std::os::raw::c_char;
}
#[inline]
pub fn sqeuclidean_f32(a: &[f32], b: &[f32]) -> Option<f32> {
if a.len() != b.len() {
return None;
}
let mut result: f32 = 0.0;
let status: HsdStatus =
unsafe { hsd_dist_sqeuclidean_f32(a.as_ptr(), b.as_ptr(), a.len(), &mut result) }.into();
if status.is_success() {
Some(result)
} else {
None
}
}
#[inline]
pub fn manhattan_f32(a: &[f32], b: &[f32]) -> Option<f32> {
if a.len() != b.len() {
return None;
}
let mut result: f32 = 0.0;
let status: HsdStatus =
unsafe { hsd_dist_manhattan_f32(a.as_ptr(), b.as_ptr(), a.len(), &mut result) }.into();
if status.is_success() {
Some(result)
} else {
None
}
}
#[inline]
pub fn cosine_f32(a: &[f32], b: &[f32]) -> Option<f32> {
if a.len() != b.len() {
return None;
}
let mut result: f32 = 0.0;
let status: HsdStatus =
unsafe { hsd_sim_cosine_f32(a.as_ptr(), b.as_ptr(), a.len(), &mut result) }.into();
if status.is_success() {
Some(result)
} else {
None
}
}
pub fn get_simd_backend() -> String {
unsafe {
let ptr = hsd_get_backend();
if ptr.is_null() {
return "Unknown".to_string();
}
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backend_detection() {
let backend = get_simd_backend();
assert!(!backend.is_empty());
assert!(!backend.contains("Unknown"));
}
#[test]
fn test_sqeuclidean_f32_valid() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = sqeuclidean_f32(&a, &b).expect("SIMD sqeuclidean failed");
assert!((result - 27.0).abs() < 1e-5);
}
#[test]
fn test_sqeuclidean_f32_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0];
let result = sqeuclidean_f32(&a, &b);
assert!(result.is_none());
}
#[test]
fn test_manhattan_f32_valid() {
let a = vec![1.0, 2.0];
let b = vec![3.0, 1.0];
let result = manhattan_f32(&a, &b).expect("SIMD manhattan failed");
assert!((result - 3.0).abs() < 1e-5);
}
#[test]
fn test_cosine_f32_valid() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let result = cosine_f32(&a, &b).expect("SIMD cosine failed");
assert!(result.abs() < 1e-5);
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0];
let result = cosine_f32(&a, &b).expect("SIMD cosine failed");
assert!((result - 1.0).abs() < 1e-5);
}
}