Skip to main content

citadel_buffer/
pool.rs

1use std::sync::Arc;
2
3use citadel_core::types::PageId;
4use citadel_core::{Error, Result};
5use citadel_core::{BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE};
6
7use citadel_crypto::page_cipher;
8use citadel_io::file_manager::page_offset;
9use citadel_io::traits::PageIO;
10use citadel_page::page::Page;
11
12use crate::sieve::SieveCache;
13
14pub fn read_and_decrypt(
15    io: &dyn PageIO,
16    page_id: PageId,
17    offset: u64,
18    dek: &[u8; DEK_SIZE],
19    mac_key: &[u8; MAC_KEY_SIZE],
20    encryption_epoch: u32,
21) -> Result<Page> {
22    let mut encrypted = [0u8; PAGE_SIZE];
23    io.read_page(offset, &mut encrypted)?;
24
25    let mut body = [0u8; BODY_SIZE];
26    page_cipher::decrypt_page(
27        dek,
28        mac_key,
29        page_id,
30        encryption_epoch,
31        &encrypted,
32        &mut body,
33    )?;
34
35    let page = Page::from_bytes(body);
36
37    if !page.verify_checksum() {
38        return Err(Error::ChecksumMismatch(page_id));
39    }
40
41    Ok(page)
42}
43
44/// Buffer pool: caches decrypted pages in memory with SIEVE eviction.
45///
46/// Keyed by physical disk offset (not logical page_id) because under CoW/MVCC
47/// the same logical page_id can exist at different disk locations.
48///
49/// Invariants:
50/// - HMAC is verified BEFORE decryption on every page fetch (cache miss).
51/// - Dirty pages are PINNED and never evictable until commit.
52/// - Transaction size is bounded by buffer pool capacity.
53pub struct BufferPool {
54    cache: SieveCache<Arc<Page>>,
55    capacity: usize,
56}
57
58impl BufferPool {
59    pub fn new(capacity: usize) -> Self {
60        Self {
61            cache: SieveCache::new(capacity),
62            capacity,
63        }
64    }
65
66    /// Fetch a page by page_id. Reads from cache or disk.
67    ///
68    /// On cache miss: reads from disk, verifies HMAC BEFORE decrypting,
69    /// verifies xxHash64 checksum after decrypting.
70    pub fn fetch(
71        &mut self,
72        io: &dyn PageIO,
73        page_id: PageId,
74        dek: &[u8; DEK_SIZE],
75        mac_key: &[u8; MAC_KEY_SIZE],
76        encryption_epoch: u32,
77    ) -> Result<&Page> {
78        let offset = page_offset(page_id);
79
80        if self.cache.contains(offset) {
81            return Ok(self.cache.get(offset).unwrap());
82        }
83
84        let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
85        self.cache
86            .insert(offset, Arc::new(page))
87            .map_err(|()| Error::BufferPoolFull)?;
88
89        Ok(self.cache.get(offset).unwrap())
90    }
91
92    pub fn fetch_mut(
93        &mut self,
94        io: &dyn PageIO,
95        page_id: PageId,
96        dek: &[u8; DEK_SIZE],
97        mac_key: &[u8; MAC_KEY_SIZE],
98        encryption_epoch: u32,
99    ) -> Result<&mut Page> {
100        let offset = page_offset(page_id);
101
102        if !self.cache.contains(offset) {
103            let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
104            self.cache
105                .insert(offset, Arc::new(page))
106                .map_err(|()| Error::BufferPoolFull)?;
107        }
108
109        Ok(Arc::make_mut(self.cache.get_mut(offset).unwrap()))
110    }
111
112    /// Insert a newly allocated page. Marks it dirty immediately.
113    pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
114        let offset = page_offset(page_id);
115
116        if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
117            self.cache
118                .insert(offset, Arc::new(page))
119                .map_err(|()| Error::TransactionTooLarge {
120                    capacity: self.capacity,
121                })?;
122        } else {
123            self.cache
124                .insert(offset, Arc::new(page))
125                .map_err(|()| Error::BufferPoolFull)?;
126        }
127
128        self.cache.set_dirty(offset);
129        Ok(())
130    }
131
132    pub fn mark_dirty(&mut self, page_id: PageId) {
133        let offset = page_offset(page_id);
134        self.cache.set_dirty(offset);
135    }
136
137    /// Flush all dirty pages to disk: encrypt + compute MAC + write.
138    /// Clears dirty flags after successful flush.
139    pub fn flush_dirty(
140        &mut self,
141        io: &dyn PageIO,
142        dek: &[u8; DEK_SIZE],
143        mac_key: &[u8; MAC_KEY_SIZE],
144        encryption_epoch: u32,
145    ) -> Result<()> {
146        let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self
147            .cache
148            .dirty_entries()
149            .map(|(offset, arc)| {
150                let page_id = arc.page_id();
151                let body = *arc.as_bytes();
152                (offset, page_id, body)
153            })
154            .collect();
155
156        for (offset, page_id, body) in &dirty {
157            let mut encrypted = [0u8; PAGE_SIZE];
158            page_cipher::encrypt_page(
159                dek,
160                mac_key,
161                *page_id,
162                encryption_epoch,
163                body,
164                &mut encrypted,
165            );
166            io.write_page(*offset, &encrypted)?;
167        }
168
169        self.cache.clear_all_dirty();
170        Ok(())
171    }
172
173    /// Discard all dirty pages (on transaction abort).
174    /// Removes dirty entries from the cache.
175    pub fn discard_dirty(&mut self) {
176        let dirty_offsets: Vec<u64> = self
177            .cache
178            .dirty_entries()
179            .map(|(offset, _)| offset)
180            .collect();
181
182        for offset in dirty_offsets {
183            self.cache.remove(offset);
184        }
185    }
186
187    pub fn get_cached(&mut self, page_id: PageId) -> Option<Arc<Page>> {
188        let offset = page_offset(page_id);
189        self.cache.get(offset).map(Arc::clone)
190    }
191
192    pub fn insert_if_absent(&mut self, page_id: PageId, page: Arc<Page>) {
193        let offset = page_offset(page_id);
194        if !self.cache.contains(offset) {
195            let _ = self.cache.insert(offset, page);
196        }
197    }
198
199    pub fn len(&self) -> usize {
200        self.cache.len()
201    }
202
203    pub fn is_empty(&self) -> bool {
204        self.cache.is_empty()
205    }
206
207    pub fn dirty_count(&self) -> usize {
208        self.cache.dirty_count()
209    }
210
211    pub fn capacity(&self) -> usize {
212        self.capacity
213    }
214
215    pub fn is_cached(&self, page_id: PageId) -> bool {
216        let offset = page_offset(page_id);
217        self.cache.contains(offset)
218    }
219
220    pub fn invalidate(&mut self, page_id: PageId) {
221        let offset = page_offset(page_id);
222        self.cache.remove(offset);
223    }
224
225    pub fn clear(&mut self) {
226        self.cache.clear();
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use citadel_core::types::PageType;
234    use citadel_core::types::TxnId;
235    use citadel_crypto::hkdf_utils::derive_keys_from_rek;
236
237    struct MockIO {
238        pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
239    }
240
241    impl MockIO {
242        fn new() -> Self {
243            Self {
244                pages: std::sync::Mutex::new(std::collections::HashMap::new()),
245            }
246        }
247    }
248
249    impl PageIO for MockIO {
250        fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
251            let pages = self.pages.lock().unwrap();
252            if let Some(data) = pages.get(&offset) {
253                buf.copy_from_slice(data);
254                Ok(())
255            } else {
256                Err(Error::Io(std::io::Error::new(
257                    std::io::ErrorKind::NotFound,
258                    format!("no page at offset {offset}"),
259                )))
260            }
261        }
262
263        fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
264            self.pages.lock().unwrap().insert(offset, *buf);
265            Ok(())
266        }
267
268        fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> {
269            Ok(())
270        }
271        fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> {
272            Ok(())
273        }
274        fn fsync(&self) -> Result<()> {
275            Ok(())
276        }
277        fn file_size(&self) -> Result<u64> {
278            Ok(0)
279        }
280        fn truncate(&self, _size: u64) -> Result<()> {
281            Ok(())
282        }
283    }
284
285    fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
286        let rek = [0x42u8; 32];
287        let keys = derive_keys_from_rek(&rek);
288        (keys.dek, keys.mac_key)
289    }
290
291    fn write_encrypted_page(
292        io: &MockIO,
293        page: &Page,
294        dek: &[u8; DEK_SIZE],
295        mac_key: &[u8; MAC_KEY_SIZE],
296        epoch: u32,
297    ) {
298        let page_id = page.page_id();
299        let offset = page_offset(page_id);
300        let mut encrypted = [0u8; PAGE_SIZE];
301        page_cipher::encrypt_page(
302            dek,
303            mac_key,
304            page_id,
305            epoch,
306            page.as_bytes(),
307            &mut encrypted,
308        );
309        io.write_page(offset, &encrypted).unwrap();
310    }
311
312    #[test]
313    fn fetch_reads_and_caches() {
314        let (dek, mac_key) = test_keys();
315        let io = MockIO::new();
316        let epoch = 1;
317
318        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
319        page.update_checksum();
320        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
321
322        let mut pool = BufferPool::new(16);
323        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
324        assert_eq!(fetched.page_id(), PageId(0));
325        assert!(pool.is_cached(PageId(0)));
326    }
327
328    #[test]
329    fn fetch_from_cache_on_second_call() {
330        let (dek, mac_key) = test_keys();
331        let io = MockIO::new();
332        let epoch = 1;
333
334        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
335        page.update_checksum();
336        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
337
338        let mut pool = BufferPool::new(16);
339        pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
340
341        // Remove from "disk" - should still be in cache
342        io.pages.lock().unwrap().clear();
343        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
344        assert_eq!(fetched.page_id(), PageId(0));
345    }
346
347    #[test]
348    fn tampered_page_detected_on_fetch() {
349        let (dek, mac_key) = test_keys();
350        let io = MockIO::new();
351        let epoch = 1;
352
353        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
354        page.update_checksum();
355        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
356
357        // Tamper with encrypted data
358        let offset = page_offset(PageId(0));
359        {
360            let mut pages = io.pages.lock().unwrap();
361            let data = pages.get_mut(&offset).unwrap();
362            data[100] ^= 0x01;
363        }
364
365        let mut pool = BufferPool::new(16);
366        let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
367        assert!(matches!(result, Err(Error::PageTampered(_))));
368    }
369
370    #[test]
371    fn dirty_pages_survive_eviction() {
372        let mut pool = BufferPool::new(3);
373
374        // Insert 3 pages (all dirty from insert_new)
375        for i in 0..3 {
376            let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
377            page.update_checksum();
378            pool.insert_new(PageId(i), page).unwrap();
379        }
380
381        assert_eq!(pool.dirty_count(), 3);
382
383        // Clear dirty on pages 0 and 2, making them evictable
384        pool.cache.clear_dirty(page_offset(PageId(0)));
385        pool.cache.clear_dirty(page_offset(PageId(2)));
386
387        // Insert page 3 - should evict page 0 or 2 (not dirty page 1)
388        let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
389        page3.update_checksum();
390        pool.insert_new(PageId(3), page3).unwrap();
391        // Dirty page 1 must still be in the cache
392        assert!(pool.is_cached(PageId(1)));
393    }
394
395    #[test]
396    fn flush_dirty_writes_encrypted() {
397        let (dek, mac_key) = test_keys();
398        let io = MockIO::new();
399        let epoch = 1;
400
401        let mut pool = BufferPool::new(16);
402        let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
403        page.update_checksum();
404        pool.insert_new(PageId(5), page).unwrap();
405
406        assert_eq!(pool.dirty_count(), 1);
407
408        pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
409        assert_eq!(pool.dirty_count(), 0);
410
411        // Verify we can read it back from disk
412        let offset = page_offset(PageId(5));
413        assert!(io.pages.lock().unwrap().contains_key(&offset));
414    }
415
416    #[test]
417    fn discard_dirty_removes_from_cache() {
418        let mut pool = BufferPool::new(16);
419        let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
420        page.update_checksum();
421        pool.insert_new(PageId(1), page).unwrap();
422
423        assert_eq!(pool.len(), 1);
424        pool.discard_dirty();
425        assert_eq!(pool.len(), 0);
426    }
427}