use crate::node::NodeId;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, MatchKind};
use std::collections::HashMap;
#[derive(Default)]
pub struct TextIndex {
ac_matcher: Option<AhoCorasick>,
keywords: Vec<String>,
keyword_to_nodes: HashMap<String, Vec<NodeId>>,
bm25_tf: HashMap<String, HashMap<NodeId, usize>>,
doc_lengths: HashMap<NodeId, usize>,
avg_dl: f32,
total_docs: usize,
}
impl TextIndex {
pub fn new() -> Self {
Self::default()
}
pub fn clear(&mut self) {
self.ac_matcher = None;
self.keywords.clear();
self.keyword_to_nodes.clear();
self.bm25_tf.clear();
self.doc_lengths.clear();
self.avg_dl = 0.0;
self.total_docs = 0;
}
pub fn add_keyword(&mut self, id: NodeId, keyword: &str) {
let kw = keyword.to_lowercase();
self.keyword_to_nodes
.entry(kw.clone())
.or_default()
.push(id);
}
pub fn add_text(&mut self, id: NodeId, text: &str) {
let text_lower = text.to_lowercase();
let chars: Vec<char> = text_lower.chars().collect();
if chars.is_empty() {
return;
}
let tokens: Vec<String> = if chars.len() > 1 {
chars.windows(2).map(|w| w.iter().collect()).collect()
} else {
vec![text_lower]
};
let mut local_tf = HashMap::new();
for token in &tokens {
*local_tf.entry(token.clone()).or_insert(0) += 1;
}
let dl = tokens.len();
self.doc_lengths.insert(id, dl);
for (token, tf) in local_tf {
self.bm25_tf.entry(token).or_default().insert(id, tf);
}
}
pub fn build(&mut self) {
let mut keys: Vec<String> = self.keyword_to_nodes.keys().cloned().collect();
keys.sort_by(|a, b| b.len().cmp(&a.len())); if !keys.is_empty()
&& let Ok(ac) = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest)
.build(&keys)
{
self.ac_matcher = Some(ac);
self.keywords = keys;
}
self.total_docs = self.doc_lengths.len();
if self.total_docs > 0 {
let sum_dl: usize = self.doc_lengths.values().sum();
self.avg_dl = sum_dl as f32 / self.total_docs as f32;
}
}
pub fn search_bm25(&self, query: &str, k1: f32, b: f32) -> HashMap<NodeId, f32> {
let mut results = HashMap::new();
if self.total_docs == 0 {
return results;
}
let query_lower = query.to_lowercase();
let chars: Vec<char> = query_lower.chars().collect();
if chars.is_empty() {
return results;
}
let tokens: Vec<String> = if chars.len() > 1 {
chars.windows(2).map(|w| w.iter().collect()).collect()
} else {
vec![query_lower]
};
let mut query_tf = HashMap::new();
for token in &tokens {
*query_tf.entry(token).or_insert(0) += 1;
}
let n = self.total_docs as f32;
let avg_dl = self.avg_dl;
for (token, _q_tf) in query_tf {
if let Some(docs) = self.bm25_tf.get(token) {
let df = docs.len() as f32;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
for (&id, &tf) in docs {
let dl = *self.doc_lengths.get(&id).unwrap_or(&0) as f32;
let tf_f32 = tf as f32;
let tf_norm =
(tf_f32 * (k1 + 1.0)) / (tf_f32 + k1 * (1.0 - b + b * dl / avg_dl));
*results.entry(id).or_insert(0.0) += idf * tf_norm;
}
}
}
results
}
pub fn search_ac(&self, query: &str) -> HashMap<NodeId, f32> {
let mut results = HashMap::new();
if let Some(ac) = &self.ac_matcher {
let query_lower = query.to_lowercase();
for mat in ac.find_iter(&query_lower) {
let kw = &self.keywords[mat.pattern()];
if let Some(nodes) = self.keyword_to_nodes.get(kw) {
for &id in nodes {
*results.entry(id).or_insert(0.0) += 1.0;
}
}
}
}
results
}
}