amaters_core/storage/
encrypted_index.rs1use crate::error::{AmateRSError, ErrorContext, Result};
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use siphasher::sip128::{Hash128, Hasher128, SipHasher13};
9use std::hash::Hasher;
10
11pub struct EncryptedIndex {
15 name: String,
16 collection: String,
17 field: String,
18 index: DashMap<u128, Vec<u64>>,
20}
21
22#[derive(Debug, Serialize, Deserialize)]
24struct EncryptedIndexSnapshot {
25 name: String,
26 collection: String,
27 field: String,
28 entries: Vec<(u128, Vec<u64>)>,
29}
30
31impl EncryptedIndex {
32 pub fn new(
49 name: impl Into<String>,
50 collection: impl Into<String>,
51 field: impl Into<String>,
52 ) -> Self {
53 Self {
54 name: name.into(),
55 collection: collection.into(),
56 field: field.into(),
57 index: DashMap::new(),
58 }
59 }
60
61 pub fn name(&self) -> &str {
63 &self.name
64 }
65
66 pub fn collection(&self) -> &str {
68 &self.collection
69 }
70
71 pub fn field(&self) -> &str {
73 &self.field
74 }
75
76 fn hash_bytes(ciphertext_bytes: &[u8]) -> u128 {
78 let mut hasher = SipHasher13::new();
79 hasher.write(ciphertext_bytes);
80 let Hash128 { h1, h2 } = hasher.finish128();
81 ((h1 as u128) << 64) | (h2 as u128)
82 }
83
84 pub fn insert(&self, ciphertext_bytes: &[u8], record_id: u64) {
86 let key = Self::hash_bytes(ciphertext_bytes);
87 self.index.entry(key).or_default().push(record_id);
88 }
89
90 pub fn remove(&self, ciphertext_bytes: &[u8], record_id: u64) {
92 let key = Self::hash_bytes(ciphertext_bytes);
93 let mut remove_entry = false;
94 if let Some(mut entry) = self.index.get_mut(&key) {
95 if let Some(pos) = entry.iter().position(|&id| id == record_id) {
96 entry.swap_remove(pos);
97 }
98 remove_entry = entry.is_empty();
99 }
100 if remove_entry {
101 self.index.remove(&key);
102 }
103 }
104
105 pub fn lookup_candidates(&self, ciphertext_bytes: &[u8]) -> Vec<u64> {
107 let key = Self::hash_bytes(ciphertext_bytes);
108 self.index
109 .get(&key)
110 .map(|entry| entry.value().clone())
111 .unwrap_or_default()
112 }
113
114 pub fn len(&self) -> usize {
116 self.index.len()
117 }
118
119 pub fn is_empty(&self) -> bool {
121 self.index.is_empty()
122 }
123
124 pub fn total_records(&self) -> usize {
126 self.index.iter().map(|e| e.value().len()).sum()
127 }
128
129 pub fn serialize(&self) -> Result<Vec<u8>> {
131 let entries: Vec<(u128, Vec<u64>)> = self
132 .index
133 .iter()
134 .map(|e| (*e.key(), e.value().clone()))
135 .collect();
136 let snapshot = EncryptedIndexSnapshot {
137 name: self.name.clone(),
138 collection: self.collection.clone(),
139 field: self.field.clone(),
140 entries,
141 };
142 oxicode::serde::encode_serde(&snapshot).map_err(|e| {
143 AmateRSError::SerializationError(ErrorContext::new(format!(
144 "EncryptedIndex serialize failed: {e}"
145 )))
146 })
147 }
148
149 pub fn deserialize(data: &[u8]) -> Result<Self> {
151 let snapshot: EncryptedIndexSnapshot = oxicode::serde::decode_serde(data).map_err(|e| {
152 AmateRSError::SerializationError(ErrorContext::new(format!(
153 "EncryptedIndex deserialize failed: {e}"
154 )))
155 })?;
156 let index = DashMap::new();
157 for (hash_key, record_ids) in snapshot.entries {
158 index.insert(hash_key, record_ids);
159 }
160 Ok(Self {
161 name: snapshot.name,
162 collection: snapshot.collection,
163 field: snapshot.field,
164 index,
165 })
166 }
167}
168
169impl std::fmt::Debug for EncryptedIndex {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 f.debug_struct("EncryptedIndex")
172 .field("name", &self.name)
173 .field("collection", &self.collection)
174 .field("field", &self.field)
175 .field("bucket_count", &self.index.len())
176 .finish()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_insert_and_lookup() {
186 let idx = EncryptedIndex::new("test", "col", "field");
187 let bytes = b"same_ciphertext";
188 idx.insert(bytes, 1);
189 idx.insert(bytes, 2);
190 idx.insert(bytes, 3);
191 let mut candidates = idx.lookup_candidates(bytes);
192 candidates.sort_unstable();
193 assert_eq!(candidates, vec![1, 2, 3]);
194 }
195
196 #[test]
197 fn test_insert_different_keys() {
198 let idx = EncryptedIndex::new("test", "col", "field");
199 idx.insert(b"key_a", 10);
200 idx.insert(b"key_b", 20);
201 let mut a_cands = idx.lookup_candidates(b"key_a");
202 let mut b_cands = idx.lookup_candidates(b"key_b");
203 a_cands.sort_unstable();
204 b_cands.sort_unstable();
205 assert_eq!(a_cands, vec![10]);
206 assert_eq!(b_cands, vec![20]);
207 assert_eq!(idx.len(), 2);
208 }
209
210 #[test]
211 fn test_remove_record() {
212 let idx = EncryptedIndex::new("test", "col", "field");
213 idx.insert(b"ct_bytes", 100);
214 idx.insert(b"ct_bytes", 101);
215 idx.remove(b"ct_bytes", 100);
216 let candidates = idx.lookup_candidates(b"ct_bytes");
217 assert!(!candidates.contains(&100));
218 assert!(candidates.contains(&101));
219 }
220
221 #[test]
222 fn test_serialize_deserialize_roundtrip() {
223 let idx = EncryptedIndex::new("persist", "docs", "content");
224 idx.insert(b"cipher_a", 1);
225 idx.insert(b"cipher_a", 2);
226 idx.insert(b"cipher_b", 3);
227
228 let bytes = idx.serialize().expect("serialize ok");
229 let restored = EncryptedIndex::deserialize(&bytes).expect("deserialize ok");
230
231 assert_eq!(restored.name(), "persist");
232 assert_eq!(restored.collection(), "docs");
233 assert_eq!(restored.field(), "content");
234
235 let mut a_cands = restored.lookup_candidates(b"cipher_a");
236 a_cands.sort_unstable();
237 assert_eq!(a_cands, vec![1, 2]);
238 assert_eq!(restored.lookup_candidates(b"cipher_b"), vec![3]);
239 }
240}