1use std::sync::Arc;
2
3use citadel_core::types::PageId;
4use citadel_core::{Error, Result};
5use citadel_core::{BODY_SIZE, DEK_SIZE, MAC_KEY_SIZE, PAGE_SIZE};
6
7use citadel_crypto::page_cipher;
8use citadel_io::file_manager::page_offset;
9use citadel_io::traits::PageIO;
10use citadel_page::page::Page;
11
12use crate::sieve::SieveCache;
13
14pub fn read_and_decrypt(
15 io: &dyn PageIO,
16 page_id: PageId,
17 offset: u64,
18 dek: &[u8; DEK_SIZE],
19 mac_key: &[u8; MAC_KEY_SIZE],
20 encryption_epoch: u32,
21) -> Result<Page> {
22 let mut encrypted = [0u8; PAGE_SIZE];
23 io.read_page(offset, &mut encrypted)?;
24
25 let mut body = [0u8; BODY_SIZE];
26 page_cipher::decrypt_page(
27 dek,
28 mac_key,
29 page_id,
30 encryption_epoch,
31 &encrypted,
32 &mut body,
33 )?;
34
35 let page = Page::from_bytes(body);
36
37 if !page.verify_checksum() {
38 return Err(Error::ChecksumMismatch(page_id));
39 }
40
41 Ok(page)
42}
43
44pub struct BufferPool {
54 cache: SieveCache<Arc<Page>>,
55 capacity: usize,
56}
57
58impl BufferPool {
59 pub fn new(capacity: usize) -> Self {
60 Self {
61 cache: SieveCache::new(capacity),
62 capacity,
63 }
64 }
65
66 pub fn fetch(
71 &mut self,
72 io: &dyn PageIO,
73 page_id: PageId,
74 dek: &[u8; DEK_SIZE],
75 mac_key: &[u8; MAC_KEY_SIZE],
76 encryption_epoch: u32,
77 ) -> Result<&Page> {
78 let offset = page_offset(page_id);
79
80 if self.cache.contains(offset) {
81 return Ok(self.cache.get(offset).unwrap());
82 }
83
84 let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
85 self.cache
86 .insert(offset, Arc::new(page))
87 .map_err(|()| Error::BufferPoolFull)?;
88
89 Ok(self.cache.get(offset).unwrap())
90 }
91
92 pub fn fetch_mut(
93 &mut self,
94 io: &dyn PageIO,
95 page_id: PageId,
96 dek: &[u8; DEK_SIZE],
97 mac_key: &[u8; MAC_KEY_SIZE],
98 encryption_epoch: u32,
99 ) -> Result<&mut Page> {
100 let offset = page_offset(page_id);
101
102 if !self.cache.contains(offset) {
103 let page = read_and_decrypt(io, page_id, offset, dek, mac_key, encryption_epoch)?;
104 self.cache
105 .insert(offset, Arc::new(page))
106 .map_err(|()| Error::BufferPoolFull)?;
107 }
108
109 Ok(Arc::make_mut(self.cache.get_mut(offset).unwrap()))
110 }
111
112 pub fn insert_new(&mut self, page_id: PageId, page: Page) -> Result<()> {
114 let offset = page_offset(page_id);
115
116 if self.cache.len() >= self.capacity && !self.cache.contains(offset) {
117 self.cache
118 .insert(offset, Arc::new(page))
119 .map_err(|()| Error::TransactionTooLarge {
120 capacity: self.capacity,
121 })?;
122 } else {
123 self.cache
124 .insert(offset, Arc::new(page))
125 .map_err(|()| Error::BufferPoolFull)?;
126 }
127
128 self.cache.set_dirty(offset);
129 Ok(())
130 }
131
132 pub fn mark_dirty(&mut self, page_id: PageId) {
133 let offset = page_offset(page_id);
134 self.cache.set_dirty(offset);
135 }
136
137 pub fn flush_dirty(
140 &mut self,
141 io: &dyn PageIO,
142 dek: &[u8; DEK_SIZE],
143 mac_key: &[u8; MAC_KEY_SIZE],
144 encryption_epoch: u32,
145 ) -> Result<()> {
146 let dirty: Vec<(u64, PageId, [u8; BODY_SIZE])> = self
147 .cache
148 .dirty_entries()
149 .map(|(offset, arc)| {
150 let page_id = arc.page_id();
151 let body = *arc.as_bytes();
152 (offset, page_id, body)
153 })
154 .collect();
155
156 for (offset, page_id, body) in &dirty {
157 let mut encrypted = [0u8; PAGE_SIZE];
158 page_cipher::encrypt_page(
159 dek,
160 mac_key,
161 *page_id,
162 encryption_epoch,
163 body,
164 &mut encrypted,
165 );
166 io.write_page(*offset, &encrypted)?;
167 }
168
169 self.cache.clear_all_dirty();
170 Ok(())
171 }
172
173 pub fn discard_dirty(&mut self) {
176 let dirty_offsets: Vec<u64> = self
177 .cache
178 .dirty_entries()
179 .map(|(offset, _)| offset)
180 .collect();
181
182 for offset in dirty_offsets {
183 self.cache.remove(offset);
184 }
185 }
186
187 pub fn get_cached(&mut self, page_id: PageId) -> Option<Arc<Page>> {
188 let offset = page_offset(page_id);
189 self.cache.get(offset).map(Arc::clone)
190 }
191
192 pub fn insert_if_absent(&mut self, page_id: PageId, page: Arc<Page>) {
193 let offset = page_offset(page_id);
194 if !self.cache.contains(offset) {
195 let _ = self.cache.insert(offset, page);
196 }
197 }
198
199 pub fn len(&self) -> usize {
200 self.cache.len()
201 }
202
203 pub fn is_empty(&self) -> bool {
204 self.cache.is_empty()
205 }
206
207 pub fn dirty_count(&self) -> usize {
208 self.cache.dirty_count()
209 }
210
211 pub fn capacity(&self) -> usize {
212 self.capacity
213 }
214
215 pub fn is_cached(&self, page_id: PageId) -> bool {
216 let offset = page_offset(page_id);
217 self.cache.contains(offset)
218 }
219
220 pub fn invalidate(&mut self, page_id: PageId) {
221 let offset = page_offset(page_id);
222 self.cache.remove(offset);
223 }
224
225 pub fn clear(&mut self) {
226 self.cache.clear();
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use citadel_core::types::PageType;
234 use citadel_core::types::TxnId;
235 use citadel_crypto::hkdf_utils::derive_keys_from_rek;
236
237 struct MockIO {
238 pages: std::sync::Mutex<std::collections::HashMap<u64, [u8; PAGE_SIZE]>>,
239 }
240
241 impl MockIO {
242 fn new() -> Self {
243 Self {
244 pages: std::sync::Mutex::new(std::collections::HashMap::new()),
245 }
246 }
247 }
248
249 impl PageIO for MockIO {
250 fn read_page(&self, offset: u64, buf: &mut [u8; PAGE_SIZE]) -> Result<()> {
251 let pages = self.pages.lock().unwrap();
252 if let Some(data) = pages.get(&offset) {
253 buf.copy_from_slice(data);
254 Ok(())
255 } else {
256 Err(Error::Io(std::io::Error::new(
257 std::io::ErrorKind::NotFound,
258 format!("no page at offset {offset}"),
259 )))
260 }
261 }
262
263 fn write_page(&self, offset: u64, buf: &[u8; PAGE_SIZE]) -> Result<()> {
264 self.pages.lock().unwrap().insert(offset, *buf);
265 Ok(())
266 }
267
268 fn read_at(&self, _offset: u64, _buf: &mut [u8]) -> Result<()> {
269 Ok(())
270 }
271 fn write_at(&self, _offset: u64, _buf: &[u8]) -> Result<()> {
272 Ok(())
273 }
274 fn fsync(&self) -> Result<()> {
275 Ok(())
276 }
277 fn file_size(&self) -> Result<u64> {
278 Ok(0)
279 }
280 fn truncate(&self, _size: u64) -> Result<()> {
281 Ok(())
282 }
283 }
284
285 fn test_keys() -> ([u8; DEK_SIZE], [u8; MAC_KEY_SIZE]) {
286 let rek = [0x42u8; 32];
287 let keys = derive_keys_from_rek(&rek);
288 (keys.dek, keys.mac_key)
289 }
290
291 fn write_encrypted_page(
292 io: &MockIO,
293 page: &Page,
294 dek: &[u8; DEK_SIZE],
295 mac_key: &[u8; MAC_KEY_SIZE],
296 epoch: u32,
297 ) {
298 let page_id = page.page_id();
299 let offset = page_offset(page_id);
300 let mut encrypted = [0u8; PAGE_SIZE];
301 page_cipher::encrypt_page(
302 dek,
303 mac_key,
304 page_id,
305 epoch,
306 page.as_bytes(),
307 &mut encrypted,
308 );
309 io.write_page(offset, &encrypted).unwrap();
310 }
311
312 #[test]
313 fn fetch_reads_and_caches() {
314 let (dek, mac_key) = test_keys();
315 let io = MockIO::new();
316 let epoch = 1;
317
318 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
319 page.update_checksum();
320 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
321
322 let mut pool = BufferPool::new(16);
323 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
324 assert_eq!(fetched.page_id(), PageId(0));
325 assert!(pool.is_cached(PageId(0)));
326 }
327
328 #[test]
329 fn fetch_from_cache_on_second_call() {
330 let (dek, mac_key) = test_keys();
331 let io = MockIO::new();
332 let epoch = 1;
333
334 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
335 page.update_checksum();
336 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
337
338 let mut pool = BufferPool::new(16);
339 pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
340
341 io.pages.lock().unwrap().clear();
343 let fetched = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch).unwrap();
344 assert_eq!(fetched.page_id(), PageId(0));
345 }
346
347 #[test]
348 fn tampered_page_detected_on_fetch() {
349 let (dek, mac_key) = test_keys();
350 let io = MockIO::new();
351 let epoch = 1;
352
353 let mut page = Page::new(PageId(0), PageType::Leaf, TxnId(1));
354 page.update_checksum();
355 write_encrypted_page(&io, &page, &dek, &mac_key, epoch);
356
357 let offset = page_offset(PageId(0));
359 {
360 let mut pages = io.pages.lock().unwrap();
361 let data = pages.get_mut(&offset).unwrap();
362 data[100] ^= 0x01;
363 }
364
365 let mut pool = BufferPool::new(16);
366 let result = pool.fetch(&io, PageId(0), &dek, &mac_key, epoch);
367 assert!(matches!(result, Err(Error::PageTampered(_))));
368 }
369
370 #[test]
371 fn dirty_pages_survive_eviction() {
372 let mut pool = BufferPool::new(3);
373
374 for i in 0..3 {
376 let mut page = Page::new(PageId(i), PageType::Leaf, TxnId(1));
377 page.update_checksum();
378 pool.insert_new(PageId(i), page).unwrap();
379 }
380
381 assert_eq!(pool.dirty_count(), 3);
382
383 pool.cache.clear_dirty(page_offset(PageId(0)));
385 pool.cache.clear_dirty(page_offset(PageId(2)));
386
387 let mut page3 = Page::new(PageId(3), PageType::Leaf, TxnId(1));
389 page3.update_checksum();
390 pool.insert_new(PageId(3), page3).unwrap();
391 assert!(pool.is_cached(PageId(1)));
393 }
394
395 #[test]
396 fn flush_dirty_writes_encrypted() {
397 let (dek, mac_key) = test_keys();
398 let io = MockIO::new();
399 let epoch = 1;
400
401 let mut pool = BufferPool::new(16);
402 let mut page = Page::new(PageId(5), PageType::Leaf, TxnId(1));
403 page.update_checksum();
404 pool.insert_new(PageId(5), page).unwrap();
405
406 assert_eq!(pool.dirty_count(), 1);
407
408 pool.flush_dirty(&io, &dek, &mac_key, epoch).unwrap();
409 assert_eq!(pool.dirty_count(), 0);
410
411 let offset = page_offset(PageId(5));
413 assert!(io.pages.lock().unwrap().contains_key(&offset));
414 }
415
416 #[test]
417 fn discard_dirty_removes_from_cache() {
418 let mut pool = BufferPool::new(16);
419 let mut page = Page::new(PageId(1), PageType::Leaf, TxnId(1));
420 page.update_checksum();
421 pool.insert_new(PageId(1), page).unwrap();
422
423 assert_eq!(pool.len(), 1);
424 pool.discard_dirty();
425 assert_eq!(pool.len(), 0);
426 }
427}