Skip to main content

amaters_core/storage/
encrypted_index.rs

1//! Encrypted index for ciphertext equality lookups.
2//!
3//! Maps SipHash-1-3 (128-bit) of serialized ciphertext bytes to a list of record IDs.
4
5use crate::error::{AmateRSError, ErrorContext, Result};
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use siphasher::sip128::{Hash128, Hasher128, SipHasher13};
9use std::hash::Hasher;
10
11/// Inverted index mapping ciphertext hash → Vec<record_id>.
12///
13/// Thread-safe via `DashMap`.
14pub struct EncryptedIndex {
15    name: String,
16    collection: String,
17    field: String,
18    /// Maps SipHash-1-3-128 of ciphertext bytes → sorted record IDs.
19    index: DashMap<u128, Vec<u64>>,
20}
21
22/// Serializable snapshot of `EncryptedIndex` for persistence.
23#[derive(Debug, Serialize, Deserialize)]
24struct EncryptedIndexSnapshot {
25    name: String,
26    collection: String,
27    field: String,
28    entries: Vec<(u128, Vec<u64>)>,
29}
30
31impl EncryptedIndex {
32    /// Create a new empty index.
33    ///
34    /// # Example
35    ///
36    /// ```
37    /// use amaters_core::storage::EncryptedIndex;
38    ///
39    /// let index = EncryptedIndex::new("age_index", "users", "age");
40    /// // Insert a record: ciphertext bytes representing an encrypted age value
41    /// let ciphertext_bytes = b"fake_ciphertext_for_value_30";
42    /// index.insert(ciphertext_bytes, 42); // record_id = 42
43    ///
44    /// // Look up candidates by the same ciphertext bytes
45    /// let candidates = index.lookup_candidates(ciphertext_bytes);
46    /// assert_eq!(candidates, vec![42]);
47    /// ```
48    pub fn new(
49        name: impl Into<String>,
50        collection: impl Into<String>,
51        field: impl Into<String>,
52    ) -> Self {
53        Self {
54            name: name.into(),
55            collection: collection.into(),
56            field: field.into(),
57            index: DashMap::new(),
58        }
59    }
60
61    /// Name of this index.
62    pub fn name(&self) -> &str {
63        &self.name
64    }
65
66    /// Collection this index belongs to.
67    pub fn collection(&self) -> &str {
68        &self.collection
69    }
70
71    /// Field this index covers.
72    pub fn field(&self) -> &str {
73        &self.field
74    }
75
76    /// Hash `ciphertext_bytes` using SipHash-1-3 (128-bit).
77    fn hash_bytes(ciphertext_bytes: &[u8]) -> u128 {
78        let mut hasher = SipHasher13::new();
79        hasher.write(ciphertext_bytes);
80        let Hash128 { h1, h2 } = hasher.finish128();
81        ((h1 as u128) << 64) | (h2 as u128)
82    }
83
84    /// Insert a record into the index.
85    pub fn insert(&self, ciphertext_bytes: &[u8], record_id: u64) {
86        let key = Self::hash_bytes(ciphertext_bytes);
87        self.index.entry(key).or_default().push(record_id);
88    }
89
90    /// Remove a specific record from the index.
91    pub fn remove(&self, ciphertext_bytes: &[u8], record_id: u64) {
92        let key = Self::hash_bytes(ciphertext_bytes);
93        let mut remove_entry = false;
94        if let Some(mut entry) = self.index.get_mut(&key) {
95            if let Some(pos) = entry.iter().position(|&id| id == record_id) {
96                entry.swap_remove(pos);
97            }
98            remove_entry = entry.is_empty();
99        }
100        if remove_entry {
101            self.index.remove(&key);
102        }
103    }
104
105    /// Look up candidate record IDs by ciphertext hash.
106    pub fn lookup_candidates(&self, ciphertext_bytes: &[u8]) -> Vec<u64> {
107        let key = Self::hash_bytes(ciphertext_bytes);
108        self.index
109            .get(&key)
110            .map(|entry| entry.value().clone())
111            .unwrap_or_default()
112    }
113
114    /// Number of distinct hash buckets currently in the index.
115    pub fn len(&self) -> usize {
116        self.index.len()
117    }
118
119    /// Whether the index has no entries.
120    pub fn is_empty(&self) -> bool {
121        self.index.is_empty()
122    }
123
124    /// Total number of record IDs stored across all buckets.
125    pub fn total_records(&self) -> usize {
126        self.index.iter().map(|e| e.value().len()).sum()
127    }
128
129    /// Serialize the index to bytes using oxicode.
130    pub fn serialize(&self) -> Result<Vec<u8>> {
131        let entries: Vec<(u128, Vec<u64>)> = self
132            .index
133            .iter()
134            .map(|e| (*e.key(), e.value().clone()))
135            .collect();
136        let snapshot = EncryptedIndexSnapshot {
137            name: self.name.clone(),
138            collection: self.collection.clone(),
139            field: self.field.clone(),
140            entries,
141        };
142        oxicode::serde::encode_serde(&snapshot).map_err(|e| {
143            AmateRSError::SerializationError(ErrorContext::new(format!(
144                "EncryptedIndex serialize failed: {e}"
145            )))
146        })
147    }
148
149    /// Deserialize an index from bytes produced by `serialize`.
150    pub fn deserialize(data: &[u8]) -> Result<Self> {
151        let snapshot: EncryptedIndexSnapshot = oxicode::serde::decode_serde(data).map_err(|e| {
152            AmateRSError::SerializationError(ErrorContext::new(format!(
153                "EncryptedIndex deserialize failed: {e}"
154            )))
155        })?;
156        let index = DashMap::new();
157        for (hash_key, record_ids) in snapshot.entries {
158            index.insert(hash_key, record_ids);
159        }
160        Ok(Self {
161            name: snapshot.name,
162            collection: snapshot.collection,
163            field: snapshot.field,
164            index,
165        })
166    }
167}
168
169impl std::fmt::Debug for EncryptedIndex {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        f.debug_struct("EncryptedIndex")
172            .field("name", &self.name)
173            .field("collection", &self.collection)
174            .field("field", &self.field)
175            .field("bucket_count", &self.index.len())
176            .finish()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_insert_and_lookup() {
186        let idx = EncryptedIndex::new("test", "col", "field");
187        let bytes = b"same_ciphertext";
188        idx.insert(bytes, 1);
189        idx.insert(bytes, 2);
190        idx.insert(bytes, 3);
191        let mut candidates = idx.lookup_candidates(bytes);
192        candidates.sort_unstable();
193        assert_eq!(candidates, vec![1, 2, 3]);
194    }
195
196    #[test]
197    fn test_insert_different_keys() {
198        let idx = EncryptedIndex::new("test", "col", "field");
199        idx.insert(b"key_a", 10);
200        idx.insert(b"key_b", 20);
201        let mut a_cands = idx.lookup_candidates(b"key_a");
202        let mut b_cands = idx.lookup_candidates(b"key_b");
203        a_cands.sort_unstable();
204        b_cands.sort_unstable();
205        assert_eq!(a_cands, vec![10]);
206        assert_eq!(b_cands, vec![20]);
207        assert_eq!(idx.len(), 2);
208    }
209
210    #[test]
211    fn test_remove_record() {
212        let idx = EncryptedIndex::new("test", "col", "field");
213        idx.insert(b"ct_bytes", 100);
214        idx.insert(b"ct_bytes", 101);
215        idx.remove(b"ct_bytes", 100);
216        let candidates = idx.lookup_candidates(b"ct_bytes");
217        assert!(!candidates.contains(&100));
218        assert!(candidates.contains(&101));
219    }
220
221    #[test]
222    fn test_serialize_deserialize_roundtrip() {
223        let idx = EncryptedIndex::new("persist", "docs", "content");
224        idx.insert(b"cipher_a", 1);
225        idx.insert(b"cipher_a", 2);
226        idx.insert(b"cipher_b", 3);
227
228        let bytes = idx.serialize().expect("serialize ok");
229        let restored = EncryptedIndex::deserialize(&bytes).expect("deserialize ok");
230
231        assert_eq!(restored.name(), "persist");
232        assert_eq!(restored.collection(), "docs");
233        assert_eq!(restored.field(), "content");
234
235        let mut a_cands = restored.lookup_candidates(b"cipher_a");
236        a_cands.sort_unstable();
237        assert_eq!(a_cands, vec![1, 2]);
238        assert_eq!(restored.lookup_candidates(b"cipher_b"), vec![3]);
239    }
240}