use crate::directories::{FileHandle, OwnedBytes};
pub const BMP_SUPERBLOCK_SIZE: u32 = 64;
#[derive(Clone, Copy)]
#[repr(C)]
pub struct BmpPosting {
pub local_slot: u8,
pub impact: u8,
}
#[inline(always)]
unsafe fn read_u32_unchecked(base: *const u8, idx: usize) -> u32 {
unsafe {
let p = base.add(idx * 4);
u32::from_le((p as *const u32).read_unaligned())
}
}
#[inline(always)]
unsafe fn read_u64_unchecked(base: *const u8, idx: usize) -> u64 {
unsafe {
let p = base.add(idx * 8);
u64::from_le((p as *const u64).read_unaligned())
}
}
#[derive(Clone)]
pub struct BmpIndex {
pub bmp_block_size: u32,
pub num_blocks: u32,
pub num_virtual_docs: u32,
pub max_weight_scale: f32,
pub total_vectors: u32,
dims: u32,
total_terms: u32,
total_postings: u32,
packed_row_size: u32,
num_real_docs: u32,
block_data_starts_bytes: OwnedBytes,
block_data_bytes: OwnedBytes,
grid_bytes: OwnedBytes,
sb_grid_bytes: OwnedBytes,
pub num_superblocks: u32,
doc_map_ids_bytes: OwnedBytes,
doc_map_ordinals_bytes: OwnedBytes,
}
impl BmpIndex {
pub fn parse(
handle: FileHandle,
blob_offset: u64,
blob_len: u64,
_total_docs: u32,
total_vectors: u32,
) -> crate::Result<Self> {
use crate::segment::format::{BMP_BLOB_FOOTER_SIZE_V13, BMP_BLOB_MAGIC_V13};
if blob_len < BMP_BLOB_FOOTER_SIZE_V13 as u64 {
return Err(crate::Error::Corruption(
"BMP blob too small for V13 footer".into(),
));
}
let footer_start = blob_offset + blob_len - BMP_BLOB_FOOTER_SIZE_V13 as u64;
let footer_bytes = handle
.read_bytes_range_sync(footer_start..footer_start + BMP_BLOB_FOOTER_SIZE_V13 as u64)
.map_err(crate::Error::Io)?;
let fb = footer_bytes.as_slice();
let total_terms = u32::from_le_bytes(fb[0..4].try_into().unwrap());
let total_postings = u32::from_le_bytes(fb[4..8].try_into().unwrap());
let grid_offset = u64::from_le_bytes(fb[8..16].try_into().unwrap());
let sb_grid_offset = u64::from_le_bytes(fb[16..24].try_into().unwrap());
let num_blocks = u32::from_le_bytes(fb[24..28].try_into().unwrap());
let dims = u32::from_le_bytes(fb[28..32].try_into().unwrap());
let bmp_block_size = u32::from_le_bytes(fb[32..36].try_into().unwrap());
let num_virtual_docs = u32::from_le_bytes(fb[36..40].try_into().unwrap());
let max_weight_scale = f32::from_le_bytes(fb[40..44].try_into().unwrap());
let doc_map_offset = u64::from_le_bytes(fb[44..52].try_into().unwrap());
let num_real_docs = u32::from_le_bytes(fb[52..56].try_into().unwrap());
let magic = u32::from_le_bytes(fb[60..64].try_into().unwrap());
if magic != BMP_BLOB_MAGIC_V13 {
return Err(crate::Error::Corruption(format!(
"Invalid BMP blob magic: {:#x} (expected BMP3 {:#x})",
magic, BMP_BLOB_MAGIC_V13
)));
}
if num_blocks == 0 {
return Ok(Self {
bmp_block_size,
num_blocks,
num_virtual_docs,
max_weight_scale,
total_vectors,
dims,
total_terms: 0,
total_postings: 0,
packed_row_size: 0,
num_real_docs,
block_data_starts_bytes: OwnedBytes::empty(),
block_data_bytes: OwnedBytes::empty(),
grid_bytes: OwnedBytes::empty(),
sb_grid_bytes: OwnedBytes::empty(),
num_superblocks: 0,
doc_map_ids_bytes: OwnedBytes::empty(),
doc_map_ordinals_bytes: OwnedBytes::empty(),
});
}
let data_len = blob_len - BMP_BLOB_FOOTER_SIZE_V13 as u64;
let blob = handle
.read_bytes_range_sync(blob_offset..blob_offset + data_len)
.map_err(crate::Error::Io)?;
let section_a_size = (num_blocks as usize + 1) * 8;
let bds_start = grid_offset as usize - section_a_size;
let block_data_bytes = blob.slice(0..bds_start);
let block_data_starts_bytes = blob.slice(bds_start..grid_offset as usize);
let packed_row_size = (num_blocks as usize).div_ceil(2) as u32;
let grid_start = grid_offset as usize;
let grid_end = grid_start + dims as usize * packed_row_size as usize;
let num_superblocks = num_blocks.div_ceil(BMP_SUPERBLOCK_SIZE);
let sb_grid_start = sb_grid_offset as usize;
let sb_grid_end = sb_grid_start + dims as usize * num_superblocks as usize;
let dm_start = doc_map_offset as usize;
let dm_ids_end = dm_start + num_virtual_docs as usize * 4;
let dm_ords_end = dm_ids_end + num_virtual_docs as usize * 2;
let grid_bytes = blob.slice(grid_start..grid_end);
let sb_grid_bytes = blob.slice(sb_grid_start..sb_grid_end);
let doc_map_ids_bytes = blob.slice(dm_start..dm_ids_end);
let doc_map_ordinals_bytes = blob.slice(dm_ids_end..dm_ords_end);
log::debug!(
"BMP V13 index loaded: num_blocks={}, num_superblocks={}, dims={}, bmp_block_size={}, \
num_virtual_docs={}, num_real_docs={}, max_weight_scale={:.4}, postings={}, \
packed_row_size={}, block_data={}B, doc_map={}B",
num_blocks,
num_superblocks,
dims,
bmp_block_size,
num_virtual_docs,
num_real_docs,
max_weight_scale,
total_postings,
packed_row_size,
bds_start,
num_virtual_docs as usize * 6,
);
Ok(Self {
bmp_block_size,
num_blocks,
num_virtual_docs,
max_weight_scale,
total_vectors,
dims,
total_terms,
total_postings,
packed_row_size,
num_real_docs,
block_data_starts_bytes,
block_data_bytes,
grid_bytes,
sb_grid_bytes,
num_superblocks,
doc_map_ids_bytes,
doc_map_ordinals_bytes,
})
}
#[inline(always)]
pub fn virtual_to_doc(&self, virtual_id: u32) -> (u32, u16) {
let ids = self.doc_map_ids_bytes.as_slice();
let ords = self.doc_map_ordinals_bytes.as_slice();
debug_assert!((virtual_id as usize + 1) * 4 <= ids.len());
debug_assert!((virtual_id as usize + 1) * 2 <= ords.len());
unsafe {
let doc_id = read_u32_unchecked(ids.as_ptr(), virtual_id as usize);
let p = ords.as_ptr().add(virtual_id as usize * 2);
let ordinal = u16::from_le((p as *const u16).read_unaligned());
(doc_id, ordinal)
}
}
#[inline(always)]
pub fn doc_id_for_virtual(&self, virtual_id: u32) -> u32 {
let d = self.doc_map_ids_bytes.as_slice();
debug_assert!((virtual_id as usize + 1) * 4 <= d.len());
unsafe { read_u32_unchecked(d.as_ptr(), virtual_id as usize) }
}
#[inline(always)]
pub(crate) fn block_data_range(&self, block_id: u32) -> (u64, u64) {
let d = self.block_data_starts_bytes.as_slice();
debug_assert!((block_id as usize + 2) * 8 <= d.len());
unsafe {
let start = read_u64_unchecked(d.as_ptr(), block_id as usize);
let end = read_u64_unchecked(d.as_ptr(), block_id as usize + 1);
(start, end)
}
}
#[inline(always)]
pub(crate) fn block_data_ptr(&self, block_id: u32) -> *const u8 {
let (start, _) = self.block_data_range(block_id);
unsafe {
self.block_data_bytes
.as_slice()
.as_ptr()
.add(start as usize)
}
}
#[inline(always)]
pub(crate) fn parse_block(&self, block_id: u32) -> (u16, *const u8, *const u8, *const u8) {
let (start, end) = self.block_data_range(block_id);
if start == end {
return (0, std::ptr::null(), std::ptr::null(), std::ptr::null());
}
let base = unsafe {
self.block_data_bytes
.as_slice()
.as_ptr()
.add(start as usize)
};
let num_terms = unsafe { u16::from_le((base as *const u16).read_unaligned()) };
let dim_ptr = unsafe { base.add(2) };
let ps_ptr = unsafe { dim_ptr.add(num_terms as usize * 4) };
let post_ptr = unsafe { ps_ptr.add((num_terms as usize + 1) * 2) };
(num_terms, dim_ptr, ps_ptr, post_ptr)
}
#[inline(always)]
pub(crate) fn block_data_starts_ptr(&self, block_id: u32) -> *const u8 {
unsafe {
self.block_data_starts_bytes
.as_slice()
.as_ptr()
.add(block_id as usize * 8)
}
}
pub fn iter_block_terms(&self, block_id: u32) -> BlockTermIter<'_> {
let (num_terms, dim_ptr, ps_ptr, post_ptr) = self.parse_block(block_id);
BlockTermIter {
dim_ptr,
ps_ptr,
post_ptr,
num_terms,
current: 0,
_marker: std::marker::PhantomData,
}
}
pub fn dims(&self) -> u32 {
self.dims
}
pub fn total_terms(&self) -> u64 {
self.total_terms as u64
}
pub fn total_postings(&self) -> u64 {
self.total_postings as u64
}
pub fn num_real_docs(&self) -> u32 {
self.num_real_docs
}
pub fn estimated_memory_bytes(&self) -> usize {
std::mem::size_of::<Self>()
+ self.block_data_starts_bytes.len()
+ self.block_data_bytes.len()
+ self.grid_bytes.len()
+ self.sb_grid_bytes.len()
+ self.doc_map_ids_bytes.len()
+ self.doc_map_ordinals_bytes.len()
}
pub(crate) fn extract_compact_grids(
&self,
dim_indices: &[usize],
compact_sb_grid: &mut Vec<u8>,
compact_grid: &mut Vec<u8>,
) {
let nsb = self.num_superblocks as usize;
let prs = self.packed_row_size as usize;
let nqd = dim_indices.len();
compact_sb_grid.resize(nqd * nsb, 0);
compact_grid.resize(nqd * prs, 0);
let sb_grid = self.sb_grid_bytes.as_slice();
let grid = self.grid_bytes.as_slice();
for (local, &dim_idx) in dim_indices.iter().enumerate() {
compact_sb_grid[local * nsb..(local + 1) * nsb]
.copy_from_slice(&sb_grid[dim_idx * nsb..(dim_idx + 1) * nsb]);
compact_grid[local * prs..(local + 1) * prs]
.copy_from_slice(&grid[dim_idx * prs..(dim_idx + 1) * prs]);
}
}
#[inline]
pub fn packed_row_size(&self) -> usize {
self.packed_row_size as usize
}
#[inline]
pub(crate) fn sb_grid_slice(&self) -> &[u8] {
self.sb_grid_bytes.as_slice()
}
#[inline]
pub fn grid_slice(&self) -> &[u8] {
self.grid_bytes.as_slice()
}
#[inline]
pub fn block_data_slice(&self) -> &[u8] {
self.block_data_bytes.as_slice()
}
#[inline]
pub fn block_data_start(&self, block_id: u32) -> u64 {
let d = self.block_data_starts_bytes.as_slice();
let off = block_id as usize * 8;
u64::from_le_bytes(d[off..off + 8].try_into().unwrap())
}
#[inline]
pub fn block_data_sentinel(&self) -> u64 {
self.block_data_start(self.num_blocks)
}
#[inline]
pub fn doc_map_ids_slice(&self) -> &[u8] {
self.doc_map_ids_bytes.as_slice()
}
#[inline]
pub fn doc_map_ordinals_slice(&self) -> &[u8] {
self.doc_map_ordinals_bytes.as_slice()
}
#[cfg(feature = "native")]
pub fn madvise_sequential(&self) {
Self::madvise_owned(&self.block_data_bytes, libc::MADV_SEQUENTIAL);
Self::madvise_owned(&self.block_data_starts_bytes, libc::MADV_SEQUENTIAL);
Self::madvise_owned(&self.grid_bytes, libc::MADV_SEQUENTIAL);
Self::madvise_owned(&self.sb_grid_bytes, libc::MADV_SEQUENTIAL);
Self::madvise_owned(&self.doc_map_ids_bytes, libc::MADV_SEQUENTIAL);
Self::madvise_owned(&self.doc_map_ordinals_bytes, libc::MADV_SEQUENTIAL);
}
#[cfg(feature = "native")]
pub fn madvise_dontneed_block_data(&self) {
Self::madvise_owned(&self.block_data_bytes, libc::MADV_DONTNEED);
}
#[cfg(feature = "native")]
pub fn madvise_dontneed_grids(&self) {
Self::madvise_owned(&self.grid_bytes, libc::MADV_DONTNEED);
Self::madvise_owned(&self.sb_grid_bytes, libc::MADV_DONTNEED);
}
#[cfg(feature = "native")]
fn madvise_owned(bytes: &crate::directories::OwnedBytes, advice: i32) {
if !bytes.is_mmap() {
return;
}
let slice = bytes.as_slice();
if slice.is_empty() {
return;
}
let ptr = slice.as_ptr();
let len = slice.len();
let page_size = 4096usize;
let aligned_ptr = (ptr as usize) & !(page_size - 1);
let aligned_len = len + (ptr as usize - aligned_ptr);
unsafe {
libc::madvise(aligned_ptr as *mut libc::c_void, aligned_len, advice);
}
}
}
pub struct BlockTermIter<'a> {
dim_ptr: *const u8,
ps_ptr: *const u8,
post_ptr: *const u8,
num_terms: u16,
current: u16,
_marker: std::marker::PhantomData<&'a ()>,
}
unsafe impl<'a> Send for BlockTermIter<'a> {}
unsafe impl<'a> Sync for BlockTermIter<'a> {}
impl<'a> Iterator for BlockTermIter<'a> {
type Item = (u32, &'a [BmpPosting]);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.num_terms {
return None;
}
let i = self.current;
self.current += 1;
let dim_id = unsafe { read_u32_unchecked(self.dim_ptr, i as usize) };
let postings = unsafe { block_term_postings(self.ps_ptr, self.post_ptr, i) };
Some((dim_id, postings))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let rem = (self.num_terms - self.current) as usize;
(rem, Some(rem))
}
}
impl<'a> ExactSizeIterator for BlockTermIter<'a> {}
#[inline(always)]
pub(crate) fn find_dim_in_block_data(
dim_ptr: *const u8,
num_terms: u16,
dim_id: u32,
) -> Option<u16> {
let count = num_terms as usize;
if count == 0 {
return None;
}
let mut lo = 0usize;
let mut hi = count;
while lo < hi {
let mid = lo + (hi - lo) / 2;
let val = unsafe { read_u32_unchecked(dim_ptr, mid) };
match val.cmp(&dim_id) {
std::cmp::Ordering::Less => lo = mid + 1,
std::cmp::Ordering::Equal => return Some(mid as u16),
std::cmp::Ordering::Greater => hi = mid,
}
}
None
}
#[inline(always)]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn block_term_postings<'a>(
ps_ptr: *const u8,
post_ptr: *const u8,
local_term: u16,
) -> &'a [BmpPosting] {
let start_p = ps_ptr.add(local_term as usize * 2);
let end_p = ps_ptr.add((local_term as usize + 1) * 2);
let start = u16::from_le((start_p as *const u16).read_unaligned()) as usize;
let end = u16::from_le((end_p as *const u16).read_unaligned()) as usize;
let count = end - start;
if count == 0 {
return &[];
}
let ptr = post_ptr.add(start * 2) as *const BmpPosting;
std::slice::from_raw_parts(ptr, count)
}
#[inline]
pub(crate) fn accumulate_u4_weighted(
packed: &[u8],
elem_offset: usize,
count: usize,
weight: f32,
out: &mut [f32],
) {
if count == 0 {
return;
}
#[cfg(target_arch = "aarch64")]
{
if elem_offset.is_multiple_of(2) {
unsafe { accumulate_u4_weighted_neon(packed, elem_offset, count, weight, out) };
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if elem_offset.is_multiple_of(2) && is_x86_feature_detected!("sse4.1") {
unsafe { accumulate_u4_weighted_sse41(packed, elem_offset, count, weight, out) };
return;
}
}
for i in 0..count {
let abs_idx = elem_offset + i;
let byte_val = unsafe { *packed.get_unchecked(abs_idx / 2) };
let val = if abs_idx.is_multiple_of(2) {
byte_val & 0x0F
} else {
byte_val >> 4
};
unsafe {
*out.get_unchecked_mut(i) += (val as u32 * 17) as f32 * weight;
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn accumulate_u4_weighted_neon(
packed: &[u8],
elem_offset: usize,
count: usize,
weight: f32,
out: &mut [f32],
) {
use std::arch::aarch64::*;
debug_assert!(elem_offset.is_multiple_of(2));
let weight_v = vdupq_n_f32(weight);
let mask_lo = vdupq_n_u8(0x0F);
let scale17 = vdupq_n_u8(17);
let byte_offset = elem_offset / 2;
let packed_ptr = packed.as_ptr().add(byte_offset);
let out_ptr = out.as_mut_ptr();
let chunks = count / 32;
let remainder = count % 32;
for chunk in 0..chunks {
let pb = packed_ptr.add(chunk * 16);
let ob = out_ptr.add(chunk * 32);
let bytes = vld1q_u8(pb);
let low = vandq_u8(bytes, mask_lo);
let high = vshrq_n_u8::<4>(bytes);
let low_scaled = vmulq_u8(low, scale17);
let high_scaled = vmulq_u8(high, scale17);
let elems_0_15 = vzip1q_u8(low_scaled, high_scaled);
let elems_16_31 = vzip2q_u8(low_scaled, high_scaled);
{
let lo8 = vget_low_u8(elems_0_15);
let hi8 = vget_high_u8(elems_0_15);
let lo16 = vmovl_u8(lo8);
let hi16 = vmovl_u8(hi8);
let u32_0 = vmovl_u16(vget_low_u16(lo16));
let f32_0 = vcvtq_f32_u32(u32_0);
let acc_0 = vld1q_f32(ob);
vst1q_f32(ob, vfmaq_f32(acc_0, f32_0, weight_v));
let u32_1 = vmovl_u16(vget_high_u16(lo16));
let f32_1 = vcvtq_f32_u32(u32_1);
let acc_1 = vld1q_f32(ob.add(4));
vst1q_f32(ob.add(4), vfmaq_f32(acc_1, f32_1, weight_v));
let u32_2 = vmovl_u16(vget_low_u16(hi16));
let f32_2 = vcvtq_f32_u32(u32_2);
let acc_2 = vld1q_f32(ob.add(8));
vst1q_f32(ob.add(8), vfmaq_f32(acc_2, f32_2, weight_v));
let u32_3 = vmovl_u16(vget_high_u16(hi16));
let f32_3 = vcvtq_f32_u32(u32_3);
let acc_3 = vld1q_f32(ob.add(12));
vst1q_f32(ob.add(12), vfmaq_f32(acc_3, f32_3, weight_v));
}
{
let lo8 = vget_low_u8(elems_16_31);
let hi8 = vget_high_u8(elems_16_31);
let lo16 = vmovl_u8(lo8);
let hi16 = vmovl_u8(hi8);
let u32_0 = vmovl_u16(vget_low_u16(lo16));
let f32_0 = vcvtq_f32_u32(u32_0);
let acc_0 = vld1q_f32(ob.add(16));
vst1q_f32(ob.add(16), vfmaq_f32(acc_0, f32_0, weight_v));
let u32_1 = vmovl_u16(vget_high_u16(lo16));
let f32_1 = vcvtq_f32_u32(u32_1);
let acc_1 = vld1q_f32(ob.add(20));
vst1q_f32(ob.add(20), vfmaq_f32(acc_1, f32_1, weight_v));
let u32_2 = vmovl_u16(vget_low_u16(hi16));
let f32_2 = vcvtq_f32_u32(u32_2);
let acc_2 = vld1q_f32(ob.add(24));
vst1q_f32(ob.add(24), vfmaq_f32(acc_2, f32_2, weight_v));
let u32_3 = vmovl_u16(vget_high_u16(hi16));
let f32_3 = vcvtq_f32_u32(u32_3);
let acc_3 = vld1q_f32(ob.add(28));
vst1q_f32(ob.add(28), vfmaq_f32(acc_3, f32_3, weight_v));
}
}
let base_elem = chunks * 32;
for i in 0..remainder {
let abs_idx = elem_offset + base_elem + i;
let byte_val = *packed.get_unchecked(abs_idx / 2);
let val = if abs_idx.is_multiple_of(2) {
byte_val & 0x0F
} else {
byte_val >> 4
};
*out.get_unchecked_mut(base_elem + i) += (val as u32 * 17) as f32 * weight;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn accumulate_u4_weighted_sse41(
packed: &[u8],
elem_offset: usize,
count: usize,
weight: f32,
out: &mut [f32],
) {
use std::arch::x86_64::*;
debug_assert!(elem_offset.is_multiple_of(2));
let weight_v = _mm_set1_ps(weight);
let mask_lo = _mm_set1_epi8(0x0F);
let zero = _mm_setzero_si128();
let byte_offset = elem_offset / 2;
let packed_ptr = packed.as_ptr().add(byte_offset);
let out_ptr = out.as_mut_ptr();
let chunks = count / 32;
let remainder = count % 32;
for chunk in 0..chunks {
let pb = packed_ptr.add(chunk * 16);
let ob = out_ptr.add(chunk * 32);
let bytes = _mm_loadu_si128(pb as *const __m128i);
let low = _mm_and_si128(bytes, mask_lo);
let high = _mm_srli_epi16::<4>(bytes);
let high = _mm_and_si128(high, mask_lo);
let low_scaled = _mm_add_epi8(_mm_slli_epi16::<4>(_mm_and_si128(low, mask_lo)), low);
let high_scaled = _mm_add_epi8(_mm_slli_epi16::<4>(_mm_and_si128(high, mask_lo)), high);
let elems_0_15 = _mm_unpacklo_epi8(low_scaled, high_scaled);
let elems_16_31 = _mm_unpackhi_epi8(low_scaled, high_scaled);
{
let lo8 = _mm_unpacklo_epi8(elems_0_15, zero);
let hi8 = _mm_unpackhi_epi8(elems_0_15, zero);
let u32_0 = _mm_unpacklo_epi16(lo8, zero);
let f32_0 = _mm_cvtepi32_ps(u32_0);
let acc_0 = _mm_loadu_ps(ob);
_mm_storeu_ps(ob, _mm_add_ps(acc_0, _mm_mul_ps(f32_0, weight_v)));
let u32_1 = _mm_unpackhi_epi16(lo8, zero);
let f32_1 = _mm_cvtepi32_ps(u32_1);
let acc_1 = _mm_loadu_ps(ob.add(4));
_mm_storeu_ps(ob.add(4), _mm_add_ps(acc_1, _mm_mul_ps(f32_1, weight_v)));
let u32_2 = _mm_unpacklo_epi16(hi8, zero);
let f32_2 = _mm_cvtepi32_ps(u32_2);
let acc_2 = _mm_loadu_ps(ob.add(8));
_mm_storeu_ps(ob.add(8), _mm_add_ps(acc_2, _mm_mul_ps(f32_2, weight_v)));
let u32_3 = _mm_unpackhi_epi16(hi8, zero);
let f32_3 = _mm_cvtepi32_ps(u32_3);
let acc_3 = _mm_loadu_ps(ob.add(12));
_mm_storeu_ps(ob.add(12), _mm_add_ps(acc_3, _mm_mul_ps(f32_3, weight_v)));
}
{
let lo8 = _mm_unpacklo_epi8(elems_16_31, zero);
let hi8 = _mm_unpackhi_epi8(elems_16_31, zero);
let u32_0 = _mm_unpacklo_epi16(lo8, zero);
let f32_0 = _mm_cvtepi32_ps(u32_0);
let acc_0 = _mm_loadu_ps(ob.add(16));
_mm_storeu_ps(ob.add(16), _mm_add_ps(acc_0, _mm_mul_ps(f32_0, weight_v)));
let u32_1 = _mm_unpackhi_epi16(lo8, zero);
let f32_1 = _mm_cvtepi32_ps(u32_1);
let acc_1 = _mm_loadu_ps(ob.add(20));
_mm_storeu_ps(ob.add(20), _mm_add_ps(acc_1, _mm_mul_ps(f32_1, weight_v)));
let u32_2 = _mm_unpacklo_epi16(hi8, zero);
let f32_2 = _mm_cvtepi32_ps(u32_2);
let acc_2 = _mm_loadu_ps(ob.add(24));
_mm_storeu_ps(ob.add(24), _mm_add_ps(acc_2, _mm_mul_ps(f32_2, weight_v)));
let u32_3 = _mm_unpackhi_epi16(hi8, zero);
let f32_3 = _mm_cvtepi32_ps(u32_3);
let acc_3 = _mm_loadu_ps(ob.add(28));
_mm_storeu_ps(ob.add(28), _mm_add_ps(acc_3, _mm_mul_ps(f32_3, weight_v)));
}
}
let base_elem = chunks * 32;
for i in 0..remainder {
let abs_idx = elem_offset + base_elem + i;
let byte_val = *packed.get_unchecked(abs_idx / 2);
let val = if abs_idx.is_multiple_of(2) {
byte_val & 0x0F
} else {
byte_val >> 4
};
*out.get_unchecked_mut(base_elem + i) += (val as u32 * 17) as f32 * weight;
}
}
pub(crate) fn compute_block_masks_4bit(
grid: &[u8],
prs: usize,
query_dims: &[(usize, f32)],
block_start: usize,
count: usize,
masks: &mut [u64],
) {
debug_assert!(masks.len() >= count);
masks[..count].fill(0);
#[cfg(target_arch = "aarch64")]
{
if block_start.is_multiple_of(2) {
unsafe {
compute_block_masks_range_neon(grid, prs, query_dims, block_start, count, masks)
};
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if block_start.is_multiple_of(2) && is_x86_feature_detected!("sse4.1") {
unsafe {
compute_block_masks_range_sse41(grid, prs, query_dims, block_start, count, masks)
};
return;
}
}
for (q, &(dim_idx, _)) in query_dims.iter().enumerate() {
let row = &grid[dim_idx * prs..(dim_idx + 1) * prs];
let bit = 1u64 << q;
for b in 0..count {
let abs_b = block_start + b;
let byte_val = unsafe { *row.get_unchecked(abs_b / 2) };
let val = if abs_b.is_multiple_of(2) {
byte_val & 0x0F
} else {
byte_val >> 4
};
if val > 0 {
unsafe { *masks.get_unchecked_mut(b) |= bit };
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn compute_block_masks_range_neon(
grid: &[u8],
prs: usize,
query_dims: &[(usize, f32)],
block_start: usize,
count: usize,
masks: &mut [u64],
) {
use std::arch::aarch64::*;
debug_assert!(block_start.is_multiple_of(2));
let byte_offset = block_start / 2;
let zero = vdupq_n_u8(0);
let mask_lo = vdupq_n_u8(0x0F);
for (q, &(dim_idx, _)) in query_dims.iter().enumerate() {
let row_ptr = grid.as_ptr().add(dim_idx * prs + byte_offset);
let bit = 1u64 << q;
let chunks = count / 32;
let remainder = count % 32;
for chunk in 0..chunks {
let pb = row_ptr.add(chunk * 16);
let base = chunk * 32;
let bytes = vld1q_u8(pb);
let low = vandq_u8(bytes, mask_lo);
let high = vshrq_n_u8::<4>(bytes);
let elems_lo = vzip1q_u8(low, high);
let elems_hi = vzip2q_u8(low, high);
let nz_lo = vcgtq_u8(elems_lo, zero);
let nz_hi = vcgtq_u8(elems_hi, zero);
let mut lo_arr = [0u8; 16];
let mut hi_arr = [0u8; 16];
vst1q_u8(lo_arr.as_mut_ptr(), nz_lo);
vst1q_u8(hi_arr.as_mut_ptr(), nz_hi);
for (i, &v) in lo_arr.iter().enumerate() {
if v != 0 {
*masks.get_unchecked_mut(base + i) |= bit;
}
}
for (i, &v) in hi_arr.iter().enumerate() {
if v != 0 {
*masks.get_unchecked_mut(base + 16 + i) |= bit;
}
}
}
let base = chunks * 32;
for i in 0..remainder {
let abs_b = block_start + base + i;
let byte_val = *grid.get_unchecked(dim_idx * prs + abs_b / 2);
let val = if abs_b.is_multiple_of(2) {
byte_val & 0x0F
} else {
byte_val >> 4
};
if val > 0 {
*masks.get_unchecked_mut(base + i) |= bit;
}
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.1")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn compute_block_masks_range_sse41(
grid: &[u8],
prs: usize,
query_dims: &[(usize, f32)],
block_start: usize,
count: usize,
masks: &mut [u64],
) {
use std::arch::x86_64::*;
debug_assert!(block_start.is_multiple_of(2));
let byte_offset = block_start / 2;
let zero = _mm_setzero_si128();
let mask_lo_v = _mm_set1_epi8(0x0F);
for (q, &(dim_idx, _)) in query_dims.iter().enumerate() {
let row_ptr = grid.as_ptr().add(dim_idx * prs + byte_offset);
let bit = 1u64 << q;
let chunks = count / 32;
let remainder = count % 32;
for chunk in 0..chunks {
let pb = row_ptr.add(chunk * 16);
let base = chunk * 32;
let bytes = _mm_loadu_si128(pb as *const __m128i);
let low = _mm_and_si128(bytes, mask_lo_v);
let high = _mm_and_si128(_mm_srli_epi16::<4>(bytes), mask_lo_v);
let elems_lo = _mm_unpacklo_epi8(low, high);
let elems_hi = _mm_unpackhi_epi8(low, high);
let nz_lo = _mm_cmpgt_epi8(elems_lo, zero);
let nz_hi = _mm_cmpgt_epi8(elems_hi, zero);
let mut m = _mm_movemask_epi8(nz_lo) as u32;
while m != 0 {
let i = m.trailing_zeros() as usize;
m &= m - 1;
*masks.get_unchecked_mut(base + i) |= bit;
}
let mut m = _mm_movemask_epi8(nz_hi) as u32;
while m != 0 {
let i = m.trailing_zeros() as usize;
m &= m - 1;
*masks.get_unchecked_mut(base + 16 + i) |= bit;
}
}
let base = chunks * 32;
for i in 0..remainder {
let abs_b = block_start + base + i;
let byte_val = *grid.get_unchecked(dim_idx * prs + abs_b / 2);
let val = if abs_b.is_multiple_of(2) {
byte_val & 0x0F
} else {
byte_val >> 4
};
if val > 0 {
*masks.get_unchecked_mut(base + i) |= bit;
}
}
}
}