Skip to main content

citadel_buffer/
pool.rs

1use std::sync::Arc;
2
3use citadel_core::types::PageId;
4use citadel_core::{Error, Result};
5use citadel_core::{BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE};
6
7use citadel_crypto::page_cipher;
8use citadel_io::file_manager::page_offset;
9use citadel_io::traits::PageIO;
10use citadel_page::page::Page;
11
12use crate::sieve::SieveCache;
13
14pub fn read_and_decrypt(
15    io: &dyn PageIO,
16    page_id: PageId,
17    offset: u64,
18    dek: &[u8; DEK_SIZE],
19    mac_key: &[u8; MAC_KEY_SIZE],
20    encryption_epoch: u32,
21) -> Result<Page> {
22    let mut encrypted = [0u8; PAGE_SIZE];
23    io.read_page(offset, &mut encrypted)?;
24
25    let mut body = [0u8; BODY_SIZE];
26    page_cipher::decrypt_page(
27        dek,
28        mac_key,
29        page_id,
30        encryption_epoch,
31        &encrypted,
32        &mut body,
33    )?;
34
35    let page = Page::from_bytes(body);
36
37    if !page.verify_checksum() {
38        return Err(Error::ChecksumMismatch(page_id));
39    }
40
41    Ok(page)
42}
43
44/// Buffer pool: caches decrypted pages in memory with SIEVE eviction.
45///
46/// Keyed by physical disk offset (not logical page_id) because under CoW/MVCC
47/// the same logical page_id can exist at different disk locations.
48///
49/// Invariants:
50/// - HMAC is verified BEFORE decryption on every page fetch (cache miss).
51/// - Dirty pages are PINNED and never evictable until commit.
52/// - Transaction size is bounded by buffer pool capacity.
53pub struct BufferPool {
54    cache: SieveCache<Arc<Page>>,
55    capacity: usize,
56}
57
58impl BufferPool {
59    pub fn new(capacity: usize) -> Self {
60        Self {
61            cache: SieveCache::new(capacity),
62            capacity,
63        }
64    }
65
66    /// Fetch a page by page_id. Reads from cache or disk.
67    ///
68    /// On cache miss: reads from disk, verifies HMAC BEFORE decrypting,
69    /// verifies xxHash64 checksum after decrypting.
70    pub fn fetch(
71        &mut self,
72        io: &dyn PageIO,
73        page_id: PageId,
74        dek: &[u8; DEK_SIZE],
75        mac_key: &[u8; MAC_KEY_SIZE],
76        encryption_epoch: u32,
77    ) -> Result<&Page> {
78        let offset = page_offset(page_id);
79
80        // Cache hit
81        if self.cache.contains(offset) {
82            return Ok(self.cache.get(offset).unwrap());
83        }
84
85        // Cache miss: read from disk
86        let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
87
88        // Insert into cache (may evict)
89        self.cache
90            .insert(offset, Arc::new(page))
91            .map_err(|()| Error::BufferPoolFull)?;
92
93        Ok(self.cache.get(offset).unwrap())
94    }
95
96    /// Fetch a page mutably (for modification during write transaction).
97    pub fn fetch_mut(
98        &mut self,
99        io: &dyn PageIO,
100        page_id: PageId,
101        dek: &[u8; DEK_SIZE],
102        mac_key: &[u8; MAC_KEY_SIZE],
103        encryption_epoch: u32,
104    ) -> Result<&mut Page> {
105        let offset = page_offset(page_id);
106
107        if !self.cache.contains(offset) {
108            let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
109            self.cache
110                .insert(offset, Arc::new(page))
111                .map_err(|()| Error::BufferPoolFull)?;
112        }
113
114        Ok(Arc::make_mut(self.cache.get_mut(offset).unwrap()))
115    }
116
117    /// Insert a new page directly into the buffer pool (for newly allocated pages).
118    /// Marks it as dirty immediately.
119    pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
120        let offset = page_offset(page_id);
121
122        // Check if we'd exceed capacity with all dirty pages
123        if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
124            // Try to insert (may evict a clean page)
125            self.cache
126                .insert(offset, Arc::new(page))
127                .map_err(|()| Error::TransactionTooLarge {
128                    capacity: self.capacity,
129                })?;
130        } else {
131            self.cache
132                .insert(offset, Arc::new(page))
133                .map_err(|()| Error::BufferPoolFull)?;
134        }
135
136        self.cache.set_dirty(offset);
137        Ok(())
138    }
139
140    /// Mark a page as dirty (modified in current write transaction).
141    pub fn mark_dirty(&mut self, page_id: PageId) {
142        let offset = page_offset(page_id);
143        self.cache.set_dirty(offset);
144    }
145
146    /// Flush all dirty pages to disk: encrypt + compute MAC + write.
147    /// Clears dirty flags after successful flush.
148    pub fn flush_dirty(
149        &mut self,
150        io: &dyn PageIO,
151        dek: &[u8; DEK_SIZE],
152        mac_key: &[u8; MAC_KEY_SIZE],
153        encryption_epoch: u32,
154    ) -> Result<()> {
155        // Collect dirty page offsets and data
156        let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self
157            .cache
158            .dirty_entries()
159            .map(|(offset, arc)| {
160                let page_id = arc.page_id();
161                let body = *arc.as_bytes();
162                (offset, page_id, body)
163            })
164            .collect();
165
166        for (offset, page_id, body) in &dirty {
167            let mut encrypted = [0u8; PAGE_SIZE];
168            page_cipher::encrypt_page(
169                dek,
170                mac_key,
171                *page_id,
172                encryption_epoch,
173                body,
174                &mut encrypted,
175            );
176            io.write_page(*offset, &encrypted)?;
177        }
178
179        self.cache.clear_all_dirty();
180        Ok(())
181    }
182
183    /// Discard all dirty pages (on transaction abort).
184    /// Removes dirty entries from the cache.
185    pub fn discard_dirty(&mut self) {
186        let dirty_offsets: Vec<u64> = self
187            .cache
188            .dirty_entries()
189            .map(|(offset, _)| offset)
190            .collect();
191
192        for offset in dirty_offsets {
193            self.cache.remove(offset);
194        }
195    }
196
197    pub fn get_cached(&mut self, page_id: PageId) -> Option<Arc<Page>> {
198        let offset = page_offset(page_id);
199        self.cache.get(offset).map(Arc::clone)
200    }
201
202    pub fn insert_if_absent(&mut self, page_id: PageId, page: Arc<Page>) {
203        let offset = page_offset(page_id);
204        if !self.cache.contains(offset) {
205            let _ = self.cache.insert(offset, page);
206        }
207    }
208
209    /// Number of pages currently in the cache.
210    pub fn len(&self) -> usize {
211        self.cache.len()
212    }
213
214    /// Whether the cache is empty.
215    pub fn is_empty(&self) -> bool {
216        self.cache.is_empty()
217    }
218
219    /// Number of dirty pages.
220    pub fn dirty_count(&self) -> usize {
221        self.cache.dirty_count()
222    }
223
224    /// Cache capacity.
225    pub fn capacity(&self) -> usize {
226        self.capacity
227    }
228
229    /// Check if a page is cached.
230    pub fn is_cached(&self, page_id: PageId) -> bool {
231        let offset = page_offset(page_id);
232        self.cache.contains(offset)
233    }
234
235    /// Remove a single page from the cache (e.g. after CoW overwrites it).
236    pub fn invalidate(&mut self, page_id: PageId) {
237        let offset = page_offset(page_id);
238        self.cache.remove(offset);
239    }
240
241    /// Clear all entries from the cache.
242    pub fn clear(&mut self) {
243        self.cache.clear();
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use citadel_core::types::PageType;
251    use citadel_core::types::TxnId;
252    use citadel_crypto::hkdf_utils::derive_keys_from_rek;
253
254    struct MockIO {
255        pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
256    }
257
258    impl MockIO {
259        fn new() -> Self {
260            Self {
261                pages: std::sync::Mutex::new(std::collections::HashMap::new()),
262            }
263        }
264    }
265
266    impl PageIO for MockIO {
267        fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
268            let pages = self.pages.lock().unwrap();
269            if let Some(data) = pages.get(&offset) {
270                buf.copy_from_slice(data);
271                Ok(())
272            } else {
273                Err(Error::Io(std::io::Error::new(
274                    std::io::ErrorKind::NotFound,
275                    format!("no page at offset {offset}"),
276                )))
277            }
278        }
279
280        fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
281            self.pages.lock().unwrap().insert(offset, *buf);
282            Ok(())
283        }
284
285        fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> {
286            Ok(())
287        }
288        fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> {
289            Ok(())
290        }
291        fn fsync(&self) -> Result<()> {
292            Ok(())
293        }
294        fn file_size(&self) -> Result<u64> {
295            Ok(0)
296        }
297        fn truncate(&self, _size: u64) -> Result<()> {
298            Ok(())
299        }
300    }
301
302    fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
303        let rek = [0x42u8; 32];
304        let keys = derive_keys_from_rek(&rek);
305        (keys.dek, keys.mac_key)
306    }
307
308    fn write_encrypted_page(
309        io: &MockIO,
310        page: &Page,
311        dek: &[u8; DEK_SIZE],
312        mac_key: &[u8; MAC_KEY_SIZE],
313        epoch: u32,
314    ) {
315        let page_id = page.page_id();
316        let offset = page_offset(page_id);
317        let mut encrypted = [0u8; PAGE_SIZE];
318        page_cipher::encrypt_page(
319            dek,
320            mac_key,
321            page_id,
322            epoch,
323            page.as_bytes(),
324            &mut encrypted,
325        );
326        io.write_page(offset, &encrypted).unwrap();
327    }
328
329    #[test]
330    fn fetch_reads_and_caches() {
331        let (dek, mac_key) = test_keys();
332        let io = MockIO::new();
333        let epoch = 1;
334
335        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
336        page.update_checksum();
337        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
338
339        let mut pool = BufferPool::new(16);
340        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
341        assert_eq!(fetched.page_id(), PageId(0));
342        assert!(pool.is_cached(PageId(0)));
343    }
344
345    #[test]
346    fn fetch_from_cache_on_second_call() {
347        let (dek, mac_key) = test_keys();
348        let io = MockIO::new();
349        let epoch = 1;
350
351        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
352        page.update_checksum();
353        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
354
355        let mut pool = BufferPool::new(16);
356        pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
357
358        // Remove from "disk" — should still be in cache
359        io.pages.lock().unwrap().clear();
360        let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
361        assert_eq!(fetched.page_id(), PageId(0));
362    }
363
364    #[test]
365    fn tampered_page_detected_on_fetch() {
366        let (dek, mac_key) = test_keys();
367        let io = MockIO::new();
368        let epoch = 1;
369
370        let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
371        page.update_checksum();
372        write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
373
374        // Tamper with encrypted data
375        let offset = page_offset(PageId(0));
376        {
377            let mut pages = io.pages.lock().unwrap();
378            let data = pages.get_mut(&offset).unwrap();
379            data[100] ^= 0x01;
380        }
381
382        let mut pool = BufferPool::new(16);
383        let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
384        assert!(matches!(result, Err(Error::PageTampered(_))));
385    }
386
387    #[test]
388    fn dirty_pages_survive_eviction() {
389        let mut pool = BufferPool::new(3);
390
391        // Insert 3 pages (all dirty from insert_new)
392        for i in 0..3 {
393            let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
394            page.update_checksum();
395            pool.insert_new(PageId(i), page).unwrap();
396        }
397
398        assert_eq!(pool.dirty_count(), 3);
399
400        // Clear dirty on pages 0 and 2, making them evictable
401        pool.cache.clear_dirty(page_offset(PageId(0)));
402        pool.cache.clear_dirty(page_offset(PageId(2)));
403
404        // Insert page 3 — should evict page 0 or 2 (not dirty page 1)
405        let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
406        page3.update_checksum();
407        pool.insert_new(PageId(3), page3).unwrap();
408        // Dirty page 1 must still be in the cache
409        assert!(pool.is_cached(PageId(1)));
410    }
411
412    #[test]
413    fn flush_dirty_writes_encrypted() {
414        let (dek, mac_key) = test_keys();
415        let io = MockIO::new();
416        let epoch = 1;
417
418        let mut pool = BufferPool::new(16);
419        let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
420        page.update_checksum();
421        pool.insert_new(PageId(5), page).unwrap();
422
423        assert_eq!(pool.dirty_count(), 1);
424
425        pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
426        assert_eq!(pool.dirty_count(), 0);
427
428        // Verify we can read it back from disk
429        let offset = page_offset(PageId(5));
430        assert!(io.pages.lock().unwrap().contains_key(&offset));
431    }
432
433    #[test]
434    fn discard_dirty_removes_from_cache() {
435        let mut pool = BufferPool::new(16);
436        let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
437        page.update_checksum();
438        pool.insert_new(PageId(1), page).unwrap();
439
440        assert_eq!(pool.len(), 1);
441        pool.discard_dirty();
442        assert_eq!(pool.len(), 0);
443    }
444}