extern crate alloc;
use core::marker::PhantomData;
use core::ptr::NonNull;
use crate::tensor::{Allocator, Global, TensorError, TensorView, SIMD_ALIGNMENT};
use crate::types::{bf16, f16, StorageElement};
#[link(name = "numkong")]
extern "C" {
fn nk_maxsim_packed_size_f32(vector_count: usize, depth: usize) -> usize;
fn nk_maxsim_pack_f32(
v: *const f32,
vector_count: usize,
depth: usize,
stride: usize,
packed: *mut u8,
);
fn nk_maxsim_packed_f32(
q: *const u8,
d: *const u8,
query_count: usize,
document_count: usize,
depth: usize,
result: *mut f64,
);
fn nk_maxsim_packed_size_f16(vector_count: usize, depth: usize) -> usize;
fn nk_maxsim_pack_f16(
v: *const f16,
vector_count: usize,
depth: usize,
stride: usize,
packed: *mut u8,
);
fn nk_maxsim_packed_f16(
q: *const u8,
d: *const u8,
query_count: usize,
document_count: usize,
depth: usize,
result: *mut f32,
);
fn nk_maxsim_packed_size_bf16(vector_count: usize, depth: usize) -> usize;
fn nk_maxsim_pack_bf16(
v: *const bf16,
vector_count: usize,
depth: usize,
stride: usize,
packed: *mut u8,
);
fn nk_maxsim_packed_bf16(
q: *const u8,
d: *const u8,
query_count: usize,
document_count: usize,
depth: usize,
result: *mut f32,
);
}
pub trait MaxSim: StorageElement + Clone {
type Score: Clone + Default;
fn maxsim_packed_size(vector_count: usize, depth: usize) -> usize;
unsafe fn maxsim_pack(
vectors: *const Self,
vector_count: usize,
depth: usize,
stride: usize,
packed: *mut u8,
);
unsafe fn maxsim_packed(
q: *const u8,
d: *const u8,
query_count: usize,
document_count: usize,
depth: usize,
result: *mut Self::Score,
);
}
impl MaxSim for f32 {
type Score = f64;
fn maxsim_packed_size(vector_count: usize, depth: usize) -> usize {
unsafe { nk_maxsim_packed_size_f32(vector_count, depth) }
}
unsafe fn maxsim_pack(
vectors: *const Self,
vector_count: usize,
depth: usize,
stride: usize,
packed: *mut u8,
) {
nk_maxsim_pack_f32(vectors, vector_count, depth, stride, packed)
}
unsafe fn maxsim_packed(
q: *const u8,
d: *const u8,
query_count: usize,
document_count: usize,
depth: usize,
result: *mut Self::Score,
) {
nk_maxsim_packed_f32(q, d, query_count, document_count, depth, result)
}
}
impl MaxSim for f16 {
type Score = f32;
fn maxsim_packed_size(vector_count: usize, depth: usize) -> usize {
unsafe { nk_maxsim_packed_size_f16(vector_count, depth) }
}
unsafe fn maxsim_pack(
vectors: *const Self,
vector_count: usize,
depth: usize,
stride: usize,
packed: *mut u8,
) {
nk_maxsim_pack_f16(vectors, vector_count, depth, stride, packed)
}
unsafe fn maxsim_packed(
q: *const u8,
d: *const u8,
query_count: usize,
document_count: usize,
depth: usize,
result: *mut Self::Score,
) {
nk_maxsim_packed_f16(q, d, query_count, document_count, depth, result)
}
}
impl MaxSim for bf16 {
type Score = f32;
fn maxsim_packed_size(vector_count: usize, depth: usize) -> usize {
unsafe { nk_maxsim_packed_size_bf16(vector_count, depth) }
}
unsafe fn maxsim_pack(
vectors: *const Self,
vector_count: usize,
depth: usize,
stride: usize,
packed: *mut u8,
) {
nk_maxsim_pack_bf16(vectors, vector_count, depth, stride, packed)
}
unsafe fn maxsim_packed(
q: *const u8,
d: *const u8,
query_count: usize,
document_count: usize,
depth: usize,
result: *mut Self::Score,
) {
nk_maxsim_packed_bf16(q, d, query_count, document_count, depth, result)
}
}
pub struct MaxSimPackedMatrix<Scalar: MaxSim, Alloc: Allocator = Global> {
data: NonNull<u8>,
size: usize,
vector_count: usize,
depth: usize,
alloc: Alloc,
_marker: PhantomData<Scalar>,
}
unsafe impl<Scalar: MaxSim + Send, Alloc: Allocator + Send> Send
for MaxSimPackedMatrix<Scalar, Alloc>
{
}
unsafe impl<Scalar: MaxSim + Sync, Alloc: Allocator + Sync> Sync
for MaxSimPackedMatrix<Scalar, Alloc>
{
}
impl<Scalar: MaxSim, Alloc: Allocator> Drop for MaxSimPackedMatrix<Scalar, Alloc> {
fn drop(&mut self) {
if self.size > 0 {
unsafe {
let layout =
alloc::alloc::Layout::from_size_align_unchecked(self.size, SIMD_ALIGNMENT);
self.alloc.deallocate(self.data, layout);
}
}
}
}
impl<Scalar: MaxSim, Alloc: Allocator + Clone> MaxSimPackedMatrix<Scalar, Alloc> {
pub fn try_clone(&self) -> Result<Self, TensorError> {
if self.size == 0 {
return Ok(Self {
data: NonNull::dangling(),
size: 0,
vector_count: self.vector_count,
depth: self.depth,
alloc: self.alloc.clone(),
_marker: PhantomData,
});
}
let layout = alloc::alloc::Layout::from_size_align(self.size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = self
.alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe {
core::ptr::copy_nonoverlapping(self.data.as_ptr(), ptr.as_ptr(), self.size);
}
Ok(Self {
data: ptr,
size: self.size,
vector_count: self.vector_count,
depth: self.depth,
alloc: self.alloc.clone(),
_marker: PhantomData,
})
}
}
impl<Scalar: MaxSim, Alloc: Allocator + Clone> Clone for MaxSimPackedMatrix<Scalar, Alloc> {
fn clone(&self) -> Self {
self.try_clone()
.expect("MaxSimPackedMatrix clone allocation failed")
}
}
impl<Scalar: MaxSim, Alloc: Allocator> MaxSimPackedMatrix<Scalar, Alloc> {
pub fn try_pack_in<const MAX_RANK: usize>(
vectors: &TensorView<'_, Scalar, MAX_RANK>,
alloc: Alloc,
) -> Result<Self, TensorError> {
let (vector_count, depth, row_stride_bytes) = validate_maxsim_view(vectors)?;
let size = Scalar::maxsim_packed_size(vector_count, depth);
let data = if size == 0 {
NonNull::dangling()
} else {
let layout = alloc::alloc::Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|_| TensorError::AllocationFailed)?;
let ptr = alloc
.allocate(layout)
.ok_or(TensorError::AllocationFailed)?;
unsafe {
core::ptr::write_bytes(ptr.as_ptr(), 0, size);
}
ptr
};
if size > 0 {
unsafe {
Scalar::maxsim_pack(
vectors.as_ptr(),
vector_count,
depth,
row_stride_bytes,
data.as_ptr(),
);
}
}
Ok(Self {
data,
size,
vector_count,
depth,
alloc,
_marker: PhantomData,
})
}
pub fn try_score<OtherAlloc: Allocator>(
&self,
other: &MaxSimPackedMatrix<Scalar, OtherAlloc>,
) -> Result<Scalar::Score, TensorError> {
if self.depth != other.depth {
return Err(TensorError::DimensionMismatch {
expected: self.depth,
got: other.depth,
});
}
let mut score = Scalar::Score::default();
unsafe {
Scalar::maxsim_packed(
self.as_ptr(),
other.as_ptr(),
self.vector_count,
other.vector_count,
self.depth,
&mut score,
)
};
Ok(score)
}
pub fn score<OtherAlloc: Allocator>(
&self,
other: &MaxSimPackedMatrix<Scalar, OtherAlloc>,
) -> Scalar::Score {
self.try_score(other)
.expect("MaxSimPackedMatrix::score failed")
}
pub fn allocator(&self) -> &Alloc {
&self.alloc
}
pub fn dims(&self) -> (usize, usize) {
(self.vector_count, self.depth)
}
pub fn as_bytes(&self) -> &[u8] {
unsafe { core::slice::from_raw_parts(self.data.as_ptr(), self.size) }
}
pub fn as_ptr(&self) -> *const u8 {
self.data.as_ptr()
}
}
impl<Scalar: MaxSim> MaxSimPackedMatrix<Scalar, Global> {
pub fn try_pack<const MAX_RANK: usize>(
vectors: &TensorView<'_, Scalar, MAX_RANK>,
) -> Result<Self, TensorError> {
Self::try_pack_in(vectors, Global)
}
pub fn pack<const MAX_RANK: usize>(vectors: &TensorView<'_, Scalar, MAX_RANK>) -> Self {
Self::try_pack(vectors).expect("MaxSimPackedMatrix::pack failed")
}
}
fn validate_maxsim_view<Scalar, const MAX_RANK: usize>(
vectors: &TensorView<'_, Scalar, MAX_RANK>,
) -> Result<(usize, usize, usize), TensorError> {
if vectors.ndim() != 2 {
return Err(TensorError::DimensionMismatch {
expected: 2,
got: vectors.ndim(),
});
}
if !vectors.has_contiguous_rows() {
return Err(TensorError::NonContiguousRows);
}
let row_stride_bytes = vectors.stride_bytes(0);
if row_stride_bytes < 0 {
return Err(TensorError::InvalidShape {
axis: 0,
size: row_stride_bytes as usize,
reason: "MaxSim requires non-negative row strides",
});
}
Ok((
vectors.shape()[0],
vectors.shape()[1],
row_stride_bytes as usize,
))
}
impl<'a, Scalar: MaxSim, const MAX_RANK: usize> TensorView<'a, Scalar, MAX_RANK> {
pub fn try_maxsim_pack_in<Alloc: Allocator>(
&self,
alloc: Alloc,
) -> Result<MaxSimPackedMatrix<Scalar, Alloc>, TensorError> {
MaxSimPackedMatrix::try_pack_in(self, alloc)
}
pub fn try_maxsim_pack(&self) -> Result<MaxSimPackedMatrix<Scalar, Global>, TensorError> {
self.try_maxsim_pack_in(Global)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::{SliceRange, Tensor};
#[test]
fn maxsim_packs_from_tensor_view() {
let queries = Tensor::<f32>::try_full(&[4, 16], 1.0).unwrap();
let docs = Tensor::<f32>::try_full(&[8, 16], 1.0).unwrap();
let queries_packed = queries.view().try_maxsim_pack().unwrap();
let docs_packed = docs.view().try_maxsim_pack().unwrap();
assert_eq!(queries_packed.dims(), (4, 16));
assert_eq!(docs_packed.dims(), (8, 16));
assert!(queries_packed.score(&docs_packed).is_finite());
}
#[test]
fn maxsim_rejects_non_contiguous_depth_axis() {
let queries = Tensor::<f32>::try_full(&[4, 16], 1.0).unwrap();
let transposed = queries.transpose().unwrap();
let result = transposed.try_maxsim_pack();
assert!(matches!(result, Err(TensorError::NonContiguousRows)));
}
#[test]
fn maxsim_accepts_outer_strided_views() {
let queries = Tensor::<f32>::try_full(&[8, 16], 1.0).unwrap();
let odd_rows = queries
.slice(&[
SliceRange::range_step(1, 7, 2),
SliceRange::range_step(0, 16, 1),
])
.unwrap();
let queries_packed = odd_rows.try_maxsim_pack().unwrap();
assert_eq!(queries_packed.dims(), (3, 16));
}
#[test]
fn maxsim_rejects_negative_row_stride() {
let queries = Tensor::<f32>::try_full(&[8, 16], 1.0).unwrap();
let reversed_rows = queries
.slice(&[
SliceRange::range_step(7, 0, -1),
SliceRange::range_step(0, 16, 1),
])
.unwrap();
let result = reversed_rows.try_maxsim_pack();
assert!(matches!(result, Err(TensorError::InvalidShape { .. })));
}
}