1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use redb::{Database, ReadableTable};
6
7use crate::graph::{GraphFile, Node};
8
9const REDB_TERMS_TABLE: redb::TableDefinition<&str, &[u8]> =
10 redb::TableDefinition::new("bm25_terms");
11const REDB_META_TABLE: redb::TableDefinition<&str, &[u8]> = redb::TableDefinition::new("bm25_meta");
12
13#[derive(Debug, Clone)]
14pub struct Bm25Index {
15 pub avg_doc_len: f32,
16 pub doc_count: usize,
17 pub k1: f32,
18 pub b: f32,
19 pub idf: HashMap<String, f32>,
20 pub term_to_docs: HashMap<String, HashSet<String>>,
21}
22
23impl Bm25Index {
24 pub fn new() -> Self {
25 Self {
26 avg_doc_len: 0.0,
27 doc_count: 0,
28 k1: 1.5,
29 b: 0.75,
30 idf: HashMap::new(),
31 term_to_docs: HashMap::new(),
32 }
33 }
34
35 pub fn build(graph: &GraphFile) -> Self {
36 let mut index = Self::new();
37 let mut doc_lengths: Vec<usize> = Vec::new();
38
39 for node in &graph.nodes {
40 let terms = extract_terms(node);
41 let doc_len = terms.len();
42 doc_lengths.push(doc_len);
43
44 let mut doc_terms: HashSet<String> = HashSet::new();
45 for term in terms {
46 doc_terms.insert(term.clone());
47 index
48 .term_to_docs
49 .entry(term)
50 .or_default()
51 .insert(node.id.clone());
52 }
53 }
54
55 index.doc_count = graph.nodes.len();
56 if !doc_lengths.is_empty() {
57 index.avg_doc_len = doc_lengths.iter().sum::<usize>() as f32 / doc_lengths.len() as f32;
58 }
59
60 let num_docs = index.doc_count.max(1) as f32;
61 for (term, docs) in &index.term_to_docs {
62 let doc_freq = docs.len() as f32;
63 index.idf.insert(
64 term.clone(),
65 ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln(),
66 );
67 }
68
69 index
70 }
71
72 pub fn save(&self, db_path: &Path) -> Result<()> {
73 let db = open_index_db(db_path)?;
74 let write_txn = db
75 .begin_write()
76 .context("failed to start write transaction")?;
77 {
78 let mut terms_table = write_txn.open_table(REDB_TERMS_TABLE)?;
79 for (term, doc_ids) in &self.term_to_docs {
80 let doc_ids_json =
81 serde_json::to_string(doc_ids).context("failed to serialize doc ids")?;
82 terms_table.insert(term.as_str(), doc_ids_json.as_bytes())?;
83 }
84 }
85 {
86 let mut meta_table = write_txn.open_table(REDB_META_TABLE)?;
87 meta_table.insert("avg_doc_len", self.avg_doc_len.to_string().as_bytes())?;
88 meta_table.insert("doc_count", self.doc_count.to_string().as_bytes())?;
89 meta_table.insert("k1", self.k1.to_string().as_bytes())?;
90 meta_table.insert("b", self.b.to_string().as_bytes())?;
91 let idf_json = serde_json::to_string(&self.idf).context("failed to serialize idf")?;
92 meta_table.insert("idf", idf_json.as_bytes())?;
93 }
94 write_txn.commit().context("failed to commit index")?;
95 Ok(())
96 }
97
98 pub fn load(db_path: &Path) -> Result<Self> {
99 let db = open_index_db(db_path)?;
100 let read_txn = db
101 .begin_read()
102 .context("failed to start read transaction")?;
103
104 let avg_doc_len = read_txn
105 .open_table(REDB_META_TABLE)?
106 .get("avg_doc_len")?
107 .map(|v| {
108 std::str::from_utf8(v.value())
109 .unwrap_or("0")
110 .parse::<f32>()
111 .unwrap_or(0.0)
112 })
113 .unwrap_or(0.0);
114
115 let doc_count = read_txn
116 .open_table(REDB_META_TABLE)?
117 .get("doc_count")?
118 .map(|v| {
119 std::str::from_utf8(v.value())
120 .unwrap_or("0")
121 .parse::<usize>()
122 .unwrap_or(0)
123 })
124 .unwrap_or(0);
125
126 let k1 = read_txn
127 .open_table(REDB_META_TABLE)?
128 .get("k1")?
129 .map(|v| {
130 std::str::from_utf8(v.value())
131 .unwrap_or("1.5")
132 .parse::<f32>()
133 .unwrap_or(1.5)
134 })
135 .unwrap_or(1.5);
136
137 let b = read_txn
138 .open_table(REDB_META_TABLE)?
139 .get("b")?
140 .map(|v| {
141 std::str::from_utf8(v.value())
142 .unwrap_or("0.75")
143 .parse::<f32>()
144 .unwrap_or(0.75)
145 })
146 .unwrap_or(0.75);
147
148 let idf_json = read_txn
149 .open_table(REDB_META_TABLE)?
150 .get("idf")?
151 .map(|v| -> String { String::from_utf8_lossy(v.value()).into_owned() })
152 .unwrap_or_else(|| "{}".to_string());
153 let idf: HashMap<String, f32> = serde_json::from_str(&idf_json).unwrap_or_default();
154
155 let mut term_to_docs: HashMap<String, HashSet<String>> = HashMap::new();
156 let terms_table = read_txn.open_table(REDB_TERMS_TABLE)?;
157 let entries: Vec<_> = terms_table.iter()?.collect();
158 for entry in entries {
159 let entry = entry?;
160 let term_str = entry.0.value();
161 let doc_ids_str = std::str::from_utf8(entry.1.value())?;
162 let doc_ids: HashSet<String> = serde_json::from_str(doc_ids_str).unwrap_or_default();
163 term_to_docs.insert(term_str.to_string(), doc_ids);
164 }
165
166 Ok(Self {
167 avg_doc_len,
168 doc_count,
169 k1,
170 b,
171 idf,
172 term_to_docs,
173 })
174 }
175
176 pub fn search(&self, query_terms: &[String], graph: &GraphFile) -> Vec<(String, f32)> {
177 if query_terms.is_empty() || self.doc_count == 0 {
178 return Vec::new();
179 }
180
181 let mut scores: HashMap<String, f32> = HashMap::new();
182
183 for term in query_terms {
184 let idf = self.idf.get(term).copied().unwrap_or(0.0);
185 if idf <= 0.0 {
186 continue;
187 }
188
189 if let Some(doc_ids) = self.term_to_docs.get(term) {
190 for doc_id in doc_ids {
191 if let Some(node) = graph.node_by_id(doc_id) {
192 let terms = extract_terms(node);
193 let doc_len = terms.len() as f32;
194 let tf = terms.iter().filter(|t| *t == term).count() as f32;
195
196 let numerator = idf * tf * (self.k1 + 1.0);
197 let denominator =
198 tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_len);
199
200 let score = if denominator > 0.0 {
201 numerator / denominator
202 } else {
203 0.0
204 };
205
206 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
207 }
208 }
209 }
210 }
211
212 let mut results: Vec<(String, f32)> = scores.into_iter().collect();
213 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
214 results
215 }
216}
217
218fn open_index_db(db_path: &Path) -> Result<Database> {
219 if db_path.exists() {
220 Database::open(db_path)
221 .with_context(|| format!("failed to open index db: {}", db_path.display()))
222 } else {
223 Database::create(db_path)
224 .with_context(|| format!("failed to create index db: {}", db_path.display()))
225 }
226}
227
228fn extract_terms(node: &Node) -> Vec<String> {
229 let mut terms: Vec<String> = Vec::new();
230
231 for word in node.id.split(|c: char| !c.is_alphanumeric()) {
232 if word.len() > 2 {
233 terms.push(word.to_lowercase());
234 }
235 }
236
237 for word in node.name.split(|c: char| !c.is_alphanumeric()) {
238 if word.len() > 2 {
239 terms.push(word.to_lowercase());
240 }
241 }
242
243 for word in node
244 .properties
245 .description
246 .split(|c: char| !c.is_alphanumeric())
247 {
248 if word.len() > 2 {
249 terms.push(word.to_lowercase());
250 }
251 }
252
253 for alias in &node.properties.alias {
254 for word in alias.split(|c: char| !c.is_alphanumeric()) {
255 if word.len() > 2 {
256 terms.push(word.to_lowercase());
257 }
258 }
259 }
260
261 for fact in &node.properties.key_facts {
262 for word in fact.split(|c: char| !c.is_alphanumeric()) {
263 if word.len() > 2 {
264 terms.push(word.to_lowercase());
265 }
266 }
267 }
268
269 terms.sort();
270 terms.dedup();
271 terms
272}