use std::io;
use std::mem::size_of;
use serde::{Deserialize, Serialize};
use crate::directories::{FileHandle, OwnedBytes};
use crate::dsl::DenseVectorQuantization;
use crate::segment::format::{DOC_ID_ENTRY_SIZE, FLAT_BINARY_HEADER_SIZE, FLAT_BINARY_MAGIC};
use crate::structures::simd::{batch_f32_to_f16, batch_f32_to_u8, f16_to_f32, u8_to_f32};
#[inline]
pub fn dequantize_raw(
raw: &[u8],
quant: DenseVectorQuantization,
num_floats: usize,
out: &mut [f32],
) {
debug_assert!(out.len() >= num_floats);
match quant {
DenseVectorQuantization::F32 => {
debug_assert!(
(raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<f32>()),
"f32 vector data not 4-byte aligned"
);
out[..num_floats].copy_from_slice(unsafe {
std::slice::from_raw_parts(raw.as_ptr() as *const f32, num_floats)
});
}
DenseVectorQuantization::F16 => {
debug_assert!(
(raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
"f16 vector data not 2-byte aligned"
);
let f16_slice =
unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, num_floats) };
for (i, &h) in f16_slice.iter().enumerate() {
out[i] = f16_to_f32(h);
}
}
DenseVectorQuantization::UInt8 => {
for (i, &b) in raw.iter().enumerate().take(num_floats) {
out[i] = u8_to_f32(b);
}
}
DenseVectorQuantization::Binary => {
unreachable!("Binary vectors use raw bytes, not f32 dequantization");
}
}
}
pub struct FlatVectorData;
impl FlatVectorData {
pub fn write_binary_header(
dim: usize,
num_vectors: usize,
quant: DenseVectorQuantization,
writer: &mut dyn std::io::Write,
) -> std::io::Result<()> {
writer.write_all(&FLAT_BINARY_MAGIC.to_le_bytes())?;
writer.write_all(&(dim as u32).to_le_bytes())?;
writer.write_all(&(num_vectors as u32).to_le_bytes())?;
writer.write_all(&[quant.tag(), 0, 0, 0])?; Ok(())
}
pub fn serialized_binary_size(
dim: usize,
num_vectors: usize,
quant: DenseVectorQuantization,
) -> usize {
let bytes_per_vector = match quant {
DenseVectorQuantization::Binary => dim.div_ceil(8),
_ => dim * quant.element_size(),
};
FLAT_BINARY_HEADER_SIZE + num_vectors * bytes_per_vector + num_vectors * DOC_ID_ENTRY_SIZE
}
pub fn serialize_binary_from_flat_streaming(
dim: usize,
flat_vectors: &[f32],
doc_ids: &[(u32, u16)],
quant: DenseVectorQuantization,
writer: &mut dyn std::io::Write,
) -> std::io::Result<()> {
let num_vectors = doc_ids.len();
Self::write_binary_header(dim, num_vectors, quant, writer)?;
match quant {
DenseVectorQuantization::F32 => {
let bytes: &[u8] = unsafe {
std::slice::from_raw_parts(
flat_vectors.as_ptr() as *const u8,
std::mem::size_of_val(flat_vectors),
)
};
writer.write_all(bytes)?;
}
DenseVectorQuantization::F16 => {
let mut buf = vec![0u16; dim];
for v in flat_vectors.chunks_exact(dim) {
batch_f32_to_f16(v, &mut buf);
let bytes: &[u8] =
unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const u8, dim * 2) };
writer.write_all(bytes)?;
}
}
DenseVectorQuantization::UInt8 => {
let mut buf = vec![0u8; dim];
for v in flat_vectors.chunks_exact(dim) {
batch_f32_to_u8(v, &mut buf);
writer.write_all(&buf)?;
}
}
DenseVectorQuantization::Binary => {
unreachable!("Binary quantization should use serialize_binary_from_bits_streaming");
}
}
for &(doc_id, ordinal) in doc_ids {
writer.write_all(&doc_id.to_le_bytes())?;
writer.write_all(&ordinal.to_le_bytes())?;
}
Ok(())
}
pub fn serialize_binary_from_bits_streaming(
dim_bits: usize,
packed_vectors: &[u8],
doc_ids: &[(u32, u16)],
writer: &mut dyn std::io::Write,
) -> std::io::Result<()> {
let num_vectors = doc_ids.len();
let byte_len = dim_bits.div_ceil(8);
debug_assert_eq!(packed_vectors.len(), num_vectors * byte_len);
Self::write_binary_header(
dim_bits,
num_vectors,
DenseVectorQuantization::Binary,
writer,
)?;
writer.write_all(packed_vectors)?;
for &(doc_id, ordinal) in doc_ids {
writer.write_all(&doc_id.to_le_bytes())?;
writer.write_all(&ordinal.to_le_bytes())?;
}
Ok(())
}
pub fn write_raw_vector_bytes(
raw_bytes: &[u8],
writer: &mut dyn std::io::Write,
) -> std::io::Result<()> {
writer.write_all(raw_bytes)
}
}
#[derive(Debug, Clone)]
pub struct LazyFlatVectorData {
pub dim: usize,
pub num_vectors: usize,
pub quantization: DenseVectorQuantization,
doc_ids_bytes: OwnedBytes,
handle: FileHandle,
vectors_offset: u64,
vbs: usize,
}
impl LazyFlatVectorData {
pub async fn open(handle: FileHandle) -> io::Result<Self> {
let header = handle
.read_bytes_range(0..FLAT_BINARY_HEADER_SIZE as u64)
.await?;
let hdr = header.as_slice();
let magic = u32::from_le_bytes([hdr[0], hdr[1], hdr[2], hdr[3]]);
if magic != FLAT_BINARY_MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid FlatVectorData binary magic",
));
}
let dim = u32::from_le_bytes([hdr[4], hdr[5], hdr[6], hdr[7]]) as usize;
let num_vectors = u32::from_le_bytes([hdr[8], hdr[9], hdr[10], hdr[11]]) as usize;
let quantization = DenseVectorQuantization::from_tag(hdr[12]).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("Unknown quantization tag: {}", hdr[12]),
)
})?;
let vbs = if quantization == DenseVectorQuantization::Binary {
dim.div_ceil(8)
} else {
dim * quantization.element_size()
};
let vectors_byte_len = num_vectors * vbs;
let doc_ids_start = (FLAT_BINARY_HEADER_SIZE + vectors_byte_len) as u64;
let doc_ids_byte_len = (num_vectors * DOC_ID_ENTRY_SIZE) as u64;
let doc_ids_bytes = handle
.read_bytes_range(doc_ids_start..doc_ids_start + doc_ids_byte_len)
.await?;
Ok(Self {
dim,
num_vectors,
quantization,
doc_ids_bytes,
handle,
vectors_offset: FLAT_BINARY_HEADER_SIZE as u64,
vbs,
})
}
pub async fn read_vector_into(&self, idx: usize, out: &mut [f32]) -> io::Result<()> {
debug_assert!(out.len() >= self.dim);
let vbs = self.vector_byte_size();
let byte_offset = self.vectors_offset + (idx * vbs) as u64;
let bytes = self
.handle
.read_bytes_range(byte_offset..byte_offset + vbs as u64)
.await?;
let raw = bytes.as_slice();
dequantize_raw(raw, self.quantization, self.dim, out);
Ok(())
}
pub async fn get_vector(&self, idx: usize) -> io::Result<Vec<f32>> {
let mut vector = vec![0f32; self.dim];
self.read_vector_into(idx, &mut vector).await?;
Ok(vector)
}
pub async fn read_vector_raw_into(&self, idx: usize, out: &mut [u8]) -> io::Result<()> {
let vbs = self.vector_byte_size();
debug_assert!(out.len() >= vbs);
let byte_offset = self.vectors_offset + (idx * vbs) as u64;
let bytes = self
.handle
.read_bytes_range(byte_offset..byte_offset + vbs as u64)
.await?;
out[..vbs].copy_from_slice(bytes.as_slice());
Ok(())
}
pub async fn read_vectors_batch(
&self,
start_idx: usize,
count: usize,
) -> io::Result<OwnedBytes> {
debug_assert!(start_idx + count <= self.num_vectors);
let vbs = self.vector_byte_size();
let byte_offset = self.vectors_offset + (start_idx * vbs) as u64;
let byte_len = (count * vbs) as u64;
self.handle
.read_bytes_range(byte_offset..byte_offset + byte_len)
.await
}
#[cfg(feature = "sync")]
pub fn read_vector_raw_into_sync(&self, idx: usize, out: &mut [u8]) -> io::Result<()> {
let vbs = self.vector_byte_size();
debug_assert!(out.len() >= vbs);
let byte_offset = self.vectors_offset + (idx * vbs) as u64;
let bytes = self
.handle
.read_bytes_range_sync(byte_offset..byte_offset + vbs as u64)?;
out[..vbs].copy_from_slice(bytes.as_slice());
Ok(())
}
#[cfg(feature = "sync")]
pub fn read_vectors_batch_sync(
&self,
start_idx: usize,
count: usize,
) -> io::Result<OwnedBytes> {
debug_assert!(start_idx + count <= self.num_vectors);
let vbs = self.vector_byte_size();
let byte_offset = self.vectors_offset + (start_idx * vbs) as u64;
let byte_len = (count * vbs) as u64;
self.handle
.read_bytes_range_sync(byte_offset..byte_offset + byte_len)
}
pub fn flat_indexes_for_doc_range(&self, doc_id: u32) -> (usize, usize) {
let n = self.num_vectors;
let start = {
let mut lo = 0usize;
let mut hi = n;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.doc_id_at(mid) < doc_id {
lo = mid + 1;
} else {
hi = mid;
}
}
lo
};
let mut count = 0;
let mut i = start;
while i < n && self.doc_id_at(i) == doc_id {
count += 1;
i += 1;
}
(start, count)
}
pub fn flat_indexes_for_doc(&self, doc_id: u32) -> (usize, Vec<(u32, u16)>) {
let n = self.num_vectors;
let start = {
let mut lo = 0usize;
let mut hi = n;
while lo < hi {
let mid = lo + (hi - lo) / 2;
if self.doc_id_at(mid) < doc_id {
lo = mid + 1;
} else {
hi = mid;
}
}
lo
};
let mut entries = Vec::new();
let mut i = start;
while i < n {
let (did, ord) = self.get_doc_id(i);
if did != doc_id {
break;
}
entries.push((did, ord));
i += 1;
}
(start, entries)
}
#[inline]
fn doc_id_at(&self, idx: usize) -> u32 {
let off = idx * DOC_ID_ENTRY_SIZE;
let d = &self.doc_ids_bytes[off..];
u32::from_le_bytes([d[0], d[1], d[2], d[3]])
}
#[inline]
pub fn get_doc_id(&self, idx: usize) -> (u32, u16) {
let off = idx * DOC_ID_ENTRY_SIZE;
let d = &self.doc_ids_bytes[off..];
let doc_id = u32::from_le_bytes([d[0], d[1], d[2], d[3]]);
let ordinal = u16::from_le_bytes([d[4], d[5]]);
(doc_id, ordinal)
}
#[inline]
pub fn vector_byte_size(&self) -> usize {
self.vbs
}
pub fn vector_bytes_len(&self) -> u64 {
(self.num_vectors as u64) * (self.vector_byte_size() as u64)
}
pub fn vectors_byte_offset(&self) -> u64 {
self.vectors_offset
}
pub fn handle(&self) -> &FileHandle {
&self.handle
}
pub fn estimated_memory_bytes(&self) -> usize {
size_of::<Self>() + size_of::<OwnedBytes>()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IVFRaBitQIndexData {
pub index: crate::structures::IVFRaBitQIndex,
pub codebook: crate::structures::RaBitQCodebook,
}
impl IVFRaBitQIndexData {
pub fn to_bytes(&self) -> std::io::Result<Vec<u8>> {
bincode::serde::encode_to_vec(self, bincode::config::standard())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
pub fn from_bytes(data: &[u8]) -> std::io::Result<Self> {
bincode::serde::decode_from_slice(data, bincode::config::standard())
.map(|(v, _)| v)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScaNNIndexData {
pub index: crate::structures::IVFPQIndex,
pub codebook: crate::structures::PQCodebook,
}
impl ScaNNIndexData {
pub fn to_bytes(&self) -> std::io::Result<Vec<u8>> {
bincode::serde::encode_to_vec(self, bincode::config::standard())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
pub fn from_bytes(data: &[u8]) -> std::io::Result<Self> {
bincode::serde::decode_from_slice(data, bincode::config::standard())
.map(|(v, _)| v)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
}