Skip to main content

citadeldb_buffer/
pool.rs

1use citadel_core::{
2    BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE,
3};
4use citadel_core::types::PageId;
5use citadel_core::{Error, Result};
6
7use citadel_crypto::page_cipher;
8use citadel_io::file_manager::page_offset;
9use citadel_io::traits::PageIO;
10use citadel_page::page::Page;
11
12use crate::sieve::SieveCache;
13
14/// Buffer pool: caches decrypted pages in memory with SIEVE eviction.
15///
16/// Keyed by physical disk offset (not logical page_id) because under CoW/MVCC
17/// the same logical page_id can exist at different disk locations.
18///
19/// Invariants:
20/// - HMAC is verified BEFORE decryption on every page fetch (cache miss).
21/// - Dirty pages are PINNED and never evictable until commit.
22/// - Transaction size is bounded by buffer pool capacity.
23pub struct BufferPool {
24    cache: SieveCache<Page>,
25    capacity: usize,
26}
27
28impl BufferPool {
29    pub fn new(capacity: usize) -> Self {
30        Self {
31            cache: SieveCache::new(capacity),
32            capacity,
33        }
34    }
35
36    /// Fetch a page by page_id. Reads from cache or disk.
37    ///
38    /// On cache miss: reads from disk, verifies HMAC BEFORE decrypting,
39    /// verifies xxHash64 checksum after decrypting.
40    pub fn fetch(
41        &mut self,
42        io: &dyn PageIO,
43        page_id: PageId,
44        dek: &[u8; DEK_SIZE],
45        mac_key: &[u8; MAC_KEY_SIZE],
46        encryption_epoch: u32,
47    ) -> Result<&Page> {
48        let offset = page_offset(page_id);
49
50        // Cache hit
51        if self.cache.contains(offset) {
52            return Ok(self.cache.get(offset).unwrap());
53        }
54
55        // Cache miss: read from disk
56        let page = self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
57
58        // Insert into cache (may evict)
59        self.cache.insert(offset, page)
60            .map_err(|()| Error::BufferPoolFull)?;
61
62        Ok(self.cache.get(offset).unwrap())
63    }
64
65    /// Fetch a page mutably (for modification during write transaction).
66    pub fn fetch_mut(
67        &mut self,
68        io: &dyn PageIO,
69        page_id: PageId,
70        dek: &[u8; DEK_SIZE],
71        mac_key: &[u8; MAC_KEY_SIZE],
72        encryption_epoch: u32,
73    ) -> Result<&mut Page> {
74        let offset = page_offset(page_id);
75
76        if !self.cache.contains(offset) {
77            let page = self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
78            self.cache.insert(offset, page)
79                .map_err(|()| Error::BufferPoolFull)?;
80        }
81
82        Ok(self.cache.get_mut(offset).unwrap())
83    }
84
85    /// Insert a new page directly into the buffer pool (for newly allocated pages).
86    /// Marks it as dirty immediately.
87    pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
88        let offset = page_offset(page_id);
89
90        // Check if we'd exceed capacity with all dirty pages
91        if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
92            // Try to insert (may evict a clean page)
93            self.cache.insert(offset, page)
94                .map_err(|()| Error::TransactionTooLarge { capacity: self.capacity })?;
95        } else {
96            self.cache.insert(offset, page)
97                .map_err(|()| Error::BufferPoolFull)?;
98        }
99
100        self.cache.set_dirty(offset);
101        Ok(())
102    }
103
104    /// Mark a page as dirty (modified in current write transaction).
105    pub fn mark_dirty(&mut self, page_id: PageId) {
106        let offset = page_offset(page_id);
107        self.cache.set_dirty(offset);
108    }
109
110    /// Flush all dirty pages to disk: encrypt + compute MAC + write.
111    /// Clears dirty flags after successful flush.
112    pub fn flush_dirty(
113        &mut self,
114        io: &dyn PageIO,
115        dek: &[u8; DEK_SIZE],
116        mac_key: &[u8; MAC_KEY_SIZE],
117        encryption_epoch: u32,
118    ) -> Result<()> {
119        // Collect dirty page offsets and data
120        let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self.cache.dirty_entries()
121            .map(|(offset, page)| {
122                let page_id = page.page_id();
123                let body = *page.as_bytes();
124                (offset, page_id, body)
125            })
126            .collect();
127
128        for (offset, page_id, body) in &dirty {
129            let mut encrypted = [0u8; PAGE_SIZE];
130            page_cipher::encrypt_page(dek, mac_key, *page_id, encryption_epoch, body, &mut encrypted);
131            io.write_page(*offset, &encrypted)?;
132        }
133
134        self.cache.clear_all_dirty();
135        Ok(())
136    }
137
138    /// Discard all dirty pages (on transaction abort).
139    /// Removes dirty entries from the cache.
140    pub fn discard_dirty(&mut self) {
141        let dirty_offsets: Vec<u64> = self.cache.dirty_entries()
142            .map(|(offset, _)| offset)
143            .collect();
144
145        for offset in dirty_offsets {
146            self.cache.remove(offset);
147        }
148    }
149
150    fn read_and_decrypt(
151        &self,
152        io: &dyn PageIO,
153        page_id: PageId,
154        offset: u64,
155        dek: &[u8; DEK_SIZE],
156        mac_key: &[u8; MAC_KEY_SIZE],
157        encryption_epoch: u32,
158    ) -> Result<Page> {
159        let mut encrypted = [0u8; PAGE_SIZE];
160        io.read_page(offset, &mut encrypted)?;
161
162        // INVARIANT: Verify HMAC BEFORE decryption
163        let mut body = [0u8; BODY_SIZE];
164        page_cipher::decrypt_page(dek, mac_key, page_id, encryption_epoch, &encrypted, &mut body)?;
165
166        let page = Page::from_bytes(body);
167
168        // Defense-in-depth: verify xxHash64 checksum
169        if !page.verify_checksum() {
170            return Err(Error::ChecksumMismatch(page_id));
171        }
172
173        Ok(page)
174    }
175
176    /// Number of pages currently in the cache.
177    pub fn len(&self) -> usize {
178        self.cache.len()
179    }
180
181    /// Number of dirty pages.
182    pub fn dirty_count(&self) -> usize {
183        self.cache.dirty_count()
184    }
185
186    /// Cache capacity.
187    pub fn capacity(&self) -> usize {
188        self.capacity
189    }
190
191    /// Check if a page is cached.
192    pub fn is_cached(&self, page_id: PageId) -> bool {
193        let offset = page_offset(page_id);
194        self.cache.contains(offset)
195    }
196
197    /// Clear all entries from the cache.
198    pub fn clear(&mut self) {
199        self.cache.clear();
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use citadel_core::types::PageType;
207    use citadel_core::types::TxnId;
208    use citadel_crypto::hkdf_utils::derive_keys_from_rek;
209
210    struct MockIO {
211        pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
212    }
213
214    impl MockIO {
215        fn new() -> Self {
216            Self { pages: std::sync::Mutex::new(std::collections::HashMap::new()) }
217        }
218    }
219
220    impl PageIO for MockIO {
221        fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
222            let pages = self.pages.lock().unwrap();
223            if let Some(data) = pages.get(&offset) {
224                buf.copy_from_slice(data);
225                Ok(())
226            } else {
227                Err(Error::Io(std::io::Error::new(
228                    std::io::ErrorKind::NotFound,
229                    format!("no page at offset {offset}"),
230                )))
231            }
232        }
233
234        fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
235            self.pages.lock().unwrap().insert(offset, *buf);
236            Ok(())
237        }
238
239        fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> { Ok(()) }
240        fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> { Ok(()) }
241        fn fsync(&self) -> Result<()> { Ok(()) }
242        fn file_size(&self) -> Result<u64> { Ok(0) }
243        fn truncate(&self, _size: u64) -> Result<()> { Ok(()) }
244    }
245
246    fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
247        let rek = [0x42u8; 32];
248        let keys = derive_keys_from_rek(&rek);
249        (keys.dek, keys.mac_key)
250    }
251
252    fn write_encrypted_page(io: &MockIO, page: &Page, dek: &[u8; DEK_SIZE], mac_key: &[u8; MAC_KEY_SIZE], epoch: u32) {
253        let page_id = page.page_id();
254        let offset = page_offset(page_id);
255        let mut encrypted = [0u8; PAGE_SIZE];
256        page_cipher::encrypt_page(dek, mac_key, page_id, epoch, page.as_bytes(), &mut encrypted);
257        io.write_page(offset, &encrypted).unwrap();
258    }
259
260    #[test]
261    fn fetch_reads_and_caches() {
262        let (dek, mac_key) = test_keys();
263        let io = MockIO::new();
264        let epoch = 1;
265
266        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
267        page.update_checksum();
268        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
269
270        let mut pool = BufferPool::new(16);
271        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
272        assert_eq!(fetched.page_id(), PageId(0));
273        assert!(pool.is_cached(PageId(0)));
274    }
275
276    #[test]
277    fn fetch_from_cache_on_second_call() {
278        let (dek, mac_key) = test_keys();
279        let io = MockIO::new();
280        let epoch = 1;
281
282        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
283        page.update_checksum();
284        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
285
286        let mut pool = BufferPool::new(16);
287        pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
288
289        // Remove from "disk" — should still be in cache
290        io.pages.lock().unwrap().clear();
291        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
292        assert_eq!(fetched.page_id(), PageId(0));
293    }
294
295    #[test]
296    fn tampered_page_detected_on_fetch() {
297        let (dek, mac_key) = test_keys();
298        let io = MockIO::new();
299        let epoch = 1;
300
301        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
302        page.update_checksum();
303        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
304
305        // Tamper with encrypted data
306        let offset = page_offset(PageId(0));
307        {
308            let mut pages = io.pages.lock().unwrap();
309            let data = pages.get_mut(&offset).unwrap();
310            data[100] ^= 0x01;
311        }
312
313        let mut pool = BufferPool::new(16);
314        let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
315        assert!(matches!(result, Err(Error::PageTampered(_))));
316    }
317
318    #[test]
319    fn dirty_pages_survive_eviction() {
320        let mut pool = BufferPool::new(3);
321
322        // Insert 3 pages (all dirty from insert_new)
323        for i in 0..3 {
324            let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
325            page.update_checksum();
326            pool.insert_new(PageId(i), page).unwrap();
327        }
328
329        assert_eq!(pool.dirty_count(), 3);
330
331        // Clear dirty on pages 0 and 2, making them evictable
332        pool.cache.clear_dirty(page_offset(PageId(0)));
333        pool.cache.clear_dirty(page_offset(PageId(2)));
334
335        // Insert page 3 — should evict page 0 or 2 (not dirty page 1)
336        let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
337        page3.update_checksum();
338        pool.insert_new(PageId(3), page3).unwrap();
339        // Dirty page 1 must still be in the cache
340        assert!(pool.is_cached(PageId(1)));
341    }
342
343    #[test]
344    fn flush_dirty_writes_encrypted() {
345        let (dek, mac_key) = test_keys();
346        let io = MockIO::new();
347        let epoch = 1;
348
349        let mut pool = BufferPool::new(16);
350        let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
351        page.update_checksum();
352        pool.insert_new(PageId(5), page).unwrap();
353
354        assert_eq!(pool.dirty_count(), 1);
355
356        pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
357        assert_eq!(pool.dirty_count(), 0);
358
359        // Verify we can read it back from disk
360        let offset = page_offset(PageId(5));
361        assert!(io.pages.lock().unwrap().contains_key(&offset));
362    }
363
364    #[test]
365    fn discard_dirty_removes_from_cache() {
366        let mut pool = BufferPool::new(16);
367        let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
368        page.update_checksum();
369        pool.insert_new(PageId(1), page).unwrap();
370
371        assert_eq!(pool.len(), 1);
372        pool.discard_dirty();
373        assert_eq!(pool.len(), 0);
374    }
375}