use crate::types::{bf16, bf16c, f16, f16c, f32c, f64c, StorageElement};
#[link(name = "numkong")]
extern "C" {
fn nk_bilinear_f64(a: *const f64, b: *const f64, c: *const f64, n: usize, result: *mut f64);
fn nk_bilinear_f32(a: *const f32, b: *const f32, c: *const f32, n: usize, result: *mut f64);
fn nk_bilinear_f16(a: *const u16, b: *const u16, c: *const u16, n: usize, result: *mut f32);
fn nk_bilinear_bf16(a: *const u16, b: *const u16, c: *const u16, n: usize, result: *mut f32);
fn nk_bilinear_f64c(a: *const f64, b: *const f64, c: *const f64, n: usize, results: *mut f64);
fn nk_bilinear_f32c(a: *const f32, b: *const f32, c: *const f32, n: usize, results: *mut f64);
fn nk_bilinear_f16c(a: *const u16, b: *const u16, c: *const u16, n: usize, results: *mut f32);
fn nk_bilinear_bf16c(a: *const u16, b: *const u16, c: *const u16, n: usize, results: *mut f32);
fn nk_mahalanobis_f64(a: *const f64, b: *const f64, c: *const f64, n: usize, result: *mut f64);
fn nk_mahalanobis_f32(a: *const f32, b: *const f32, c: *const f32, n: usize, result: *mut f64);
fn nk_mahalanobis_f16(a: *const u16, b: *const u16, c: *const u16, n: usize, result: *mut f32);
fn nk_mahalanobis_bf16(a: *const u16, b: *const u16, c: *const u16, n: usize, result: *mut f32);
}
pub trait Bilinear: StorageElement {
type Output;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output>;
}
impl Bilinear for f64 {
type Output = f64;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f64 = 0.0;
unsafe {
nk_bilinear_f64(a.as_ptr(), b.as_ptr(), c.as_ptr(), n, &mut result);
}
Some(result)
}
}
impl Bilinear for f32 {
type Output = f64;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f64 = 0.0;
unsafe {
nk_bilinear_f32(a.as_ptr(), b.as_ptr(), c.as_ptr(), n, &mut result);
}
Some(result)
}
}
impl Bilinear for f16 {
type Output = f32;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f32 = 0.0;
unsafe {
nk_bilinear_f16(
a.as_ptr() as *const u16,
b.as_ptr() as *const u16,
c.as_ptr() as *const u16,
n,
&mut result,
);
}
Some(result)
}
}
impl Bilinear for bf16 {
type Output = f32;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f32 = 0.0;
unsafe {
nk_bilinear_bf16(
a.as_ptr() as *const u16,
b.as_ptr() as *const u16,
c.as_ptr() as *const u16,
n,
&mut result,
);
}
Some(result)
}
}
impl Bilinear for f64c {
type Output = f64c;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result = [0.0f64; 2];
unsafe {
nk_bilinear_f64c(
a.as_ptr() as *const f64,
b.as_ptr() as *const f64,
c.as_ptr() as *const f64,
n,
result.as_mut_ptr(),
);
}
Some(f64c {
re: result[0],
im: result[1],
})
}
}
impl Bilinear for f32c {
type Output = f64c;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result = [0.0f64; 2];
unsafe {
nk_bilinear_f32c(
a.as_ptr() as *const f32,
b.as_ptr() as *const f32,
c.as_ptr() as *const f32,
n,
result.as_mut_ptr(),
);
}
Some(f64c {
re: result[0],
im: result[1],
})
}
}
impl Bilinear for f16c {
type Output = f32c;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result = [0.0f32; 2];
unsafe {
nk_bilinear_f16c(
a.as_ptr() as *const u16,
b.as_ptr() as *const u16,
c.as_ptr() as *const u16,
n,
result.as_mut_ptr(),
);
}
Some(f32c {
re: result[0],
im: result[1],
})
}
}
impl Bilinear for bf16c {
type Output = f32c;
fn bilinear(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result = [0.0f32; 2];
unsafe {
nk_bilinear_bf16c(
a.as_ptr() as *const u16,
b.as_ptr() as *const u16,
c.as_ptr() as *const u16,
n,
result.as_mut_ptr(),
);
}
Some(f32c {
re: result[0],
im: result[1],
})
}
}
pub trait Mahalanobis: StorageElement {
type Output;
fn mahalanobis(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output>;
}
impl Mahalanobis for f64 {
type Output = f64;
fn mahalanobis(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f64 = 0.0;
unsafe {
nk_mahalanobis_f64(a.as_ptr(), b.as_ptr(), c.as_ptr(), n, &mut result);
}
Some(result)
}
}
impl Mahalanobis for f32 {
type Output = f64;
fn mahalanobis(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f64 = 0.0;
unsafe {
nk_mahalanobis_f32(a.as_ptr(), b.as_ptr(), c.as_ptr(), n, &mut result);
}
Some(result)
}
}
impl Mahalanobis for f16 {
type Output = f32;
fn mahalanobis(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f32 = 0.0;
unsafe {
nk_mahalanobis_f16(
a.as_ptr() as *const u16,
b.as_ptr() as *const u16,
c.as_ptr() as *const u16,
n,
&mut result,
);
}
Some(result)
}
}
impl Mahalanobis for bf16 {
type Output = f32;
fn mahalanobis(a: &[Self], b: &[Self], c: &[Self]) -> Option<Self::Output> {
let n = a.len();
if n == 0 || b.len() != n || c.len() != n * n {
return None;
}
let mut result: f32 = 0.0;
unsafe {
nk_mahalanobis_bf16(
a.as_ptr() as *const u16,
b.as_ptr() as *const u16,
c.as_ptr() as *const u16,
n,
&mut result,
);
}
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{assert_close, bf16, f16, FloatLike, NumberLike, TestableType};
pub(crate) fn make_identity<T: FloatLike>(n: usize) -> Vec<T> {
let mut v = vec![T::zero(); n * n];
for i in 0..n {
v[i * n + i] = T::one();
}
v
}
fn check_bilinear<T>(first_values: &[f32], second_values: &[f32], expected: f64)
where
T: FloatLike + TestableType + Bilinear,
T::Output: FloatLike,
{
let first: Vec<T> = first_values.iter().map(|&v| T::from_f32(v)).collect();
let second: Vec<T> = second_values.iter().map(|&v| T::from_f32(v)).collect();
let identity = make_identity::<T>(first.len());
let result = T::bilinear(&first, &second, &identity).unwrap();
assert_close(
result.to_f64(),
expected,
T::atol(),
T::rtol(),
&format!("bilinear<{}>", core::any::type_name::<T>()),
);
}
#[test]
fn bilinear() {
let first_values = &[1.0, 2.0, 3.0];
let second_values = &[4.0, 5.0, 6.0];
check_bilinear::<f64>(first_values, second_values, 32.0);
check_bilinear::<f32>(first_values, second_values, 32.0);
check_bilinear::<f16>(first_values, second_values, 32.0);
check_bilinear::<bf16>(first_values, second_values, 32.0);
}
fn check_mahalanobis<T>(first_values: &[f32], second_values: &[f32], expected: f64)
where
T: FloatLike + TestableType + Mahalanobis,
T::Output: FloatLike,
{
let first: Vec<T> = first_values.iter().map(|&v| T::from_f32(v)).collect();
let second: Vec<T> = second_values.iter().map(|&v| T::from_f32(v)).collect();
let identity = make_identity::<T>(first.len());
let result = T::mahalanobis(&first, &second, &identity).unwrap();
assert_close(
result.to_f64(),
expected,
T::atol(),
T::rtol(),
&format!("mahalanobis<{}>", core::any::type_name::<T>()),
);
}
#[test]
fn mahalanobis() {
let first_values = &[1.0, 2.0, 3.0];
let second_values = &[4.0, 5.0, 6.0];
let expected = (27.0_f64).sqrt();
check_mahalanobis::<f64>(first_values, second_values, expected);
check_mahalanobis::<f32>(first_values, second_values, expected);
check_mahalanobis::<f16>(first_values, second_values, expected);
check_mahalanobis::<bf16>(first_values, second_values, expected);
}
}