use std::collections::{BTreeSet, HashMap};
use std::sync::Arc;
use super::tokenize::{ngrams, normalize};
use super::types::{ConversationMessage, CrossThreadHit};
const MIN_TERM_BYTES: usize = 3;
const LARGE_CANDIDATE_LIMIT: usize = 10_000;
#[derive(Debug, Clone)]
struct DocEntry {
thread_id: Arc<str>,
message_id: String,
role: Arc<str>,
content: String, content_normalized: String, created_at: String,
}
#[derive(Debug, Default)]
pub(crate) struct InvertedIndex {
postings: HashMap<Box<str>, BTreeSet<u32>>,
docs: Vec<Option<DocEntry>>,
by_message: HashMap<(String, String), u32>,
thread_id_pool: HashMap<String, Arc<str>>,
role_pool: HashMap<String, Arc<str>>,
}
impl InvertedIndex {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, thread_id: &str, msg: ConversationMessage) {
let ConversationMessage {
id,
content,
sender,
created_at,
message_type: _,
extra_metadata: _,
} = msg;
let key = (thread_id.to_string(), id.clone());
if self.by_message.contains_key(&key) {
return;
}
let normalized = normalize(&content);
let doc_id = self.docs.len() as u32;
for ngram in ngrams(&normalized) {
if let Some(posting) = self.postings.get_mut(ngram) {
posting.insert(doc_id);
} else {
let mut set = BTreeSet::new();
set.insert(doc_id);
self.postings.insert(ngram.into(), set);
}
}
let thread_arc = self.intern_thread_id(thread_id);
let role_arc = self.intern_role(&sender);
self.docs.push(Some(DocEntry {
thread_id: thread_arc,
message_id: id,
role: role_arc,
content,
content_normalized: normalized,
created_at,
}));
self.by_message.insert(key, doc_id);
}
pub fn remove_thread(&mut self, thread_id: &str) {
let to_remove: Vec<u32> = self
.by_message
.iter()
.filter(|((t, _), _)| t == thread_id)
.map(|(_, id)| *id)
.collect();
for doc_id in to_remove {
self.remove_doc(doc_id);
}
self.thread_id_pool.remove(thread_id);
}
#[allow(dead_code)]
pub fn clear(&mut self) {
self.postings.clear();
self.docs.clear();
self.by_message.clear();
self.thread_id_pool.clear();
self.role_pool.clear();
}
fn remove_doc(&mut self, doc_id: u32) {
let idx = doc_id as usize;
let Some(entry) = self.docs.get_mut(idx).and_then(|slot| slot.take()) else {
return;
};
self.by_message
.remove(&(entry.thread_id.to_string(), entry.message_id.clone()));
for ngram in ngrams(&entry.content_normalized) {
if let Some(posting) = self.postings.get_mut(ngram) {
posting.remove(&doc_id);
if posting.is_empty() {
self.postings.remove(ngram);
}
}
}
}
pub fn search(
&self,
query: &str,
limit: usize,
exclude_thread_id: Option<&str>,
) -> Vec<CrossThreadHit> {
if limit == 0 {
return Vec::new();
}
let query_lower = normalize(query);
let terms: Vec<String> = query_lower
.split_whitespace()
.filter(|t| t.len() >= MIN_TERM_BYTES)
.map(|s| s.to_string())
.collect();
if terms.is_empty() {
return Vec::new();
}
let mut per_term: Vec<Vec<u32>> = Vec::with_capacity(terms.len());
for term in &terms {
let candidates = match self.candidates_for_term(term) {
Some(v) => v,
None => self
.docs
.iter()
.enumerate()
.filter_map(|(i, slot)| slot.as_ref().map(|_| i as u32))
.collect::<Vec<u32>>(),
};
if candidates.len() > LARGE_CANDIDATE_LIMIT {
return self.recency_fallback(exclude_thread_id, limit);
}
per_term.push(candidates);
}
let mut hit_counts: HashMap<u32, usize> = HashMap::new();
for (term, candidates) in terms.iter().zip(per_term.into_iter()) {
for doc_id in candidates {
let Some(entry) = self.docs[doc_id as usize].as_ref() else {
continue;
};
if exclude_thread_id == Some(entry.thread_id.as_ref()) {
continue;
}
if entry.content_normalized.contains(term.as_str()) {
*hit_counts.entry(doc_id).or_insert(0) += 1;
}
}
}
let total_terms = terms.len() as f64;
let mut hits: Vec<CrossThreadHit> = hit_counts
.into_iter()
.map(|(doc_id, matched)| {
let entry = self.docs[doc_id as usize]
.as_ref()
.expect("doc_id from hit_counts must be live");
CrossThreadHit {
thread_id: entry.thread_id.to_string(),
message_id: entry.message_id.clone(),
role: entry.role.to_string(),
content: entry.content.clone(),
created_at: entry.created_at.clone(),
score: matched as f64 / total_terms,
}
})
.collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| b.created_at.cmp(&a.created_at))
});
hits.truncate(limit);
hits
}
fn candidates_for_term(&self, term: &str) -> Option<Vec<u32>> {
let term_ngrams = ngrams(term);
if term_ngrams.is_empty() {
return None;
}
let mut iter = term_ngrams.iter();
let first = iter.next().expect("non-empty by check above");
let mut acc: Vec<u32> = match self.postings.get(*first) {
Some(p) => p.iter().copied().collect(),
None => return Some(Vec::new()),
};
for ng in iter {
if acc.is_empty() {
return Some(acc);
}
match self.postings.get(*ng) {
Some(p) => intersect_sorted_with_btreeset(&mut acc, p),
None => return Some(Vec::new()),
}
}
Some(acc)
}
fn intern_thread_id(&mut self, thread_id: &str) -> Arc<str> {
if let Some(existing) = self.thread_id_pool.get(thread_id) {
return Arc::clone(existing);
}
let arc: Arc<str> = Arc::from(thread_id);
self.thread_id_pool
.insert(thread_id.to_string(), Arc::clone(&arc));
arc
}
fn intern_role(&mut self, role: &str) -> Arc<str> {
if let Some(existing) = self.role_pool.get(role) {
return Arc::clone(existing);
}
let arc: Arc<str> = Arc::from(role);
self.role_pool.insert(role.to_string(), Arc::clone(&arc));
arc
}
fn recency_fallback(
&self,
exclude_thread_id: Option<&str>,
limit: usize,
) -> Vec<CrossThreadHit> {
let mut hits: Vec<CrossThreadHit> = self
.docs
.iter()
.filter_map(|slot| slot.as_ref())
.filter(|entry| exclude_thread_id != Some(entry.thread_id.as_ref()))
.map(|entry| CrossThreadHit {
thread_id: entry.thread_id.to_string(),
message_id: entry.message_id.clone(),
role: entry.role.to_string(),
content: entry.content.clone(),
created_at: entry.created_at.clone(),
score: 0.0,
})
.collect();
hits.sort_by(|a, b| b.created_at.cmp(&a.created_at));
hits.truncate(limit);
hits
}
}
fn intersect_sorted_with_btreeset(acc: &mut Vec<u32>, other: &BTreeSet<u32>) {
let mut other_iter = other.iter().copied().peekable();
let mut write = 0usize;
for read in 0..acc.len() {
let target = acc[read];
while let Some(&o) = other_iter.peek() {
if o < target {
other_iter.next();
} else {
break;
}
}
if other_iter.peek().copied() == Some(target) {
acc[write] = target;
write += 1;
other_iter.next();
}
}
acc.truncate(write);
}
#[cfg(test)]
#[path = "inverted_index_tests.rs"]
mod tests;