1use std::sync::Arc;
2
3use citadel_core::types::PageId;
4use citadel_core::{Error, Result};
5use citadel_core::{BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE};
6
7use citadel_crypto::page_cipher;
8use citadel_io::file_manager::page_offset;
9use citadel_io::traits::PageIO;
10use citadel_page::page::Page;
11
12use crate::sieve::SieveCache;
13
14pub fn read_and_decrypt(
15 io: &dyn PageIO,
16 page_id: PageId,
17 offset: u64,
18 dek: &[u8; DEK_SIZE],
19 mac_key: &[u8; MAC_KEY_SIZE],
20 encryption_epoch: u32,
21) -> Result<Page> {
22 let mut encrypted = [0u8; PAGE_SIZE];
23 io.read_page(offset, &mut encrypted)?;
24
25 let mut body = [0u8; BODY_SIZE];
26 page_cipher::decrypt_page(
27 dek,
28 mac_key,
29 page_id,
30 encryption_epoch,
31 &encrypted,
32 &mut body,
33 )?;
34
35 let page = Page::from_bytes(body);
36
37 if !page.verify_checksum() {
38 return Err(Error::ChecksumMismatch(page_id));
39 }
40
41 Ok(page)
42}
43
44pub struct BufferPool {
54 cache: SieveCache<Arc<Page>>,
55 capacity: usize,
56}
57
58impl BufferPool {
59 pub fn new(capacity: usize) -> Self {
60 Self {
61 cache: SieveCache::new(capacity),
62 capacity,
63 }
64 }
65
66 pub fn fetch(
71 &mut self,
72 io: &dyn PageIO,
73 page_id: PageId,
74 dek: &[u8; DEK_SIZE],
75 mac_key: &[u8; MAC_KEY_SIZE],
76 encryption_epoch: u32,
77 ) -> Result<&Page> {
78 let offset = page_offset(page_id);
79
80 if self.cache.contains(offset) {
82 return Ok(self.cache.get(offset).unwrap());
83 }
84
85 let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
87
88 self.cache
90 .insert(offset, Arc::new(page))
91 .map_err(|()| Error::BufferPoolFull)?;
92
93 Ok(self.cache.get(offset).unwrap())
94 }
95
96 pub fn fetch_mut(
98 &mut self,
99 io: &dyn PageIO,
100 page_id: PageId,
101 dek: &[u8; DEK_SIZE],
102 mac_key: &[u8; MAC_KEY_SIZE],
103 encryption_epoch: u32,
104 ) -> Result<&mut Page> {
105 let offset = page_offset(page_id);
106
107 if !self.cache.contains(offset) {
108 let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
109 self.cache
110 .insert(offset, Arc::new(page))
111 .map_err(|()| Error::BufferPoolFull)?;
112 }
113
114 Ok(Arc::make_mut(self.cache.get_mut(offset).unwrap()))
115 }
116
117 pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
120 let offset = page_offset(page_id);
121
122 if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
124 self.cache
126 .insert(offset, Arc::new(page))
127 .map_err(|()| Error::TransactionTooLarge {
128 capacity: self.capacity,
129 })?;
130 } else {
131 self.cache
132 .insert(offset, Arc::new(page))
133 .map_err(|()| Error::BufferPoolFull)?;
134 }
135
136 self.cache.set_dirty(offset);
137 Ok(())
138 }
139
140 pub fn mark_dirty(&mut self, page_id: PageId) {
142 let offset = page_offset(page_id);
143 self.cache.set_dirty(offset);
144 }
145
146 pub fn flush_dirty(
149 &mut self,
150 io: &dyn PageIO,
151 dek: &[u8; DEK_SIZE],
152 mac_key: &[u8; MAC_KEY_SIZE],
153 encryption_epoch: u32,
154 ) -> Result<()> {
155 let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self
157 .cache
158 .dirty_entries()
159 .map(|(offset, arc)| {
160 let page_id = arc.page_id();
161 let body = *arc.as_bytes();
162 (offset, page_id, body)
163 })
164 .collect();
165
166 for (offset, page_id, body) in &dirty {
167 let mut encrypted = [0u8; PAGE_SIZE];
168 page_cipher::encrypt_page(
169 dek,
170 mac_key,
171 *page_id,
172 encryption_epoch,
173 body,
174 &mut encrypted,
175 );
176 io.write_page(*offset, &encrypted)?;
177 }
178
179 self.cache.clear_all_dirty();
180 Ok(())
181 }
182
183 pub fn discard_dirty(&mut self) {
186 let dirty_offsets: Vec<u64> = self
187 .cache
188 .dirty_entries()
189 .map(|(offset, _)| offset)
190 .collect();
191
192 for offset in dirty_offsets {
193 self.cache.remove(offset);
194 }
195 }
196
197 pub fn get_cached(&mut self, page_id: PageId) -> Option<Arc<Page>> {
198 let offset = page_offset(page_id);
199 self.cache.get(offset).map(Arc::clone)
200 }
201
202 pub fn insert_if_absent(&mut self, page_id: PageId, page: Arc<Page>) {
203 let offset = page_offset(page_id);
204 if !self.cache.contains(offset) {
205 let _ = self.cache.insert(offset, page);
206 }
207 }
208
209 pub fn len(&self) -> usize {
211 self.cache.len()
212 }
213
214 pub fn is_empty(&self) -> bool {
216 self.cache.is_empty()
217 }
218
219 pub fn dirty_count(&self) -> usize {
221 self.cache.dirty_count()
222 }
223
224 pub fn capacity(&self) -> usize {
226 self.capacity
227 }
228
229 pub fn is_cached(&self, page_id: PageId) -> bool {
231 let offset = page_offset(page_id);
232 self.cache.contains(offset)
233 }
234
235 pub fn invalidate(&mut self, page_id: PageId) {
237 let offset = page_offset(page_id);
238 self.cache.remove(offset);
239 }
240
241 pub fn clear(&mut self) {
243 self.cache.clear();
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use citadel_core::types::PageType;
251 use citadel_core::types::TxnId;
252 use citadel_crypto::hkdf_utils::derive_keys_from_rek;
253
254 struct MockIO {
255 pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
256 }
257
258 impl MockIO {
259 fn new() -> Self {
260 Self {
261 pages: std::sync::Mutex::new(std::collections::HashMap::new()),
262 }
263 }
264 }
265
266 impl PageIO for MockIO {
267 fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
268 let pages = self.pages.lock().unwrap();
269 if let Some(data) = pages.get(&offset) {
270 buf.copy_from_slice(data);
271 Ok(())
272 } else {
273 Err(Error::Io(std::io::Error::new(
274 std::io::ErrorKind::NotFound,
275 format!("no page at offset {offset}"),
276 )))
277 }
278 }
279
280 fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
281 self.pages.lock().unwrap().insert(offset, *buf);
282 Ok(())
283 }
284
285 fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> {
286 Ok(())
287 }
288 fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> {
289 Ok(())
290 }
291 fn fsync(&self) -> Result<()> {
292 Ok(())
293 }
294 fn file_size(&self) -> Result<u64> {
295 Ok(0)
296 }
297 fn truncate(&self, _size: u64) -> Result<()> {
298 Ok(())
299 }
300 }
301
302 fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
303 let rek = [0x42u8; 32];
304 let keys = derive_keys_from_rek(&rek);
305 (keys.dek, keys.mac_key)
306 }
307
308 fn write_encrypted_page(
309 io: &MockIO,
310 page: &Page,
311 dek: &[u8; DEK_SIZE],
312 mac_key: &[u8; MAC_KEY_SIZE],
313 epoch: u32,
314 ) {
315 let page_id = page.page_id();
316 let offset = page_offset(page_id);
317 let mut encrypted = [0u8; PAGE_SIZE];
318 page_cipher::encrypt_page(
319 dek,
320 mac_key,
321 page_id,
322 epoch,
323 page.as_bytes(),
324 &mut encrypted,
325 );
326 io.write_page(offset, &encrypted).unwrap();
327 }
328
329 #[test]
330 fn fetch_reads_and_caches() {
331 let (dek, mac_key) = test_keys();
332 let io = MockIO::new();
333 let epoch = 1;
334
335 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
336 page.update_checksum();
337 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
338
339 let mut pool = BufferPool::new(16);
340 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
341 assert_eq!(fetched.page_id(), PageId(0));
342 assert!(pool.is_cached(PageId(0)));
343 }
344
345 #[test]
346 fn fetch_from_cache_on_second_call() {
347 let (dek, mac_key) = test_keys();
348 let io = MockIO::new();
349 let epoch = 1;
350
351 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
352 page.update_checksum();
353 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
354
355 let mut pool = BufferPool::new(16);
356 pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
357
358 io.pages.lock().unwrap().clear();
360 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
361 assert_eq!(fetched.page_id(), PageId(0));
362 }
363
364 #[test]
365 fn tampered_page_detected_on_fetch() {
366 let (dek, mac_key) = test_keys();
367 let io = MockIO::new();
368 let epoch = 1;
369
370 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
371 page.update_checksum();
372 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
373
374 let offset = page_offset(PageId(0));
376 {
377 let mut pages = io.pages.lock().unwrap();
378 let data = pages.get_mut(&offset).unwrap();
379 data[100] ^= 0x01;
380 }
381
382 let mut pool = BufferPool::new(16);
383 let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
384 assert!(matches!(result, Err(Error::PageTampered(_))));
385 }
386
387 #[test]
388 fn dirty_pages_survive_eviction() {
389 let mut pool = BufferPool::new(3);
390
391 for i in 0..3 {
393 let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
394 page.update_checksum();
395 pool.insert_new(PageId(i), page).unwrap();
396 }
397
398 assert_eq!(pool.dirty_count(), 3);
399
400 pool.cache.clear_dirty(page_offset(PageId(0)));
402 pool.cache.clear_dirty(page_offset(PageId(2)));
403
404 let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
406 page3.update_checksum();
407 pool.insert_new(PageId(3), page3).unwrap();
408 assert!(pool.is_cached(PageId(1)));
410 }
411
412 #[test]
413 fn flush_dirty_writes_encrypted() {
414 let (dek, mac_key) = test_keys();
415 let io = MockIO::new();
416 let epoch = 1;
417
418 let mut pool = BufferPool::new(16);
419 let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
420 page.update_checksum();
421 pool.insert_new(PageId(5), page).unwrap();
422
423 assert_eq!(pool.dirty_count(), 1);
424
425 pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
426 assert_eq!(pool.dirty_count(), 0);
427
428 let offset = page_offset(PageId(5));
430 assert!(io.pages.lock().unwrap().contains_key(&offset));
431 }
432
433 #[test]
434 fn discard_dirty_removes_from_cache() {
435 let mut pool = BufferPool::new(16);
436 let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
437 page.update_checksum();
438 pool.insert_new(PageId(1), page).unwrap();
439
440 assert_eq!(pool.len(), 1);
441 pool.discard_dirty();
442 assert_eq!(pool.len(), 0);
443 }
444}