use std::collections::HashMap;
use tantivy::collector::TopDocs;
use tantivy::query::QueryParser;
use tantivy::schema::{Field, OwnedValue, Schema, Value, FAST, INDEXED, STORED, STRING, TEXT};
use tantivy::{doc, Index, IndexWriter, TantivyDocument};
use uuid::Uuid;
use crate::error::CorpFinanceError;
use crate::memory::types::{RunSummary, Surface};
use crate::CorpFinanceResult;
const WRITER_HEAP_BYTES: usize = 50_000_000;
#[derive(Debug, Clone, PartialEq)]
pub struct BM25Hit {
pub run_summary: RunSummary,
pub bm25_score: f32,
}
pub struct BM25MemoryIndex {
index: Index,
f_run_id: Field,
f_surface_event_id: Field,
f_summary_text: Field,
f_surface: Field,
f_ts: Field,
f_tenant_id: Field,
summaries: HashMap<Uuid, RunSummary>,
}
impl BM25MemoryIndex {
pub fn new() -> CorpFinanceResult<Self> {
let mut sb = Schema::builder();
let f_run_id = sb.add_text_field("run_id", STRING | STORED);
let f_surface_event_id = sb.add_text_field("surface_event_id", STRING | STORED);
let f_summary_text = sb.add_text_field("summary_text", TEXT | STORED);
let f_surface = sb.add_text_field("surface", STRING | STORED);
let f_ts = sb.add_i64_field("ts", INDEXED | STORED | FAST);
let f_tenant_id = sb.add_text_field("tenant_id", STRING | STORED);
let schema = sb.build();
let index = Index::create_in_ram(schema);
Ok(Self {
index,
f_run_id,
f_surface_event_id,
f_summary_text,
f_surface,
f_ts,
f_tenant_id,
summaries: HashMap::new(),
})
}
pub fn len(&self) -> usize {
self.summaries.len()
}
pub fn is_empty(&self) -> bool {
self.summaries.is_empty()
}
pub fn ingest(&mut self, summary: &RunSummary) -> CorpFinanceResult<()> {
let mut writer: IndexWriter = self
.index
.writer(WRITER_HEAP_BYTES)
.map_err(tantivy_to_cf)?;
let tenant = summary.tenant_id.clone().unwrap_or_default();
writer
.add_document(doc!(
self.f_run_id => summary.run_id.to_string(),
self.f_surface_event_id => summary.surface_event_id.clone(),
self.f_summary_text => summary.summary_text.clone(),
self.f_surface => summary.surface.as_str().to_string(),
self.f_ts => summary.ts.timestamp(),
self.f_tenant_id => tenant,
))
.map_err(tantivy_to_cf)?;
writer.commit().map_err(tantivy_to_cf)?;
self.summaries.insert(summary.run_id, summary.clone());
Ok(())
}
pub fn query(
&self,
query_text: &str,
limit: usize,
filter_tenant: Option<&str>,
) -> CorpFinanceResult<Vec<BM25Hit>> {
if query_text.trim().is_empty() {
return Ok(Vec::new());
}
let reader = self.index.reader().map_err(tantivy_to_cf)?;
let searcher = reader.searcher();
let qp = QueryParser::for_index(&self.index, vec![self.f_summary_text]);
let q = qp
.parse_query(query_text)
.map_err(|e| CorpFinanceError::InvalidInput {
field: "query_text".into(),
reason: format!("parse: {e}"),
})?;
let over = limit.saturating_mul(4).max(limit).max(1);
let hits = searcher
.search(&q, &TopDocs::with_limit(over))
.map_err(tantivy_to_cf)?;
let mut out: Vec<BM25Hit> = Vec::with_capacity(limit);
for (score, addr) in hits {
let stored: TantivyDocument = searcher.doc(addr).map_err(tantivy_to_cf)?;
let run_id_str = match stored.get_first(self.f_run_id) {
Some(v) => owned_str(v).unwrap_or_default(),
None => continue,
};
let run_id = match Uuid::parse_str(&run_id_str) {
Ok(u) => u,
Err(_) => continue,
};
if let Some(summary) = self.summaries.get(&run_id) {
if let Some(want) = filter_tenant {
if summary.tenant_id.as_deref() != Some(want) {
continue;
}
}
out.push(BM25Hit {
run_summary: summary.clone(),
bm25_score: score,
});
if out.len() >= limit {
break;
}
}
}
Ok(out)
}
pub fn matches_surface(summary: &RunSummary, want: Option<Surface>) -> bool {
match want {
Some(s) => summary.surface == s,
None => true,
}
}
}
fn tantivy_to_cf(e: tantivy::TantivyError) -> CorpFinanceError {
CorpFinanceError::SerializationError(format!("tantivy: {e}"))
}
fn owned_str(v: &OwnedValue) -> Option<String> {
v.as_str().map(|s| s.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
fn mk_summary(text: &str, surface_event_id: &str) -> RunSummary {
RunSummary::new(
Surface::Mcp,
surface_event_id,
"djb2:0xaaaa",
text,
vec![0.0, 0.0, 0.0],
)
}
#[test]
fn ingest_then_query_finds_doc() {
let mut idx = BM25MemoryIndex::new().unwrap();
let s = mk_summary("Apple revenue beat consensus", "earnings");
idx.ingest(&s).unwrap();
let res = idx.query("apple revenue", 5, None).unwrap();
assert!(!res.is_empty());
assert_eq!(res[0].run_summary.run_id, s.run_id);
}
#[test]
fn bm25_score_is_nonzero_for_relevant_hits() {
let mut idx = BM25MemoryIndex::new().unwrap();
idx.ingest(&mk_summary(
"AAPL Q3 earnings beat consensus on iPhone revenue",
"earnings",
))
.unwrap();
let hits = idx.query("earnings", 10, None).unwrap();
assert!(!hits.is_empty());
assert!(
hits[0].bm25_score > 0.0,
"expected positive BM25 score, got {}",
hits[0].bm25_score
);
}
}