use triblespace_core::inline::Encodes;
use anybytes::area::{SectionHandle, SectionWriter};
use anybytes::view::View;
use anybytes::{ByteArea, Bytes};
use jerky::int_vectors::compact_vector::CompactVectorMeta;
use jerky::int_vectors::{CompactVector, CompactVectorBuilder};
use jerky::serialization::Serializable;
use triblespace_core::blob::encodings::succinctarchive::{
CompressedUniverse, CompressedUniverseMeta, Universe,
};
use triblespace_core::blob::{Blob, BlobEncoding, TryFromBlob};
use triblespace_core::id::ExclusiveId;
use triblespace_core::id_hex;
use triblespace_core::macros::entity;
use triblespace_core::metadata::{self, MetaDescribe};
use triblespace_core::query::Variable;
use triblespace_core::trible::Fragment;
use triblespace_core::inline::{RawInline, Inline, InlineEncoding};
use crate::schemas::{EmbHandle, Embedding};
use std::collections::HashMap;
use crate::hnsw::HNSWIndex;
#[derive(Debug)]
pub enum SuccinctDocLensError {
Bytes(std::io::Error),
Jerky(jerky::error::Error),
SizeMismatch {
bytes: usize,
expected: usize,
},
}
impl std::fmt::Display for SuccinctDocLensError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Bytes(e) => write!(f, "succinct: bytes error: {e}"),
Self::Jerky(e) => write!(f, "succinct: jerky error: {e}"),
Self::SizeMismatch { bytes, expected } => write!(
f,
"succinct: size mismatch: have {bytes} bytes, declared needs {expected}"
),
}
}
}
impl std::error::Error for SuccinctDocLensError {}
impl From<std::io::Error> for SuccinctDocLensError {
fn from(e: std::io::Error) -> Self {
Self::Bytes(e)
}
}
impl From<jerky::error::Error> for SuccinctDocLensError {
fn from(e: jerky::error::Error) -> Self {
Self::Jerky(e)
}
}
#[derive(Debug)]
pub(crate) struct SuccinctDocLens {
inner: CompactVector,
}
impl SuccinctDocLens {
pub(crate) fn build_into(
sections: &mut SectionWriter<'_>,
lens: &[u32],
) -> Result<CompactVectorMeta, SuccinctDocLensError> {
let width = required_width(lens);
let mut builder = CompactVectorBuilder::with_capacity(lens.len(), width, sections)?;
builder.set_ints(0..lens.len(), lens.iter().map(|&n| n as usize))?;
let cv = builder.freeze();
Ok(cv.metadata())
}
#[cfg(test)]
pub(crate) fn build(
lens: &[u32],
) -> Result<(Bytes, CompactVectorMeta), SuccinctDocLensError> {
let mut area = ByteArea::new()?;
let mut sections = area.sections();
let meta = Self::build_into(&mut sections, lens)?;
let bytes = area.freeze()?;
Ok((bytes, meta))
}
pub(crate) fn from_bytes(
meta: CompactVectorMeta,
bytes: Bytes,
) -> Result<Self, SuccinctDocLensError> {
let inner = CompactVector::from_bytes(meta, bytes)?;
Ok(Self { inner })
}
pub(crate) fn get(&self, i: usize) -> Option<u32> {
self.inner.get_int(i).map(|n| n as u32)
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize {
self.inner.len()
}
#[cfg(test)]
pub(crate) fn is_empty(&self) -> bool {
self.inner.len() == 0
}
#[cfg(test)]
pub(crate) fn to_vec(&self) -> Vec<u32> {
self.inner.to_vec().into_iter().map(|n| n as u32).collect()
}
#[cfg(test)]
pub(crate) fn width(&self) -> usize {
self.inner.metadata().width
}
}
fn required_width(lens: &[u32]) -> usize {
let max = lens.iter().copied().max().unwrap_or(0);
match max {
0 => 1,
_ => 32 - max.leading_zeros() as usize,
}
}
pub(crate) fn pack_byte_table<const N: usize>(
sections: &mut SectionWriter<'_>,
rows: &[[u8; N]],
) -> Result<SectionHandle<[u8; N]>, std::io::Error> {
let mut sec = sections.reserve::<[u8; N]>(rows.len())?;
sec.as_mut_slice().copy_from_slice(rows);
let handle = sec.handle();
let _ = sec.freeze()?;
Ok(handle)
}
#[derive(Debug)]
pub(crate) struct SuccinctPostings {
doc_idx: CompactVector,
offsets: CompactVector,
scores: CompactVector,
max_score: f32,
n_terms: usize,
}
#[derive(
Debug, Clone, Copy, zerocopy::FromBytes, zerocopy::KnownLayout, zerocopy::Immutable,
)]
#[repr(C)]
pub(crate) struct SuccinctPostingsMeta {
pub n_terms: u64,
pub doc_idx: CompactVectorMeta,
pub offsets: CompactVectorMeta,
pub scores: CompactVectorMeta,
pub max_score: f32,
_pad: u32,
}
const SCORE_WIDTH: usize = 16;
const SCORE_MAX_Q: u32 = u16::MAX as u32;
fn quantize_score(s: f32, max_score: f32) -> u16 {
if max_score <= 0.0 {
return 0;
}
let ratio = (s / max_score).clamp(0.0, 1.0);
(ratio * SCORE_MAX_Q as f32).round() as u16
}
fn dequantize_score(q: u16, max_score: f32) -> f32 {
if max_score <= 0.0 {
return 0.0;
}
(q as f32 / SCORE_MAX_Q as f32) * max_score
}
impl SuccinctPostings {
#[cfg(test)]
pub(crate) fn build(
lists: &[Vec<(u32, f32)>],
n_docs: u32,
) -> Result<(Bytes, SuccinctPostingsMeta), SuccinctDocLensError> {
Self::build_with(n_docs, lists.len(), |t, buf| {
buf.extend_from_slice(&lists[t]);
})
}
#[cfg(test)]
pub(crate) fn build_with<F>(
n_docs: u32,
n_terms: usize,
mut materialize_term: F,
) -> Result<(Bytes, SuccinctPostingsMeta), SuccinctDocLensError>
where
F: FnMut(usize, &mut Vec<(u32, f32)>),
{
let mut buf: Vec<(u32, f32)> = Vec::new();
let mut total: usize = 0;
let mut max_score = 0.0f32;
for t in 0..n_terms {
buf.clear();
materialize_term(t, &mut buf);
total += buf.len();
for &(_, s) in &buf {
if s > max_score {
max_score = s;
}
}
}
let mut area = ByteArea::new()?;
let mut sections = area.sections();
let meta = Self::build_with_into(
&mut sections,
n_docs,
n_terms,
total,
max_score,
materialize_term,
)?;
let bytes = area.freeze()?;
Ok((bytes, meta))
}
pub(crate) fn build_with_into<F>(
sections: &mut SectionWriter<'_>,
n_docs: u32,
n_terms: usize,
total: usize,
max_score: f32,
mut materialize_term: F,
) -> Result<SuccinctPostingsMeta, SuccinctDocLensError>
where
F: FnMut(usize, &mut Vec<(u32, f32)>),
{
let doc_idx_width = width_for(n_docs as usize + 1);
let offsets_width = width_for(total + 1);
let mut doc_idx_b =
CompactVectorBuilder::with_capacity(total, doc_idx_width, sections)?;
let mut offsets_b =
CompactVectorBuilder::with_capacity(n_terms + 1, offsets_width, sections)?;
let mut scores_b = CompactVectorBuilder::with_capacity(total, SCORE_WIDTH, sections)?;
offsets_b.set_int(0, 0)?;
let mut buf: Vec<(u32, f32)> = Vec::new();
let mut pos = 0usize;
for t in 0..n_terms {
buf.clear();
materialize_term(t, &mut buf);
for &(idx, s) in &buf {
doc_idx_b.set_int(pos, idx as usize)?;
scores_b.set_int(pos, quantize_score(s, max_score) as usize)?;
pos += 1;
}
offsets_b.set_int(t + 1, pos)?;
}
debug_assert_eq!(
pos, total,
"build_with_into: closure produced {pos} postings; caller said total = {total}"
);
let doc_idx_meta = doc_idx_b.freeze().metadata();
let offsets_meta = offsets_b.freeze().metadata();
let scores_meta = scores_b.freeze().metadata();
Ok(SuccinctPostingsMeta {
n_terms: n_terms as u64,
doc_idx: doc_idx_meta,
offsets: offsets_meta,
scores: scores_meta,
max_score,
_pad: 0,
})
}
pub fn from_bytes(
meta: SuccinctPostingsMeta,
bytes: Bytes,
) -> Result<Self, SuccinctDocLensError> {
let doc_idx = CompactVector::from_bytes(meta.doc_idx, bytes.clone())?;
let offsets = CompactVector::from_bytes(meta.offsets, bytes.clone())?;
let scores = CompactVector::from_bytes(meta.scores, bytes)?;
Ok(Self {
doc_idx,
offsets,
scores,
max_score: meta.max_score,
n_terms: meta.n_terms as usize,
})
}
#[cfg(test)]
pub(crate) fn term_count(&self) -> usize {
self.n_terms
}
pub fn score_tolerance(&self) -> f32 {
if self.max_score <= 0.0 {
f32::EPSILON
} else {
self.max_score / 65534.0
}
}
pub fn posting_count(&self, t: usize) -> Option<usize> {
if t >= self.n_terms {
return None;
}
let start = self.offsets.get_int(t)?;
let end = self.offsets.get_int(t + 1)?;
Some(end - start)
}
pub fn postings_for(&self, t: usize) -> Option<impl Iterator<Item = (u32, f32)> + '_> {
if t >= self.n_terms {
return None;
}
let start = self.offsets.get_int(t)?;
let end = self.offsets.get_int(t + 1)?;
let max = self.max_score;
Some((start..end).map(move |i| {
let idx = self.doc_idx.get_int(i).unwrap() as u32;
let q = self.scores.get_int(i).unwrap() as u16;
(idx, dequantize_score(q, max))
}))
}
}
fn width_for(n: usize) -> usize {
if n <= 1 {
1
} else {
(usize::BITS - (n - 1).leading_zeros()) as usize
}
}
#[derive(Debug)]
pub struct SuccinctGraph {
neighbours: CompactVector,
offsets: CompactVector,
n_nodes: usize,
n_layers: usize,
}
#[derive(
Debug, Clone, Copy, zerocopy::FromBytes, zerocopy::KnownLayout, zerocopy::Immutable,
)]
#[repr(C)]
pub struct SuccinctGraphMeta {
pub n_nodes: u64,
pub n_layers: u64,
pub neighbours: CompactVectorMeta,
pub offsets: CompactVectorMeta,
}
impl SuccinctGraph {
pub fn build(
layer_graph: &[Vec<Vec<u32>>],
n_nodes: usize,
) -> Result<(Bytes, SuccinctGraphMeta), SuccinctDocLensError> {
let mut area = ByteArea::new()?;
let mut sections = area.sections();
let meta = Self::build_into(&mut sections, layer_graph, n_nodes)?;
let bytes = area.freeze()?;
Ok((bytes, meta))
}
pub(crate) fn build_into(
sections: &mut SectionWriter<'_>,
layer_graph: &[Vec<Vec<u32>>],
n_nodes: usize,
) -> Result<SuccinctGraphMeta, SuccinctDocLensError> {
let n_layers = layer_graph.len();
for layer in layer_graph {
if layer.len() != n_nodes {
return Err(SuccinctDocLensError::SizeMismatch {
bytes: layer.len(),
expected: n_nodes,
});
}
for list in layer {
for &n in list {
if (n as usize) >= n_nodes {
return Err(SuccinctDocLensError::SizeMismatch {
bytes: n as usize,
expected: n_nodes,
});
}
}
}
}
let total_edges: usize = layer_graph
.iter()
.flat_map(|layer| layer.iter().map(|l| l.len()))
.sum();
let neighbours_width = width_for(n_nodes + 1);
let offsets_width = width_for(total_edges + 1);
let offsets_len = n_layers * (n_nodes + 1);
let mut neighbours_b =
CompactVectorBuilder::with_capacity(total_edges, neighbours_width, sections)?;
let mut pos = 0usize;
for layer in layer_graph {
for list in layer {
for &n in list {
neighbours_b.set_int(pos, n as usize)?;
pos += 1;
}
}
}
let neighbours_meta = neighbours_b.freeze().metadata();
let mut offsets_b =
CompactVectorBuilder::with_capacity(offsets_len, offsets_width, sections)?;
let mut cum = 0usize;
let mut slot = 0usize;
for layer in layer_graph {
offsets_b.set_int(slot, cum)?;
slot += 1;
for list in layer {
cum += list.len();
offsets_b.set_int(slot, cum)?;
slot += 1;
}
}
while slot < offsets_len {
offsets_b.set_int(slot, cum)?;
slot += 1;
}
let offsets_meta = offsets_b.freeze().metadata();
Ok(SuccinctGraphMeta {
neighbours: neighbours_meta,
offsets: offsets_meta,
n_nodes: n_nodes as u64,
n_layers: n_layers as u64,
})
}
pub fn from_bytes(meta: SuccinctGraphMeta, bytes: Bytes) -> Result<Self, SuccinctDocLensError> {
let neighbours = CompactVector::from_bytes(meta.neighbours, bytes.clone())?;
let offsets = CompactVector::from_bytes(meta.offsets, bytes)?;
Ok(Self {
neighbours,
offsets,
n_nodes: meta.n_nodes as usize,
n_layers: meta.n_layers as usize,
})
}
pub fn n_nodes(&self) -> usize {
self.n_nodes
}
pub fn n_layers(&self) -> usize {
self.n_layers
}
pub fn neighbours(&self, node: usize, layer: usize) -> impl Iterator<Item = u32> + '_ {
let (start, end) = if node >= self.n_nodes || layer >= self.n_layers {
(0usize, 0usize)
} else {
let slot = layer * (self.n_nodes + 1) + node;
let start = self.offsets.get_int(slot).unwrap_or(0);
let end = self.offsets.get_int(slot + 1).unwrap_or(start);
(start, end)
};
(start..end).map(move |i| self.neighbours.get_int(i).unwrap() as u32)
}
}
#[derive(
Debug, Clone, Copy, zerocopy::FromBytes, zerocopy::KnownLayout, zerocopy::Immutable,
)]
#[repr(C)]
pub struct SuccinctHNSWMeta {
pub(crate) n_nodes: u64,
pub(crate) graph: SuccinctGraphMeta,
pub(crate) handles: SectionHandle<[u8; 32]>,
pub(crate) dim: u32,
pub(crate) entry_point: u32,
pub(crate) m: u16,
pub(crate) m0: u16,
pub(crate) max_level: u8,
pub(crate) has_entry_point: u8,
_pad: [u8; 10],
}
const _: () = assert!(
std::mem::size_of::<SuccinctHNSWMeta>() == 128,
"SuccinctHNSWMeta must be 128 bytes — re-tune _pad if the layout shifts",
);
pub struct SuccinctHNSWIndex {
pub bytes: Bytes,
dim: usize,
m: u16,
m0: u16,
max_level: u8,
entry_point: Option<u32>,
handles: View<[[u8; 32]]>,
graph: SuccinctGraph,
}
impl std::fmt::Debug for SuccinctHNSWIndex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SuccinctHNSWIndex")
.field("n_nodes", &self.handles.len())
.field("dim", &self.dim)
.field("max_level", &self.max_level)
.finish()
}
}
impl SuccinctHNSWIndex {
pub fn from_naive(idx: &HNSWIndex) -> Result<Self, SuccinctDocLensError> {
let n = idx.doc_count();
let dim = idx.dim();
let max_level = idx.max_level();
let n_layers = max_level as usize + 1;
let mut area = ByteArea::new()?;
let mut sections = area.sections();
let handle_rows: Vec<RawInline> = idx.handles().iter().map(|h| h.raw).collect();
let handles_handle = pack_byte_table::<32>(&mut sections, &handle_rows)?;
let mut layer_graph: Vec<Vec<Vec<u32>>> = (0..n_layers)
.map(|_| (0..n).map(|_| Vec::new()).collect())
.collect();
for (layer, row) in layer_graph.iter_mut().enumerate() {
for (i, slot) in row.iter_mut().enumerate() {
let lvl = idx.node_level(i).expect("node in range") as usize;
if lvl >= layer {
*slot = idx.node_neighbours(i, layer as u8).to_vec();
}
}
}
let graph_meta = SuccinctGraph::build_into(&mut sections, &layer_graph, n)?;
let entry_point_raw = idx.entry_point();
let meta = SuccinctHNSWMeta {
n_nodes: n as u64,
graph: graph_meta,
handles: handles_handle,
dim: dim as u32,
entry_point: entry_point_raw.unwrap_or(u32::MAX),
m: idx.m(),
m0: idx.m0(),
max_level,
has_entry_point: entry_point_raw.is_some() as u8,
_pad: [0u8; 10],
};
{
let mut meta_sec = sections.reserve::<SuccinctHNSWMeta>(1)?;
meta_sec.as_mut_slice()[0] = meta;
meta_sec.freeze()?;
}
drop(sections);
let bytes = area.freeze()?;
Self::from_bytes(meta, bytes).map_err(|_| SuccinctDocLensError::SizeMismatch {
bytes: 0,
expected: 0,
})
}
pub fn from_bytes(
meta: SuccinctHNSWMeta,
bytes: Bytes,
) -> Result<Self, SuccinctLoadError> {
let handles = meta
.handles
.view(&bytes)
.map_err(|_| SuccinctLoadError::TruncatedSection("handles"))?;
let graph = SuccinctGraph::from_bytes(meta.graph, bytes.clone())
.map_err(|_| SuccinctLoadError::TruncatedSection("graph"))?;
Ok(Self {
bytes,
dim: meta.dim as usize,
m: meta.m,
m0: meta.m0,
max_level: meta.max_level,
entry_point: if meta.has_entry_point != 0 {
Some(meta.entry_point)
} else {
None
},
handles,
graph,
})
}
pub fn meta(&self) -> SuccinctHNSWMeta {
let mut tail = self.bytes.clone();
*tail
.view_suffix::<SuccinctHNSWMeta>()
.expect("canonical bytes carry meta as suffix-section")
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn doc_count(&self) -> usize {
self.handles.len()
}
pub fn m(&self) -> u16 {
self.m
}
pub fn m0(&self) -> u16 {
self.m0
}
pub fn max_level(&self) -> u8 {
self.max_level
}
pub fn attach<'a, B>(&'a self, store: &B) -> AttachedSuccinctHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet + Clone,
{
AttachedSuccinctHNSWIndex {
index: self,
cache: triblespace_core::blob::BlobCache::new(store.clone()),
ef_search: 200,
}
}
}
pub struct AttachedSuccinctHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
index: &'a SuccinctHNSWIndex,
cache: triblespace_core::blob::BlobCache<B, Embedding, anybytes::View<[f32]>>,
ef_search: usize,
}
impl<'a, B> AttachedSuccinctHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
pub fn index(&self) -> &SuccinctHNSWIndex {
self.index
}
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
pub fn similar(
&self,
a: Variable<EmbHandle>,
b: Variable<EmbHandle>,
score_floor: f32,
) -> crate::constraint::Similar<'_, Self> {
crate::constraint::Similar::new(self, a, b, score_floor)
}
pub fn similar_to(
&self,
probe: Inline<EmbHandle>,
var: Variable<EmbHandle>,
score_floor: f32,
) -> crate::constraint::SimilarTo {
let candidates = self
.candidates_above(probe, score_floor)
.map(|v| v.into_iter().map(|h| h.raw).collect())
.unwrap_or_default();
crate::constraint::SimilarTo::from_candidates(var, candidates)
}
#[doc(hidden)]
pub fn candidates_above(
&self,
from_handle: Inline<EmbHandle>,
score_floor: f32,
) -> Result<Vec<Inline<EmbHandle>>, B::GetError<anybytes::view::ViewError>> {
let Some(entry) = self.index.entry_point else {
return Ok(Vec::new());
};
let from = self.cache.get(from_handle)?;
let query: Vec<f32> = from.as_ref().as_ref().to_vec();
if query.len() != self.index.dim {
return Ok(Vec::new());
}
let mut curr = entry;
for lvl in (1..=self.index.max_level).rev() {
curr = self.greedy_search_layer(&query, curr, lvl)?;
}
let candidates = self.search_layer(&query, curr, self.ef_search, 0)?;
Ok(candidates
.into_iter()
.filter(|(_, dist)| 1.0 - dist >= score_floor)
.map(|(i, _)| {
let raw = *self.index.handles.get(i as usize).expect("in range");
Inline::new(raw)
})
.collect())
}
fn dist_to(
&self,
q: &[f32],
i: u32,
) -> Result<f32, B::GetError<anybytes::view::ViewError>> {
let raw = *self.index.handles.get(i as usize).expect("in range");
let handle: Inline<EmbHandle> = Inline::new(raw);
let view = self.cache.get(handle)?;
Ok(crate::hnsw::cosine_dist(q, view.as_ref().as_ref()))
}
fn greedy_search_layer(
&self,
q: &[f32],
entry: u32,
layer: u8,
) -> Result<u32, B::GetError<anybytes::view::ViewError>> {
let mut curr = entry;
let mut curr_dist = self.dist_to(q, curr)?;
loop {
let mut changed = false;
let neigh: Vec<u32> = self
.index
.graph
.neighbours(curr as usize, layer as usize)
.collect();
if neigh.is_empty() {
return Ok(curr);
}
for n in neigh {
let d = self.dist_to(q, n)?;
if d < curr_dist {
curr_dist = d;
curr = n;
changed = true;
}
}
if !changed {
return Ok(curr);
}
}
}
fn search_layer(
&self,
q: &[f32],
entry: u32,
ef: usize,
layer: u8,
) -> Result<Vec<(u32, f32)>, B::GetError<anybytes::view::ViewError>> {
use std::collections::{BinaryHeap, HashSet};
let mut visited: HashSet<u32> = HashSet::new();
visited.insert(entry);
let d0 = self.dist_to(q, entry)?;
#[derive(Clone, Copy)]
struct MinD {
idx: u32,
dist: f32,
}
impl PartialEq for MinD {
fn eq(&self, o: &Self) -> bool {
self.dist == o.dist
}
}
impl Eq for MinD {}
impl PartialOrd for MinD {
fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(o))
}
}
impl Ord for MinD {
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
o.dist
.partial_cmp(&self.dist)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Clone, Copy)]
struct MaxD {
idx: u32,
dist: f32,
}
impl PartialEq for MaxD {
fn eq(&self, o: &Self) -> bool {
self.dist == o.dist
}
}
impl Eq for MaxD {}
impl PartialOrd for MaxD {
fn partial_cmp(&self, o: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(o))
}
}
impl Ord for MaxD {
fn cmp(&self, o: &Self) -> std::cmp::Ordering {
self.dist
.partial_cmp(&o.dist)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
let mut candidates: BinaryHeap<MinD> = BinaryHeap::new();
candidates.push(MinD {
idx: entry,
dist: d0,
});
let mut results: BinaryHeap<MaxD> = BinaryHeap::new();
results.push(MaxD {
idx: entry,
dist: d0,
});
while let Some(c) = candidates.pop() {
let farthest = results.peek().map(|r| r.dist).unwrap_or(f32::INFINITY);
if c.dist > farthest && results.len() >= ef {
break;
}
let neigh: Vec<u32> = self
.index
.graph
.neighbours(c.idx as usize, layer as usize)
.collect();
for n in neigh {
if !visited.insert(n) {
continue;
}
let d = self.dist_to(q, n)?;
let farthest = results.peek().map(|r| r.dist).unwrap_or(f32::INFINITY);
if d < farthest || results.len() < ef {
candidates.push(MinD { idx: n, dist: d });
results.push(MaxD { idx: n, dist: d });
if results.len() > ef {
results.pop();
}
}
}
}
Ok(results.into_iter().map(|m| (m.idx, m.dist)).collect())
}
}
impl<'a, B> crate::constraint::SimilaritySearch for AttachedSuccinctHNSWIndex<'a, B>
where
B: triblespace_core::repo::BlobStoreGet,
{
fn neighbours_above(
&self,
from: Inline<EmbHandle>,
score_floor: f32,
) -> Vec<Inline<EmbHandle>> {
self.candidates_above(from, score_floor).unwrap_or_default()
}
fn cosine_between(
&self,
a: Inline<EmbHandle>,
b: Inline<EmbHandle>,
) -> Option<f32> {
let va = self.cache.get(a).ok()?;
let vb = self.cache.get(b).ok()?;
let a_slice: &[f32] = va.as_ref().as_ref();
let b_slice: &[f32] = vb.as_ref().as_ref();
if a_slice.len() != b_slice.len() {
return None;
}
let mut sum = 0.0f32;
for (x, y) in a_slice.iter().zip(b_slice.iter()) {
sum += x * y;
}
Some(sum)
}
}
#[derive(
Debug, Clone, Copy, zerocopy::FromBytes, zerocopy::KnownLayout, zerocopy::Immutable,
)]
#[repr(C)]
pub struct SuccinctBM25Meta {
pub(crate) n_docs: u64,
pub(crate) n_terms: u64,
pub(crate) avg_doc_len: f32,
pub(crate) k1: f32,
pub(crate) b: f32,
_pad: u32,
pub(crate) keys: CompressedUniverseMeta,
pub(crate) doc_lens: CompactVectorMeta,
pub(crate) postings: SuccinctPostingsMeta,
pub(crate) terms: SectionHandle<[u8; 32]>,
}
pub struct SuccinctBM25Index<
D: InlineEncoding = triblespace_core::inline::encodings::genid::GenId,
T: InlineEncoding = crate::tokens::WordHash,
> {
pub bytes: Bytes,
keys: CompressedUniverse,
doc_lens: SuccinctDocLens,
terms: View<[[u8; 32]]>,
postings: SuccinctPostings,
avg_doc_len: f32,
k1: f32,
b: f32,
_phantom: std::marker::PhantomData<(D, T)>,
}
impl<D: InlineEncoding, T: InlineEncoding> std::fmt::Debug for SuccinctBM25Index<D, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SuccinctBM25Index")
.field("n_docs", &self.keys.len())
.field("n_terms", &self.terms.len())
.field("avg_doc_len", &self.avg_doc_len)
.field("k1", &self.k1)
.field("b", &self.b)
.finish()
}
}
impl<D: InlineEncoding, T: InlineEncoding> SuccinctBM25Index<D, T> {
pub(crate) fn from_builder(
builder: crate::bm25::BM25Builder<D, T>,
) -> Self {
let crate::bm25::BM25Builder { docs, k1, b, _phantom: _ } = builder;
let mut area = ByteArea::new().expect("alloc ByteArea");
let mut sections = area.sections();
let build_universe =
CompressedUniverse::with(docs.iter().map(|(k, _)| *k), &mut sections);
let keys_meta = build_universe.metadata();
let n_universe = build_universe.len();
let mut doc_lens_vec = vec![0u32; n_universe];
let mut term_to_tfs: HashMap<RawInline, HashMap<u32, u32>> = HashMap::new();
for (key, terms) in docs {
let code = build_universe
.search(&key)
.expect("key just inserted into universe") as u32;
doc_lens_vec[code as usize] = terms.len() as u32;
for term in terms {
*term_to_tfs.entry(term).or_default().entry(code).or_insert(0) += 1;
}
}
drop(build_universe);
let avg_doc_len = if n_universe == 0 {
0.0
} else {
doc_lens_vec.iter().map(|&n| n as f64).sum::<f64>() as f32
/ n_universe as f32
};
let doc_lens_meta = SuccinctDocLens::build_into(&mut sections, &doc_lens_vec)
.expect("build doc_lens");
let mut term_rows: Vec<RawInline> = term_to_tfs.keys().copied().collect();
term_rows.sort_unstable();
let n_terms = term_rows.len();
let terms_handle = pack_byte_table::<32>(&mut sections, &term_rows)
.expect("build terms");
let n = n_universe as f32;
let bm25_score = |df: f32, idf: f32, tf: u32, code: u32| -> f32 {
let tf_f = tf as f32;
let dl = doc_lens_vec[code as usize] as f32;
let norm = if avg_doc_len > 0.0 {
1.0 - b + b * (dl / avg_doc_len)
} else {
1.0
};
let _ = df;
idf * (tf_f * (k1 + 1.0)) / (tf_f + k1 * norm)
};
let total: usize = term_to_tfs.values().map(|m| m.len()).sum();
let max_score: f32 = term_rows.iter().fold(0.0f32, |acc, term| {
let tfs = &term_to_tfs[term];
let df = tfs.len() as f32;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
tfs.iter().fold(acc, |m, (&code, &tf)| {
m.max(bm25_score(df, idf, tf, code))
})
});
let postings_meta = SuccinctPostings::build_with_into(
&mut sections,
n_universe as u32,
n_terms,
total,
max_score,
|t, buf| {
let tfs = &term_to_tfs[&term_rows[t]];
let df = tfs.len() as f32;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
buf.extend(
tfs.iter()
.map(|(&code, &tf)| (code, bm25_score(df, idf, tf, code))),
);
buf.sort_unstable_by_key(|&(code, _)| code);
},
)
.expect("build postings");
let meta = SuccinctBM25Meta {
n_docs: n_universe as u64,
n_terms: n_terms as u64,
avg_doc_len,
k1,
b,
_pad: 0,
keys: keys_meta,
doc_lens: doc_lens_meta,
postings: postings_meta,
terms: terms_handle,
};
{
let mut meta_sec = sections
.reserve::<SuccinctBM25Meta>(1)
.expect("reserve meta section");
meta_sec.as_mut_slice()[0] = meta;
meta_sec.freeze().expect("freeze meta section");
}
drop(sections);
let bytes = area.freeze().expect("freeze ByteArea");
Self::from_bytes(meta, bytes).expect("round-trip the bytes we just built")
}
pub fn from_bytes(
meta: SuccinctBM25Meta,
bytes: Bytes,
) -> Result<Self, SuccinctLoadError> {
let keys = CompressedUniverse::from_bytes(meta.keys, bytes.clone())
.map_err(|_| SuccinctLoadError::TruncatedSection("keys"))?;
let doc_lens = SuccinctDocLens::from_bytes(meta.doc_lens, bytes.clone())
.map_err(|_| SuccinctLoadError::TruncatedSection("doc_lens"))?;
let terms = meta
.terms
.view(&bytes)
.map_err(|_| SuccinctLoadError::TruncatedSection("terms"))?;
let postings = SuccinctPostings::from_bytes(meta.postings, bytes.clone())
.map_err(|_| SuccinctLoadError::TruncatedSection("postings"))?;
Ok(Self {
bytes,
keys,
doc_lens,
terms,
postings,
avg_doc_len: meta.avg_doc_len,
k1: meta.k1,
b: meta.b,
_phantom: std::marker::PhantomData,
})
}
pub fn meta(&self) -> SuccinctBM25Meta {
let mut tail = self.bytes.clone();
*tail
.view_suffix::<SuccinctBM25Meta>()
.expect("canonical bytes carry meta as suffix-section")
}
pub fn doc_count(&self) -> usize {
self.keys.len()
}
pub fn term_count(&self) -> usize {
self.terms.len()
}
pub fn avg_doc_len(&self) -> f32 {
self.avg_doc_len
}
pub fn k1(&self) -> f32 {
self.k1
}
pub fn b(&self) -> f32 {
self.b
}
pub fn keys_size_bytes(&self) -> usize {
let meta = self.meta();
meta.keys.fragments.len + meta.keys.data.levels.len
}
pub fn doc_len(&self, i: usize) -> Option<u32> {
self.doc_lens.get(i)
}
pub fn score_tolerance(&self) -> f32 {
self.postings.score_tolerance()
}
pub fn doc_frequency(&self, term: &Inline<T>) -> usize {
match self.terms.binary_search(&term.raw) {
Ok(t) => self.postings.posting_count(t).unwrap_or(0),
Err(_) => 0,
}
}
pub fn query_term<'a>(
&'a self,
term: &Inline<T>,
) -> Box<dyn Iterator<Item = (Inline<D>, f32)> + 'a> {
match self.terms.binary_search(&term.raw) {
Ok(t) => match self.postings.postings_for(t) {
Some(iter) => Box::new(iter.map(move |(doc_idx, score)| {
let key = self.keys.access(doc_idx as usize);
(Inline::<D>::new(key), score)
})),
None => Box::new(std::iter::empty()),
},
Err(_) => Box::new(std::iter::empty()),
}
}
pub fn query_multi(&self, terms: &[Inline<T>]) -> Vec<(Inline<D>, f32)> {
let mut acc: std::collections::HashMap<RawInline, f32> =
std::collections::HashMap::new();
for term in terms {
for (key, score) in self.query_term(term) {
*acc.entry(key.raw).or_insert(0.0) += score;
}
}
let mut out: Vec<(Inline<D>, f32)> =
acc.into_iter().map(|(raw, s)| (Inline::<D>::new(raw), s)).collect();
out.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
out
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum SuccinctLoadError {
ShortHeader,
TruncatedSection(&'static str),
BadMeta(&'static str),
}
impl std::fmt::Display for SuccinctLoadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ShortHeader => write!(f, "succinct blob shorter than header"),
Self::TruncatedSection(name) => {
write!(f, "succinct blob: truncated section `{name}`")
}
Self::BadMeta(name) => write!(f, "succinct blob: bad meta `{name}`"),
}
}
}
impl std::error::Error for SuccinctLoadError {}
pub enum SuccinctBM25Blob {}
impl BlobEncoding for SuccinctBM25Blob {}
impl MetaDescribe for SuccinctBM25Blob {
fn describe() -> Fragment {
let id = id_hex!("DA527A8FF09A3709B2AC6425CD5AF7A8");
entity! { ExclusiveId::force_ref(&id) @
metadata::name: "SuccinctBM25Blob",
metadata::description: "Canonical-bytes blob format for the succinct BM25 index. The index *is* its blob: term-id table, postings, document-frequency table, and an `SuccinctBM25Meta` suffix all share one `anybytes::ByteArea`.",
metadata::tag: metadata::KIND_BLOB_ENCODING,
}
}
}
impl<D: InlineEncoding, T: InlineEncoding> Encodes<&SuccinctBM25Index<D, T>> for SuccinctBM25Blob
where triblespace_core::inline::encodings::hash::Handle<SuccinctBM25Blob>: triblespace_core::inline::InlineEncoding,
{
type Output = Blob<SuccinctBM25Blob>;
fn encode(source: &SuccinctBM25Index<D, T>) -> Blob<SuccinctBM25Blob> {
Blob::new(source.bytes.clone())
}
}
impl<D: InlineEncoding, T: InlineEncoding> Encodes<SuccinctBM25Index<D, T>> for SuccinctBM25Blob
where triblespace_core::inline::encodings::hash::Handle<SuccinctBM25Blob>: triblespace_core::inline::InlineEncoding,
{
type Output = Blob<SuccinctBM25Blob>;
fn encode(source: SuccinctBM25Index<D, T>) -> Blob<SuccinctBM25Blob> {
Blob::new(source.bytes)
}
}
impl<D: InlineEncoding, T: InlineEncoding> TryFromBlob<SuccinctBM25Blob> for SuccinctBM25Index<D, T> {
type Error = SuccinctLoadError;
fn try_from_blob(blob: Blob<SuccinctBM25Blob>) -> Result<Self, Self::Error> {
let bytes = blob.bytes;
let mut tail = bytes.clone();
let meta = *tail
.view_suffix::<SuccinctBM25Meta>()
.map_err(|_| SuccinctLoadError::BadMeta("suffix"))?;
SuccinctBM25Index::from_bytes(meta, bytes)
}
}
pub enum SuccinctHNSWBlob {}
impl BlobEncoding for SuccinctHNSWBlob {}
impl MetaDescribe for SuccinctHNSWBlob {
fn describe() -> Fragment {
let id = id_hex!("8DF997D25C15B73EDCEE9E08076F251E");
entity! { ExclusiveId::force_ref(&id) @
metadata::name: "SuccinctHNSWBlob",
metadata::description: "Canonical-bytes blob format for the succinct HNSW vector index. Handles + graph live in one shared `anybytes::ByteArea` with a suffix `SuccinctHNSWMeta`; embeddings themselves live as separate blobs in the pile referenced by handle.",
metadata::tag: metadata::KIND_BLOB_ENCODING,
}
}
}
impl Encodes<&SuccinctHNSWIndex> for SuccinctHNSWBlob
where triblespace_core::inline::encodings::hash::Handle<SuccinctHNSWBlob>: triblespace_core::inline::InlineEncoding,
{
type Output = Blob<SuccinctHNSWBlob>;
fn encode(source: &SuccinctHNSWIndex) -> Blob<SuccinctHNSWBlob> {
Blob::new(source.bytes.clone())
}
}
impl Encodes<SuccinctHNSWIndex> for SuccinctHNSWBlob
where triblespace_core::inline::encodings::hash::Handle<SuccinctHNSWBlob>: triblespace_core::inline::InlineEncoding,
{
type Output = Blob<SuccinctHNSWBlob>;
fn encode(source: SuccinctHNSWIndex) -> Blob<SuccinctHNSWBlob> {
Blob::new(source.bytes)
}
}
impl TryFromBlob<SuccinctHNSWBlob> for SuccinctHNSWIndex {
type Error = SuccinctLoadError;
fn try_from_blob(blob: Blob<SuccinctHNSWBlob>) -> Result<Self, Self::Error> {
let bytes = blob.bytes;
let mut tail = bytes.clone();
let meta = *tail
.view_suffix::<SuccinctHNSWMeta>()
.map_err(|_| SuccinctLoadError::BadMeta("suffix"))?;
SuccinctHNSWIndex::from_bytes(meta, bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use triblespace_core::repo::BlobStore;
#[test]
fn empty_roundtrip() {
let (bytes, meta) = SuccinctDocLens::build(&[]).unwrap();
let view = SuccinctDocLens::from_bytes(meta, bytes).unwrap();
assert!(view.is_empty());
assert_eq!(view.len(), 0);
assert_eq!(view.get(0), None);
}
#[test]
fn small_roundtrip() {
let lens = vec![3u32, 0, 7, 1, 15];
let (bytes, meta) = SuccinctDocLens::build(&lens).unwrap();
let view = SuccinctDocLens::from_bytes(meta, bytes).unwrap();
assert_eq!(view.len(), lens.len());
for (i, &n) in lens.iter().enumerate() {
assert_eq!(view.get(i), Some(n), "mismatch at {i}");
}
assert_eq!(view.to_vec(), lens);
}
#[test]
fn out_of_range_is_none() {
let (bytes, meta) = SuccinctDocLens::build(&[1u32, 2, 3]).unwrap();
let view = SuccinctDocLens::from_bytes(meta, bytes).unwrap();
assert_eq!(view.get(3), None);
assert_eq!(view.get(99), None);
}
#[test]
fn width_matches_max_value() {
assert_eq!(required_width(&[0, 15, 7, 3]), 4);
assert_eq!(required_width(&[0, 16]), 5);
assert_eq!(required_width(&[0, 0, 0]), 1);
assert_eq!(required_width(&[]), 1);
}
#[test]
fn large_lens_pack_correctly() {
let lens: Vec<u32> = (0..200).map(|i| i * 5_000).collect();
let (bytes, meta) = SuccinctDocLens::build(&lens).unwrap();
let view = SuccinctDocLens::from_bytes(meta, bytes).unwrap();
assert_eq!(view.to_vec(), lens);
assert_eq!(view.width(), 20); }
#[test]
fn bit_packing_beats_raw_u32() {
let lens: Vec<u32> = (0..1000).map(|i| (i % 200) as u32).collect();
let (bytes, _meta) = SuccinctDocLens::build(&lens).unwrap();
assert!(
bytes.len() < lens.len() * 4,
"succinct {} < naive {}",
bytes.len(),
lens.len() * 4
);
}
#[test]
fn pack_byte_table_round_trip_via_section_handle() {
let mut rows: Vec<[u8; 32]> =
vec![[5u8; 32], [1u8; 32], [9u8; 32], [3u8; 32]];
rows.sort();
let mut area = ByteArea::new().unwrap();
let mut sections = area.sections();
let handle = pack_byte_table::<32>(&mut sections, &rows).unwrap();
drop(sections);
let bytes = area.freeze().unwrap();
let view: View<[[u8; 32]]> = handle.view(&bytes).unwrap();
let view_slice: &[[u8; 32]] = &view;
assert_eq!(view_slice.len(), rows.len());
assert_eq!(view_slice, rows.as_slice());
assert_eq!(view_slice.binary_search(&[3u8; 32]), Ok(1));
assert_eq!(view_slice.binary_search(&[7u8; 32]), Err(3));
}
#[test]
fn pack_byte_table_empty_section() {
let mut area = ByteArea::new().unwrap();
let mut sections = area.sections();
let handle = pack_byte_table::<16>(&mut sections, &[]).unwrap();
drop(sections);
let bytes = area.freeze().unwrap();
let view: View<[[u8; 16]]> = handle.view(&bytes).unwrap();
assert!(view.is_empty());
}
#[test]
fn postings_roundtrip_simple() {
let lists = vec![
vec![(0u32, 1.5f32), (3, 0.75), (7, 2.0)],
vec![(1, 0.5), (2, 3.25)],
vec![],
vec![(4, 9.0)],
];
let (bytes, meta) = SuccinctPostings::build(&lists, 8).unwrap();
let view = SuccinctPostings::from_bytes(meta, bytes).unwrap();
assert_eq!(view.term_count(), 4);
assert_eq!(view.posting_count(0), Some(3));
assert_eq!(view.posting_count(1), Some(2));
assert_eq!(view.posting_count(2), Some(0));
assert_eq!(view.posting_count(3), Some(1));
assert_eq!(view.posting_count(4), None);
let tol = view.score_tolerance();
for (t, expected) in lists.iter().enumerate() {
let got: Vec<(u32, f32)> = view.postings_for(t).unwrap().collect();
assert_eq!(got.len(), expected.len(), "term {t} length");
for ((gd, gs), (ed, es)) in got.iter().zip(expected.iter()) {
assert_eq!(gd, ed, "term {t} doc idx");
assert!(
(gs - es).abs() <= tol,
"term {t} score drift {gs} vs {es} exceeds tol {tol}"
);
}
}
}
#[test]
fn postings_empty_corpus() {
let (bytes, meta) = SuccinctPostings::build(&[] as &[Vec<(u32, f32)>], 0).unwrap();
let view = SuccinctPostings::from_bytes(meta, bytes).unwrap();
assert_eq!(view.term_count(), 0);
assert!(view.postings_for(0).is_none());
}
#[test]
fn build_with_streaming_matches_lists_build() {
let lists = vec![
vec![(0u32, 1.5f32), (3, 0.75), (7, 2.0)],
vec![(1, 0.5), (2, 3.25)],
vec![],
vec![(4, 9.0)],
];
let (bytes_a, meta_a) = SuccinctPostings::build(&lists, 8).unwrap();
let (bytes_b, meta_b) = SuccinctPostings::build_with(8, lists.len(), |t, buf| {
buf.extend_from_slice(&lists[t]);
})
.unwrap();
assert_eq!(bytes_a.as_ref(), bytes_b.as_ref(), "byte-identical output");
assert_eq!(meta_a.max_score, meta_b.max_score);
assert_eq!(meta_a.n_terms, meta_b.n_terms);
}
#[test]
fn succinct_bm25_matches_naive_on_sample() {
use crate::bm25::BM25Builder;
use crate::tokens::hash_tokens;
use triblespace_core::id::Id;
fn iid(byte: u8) -> Id {
Id::new([byte; 16]).unwrap()
}
let mut b: BM25Builder = BM25Builder::new();
b.insert(iid(1), hash_tokens("the quick brown fox"));
b.insert(iid(2), hash_tokens("the lazy brown dog"));
b.insert(iid(3), hash_tokens("quick silver fox jumps"));
b.insert(iid(4), hash_tokens("unrelated filler content"));
let naive = b.clone().build_naive();
let succinct = b.build();
assert_eq!(succinct.doc_count(), naive.doc_count());
assert_eq!(succinct.term_count(), naive.term_count());
assert_eq!(succinct.k1(), naive.k1());
assert_eq!(succinct.b(), naive.b());
assert!((succinct.avg_doc_len() - naive.avg_doc_len()).abs() < 1e-6);
let tol = succinct.score_tolerance();
for term_raw in naive.terms_slice() {
let term: Inline<crate::tokens::WordHash> = Inline::new(*term_raw);
let n: Vec<_> = naive.query_term(&term).collect();
let s: Vec<_> = succinct.query_term(&term).collect();
assert_eq!(
n.len(),
s.len(),
"posting count mismatch for term {term_raw:x?}"
);
for ((n_id, n_s), (s_id, s_s)) in n.iter().zip(s.iter()) {
assert_eq!(n_id.raw, s_id.raw);
assert!(
(n_s - s_s).abs() <= tol,
"score drift for {n_id:?}: naive={n_s} succinct={s_s} > tol {tol}"
);
}
assert_eq!(naive.doc_frequency(&term), succinct.doc_frequency(&term));
}
let missing = hash_tokens("banana");
assert!(succinct.query_term(&missing[0]).next().is_none());
assert_eq!(succinct.doc_frequency(&missing[0]), 0);
}
#[test]
fn succinct_bm25_empty_corpus() {
use crate::bm25::BM25Builder;
use triblespace_core::inline::encodings::genid::GenId;
let succinct = BM25Builder::<GenId, crate::tokens::WordHash>::new().build();
assert_eq!(succinct.doc_count(), 0);
assert_eq!(succinct.term_count(), 0);
let probe: Inline<crate::tokens::WordHash> = Inline::new([0u8; 32]);
assert!(succinct.query_term(&probe).next().is_none());
}
#[test]
fn succinct_bm25_query_multi_matches_naive() {
use crate::bm25::BM25Builder;
use crate::tokens::hash_tokens;
use triblespace_core::id::Id;
fn iid(byte: u8) -> Id {
Id::new([byte; 16]).unwrap()
}
let mut b: BM25Builder = BM25Builder::new();
b.insert(iid(1), hash_tokens("quick fox"));
b.insert(iid(2),
hash_tokens("quick red rapid fox jumps high over fences"),
);
b.insert(iid(3), hash_tokens("slow brown dog"));
let naive = b.clone().build_naive();
let succinct = b.build();
let q = hash_tokens("quick fox");
let a = naive.query_multi(&q);
let b = succinct.query_multi(&q);
assert_eq!(a.len(), b.len());
let tol = succinct.score_tolerance() * 2.0; for ((a_id, a_s), (b_id, b_s)) in a.iter().zip(b.iter()) {
assert_eq!(a_id, b_id, "ranking order mismatch");
assert!(
(a_s - b_s).abs() <= tol,
"score drift: naive={a_s} succinct={b_s} > tol {tol}"
);
}
assert_eq!(b.len(), 2);
}
fn build_succinct_sample() -> SuccinctBM25Index {
use crate::bm25::BM25Builder;
use crate::tokens::hash_tokens;
use triblespace_core::id::Id;
fn iid(byte: u8) -> Id {
Id::new([byte; 16]).unwrap()
}
let mut b: BM25Builder = BM25Builder::new().k1(1.4).b(0.72);
b.insert(iid(1), hash_tokens("the quick brown fox"));
b.insert(iid(2), hash_tokens("the lazy brown dog"));
b.insert(iid(3), hash_tokens("quick silver fox jumps"));
b.insert(iid(4), hash_tokens("completely unrelated filler content"));
b.build()
}
#[test]
fn succinct_bm25_bytes_round_trip() {
use crate::tokens::hash_tokens;
use triblespace_core::blob::{Blob, TryFromBlob};
let original = build_succinct_sample();
let blob: Blob<SuccinctBM25Blob> = Blob::new(original.bytes.clone());
let reloaded = SuccinctBM25Index::try_from_blob(blob).expect("valid blob");
assert_eq!(reloaded.doc_count(), original.doc_count());
assert_eq!(reloaded.term_count(), original.term_count());
assert_eq!(reloaded.k1(), original.k1());
assert_eq!(reloaded.b(), original.b());
assert!((reloaded.avg_doc_len() - original.avg_doc_len()).abs() < 1e-6);
let tol = original.score_tolerance().max(1e-5);
for word in ["the", "fox", "quick", "brown", "dog"] {
let term = hash_tokens(word)[0];
let a: Vec<_> = original.query_term(&term).collect();
let b: Vec<_> = reloaded.query_term(&term).collect();
assert_eq!(a.len(), b.len(), "term '{word}' count mismatch");
for ((a_id, a_s), (b_id, b_s)) in a.iter().zip(b.iter()) {
assert_eq!(a_id, b_id);
assert!(
(a_s - b_s).abs() <= tol,
"term '{word}': score drift {a_s} vs {b_s} > tol {tol}"
);
}
}
}
#[test]
fn succinct_bm25_empty_round_trip() {
use crate::bm25::BM25Builder;
use triblespace_core::blob::{Blob, TryFromBlob};
use triblespace_core::inline::encodings::genid::GenId;
let idx = BM25Builder::<GenId, crate::tokens::WordHash>::new().build();
let blob: Blob<SuccinctBM25Blob> = Blob::new(idx.bytes.clone());
let reloaded: SuccinctBM25Index =
SuccinctBM25Index::try_from_blob(blob).expect("valid blob");
assert_eq!(reloaded.doc_count(), 0);
assert_eq!(reloaded.term_count(), 0);
}
#[test]
fn succinct_bm25_rejects_short_header() {
use triblespace_core::blob::{Blob, TryFromBlob};
let blob: Blob<SuccinctBM25Blob> = Blob::new(Bytes::from_source([0u8; 10].to_vec()));
let err = SuccinctBM25Index::<
triblespace_core::inline::encodings::genid::GenId,
crate::tokens::WordHash,
>::try_from_blob(blob)
.unwrap_err();
assert_eq!(err, SuccinctLoadError::BadMeta("suffix"));
}
#[test]
fn succinct_bm25_rejects_truncation() {
use triblespace_core::blob::{Blob, TryFromBlob};
let sample = build_succinct_sample();
let full = sample.bytes.as_ref();
let truncated = full[..full.len() - 2].to_vec();
let blob: Blob<SuccinctBM25Blob> = Blob::new(Bytes::from_source(truncated));
let err = SuccinctBM25Index::<
triblespace_core::inline::encodings::genid::GenId,
crate::tokens::WordHash,
>::try_from_blob(blob)
.unwrap_err();
assert!(
matches!(
err,
SuccinctLoadError::TruncatedSection(_) | SuccinctLoadError::BadMeta(_),
),
"expected TruncatedSection or BadMeta, got {err:?}",
);
}
#[test]
fn succinct_bm25_blob_schema_round_trip() {
use triblespace_core::blob::{IntoBlob, TryFromBlob};
let original = build_succinct_sample();
let blob: triblespace_core::blob::Blob<SuccinctBM25Blob> = (&original).to_blob();
let reloaded: SuccinctBM25Index =
SuccinctBM25Index::try_from_blob(blob).expect("valid blob");
assert_eq!(reloaded.doc_count(), original.doc_count());
assert_eq!(reloaded.term_count(), original.term_count());
}
#[test]
fn succinct_bm25_blob_is_deterministic() {
let a = build_succinct_sample();
let b = build_succinct_sample();
assert_eq!(a.bytes.as_ref(), b.bytes.as_ref());
}
#[test]
fn graph_roundtrip_simple() {
let layers = vec![
vec![vec![1u32, 2], vec![0, 3], vec![0, 3], vec![1, 2]],
vec![
vec![2u32],
vec![], vec![0],
vec![], ],
];
let (bytes, meta) = SuccinctGraph::build(&layers, 4).unwrap();
let view = SuccinctGraph::from_bytes(meta, bytes).unwrap();
assert_eq!(view.n_nodes(), 4);
assert_eq!(view.n_layers(), 2);
for (layer_idx, layer) in layers.iter().enumerate() {
for (i, expected) in layer.iter().enumerate() {
let got: Vec<u32> = view.neighbours(i, layer_idx).collect();
assert_eq!(&got, expected, "mismatch at (node {i}, layer {layer_idx})");
}
}
}
#[test]
fn graph_out_of_range() {
let layers = vec![vec![vec![1u32], vec![0]]];
let (bytes, meta) = SuccinctGraph::build(&layers, 2).unwrap();
let view = SuccinctGraph::from_bytes(meta, bytes).unwrap();
assert!(view.neighbours(5, 0).next().is_none());
assert!(view.neighbours(0, 99).next().is_none());
}
#[test]
fn graph_empty() {
let layers: Vec<Vec<Vec<u32>>> = vec![];
let (bytes, meta) = SuccinctGraph::build(&layers, 0).unwrap();
let view = SuccinctGraph::from_bytes(meta, bytes).unwrap();
assert_eq!(view.n_nodes(), 0);
assert_eq!(view.n_layers(), 0);
}
#[test]
fn graph_rejects_out_of_range_neighbour() {
let layers = vec![vec![vec![5u32], vec![0], vec![0]]];
let err = SuccinctGraph::build(&layers, 3).unwrap_err();
assert!(matches!(err, SuccinctDocLensError::SizeMismatch { .. }));
}
#[test]
fn succinct_hnsw_matches_naive_on_sample() {
use crate::hnsw::HNSWBuilder;
use triblespace_core::blob::MemoryBlobStore;
use triblespace_core::repo::BlobStore;
let mut store = MemoryBlobStore::new();
let mut b = HNSWBuilder::new(4).with_seed(42);
let mut handles = Vec::new();
for i in 1..=16u8 {
let f = i as f32;
let vec = vec![f.sin(), f.cos(), (f * 0.5).sin(), (f * 0.3).cos()];
let h = crate::schemas::put_embedding::<_>(&mut store, vec.clone()).unwrap();
b.insert(h, vec).unwrap();
handles.push(h);
}
let naive = b.build_naive();
let succinct = SuccinctHNSWIndex::from_naive(&naive).unwrap();
let reader = store.reader().unwrap();
assert_eq!(succinct.doc_count(), naive.doc_count());
assert_eq!(succinct.dim(), naive.dim());
assert_eq!(succinct.max_level(), naive.max_level());
let naive_view = naive.attach(&reader);
let succinct_view = succinct.attach(&reader);
let floor = 0.5f32;
for probe in handles.iter().take(3) {
let n: std::collections::HashSet<_> =
naive_view.candidates_above(*probe, floor).unwrap().into_iter().collect();
let s: std::collections::HashSet<_> =
succinct_view.candidates_above(*probe, floor).unwrap().into_iter().collect();
assert_eq!(n, s, "mismatch for probe {probe:?}");
}
}
fn build_succinct_hnsw_sample() -> (
SuccinctHNSWIndex,
triblespace_core::blob::MemoryBlobStore,
Vec<
triblespace_core::inline::Inline<
triblespace_core::inline::encodings::hash::Handle<
crate::schemas::Embedding,
>,
>,
>,
) {
use crate::hnsw::HNSWBuilder;
use triblespace_core::blob::MemoryBlobStore;
let mut store = MemoryBlobStore::new();
let mut b = HNSWBuilder::new(4).with_seed(17);
let mut handles = Vec::new();
for i in 1..=20u8 {
let f = i as f32;
let v = vec![f.sin(), f.cos(), (f * 0.7).sin(), (f * 0.3).cos()];
let h = crate::schemas::put_embedding::<_>(&mut store, v.clone()).unwrap();
b.insert(h, v).unwrap();
handles.push(h);
}
let idx = b.build();
(idx, store, handles)
}
#[test]
fn succinct_hnsw_bytes_round_trip() {
use triblespace_core::blob::{Blob, TryFromBlob};
let (original, mut store, handles) = build_succinct_hnsw_sample();
let blob: Blob<SuccinctHNSWBlob> = Blob::new(original.bytes.clone());
let reloaded = SuccinctHNSWIndex::try_from_blob(blob).expect("valid blob");
assert_eq!(reloaded.doc_count(), original.doc_count());
assert_eq!(reloaded.dim(), original.dim());
assert_eq!(reloaded.m(), original.m());
assert_eq!(reloaded.m0(), original.m0());
assert_eq!(reloaded.max_level(), original.max_level());
let reader = store.reader().unwrap();
let orig_hits: std::collections::HashSet<_> = original
.attach(&reader)
.candidates_above(handles[0], 0.5)
.unwrap()
.into_iter()
.collect();
let load_hits: std::collections::HashSet<_> = reloaded
.attach(&reader)
.candidates_above(handles[0], 0.5)
.unwrap()
.into_iter()
.collect();
assert_eq!(orig_hits, load_hits);
}
#[test]
fn succinct_hnsw_empty_round_trip() {
use crate::hnsw::HNSWBuilder;
use triblespace_core::blob::{Blob, MemoryBlobStore, TryFromBlob};
use triblespace_core::repo::BlobStore;
let idx = HNSWBuilder::new(3).build();
let blob: Blob<SuccinctHNSWBlob> = Blob::new(idx.bytes.clone());
let reloaded: SuccinctHNSWIndex =
SuccinctHNSWIndex::try_from_blob(blob).expect("valid blob");
assert_eq!(reloaded.doc_count(), 0);
let mut store: MemoryBlobStore = MemoryBlobStore::new();
let probe = crate::schemas::put_embedding::<_>(
&mut store,
vec![1.0, 0.0, 0.0],
)
.unwrap();
assert!(reloaded
.attach(&store.reader().unwrap())
.candidates_above(probe, 0.0)
.unwrap()
.is_empty());
}
#[test]
fn succinct_hnsw_rejects_short_header() {
use triblespace_core::blob::{Blob, TryFromBlob};
let blob: Blob<SuccinctHNSWBlob> = Blob::new(Bytes::from_source([0u8; 10].to_vec()));
let err = SuccinctHNSWIndex::try_from_blob(blob).unwrap_err();
assert_eq!(err, SuccinctLoadError::BadMeta("suffix"));
}
#[test]
fn succinct_hnsw_rejects_truncation() {
use triblespace_core::blob::{Blob, TryFromBlob};
let (idx, _, _) = build_succinct_hnsw_sample();
let full = idx.bytes.as_ref();
let truncated = full[..full.len() - 2].to_vec();
let blob: Blob<SuccinctHNSWBlob> = Blob::new(Bytes::from_source(truncated));
let err = SuccinctHNSWIndex::try_from_blob(blob).unwrap_err();
assert!(
matches!(
err,
SuccinctLoadError::TruncatedSection(_) | SuccinctLoadError::BadMeta(_),
),
"expected TruncatedSection or BadMeta, got {err:?}",
);
}
#[test]
fn succinct_hnsw_blob_schema_round_trip() {
use triblespace_core::blob::{IntoBlob, TryFromBlob};
let (original, _, _) = build_succinct_hnsw_sample();
let blob: triblespace_core::blob::Blob<SuccinctHNSWBlob> = (&original).to_blob();
let reloaded: SuccinctHNSWIndex =
SuccinctHNSWIndex::try_from_blob(blob).expect("valid blob");
assert_eq!(reloaded.doc_count(), original.doc_count());
assert_eq!(reloaded.dim(), original.dim());
}
#[test]
fn succinct_hnsw_blob_is_deterministic() {
let (a, _, _) = build_succinct_hnsw_sample();
let (b, _, _) = build_succinct_hnsw_sample();
assert_eq!(a.bytes.as_ref(), b.bytes.as_ref());
}
#[test]
fn succinct_hnsw_empty_index() {
use crate::hnsw::HNSWBuilder;
use triblespace_core::blob::MemoryBlobStore;
use triblespace_core::repo::BlobStore;
let succinct = HNSWBuilder::new(3).build();
assert_eq!(succinct.doc_count(), 0);
let mut store: MemoryBlobStore = MemoryBlobStore::new();
let probe = crate::schemas::put_embedding::<_>(
&mut store,
vec![1.0, 0.0, 0.0],
)
.unwrap();
assert!(succinct
.attach(&store.reader().unwrap())
.candidates_above(probe, 0.0)
.unwrap()
.is_empty());
}
#[test]
fn graph_rejects_mismatched_layer_width() {
let layers = vec![vec![vec![1u32], vec![0]]];
let err = SuccinctGraph::build(&layers, 3).unwrap_err();
assert!(matches!(err, SuccinctDocLensError::SizeMismatch { .. }));
}
#[test]
fn succinct_blob_smaller_than_naive_at_scale() {
use crate::bm25::BM25Builder;
use crate::tokens::hash_tokens;
use triblespace_core::id::Id;
let mut b: BM25Builder = BM25Builder::new();
for i in 1..=250u16 {
let text = format!("doc {i} contains the quick brown fox {}", i % 17);
let id = Id::new([
(i >> 8) as u8,
(i & 0xff) as u8,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
0xaa,
if i == 0 { 1 } else { 0xaa },
])
.unwrap();
b.insert(id, hash_tokens(&text));
}
let naive_size = b.clone().build_naive().byte_size();
let succinct_size = b.build().bytes.len();
assert!(
succinct_size < naive_size,
"succinct {succinct_size} should be < naive baseline {naive_size}",
);
}
#[test]
fn postings_scale_saves_space_vs_naive() {
let mut lists = Vec::new();
for t in 0..500 {
let mut l = Vec::new();
for j in 0..3 {
l.push(((t * 3 + j) as u32 % 1000, 1.0 + j as f32));
}
lists.push(l);
}
let total: usize = lists.iter().map(|l| l.len()).sum();
let (bytes, _meta) = SuccinctPostings::build(&lists, 1000).unwrap();
let naive = total * 4 + (lists.len() + 1) * 4 + total * 4;
assert!(
bytes.len() < naive,
"succinct body {} < naive total {}",
bytes.len(),
naive
);
}
}