Skip to main content

citadel_buffer/
pool.rs

1use citadel_core::types::PageId;
2use citadel_core::{Error, Result};
3use citadel_core::{BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE};
4
5use citadel_crypto::page_cipher;
6use citadel_io::file_manager::page_offset;
7use citadel_io::traits::PageIO;
8use citadel_page::page::Page;
9
10use crate::sieve::SieveCache;
11
12/// Buffer pool: caches decrypted pages in memory with SIEVE eviction.
13///
14/// Keyed by physical disk offset (not logical page_id) because under CoW/MVCC
15/// the same logical page_id can exist at different disk locations.
16///
17/// Invariants:
18/// - HMAC is verified BEFORE decryption on every page fetch (cache miss).
19/// - Dirty pages are PINNED and never evictable until commit.
20/// - Transaction size is bounded by buffer pool capacity.
21pub struct BufferPool {
22    cache: SieveCache<Page>,
23    capacity: usize,
24}
25
26impl BufferPool {
27    pub fn new(capacity: usize) -> Self {
28        Self {
29            cache: SieveCache::new(capacity),
30            capacity,
31        }
32    }
33
34    /// Fetch a page by page_id. Reads from cache or disk.
35    ///
36    /// On cache miss: reads from disk, verifies HMAC BEFORE decrypting,
37    /// verifies xxHash64 checksum after decrypting.
38    pub fn fetch(
39        &mut self,
40        io: &dyn PageIO,
41        page_id: PageId,
42        dek: &[u8; DEK_SIZE],
43        mac_key: &[u8; MAC_KEY_SIZE],
44        encryption_epoch: u32,
45    ) -> Result<&Page> {
46        let offset = page_offset(page_id);
47
48        // Cache hit
49        if self.cache.contains(offset) {
50            return Ok(self.cache.get(offset).unwrap());
51        }
52
53        // Cache miss: read from disk
54        let page = self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
55
56        // Insert into cache (may evict)
57        self.cache
58            .insert(offset, page)
59            .map_err(|()| Error::BufferPoolFull)?;
60
61        Ok(self.cache.get(offset).unwrap())
62    }
63
64    /// Fetch a page mutably (for modification during write transaction).
65    pub fn fetch_mut(
66        &mut self,
67        io: &dyn PageIO,
68        page_id: PageId,
69        dek: &[u8; DEK_SIZE],
70        mac_key: &[u8; MAC_KEY_SIZE],
71        encryption_epoch: u32,
72    ) -> Result<&mut Page> {
73        let offset = page_offset(page_id);
74
75        if !self.cache.contains(offset) {
76            let page =
77                self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
78            self.cache
79                .insert(offset, page)
80                .map_err(|()| Error::BufferPoolFull)?;
81        }
82
83        Ok(self.cache.get_mut(offset).unwrap())
84    }
85
86    /// Insert a new page directly into the buffer pool (for newly allocated pages).
87    /// Marks it as dirty immediately.
88    pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
89        let offset = page_offset(page_id);
90
91        // Check if we'd exceed capacity with all dirty pages
92        if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
93            // Try to insert (may evict a clean page)
94            self.cache
95                .insert(offset, page)
96                .map_err(|()| Error::TransactionTooLarge {
97                    capacity: self.capacity,
98                })?;
99        } else {
100            self.cache
101                .insert(offset, page)
102                .map_err(|()| Error::BufferPoolFull)?;
103        }
104
105        self.cache.set_dirty(offset);
106        Ok(())
107    }
108
109    /// Mark a page as dirty (modified in current write transaction).
110    pub fn mark_dirty(&mut self, page_id: PageId) {
111        let offset = page_offset(page_id);
112        self.cache.set_dirty(offset);
113    }
114
115    /// Flush all dirty pages to disk: encrypt + compute MAC + write.
116    /// Clears dirty flags after successful flush.
117    pub fn flush_dirty(
118        &mut self,
119        io: &dyn PageIO,
120        dek: &[u8; DEK_SIZE],
121        mac_key: &[u8; MAC_KEY_SIZE],
122        encryption_epoch: u32,
123    ) -> Result<()> {
124        // Collect dirty page offsets and data
125        let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self
126            .cache
127            .dirty_entries()
128            .map(|(offset, page)| {
129                let page_id = page.page_id();
130                let body = *page.as_bytes();
131                (offset, page_id, body)
132            })
133            .collect();
134
135        for (offset, page_id, body) in &dirty {
136            let mut encrypted = [0u8; PAGE_SIZE];
137            page_cipher::encrypt_page(
138                dek,
139                mac_key,
140                *page_id,
141                encryption_epoch,
142                body,
143                &mut encrypted,
144            );
145            io.write_page(*offset, &encrypted)?;
146        }
147
148        self.cache.clear_all_dirty();
149        Ok(())
150    }
151
152    /// Discard all dirty pages (on transaction abort).
153    /// Removes dirty entries from the cache.
154    pub fn discard_dirty(&mut self) {
155        let dirty_offsets: Vec<u64> = self
156            .cache
157            .dirty_entries()
158            .map(|(offset, _)| offset)
159            .collect();
160
161        for offset in dirty_offsets {
162            self.cache.remove(offset);
163        }
164    }
165
166    fn read_and_decrypt(
167        &self,
168        io: &dyn PageIO,
169        page_id: PageId,
170        offset: u64,
171        dek: &[u8; DEK_SIZE],
172        mac_key: &[u8; MAC_KEY_SIZE],
173        encryption_epoch: u32,
174    ) -> Result<Page> {
175        let mut encrypted = [0u8; PAGE_SIZE];
176        io.read_page(offset, &mut encrypted)?;
177
178        // INVARIANT: Verify HMAC BEFORE decryption
179        let mut body = [0u8; BODY_SIZE];
180        page_cipher::decrypt_page(
181            dek,
182            mac_key,
183            page_id,
184            encryption_epoch,
185            &encrypted,
186            &mut body,
187        )?;
188
189        let page = Page::from_bytes(body);
190
191        // Defense-in-depth: verify xxHash64 checksum
192        if !page.verify_checksum() {
193            return Err(Error::ChecksumMismatch(page_id));
194        }
195
196        Ok(page)
197    }
198
199    /// Number of pages currently in the cache.
200    pub fn len(&self) -> usize {
201        self.cache.len()
202    }
203
204    /// Whether the cache is empty.
205    pub fn is_empty(&self) -> bool {
206        self.cache.is_empty()
207    }
208
209    /// Number of dirty pages.
210    pub fn dirty_count(&self) -> usize {
211        self.cache.dirty_count()
212    }
213
214    /// Cache capacity.
215    pub fn capacity(&self) -> usize {
216        self.capacity
217    }
218
219    /// Check if a page is cached.
220    pub fn is_cached(&self, page_id: PageId) -> bool {
221        let offset = page_offset(page_id);
222        self.cache.contains(offset)
223    }
224
225    /// Remove a single page from the cache (e.g. after CoW overwrites it).
226    pub fn invalidate(&mut self, page_id: PageId) {
227        let offset = page_offset(page_id);
228        self.cache.remove(offset);
229    }
230
231    /// Clear all entries from the cache.
232    pub fn clear(&mut self) {
233        self.cache.clear();
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use citadel_core::types::PageType;
241    use citadel_core::types::TxnId;
242    use citadel_crypto::hkdf_utils::derive_keys_from_rek;
243
244    struct MockIO {
245        pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
246    }
247
248    impl MockIO {
249        fn new() -> Self {
250            Self {
251                pages: std::sync::Mutex::new(std::collections::HashMap::new()),
252            }
253        }
254    }
255
256    impl PageIO for MockIO {
257        fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
258            let pages = self.pages.lock().unwrap();
259            if let Some(data) = pages.get(&offset) {
260                buf.copy_from_slice(data);
261                Ok(())
262            } else {
263                Err(Error::Io(std::io::Error::new(
264                    std::io::ErrorKind::NotFound,
265                    format!("no page at offset {offset}"),
266                )))
267            }
268        }
269
270        fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
271            self.pages.lock().unwrap().insert(offset, *buf);
272            Ok(())
273        }
274
275        fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> {
276            Ok(())
277        }
278        fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> {
279            Ok(())
280        }
281        fn fsync(&self) -> Result<()> {
282            Ok(())
283        }
284        fn file_size(&self) -> Result<u64> {
285            Ok(0)
286        }
287        fn truncate(&self, _size: u64) -> Result<()> {
288            Ok(())
289        }
290    }
291
292    fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
293        let rek = [0x42u8; 32];
294        let keys = derive_keys_from_rek(&rek);
295        (keys.dek, keys.mac_key)
296    }
297
298    fn write_encrypted_page(
299        io: &MockIO,
300        page: &Page,
301        dek: &[u8; DEK_SIZE],
302        mac_key: &[u8; MAC_KEY_SIZE],
303        epoch: u32,
304    ) {
305        let page_id = page.page_id();
306        let offset = page_offset(page_id);
307        let mut encrypted = [0u8; PAGE_SIZE];
308        page_cipher::encrypt_page(
309            dek,
310            mac_key,
311            page_id,
312            epoch,
313            page.as_bytes(),
314            &mut encrypted,
315        );
316        io.write_page(offset, &encrypted).unwrap();
317    }
318
319    #[test]
320    fn fetch_reads_and_caches() {
321        let (dek, mac_key) = test_keys();
322        let io = MockIO::new();
323        let epoch = 1;
324
325        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
326        page.update_checksum();
327        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
328
329        let mut pool = BufferPool::new(16);
330        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
331        assert_eq!(fetched.page_id(), PageId(0));
332        assert!(pool.is_cached(PageId(0)));
333    }
334
335    #[test]
336    fn fetch_from_cache_on_second_call() {
337        let (dek, mac_key) = test_keys();
338        let io = MockIO::new();
339        let epoch = 1;
340
341        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
342        page.update_checksum();
343        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
344
345        let mut pool = BufferPool::new(16);
346        pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
347
348        // Remove from "disk" — should still be in cache
349        io.pages.lock().unwrap().clear();
350        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
351        assert_eq!(fetched.page_id(), PageId(0));
352    }
353
354    #[test]
355    fn tampered_page_detected_on_fetch() {
356        let (dek, mac_key) = test_keys();
357        let io = MockIO::new();
358        let epoch = 1;
359
360        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
361        page.update_checksum();
362        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
363
364        // Tamper with encrypted data
365        let offset = page_offset(PageId(0));
366        {
367            let mut pages = io.pages.lock().unwrap();
368            let data = pages.get_mut(&offset).unwrap();
369            data[100] ^= 0x01;
370        }
371
372        let mut pool = BufferPool::new(16);
373        let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
374        assert!(matches!(result, Err(Error::PageTampered(_))));
375    }
376
377    #[test]
378    fn dirty_pages_survive_eviction() {
379        let mut pool = BufferPool::new(3);
380
381        // Insert 3 pages (all dirty from insert_new)
382        for i in 0..3 {
383            let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
384            page.update_checksum();
385            pool.insert_new(PageId(i), page).unwrap();
386        }
387
388        assert_eq!(pool.dirty_count(), 3);
389
390        // Clear dirty on pages 0 and 2, making them evictable
391        pool.cache.clear_dirty(page_offset(PageId(0)));
392        pool.cache.clear_dirty(page_offset(PageId(2)));
393
394        // Insert page 3 — should evict page 0 or 2 (not dirty page 1)
395        let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
396        page3.update_checksum();
397        pool.insert_new(PageId(3), page3).unwrap();
398        // Dirty page 1 must still be in the cache
399        assert!(pool.is_cached(PageId(1)));
400    }
401
402    #[test]
403    fn flush_dirty_writes_encrypted() {
404        let (dek, mac_key) = test_keys();
405        let io = MockIO::new();
406        let epoch = 1;
407
408        let mut pool = BufferPool::new(16);
409        let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
410        page.update_checksum();
411        pool.insert_new(PageId(5), page).unwrap();
412
413        assert_eq!(pool.dirty_count(), 1);
414
415        pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
416        assert_eq!(pool.dirty_count(), 0);
417
418        // Verify we can read it back from disk
419        let offset = page_offset(PageId(5));
420        assert!(io.pages.lock().unwrap().contains_key(&offset));
421    }
422
423    #[test]
424    fn discard_dirty_removes_from_cache() {
425        let mut pool = BufferPool::new(16);
426        let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
427        page.update_checksum();
428        pool.insert_new(PageId(1), page).unwrap();
429
430        assert_eq!(pool.len(), 1);
431        pool.discard_dirty();
432        assert_eq!(pool.len(), 0);
433    }
434}