1use ahash::AHashMap;
8use parking_lot::Mutex;
9
10use crate::page::{Page, PageId, PageManager};
11use mentedb_core::error::{MenteError, MenteResult};
12use tracing::{debug, trace};
13
14type FrameId = usize;
15
16struct Frame {
18 page: Box<Page>,
19 page_id: Option<PageId>,
20 pin_count: u32,
21 dirty: bool,
22 reference: bool,
24}
25
26impl Frame {
27 fn new() -> Self {
28 Self {
29 page: Box::new(Page::zeroed()),
30 page_id: None,
31 pin_count: 0,
32 dirty: false,
33 reference: false,
34 }
35 }
36}
37
38struct BufferPoolInner {
39 frames: Vec<Frame>,
40 page_table: AHashMap<PageId, FrameId>,
41 clock_hand: usize,
42 capacity: usize,
43}
44
45pub struct BufferPool {
47 inner: Mutex<BufferPoolInner>,
48}
49
50impl BufferPool {
51 pub fn new(capacity: usize) -> Self {
53 assert!(capacity > 0, "buffer pool capacity must be > 0");
54 let frames = (0..capacity).map(|_| Frame::new()).collect();
55 Self {
56 inner: Mutex::new(BufferPoolInner {
57 frames,
58 page_table: AHashMap::with_capacity(capacity),
59 clock_hand: 0,
60 capacity,
61 }),
62 }
63 }
64
65 pub fn fetch_page(&self, page_id: PageId, pm: &mut PageManager) -> MenteResult<Box<Page>> {
70 let mut inner = self.inner.lock();
71
72 if let Some(&frame_id) = inner.page_table.get(&page_id) {
74 let frame = &mut inner.frames[frame_id];
75 frame.pin_count += 1;
76 frame.reference = true;
77 trace!(page_id = page_id.0, frame_id, "buffer pool hit");
78 return Ok(frame.page.clone());
79 }
80
81 let frame_id = Self::find_victim(&mut inner)?;
83
84 if inner.frames[frame_id].dirty
86 && let Some(old_pid) = inner.frames[frame_id].page_id
87 {
88 pm.write_page(old_pid, &inner.frames[frame_id].page)?;
89 debug!(page_id = old_pid.0, frame_id, "flushed dirty victim");
90 }
91
92 if let Some(old_pid) = inner.frames[frame_id].page_id {
94 inner.page_table.remove(&old_pid);
95 }
96
97 let page = pm.read_page(page_id)?;
99 {
100 let frame = &mut inner.frames[frame_id];
101 *frame.page = *page;
102 frame.page_id = Some(page_id);
103 frame.pin_count = 1;
104 frame.dirty = false;
105 frame.reference = true;
106 }
107
108 inner.page_table.insert(page_id, frame_id);
109 trace!(
110 page_id = page_id.0,
111 frame_id, "loaded page into buffer pool"
112 );
113
114 Ok(inner.frames[frame_id].page.clone())
115 }
116
117 pub fn pin_page(&self, page_id: PageId) -> MenteResult<()> {
119 let mut inner = self.inner.lock();
120 match inner.page_table.get(&page_id) {
121 Some(&fid) => {
122 inner.frames[fid].pin_count += 1;
123 Ok(())
124 }
125 None => Err(MenteError::Storage(format!(
126 "page {} not in buffer pool",
127 page_id.0
128 ))),
129 }
130 }
131
132 pub fn unpin_page(&self, page_id: PageId, dirty: bool) -> MenteResult<()> {
134 let mut inner = self.inner.lock();
135 match inner.page_table.get(&page_id) {
136 Some(&fid) => {
137 let frame = &mut inner.frames[fid];
138 if frame.pin_count > 0 {
139 frame.pin_count -= 1;
140 }
141 if dirty {
142 frame.dirty = true;
143 }
144 Ok(())
145 }
146 None => Err(MenteError::Storage(format!(
147 "page {} not in buffer pool",
148 page_id.0
149 ))),
150 }
151 }
152
153 pub fn update_page(&self, page_id: PageId, page: &Page) -> MenteResult<()> {
155 let mut inner = self.inner.lock();
156 match inner.page_table.get(&page_id) {
157 Some(&fid) => {
158 let frame = &mut inner.frames[fid];
159 *frame.page = page.clone();
160 frame.dirty = true;
161 Ok(())
162 }
163 None => Err(MenteError::Storage(format!(
164 "page {} not in buffer pool",
165 page_id.0
166 ))),
167 }
168 }
169
170 pub fn flush_page(&self, page_id: PageId, pm: &mut PageManager) -> MenteResult<()> {
172 let mut inner = self.inner.lock();
173 match inner.page_table.get(&page_id) {
174 Some(&fid) => {
175 let frame = &mut inner.frames[fid];
176 if frame.dirty {
177 pm.write_page(page_id, &frame.page)?;
178 frame.dirty = false;
179 debug!(page_id = page_id.0, "flushed page");
180 }
181 Ok(())
182 }
183 None => Err(MenteError::Storage(format!(
184 "page {} not in buffer pool",
185 page_id.0
186 ))),
187 }
188 }
189
190 pub fn flush_all(&self, pm: &mut PageManager) -> MenteResult<()> {
192 let mut inner = self.inner.lock();
193 for frame in &mut inner.frames {
194 if frame.dirty
195 && let Some(pid) = frame.page_id
196 {
197 pm.write_page(pid, &frame.page)?;
198 frame.dirty = false;
199 }
200 }
201 debug!("flushed all dirty pages");
202 Ok(())
203 }
204
205 fn find_victim(inner: &mut BufferPoolInner) -> MenteResult<FrameId> {
207 let cap = inner.capacity;
208
209 for i in 0..cap {
211 if inner.frames[i].page_id.is_none() {
212 return Ok(i);
213 }
214 }
215
216 let max_sweeps = cap * 2;
218 for _ in 0..max_sweeps {
219 let idx = inner.clock_hand;
220 inner.clock_hand = (inner.clock_hand + 1) % cap;
221
222 let frame = &mut inner.frames[idx];
223 if frame.pin_count == 0 {
224 if !frame.reference {
225 return Ok(idx);
226 }
227 frame.reference = false;
228 }
229 }
230
231 Err(MenteError::Storage(
232 "buffer pool full: all pages are pinned".into(),
233 ))
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use crate::page::Page;
241
242 fn setup() -> (tempfile::TempDir, PageManager) {
243 let dir = tempfile::tempdir().unwrap();
244 let pm = PageManager::open(dir.path()).unwrap();
245 (dir, pm)
246 }
247
248 #[test]
249 fn test_fetch_and_cache_hit() {
250 let (_dir, mut pm) = setup();
251 let pool = BufferPool::new(4);
252
253 let pid = pm.allocate_page().unwrap();
254 let mut page = Page::zeroed();
255 page.header.page_id = pid.0;
256 page.data[0..3].copy_from_slice(b"abc");
257 pm.write_page(pid, &page).unwrap();
258
259 let p1 = pool.fetch_page(pid, &mut pm).unwrap();
261 assert_eq!(&p1.data[0..3], b"abc");
262
263 pool.unpin_page(pid, false).unwrap();
265
266 let p2 = pool.fetch_page(pid, &mut pm).unwrap();
268 assert_eq!(&p2.data[0..3], b"abc");
269 pool.unpin_page(pid, false).unwrap();
270 }
271
272 #[test]
273 fn test_dirty_flush() {
274 let (_dir, mut pm) = setup();
275 let pool = BufferPool::new(4);
276
277 let pid = pm.allocate_page().unwrap();
278
279 let mut page = Page::zeroed();
280 page.header.page_id = pid.0;
281 page.data[0] = 42;
282 pm.write_page(pid, &page).unwrap();
283
284 let _ = pool.fetch_page(pid, &mut pm).unwrap();
286 let mut modified = Page::zeroed();
287 modified.header.page_id = pid.0;
288 modified.data[0] = 99;
289 pool.update_page(pid, &modified).unwrap();
290 pool.unpin_page(pid, true).unwrap();
291
292 pool.flush_page(pid, &mut pm).unwrap();
294
295 let on_disk = pm.read_page(pid).unwrap();
297 assert_eq!(on_disk.data[0], 99);
298 }
299
300 #[test]
301 fn test_eviction() {
302 let (_dir, mut pm) = setup();
303 let pool = BufferPool::new(2); let p1 = pm.allocate_page().unwrap();
307 let p2 = pm.allocate_page().unwrap();
308 let p3 = pm.allocate_page().unwrap();
309
310 for pid in [p1, p2, p3] {
311 let mut page = Page::zeroed();
312 page.header.page_id = pid.0;
313 page.data[0] = pid.0 as u8;
314 pm.write_page(pid, &page).unwrap();
315 }
316
317 let _ = pool.fetch_page(p1, &mut pm).unwrap();
319 pool.unpin_page(p1, false).unwrap();
320 let _ = pool.fetch_page(p2, &mut pm).unwrap();
321 pool.unpin_page(p2, false).unwrap();
322
323 let page3 = pool.fetch_page(p3, &mut pm).unwrap();
325 assert_eq!(page3.data[0], p3.0 as u8);
326 pool.unpin_page(p3, false).unwrap();
327 }
328
329 #[test]
330 fn test_all_pinned_error() {
331 let (_dir, mut pm) = setup();
332 let pool = BufferPool::new(2);
333
334 let p1 = pm.allocate_page().unwrap();
335 let p2 = pm.allocate_page().unwrap();
336 let p3 = pm.allocate_page().unwrap();
337
338 for pid in [p1, p2, p3] {
339 let mut page = Page::zeroed();
340 page.header.page_id = pid.0;
341 pm.write_page(pid, &page).unwrap();
342 }
343
344 let _ = pool.fetch_page(p1, &mut pm).unwrap();
346 let _ = pool.fetch_page(p2, &mut pm).unwrap();
347
348 assert!(pool.fetch_page(p3, &mut pm).is_err());
350 }
351
352 #[test]
353 fn test_flush_all() {
354 let (_dir, mut pm) = setup();
355 let pool = BufferPool::new(4);
356
357 let p1 = pm.allocate_page().unwrap();
358 let p2 = pm.allocate_page().unwrap();
359
360 for pid in [p1, p2] {
361 let mut page = Page::zeroed();
362 page.header.page_id = pid.0;
363 pm.write_page(pid, &page).unwrap();
364 }
365
366 let _ = pool.fetch_page(p1, &mut pm).unwrap();
367 let _ = pool.fetch_page(p2, &mut pm).unwrap();
368
369 let mut mod1 = Page::zeroed();
370 mod1.data[0] = 0xAA;
371 pool.update_page(p1, &mod1).unwrap();
372
373 let mut mod2 = Page::zeroed();
374 mod2.data[0] = 0xBB;
375 pool.update_page(p2, &mod2).unwrap();
376
377 pool.unpin_page(p1, true).unwrap();
378 pool.unpin_page(p2, true).unwrap();
379
380 pool.flush_all(&mut pm).unwrap();
381
382 let d1 = pm.read_page(p1).unwrap();
383 let d2 = pm.read_page(p2).unwrap();
384 assert_eq!(d1.data[0], 0xAA);
385 assert_eq!(d2.data[0], 0xBB);
386 }
387}