use alloc::vec::Vec;
use core::result::Result;
use crate::codec::{
decode_match_type, EngineKind, EntryFlags, EntryRecord, Header, Version,
FULL_HEADER_SIZE, MAGIC,
};
#[derive(Debug)]
pub enum OpenError {
#[cfg(feature = "std")]
Io(std::io::Error),
TooShort,
BadMagic,
UnsupportedVersion(u8),
UnknownEngineKind(u8),
CorruptOffsets,
Sha256Mismatch,
}
#[cfg(feature = "std")]
impl From<std::io::Error> for OpenError {
fn from(e: std::io::Error) -> Self { Self::Io(e) }
}
#[derive(Copy, Clone, Debug)]
pub struct Entry<'a> {
pub word: &'a str,
pub code: &'a str,
pub log_prior: i16,
pub raw_freq: u32,
pub match_type: inputx_scoring::MatchType,
pub flags: EntryFlags,
}
pub struct IdfReader<B>
where
B: AsRef<[u8]>,
{
bytes: B,
header: Header,
}
impl<B: AsRef<[u8]>> IdfReader<B> {
pub fn from_bytes(bytes: B) -> Result<Self, OpenError> {
let buf = bytes.as_ref();
if buf.len() < FULL_HEADER_SIZE {
return Err(OpenError::TooShort);
}
if buf[0..4] != MAGIC {
return Err(OpenError::BadMagic);
}
let header = Header::parse(buf).ok_or(OpenError::BadMagic)?;
if Version::from_byte(header.format_version).is_none() {
return Err(OpenError::UnsupportedVersion(header.format_version));
}
if EngineKind::from_byte(header.engine_kind).is_none() {
return Err(OpenError::UnknownEngineKind(header.engine_kind));
}
let file_len = buf.len() as u32;
for (off, sz) in [
(header.string_pool_offset, header.string_pool_size),
(
header.entry_table_offset,
header.entry_count.saturating_mul(crate::codec::ENTRY_SIZE as u32),
),
(header.fst_code_index_offset, header.fst_code_index_size),
(header.fst_word_index_offset, header.fst_word_index_size),
] {
if off > file_len || off.saturating_add(sz) > file_len {
return Err(OpenError::CorruptOffsets);
}
}
#[cfg(feature = "std")]
{
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(&buf[FULL_HEADER_SIZE..]);
let got: [u8; 32] = hasher.finalize().into();
if got != header.sha256_of_payload {
return Err(OpenError::Sha256Mismatch);
}
}
Ok(Self { bytes, header })
}
pub fn header(&self) -> &Header { &self.header }
pub fn version(&self) -> Version {
Version::from_byte(self.header.format_version).expect("validated at open")
}
pub fn engine_kind(&self) -> EngineKind {
EngineKind::from_byte(self.header.engine_kind).expect("validated at open")
}
pub fn entry_count(&self) -> u32 { self.header.entry_count }
pub fn sha256(&self) -> [u8; 32] { self.header.sha256_of_payload }
pub fn entries(&self) -> impl Iterator<Item = Entry<'_>> + '_ {
let buf = self.bytes.as_ref();
let entry_table_start = self.header.entry_table_offset as usize;
let n = self.header.entry_count as usize;
let string_pool_start = self.header.string_pool_offset as usize;
let string_pool_end =
string_pool_start + self.header.string_pool_size as usize;
let pool = &buf[string_pool_start..string_pool_end];
(0..n).map(move |i| {
let off = entry_table_start + i * crate::codec::ENTRY_SIZE;
let rec_bytes: [u8; crate::codec::ENTRY_SIZE] = buf
[off..off + crate::codec::ENTRY_SIZE]
.try_into()
.expect("entry slice is exactly ENTRY_SIZE");
let rec = EntryRecord::parse(&rec_bytes);
decode_entry(&rec, pool)
})
}
pub fn entry_at(&self, index: u32) -> Option<Entry<'_>> {
if index >= self.header.entry_count {
return None;
}
let buf = self.bytes.as_ref();
let off = self.header.entry_table_offset as usize
+ index as usize * crate::codec::ENTRY_SIZE;
let rec_bytes: [u8; crate::codec::ENTRY_SIZE] = buf
[off..off + crate::codec::ENTRY_SIZE]
.try_into()
.ok()?;
let rec = EntryRecord::parse(&rec_bytes);
let pool = self.string_pool();
Some(decode_entry(&rec, pool))
}
pub fn lookup<'a>(&'a self, code: &[u8]) -> Vec<Entry<'a>> {
let mut out: Vec<Entry<'a>> = Vec::new();
if let Some(fst_bytes) = self.fst_code_index_bytes() {
if let Ok(fst) = inputx_fsa::Fsa::new(fst_bytes) {
if let Some(first) = fst.get(code) {
let total = self.header.entry_count as u64;
let mut idx = first;
while idx < total {
if let Some(e) = self.entry_at(idx as u32) {
if e.code.as_bytes() == code {
out.push(e);
idx += 1;
continue;
}
}
break;
}
}
return out;
}
}
for entry in self.entries() {
if entry.code.as_bytes() == code {
out.push(entry);
}
}
out
}
fn fst_code_index_bytes(&self) -> Option<&[u8]> {
if self.header.fst_code_index_size == 0 {
return None;
}
let buf = self.bytes.as_ref();
let s = self.header.fst_code_index_offset as usize;
let e = s + self.header.fst_code_index_size as usize;
Some(&buf[s..e])
}
pub fn find_by_word<'a>(&'a self, word: &str) -> Vec<Entry<'a>> {
let mut out: Vec<Entry<'a>> = Vec::new();
for entry in self.entries() {
if entry.word == word {
out.push(entry);
}
}
out
}
pub fn prefix_top_k_fst<'a>(&'a self, prefix: &[u8], k: usize) -> Vec<Entry<'a>> {
if k == 0 { return Vec::new(); }
let Some(fst_bytes) = self.fst_code_index_bytes() else {
return self.prefix_top_k(prefix, k);
};
let Ok(fst) = inputx_fsa::Fsa::new(fst_bytes) else {
return self.prefix_top_k(prefix, k);
};
let total = self.header.entry_count as u64;
let mut hits: Vec<Entry<'a>> = Vec::new();
fst.prefix_for_each(prefix, |_code, first_idx| {
let mut idx = first_idx;
while idx < total {
if let Some(e) = self.entry_at(idx as u32) {
if e.code.as_bytes().starts_with(prefix) {
if let Some(prev) = hits.last() {
if prev.code.as_bytes() == e.code.as_bytes() {
hits.push(e);
idx += 1;
continue;
}
}
if e.code.as_bytes() == &_code[..] {
hits.push(e);
idx += 1;
continue;
}
}
}
break;
}
});
hits.sort_by(|a, b| {
b.log_prior
.cmp(&a.log_prior)
.then_with(|| a.code.cmp(b.code))
});
hits.truncate(k);
hits
}
pub fn prefix_for_each_entry<'a, F: FnMut(Entry<'a>)>(
&'a self,
prefix: &[u8],
mut visit: F,
) {
if let Some(fst_bytes) = self.fst_code_index_bytes()
&& let Ok(fst) = inputx_fsa::Fsa::new(fst_bytes)
{
let total = self.header.entry_count as u64;
fst.prefix_for_each(prefix, |code, first_idx| {
let mut idx = first_idx;
while idx < total {
let Some(e) = self.entry_at(idx as u32) else { break; };
if e.code.as_bytes() != code { break; }
visit(e);
idx += 1;
}
});
return;
}
for entry in self.entries() {
if entry.code.as_bytes().starts_with(prefix) {
visit(entry);
}
}
}
pub fn prefix_top_k<'a>(&'a self, prefix: &[u8], k: usize) -> Vec<Entry<'a>> {
if k == 0 {
return Vec::new();
}
let mut candidates: Vec<Entry<'a>> = self
.entries()
.filter(|e| e.code.as_bytes().starts_with(prefix))
.collect();
candidates.sort_by(|a, b| {
b.log_prior
.cmp(&a.log_prior)
.then_with(|| a.code.cmp(b.code))
});
candidates.truncate(k);
candidates
}
fn string_pool(&self) -> &[u8] {
let buf = self.bytes.as_ref();
let start = self.header.string_pool_offset as usize;
let end = start + self.header.string_pool_size as usize;
&buf[start..end]
}
}
#[cfg(feature = "std")]
impl IdfReader<memmap2::Mmap> {
pub fn open<P: AsRef<std::path::Path>>(path: P) -> Result<Self, OpenError> {
let file = std::fs::File::open(path)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
Self::from_bytes(mmap)
}
}
fn decode_entry<'a>(rec: &EntryRecord, pool: &'a [u8]) -> Entry<'a> {
let word = read_string(pool, rec.word_offset);
let code = read_string(pool, rec.code_offset);
Entry {
word,
code,
log_prior: rec.log_prior,
raw_freq: rec.raw_freq,
match_type: decode_match_type(rec.match_type),
flags: EntryFlags(rec.flags),
}
}
fn read_string(pool: &[u8], offset: u32) -> &str {
let start = offset as usize;
if start >= pool.len() {
return "";
}
let rest = &pool[start..];
let end = rest.iter().position(|&b| b == 0).unwrap_or(rest.len());
core::str::from_utf8(&rest[..end]).unwrap_or("")
}