use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
use tantivy::collector::TopDocs;
use tantivy::directory::MmapDirectory;
use tantivy::query::QueryParser;
use tantivy::schema::{STORED, Schema, TEXT, Value};
use tantivy::{Index, IndexWriter, TantivyDocument};
use crate::domain::MemoryRecord;
pub struct Bm25Index {
index: Index,
index_path: PathBuf,
schema: Schema,
field_title: tantivy::schema::Field,
field_summary: tantivy::schema::Field,
field_entities: tantivy::schema::Field,
field_tags: tantivy::schema::Field,
field_triggers: tantivy::schema::Field,
field_record_id: tantivy::schema::Field,
}
impl Bm25Index {
pub fn open_or_create(index_path: &Path) -> Result<Self> {
let (schema, fields) = build_schema();
std::fs::create_dir_all(index_path).with_context(|| {
format!("failed to create bm25 index dir: {}", index_path.display())
})?;
let mmap_dir = MmapDirectory::open(index_path)
.with_context(|| format!("failed to open mmap dir: {}", index_path.display()))?;
let index = if Index::exists(&mmap_dir)? {
Index::open_in_dir(index_path)
.with_context(|| format!("failed to open bm25 index: {}", index_path.display()))?
} else {
Index::create_in_dir(index_path, schema.clone())
.with_context(|| format!("failed to create bm25 index: {}", index_path.display()))?
};
Ok(Self {
index,
index_path: index_path.to_path_buf(),
schema,
field_title: fields.0,
field_summary: fields.1,
field_entities: fields.2,
field_tags: fields.3,
field_triggers: fields.4,
field_record_id: fields.5,
})
}
pub fn build_from_records(&self, records: &[(String, MemoryRecord)]) -> Result<()> {
if self.index_path.exists() {
std::fs::remove_dir_all(&self.index_path)?;
std::fs::create_dir_all(&self.index_path)?;
}
let index = Index::create_in_dir(&self.index_path, self.schema.clone())?;
let mut writer: IndexWriter = index.writer(50_000_000)?;
for (record_id, record) in records {
let mut doc = TantivyDocument::new();
doc.add_text(self.field_record_id, record_id);
doc.add_text(self.field_title, &record.title);
doc.add_text(self.field_summary, &record.summary);
doc.add_text(self.field_entities, &record.entities.join(" "));
doc.add_text(self.field_tags, &record.tags.join(" "));
doc.add_text(self.field_triggers, &record.triggers.join(" "));
writer.add_document(doc)?;
}
writer.commit()?;
Ok(())
}
pub fn add_record(&self, record_id: &str, record: &MemoryRecord) -> Result<()> {
let mut writer: IndexWriter = self.index.writer(15_000_000)?;
let mut doc = TantivyDocument::new();
doc.add_text(self.field_record_id, record_id);
doc.add_text(self.field_title, &record.title);
doc.add_text(self.field_summary, &record.summary);
doc.add_text(self.field_entities, &record.entities.join(" "));
doc.add_text(self.field_tags, &record.tags.join(" "));
doc.add_text(self.field_triggers, &record.triggers.join(" "));
writer.add_document(doc)?;
writer.commit()?;
Ok(())
}
pub fn search(&self, query: &str, limit: usize) -> Result<Vec<(String, f32)>> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
let reader = self.index.reader()?;
let searcher = reader.searcher();
let query_parser = QueryParser::for_index(
&self.index,
vec![
self.field_title,
self.field_summary,
self.field_entities,
self.field_tags,
self.field_triggers,
],
);
let parsed_query = query_parser.parse_query(query).unwrap_or_else(|_| {
let term_query = tantivy::query::TermQuery::new(
tantivy::Term::from_field_text(self.field_summary, query),
tantivy::schema::IndexRecordOption::WithFreqs,
);
Box::new(term_query)
});
let top_docs = searcher.search(&parsed_query, &TopDocs::with_limit(limit))?;
let mut results = Vec::with_capacity(top_docs.len());
for (score, doc_address) in top_docs {
let doc: TantivyDocument = searcher.doc(doc_address)?;
if let Some(value) = doc.get_first(self.field_record_id) {
if let Some(record_id) = value.as_str() {
results.push((record_id.to_string(), score));
}
}
}
Ok(results)
}
pub fn index_path(&self) -> &Path {
&self.index_path
}
}
pub fn bm25_index_path(lifecycle_root: &Path) -> PathBuf {
lifecycle_root.join("bm25-index")
}
fn build_schema() -> (
Schema,
(
tantivy::schema::Field,
tantivy::schema::Field,
tantivy::schema::Field,
tantivy::schema::Field,
tantivy::schema::Field,
tantivy::schema::Field,
),
) {
let mut schema_builder = Schema::builder();
let title = schema_builder.add_text_field("title", TEXT);
let summary = schema_builder.add_text_field("summary", TEXT);
let entities = schema_builder.add_text_field("entities", TEXT);
let tags = schema_builder.add_text_field("tags", TEXT);
let triggers = schema_builder.add_text_field("triggers", TEXT);
let record_id = schema_builder.add_text_field("record_id", TEXT | STORED);
let schema = schema_builder.build();
(
schema,
(title, summary, entities, tags, triggers, record_id),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::{
MemoryLifecycleState, MemoryOrigin, MemoryRecord, MemoryScope, MemorySourceKind,
};
use tempfile::tempdir;
fn make_record(title: &str, summary: &str) -> MemoryRecord {
MemoryRecord {
title: title.to_string(),
summary: summary.to_string(),
memory_type: "decision".to_string(),
scope: MemoryScope::User,
state: MemoryLifecycleState::Accepted,
origin: MemoryOrigin {
source_kind: MemorySourceKind::Manual,
source_ref: "test".to_string(),
},
project_id: None,
user_id: None,
sensitivity: None,
entities: Vec::new(),
tags: Vec::new(),
triggers: Vec::new(),
related_files: Vec::new(),
related_records: Vec::new(),
supersedes: None,
applies_to: Vec::new(),
valid_until: None,
}
}
#[test]
fn bm25_should_index_and_search_records() {
let temp = tempdir().unwrap();
let index_path = temp.path().join("bm25-index");
let idx = Bm25Index::open_or_create(&index_path).unwrap();
let records = vec![
(
"r1".to_string(),
make_record(
"Rust error handling",
"Use anyhow for application errors and thiserror for library errors",
),
),
(
"r2".to_string(),
make_record(
"Database migrations",
"Always use reversible migrations with up/down scripts",
),
),
(
"r3".to_string(),
make_record(
"API design",
"REST endpoints should use consistent error envelopes",
),
),
];
idx.build_from_records(&records).unwrap();
let idx2 = Bm25Index::open_or_create(&index_path).unwrap();
let results = idx2.search("error handling", 10).unwrap();
assert!(
!results.is_empty(),
"should find results for 'error handling'"
);
assert_eq!(results[0].0, "r1");
}
#[test]
fn bm25_should_find_results_token_matching_would_miss() {
let temp = tempdir().unwrap();
let index_path = temp.path().join("bm25-index");
let idx = Bm25Index::open_or_create(&index_path).unwrap();
let records = vec![
(
"r1".to_string(),
make_record(
"Authentication strategy",
"Use OAuth2 with PKCE flow for single-page applications",
),
),
(
"r2".to_string(),
make_record(
"Caching policy",
"Redis for session storage, local LRU for hot paths",
),
),
];
idx.build_from_records(&records).unwrap();
let idx2 = Bm25Index::open_or_create(&index_path).unwrap();
let results = idx2.search("OAuth authentication", 10).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, "r1");
}
#[test]
fn bm25_empty_query_should_return_empty() {
let temp = tempdir().unwrap();
let index_path = temp.path().join("bm25-index");
let idx = Bm25Index::open_or_create(&index_path).unwrap();
let results = idx.search("", 10).unwrap();
assert!(results.is_empty());
}
}