Skip to main content

agentic_memory/index/
term_index.rs

1//! BM25 inverted index mapping terms to posting lists.
2
3use std::collections::HashMap;
4
5use crate::engine::tokenizer::Tokenizer;
6use crate::graph::MemoryGraph;
7use crate::types::CognitiveEvent;
8
9/// An inverted index mapping tokenized terms to posting lists (sorted node ID arrays).
10pub struct TermIndex {
11    /// term → sorted Vec of (node_id, term_frequency_in_node)
12    postings: HashMap<String, Vec<(u64, u32)>>,
13    /// Total number of documents (nodes) indexed.
14    doc_count: u64,
15    /// Average document length in tokens.
16    avg_doc_length: f32,
17}
18
19impl TermIndex {
20    /// Create an empty term index.
21    pub fn new() -> Self {
22        Self {
23            postings: HashMap::new(),
24            doc_count: 0,
25            avg_doc_length: 0.0,
26        }
27    }
28
29    /// Build the index from all node contents in the graph.
30    pub fn build(graph: &MemoryGraph, tokenizer: &Tokenizer) -> Self {
31        let mut index = Self::new();
32        let mut total_tokens: u64 = 0;
33
34        for node in graph.nodes() {
35            let freqs = tokenizer.term_frequencies(&node.content);
36            let doc_len: u32 = freqs.values().sum();
37            total_tokens += doc_len as u64;
38
39            for (term, freq) in freqs {
40                let posting = index.postings.entry(term).or_default();
41                // Maintain sort order by node_id
42                let pos = posting
43                    .binary_search_by_key(&node.id, |(id, _)| *id)
44                    .unwrap_or_else(|p| p);
45                posting.insert(pos, (node.id, freq));
46            }
47
48            index.doc_count += 1;
49        }
50
51        if index.doc_count > 0 {
52            index.avg_doc_length = total_tokens as f32 / index.doc_count as f32;
53        }
54
55        index
56    }
57
58    /// Look up a term. Returns (node_id, term_frequency) pairs.
59    pub fn get(&self, term: &str) -> &[(u64, u32)] {
60        self.postings.get(term).map(|v| v.as_slice()).unwrap_or(&[])
61    }
62
63    /// Number of nodes containing a term (document frequency).
64    pub fn doc_frequency(&self, term: &str) -> usize {
65        self.postings.get(term).map(|v| v.len()).unwrap_or(0)
66    }
67
68    /// Total number of indexed documents.
69    pub fn doc_count(&self) -> u64 {
70        self.doc_count
71    }
72
73    /// Average document length.
74    pub fn avg_doc_length(&self) -> f32 {
75        self.avg_doc_length
76    }
77
78    /// Number of unique terms.
79    pub fn term_count(&self) -> usize {
80        self.postings.len()
81    }
82
83    /// Add a single node to the index incrementally.
84    pub fn add_node(&mut self, event: &CognitiveEvent) {
85        let tokenizer = Tokenizer::new();
86        let freqs = tokenizer.term_frequencies(&event.content);
87        for (term, freq) in freqs {
88            let posting = self.postings.entry(term).or_default();
89            let pos = posting
90                .binary_search_by_key(&event.id, |(id, _)| *id)
91                .unwrap_or_else(|p| p);
92            posting.insert(pos, (event.id, freq));
93        }
94        self.doc_count += 1;
95        // avg_doc_length becomes approximate after incremental adds
96    }
97
98    /// Remove a node from the index.
99    pub fn remove_node(&mut self, id: u64) {
100        for posting in self.postings.values_mut() {
101            if let Ok(pos) = posting.binary_search_by_key(&id, |(nid, _)| *nid) {
102                posting.remove(pos);
103            }
104        }
105        self.doc_count = self.doc_count.saturating_sub(1);
106    }
107
108    /// Clear the index.
109    pub fn clear(&mut self) {
110        self.postings.clear();
111        self.doc_count = 0;
112        self.avg_doc_length = 0.0;
113    }
114
115    /// Rebuild the index from a graph.
116    pub fn rebuild(&mut self, graph: &MemoryGraph) {
117        *self = Self::build(graph, &Tokenizer::new());
118    }
119
120    /// Serialize the term index to bytes for file writing.
121    pub fn to_bytes(&self) -> Vec<u8> {
122        let mut buf: Vec<u8> = Vec::new();
123
124        buf.extend_from_slice(&self.doc_count.to_le_bytes());
125        buf.extend_from_slice(&self.avg_doc_length.to_le_bytes());
126        buf.extend_from_slice(&(self.postings.len() as u32).to_le_bytes());
127
128        // Sort terms for deterministic output
129        let mut terms: Vec<&String> = self.postings.keys().collect();
130        terms.sort();
131
132        for term in terms {
133            let postings = &self.postings[term];
134            let term_bytes = term.as_bytes();
135            buf.extend_from_slice(&(term_bytes.len() as u16).to_le_bytes());
136            buf.extend_from_slice(term_bytes);
137            buf.extend_from_slice(&(postings.len() as u32).to_le_bytes());
138            for &(node_id, term_freq) in postings {
139                buf.extend_from_slice(&node_id.to_le_bytes());
140                buf.extend_from_slice(&term_freq.to_le_bytes());
141            }
142        }
143
144        buf
145    }
146
147    /// Deserialize a term index from bytes.
148    pub fn from_bytes(data: &[u8]) -> Option<Self> {
149        let mut pos = 0;
150
151        if data.len() < 16 {
152            return None;
153        }
154
155        let doc_count = u64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
156        pos += 8;
157        let avg_doc_length = f32::from_le_bytes(data[pos..pos + 4].try_into().ok()?);
158        pos += 4;
159        let term_count = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
160        pos += 4;
161
162        let mut postings = HashMap::with_capacity(term_count);
163
164        for _ in 0..term_count {
165            if pos + 2 > data.len() {
166                return None;
167            }
168            let term_len = u16::from_le_bytes(data[pos..pos + 2].try_into().ok()?) as usize;
169            pos += 2;
170
171            if pos + term_len > data.len() {
172                return None;
173            }
174            let term = std::str::from_utf8(&data[pos..pos + term_len])
175                .ok()?
176                .to_string();
177            pos += term_len;
178
179            if pos + 4 > data.len() {
180                return None;
181            }
182            let posting_count = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
183            pos += 4;
184
185            let mut posting_list = Vec::with_capacity(posting_count);
186            for _ in 0..posting_count {
187                if pos + 12 > data.len() {
188                    return None;
189                }
190                let node_id = u64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
191                pos += 8;
192                let term_freq = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?);
193                pos += 4;
194                posting_list.push((node_id, term_freq));
195            }
196
197            postings.insert(term, posting_list);
198        }
199
200        Some(Self {
201            postings,
202            doc_count,
203            avg_doc_length,
204        })
205    }
206}
207
208impl Default for TermIndex {
209    fn default() -> Self {
210        Self::new()
211    }
212}