Skip to main content

kg/
index.rs

1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3
4use anyhow::{Context, Result};
5use redb::{Database, ReadableTable};
6
7use crate::graph::{GraphFile, Node};
8
9const REDB_TERMS_TABLE: redb::TableDefinition<&str, &[u8]> =
10    redb::TableDefinition::new("bm25_terms");
11const REDB_META_TABLE: redb::TableDefinition<&str, &[u8]> = redb::TableDefinition::new("bm25_meta");
12
13#[derive(Debug, Clone)]
14pub struct Bm25Index {
15    pub avg_doc_len: f32,
16    pub doc_count: usize,
17    pub k1: f32,
18    pub b: f32,
19    pub idf: HashMap<String, f32>,
20    pub term_to_docs: HashMap<String, HashSet<String>>,
21}
22
23impl Bm25Index {
24    pub fn new() -> Self {
25        Self {
26            avg_doc_len: 0.0,
27            doc_count: 0,
28            k1: 1.5,
29            b: 0.75,
30            idf: HashMap::new(),
31            term_to_docs: HashMap::new(),
32        }
33    }
34
35    pub fn build(graph: &GraphFile) -> Self {
36        let mut index = Self::new();
37        let mut doc_lengths: Vec<usize> = Vec::new();
38
39        for node in &graph.nodes {
40            let terms = extract_terms(node);
41            let doc_len = terms.len();
42            doc_lengths.push(doc_len);
43
44            let mut doc_terms: HashSet<String> = HashSet::new();
45            for term in terms {
46                doc_terms.insert(term.clone());
47                index
48                    .term_to_docs
49                    .entry(term)
50                    .or_default()
51                    .insert(node.id.clone());
52            }
53        }
54
55        index.doc_count = graph.nodes.len();
56        if !doc_lengths.is_empty() {
57            index.avg_doc_len = doc_lengths.iter().sum::<usize>() as f32 / doc_lengths.len() as f32;
58        }
59
60        let num_docs = index.doc_count.max(1) as f32;
61        for (term, docs) in &index.term_to_docs {
62            let doc_freq = docs.len() as f32;
63            index.idf.insert(
64                term.clone(),
65                ((num_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln(),
66            );
67        }
68
69        index
70    }
71
72    pub fn save(&self, db_path: &Path) -> Result<()> {
73        let db = open_index_db(db_path)?;
74        let write_txn = db
75            .begin_write()
76            .context("failed to start write transaction")?;
77        {
78            let mut terms_table = write_txn.open_table(REDB_TERMS_TABLE)?;
79            for (term, doc_ids) in &self.term_to_docs {
80                let doc_ids_json =
81                    serde_json::to_string(doc_ids).context("failed to serialize doc ids")?;
82                terms_table.insert(term.as_str(), doc_ids_json.as_bytes())?;
83            }
84        }
85        {
86            let mut meta_table = write_txn.open_table(REDB_META_TABLE)?;
87            meta_table.insert("avg_doc_len", self.avg_doc_len.to_string().as_bytes())?;
88            meta_table.insert("doc_count", self.doc_count.to_string().as_bytes())?;
89            meta_table.insert("k1", self.k1.to_string().as_bytes())?;
90            meta_table.insert("b", self.b.to_string().as_bytes())?;
91            let idf_json = serde_json::to_string(&self.idf).context("failed to serialize idf")?;
92            meta_table.insert("idf", idf_json.as_bytes())?;
93        }
94        write_txn.commit().context("failed to commit index")?;
95        Ok(())
96    }
97
98    pub fn load(db_path: &Path) -> Result<Self> {
99        let db = open_index_db(db_path)?;
100        let read_txn = db
101            .begin_read()
102            .context("failed to start read transaction")?;
103
104        let avg_doc_len = read_txn
105            .open_table(REDB_META_TABLE)?
106            .get("avg_doc_len")?
107            .map(|v| {
108                std::str::from_utf8(v.value())
109                    .unwrap_or("0")
110                    .parse::<f32>()
111                    .unwrap_or(0.0)
112            })
113            .unwrap_or(0.0);
114
115        let doc_count = read_txn
116            .open_table(REDB_META_TABLE)?
117            .get("doc_count")?
118            .map(|v| {
119                std::str::from_utf8(v.value())
120                    .unwrap_or("0")
121                    .parse::<usize>()
122                    .unwrap_or(0)
123            })
124            .unwrap_or(0);
125
126        let k1 = read_txn
127            .open_table(REDB_META_TABLE)?
128            .get("k1")?
129            .map(|v| {
130                std::str::from_utf8(v.value())
131                    .unwrap_or("1.5")
132                    .parse::<f32>()
133                    .unwrap_or(1.5)
134            })
135            .unwrap_or(1.5);
136
137        let b = read_txn
138            .open_table(REDB_META_TABLE)?
139            .get("b")?
140            .map(|v| {
141                std::str::from_utf8(v.value())
142                    .unwrap_or("0.75")
143                    .parse::<f32>()
144                    .unwrap_or(0.75)
145            })
146            .unwrap_or(0.75);
147
148        let idf_json = read_txn
149            .open_table(REDB_META_TABLE)?
150            .get("idf")?
151            .map(|v| -> String { String::from_utf8_lossy(v.value()).into_owned() })
152            .unwrap_or_else(|| "{}".to_string());
153        let idf: HashMap<String, f32> = serde_json::from_str(&idf_json).unwrap_or_default();
154
155        let mut term_to_docs: HashMap<String, HashSet<String>> = HashMap::new();
156        let terms_table = read_txn.open_table(REDB_TERMS_TABLE)?;
157        let entries: Vec<_> = terms_table.iter()?.collect();
158        for entry in entries {
159            let entry = entry?;
160            let term_str = entry.0.value();
161            let doc_ids_str = std::str::from_utf8(entry.1.value())?;
162            let doc_ids: HashSet<String> = serde_json::from_str(doc_ids_str).unwrap_or_default();
163            term_to_docs.insert(term_str.to_string(), doc_ids);
164        }
165
166        Ok(Self {
167            avg_doc_len,
168            doc_count,
169            k1,
170            b,
171            idf,
172            term_to_docs,
173        })
174    }
175
176    pub fn search(&self, query_terms: &[String], graph: &GraphFile) -> Vec<(String, f32)> {
177        if query_terms.is_empty() || self.doc_count == 0 {
178            return Vec::new();
179        }
180
181        let mut scores: HashMap<String, f32> = HashMap::new();
182
183        for term in query_terms {
184            let idf = self.idf.get(term).copied().unwrap_or(0.0);
185            if idf <= 0.0 {
186                continue;
187            }
188
189            if let Some(doc_ids) = self.term_to_docs.get(term) {
190                for doc_id in doc_ids {
191                    if let Some(node) = graph.node_by_id(doc_id) {
192                        let terms = extract_terms(node);
193                        let doc_len = terms.len() as f32;
194                        let tf = terms.iter().filter(|t| *t == term).count() as f32;
195
196                        let numerator = idf * tf * (self.k1 + 1.0);
197                        let denominator =
198                            tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_len);
199
200                        let score = if denominator > 0.0 {
201                            numerator / denominator
202                        } else {
203                            0.0
204                        };
205
206                        *scores.entry(doc_id.clone()).or_insert(0.0) += score;
207                    }
208                }
209            }
210        }
211
212        let mut results: Vec<(String, f32)> = scores.into_iter().collect();
213        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
214        results
215    }
216}
217
218fn open_index_db(db_path: &Path) -> Result<Database> {
219    if db_path.exists() {
220        Database::open(db_path)
221            .with_context(|| format!("failed to open index db: {}", db_path.display()))
222    } else {
223        Database::create(db_path)
224            .with_context(|| format!("failed to create index db: {}", db_path.display()))
225    }
226}
227
228fn extract_terms(node: &Node) -> Vec<String> {
229    let mut terms: Vec<String> = Vec::new();
230
231    for word in node.id.split(|c: char| !c.is_alphanumeric()) {
232        if word.len() > 2 {
233            terms.push(word.to_lowercase());
234        }
235    }
236
237    for word in node.name.split(|c: char| !c.is_alphanumeric()) {
238        if word.len() > 2 {
239            terms.push(word.to_lowercase());
240        }
241    }
242
243    for word in node
244        .properties
245        .description
246        .split(|c: char| !c.is_alphanumeric())
247    {
248        if word.len() > 2 {
249            terms.push(word.to_lowercase());
250        }
251    }
252
253    for alias in &node.properties.alias {
254        for word in alias.split(|c: char| !c.is_alphanumeric()) {
255            if word.len() > 2 {
256                terms.push(word.to_lowercase());
257            }
258        }
259    }
260
261    for fact in &node.properties.key_facts {
262        for word in fact.split(|c: char| !c.is_alphanumeric()) {
263            if word.len() > 2 {
264                terms.push(word.to_lowercase());
265            }
266        }
267    }
268
269    terms.sort();
270    terms.dedup();
271    terms
272}