1use citadel_core::{
2 BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE,
3};
4use citadel_core::types::PageId;
5use citadel_core::{Error, Result};
6
7use citadel_crypto::page_cipher;
8use citadel_io::file_manager::page_offset;
9use citadel_io::traits::PageIO;
10use citadel_page::page::Page;
11
12use crate::sieve::SieveCache;
13
14pub struct BufferPool {
24 cache: SieveCache<Page>,
25 capacity: usize,
26}
27
28impl BufferPool {
29 pub fn new(capacity: usize) -> Self {
30 Self {
31 cache: SieveCache::new(capacity),
32 capacity,
33 }
34 }
35
36 pub fn fetch(
41 &mut self,
42 io: &dyn PageIO,
43 page_id: PageId,
44 dek: &[u8; DEK_SIZE],
45 mac_key: &[u8; MAC_KEY_SIZE],
46 encryption_epoch: u32,
47 ) -> Result<&Page> {
48 let offset = page_offset(page_id);
49
50 if self.cache.contains(offset) {
52 return Ok(self.cache.get(offset).unwrap());
53 }
54
55 let page = self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
57
58 self.cache.insert(offset, page)
60 .map_err(|()| Error::BufferPoolFull)?;
61
62 Ok(self.cache.get(offset).unwrap())
63 }
64
65 pub fn fetch_mut(
67 &mut self,
68 io: &dyn PageIO,
69 page_id: PageId,
70 dek: &[u8; DEK_SIZE],
71 mac_key: &[u8; MAC_KEY_SIZE],
72 encryption_epoch: u32,
73 ) -> Result<&mut Page> {
74 let offset = page_offset(page_id);
75
76 if !self.cache.contains(offset) {
77 let page = self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
78 self.cache.insert(offset, page)
79 .map_err(|()| Error::BufferPoolFull)?;
80 }
81
82 Ok(self.cache.get_mut(offset).unwrap())
83 }
84
85 pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
88 let offset = page_offset(page_id);
89
90 if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
92 self.cache.insert(offset, page)
94 .map_err(|()| Error::TransactionTooLarge { capacity: self.capacity })?;
95 } else {
96 self.cache.insert(offset, page)
97 .map_err(|()| Error::BufferPoolFull)?;
98 }
99
100 self.cache.set_dirty(offset);
101 Ok(())
102 }
103
104 pub fn mark_dirty(&mut self, page_id: PageId) {
106 let offset = page_offset(page_id);
107 self.cache.set_dirty(offset);
108 }
109
110 pub fn flush_dirty(
113 &mut self,
114 io: &dyn PageIO,
115 dek: &[u8; DEK_SIZE],
116 mac_key: &[u8; MAC_KEY_SIZE],
117 encryption_epoch: u32,
118 ) -> Result<()> {
119 let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self.cache.dirty_entries()
121 .map(|(offset, page)| {
122 let page_id = page.page_id();
123 let body = *page.as_bytes();
124 (offset, page_id, body)
125 })
126 .collect();
127
128 for (offset, page_id, body) in &dirty {
129 let mut encrypted = [0u8; PAGE_SIZE];
130 page_cipher::encrypt_page(dek, mac_key, *page_id, encryption_epoch, body, &mut encrypted);
131 io.write_page(*offset, &encrypted)?;
132 }
133
134 self.cache.clear_all_dirty();
135 Ok(())
136 }
137
138 pub fn discard_dirty(&mut self) {
141 let dirty_offsets: Vec<u64> = self.cache.dirty_entries()
142 .map(|(offset, _)| offset)
143 .collect();
144
145 for offset in dirty_offsets {
146 self.cache.remove(offset);
147 }
148 }
149
150 fn read_and_decrypt(
151 &self,
152 io: &dyn PageIO,
153 page_id: PageId,
154 offset: u64,
155 dek: &[u8; DEK_SIZE],
156 mac_key: &[u8; MAC_KEY_SIZE],
157 encryption_epoch: u32,
158 ) -> Result<Page> {
159 let mut encrypted = [0u8; PAGE_SIZE];
160 io.read_page(offset, &mut encrypted)?;
161
162 let mut body = [0u8; BODY_SIZE];
164 page_cipher::decrypt_page(dek, mac_key, page_id, encryption_epoch, &encrypted, &mut body)?;
165
166 let page = Page::from_bytes(body);
167
168 if !page.verify_checksum() {
170 return Err(Error::ChecksumMismatch(page_id));
171 }
172
173 Ok(page)
174 }
175
176 pub fn len(&self) -> usize {
178 self.cache.len()
179 }
180
181 pub fn dirty_count(&self) -> usize {
183 self.cache.dirty_count()
184 }
185
186 pub fn capacity(&self) -> usize {
188 self.capacity
189 }
190
191 pub fn is_cached(&self, page_id: PageId) -> bool {
193 let offset = page_offset(page_id);
194 self.cache.contains(offset)
195 }
196
197 pub fn clear(&mut self) {
199 self.cache.clear();
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use citadel_core::types::PageType;
207 use citadel_core::types::TxnId;
208 use citadel_crypto::hkdf_utils::derive_keys_from_rek;
209
210 struct MockIO {
211 pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
212 }
213
214 impl MockIO {
215 fn new() -> Self {
216 Self { pages: std::sync::Mutex::new(std::collections::HashMap::new()) }
217 }
218 }
219
220 impl PageIO for MockIO {
221 fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
222 let pages = self.pages.lock().unwrap();
223 if let Some(data) = pages.get(&offset) {
224 buf.copy_from_slice(data);
225 Ok(())
226 } else {
227 Err(Error::Io(std::io::Error::new(
228 std::io::ErrorKind::NotFound,
229 format!("no page at offset {offset}"),
230 )))
231 }
232 }
233
234 fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
235 self.pages.lock().unwrap().insert(offset, *buf);
236 Ok(())
237 }
238
239 fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> { Ok(()) }
240 fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> { Ok(()) }
241 fn fsync(&self) -> Result<()> { Ok(()) }
242 fn file_size(&self) -> Result<u64> { Ok(0) }
243 fn truncate(&self, _size: u64) -> Result<()> { Ok(()) }
244 }
245
246 fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
247 let rek = [0x42u8; 32];
248 let keys = derive_keys_from_rek(&rek);
249 (keys.dek, keys.mac_key)
250 }
251
252 fn write_encrypted_page(io: &MockIO, page: &Page, dek: &[u8; DEK_SIZE], mac_key: &[u8; MAC_KEY_SIZE], epoch: u32) {
253 let page_id = page.page_id();
254 let offset = page_offset(page_id);
255 let mut encrypted = [0u8; PAGE_SIZE];
256 page_cipher::encrypt_page(dek, mac_key, page_id, epoch, page.as_bytes(), &mut encrypted);
257 io.write_page(offset, &encrypted).unwrap();
258 }
259
260 #[test]
261 fn fetch_reads_and_caches() {
262 let (dek, mac_key) = test_keys();
263 let io = MockIO::new();
264 let epoch = 1;
265
266 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
267 page.update_checksum();
268 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
269
270 let mut pool = BufferPool::new(16);
271 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
272 assert_eq!(fetched.page_id(), PageId(0));
273 assert!(pool.is_cached(PageId(0)));
274 }
275
276 #[test]
277 fn fetch_from_cache_on_second_call() {
278 let (dek, mac_key) = test_keys();
279 let io = MockIO::new();
280 let epoch = 1;
281
282 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
283 page.update_checksum();
284 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
285
286 let mut pool = BufferPool::new(16);
287 pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
288
289 io.pages.lock().unwrap().clear();
291 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
292 assert_eq!(fetched.page_id(), PageId(0));
293 }
294
295 #[test]
296 fn tampered_page_detected_on_fetch() {
297 let (dek, mac_key) = test_keys();
298 let io = MockIO::new();
299 let epoch = 1;
300
301 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
302 page.update_checksum();
303 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
304
305 let offset = page_offset(PageId(0));
307 {
308 let mut pages = io.pages.lock().unwrap();
309 let data = pages.get_mut(&offset).unwrap();
310 data[100] ^= 0x01;
311 }
312
313 let mut pool = BufferPool::new(16);
314 let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
315 assert!(matches!(result, Err(Error::PageTampered(_))));
316 }
317
318 #[test]
319 fn dirty_pages_survive_eviction() {
320 let mut pool = BufferPool::new(3);
321
322 for i in 0..3 {
324 let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
325 page.update_checksum();
326 pool.insert_new(PageId(i), page).unwrap();
327 }
328
329 assert_eq!(pool.dirty_count(), 3);
330
331 pool.cache.clear_dirty(page_offset(PageId(0)));
333 pool.cache.clear_dirty(page_offset(PageId(2)));
334
335 let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
337 page3.update_checksum();
338 pool.insert_new(PageId(3), page3).unwrap();
339 assert!(pool.is_cached(PageId(1)));
341 }
342
343 #[test]
344 fn flush_dirty_writes_encrypted() {
345 let (dek, mac_key) = test_keys();
346 let io = MockIO::new();
347 let epoch = 1;
348
349 let mut pool = BufferPool::new(16);
350 let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
351 page.update_checksum();
352 pool.insert_new(PageId(5), page).unwrap();
353
354 assert_eq!(pool.dirty_count(), 1);
355
356 pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
357 assert_eq!(pool.dirty_count(), 0);
358
359 let offset = page_offset(PageId(5));
361 assert!(io.pages.lock().unwrap().contains_key(&offset));
362 }
363
364 #[test]
365 fn discard_dirty_removes_from_cache() {
366 let mut pool = BufferPool::new(16);
367 let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
368 page.update_checksum();
369 pool.insert_new(PageId(1), page).unwrap();
370
371 assert_eq!(pool.len(), 1);
372 pool.discard_dirty();
373 assert_eq!(pool.len(), 0);
374 }
375}