use memmap2::Mmap;
use std::fs::File;
use std::path::Path;
use std::sync::Arc;
use super::format::*;
use crate::error::{Error, Result};
use crate::types::*;
pub struct Segment {
mmap: Arc<Mmap>,
header: SegmentHeader,
path: String,
}
impl Segment {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy().to_string();
let file = File::open(&path)?;
let mmap = unsafe { Mmap::map(&file)? };
if mmap.len() < SegmentHeader::SIZE {
return Err(Error::Segment("File too small for header".into()));
}
let header: SegmentHeader =
unsafe { std::ptr::read_unaligned(mmap.as_ptr() as *const SegmentHeader) };
header.validate()?;
if mmap.len() < header.file_len as usize {
return Err(Error::Segment(format!(
"File size {} < declared length {}",
mmap.len(),
header.file_len
)));
}
Ok(Self {
mmap: Arc::new(mmap),
header,
path: path_str,
})
}
#[inline]
pub fn header(&self) -> &SegmentHeader {
&self.header
}
#[inline]
pub fn num_vectors(&self) -> u32 {
self.header.n_vec
}
#[inline]
pub fn dim(&self) -> u32 {
self.header.dim
}
#[inline]
pub fn bps_ptr(&self) -> *const u8 {
unsafe { self.mmap.as_ptr().add(self.header.off_bps as usize) }
}
pub fn bps_data(&self) -> &[u8] {
let size = self.header.bps_size();
unsafe { std::slice::from_raw_parts(self.bps_ptr(), size) }
}
#[inline]
pub fn i8_ptr(&self) -> *const i8 {
unsafe { self.mmap.as_ptr().add(self.header.off_i8 as usize) as *const i8 }
}
pub fn i8_data(&self) -> &[i8] {
let size = self.header.i8_size();
unsafe { std::slice::from_raw_parts(self.i8_ptr(), size) }
}
pub fn get_i8_vector(&self, vid: VectorId) -> Option<&[i8]> {
if vid >= self.header.n_vec {
return None;
}
let dim = self.header.dim as usize;
let offset = vid as usize * dim;
Some(&self.i8_data()[offset..offset + dim])
}
#[inline]
pub fn scales_ptr(&self) -> *const f32 {
unsafe { self.mmap.as_ptr().add(self.header.off_scales as usize) as *const f32 }
}
pub fn scales_data(&self) -> &[f32] {
let num_blocks = self.header.num_bps_blocks() as usize;
let size = num_blocks * self.header.n_vec as usize;
unsafe { std::slice::from_raw_parts(self.scales_ptr(), size) }
}
#[inline]
pub fn outliers_ptr(&self) -> *const OutlierEntry {
unsafe { self.mmap.as_ptr().add(self.header.off_outliers as usize) as *const OutlierEntry }
}
pub fn get_outliers(&self, vid: VectorId) -> Option<&[OutlierEntry]> {
if vid >= self.header.n_vec || !self.header.flags.has(SegmentFlags::HAS_OUTLIERS) {
return None;
}
let num_outliers = self.header.num_outliers as usize;
let offset = vid as usize * num_outliers;
unsafe {
Some(std::slice::from_raw_parts(
self.outliers_ptr().add(offset),
num_outliers,
))
}
}
#[inline]
pub fn tombstone_ptr(&self) -> *const u64 {
unsafe { self.mmap.as_ptr().add(self.header.off_tombstone as usize) as *const u64 }
}
pub fn tombstone_data(&self) -> &[u64] {
let num_words = (self.header.n_vec as usize + 63) / 64;
unsafe { std::slice::from_raw_parts(self.tombstone_ptr(), num_words) }
}
pub fn is_tombstoned(&self, vid: VectorId) -> bool {
if vid >= self.header.n_vec {
return true;
}
let word_idx = vid as usize / 64;
let bit_idx = vid as usize % 64;
let tombstones = self.tombstone_data();
if word_idx >= tombstones.len() {
return false;
}
(tombstones[word_idx] & (1u64 << bit_idx)) != 0
}
pub fn rdf_directory(&self) -> &[PostingListEntry] {
if !self.header.flags.has(SegmentFlags::HAS_RDF) {
return &[];
}
let dim = self.header.dim as usize;
unsafe {
std::slice::from_raw_parts(
self.mmap.as_ptr().add(self.header.off_rdf_dir as usize) as *const PostingListEntry,
dim,
)
}
}
#[inline]
pub fn rdf_data_ptr(&self) -> *const u8 {
unsafe { self.mmap.as_ptr().add(self.header.off_rdf_data as usize) }
}
pub fn dim_weights(&self) -> &[f32] {
if !self.header.flags.has(SegmentFlags::HAS_RDF) {
return &[];
}
let dim = self.header.dim as usize;
unsafe {
std::slice::from_raw_parts(
self.mmap.as_ptr().add(self.header.off_dim_weights as usize) as *const f32,
dim,
)
}
}
pub fn fp32_data(&self) -> Option<&[f32]> {
if !self.header.flags.has(SegmentFlags::HAS_FP32) {
return None;
}
let size = self.header.n_vec as usize * self.header.dim as usize;
unsafe {
Some(std::slice::from_raw_parts(
self.mmap.as_ptr().add(self.header.off_fp32 as usize) as *const f32,
size,
))
}
}
pub fn bps_qparams(&self) -> Option<&[super::bps::BpsQParam]> {
if self.header.off_bps_qparams == 0 {
return None;
}
let num_slots = self.header.num_bps_blocks() as usize * self.header.bps_proj as usize;
if num_slots == 0 {
return None;
}
unsafe {
Some(std::slice::from_raw_parts(
self.mmap.as_ptr().add(self.header.off_bps_qparams as usize)
as *const super::bps::BpsQParam,
num_slots,
))
}
}
pub fn get_fp32_vector(&self, vid: VectorId) -> Option<&[f32]> {
let fp32 = self.fp32_data()?;
let dim = self.header.dim as usize;
let offset = vid as usize * dim;
Some(&fp32[offset..offset + dim])
}
pub fn path(&self) -> &str {
&self.path
}
pub fn clone_mmap(&self) -> Arc<Mmap> {
Arc::clone(&self.mmap)
}
}
impl std::fmt::Debug for Segment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Segment")
.field("path", &self.path)
.field("n_vec", &self.header.n_vec)
.field("dim", &self.header.dim)
.field("flags", &self.header.flags)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_test_segment() -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
let n_vec = 100u32;
let dim = 64u32;
let num_blocks = (dim + 15) / 16;
let mut header = SegmentHeader::new(n_vec, dim);
header.flags.set(SegmentFlags::HAS_BPS);
let mut offset = SegmentHeader::SIZE as u64;
header.off_bps = offset;
let bps_size = (num_blocks as usize * n_vec as usize) as u64;
offset += bps_size;
header.off_i8 = offset;
let i8_size = (n_vec as usize * dim as usize) as u64;
offset += i8_size;
header.off_scales = offset;
let scales_size = (num_blocks as usize * n_vec as usize * 4) as u64;
offset += scales_size;
header.off_tombstone = offset;
let tombstone_size = ((n_vec as usize + 63) / 64 * 8) as u64;
offset += tombstone_size;
header.file_len = offset;
file.write_all(bytemuck::bytes_of(&header)).unwrap();
file.write_all(&vec![0u8; bps_size as usize]).unwrap();
file.write_all(&vec![0u8; i8_size as usize]).unwrap();
for _ in 0..(num_blocks * n_vec) {
file.write_all(&1.0f32.to_le_bytes()).unwrap();
}
file.write_all(&vec![0u8; tombstone_size as usize]).unwrap();
file.flush().unwrap();
file
}
#[test]
fn test_segment_open() {
let file = create_test_segment();
let segment = Segment::open(file.path()).unwrap();
assert_eq!(segment.num_vectors(), 100);
assert_eq!(segment.dim(), 64);
}
#[test]
fn test_tombstone_check() {
let file = create_test_segment();
let segment = Segment::open(file.path()).unwrap();
assert!(!segment.is_tombstoned(0));
assert!(!segment.is_tombstoned(50));
assert!(!segment.is_tombstoned(99));
assert!(segment.is_tombstoned(100));
}
}