agentic_memory/index/
term_index.rs1use std::collections::HashMap;
4
5use crate::engine::tokenizer::Tokenizer;
6use crate::graph::MemoryGraph;
7use crate::types::CognitiveEvent;
8
9pub struct TermIndex {
11 postings: HashMap<String, Vec<(u64, u32)>>,
13 doc_count: u64,
15 avg_doc_length: f32,
17}
18
19impl TermIndex {
20 pub fn new() -> Self {
22 Self {
23 postings: HashMap::new(),
24 doc_count: 0,
25 avg_doc_length: 0.0,
26 }
27 }
28
29 pub fn build(graph: &MemoryGraph, tokenizer: &Tokenizer) -> Self {
31 let mut index = Self::new();
32 let mut total_tokens: u64 = 0;
33
34 for node in graph.nodes() {
35 let freqs = tokenizer.term_frequencies(&node.content);
36 let doc_len: u32 = freqs.values().sum();
37 total_tokens += doc_len as u64;
38
39 for (term, freq) in freqs {
40 let posting = index.postings.entry(term).or_default();
41 let pos = posting
43 .binary_search_by_key(&node.id, |(id, _)| *id)
44 .unwrap_or_else(|p| p);
45 posting.insert(pos, (node.id, freq));
46 }
47
48 index.doc_count += 1;
49 }
50
51 if index.doc_count > 0 {
52 index.avg_doc_length = total_tokens as f32 / index.doc_count as f32;
53 }
54
55 index
56 }
57
58 pub fn get(&self, term: &str) -> &[(u64, u32)] {
60 self.postings.get(term).map(|v| v.as_slice()).unwrap_or(&[])
61 }
62
63 pub fn doc_frequency(&self, term: &str) -> usize {
65 self.postings.get(term).map(|v| v.len()).unwrap_or(0)
66 }
67
68 pub fn doc_count(&self) -> u64 {
70 self.doc_count
71 }
72
73 pub fn avg_doc_length(&self) -> f32 {
75 self.avg_doc_length
76 }
77
78 pub fn term_count(&self) -> usize {
80 self.postings.len()
81 }
82
83 pub fn add_node(&mut self, event: &CognitiveEvent) {
85 let tokenizer = Tokenizer::new();
86 let freqs = tokenizer.term_frequencies(&event.content);
87 for (term, freq) in freqs {
88 let posting = self.postings.entry(term).or_default();
89 let pos = posting
90 .binary_search_by_key(&event.id, |(id, _)| *id)
91 .unwrap_or_else(|p| p);
92 posting.insert(pos, (event.id, freq));
93 }
94 self.doc_count += 1;
95 }
97
98 pub fn remove_node(&mut self, id: u64) {
100 for posting in self.postings.values_mut() {
101 if let Ok(pos) = posting.binary_search_by_key(&id, |(nid, _)| *nid) {
102 posting.remove(pos);
103 }
104 }
105 self.doc_count = self.doc_count.saturating_sub(1);
106 }
107
108 pub fn clear(&mut self) {
110 self.postings.clear();
111 self.doc_count = 0;
112 self.avg_doc_length = 0.0;
113 }
114
115 pub fn rebuild(&mut self, graph: &MemoryGraph) {
117 *self = Self::build(graph, &Tokenizer::new());
118 }
119
120 pub fn to_bytes(&self) -> Vec<u8> {
122 let mut buf: Vec<u8> = Vec::new();
123
124 buf.extend_from_slice(&self.doc_count.to_le_bytes());
125 buf.extend_from_slice(&self.avg_doc_length.to_le_bytes());
126 buf.extend_from_slice(&(self.postings.len() as u32).to_le_bytes());
127
128 let mut terms: Vec<&String> = self.postings.keys().collect();
130 terms.sort();
131
132 for term in terms {
133 let postings = &self.postings[term];
134 let term_bytes = term.as_bytes();
135 buf.extend_from_slice(&(term_bytes.len() as u16).to_le_bytes());
136 buf.extend_from_slice(term_bytes);
137 buf.extend_from_slice(&(postings.len() as u32).to_le_bytes());
138 for &(node_id, term_freq) in postings {
139 buf.extend_from_slice(&node_id.to_le_bytes());
140 buf.extend_from_slice(&term_freq.to_le_bytes());
141 }
142 }
143
144 buf
145 }
146
147 pub fn from_bytes(data: &[u8]) -> Option<Self> {
149 let mut pos = 0;
150
151 if data.len() < 16 {
152 return None;
153 }
154
155 let doc_count = u64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
156 pos += 8;
157 let avg_doc_length = f32::from_le_bytes(data[pos..pos + 4].try_into().ok()?);
158 pos += 4;
159 let term_count = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
160 pos += 4;
161
162 let mut postings = HashMap::with_capacity(term_count);
163
164 for _ in 0..term_count {
165 if pos + 2 > data.len() {
166 return None;
167 }
168 let term_len = u16::from_le_bytes(data[pos..pos + 2].try_into().ok()?) as usize;
169 pos += 2;
170
171 if pos + term_len > data.len() {
172 return None;
173 }
174 let term = std::str::from_utf8(&data[pos..pos + term_len])
175 .ok()?
176 .to_string();
177 pos += term_len;
178
179 if pos + 4 > data.len() {
180 return None;
181 }
182 let posting_count = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?) as usize;
183 pos += 4;
184
185 let mut posting_list = Vec::with_capacity(posting_count);
186 for _ in 0..posting_count {
187 if pos + 12 > data.len() {
188 return None;
189 }
190 let node_id = u64::from_le_bytes(data[pos..pos + 8].try_into().ok()?);
191 pos += 8;
192 let term_freq = u32::from_le_bytes(data[pos..pos + 4].try_into().ok()?);
193 pos += 4;
194 posting_list.push((node_id, term_freq));
195 }
196
197 postings.insert(term, posting_list);
198 }
199
200 Some(Self {
201 postings,
202 doc_count,
203 avg_doc_length,
204 })
205 }
206}
207
208impl Default for TermIndex {
209 fn default() -> Self {
210 Self::new()
211 }
212}