use core::marker::PhantomData;
use diskann_vector::conversion::SliceCast;
use diskann_wide::Architecture;
use diskann_wide::arch::Target2;
pub(super) trait Layout {
type Element: Copy;
}
pub(super) struct BlockTransposed<T, const GROUP: usize, const PACK: usize = 1>(PhantomData<T>);
impl<T, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> {
pub(super) fn new() -> Self {
Self(PhantomData)
}
}
impl<T, const GROUP: usize, const PACK: usize> Copy for BlockTransposed<T, GROUP, PACK> {}
impl<T, const GROUP: usize, const PACK: usize> Clone for BlockTransposed<T, GROUP, PACK> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Copy, const GROUP: usize, const PACK: usize> Layout for BlockTransposed<T, GROUP, PACK> {
type Element = T;
}
pub(super) struct RowMajor<T>(PhantomData<T>);
impl<T> RowMajor<T> {
pub(super) fn new() -> Self {
Self(PhantomData)
}
}
impl<T> Copy for RowMajor<T> {}
impl<T> Clone for RowMajor<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Copy> Layout for RowMajor<T> {
type Element = T;
}
pub(super) trait DescribeLayout {
type Layout: Layout;
fn layout(&self) -> Self::Layout;
}
impl<T: Copy, const GROUP: usize, const PACK: usize> DescribeLayout
for crate::multi_vector::BlockTransposedRef<'_, T, GROUP, PACK>
{
type Layout = BlockTransposed<T, GROUP, PACK>;
fn layout(&self) -> Self::Layout {
BlockTransposed::new()
}
}
impl<T: Copy> DescribeLayout for crate::multi_vector::MatRef<'_, crate::multi_vector::Standard<T>> {
type Layout = RowMajor<T>;
fn layout(&self) -> Self::Layout {
RowMajor::new()
}
}
pub(super) unsafe trait ConvertTo<A: Architecture, To: Layout>: Layout {
type Buffer;
fn new_buffer(&self, max_tile_rows: usize, k: usize) -> Self::Buffer;
unsafe fn convert(
&self,
buf: &mut Self::Buffer,
arch: A,
src: *const Self::Element,
rows: usize,
k: usize,
) -> *const To::Element;
}
unsafe impl<A: Architecture, L: Layout> ConvertTo<A, L> for L {
type Buffer = ();
fn new_buffer(&self, _max_tile_rows: usize, _k: usize) {}
unsafe fn convert(
&self,
_buf: &mut (),
_arch: A,
src: *const L::Element,
_rows: usize,
_k: usize,
) -> *const L::Element {
src
}
}
unsafe impl<A, const GROUP: usize, const PACK: usize>
ConvertTo<A, BlockTransposed<f32, GROUP, PACK>> for BlockTransposed<half::f16, GROUP, PACK>
where
A: Architecture,
SliceCast<f32, half::f16>: for<'a> Target2<A, (), &'a mut [f32], &'a [half::f16]>,
{
type Buffer = Vec<f32>;
fn new_buffer(&self, max_tile_rows: usize, k: usize) -> Vec<f32> {
vec![0.0f32; max_tile_rows * k]
}
unsafe fn convert(
&self,
buf: &mut Vec<f32>,
arch: A,
src: *const half::f16,
rows: usize,
k: usize,
) -> *const f32 {
let count = rows * k;
let src_slice = unsafe { std::slice::from_raw_parts(src, count) };
arch.run2(SliceCast::new(), &mut buf[..count], src_slice);
buf.as_ptr()
}
}
unsafe impl<A> ConvertTo<A, RowMajor<f32>> for RowMajor<half::f16>
where
A: Architecture,
SliceCast<f32, half::f16>: for<'a> Target2<A, (), &'a mut [f32], &'a [half::f16]>,
{
type Buffer = Vec<f32>;
fn new_buffer(&self, max_tile_rows: usize, k: usize) -> Vec<f32> {
vec![0.0f32; max_tile_rows * k]
}
unsafe fn convert(
&self,
buf: &mut Vec<f32>,
arch: A,
src: *const half::f16,
rows: usize,
k: usize,
) -> *const f32 {
let count = rows * k;
let src_slice = unsafe { std::slice::from_raw_parts(src, count) };
arch.run2(SliceCast::new(), &mut buf[..count], src_slice);
buf.as_ptr()
}
}