#![allow(clippy::io_other_error)]
#[cfg(feature = "memmap2")]
use std::path::Path;
#[cfg(feature = "memmap2")]
use memmap2::MmapOptions;
use xxhash_rust::xxh64::xxh64;
use super::reader::parse_segment_mmap;
#[cfg(feature = "memmap2")]
use super::reader::read_exact_at;
#[cfg(feature = "memmap2")]
use super::MAX_SEGMENT_SIZE;
use super::{MmapSegment, PostingsBacking, SegmentData, FORMAT_VERSION_V2, FORMAT_VERSION_V3};
use crate::IndexError;
const POST_MAGIC: &[u8; 8] = b"SNTXPOST";
const POST_MIN_SIZE: usize = 8 + 8;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PostVerify {
Full,
Structural,
}
impl MmapSegment {
pub fn from_bytes(dict_bytes: Vec<u8>, post_bytes: Vec<u8>) -> Result<Self, IndexError> {
let layout = parse_segment_mmap(&dict_bytes, &[FORMAT_VERSION_V2, FORMAT_VERSION_V3])?;
let len = dict_bytes.len();
Ok(MmapSegment {
_file: None,
expected_len: len,
doc_count: layout.doc_count,
gram_count: layout.gram_count,
doc_table_offset: layout.doc_table_offset,
dict_offset: layout.dict_offset,
postings_start: layout.postings_start,
mmap: SegmentData::Heap(dict_bytes),
postings: PostingsBacking::InMemory(post_bytes),
})
}
#[cfg(feature = "memmap2")]
pub fn open(path: &Path) -> Result<Self, IndexError> {
let file = std::fs::File::open(path)?;
let file_meta = file.metadata()?;
if file_meta.len() > MAX_SEGMENT_SIZE {
return Err(IndexError::CorruptIndex(format!(
"segment too large ({} bytes, max {})",
file_meta.len(),
MAX_SEGMENT_SIZE
)));
}
file.try_lock_shared()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
let mmap = unsafe { MmapOptions::new().map_copy_read_only(&file)? };
let len = mmap.len();
let layout = parse_segment_mmap(&mmap, &[FORMAT_VERSION_V2, FORMAT_VERSION_V3])?;
Ok(MmapSegment {
_file: Some(file),
mmap: SegmentData::Mmap(mmap),
expected_len: len,
doc_count: layout.doc_count,
gram_count: layout.gram_count,
doc_table_offset: layout.doc_table_offset,
dict_offset: layout.dict_offset,
postings_start: layout.postings_start,
postings: PostingsBacking::V2Mmap,
})
}
#[cfg(feature = "memmap2")]
pub fn open_split(
dict_path: &Path,
post_path: &Path,
verify: PostVerify,
) -> Result<Self, IndexError> {
let file = std::fs::File::open(dict_path)?;
let file_meta = file.metadata()?;
if file_meta.len() > MAX_SEGMENT_SIZE {
return Err(IndexError::CorruptIndex(format!(
"dict file too large ({} bytes, max {})",
file_meta.len(),
MAX_SEGMENT_SIZE
)));
}
file.try_lock_shared()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
let mmap = unsafe { MmapOptions::new().map_copy_read_only(&file)? };
let len = mmap.len();
let layout = parse_segment_mmap(&mmap, &[FORMAT_VERSION_V3])?;
let post_file = std::fs::File::open(post_path)?;
post_file
.try_lock_shared()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
check_post_file_structure(&post_file)?;
if verify == PostVerify::Full {
verify_post_file_checksum(&post_file)?;
}
Ok(MmapSegment {
_file: Some(file),
mmap: SegmentData::Mmap(mmap),
expected_len: len,
doc_count: layout.doc_count,
gram_count: layout.gram_count,
doc_table_offset: layout.doc_table_offset,
dict_offset: layout.dict_offset,
postings_start: 0,
postings: PostingsBacking::V3File(post_file),
})
}
pub fn verify_postings(&self) -> Result<(), IndexError> {
match &self.postings {
#[cfg(feature = "memmap2")]
PostingsBacking::V2Mmap => Ok(()),
#[cfg(feature = "memmap2")]
PostingsBacking::V3File(post_file) => {
check_post_file_structure(post_file)?;
verify_post_file_checksum(post_file)
}
PostingsBacking::InMemory(bytes) => {
let len = bytes.len();
if len < POST_MIN_SIZE {
return Err(IndexError::CorruptIndex(format!(
"post bytes too small: {len} bytes"
)));
}
if &bytes[..8] != POST_MAGIC {
return Err(IndexError::CorruptIndex(
"post bytes have wrong magic (expected SNTXPOST)".into(),
));
}
let stored = u64::from_le_bytes(
bytes[len - 8..]
.try_into()
.map_err(|_| IndexError::CorruptIndex("post trailer slice".into()))?,
);
if xxh64(&bytes[8..len - 8], 0) != stored {
return Err(IndexError::CorruptIndex(
"post file checksum mismatch".into(),
));
}
Ok(())
}
}
}
}
#[cfg(feature = "memmap2")]
fn check_post_file_structure(post_file: &std::fs::File) -> Result<(), IndexError> {
let post_len = post_file.metadata()?.len() as usize;
if post_len < POST_MIN_SIZE {
return Err(IndexError::CorruptIndex(format!(
"post file too small: {post_len} bytes"
)));
}
let mut post_magic = [0u8; 8];
read_exact_at(post_file, &mut post_magic, 0)?;
if &post_magic != POST_MAGIC {
return Err(IndexError::CorruptIndex(
"post file has wrong magic (expected SNTXPOST)".into(),
));
}
let mut trailer = [0u8; 8];
read_exact_at(post_file, &mut trailer, (post_len - 8) as u64)?;
Ok(())
}
#[cfg(feature = "memmap2")]
fn verify_post_file_checksum(post_file: &std::fs::File) -> Result<(), IndexError> {
let post_len = post_file.metadata()?.len() as usize;
if post_len < POST_MIN_SIZE {
return Err(IndexError::CorruptIndex(format!(
"post file too small: {post_len} bytes"
)));
}
let mut stored_cksum_bytes = [0u8; 8];
read_exact_at(post_file, &mut stored_cksum_bytes, (post_len - 8) as u64)?;
let stored_post_checksum = u64::from_le_bytes(stored_cksum_bytes);
let postings_data_len = post_len - 16;
let mut postings_data = vec![0u8; postings_data_len];
if postings_data_len > 0 {
read_exact_at(post_file, &mut postings_data, 8)?;
}
if xxh64(&postings_data, 0) != stored_post_checksum {
return Err(IndexError::CorruptIndex(
"post file checksum mismatch".into(),
));
}
Ok(())
}