1use citadel_core::types::PageId;
2use citadel_core::{Error, Result};
3use citadel_core::{BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE};
4
5use citadel_crypto::page_cipher;
6use citadel_io::file_manager::page_offset;
7use citadel_io::traits::PageIO;
8use citadel_page::page::Page;
9
10use crate::sieve::SieveCache;
11
12pub struct BufferPool {
22 cache: SieveCache<Page>,
23 capacity: usize,
24}
25
26impl BufferPool {
27 pub fn new(capacity: usize) -> Self {
28 Self {
29 cache: SieveCache::new(capacity),
30 capacity,
31 }
32 }
33
34 pub fn fetch(
39 &mut self,
40 io: &dyn PageIO,
41 page_id: PageId,
42 dek: &[u8; DEK_SIZE],
43 mac_key: &[u8; MAC_KEY_SIZE],
44 encryption_epoch: u32,
45 ) -> Result<&Page> {
46 let offset = page_offset(page_id);
47
48 if self.cache.contains(offset) {
50 return Ok(self.cache.get(offset).unwrap());
51 }
52
53 let page = self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
55
56 self.cache
58 .insert(offset, page)
59 .map_err(|()| Error::BufferPoolFull)?;
60
61 Ok(self.cache.get(offset).unwrap())
62 }
63
64 pub fn fetch_mut(
66 &mut self,
67 io: &dyn PageIO,
68 page_id: PageId,
69 dek: &[u8; DEK_SIZE],
70 mac_key: &[u8; MAC_KEY_SIZE],
71 encryption_epoch: u32,
72 ) -> Result<&mut Page> {
73 let offset = page_offset(page_id);
74
75 if !self.cache.contains(offset) {
76 let page =
77 self.read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
78 self.cache
79 .insert(offset, page)
80 .map_err(|()| Error::BufferPoolFull)?;
81 }
82
83 Ok(self.cache.get_mut(offset).unwrap())
84 }
85
86 pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
89 let offset = page_offset(page_id);
90
91 if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
93 self.cache
95 .insert(offset, page)
96 .map_err(|()| Error::TransactionTooLarge {
97 capacity: self.capacity,
98 })?;
99 } else {
100 self.cache
101 .insert(offset, page)
102 .map_err(|()| Error::BufferPoolFull)?;
103 }
104
105 self.cache.set_dirty(offset);
106 Ok(())
107 }
108
109 pub fn mark_dirty(&mut self, page_id: PageId) {
111 let offset = page_offset(page_id);
112 self.cache.set_dirty(offset);
113 }
114
115 pub fn flush_dirty(
118 &mut self,
119 io: &dyn PageIO,
120 dek: &[u8; DEK_SIZE],
121 mac_key: &[u8; MAC_KEY_SIZE],
122 encryption_epoch: u32,
123 ) -> Result<()> {
124 let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self
126 .cache
127 .dirty_entries()
128 .map(|(offset, page)| {
129 let page_id = page.page_id();
130 let body = *page.as_bytes();
131 (offset, page_id, body)
132 })
133 .collect();
134
135 for (offset, page_id, body) in &dirty {
136 let mut encrypted = [0u8; PAGE_SIZE];
137 page_cipher::encrypt_page(
138 dek,
139 mac_key,
140 *page_id,
141 encryption_epoch,
142 body,
143 &mut encrypted,
144 );
145 io.write_page(*offset, &encrypted)?;
146 }
147
148 self.cache.clear_all_dirty();
149 Ok(())
150 }
151
152 pub fn discard_dirty(&mut self) {
155 let dirty_offsets: Vec<u64> = self
156 .cache
157 .dirty_entries()
158 .map(|(offset, _)| offset)
159 .collect();
160
161 for offset in dirty_offsets {
162 self.cache.remove(offset);
163 }
164 }
165
166 fn read_and_decrypt(
167 &self,
168 io: &dyn PageIO,
169 page_id: PageId,
170 offset: u64,
171 dek: &[u8; DEK_SIZE],
172 mac_key: &[u8; MAC_KEY_SIZE],
173 encryption_epoch: u32,
174 ) -> Result<Page> {
175 let mut encrypted = [0u8; PAGE_SIZE];
176 io.read_page(offset, &mut encrypted)?;
177
178 let mut body = [0u8; BODY_SIZE];
180 page_cipher::decrypt_page(
181 dek,
182 mac_key,
183 page_id,
184 encryption_epoch,
185 &encrypted,
186 &mut body,
187 )?;
188
189 let page = Page::from_bytes(body);
190
191 if !page.verify_checksum() {
193 return Err(Error::ChecksumMismatch(page_id));
194 }
195
196 Ok(page)
197 }
198
199 pub fn len(&self) -> usize {
201 self.cache.len()
202 }
203
204 pub fn is_empty(&self) -> bool {
206 self.cache.is_empty()
207 }
208
209 pub fn dirty_count(&self) -> usize {
211 self.cache.dirty_count()
212 }
213
214 pub fn capacity(&self) -> usize {
216 self.capacity
217 }
218
219 pub fn is_cached(&self, page_id: PageId) -> bool {
221 let offset = page_offset(page_id);
222 self.cache.contains(offset)
223 }
224
225 pub fn invalidate(&mut self, page_id: PageId) {
227 let offset = page_offset(page_id);
228 self.cache.remove(offset);
229 }
230
231 pub fn clear(&mut self) {
233 self.cache.clear();
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use citadel_core::types::PageType;
241 use citadel_core::types::TxnId;
242 use citadel_crypto::hkdf_utils::derive_keys_from_rek;
243
244 struct MockIO {
245 pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
246 }
247
248 impl MockIO {
249 fn new() -> Self {
250 Self {
251 pages: std::sync::Mutex::new(std::collections::HashMap::new()),
252 }
253 }
254 }
255
256 impl PageIO for MockIO {
257 fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
258 let pages = self.pages.lock().unwrap();
259 if let Some(data) = pages.get(&offset) {
260 buf.copy_from_slice(data);
261 Ok(())
262 } else {
263 Err(Error::Io(std::io::Error::new(
264 std::io::ErrorKind::NotFound,
265 format!("no page at offset {offset}"),
266 )))
267 }
268 }
269
270 fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
271 self.pages.lock().unwrap().insert(offset, *buf);
272 Ok(())
273 }
274
275 fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> {
276 Ok(())
277 }
278 fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> {
279 Ok(())
280 }
281 fn fsync(&self) -> Result<()> {
282 Ok(())
283 }
284 fn file_size(&self) -> Result<u64> {
285 Ok(0)
286 }
287 fn truncate(&self, _size: u64) -> Result<()> {
288 Ok(())
289 }
290 }
291
292 fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
293 let rek = [0x42u8; 32];
294 let keys = derive_keys_from_rek(&rek);
295 (keys.dek, keys.mac_key)
296 }
297
298 fn write_encrypted_page(
299 io: &MockIO,
300 page: &Page,
301 dek: &[u8; DEK_SIZE],
302 mac_key: &[u8; MAC_KEY_SIZE],
303 epoch: u32,
304 ) {
305 let page_id = page.page_id();
306 let offset = page_offset(page_id);
307 let mut encrypted = [0u8; PAGE_SIZE];
308 page_cipher::encrypt_page(
309 dek,
310 mac_key,
311 page_id,
312 epoch,
313 page.as_bytes(),
314 &mut encrypted,
315 );
316 io.write_page(offset, &encrypted).unwrap();
317 }
318
319 #[test]
320 fn fetch_reads_and_caches() {
321 let (dek, mac_key) = test_keys();
322 let io = MockIO::new();
323 let epoch = 1;
324
325 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
326 page.update_checksum();
327 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
328
329 let mut pool = BufferPool::new(16);
330 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
331 assert_eq!(fetched.page_id(), PageId(0));
332 assert!(pool.is_cached(PageId(0)));
333 }
334
335 #[test]
336 fn fetch_from_cache_on_second_call() {
337 let (dek, mac_key) = test_keys();
338 let io = MockIO::new();
339 let epoch = 1;
340
341 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
342 page.update_checksum();
343 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
344
345 let mut pool = BufferPool::new(16);
346 pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
347
348 io.pages.lock().unwrap().clear();
350 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
351 assert_eq!(fetched.page_id(), PageId(0));
352 }
353
354 #[test]
355 fn tampered_page_detected_on_fetch() {
356 let (dek, mac_key) = test_keys();
357 let io = MockIO::new();
358 let epoch = 1;
359
360 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
361 page.update_checksum();
362 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
363
364 let offset = page_offset(PageId(0));
366 {
367 let mut pages = io.pages.lock().unwrap();
368 let data = pages.get_mut(&offset).unwrap();
369 data[100] ^= 0x01;
370 }
371
372 let mut pool = BufferPool::new(16);
373 let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
374 assert!(matches!(result, Err(Error::PageTampered(_))));
375 }
376
377 #[test]
378 fn dirty_pages_survive_eviction() {
379 let mut pool = BufferPool::new(3);
380
381 for i in 0..3 {
383 let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
384 page.update_checksum();
385 pool.insert_new(PageId(i), page).unwrap();
386 }
387
388 assert_eq!(pool.dirty_count(), 3);
389
390 pool.cache.clear_dirty(page_offset(PageId(0)));
392 pool.cache.clear_dirty(page_offset(PageId(2)));
393
394 let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
396 page3.update_checksum();
397 pool.insert_new(PageId(3), page3).unwrap();
398 assert!(pool.is_cached(PageId(1)));
400 }
401
402 #[test]
403 fn flush_dirty_writes_encrypted() {
404 let (dek, mac_key) = test_keys();
405 let io = MockIO::new();
406 let epoch = 1;
407
408 let mut pool = BufferPool::new(16);
409 let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
410 page.update_checksum();
411 pool.insert_new(PageId(5), page).unwrap();
412
413 assert_eq!(pool.dirty_count(), 1);
414
415 pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
416 assert_eq!(pool.dirty_count(), 0);
417
418 let offset = page_offset(PageId(5));
420 assert!(io.pages.lock().unwrap().contains_key(&offset));
421 }
422
423 #[test]
424 fn discard_dirty_removes_from_cache() {
425 let mut pool = BufferPool::new(16);
426 let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
427 page.update_checksum();
428 pool.insert_new(PageId(1), page).unwrap();
429
430 assert_eq!(pool.len(), 1);
431 pool.discard_dirty();
432 assert_eq!(pool.len(), 0);
433 }
434}