use crate::digest::ValueDigest;
use crate::proximity::chunker::{Chunker, IdentityChunker};
use crate::proximity::embedder::{EmbedError, Embedder};
use crate::proximity::index::{ProximityConfig, ProximityError, ProximityIndex};
use crate::proximity::Metric;
use crate::storage::{NodeStorage, StorageError};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use thiserror::Error;
pub(crate) fn make_chunk_id(doc_id: &[u8], chunk_idx: u32) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + doc_id.len() + 4);
out.extend_from_slice(&(doc_id.len() as u32).to_le_bytes());
out.extend_from_slice(doc_id);
out.extend_from_slice(&chunk_idx.to_le_bytes());
out
}
pub fn parse_chunk_id(bytes: &[u8]) -> Option<(Vec<u8>, u32)> {
if bytes.len() < 8 {
return None;
}
let len = u32::from_le_bytes(bytes[..4].try_into().ok()?) as usize;
if bytes.len() != 4 + len + 4 {
return None;
}
let doc_id = bytes[4..4 + len].to_vec();
let chunk_idx = u32::from_le_bytes(bytes[4 + len..].try_into().ok()?);
Some((doc_id, chunk_idx))
}
pub(crate) fn doc_id_prefix(doc_id: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + doc_id.len());
out.extend_from_slice(&(doc_id.len() as u32).to_le_bytes());
out.extend_from_slice(doc_id);
out
}
pub(crate) const OVERFETCH_MULTIPLIER: usize = 4;
pub(crate) fn dedup_chunk_hits_by_doc(chunk_hits: Vec<(Vec<u8>, f32)>, k: usize) -> Vec<TextHit> {
let mut best_per_doc: HashMap<Vec<u8>, f32> = HashMap::new();
for (chunk_id, score) in chunk_hits {
let doc_id = match parse_chunk_id(&chunk_id) {
Some((d, _)) => d,
None => chunk_id.clone(), };
best_per_doc
.entry(doc_id)
.and_modify(|s| {
if score < *s {
*s = score;
}
})
.or_insert(score);
}
let mut docs: Vec<(Vec<u8>, f32)> = best_per_doc.into_iter().collect();
docs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
docs.truncate(k);
docs.into_iter()
.map(|(id, score)| TextHit { id, score })
.collect()
}
#[derive(Debug, Error)]
pub enum TextIndexError {
#[error("embedder mismatch: stored {stored_id}@{stored_version}, supplied {provided_id}@{provided_version}")]
EmbedderMismatch {
stored_id: String,
stored_version: String,
provided_id: String,
provided_version: String,
},
#[error("dimension mismatch: stored index uses dim {stored}, embedder produces dim {got}")]
DimensionMismatch { stored: u16, got: u16 },
#[error("proximity error: {0}")]
Proximity(#[from] ProximityError),
#[error("embedder error: {0}")]
Embed(#[from] EmbedError),
#[error("no text index persisted under name {0:?}")]
NotFound(String),
#[error("could not decode saved text index state: {0}")]
InvalidSavedState(String),
#[error("serialize error: {0}")]
Serialize(String),
#[error("storage error: {0}")]
Storage(#[from] StorageError),
}
#[derive(Debug, Clone, PartialEq)]
pub struct TextHit {
pub id: Vec<u8>,
pub score: f32,
}
#[derive(Clone)]
pub struct TextIndexConfig<E: Embedder> {
pub embedder: E,
pub chunker: Arc<dyn Chunker>,
pub metric: Metric,
pub level_bits: u8,
pub max_bucket_size: u16,
}
impl<E: Embedder + std::fmt::Debug> std::fmt::Debug for TextIndexConfig<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TextIndexConfig")
.field("embedder", &self.embedder)
.field("chunker_id", &self.chunker.id())
.field("metric", &self.metric)
.field("level_bits", &self.level_bits)
.field("max_bucket_size", &self.max_bucket_size)
.finish()
}
}
impl<E: Embedder> TextIndexConfig<E> {
pub fn new(embedder: E) -> Self {
Self {
embedder,
chunker: Arc::new(IdentityChunker),
metric: Metric::Cosine,
level_bits: 4,
max_bucket_size: 64,
}
}
pub fn with_chunker<C: Chunker + 'static>(mut self, chunker: C) -> Self {
self.chunker = Arc::new(chunker);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub(crate) struct SavedTextIndexState {
pub(crate) embedder_id: String,
pub(crate) embedder_version: String,
pub(crate) dim: u16,
pub(crate) metric: Metric,
pub(crate) level_bits: u8,
pub(crate) max_bucket_size: u16,
}
pub(crate) fn text_state_key(name: &str) -> String {
format!("text:{name}:state")
}
pub(crate) fn text_inner_proximity_name(name: &str) -> String {
format!("__text__{name}")
}
pub(crate) fn validate_or_write_text_identity<const N: usize, E, S>(
storage: &S,
state_key: &str,
embedder: &E,
metric: Metric,
level_bits: u8,
max_bucket_size: u16,
) -> Result<(), TextIndexError>
where
E: Embedder,
S: NodeStorage<N>,
{
if let Some(bytes) = storage.get_config(state_key) {
let state: SavedTextIndexState = bincode::deserialize(&bytes)
.map_err(|e| TextIndexError::InvalidSavedState(e.to_string()))?;
if state.embedder_id != embedder.id() || state.embedder_version != embedder.version() {
return Err(TextIndexError::EmbedderMismatch {
stored_id: state.embedder_id,
stored_version: state.embedder_version,
provided_id: embedder.id().to_string(),
provided_version: embedder.version().to_string(),
});
}
if state.dim != embedder.dim() {
return Err(TextIndexError::DimensionMismatch {
stored: state.dim,
got: embedder.dim(),
});
}
Ok(())
} else {
let state = SavedTextIndexState {
embedder_id: embedder.id().to_string(),
embedder_version: embedder.version().to_string(),
dim: embedder.dim(),
metric,
level_bits,
max_bucket_size,
};
let bytes =
bincode::serialize(&state).map_err(|e| TextIndexError::Serialize(e.to_string()))?;
storage.save_config(state_key, &bytes);
Ok(())
}
}
pub struct TextIndex<const N: usize, E: Embedder, S: NodeStorage<N>> {
inner: ProximityIndex<N, S>,
embedder: E,
chunker: Arc<dyn Chunker>,
}
impl<const N: usize, E: Embedder, S: NodeStorage<N>> std::fmt::Debug for TextIndex<N, E, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TextIndex")
.field("embedder_id", &self.embedder.id())
.field("embedder_version", &self.embedder.version())
.field("dim", &self.embedder.dim())
.field("len", &self.inner.len())
.finish()
}
}
impl<const N: usize, E: Embedder, S: NodeStorage<N>> TextIndex<N, E, S> {
pub fn new(storage: S, config: TextIndexConfig<E>) -> Self {
let inner = ProximityIndex::new(
storage,
ProximityConfig {
dim: config.embedder.dim(),
metric: config.metric,
level_bits: config.level_bits,
max_bucket_size: config.max_bucket_size,
},
);
Self {
inner,
embedder: config.embedder,
chunker: config.chunker,
}
}
pub fn load(storage: S, name: &str, embedder: E) -> Result<Self, TextIndexError> {
let state_key = text_state_key(name);
if storage.get_config(&state_key).is_none() {
return Err(TextIndexError::NotFound(name.to_string()));
}
validate_or_write_text_identity::<N, _, _>(
&storage,
&state_key,
&embedder,
Metric::Cosine,
4,
64,
)?;
let proximity_name = text_inner_proximity_name(name);
let inner = ProximityIndex::load(storage, &proximity_name)?;
Ok(Self {
inner,
embedder,
chunker: Arc::new(IdentityChunker),
})
}
pub fn set_chunker<C: Chunker + 'static>(&mut self, chunker: C) {
self.chunker = Arc::new(chunker);
}
pub fn persist(&mut self, name: &str) -> Result<Option<ValueDigest<N>>, TextIndexError> {
let cfg = self.inner.config().clone();
let state = SavedTextIndexState {
embedder_id: self.embedder.id().to_string(),
embedder_version: self.embedder.version().to_string(),
dim: cfg.dim,
metric: cfg.metric,
level_bits: cfg.level_bits,
max_bucket_size: cfg.max_bucket_size,
};
let bytes =
bincode::serialize(&state).map_err(|e| TextIndexError::Serialize(e.to_string()))?;
self.inner
.storage()
.save_config(&text_state_key(name), &bytes);
let proximity_name = text_inner_proximity_name(name);
let root = self.inner.persist(&proximity_name)?;
Ok(root)
}
pub fn insert(&mut self, id: &[u8], text: &str) -> Result<(), TextIndexError> {
self.delete_chunks_for_doc(id);
let chunks = self.chunker.split(text);
if chunks.is_empty() {
return Ok(());
}
for (idx, chunk_text) in chunks.iter().enumerate() {
let vec = self.embedder.embed(chunk_text)?;
if vec.len() as u16 != self.embedder.dim() {
return Err(TextIndexError::Embed(EmbedError::DimensionMismatch {
expected: self.embedder.dim(),
got: vec.len() as u16,
}));
}
let chunk_id = make_chunk_id(id, idx as u32);
self.inner.insert(chunk_id, vec)?;
}
Ok(())
}
pub fn delete(&mut self, id: &[u8]) -> bool {
self.delete_chunks_for_doc(id)
}
pub fn search(&mut self, query: &str, k: usize) -> Result<Vec<TextHit>, TextIndexError> {
if k == 0 {
return Ok(Vec::new());
}
let q = self.embedder.embed(query)?;
let raw_k = (k * OVERFETCH_MULTIPLIER).max(k);
let ef = (raw_k * 4).max(32);
let chunk_hits = self.inner.knn(&q, raw_k, ef)?;
Ok(dedup_chunk_hits_by_doc(chunk_hits, k))
}
fn delete_chunks_for_doc(&mut self, doc_id: &[u8]) -> bool {
let prefix = doc_id_prefix(doc_id);
let to_remove: Vec<Vec<u8>> = self
.inner
.entries_snapshot()
.keys()
.filter(|k| k.starts_with(&prefix))
.cloned()
.collect();
let mut any = false;
for cid in to_remove {
if self.inner.remove(&cid) {
any = true;
}
}
any
}
pub fn reindex_from_texts<I>(&mut self, texts: I) -> Result<(), TextIndexError>
where
I: IntoIterator<Item = (Vec<u8>, String)>,
{
let ids: Vec<Vec<u8>> = self.inner.entries_snapshot().keys().cloned().collect();
for id in ids {
self.inner.remove(&id);
}
for (id, text) in texts {
self.insert(&id, &text)?;
}
Ok(())
}
pub fn len(&self) -> usize {
use std::collections::HashSet;
let mut docs: HashSet<Vec<u8>> = HashSet::new();
for k in self.inner.entries_snapshot().keys() {
match parse_chunk_id(k) {
Some((doc, _)) => {
docs.insert(doc);
}
None => {
docs.insert(k.clone());
}
}
}
docs.len()
}
pub fn chunk_count(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn root_hash(&mut self) -> Result<Option<ValueDigest<N>>, TextIndexError> {
Ok(self.inner.root_hash()?.cloned())
}
pub fn embedder(&self) -> &E {
&self.embedder
}
pub fn proximity_config(&self) -> &ProximityConfig {
self.inner.config()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proximity::embedder::HashEmbedder;
use crate::storage::InMemoryNodeStorage;
fn config(dim: u16) -> TextIndexConfig<HashEmbedder> {
TextIndexConfig::new(HashEmbedder::new(dim, 0))
}
#[test]
fn insert_and_search_finds_exact_match() {
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::new(storage, config(32));
idx.insert(b"doc:1", "the quick brown fox").unwrap();
idx.insert(b"doc:2", "another piece of text").unwrap();
idx.insert(b"doc:3", "yet more content").unwrap();
let hits = idx.search("the quick brown fox", 1).unwrap();
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].id, b"doc:1".to_vec());
assert!(
hits[0].score < 1e-4,
"expected near-zero, got {}",
hits[0].score
);
}
#[test]
fn delete_removes_document_from_search() {
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::new(storage, config(8));
idx.insert(b"a", "hello world").unwrap();
idx.insert(b"b", "goodbye world").unwrap();
assert!(idx.delete(b"a"));
assert!(!idx.delete(b"a"));
let hits = idx.search("hello world", 2).unwrap();
assert!(!hits.iter().any(|h| h.id == b"a".to_vec()));
assert!(hits.iter().any(|h| h.id == b"b".to_vec()));
}
#[test]
fn empty_index_search_returns_empty() {
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::<32, _, _>::new(storage, config(8));
let hits = idx.search("anything", 5).unwrap();
assert!(hits.is_empty());
}
#[test]
fn persist_then_load_matches_original() {
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::new(storage, config(16));
idx.insert(b"a", "first document").unwrap();
idx.insert(b"b", "second document").unwrap();
idx.insert(b"c", "third document").unwrap();
let original_hits = idx.search("second document", 3).unwrap();
idx.persist("docs").unwrap();
let storage_after = idx.inner.storage().clone();
let mut reopened: TextIndex<32, _, _> =
TextIndex::load(storage_after, "docs", HashEmbedder::new(16, 0)).unwrap();
assert_eq!(reopened.len(), 3);
let reopened_hits = reopened.search("second document", 3).unwrap();
assert_eq!(reopened_hits, original_hits);
}
#[test]
fn load_with_different_embedder_id_returns_mismatch() {
struct DifferentIdEmbedder(HashEmbedder);
impl Embedder for DifferentIdEmbedder {
fn id(&self) -> &str {
"other:family/v1"
}
fn version(&self) -> &str {
self.0.version()
}
fn dim(&self) -> u16 {
self.0.dim()
}
fn embed(&self, text: &str) -> Result<Vec<f32>, EmbedError> {
self.0.embed(text)
}
}
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::new(storage, config(8));
idx.insert(b"a", "x").unwrap();
idx.persist("docs").unwrap();
let storage_after = idx.inner.storage().clone();
let err = TextIndex::<32, _, _>::load(
storage_after,
"docs",
DifferentIdEmbedder(HashEmbedder::new(8, 0)),
)
.unwrap_err();
assert!(
matches!(err, TextIndexError::EmbedderMismatch { .. }),
"expected EmbedderMismatch, got {err:?}"
);
}
#[test]
fn load_with_different_embedder_version_returns_mismatch() {
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::new(storage, TextIndexConfig::new(HashEmbedder::new(8, 0)));
idx.insert(b"a", "x").unwrap();
idx.persist("docs").unwrap();
let storage_after = idx.inner.storage().clone();
let err = TextIndex::<32, _, _>::load(storage_after, "docs", HashEmbedder::new(8, 1))
.unwrap_err();
assert!(matches!(err, TextIndexError::EmbedderMismatch { .. }));
}
#[test]
fn load_unknown_name_returns_not_found() {
let storage = InMemoryNodeStorage::<32>::new();
let err =
TextIndex::<32, _, _>::load(storage, "missing", HashEmbedder::new(8, 0)).unwrap_err();
assert!(matches!(err, TextIndexError::NotFound(name) if name == "missing"));
}
#[test]
fn reindex_from_texts_clears_then_re_embeds() {
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::new(storage, config(8));
idx.insert(b"a", "stale").unwrap();
idx.insert(b"b", "stale").unwrap();
assert_eq!(idx.len(), 2);
idx.reindex_from_texts(vec![
(b"a".to_vec(), "fresh a".to_string()),
(b"c".to_vec(), "fresh c".to_string()),
])
.unwrap();
assert_eq!(idx.len(), 2);
let hits = idx.search("fresh c", 1).unwrap();
assert_eq!(hits[0].id, b"c".to_vec());
}
#[test]
fn search_returns_at_most_k() {
let storage = InMemoryNodeStorage::<32>::new();
let mut idx = TextIndex::new(storage, config(16));
for i in 0..10 {
let id = format!("id-{i}").into_bytes();
let text = format!("text-{i}");
idx.insert(&id, &text).unwrap();
}
assert_eq!(idx.search("text-3", 3).unwrap().len(), 3);
assert_eq!(idx.search("text-3", 100).unwrap().len(), 10);
}
}