extern crate alloc;
use core::marker::PhantomData;
use core::ptr::NonNull;
use crate::tensor::{
Allocator, Global, Tensor, TensorError, TensorMut, TensorRef, TensorView, SIMD_ALIGNMENT,
};
use crate::types::{bf16, e2m3, e3m2, e4m3, e5m2, f16, i4x2, u1x8, u4x2, StorageElement};
#[link(name = "numkong")]
extern "C" {
fn nk_dots_packed_size_f32(width: usize, depth: usize) -> usize;
fn nk_dots_pack_f32(
b: *const f32,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_f32(
a: *const f32,
packed: *const u8,
c: *mut f64,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_f64(width: usize, depth: usize) -> usize;
fn nk_dots_pack_f64(
b: *const f64,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_f64(
a: *const f64,
packed: *const u8,
c: *mut f64,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_f16(width: usize, depth: usize) -> usize;
fn nk_dots_pack_f16(
b: *const u16,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_f16(
a: *const u16,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_bf16(width: usize, depth: usize) -> usize;
fn nk_dots_pack_bf16(
b: *const u16,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_bf16(
a: *const u16,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_i8(width: usize, depth: usize) -> usize;
fn nk_dots_pack_i8(b: *const i8, width: usize, depth: usize, b_stride: usize, packed: *mut u8);
fn nk_dots_packed_i8(
a: *const i8,
packed: *const u8,
c: *mut i32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_u8(width: usize, depth: usize) -> usize;
fn nk_dots_pack_u8(b: *const u8, width: usize, depth: usize, b_stride: usize, packed: *mut u8);
fn nk_dots_packed_u8(
a: *const u8,
packed: *const u8,
c: *mut u32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_e4m3(width: usize, depth: usize) -> usize;
fn nk_dots_pack_e4m3(
b: *const u8,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_e4m3(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_e5m2(width: usize, depth: usize) -> usize;
fn nk_dots_pack_e5m2(
b: *const u8,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_e5m2(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_e2m3(width: usize, depth: usize) -> usize;
fn nk_dots_pack_e2m3(
b: *const u8,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_e2m3(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_e3m2(width: usize, depth: usize) -> usize;
fn nk_dots_pack_e3m2(
b: *const u8,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
fn nk_dots_packed_e3m2(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_u4(width: usize, depth: usize) -> usize;
fn nk_dots_pack_u4(b: *const u8, width: usize, depth: usize, b_stride: usize, packed: *mut u8);
fn nk_dots_packed_u4(
a: *const u8,
packed: *const u8,
c: *mut u32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_packed_size_i4(width: usize, depth: usize) -> usize;
fn nk_dots_pack_i4(b: *const u8, width: usize, depth: usize, b_stride: usize, packed: *mut u8);
fn nk_dots_packed_i4(
a: *const u8,
packed: *const u8,
c: *mut i32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_symmetric_f32(
vectors: *const f32,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f64,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_f64(
vectors: *const f64,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f64,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_f16(
vectors: *const u16,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_bf16(
vectors: *const u16,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_i8(
vectors: *const i8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut i32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_u8(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut u32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_e4m3(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_e5m2(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_e2m3(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_e3m2(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_u4(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut u32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_symmetric_i4(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut i32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_dots_packed_size_u1(width: usize, depth: usize) -> usize;
fn nk_dots_pack_u1(
q: *const u8,
width: usize,
depth: usize,
q_stride: usize,
q_packed: *mut u8,
);
fn nk_dots_packed_u1(
a: *const u8,
packed: *const u8,
c: *mut u32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_dots_symmetric_u1(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut u32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_hammings_packed_u1(
a: *const u8,
q_packed: *const u8,
result: *mut u32,
height: usize,
width: usize,
depth: usize,
v_stride: usize,
r_stride: usize,
);
fn nk_hammings_symmetric_u1(
vectors: *const u8,
n_vectors: usize,
d: usize,
stride: usize,
result: *mut u32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_jaccards_packed_u1(
v: *const u8,
q_packed: *const u8,
result: *mut f32,
height: usize,
width: usize,
depth: usize,
v_stride: usize,
r_stride: usize,
);
fn nk_jaccards_symmetric_u1(
vectors: *const u8,
n_vectors: usize,
d: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_f32(
a: *const f32,
packed: *const u8,
c: *mut f64,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_f32(
vectors: *const f32,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f64,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_f64(
a: *const f64,
packed: *const u8,
c: *mut f64,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_f64(
vectors: *const f64,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f64,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_f16(
a: *const u16,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_f16(
vectors: *const u16,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_bf16(
a: *const u16,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_bf16(
vectors: *const u16,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_i8(
a: *const i8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_i8(
vectors: *const i8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_u8(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_u8(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_e4m3(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_e4m3(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_e5m2(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_e5m2(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_e2m3(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_e2m3(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_e3m2(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_e3m2(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_i4(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_i4(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_angulars_packed_u4(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_angulars_symmetric_u4(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_f32(
a: *const f32,
packed: *const u8,
c: *mut f64,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_f32(
vectors: *const f32,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f64,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_f64(
a: *const f64,
packed: *const u8,
c: *mut f64,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_f64(
vectors: *const f64,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f64,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_f16(
a: *const u16,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_f16(
vectors: *const u16,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_bf16(
a: *const u16,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_bf16(
vectors: *const u16,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_i8(
a: *const i8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_i8(
vectors: *const i8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_u8(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_u8(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_e4m3(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_e4m3(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_e5m2(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_e5m2(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_e2m3(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_e2m3(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_e3m2(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_e3m2(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_i4(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_i4(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
fn nk_euclideans_packed_u4(
a: *const u8,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
fn nk_euclideans_symmetric_u4(
vectors: *const u8,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
}
mod private {
pub trait Sealed {}
impl Sealed for f32 {}
impl Sealed for f64 {}
impl Sealed for super::f16 {}
impl Sealed for super::bf16 {}
impl Sealed for super::e4m3 {}
impl Sealed for super::e5m2 {}
impl Sealed for super::e2m3 {}
impl Sealed for super::e3m2 {}
impl Sealed for i8 {}
impl Sealed for u8 {}
impl Sealed for super::u4x2 {}
impl Sealed for super::i4x2 {}
impl Sealed for super::u1x8 {}
}
pub trait Dots: StorageElement + private::Sealed {
type Accumulator: StorageElement;
fn dots_packed_size(width: usize, depth: usize) -> usize;
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
);
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
);
}
impl Dots for f32 {
type Accumulator = f64;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_f32(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_f32(b, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_f32(a, packed, c, height, width, depth, a_stride, c_stride)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_f32(
vectors,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for f64 {
type Accumulator = f64;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_f64(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_f64(b, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_f64(a, packed, c, height, width, depth, a_stride, c_stride)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_f64(
vectors,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for f16 {
type Accumulator = f32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_f16(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_f16(b as *const u16, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_f16(
a as *const u16,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_f16(
vectors as *const u16,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for bf16 {
type Accumulator = f32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_bf16(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_bf16(b as *const u16, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_bf16(
a as *const u16,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_bf16(
vectors as *const u16,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for i8 {
type Accumulator = i32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_i8(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_i8(b, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_i8(a, packed, c, height, width, depth, a_stride, c_stride)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_i8(
vectors,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for u8 {
type Accumulator = u32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_u8(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_u8(b, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_u8(a, packed, c, height, width, depth, a_stride, c_stride)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_u8(
vectors,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for e4m3 {
type Accumulator = f32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_e4m3(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_e4m3(b as *const u8, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_e4m3(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_e4m3(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for e5m2 {
type Accumulator = f32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_e5m2(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_e5m2(b as *const u8, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_e5m2(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_e5m2(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for e2m3 {
type Accumulator = f32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_e2m3(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_e2m3(b as *const u8, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_e2m3(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_e2m3(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for e3m2 {
type Accumulator = f32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_e3m2(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_e3m2(b as *const u8, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_e3m2(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_e3m2(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for u4x2 {
type Accumulator = u32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_u4(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_u4(b as *const u8, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_u4(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_u4(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for i4x2 {
type Accumulator = i32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_i4(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_i4(b as *const u8, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_i4(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_i4(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Dots for u1x8 {
type Accumulator = u32;
fn dots_packed_size(width: usize, depth: usize) -> usize {
unsafe { nk_dots_packed_size_u1(width, depth) }
}
unsafe fn dots_pack(
b: *const Self,
width: usize,
depth: usize,
b_stride: usize,
packed: *mut u8,
) {
nk_dots_pack_u1(b as *const u8, width, depth, b_stride, packed)
}
unsafe fn dots_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::Accumulator,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_dots_packed_u1(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn dots_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::Accumulator,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_dots_symmetric_u1(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
pub trait Hammings: Dots {
unsafe fn hammings_packed(
a: *const Self,
q_packed: *const u8,
result: *mut u32,
height: usize,
width: usize,
depth: usize,
v_stride: usize,
r_stride: usize,
);
unsafe fn hammings_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut u32,
result_stride: usize,
row_start: usize,
row_count: usize,
);
}
impl Hammings for u1x8 {
unsafe fn hammings_packed(
a: *const Self,
q_packed: *const u8,
result: *mut u32,
height: usize,
width: usize,
depth: usize,
v_stride: usize,
r_stride: usize,
) {
nk_hammings_packed_u1(
a as *const u8,
q_packed,
result,
height,
width,
depth,
v_stride,
r_stride,
)
}
unsafe fn hammings_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut u32,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_hammings_symmetric_u1(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
pub trait Jaccards: Dots {
type JaccardResult: StorageElement;
unsafe fn jaccards_packed(
a: *const Self,
q_packed: *const u8,
result: *mut Self::JaccardResult,
height: usize,
width: usize,
depth: usize,
v_stride: usize,
r_stride: usize,
);
unsafe fn jaccards_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::JaccardResult,
result_stride: usize,
row_start: usize,
row_count: usize,
);
}
impl Jaccards for u1x8 {
type JaccardResult = f32;
unsafe fn jaccards_packed(
a: *const Self,
q_packed: *const u8,
result: *mut Self::JaccardResult,
height: usize,
width: usize,
depth: usize,
v_stride: usize,
r_stride: usize,
) {
nk_jaccards_packed_u1(
a as *const u8,
q_packed,
result,
height,
width,
depth,
v_stride,
r_stride,
)
}
unsafe fn jaccards_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::JaccardResult,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_jaccards_symmetric_u1(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
pub trait Angulars: Dots {
type SpatialResult: StorageElement;
unsafe fn angulars_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::SpatialResult,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
unsafe fn angulars_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::SpatialResult,
result_stride: usize,
row_start: usize,
row_count: usize,
);
}
pub trait Euclideans: Dots {
type SpatialResult: StorageElement;
unsafe fn euclideans_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::SpatialResult,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
);
unsafe fn euclideans_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::SpatialResult,
result_stride: usize,
row_start: usize,
row_count: usize,
);
}
macro_rules! impl_spatial_traits {
($rust_ty:ty, $result_ty:ty, $ptr_ty:ty, $cast:expr,
$ang_packed:ident, $ang_sym:ident, $euc_packed:ident, $euc_sym:ident) => {
impl Angulars for $rust_ty {
type SpatialResult = $result_ty;
unsafe fn angulars_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::SpatialResult,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
$ang_packed(
$cast(a),
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn angulars_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::SpatialResult,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
$ang_sym(
$cast(vectors),
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Euclideans for $rust_ty {
type SpatialResult = $result_ty;
unsafe fn euclideans_packed(
a: *const Self,
packed: *const u8,
c: *mut Self::SpatialResult,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
$euc_packed(
$cast(a),
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn euclideans_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut Self::SpatialResult,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
$euc_sym(
$cast(vectors),
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
};
}
#[inline(always)]
fn identity_f32(p: *const f32) -> *const f32 {
p
}
#[inline(always)]
fn identity_f64(p: *const f64) -> *const f64 {
p
}
#[inline(always)]
fn identity_i8(p: *const i8) -> *const i8 {
p
}
#[inline(always)]
fn identity_u8(p: *const u8) -> *const u8 {
p
}
#[inline(always)]
fn cast_to_u16<Scalar>(p: *const Scalar) -> *const u16 {
p as *const u16
}
#[inline(always)]
fn cast_to_u8<Scalar>(p: *const Scalar) -> *const u8 {
p as *const u8
}
impl_spatial_traits!(
f32,
f64,
*const f32,
identity_f32,
nk_angulars_packed_f32,
nk_angulars_symmetric_f32,
nk_euclideans_packed_f32,
nk_euclideans_symmetric_f32
);
impl_spatial_traits!(
f64,
f64,
*const f64,
identity_f64,
nk_angulars_packed_f64,
nk_angulars_symmetric_f64,
nk_euclideans_packed_f64,
nk_euclideans_symmetric_f64
);
impl_spatial_traits!(
f16,
f32,
*const u16,
cast_to_u16,
nk_angulars_packed_f16,
nk_angulars_symmetric_f16,
nk_euclideans_packed_f16,
nk_euclideans_symmetric_f16
);
impl_spatial_traits!(
bf16,
f32,
*const u16,
cast_to_u16,
nk_angulars_packed_bf16,
nk_angulars_symmetric_bf16,
nk_euclideans_packed_bf16,
nk_euclideans_symmetric_bf16
);
impl_spatial_traits!(
i8,
f32,
*const i8,
identity_i8,
nk_angulars_packed_i8,
nk_angulars_symmetric_i8,
nk_euclideans_packed_i8,
nk_euclideans_symmetric_i8
);
impl_spatial_traits!(
u8,
f32,
*const u8,
identity_u8,
nk_angulars_packed_u8,
nk_angulars_symmetric_u8,
nk_euclideans_packed_u8,
nk_euclideans_symmetric_u8
);
impl_spatial_traits!(
e4m3,
f32,
*const u8,
cast_to_u8,
nk_angulars_packed_e4m3,
nk_angulars_symmetric_e4m3,
nk_euclideans_packed_e4m3,
nk_euclideans_symmetric_e4m3
);
impl_spatial_traits!(
e5m2,
f32,
*const u8,
cast_to_u8,
nk_angulars_packed_e5m2,
nk_angulars_symmetric_e5m2,
nk_euclideans_packed_e5m2,
nk_euclideans_symmetric_e5m2
);
impl_spatial_traits!(
e2m3,
f32,
*const u8,
cast_to_u8,
nk_angulars_packed_e2m3,
nk_angulars_symmetric_e2m3,
nk_euclideans_packed_e2m3,
nk_euclideans_symmetric_e2m3
);
impl_spatial_traits!(
e3m2,
f32,
*const u8,
cast_to_u8,
nk_angulars_packed_e3m2,
nk_angulars_symmetric_e3m2,
nk_euclideans_packed_e3m2,
nk_euclideans_symmetric_e3m2
);
impl Angulars for u4x2 {
type SpatialResult = f32;
unsafe fn angulars_packed(
a: *const Self,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_angulars_packed_u4(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn angulars_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_angulars_symmetric_u4(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Euclideans for u4x2 {
type SpatialResult = f32;
unsafe fn euclideans_packed(
a: *const Self,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_euclideans_packed_u4(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn euclideans_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_euclideans_symmetric_u4(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Angulars for i4x2 {
type SpatialResult = f32;
unsafe fn angulars_packed(
a: *const Self,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_angulars_packed_i4(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn angulars_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_angulars_symmetric_i4(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
impl Euclideans for i4x2 {
type SpatialResult = f32;
unsafe fn euclideans_packed(
a: *const Self,
packed: *const u8,
c: *mut f32,
height: usize,
width: usize,
depth: usize,
a_stride: usize,
c_stride: usize,
) {
nk_euclideans_packed_i4(
a as *const u8,
packed,
c,
height,
width,
depth,
a_stride,
c_stride,
)
}
unsafe fn euclideans_symmetric(
vectors: *const Self,
n_vectors: usize,
depth: usize,
stride: usize,
result: *mut f32,
result_stride: usize,
row_start: usize,
row_count: usize,
) {
nk_euclideans_symmetric_i4(
vectors as *const u8,
n_vectors,
depth,
stride,
result,
result_stride,
row_start,
row_count,
)
}
}
pub struct PackedMatrix<Scalar: Dots, Alloc: Allocator = Global> {
data: NonNull<u8>,
size: usize,
width: usize,
depth: usize,
alloc: Alloc,
_marker: PhantomData<Scalar>,
}
unsafe impl<Scalar: Dots + Send, Alloc: Allocator + Send> Send for PackedMatrix<Scalar, Alloc> {}
unsafe impl<Scalar: Dots + Sync, Alloc: Allocator + Sync> Sync for PackedMatrix<Scalar, Alloc> {}
impl<Scalar: Dots, Alloc: Allocator> Drop for PackedMatrix<Scalar, Alloc> {
fn drop(&mut self) {
if self.size > 0 {
unsafe {
let layout =
alloc::alloc::Layout::from_size_align_unchecked(self.size, SIMD_ALIGNMENT);
self.alloc.deallocate(self.data, layout);
}
}
}
}
impl<Scalar: Dots, Alloc: Allocator + Clone> PackedMatrix<Scalar, Alloc> {
pub fn try_clone(&self) -> Result<Self, TensorError> {
if self.size == 0 {
return Ok(Self {
data: NonNull::dangling(),
size: 0,
width: self.width,
depth: self.depth,
alloc: self.alloc.clone(),
_marker: PhantomData,
});
}
let layout = alloc::alloc::Layout::from_size_align(self.size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = self
.alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe {
core::ptr::copy_nonoverlapping(self.data.as_ptr(), ptr.as_ptr(), self.size);
}
Ok(Self {
data: ptr,
size: self.size,
width: self.width,
depth: self.depth,
alloc: self.alloc.clone(),
_marker: PhantomData,
})
}
}
impl<Scalar: Dots, Alloc: Allocator + Clone> Clone for PackedMatrix<Scalar, Alloc> {
fn clone(&self) -> Self {
self.try_clone()
.expect("PackedMatrix clone allocation failed")
}
}
impl<Scalar: Dots, Alloc: Allocator> PackedMatrix<Scalar, Alloc> {
pub fn try_pack_in<PackedAlloc: Allocator, const MAX_RANK: usize>(
b: &Tensor<Scalar, PackedAlloc, MAX_RANK>,
alloc: Alloc,
) -> Result<Self, TensorError> {
if b.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: b.ndim(),
});
}
let (width, depth) = (b.shape()[0], b.shape()[1]);
let size = Scalar::dots_packed_size(width, depth);
let data = if size == 0 {
NonNull::dangling()
} else {
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe {
core::ptr::write_bytes(ptr.as_ptr(), 0, size);
}
ptr
};
if size > 0 {
unsafe {
Scalar::dots_pack(
b.as_ptr(),
width,
depth,
b.stride_bytes(0) as usize,
data.as_ptr(),
);
}
}
Ok(Self {
data,
size,
width,
depth,
alloc,
_marker: PhantomData,
})
}
pub fn try_pack_transposed_in<PackedAlloc: Allocator, const MAX_RANK: usize>(
b: &Tensor<Scalar, PackedAlloc, MAX_RANK>,
alloc: Alloc,
) -> Result<Self, TensorError> {
if b.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: b.ndim(),
});
}
let transposed = b.transpose()?.to_owned()?;
Self::try_pack_in(&transposed, alloc)
}
pub fn allocator(&self) -> &Alloc {
&self.alloc
}
pub fn dims(&self) -> (usize, usize) {
(self.width, self.depth)
}
pub fn as_bytes(&self) -> &[u8] {
unsafe { core::slice::from_raw_parts(self.data.as_ptr(), self.size) }
}
pub fn as_ptr(&self) -> *const u8 {
self.data.as_ptr()
}
}
impl<Scalar: Dots> PackedMatrix<Scalar, Global> {
pub fn try_pack<PackedAlloc: Allocator, const MAX_RANK: usize>(
b: &Tensor<Scalar, PackedAlloc, MAX_RANK>,
) -> Result<Self, TensorError> {
Self::try_pack_in(b, Global)
}
pub fn try_pack_transposed<PackedAlloc: Allocator, const MAX_RANK: usize>(
b: &Tensor<Scalar, PackedAlloc, MAX_RANK>,
) -> Result<Self, TensorError> {
Self::try_pack_transposed_in(b, Global)
}
pub fn pack<PackedAlloc: Allocator, const MAX_RANK: usize>(
b: &Tensor<Scalar, PackedAlloc, MAX_RANK>,
) -> Self {
Self::try_pack(b).expect("PackedMatrix::pack failed")
}
pub fn pack_transposed<PackedAlloc: Allocator, const MAX_RANK: usize>(
b: &Tensor<Scalar, PackedAlloc, MAX_RANK>,
) -> Self {
Self::try_pack_transposed(b).expect("PackedMatrix::pack_transposed failed")
}
}
#[inline]
fn validate_packed_input<Scalar, Alloc, PackedAlloc, const MAX_RANK: usize>(
a: &Tensor<Scalar, Alloc, MAX_RANK>,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Result<(usize, usize, usize), TensorError>
where
Scalar: Dots,
Alloc: Allocator,
PackedAlloc: Allocator,
{
if a.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: a.ndim(),
});
}
if !a.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
let (height, depth) = (a.shape()[0], a.shape()[1]);
let (width, packed_depth) = packed_b.dims();
if depth != packed_depth {
return Err(TensorError::ShapeMismatch {
axis: 1,
expected: packed_depth,
got: depth,
});
}
Ok((height, width, depth))
}
#[inline]
fn validate_matrix_output<R, OutputTensor, const OUTPUT_MAX_RANK: usize>(
c: &OutputTensor,
height: usize,
width: usize,
) -> Result<(), TensorError>
where
R: StorageElement,
OutputTensor: TensorRef<R, OUTPUT_MAX_RANK> + ?Sized,
{
if c.shape() != [height, width] {
return Err(TensorError::ShapeMismatch {
axis: if c.shape().first().copied() != Some(height) {
0
} else {
1
},
expected: if c.shape().first().copied() != Some(height) {
height
} else {
width
},
got: if c.shape().first().copied() != Some(height) {
c.shape().first().copied().unwrap_or(0)
} else {
c.shape().get(1).copied().unwrap_or(0)
},
});
}
if !c.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
Ok(())
}
#[inline]
fn validate_symmetric_input<Scalar, InputTensor, const MAX_RANK: usize>(
a: &InputTensor,
) -> Result<(usize, usize), TensorError>
where
Scalar: StorageElement,
InputTensor: TensorRef<Scalar, MAX_RANK> + ?Sized,
{
if a.ndim() != 2 {
return Err(TensorError::InvalidShape {
axis: 0,
size: a.ndim(),
reason: "symmetric operations require a 2D tensor",
});
}
Ok((a.shape()[0], a.shape()[1]))
}
impl<Scalar: Dots, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
{
pub fn try_dots_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Result<Tensor<Scalar::Accumulator, Alloc, MAX_RANK>, TensorError> {
if self.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: self.ndim(),
});
}
if !self.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
let (height, depth) = (self.shape()[0], self.shape()[1]);
let (width, packed_depth) = packed_b.dims();
if depth != packed_depth {
return Err(TensorError::ShapeMismatch {
axis: 1,
expected: packed_depth,
got: depth,
});
}
let mut c = Tensor::try_full_in(
&[height, width],
Scalar::Accumulator::default(),
self.alloc.clone(),
)?;
unsafe {
Scalar::dots_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(c)
}
pub fn dots_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Tensor<Scalar::Accumulator, Alloc, MAX_RANK> {
self.try_dots_packed(packed_b).expect("dots_packed failed")
}
}
impl<Scalar: Dots, Alloc: Allocator, const MAX_RANK: usize> Tensor<Scalar, Alloc, MAX_RANK> {
pub fn try_dots_packed_into<PackedAlloc, OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::Accumulator, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::Accumulator, _, OUTPUT_MAX_RANK>(c, height, width)?;
unsafe {
Scalar::dots_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(())
}
}
#[cfg(feature = "parallel")]
impl<Scalar: Dots + Clone + Send + Sync, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
where
Scalar::Accumulator: Send + Sync,
{
pub fn try_dots_packed_parallel_into<PackedAlloc, OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::Accumulator, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::Accumulator, _, OUTPUT_MAX_RANK>(c, height, width)?;
let a_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let c_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let packed_ptr = fork_union::SyncConstPtr::new(packed_b.as_ptr());
let a_stride = self.stride_bytes(0) as usize;
let c_stride = c.stride_bytes(0) as usize;
let num_threads = pool.threads().max(1);
let rows_per_thread = height.div_ceil(num_threads);
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let row_start = thread_index * rows_per_thread;
let row_end = (row_start + rows_per_thread).min(height);
if row_start < height {
unsafe {
let a_row =
(a_ptr.as_ptr() as *const u8).add(row_start * a_stride) as *const Scalar;
let c_row = (c_ptr.as_ptr() as *mut u8).add(row_start * c_stride)
as *mut Scalar::Accumulator;
Scalar::dots_packed(
a_row,
packed_ptr.as_ptr(),
c_row,
row_end - row_start,
width,
depth,
a_stride,
c_stride,
);
}
}
})
.join();
Ok(())
}
pub fn try_dots_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::Accumulator, Global, MAX_RANK>, TensorError> {
let height = self.shape()[0];
let (width, _) = packed_b.dims();
let mut c = Tensor::<Scalar::Accumulator, Global, MAX_RANK>::try_full(
&[height, width],
Scalar::Accumulator::default(),
)?;
self.try_dots_packed_parallel_into(packed_b, &mut c, pool)?;
Ok(c)
}
pub fn dots_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::Accumulator, Global, MAX_RANK> {
self.try_dots_packed_parallel(packed_b, pool)
.expect("parallel dots_packed failed")
}
}
#[cfg(feature = "std")]
#[inline]
fn compute_thread_rows(thread_index: usize, num_threads: usize, n: usize) -> (usize, usize) {
let total_work = n * (n + 1) / 2;
let work_per_thread = total_work.div_ceil(num_threads);
let work_start = thread_index * work_per_thread;
let work_end = ((thread_index + 1) * work_per_thread).min(total_work);
let start_row = if work_start == 0 {
0
} else {
let n_f64 = n as f64;
let work_f64 = work_start as f64;
let discriminant = (2.0 * n_f64 + 1.0).powi(2) - 8.0 * work_f64;
let row_f64 = (2.0 * n_f64 + 1.0 - discriminant.sqrt()) / 2.0;
row_f64.ceil() as usize
};
let end_row = if work_end >= total_work {
n
} else {
let n_f64 = n as f64;
let work_f64 = work_end as f64;
let discriminant = (2.0 * n_f64 + 1.0).powi(2) - 8.0 * work_f64;
let row_f64 = (2.0 * n_f64 + 1.0 - discriminant.sqrt()) / 2.0;
row_f64.ceil() as usize
};
(start_row, end_row.saturating_sub(start_row))
}
#[cfg(feature = "parallel")]
impl<Scalar: Dots + Clone + Send + Sync, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
where
Scalar::Accumulator: Send + Sync,
{
pub fn try_dots_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::Accumulator, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::Accumulator, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::Accumulator::default(),
)?;
self.try_dots_symmetric_parallel_into(&mut result, pool)?;
Ok(result)
}
pub fn try_dots_symmetric_parallel_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::Accumulator, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::Accumulator, _, OUTPUT_MAX_RANK>(c, n_vectors, n_vectors)?;
let num_threads = pool.threads().max(1);
let vectors_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let result_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let stride = self.stride_bytes(0) as usize;
let result_stride = c.stride_bytes(0) as usize;
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let (row_start, row_count) = compute_thread_rows(thread_index, num_threads, n_vectors);
unsafe {
Scalar::dots_symmetric(
vectors_ptr.as_ptr(),
n_vectors,
depth,
stride,
result_ptr.as_ptr(),
result_stride,
row_start,
row_count,
);
}
})
.join();
Ok(())
}
pub fn dots_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::Accumulator, Global, MAX_RANK> {
self.try_dots_symmetric_parallel(pool)
.expect("parallel dots_symmetric failed")
}
}
impl<Scalar: Angulars, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
{
pub fn try_angulars_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Result<Tensor<Scalar::SpatialResult, Alloc, MAX_RANK>, TensorError> {
if self.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: self.ndim(),
});
}
if !self.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
let (height, depth) = (self.shape()[0], self.shape()[1]);
let (width, packed_depth) = packed_b.dims();
if depth != packed_depth {
return Err(TensorError::ShapeMismatch {
axis: 1,
expected: packed_depth,
got: depth,
});
}
let mut c = Tensor::try_full_in(
&[height, width],
Scalar::SpatialResult::default(),
self.alloc.clone(),
)?;
unsafe {
Scalar::angulars_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(c)
}
pub fn angulars_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Tensor<Scalar::SpatialResult, Alloc, MAX_RANK> {
self.try_angulars_packed(packed_b)
.expect("angulars_packed failed")
}
}
impl<Scalar: Angulars, Alloc: Allocator, const MAX_RANK: usize> Tensor<Scalar, Alloc, MAX_RANK> {
pub fn try_angulars_packed_into<PackedAlloc, OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(c, height, width)?;
unsafe {
Scalar::angulars_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(())
}
}
impl<Scalar: Euclideans, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
{
pub fn try_euclideans_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Result<Tensor<Scalar::SpatialResult, Alloc, MAX_RANK>, TensorError> {
if self.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: self.ndim(),
});
}
if !self.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
let (height, depth) = (self.shape()[0], self.shape()[1]);
let (width, packed_depth) = packed_b.dims();
if depth != packed_depth {
return Err(TensorError::ShapeMismatch {
axis: 1,
expected: packed_depth,
got: depth,
});
}
let mut c = Tensor::try_full_in(
&[height, width],
Scalar::SpatialResult::default(),
self.alloc.clone(),
)?;
unsafe {
Scalar::euclideans_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(c)
}
pub fn euclideans_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Tensor<Scalar::SpatialResult, Alloc, MAX_RANK> {
self.try_euclideans_packed(packed_b)
.expect("euclideans_packed failed")
}
}
impl<Scalar: Euclideans, Alloc: Allocator, const MAX_RANK: usize> Tensor<Scalar, Alloc, MAX_RANK> {
pub fn try_euclideans_packed_into<PackedAlloc, OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(c, height, width)?;
unsafe {
Scalar::euclideans_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(())
}
}
#[cfg(feature = "parallel")]
impl<Scalar: Angulars + Clone + Send + Sync, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
where
Scalar::SpatialResult: Send + Sync,
{
pub fn try_angulars_packed_parallel_into<
PackedAlloc,
OutputTensor,
const OUTPUT_MAX_RANK: usize,
>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(c, height, width)?;
let a_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let c_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let packed_ptr = fork_union::SyncConstPtr::new(packed_b.as_ptr());
let a_stride = self.stride_bytes(0) as usize;
let c_stride = c.stride_bytes(0) as usize;
let num_threads = pool.threads().max(1);
let rows_per_thread = height.div_ceil(num_threads);
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let row_start = thread_index * rows_per_thread;
let row_end = (row_start + rows_per_thread).min(height);
if row_start < height {
unsafe {
let a_row =
(a_ptr.as_ptr() as *const u8).add(row_start * a_stride) as *const Scalar;
let c_row = (c_ptr.as_ptr() as *mut u8).add(row_start * c_stride)
as *mut Scalar::SpatialResult;
Scalar::angulars_packed(
a_row,
packed_ptr.as_ptr(),
c_row,
row_end - row_start,
width,
depth,
a_stride,
c_stride,
);
}
}
})
.join();
Ok(())
}
pub fn try_angulars_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
let height = self.shape()[0];
let (width, _) = packed_b.dims();
let mut c = Tensor::<Scalar::SpatialResult, Global, MAX_RANK>::try_full(
&[height, width],
Scalar::SpatialResult::default(),
)?;
self.try_angulars_packed_parallel_into(packed_b, &mut c, pool)?;
Ok(c)
}
pub fn angulars_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::SpatialResult, Global, MAX_RANK> {
self.try_angulars_packed_parallel(packed_b, pool)
.expect("parallel angulars_packed failed")
}
pub fn try_angulars_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::SpatialResult, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::SpatialResult::default(),
)?;
self.try_angulars_symmetric_parallel_into(&mut result, pool)?;
Ok(result)
}
pub fn try_angulars_symmetric_parallel_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(
c, n_vectors, n_vectors,
)?;
let num_threads = pool.threads().max(1);
let vectors_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let result_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let stride = self.stride_bytes(0) as usize;
let result_stride = c.stride_bytes(0) as usize;
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let (row_start, row_count) = compute_thread_rows(thread_index, num_threads, n_vectors);
unsafe {
Scalar::angulars_symmetric(
vectors_ptr.as_ptr(),
n_vectors,
depth,
stride,
result_ptr.as_ptr(),
result_stride,
row_start,
row_count,
);
}
})
.join();
Ok(())
}
pub fn angulars_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::SpatialResult, Global, MAX_RANK> {
self.try_angulars_symmetric_parallel(pool)
.expect("parallel angulars_symmetric failed")
}
}
#[cfg(feature = "parallel")]
impl<Scalar: Euclideans + Clone + Send + Sync, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
where
Scalar::SpatialResult: Send + Sync,
{
pub fn try_euclideans_packed_parallel_into<
PackedAlloc,
OutputTensor,
const OUTPUT_MAX_RANK: usize,
>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(c, height, width)?;
let a_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let c_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let packed_ptr = fork_union::SyncConstPtr::new(packed_b.as_ptr());
let a_stride = self.stride_bytes(0) as usize;
let c_stride = c.stride_bytes(0) as usize;
let num_threads = pool.threads().max(1);
let rows_per_thread = height.div_ceil(num_threads);
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let row_start = thread_index * rows_per_thread;
let row_end = (row_start + rows_per_thread).min(height);
if row_start < height {
unsafe {
let a_row =
(a_ptr.as_ptr() as *const u8).add(row_start * a_stride) as *const Scalar;
let c_row = (c_ptr.as_ptr() as *mut u8).add(row_start * c_stride)
as *mut Scalar::SpatialResult;
Scalar::euclideans_packed(
a_row,
packed_ptr.as_ptr(),
c_row,
row_end - row_start,
width,
depth,
a_stride,
c_stride,
);
}
}
})
.join();
Ok(())
}
pub fn try_euclideans_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
let height = self.shape()[0];
let (width, _) = packed_b.dims();
let mut c = Tensor::<Scalar::SpatialResult, Global, MAX_RANK>::try_full(
&[height, width],
Scalar::SpatialResult::default(),
)?;
self.try_euclideans_packed_parallel_into(packed_b, &mut c, pool)?;
Ok(c)
}
pub fn euclideans_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::SpatialResult, Global, MAX_RANK> {
self.try_euclideans_packed_parallel(packed_b, pool)
.expect("parallel euclideans_packed failed")
}
pub fn try_euclideans_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::SpatialResult, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::SpatialResult::default(),
)?;
self.try_euclideans_symmetric_parallel_into(&mut result, pool)?;
Ok(result)
}
pub fn try_euclideans_symmetric_parallel_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(
c, n_vectors, n_vectors,
)?;
let num_threads = pool.threads().max(1);
let vectors_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let result_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let stride = self.stride_bytes(0) as usize;
let result_stride = c.stride_bytes(0) as usize;
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let (row_start, row_count) = compute_thread_rows(thread_index, num_threads, n_vectors);
unsafe {
Scalar::euclideans_symmetric(
vectors_ptr.as_ptr(),
n_vectors,
depth,
stride,
result_ptr.as_ptr(),
result_stride,
row_start,
row_count,
);
}
})
.join();
Ok(())
}
pub fn euclideans_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::SpatialResult, Global, MAX_RANK> {
self.try_euclideans_symmetric_parallel(pool)
.expect("parallel euclideans_symmetric failed")
}
}
impl<Scalar: Hammings, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
{
pub fn try_hammings_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Result<Tensor<u32, Alloc, MAX_RANK>, TensorError> {
if self.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: self.ndim(),
});
}
if !self.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
let (height, depth) = (self.shape()[0], self.shape()[1]);
let (width, packed_depth) = packed_b.dims();
if depth != packed_depth {
return Err(TensorError::ShapeMismatch {
axis: 1,
expected: packed_depth,
got: depth,
});
}
let mut c = Tensor::try_full_in(&[height, width], u32::default(), self.alloc.clone())?;
unsafe {
Scalar::hammings_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(c)
}
pub fn hammings_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Tensor<u32, Alloc, MAX_RANK> {
self.try_hammings_packed(packed_b)
.expect("hammings_packed failed")
}
}
impl<Scalar: Hammings, Alloc: Allocator, const MAX_RANK: usize> Tensor<Scalar, Alloc, MAX_RANK> {
pub fn try_hammings_packed_into<PackedAlloc, OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<u32, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<u32, _, OUTPUT_MAX_RANK>(c, height, width)?;
unsafe {
Scalar::hammings_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(())
}
}
impl<Scalar: Jaccards, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
{
pub fn try_jaccards_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Result<Tensor<Scalar::JaccardResult, Alloc, MAX_RANK>, TensorError> {
if self.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: self.ndim(),
});
}
if !self.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
let (height, depth) = (self.shape()[0], self.shape()[1]);
let (width, packed_depth) = packed_b.dims();
if depth != packed_depth {
return Err(TensorError::ShapeMismatch {
axis: 1,
expected: packed_depth,
got: depth,
});
}
let mut c = Tensor::try_full_in(
&[height, width],
Scalar::JaccardResult::default(),
self.alloc.clone(),
)?;
unsafe {
Scalar::jaccards_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(c)
}
pub fn jaccards_packed<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
) -> Tensor<Scalar::JaccardResult, Alloc, MAX_RANK> {
self.try_jaccards_packed(packed_b)
.expect("jaccards_packed failed")
}
}
impl<Scalar: Jaccards, Alloc: Allocator, const MAX_RANK: usize> Tensor<Scalar, Alloc, MAX_RANK> {
pub fn try_jaccards_packed_into<PackedAlloc, OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::JaccardResult, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::JaccardResult, _, OUTPUT_MAX_RANK>(c, height, width)?;
unsafe {
Scalar::jaccards_packed(
self.as_ptr(),
packed_b.as_ptr(),
c.as_mut_ptr(),
height,
width,
depth,
self.stride_bytes(0) as usize,
c.stride_bytes(0) as usize,
);
}
Ok(())
}
}
#[cfg(feature = "parallel")]
impl<Scalar: Hammings + Clone + Send + Sync, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
{
pub fn try_hammings_packed_parallel_into<
PackedAlloc,
OutputTensor,
const OUTPUT_MAX_RANK: usize,
>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<u32, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<u32, _, OUTPUT_MAX_RANK>(c, height, width)?;
let a_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let c_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let packed_ptr = fork_union::SyncConstPtr::new(packed_b.as_ptr());
let a_stride = self.stride_bytes(0) as usize;
let c_stride = c.stride_bytes(0) as usize;
let num_threads = pool.threads().max(1);
let rows_per_thread = height.div_ceil(num_threads);
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let row_start = thread_index * rows_per_thread;
let row_end = (row_start + rows_per_thread).min(height);
if row_start < height {
unsafe {
let a_row =
(a_ptr.as_ptr() as *const u8).add(row_start * a_stride) as *const Scalar;
let c_row = (c_ptr.as_ptr() as *mut u8).add(row_start * c_stride) as *mut u32;
Scalar::hammings_packed(
a_row,
packed_ptr.as_ptr(),
c_row,
row_end - row_start,
width,
depth,
a_stride,
c_stride,
);
}
}
})
.join();
Ok(())
}
pub fn try_hammings_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<u32, Global, MAX_RANK>, TensorError> {
let height = self.shape()[0];
let (width, _) = packed_b.dims();
let mut c = Tensor::<u32, Global, MAX_RANK>::try_full(&[height, width], 0u32)?;
self.try_hammings_packed_parallel_into(packed_b, &mut c, pool)?;
Ok(c)
}
pub fn hammings_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Tensor<u32, Global, MAX_RANK> {
self.try_hammings_packed_parallel(packed_b, pool)
.expect("parallel hammings_packed failed")
}
pub fn try_hammings_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<u32, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<u32, Global, MAX_RANK>::try_full(&[n_vectors, n_vectors], 0u32)?;
self.try_hammings_symmetric_parallel_into(&mut result, pool)?;
Ok(result)
}
pub fn try_hammings_symmetric_parallel_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<u32, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<u32, _, OUTPUT_MAX_RANK>(c, n_vectors, n_vectors)?;
let num_threads = pool.threads().max(1);
let vectors_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let result_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let stride = self.stride_bytes(0) as usize;
let result_stride = c.stride_bytes(0) as usize;
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let (row_start, row_count) = compute_thread_rows(thread_index, num_threads, n_vectors);
unsafe {
Scalar::hammings_symmetric(
vectors_ptr.as_ptr(),
n_vectors,
depth,
stride,
result_ptr.as_ptr(),
result_stride,
row_start,
row_count,
);
}
})
.join();
Ok(())
}
pub fn hammings_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Tensor<u32, Global, MAX_RANK> {
self.try_hammings_symmetric_parallel(pool)
.expect("parallel hammings_symmetric failed")
}
}
#[cfg(feature = "parallel")]
impl<Scalar: Jaccards + Clone + Send + Sync, Alloc: Allocator + Clone, const MAX_RANK: usize>
Tensor<Scalar, Alloc, MAX_RANK>
where
Scalar::JaccardResult: Send + Sync,
{
pub fn try_jaccards_packed_parallel_into<
PackedAlloc,
OutputTensor,
const OUTPUT_MAX_RANK: usize,
>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
PackedAlloc: Allocator,
OutputTensor: TensorMut<Scalar::JaccardResult, OUTPUT_MAX_RANK>,
{
let (height, width, depth) = validate_packed_input(self, packed_b)?;
validate_matrix_output::<Scalar::JaccardResult, _, OUTPUT_MAX_RANK>(c, height, width)?;
let a_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let c_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let packed_ptr = fork_union::SyncConstPtr::new(packed_b.as_ptr());
let a_stride = self.stride_bytes(0) as usize;
let c_stride = c.stride_bytes(0) as usize;
let num_threads = pool.threads().max(1);
let rows_per_thread = height.div_ceil(num_threads);
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let row_start = thread_index * rows_per_thread;
let row_end = (row_start + rows_per_thread).min(height);
if row_start < height {
unsafe {
let a_row =
(a_ptr.as_ptr() as *const u8).add(row_start * a_stride) as *const Scalar;
let c_row = (c_ptr.as_ptr() as *mut u8).add(row_start * c_stride)
as *mut Scalar::JaccardResult;
Scalar::jaccards_packed(
a_row,
packed_ptr.as_ptr(),
c_row,
row_end - row_start,
width,
depth,
a_stride,
c_stride,
);
}
}
})
.join();
Ok(())
}
pub fn try_jaccards_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::JaccardResult, Global, MAX_RANK>, TensorError> {
let height = self.shape()[0];
let (width, _) = packed_b.dims();
let mut c = Tensor::<Scalar::JaccardResult, Global, MAX_RANK>::try_full(
&[height, width],
Scalar::JaccardResult::default(),
)?;
self.try_jaccards_packed_parallel_into(packed_b, &mut c, pool)?;
Ok(c)
}
pub fn jaccards_packed_parallel<PackedAlloc: Allocator>(
&self,
packed_b: &PackedMatrix<Scalar, PackedAlloc>,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::JaccardResult, Global, MAX_RANK> {
self.try_jaccards_packed_parallel(packed_b, pool)
.expect("parallel jaccards_packed failed")
}
pub fn try_jaccards_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Result<Tensor<Scalar::JaccardResult, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::JaccardResult, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::JaccardResult::default(),
)?;
self.try_jaccards_symmetric_parallel_into(&mut result, pool)?;
Ok(result)
}
pub fn try_jaccards_symmetric_parallel_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
pool: &mut fork_union::ThreadPool,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::JaccardResult, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::JaccardResult, _, OUTPUT_MAX_RANK>(
c, n_vectors, n_vectors,
)?;
let num_threads = pool.threads().max(1);
let vectors_ptr = fork_union::SyncConstPtr::new(self.as_ptr());
let result_ptr = fork_union::SyncMutPtr::new(c.as_mut_ptr());
let stride = self.stride_bytes(0) as usize;
let result_stride = c.stride_bytes(0) as usize;
pool.for_threads(move |thread_index, _colocation_index| {
crate::capabilities::configure_thread();
let (row_start, row_count) = compute_thread_rows(thread_index, num_threads, n_vectors);
unsafe {
Scalar::jaccards_symmetric(
vectors_ptr.as_ptr(),
n_vectors,
depth,
stride,
result_ptr.as_ptr(),
result_stride,
row_start,
row_count,
);
}
})
.join();
Ok(())
}
pub fn jaccards_symmetric_parallel(
&self,
pool: &mut fork_union::ThreadPool,
) -> Tensor<Scalar::JaccardResult, Global, MAX_RANK> {
self.try_jaccards_symmetric_parallel(pool)
.expect("parallel jaccards_symmetric failed")
}
}
impl<'a, Scalar: Dots, const MAX_RANK: usize> TensorView<'a, Scalar, MAX_RANK>
where
Scalar::Accumulator: 'static,
{
pub fn try_dots_symmetric(
&self,
) -> Result<Tensor<Scalar::Accumulator, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::Accumulator, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::Accumulator::default(),
)?;
self.try_dots_symmetric_into(&mut result)?;
Ok(result)
}
pub fn try_dots_symmetric_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::Accumulator, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::Accumulator, _, OUTPUT_MAX_RANK>(c, n_vectors, n_vectors)?;
unsafe {
Scalar::dots_symmetric(
self.as_ptr(),
n_vectors,
depth,
self.stride_bytes(0) as usize,
c.as_mut_ptr(),
c.stride_bytes(0) as usize,
0,
n_vectors,
);
}
Ok(())
}
}
impl<'a, Scalar: Angulars, const MAX_RANK: usize> TensorView<'a, Scalar, MAX_RANK> {
pub fn try_angulars_symmetric(
&self,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::SpatialResult, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::SpatialResult::default(),
)?;
self.try_angulars_symmetric_into(&mut result)?;
Ok(result)
}
pub fn try_angulars_symmetric_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(
c, n_vectors, n_vectors,
)?;
unsafe {
Scalar::angulars_symmetric(
self.as_ptr(),
n_vectors,
depth,
self.stride_bytes(0) as usize,
c.as_mut_ptr(),
c.stride_bytes(0) as usize,
0,
n_vectors,
);
}
Ok(())
}
}
impl<'a, Scalar: Euclideans, const MAX_RANK: usize> TensorView<'a, Scalar, MAX_RANK> {
pub fn try_euclideans_symmetric(
&self,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::SpatialResult, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::SpatialResult::default(),
)?;
self.try_euclideans_symmetric_into(&mut result)?;
Ok(result)
}
pub fn try_euclideans_symmetric_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::SpatialResult, _, OUTPUT_MAX_RANK>(
c, n_vectors, n_vectors,
)?;
unsafe {
Scalar::euclideans_symmetric(
self.as_ptr(),
n_vectors,
depth,
self.stride_bytes(0) as usize,
c.as_mut_ptr(),
c.stride_bytes(0) as usize,
0,
n_vectors,
);
}
Ok(())
}
}
impl<'a, Scalar: Hammings, const MAX_RANK: usize> TensorView<'a, Scalar, MAX_RANK> {
pub fn try_hammings_symmetric(&self) -> Result<Tensor<u32, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result =
Tensor::<u32, Global, MAX_RANK>::try_full(&[n_vectors, n_vectors], u32::default())?;
self.try_hammings_symmetric_into(&mut result)?;
Ok(result)
}
pub fn try_hammings_symmetric_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<u32, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<u32, _, OUTPUT_MAX_RANK>(c, n_vectors, n_vectors)?;
unsafe {
Scalar::hammings_symmetric(
self.as_ptr(),
n_vectors,
depth,
self.stride_bytes(0) as usize,
c.as_mut_ptr(),
c.stride_bytes(0) as usize,
0,
n_vectors,
);
}
Ok(())
}
}
impl<'a, Scalar: Jaccards, const MAX_RANK: usize> TensorView<'a, Scalar, MAX_RANK> {
pub fn try_jaccards_symmetric(
&self,
) -> Result<Tensor<Scalar::JaccardResult, Global, MAX_RANK>, TensorError> {
let (n_vectors, _) = validate_symmetric_input(self)?;
let mut result = Tensor::<Scalar::JaccardResult, Global, MAX_RANK>::try_full(
&[n_vectors, n_vectors],
Scalar::JaccardResult::default(),
)?;
self.try_jaccards_symmetric_into(&mut result)?;
Ok(result)
}
pub fn try_jaccards_symmetric_into<OutputTensor, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut OutputTensor,
) -> Result<(), TensorError>
where
OutputTensor: TensorMut<Scalar::JaccardResult, OUTPUT_MAX_RANK>,
{
let (n_vectors, depth) = validate_symmetric_input(self)?;
validate_matrix_output::<Scalar::JaccardResult, _, OUTPUT_MAX_RANK>(
c, n_vectors, n_vectors,
)?;
unsafe {
Scalar::jaccards_symmetric(
self.as_ptr(),
n_vectors,
depth,
self.stride_bytes(0) as usize,
c.as_mut_ptr(),
c.stride_bytes(0) as usize,
0,
n_vectors,
);
}
Ok(())
}
}
pub trait SymmetricDots<Scalar: Dots, const MAX_RANK: usize>: TensorRef<Scalar, MAX_RANK>
where
Scalar::Accumulator: 'static,
{
fn try_dots_symmetric(
&self,
) -> Result<Tensor<Scalar::Accumulator, Global, MAX_RANK>, TensorError> {
self.view().try_dots_symmetric()
}
fn try_dots_symmetric_into<Out, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut Out,
) -> Result<(), TensorError>
where
Out: TensorMut<Scalar::Accumulator, OUTPUT_MAX_RANK>,
{
self.view().try_dots_symmetric_into(c)
}
}
impl<Scalar: Dots, const R: usize, OutputTensor: TensorRef<Scalar, R>> SymmetricDots<Scalar, R>
for OutputTensor
where
Scalar::Accumulator: 'static,
{
}
pub trait SymmetricAngulars<Scalar: Angulars, const MAX_RANK: usize>:
TensorRef<Scalar, MAX_RANK>
{
fn try_angulars_symmetric(
&self,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
self.view().try_angulars_symmetric()
}
fn try_angulars_symmetric_into<Out, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut Out,
) -> Result<(), TensorError>
where
Out: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
self.view().try_angulars_symmetric_into(c)
}
}
impl<Scalar: Angulars, const R: usize, OutputTensor: TensorRef<Scalar, R>>
SymmetricAngulars<Scalar, R> for OutputTensor
{
}
pub trait SymmetricEuclideans<Scalar: Euclideans, const MAX_RANK: usize>:
TensorRef<Scalar, MAX_RANK>
{
fn try_euclideans_symmetric(
&self,
) -> Result<Tensor<Scalar::SpatialResult, Global, MAX_RANK>, TensorError> {
self.view().try_euclideans_symmetric()
}
fn try_euclideans_symmetric_into<Out, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut Out,
) -> Result<(), TensorError>
where
Out: TensorMut<Scalar::SpatialResult, OUTPUT_MAX_RANK>,
{
self.view().try_euclideans_symmetric_into(c)
}
}
impl<Scalar: Euclideans, const R: usize, OutputTensor: TensorRef<Scalar, R>>
SymmetricEuclideans<Scalar, R> for OutputTensor
{
}
pub trait SymmetricHammings<Scalar: Hammings, const MAX_RANK: usize>:
TensorRef<Scalar, MAX_RANK>
{
fn try_hammings_symmetric(&self) -> Result<Tensor<u32, Global, MAX_RANK>, TensorError> {
self.view().try_hammings_symmetric()
}
fn try_hammings_symmetric_into<Out, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut Out,
) -> Result<(), TensorError>
where
Out: TensorMut<u32, OUTPUT_MAX_RANK>,
{
self.view().try_hammings_symmetric_into(c)
}
}
impl<Scalar: Hammings, const R: usize, OutputTensor: TensorRef<Scalar, R>>
SymmetricHammings<Scalar, R> for OutputTensor
{
}
pub trait SymmetricJaccards<Scalar: Jaccards, const MAX_RANK: usize>:
TensorRef<Scalar, MAX_RANK>
{
fn try_jaccards_symmetric(
&self,
) -> Result<Tensor<Scalar::JaccardResult, Global, MAX_RANK>, TensorError> {
self.view().try_jaccards_symmetric()
}
fn try_jaccards_symmetric_into<Out, const OUTPUT_MAX_RANK: usize>(
&self,
c: &mut Out,
) -> Result<(), TensorError>
where
Out: TensorMut<Scalar::JaccardResult, OUTPUT_MAX_RANK>,
{
self.view().try_jaccards_symmetric_into(c)
}
}
impl<Scalar: Jaccards, const R: usize, OutputTensor: TensorRef<Scalar, R>>
SymmetricJaccards<Scalar, R> for OutputTensor
{
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{FloatLike, NumberLike, TestableType};
use std::sync::Once;
static INIT: Once = Once::new();
fn init_thread() {
INIT.call_once(|| {
crate::capabilities::configure_thread();
});
}
const DIMS: &[(usize, usize, usize)] =
&[(1, 1, 1), (1, 8, 3), (3, 1, 7), (7, 5, 3), (33, 17, 65)];
fn align_depth<Scalar: StorageElement>(depth: usize) -> usize {
let dims_per_value = Scalar::dimensions_per_value();
depth.div_ceil(dims_per_value) * dims_per_value
}
fn check_dots_packed<Scalar: TestableType + Dots>()
where
Scalar::Accumulator: FloatLike + PartialEq + core::fmt::Debug,
{
init_thread();
for &(height, width, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let a = Tensor::<Scalar>::try_full(&[height, depth], Scalar::one()).unwrap();
let b = Tensor::<Scalar>::try_full(&[width, depth], Scalar::one()).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let c = a.dots_packed(&b_packed);
assert_eq!(
c.shape(),
&[height, width],
"shape @ ({height},{width},{depth})"
);
let expected = depth as f64;
let tol = Scalar::atol() + Scalar::rtol() * expected.abs();
for (i, &v) in c.as_slice().iter().enumerate() {
assert!(
(v.to_f64() - expected).abs() <= tol,
"({height},{width},{depth})[{i}]: {} vs {expected} (tol={tol})",
v.to_f64()
);
}
let mut into_tensor = Tensor::<Scalar::Accumulator>::try_full(
&[height, width],
Scalar::Accumulator::default(),
)
.unwrap();
a.try_dots_packed_into(&b_packed, &mut into_tensor).unwrap();
assert_eq!(
c.as_slice(),
into_tensor.as_slice(),
"_into(Tensor) @ ({height},{width},{depth})"
);
let mut into_span_buf = Tensor::<Scalar::Accumulator>::try_full(
&[height, width],
Scalar::Accumulator::default(),
)
.unwrap();
a.try_dots_packed_into(&b_packed, &mut into_span_buf.span())
.unwrap();
assert_eq!(
c.as_slice(),
into_span_buf.as_slice(),
"_into(span) @ ({height},{width},{depth})"
);
}
}
fn check_dots_packed_transposed<Scalar: TestableType + Dots>()
where
Scalar::Accumulator: FloatLike,
{
init_thread();
for &(height, width, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let a = Tensor::<Scalar>::try_full(&[height, depth], Scalar::one()).unwrap();
let b_t = Tensor::<Scalar>::try_full(&[depth, width], Scalar::from_f32(2.0)).unwrap();
let b_packed = PackedMatrix::try_pack_transposed(&b_t).unwrap();
let c = a.dots_packed(&b_packed);
assert_eq!(
c.shape(),
&[height, width],
"shape @ ({height},{width},{depth})"
);
let expected = depth as f64 * 2.0;
let tol = Scalar::atol() + Scalar::rtol() * expected.abs();
for (i, &v) in c.as_slice().iter().enumerate() {
assert!(
(v.to_f64() - expected).abs() <= tol,
"({height},{width},{depth})[{i}]: {} vs {expected} (tol={tol})",
v.to_f64()
);
}
}
}
fn check_angulars_packed<Scalar: TestableType + Angulars>()
where
Scalar::SpatialResult: FloatLike + PartialEq + core::fmt::Debug,
{
init_thread();
let tol = Scalar::atol();
for &(height, width, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let a = Tensor::<Scalar>::try_full(&[height, depth], Scalar::one()).unwrap();
let b = Tensor::<Scalar>::try_full(&[width, depth], Scalar::one()).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let c = a.angulars_packed(&b_packed);
assert_eq!(
c.shape(),
&[height, width],
"shape @ ({height},{width},{depth})"
);
for (i, &v) in c.as_slice().iter().enumerate() {
assert!(
v.to_f64().abs() <= tol,
"({height},{width},{depth})[{i}]: {} vs 0.0 (tol={tol})",
v.to_f64()
);
}
let mut into_tensor = Tensor::<Scalar::SpatialResult>::try_full(
&[height, width],
Scalar::SpatialResult::default(),
)
.unwrap();
a.try_angulars_packed_into(&b_packed, &mut into_tensor)
.unwrap();
assert_eq!(
c.as_slice(),
into_tensor.as_slice(),
"_into(Tensor) @ ({height},{width},{depth})"
);
let mut into_span_buf = Tensor::<Scalar::SpatialResult>::try_full(
&[height, width],
Scalar::SpatialResult::default(),
)
.unwrap();
a.try_angulars_packed_into(&b_packed, &mut into_span_buf.span())
.unwrap();
assert_eq!(
c.as_slice(),
into_span_buf.as_slice(),
"_into(span) @ ({height},{width},{depth})"
);
}
}
fn check_euclideans_packed<Scalar: TestableType + Euclideans>()
where
Scalar::SpatialResult: FloatLike + PartialEq + core::fmt::Debug,
{
init_thread();
let tol = Scalar::atol();
for &(height, width, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let a = Tensor::<Scalar>::try_full(&[height, depth], Scalar::one()).unwrap();
let b = Tensor::<Scalar>::try_full(&[width, depth], Scalar::one()).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let c = a.euclideans_packed(&b_packed);
assert_eq!(
c.shape(),
&[height, width],
"shape @ ({height},{width},{depth})"
);
for (i, &v) in c.as_slice().iter().enumerate() {
assert!(
v.to_f64().abs() <= tol,
"({height},{width},{depth})[{i}]: {} vs 0.0 (tol={tol})",
v.to_f64()
);
}
let mut into_tensor = Tensor::<Scalar::SpatialResult>::try_full(
&[height, width],
Scalar::SpatialResult::default(),
)
.unwrap();
a.try_euclideans_packed_into(&b_packed, &mut into_tensor)
.unwrap();
assert_eq!(
c.as_slice(),
into_tensor.as_slice(),
"_into(Tensor) @ ({height},{width},{depth})"
);
let mut into_span_buf = Tensor::<Scalar::SpatialResult>::try_full(
&[height, width],
Scalar::SpatialResult::default(),
)
.unwrap();
a.try_euclideans_packed_into(&b_packed, &mut into_span_buf.span())
.unwrap();
assert_eq!(
c.as_slice(),
into_span_buf.as_slice(),
"_into(span) @ ({height},{width},{depth})"
);
}
}
#[cfg(feature = "parallel")]
fn check_dots_packed_parallel<Scalar: TestableType + Dots + Send + Sync>()
where
Scalar::Accumulator: PartialEq + core::fmt::Debug + Send + Sync,
{
init_thread();
let mut pool = fork_union::ThreadPool::try_spawn(4).unwrap();
for &(height, width, depth) in DIMS {
let a = Tensor::<Scalar>::try_full(&[height, depth], Scalar::one()).unwrap();
let b = Tensor::<Scalar>::try_full(&[width, depth], Scalar::one()).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let serial = a.dots_packed(&b_packed);
let parallel = a.dots_packed_parallel(&b_packed, &mut pool);
assert_eq!(
serial.as_slice(),
parallel.as_slice(),
"serial != parallel @ ({height},{width},{depth})"
);
}
}
#[cfg(feature = "parallel")]
fn check_angulars_packed_parallel<Scalar: TestableType + Angulars + Send + Sync>()
where
Scalar::SpatialResult: PartialEq + core::fmt::Debug + Send + Sync,
{
init_thread();
let mut pool = fork_union::ThreadPool::try_spawn(4).unwrap();
for &(height, width, depth) in DIMS {
let a = Tensor::<Scalar>::try_full(&[height, depth], Scalar::one()).unwrap();
let b = Tensor::<Scalar>::try_full(&[width, depth], Scalar::one()).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let serial = a.angulars_packed(&b_packed);
let parallel = a.angulars_packed_parallel(&b_packed, &mut pool);
assert_eq!(
serial.as_slice(),
parallel.as_slice(),
"serial != parallel @ ({height},{width},{depth})"
);
}
}
#[cfg(feature = "parallel")]
fn check_euclideans_packed_parallel<Scalar: TestableType + Euclideans + Send + Sync>()
where
Scalar::SpatialResult: PartialEq + core::fmt::Debug + Send + Sync,
{
init_thread();
let mut pool = fork_union::ThreadPool::try_spawn(4).unwrap();
for &(height, width, depth) in DIMS {
let a = Tensor::<Scalar>::try_full(&[height, depth], Scalar::one()).unwrap();
let b = Tensor::<Scalar>::try_full(&[width, depth], Scalar::one()).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let serial = a.euclideans_packed(&b_packed);
let parallel = a.euclideans_packed_parallel(&b_packed, &mut pool);
assert_eq!(
serial.as_slice(),
parallel.as_slice(),
"serial != parallel @ ({height},{width},{depth})"
);
let mut into_span = Tensor::<Scalar::SpatialResult>::try_full(
&[height, width],
Scalar::SpatialResult::default(),
)
.unwrap();
a.try_euclideans_packed_parallel_into(&b_packed, &mut into_span.span(), &mut pool)
.unwrap();
assert_eq!(
serial.as_slice(),
into_span.as_slice(),
"_parallel_into(span)"
);
}
}
#[cfg(feature = "parallel")]
fn check_hammings_packed_parallel_u1() {
init_thread();
let mut pool = fork_union::ThreadPool::try_spawn(4).unwrap();
for &(height, width, depth) in DIMS {
let depth = align_depth::<u1x8>(depth); let a = Tensor::<u1x8>::try_full(&[height, depth], u1x8(0xFF)).unwrap();
let b = Tensor::<u1x8>::try_full(&[width, depth], u1x8(0xFF)).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let serial = a.hammings_packed(&b_packed);
let parallel = a.hammings_packed_parallel(&b_packed, &mut pool);
assert_eq!(
serial.as_slice(),
parallel.as_slice(),
"hammings @ ({height},{width},{depth})"
);
let mut into_span = Tensor::<u32>::try_full(&[height, width], 0u32).unwrap();
a.try_hammings_packed_parallel_into(&b_packed, &mut into_span.span(), &mut pool)
.unwrap();
assert_eq!(
serial.as_slice(),
into_span.as_slice(),
"hammings _parallel_into(span)"
);
let serial_j = a.jaccards_packed(&b_packed);
let parallel_j = a.jaccards_packed_parallel(&b_packed, &mut pool);
assert_eq!(
serial_j.as_slice(),
parallel_j.as_slice(),
"jaccards @ ({height},{width},{depth})"
);
let mut into_span_j = Tensor::<f32>::try_full(&[height, width], 0.0f32).unwrap();
a.try_jaccards_packed_parallel_into(&b_packed, &mut into_span_j.span(), &mut pool)
.unwrap();
assert_eq!(
serial_j.as_slice(),
into_span_j.as_slice(),
"jaccards _parallel_into(span)"
);
}
}
#[cfg(feature = "parallel")]
fn check_symmetric_parallel<Scalar: TestableType + Dots + Angulars + Euclideans + Send + Sync>()
where
Scalar::Accumulator:
Clone + Default + Copy + PartialEq + core::fmt::Debug + Send + Sync + 'static,
<Scalar as Angulars>::SpatialResult:
Clone + Default + Copy + PartialEq + core::fmt::Debug + Send + Sync,
<Scalar as Euclideans>::SpatialResult:
Clone + Default + Copy + PartialEq + core::fmt::Debug + Send + Sync,
{
init_thread();
let mut pool = fork_union::ThreadPool::try_spawn(4).unwrap();
for &(num_vectors, _, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let vectors = Tensor::<Scalar>::try_full(&[num_vectors, depth], Scalar::one()).unwrap();
let serial = vectors.view().try_dots_symmetric().unwrap();
let parallel = vectors.dots_symmetric_parallel(&mut pool);
assert_upper_triangle_eq(
serial.as_slice(),
parallel.as_slice(),
num_vectors,
"dots_symmetric_parallel",
);
let mut into_span = Tensor::<Scalar::Accumulator>::try_full(
&[num_vectors, num_vectors],
Scalar::Accumulator::default(),
)
.unwrap();
vectors
.try_dots_symmetric_parallel_into(&mut into_span.span(), &mut pool)
.unwrap();
assert_upper_triangle_eq(
serial.as_slice(),
into_span.as_slice(),
num_vectors,
"dots_symmetric_parallel_into(span)",
);
let serial_a = vectors.view().try_angulars_symmetric().unwrap();
let parallel_a = vectors.angulars_symmetric_parallel(&mut pool);
assert_upper_triangle_eq(
serial_a.as_slice(),
parallel_a.as_slice(),
num_vectors,
"angulars_symmetric_parallel",
);
let mut into_span_a = Tensor::<<Scalar as Angulars>::SpatialResult>::try_full(
&[num_vectors, num_vectors],
<Scalar as Angulars>::SpatialResult::default(),
)
.unwrap();
vectors
.try_angulars_symmetric_parallel_into(&mut into_span_a.span(), &mut pool)
.unwrap();
assert_upper_triangle_eq(
serial_a.as_slice(),
into_span_a.as_slice(),
num_vectors,
"angulars_symmetric_parallel_into(span)",
);
let serial_e = vectors.view().try_euclideans_symmetric().unwrap();
let parallel_e = vectors.euclideans_symmetric_parallel(&mut pool);
assert_upper_triangle_eq(
serial_e.as_slice(),
parallel_e.as_slice(),
num_vectors,
"euclideans_symmetric_parallel",
);
let mut into_span_e = Tensor::<<Scalar as Euclideans>::SpatialResult>::try_full(
&[num_vectors, num_vectors],
<Scalar as Euclideans>::SpatialResult::default(),
)
.unwrap();
vectors
.try_euclideans_symmetric_parallel_into(&mut into_span_e.span(), &mut pool)
.unwrap();
assert_upper_triangle_eq(
serial_e.as_slice(),
into_span_e.as_slice(),
num_vectors,
"euclideans_symmetric_parallel_into(span)",
);
}
}
#[cfg(feature = "parallel")]
fn check_symmetric_parallel_u1() {
init_thread();
let mut pool = fork_union::ThreadPool::try_spawn(4).unwrap();
for &(num_vectors, _, depth) in DIMS {
let depth = align_depth::<u1x8>(depth); let vectors = Tensor::<u1x8>::try_full(&[num_vectors, depth], u1x8(0xFF)).unwrap();
let serial_h = vectors.view().try_hammings_symmetric().unwrap();
let parallel_h = vectors.hammings_symmetric_parallel(&mut pool);
assert_upper_triangle_eq(
serial_h.as_slice(),
parallel_h.as_slice(),
num_vectors,
"hammings_symmetric_parallel",
);
let mut into_span_h =
Tensor::<u32>::try_full(&[num_vectors, num_vectors], 0u32).unwrap();
vectors
.try_hammings_symmetric_parallel_into(&mut into_span_h.span(), &mut pool)
.unwrap();
assert_upper_triangle_eq(
serial_h.as_slice(),
into_span_h.as_slice(),
num_vectors,
"hammings_symmetric_parallel_into(span)",
);
let serial_j = vectors.view().try_jaccards_symmetric().unwrap();
let parallel_j = vectors.jaccards_symmetric_parallel(&mut pool);
assert_upper_triangle_eq(
serial_j.as_slice(),
parallel_j.as_slice(),
num_vectors,
"jaccards_symmetric_parallel",
);
let mut into_span_j =
Tensor::<f32>::try_full(&[num_vectors, num_vectors], 0.0f32).unwrap();
vectors
.try_jaccards_symmetric_parallel_into(&mut into_span_j.span(), &mut pool)
.unwrap();
assert_upper_triangle_eq(
serial_j.as_slice(),
into_span_j.as_slice(),
num_vectors,
"jaccards_symmetric_parallel_into(span)",
);
}
}
#[test]
fn dots_packed() {
check_dots_packed::<f32>();
check_dots_packed::<f64>();
check_dots_packed::<f16>();
check_dots_packed::<bf16>();
check_dots_packed::<e4m3>();
check_dots_packed::<e5m2>();
check_dots_packed::<e2m3>();
check_dots_packed::<e3m2>();
check_dots_packed::<i8>();
check_dots_packed::<u8>();
check_dots_packed::<i4x2>();
check_dots_packed::<u4x2>();
}
#[test]
fn dots_packed_transposed() {
check_dots_packed_transposed::<f32>();
check_dots_packed_transposed::<f64>();
check_dots_packed_transposed::<f16>();
check_dots_packed_transposed::<bf16>();
check_dots_packed_transposed::<e4m3>();
check_dots_packed_transposed::<e5m2>();
check_dots_packed_transposed::<e2m3>();
check_dots_packed_transposed::<e3m2>();
check_dots_packed_transposed::<i8>();
check_dots_packed_transposed::<u8>();
}
#[test]
fn angulars_packed() {
check_angulars_packed::<f32>();
check_angulars_packed::<f64>();
check_angulars_packed::<f16>();
check_angulars_packed::<bf16>();
check_angulars_packed::<e4m3>();
check_angulars_packed::<e5m2>();
check_angulars_packed::<e2m3>();
check_angulars_packed::<e3m2>();
check_angulars_packed::<i8>();
check_angulars_packed::<u8>();
check_angulars_packed::<i4x2>();
check_angulars_packed::<u4x2>();
}
#[test]
fn euclideans_packed() {
check_euclideans_packed::<f32>();
check_euclideans_packed::<f64>();
check_euclideans_packed::<f16>();
check_euclideans_packed::<bf16>();
check_euclideans_packed::<e4m3>();
check_euclideans_packed::<e5m2>();
check_euclideans_packed::<e2m3>();
check_euclideans_packed::<e3m2>();
check_euclideans_packed::<i8>();
check_euclideans_packed::<u8>();
check_euclideans_packed::<i4x2>();
check_euclideans_packed::<u4x2>();
}
#[test]
#[cfg(feature = "parallel")]
fn packed_parallel() {
check_dots_packed_parallel::<f32>();
check_dots_packed_parallel::<bf16>();
check_angulars_packed_parallel::<f32>();
check_euclideans_packed_parallel::<f32>();
check_hammings_packed_parallel_u1();
}
#[test]
#[cfg(feature = "parallel")]
fn symmetric_parallel() {
check_symmetric_parallel::<f32>();
check_symmetric_parallel_u1();
}
fn assert_upper_triangle_eq<X: Copy + PartialEq + core::fmt::Debug>(
left: &[X],
right: &[X],
n: usize,
tag: &str,
) {
for i in 0..n {
for j in i..n {
let index = i * n + j;
assert_eq!(left[index], right[index], "{tag}[{i},{j}]");
}
}
}
fn check_dots_symmetric<Scalar: TestableType + Dots>()
where
Scalar::Accumulator:
Clone + Default + Copy + FloatLike + PartialEq + core::fmt::Debug + 'static,
{
init_thread();
for &(num_vectors, _num_targets, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let vectors = Tensor::<Scalar>::try_full(&[num_vectors, depth], Scalar::one()).unwrap();
let gram_matrix = vectors.view().try_dots_symmetric().unwrap();
assert_eq!(
gram_matrix.shape(),
&[num_vectors, num_vectors],
"shape @ ({num_vectors},{depth})"
);
let expected = depth as f64;
let tolerance = Scalar::atol() + Scalar::rtol() * expected.abs();
for i in 0..num_vectors {
for j in i..num_vectors {
let value = gram_matrix.as_slice()[i * num_vectors + j];
assert!(
(value.to_f64() - expected).abs() <= tolerance,
"({num_vectors},{depth})[{i},{j}]: {} vs {expected}",
value.to_f64()
);
}
}
let mut into_tensor = Tensor::<Scalar::Accumulator>::try_full(
&[num_vectors, num_vectors],
Scalar::Accumulator::default(),
)
.unwrap();
vectors.try_dots_symmetric_into(&mut into_tensor).unwrap();
assert_upper_triangle_eq(
gram_matrix.as_slice(),
into_tensor.as_slice(),
num_vectors,
"dots_symmetric_into(Tensor)",
);
let mut into_span_buf = Tensor::<Scalar::Accumulator>::try_full(
&[num_vectors, num_vectors],
Scalar::Accumulator::default(),
)
.unwrap();
vectors
.view()
.try_dots_symmetric_into(&mut into_span_buf.span())
.unwrap();
assert_upper_triangle_eq(
gram_matrix.as_slice(),
into_span_buf.as_slice(),
num_vectors,
"dots_symmetric_into(span)",
);
}
}
fn check_angulars_symmetric<Scalar: TestableType + Angulars>()
where
Scalar::SpatialResult:
Clone + Default + Copy + FloatLike + PartialEq + core::fmt::Debug + 'static,
{
init_thread();
let tolerance = Scalar::atol();
for &(num_vectors, _num_targets, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let vectors = Tensor::<Scalar>::try_full(&[num_vectors, depth], Scalar::one()).unwrap();
let gram_matrix = vectors.view().try_angulars_symmetric().unwrap();
assert_eq!(gram_matrix.shape(), &[num_vectors, num_vectors]);
for i in 0..num_vectors {
for j in i..num_vectors {
let value = gram_matrix.as_slice()[i * num_vectors + j];
assert!(
value.to_f64().abs() <= tolerance,
"angular symmetric [{i},{j}]: {}",
value.to_f64()
);
}
}
let mut into_tensor = Tensor::<Scalar::SpatialResult>::try_full(
&[num_vectors, num_vectors],
Scalar::SpatialResult::default(),
)
.unwrap();
vectors
.try_angulars_symmetric_into(&mut into_tensor)
.unwrap();
assert_upper_triangle_eq(
gram_matrix.as_slice(),
into_tensor.as_slice(),
num_vectors,
"angulars_symmetric_into(Tensor)",
);
let mut into_span_buf = Tensor::<Scalar::SpatialResult>::try_full(
&[num_vectors, num_vectors],
Scalar::SpatialResult::default(),
)
.unwrap();
vectors
.view()
.try_angulars_symmetric_into(&mut into_span_buf.span())
.unwrap();
assert_upper_triangle_eq(
gram_matrix.as_slice(),
into_span_buf.as_slice(),
num_vectors,
"angulars_symmetric_into(span)",
);
}
}
fn check_euclideans_symmetric<Scalar: TestableType + Euclideans>()
where
Scalar::SpatialResult:
Clone + Default + Copy + FloatLike + PartialEq + core::fmt::Debug + 'static,
{
init_thread();
let tolerance = Scalar::atol();
for &(num_vectors, _num_targets, depth) in DIMS {
let depth = align_depth::<Scalar>(depth);
let vectors = Tensor::<Scalar>::try_full(&[num_vectors, depth], Scalar::one()).unwrap();
let gram_matrix = vectors.view().try_euclideans_symmetric().unwrap();
assert_eq!(gram_matrix.shape(), &[num_vectors, num_vectors]);
for i in 0..num_vectors {
for j in i..num_vectors {
let value = gram_matrix.as_slice()[i * num_vectors + j];
assert!(
value.to_f64().abs() <= tolerance,
"euclidean symmetric [{i},{j}]: {}",
value.to_f64()
);
}
}
let mut into_tensor = Tensor::<Scalar::SpatialResult>::try_full(
&[num_vectors, num_vectors],
Scalar::SpatialResult::default(),
)
.unwrap();
vectors
.try_euclideans_symmetric_into(&mut into_tensor)
.unwrap();
assert_upper_triangle_eq(
gram_matrix.as_slice(),
into_tensor.as_slice(),
num_vectors,
"euclideans_symmetric_into(Tensor)",
);
let mut into_span_buf = Tensor::<Scalar::SpatialResult>::try_full(
&[num_vectors, num_vectors],
Scalar::SpatialResult::default(),
)
.unwrap();
vectors
.view()
.try_euclideans_symmetric_into(&mut into_span_buf.span())
.unwrap();
assert_upper_triangle_eq(
gram_matrix.as_slice(),
into_span_buf.as_slice(),
num_vectors,
"euclideans_symmetric_into(span)",
);
}
}
#[test]
fn dots_symmetric() {
check_dots_symmetric::<f32>();
check_dots_symmetric::<f64>();
check_dots_symmetric::<f16>();
check_dots_symmetric::<bf16>();
check_dots_symmetric::<e4m3>();
check_dots_symmetric::<e5m2>();
check_dots_symmetric::<e2m3>();
check_dots_symmetric::<e3m2>();
check_dots_symmetric::<i8>();
check_dots_symmetric::<u8>();
check_dots_symmetric::<i4x2>();
check_dots_symmetric::<u4x2>();
}
#[test]
fn angulars_symmetric() {
check_angulars_symmetric::<f32>();
check_angulars_symmetric::<f64>();
check_angulars_symmetric::<f16>();
check_angulars_symmetric::<bf16>();
check_angulars_symmetric::<e4m3>();
check_angulars_symmetric::<e5m2>();
check_angulars_symmetric::<e2m3>();
check_angulars_symmetric::<e3m2>();
check_angulars_symmetric::<i8>();
check_angulars_symmetric::<u8>();
check_angulars_symmetric::<i4x2>();
check_angulars_symmetric::<u4x2>();
}
#[test]
fn euclideans_symmetric() {
check_euclideans_symmetric::<f32>();
check_euclideans_symmetric::<f64>();
check_euclideans_symmetric::<f16>();
check_euclideans_symmetric::<bf16>();
check_euclideans_symmetric::<e4m3>();
check_euclideans_symmetric::<e5m2>();
check_euclideans_symmetric::<e2m3>();
check_euclideans_symmetric::<e3m2>();
check_euclideans_symmetric::<i8>();
check_euclideans_symmetric::<u8>();
check_euclideans_symmetric::<i4x2>();
check_euclideans_symmetric::<u4x2>();
}
#[test]
fn binary_packed_u1() {
init_thread();
let a = Tensor::<u1x8>::try_full(&[4, 64], u1x8(0xFF)).unwrap();
let b = Tensor::<u1x8>::try_full(&[16, 64], u1x8(0xFF)).unwrap();
let b_packed = PackedMatrix::try_pack(&b).unwrap();
let c = a.dots_packed(&b_packed);
assert_eq!(c.shape(), &[4, 16]);
assert_eq!(c.as_slice()[0], 64);
let c_h = a.hammings_packed(&b_packed);
assert_eq!(c_h.shape(), &[4, 16]);
assert_eq!(c_h.as_slice()[0], 0);
let mut c_h_into = Tensor::<u32>::try_full(&[4, 16], 0u32).unwrap();
a.try_hammings_packed_into(&b_packed, &mut c_h_into.span())
.unwrap();
assert_eq!(c_h.as_slice(), c_h_into.as_slice());
let c_j = a.jaccards_packed(&b_packed);
assert_eq!(c_j.shape(), &[4, 16]);
assert!(c_j.as_slice()[0].abs() < 1e-5);
let mut c_j_into = Tensor::<f32>::try_full(&[4, 16], 0.0f32).unwrap();
a.try_jaccards_packed_into(&b_packed, &mut c_j_into.span())
.unwrap();
assert_eq!(c_j.as_slice(), c_j_into.as_slice());
}
#[test]
fn binary_symmetric_u1() {
init_thread();
let a = Tensor::<u1x8>::try_full(&[4, 64], u1x8(0xFF)).unwrap();
let gram = a.view().try_dots_symmetric().unwrap();
assert_eq!(gram.shape(), &[4, 4]);
assert_eq!(gram.as_slice()[0], 64);
let gram_h = a.try_hammings_symmetric().unwrap();
assert_eq!(gram_h.shape(), &[4, 4]);
assert_eq!(gram_h.as_slice()[0], 0);
let mut gram_h_into = Tensor::<u32>::try_full(&[4, 4], 0u32).unwrap();
a.view()
.try_hammings_symmetric_into(&mut gram_h_into.span())
.unwrap();
assert_upper_triangle_eq(gram_h.as_slice(), gram_h_into.as_slice(), 4, "hammings");
let gram_j = a.try_jaccards_symmetric().unwrap();
assert_eq!(gram_j.shape(), &[4, 4]);
assert!(gram_j.as_slice()[0].abs() < 1e-5);
let mut gram_j_into = Tensor::<f32>::try_full(&[4, 4], 0.0f32).unwrap();
a.view()
.try_jaccards_symmetric_into(&mut gram_j_into.span())
.unwrap();
assert_upper_triangle_eq(gram_j.as_slice(), gram_j_into.as_slice(), 4, "jaccards");
}
}