1#![warn(
2 missing_docs,
3)]
4
5use std::fs::{File, OpenOptions};
23use std::io::Write;
24use std::path::Path;
25use std::sync::{Arc, Mutex, RwLock};
26
27use memmap::{MmapMut, MmapOptions};
28
29mod constant;
30mod page;
31
32use constant::PAGE_SIZE;
33use page::{
34 buf_write_metadata_page, empty_page_buf,
35 Cursor, Metadata, MetaPage, RecordPage
36};
37
38pub struct DiskQueue {
40 file: Arc<Mutex<File>>,
41 meta_page: Arc<RwLock<MetaPage>>,
42 read_page: Arc<RwLock<RecordPage>>,
43 write_page: Arc<RwLock<RecordPage>>,
44 rwlatch: Arc<RwLock<()>>,
46 write_page_mem: usize,
49 _mmap: MmapMut,
50}
51
52impl DiskQueue {
53 pub fn open(path: &str) -> Self {
55 if !Path::new(path).exists() {
57 let mut file = File::create(path).unwrap();
58 let meta = Metadata {
59 num_pages: 1,
60 num_items: 0,
61 read_cursor: Cursor { pageid: 1, slotid: 0 },
62 write_cursor: Cursor { pageid: 1, slotid: 0 },
63 };
64 let meta_page_buf = buf_write_metadata_page(&meta);
65 let write_page_buf = empty_page_buf();
66 assert_eq!(meta_page_buf.len(), PAGE_SIZE);
67 assert_eq!(write_page_buf.len(), PAGE_SIZE);
68 file.write(&meta_page_buf).unwrap();
69 file.write(&write_page_buf).unwrap();
70 file.sync_all().unwrap();
71 }
72
73 let file = OpenOptions::new()
75 .read(true)
76 .write(true)
77 .open(path)
78 .unwrap();
79 let mmap = unsafe {
80 MmapOptions::new()
81 .len(2 * PAGE_SIZE)
82 .map_mut(&file)
83 .unwrap()
84 };
85
86 let meta_page_mem = mmap.as_ptr() as usize;
87 let write_page_mem = unsafe {
88 mmap.as_ptr().add(PAGE_SIZE)
89 } as usize;
90 let file = Arc::new(Mutex::new(file));
91
92 let rwlatch = Arc::new(RwLock::new(()));
93
94 let meta_page = Arc::new(RwLock::new(MetaPage::from_mmap_ptr(meta_page_mem)));
95 let write_page = Arc::new(RwLock::new(
96 RecordPage::from_mmap_ptr(rwlatch.clone(), write_page_mem))
97 );
98
99 let read_page;
100 {
101 let meta_page = meta_page.write().unwrap();
102 let num_pages = meta_page.get_num_pages();
103 let read_cursor = meta_page.get_read_cursor();
104 if read_cursor.pageid == num_pages {
105 read_page = Arc::new(RwLock::new(
106 RecordPage::from_mmap_ptr(rwlatch.clone(), write_page_mem)
107 ));
108 } else {
109 read_page = Arc::new(RwLock::new(
110 RecordPage::from_file(file.clone(), read_cursor.pageid)
111 ));
112 }
113 }
114
115 Self {
116 file,
117 meta_page,
118 read_page,
119 write_page,
120 write_page_mem,
121 rwlatch,
122 _mmap: mmap,
123 }
124 }
125
126 pub fn num_items(&self) -> u64 {
128 let meta_page = self.meta_page.read().unwrap();
129 meta_page.get_num_items()
130 }
131
132 pub fn enqueue(&self, record: Vec<u8>) {
134 let mut meta_page = self.meta_page.write().unwrap();
135 let mut write_page = self.write_page.write().unwrap();
136 if write_page.can_insert(&record) {
137 write_page.insert(record);
140
141 meta_page.incr_num_items();
142 let mut write_cursor = meta_page.get_write_cursor();
143 write_cursor.slotid += 1;
144 meta_page.set_write_cursor(write_cursor);
145 } else {
146 let pageid = meta_page.get_num_pages();
152 write_page.save(self.file.clone(), pageid);
153 write_page.reset();
154
155 write_page.insert(record);
156
157 let mut write_cursor = meta_page.get_write_cursor();
158 let mut read_cursor = meta_page.get_read_cursor();
159
160 {
174 let mut read_page = self.read_page.write().unwrap();
175 if read_page.is_shared_mem() &&
176 read_cursor != write_cursor
177 {
178 let rp = RecordPage::from_file(
179 self.file.clone(),
180 pageid,
181 );
182 *read_page = rp;
183 }
184 }
185
186 if read_cursor == write_cursor {
192 read_cursor.pageid += 1;
193 read_cursor.slotid = 0;
194 meta_page.set_read_cursor(read_cursor);
195 }
196 write_cursor.pageid += 1;
197 write_cursor.slotid = 1;
198
199 meta_page.incr_num_items();
200 meta_page.incr_num_pages();
201 meta_page.set_write_cursor(write_cursor);
202 }
203 }
204
205 pub fn dequeue(&self) -> Option<Vec<u8>> {
207 let mut meta_page = self.meta_page.write().unwrap();
208
209 let mut assign_write_to_read_page = false;
210 let mut read_next_page = false;
211 let mut read_cursor;
212 let record;
213 {
214 let read_page = self.read_page.read().unwrap();
215
216 let num_pages = meta_page.get_num_pages();
217 let num_records = read_page.num_records();
218 read_cursor = meta_page.get_read_cursor();
219 let write_cursor = meta_page.get_write_cursor();
220
221 if read_cursor == write_cursor {
222 return None;
223 }
224
225 match read_page.get_record(read_cursor.slotid as usize) {
226 Some(r) => record = r,
227 None => panic!("Invariant violated"),
228 }
229 if read_cursor.slotid + 1 < num_records as u16 ||
230 read_cursor.pageid == write_cursor.pageid {
231 read_cursor.slotid += 1;
232 meta_page.set_read_cursor(read_cursor.clone());
233 } else {
234 read_cursor.pageid += 1;
235 read_cursor.slotid = 0;
236 meta_page.set_read_cursor(read_cursor.clone());
237
238 assert!(read_cursor.pageid <= num_pages);
239 if read_cursor.pageid == num_pages {
240 assign_write_to_read_page = true;
241 } else {
242 read_next_page = true;
243 }
244 }
245 }
246
247 if assign_write_to_read_page {
248 let mut read_page = self.read_page.write().unwrap();
249 *read_page = RecordPage::from_mmap_ptr(
250 self.rwlatch.clone(),
251 self.write_page_mem,
252 );
253 } else if read_next_page {
254 let mut read_page = self.read_page.write().unwrap();
255 *read_page = RecordPage::from_file(
256 self.file.clone(),
257 read_cursor.pageid,
258 );
259 }
260
261 Some(record)
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use rand::RngCore;
268 use std::sync::Condvar;
269 use std::sync::atomic::{AtomicU32, Ordering};
270 use super::*;
271
272 const TEST_DB_PATH: &str = "test.db";
273
274 fn cleanup_test_db() {
275 loop {
276 std::fs::remove_file(TEST_DB_PATH).unwrap();
277 if !Path::new(TEST_DB_PATH).exists() {
278 break;
279 }
280 std::thread::sleep(std::time::Duration::from_millis(1));
281 }
282 }
283
284 #[test]
285 fn basic() {
286 {
287 let records = vec![
288 "https://www.google.com".as_bytes().to_vec(),
289 "https://www.dexcode.com".as_bytes().to_vec(),
290 "https://sahamee.com".as_bytes().to_vec(),
291 ];
292
293 let queue = DiskQueue::open(TEST_DB_PATH);
294 for record in records.iter() {
295 queue.enqueue(record.clone());
296 }
297
298 let mut popped_records = vec![];
299 loop {
300 match queue.dequeue() {
301 Some(record) => popped_records.push(record),
302 None => break,
303 }
304 }
305
306 assert_eq!(records, popped_records);
307 }
308
309 cleanup_test_db();
310 }
311
312 fn test_read_write_single_threaded(
313 num_records: usize,
314 read_ratio: u32,
315 write_ratio: u32
316 ) {
317 assert_eq!((read_ratio + write_ratio) & 1, 0);
318
319 {
320 let mut records = vec![];
321 for i in 0..num_records {
322 let s = format!("record_{}", i);
323 records.push(s.as_bytes().to_vec());
324 }
325
326 let mut popped_records = vec![];
327
328 let queue = DiskQueue::open(TEST_DB_PATH);
329
330 let mut enqueue_finished = false;
331 let mut dequeue_finished = false;
332 let mut rng = rand::thread_rng();
333 let mut records_iter = records.iter();
334 loop {
335 let num = rng.next_u32() % (read_ratio + write_ratio);
336 if num < read_ratio {
337 match queue.dequeue() {
339 Some(r) => {
340 popped_records.push(r)
341 }
342 None => {
343 if enqueue_finished {
344 dequeue_finished = true;
345 }
346 }
347 }
348 } else {
349 match records_iter.next() {
351 Some(r) => queue.enqueue(r.clone()),
352 None => enqueue_finished = true,
353 }
354 }
355
356 if enqueue_finished && dequeue_finished {
357 break;
358 }
359 }
360
361 for (idx, record) in records.iter().enumerate() {
362 let empty_vec = vec![];
363 let popped_record = popped_records.get(idx).unwrap_or(&empty_vec);
364 assert_eq!(
365 String::from_utf8_lossy(record),
366 String::from_utf8_lossy(popped_record)
367 );
368 }
369 }
370
371 cleanup_test_db();
372 }
373
374 #[test]
375 fn multiple_pages() {
377 test_read_write_single_threaded(10000, 1, 3);
378 }
379
380 #[test]
381 fn read_plenty() {
383 test_read_write_single_threaded(10000, 3, 1);
384 }
385
386 #[test]
387 fn multithreaded() {
388 {
389 let done_writing_mut = Arc::new(Mutex::new(false));
390
391 let read_count = Arc::new(AtomicU32::new(0));
392
393 let write_ready_mutex = Arc::new(Mutex::new(0));
394 let write_ready_cond = Arc::new(Condvar::new());
395
396 let disk_queue = Arc::new(DiskQueue::open(TEST_DB_PATH));
397
398 let mut read_handles = vec![];
400 for _ in 0..8 {
401 let dq = disk_queue.clone();
402 let write_ready_mutex = write_ready_mutex.clone();
403 let write_ready_cond = write_ready_cond.clone();
404 let done_writing_mut = done_writing_mut.clone();
405 let read_count = read_count.clone();
406 let h = std::thread::spawn(move || {
407 {
408 let mut write_ready = write_ready_mutex.lock().unwrap();
409 while *write_ready < 8 {
410 write_ready = write_ready_cond.wait(write_ready).unwrap();
411 }
412 }
413
414 loop {
417 if let Some(_) = dq.dequeue() {
418 read_count.fetch_add(1, Ordering::Relaxed);
419 } else {
420 let done_writing = done_writing_mut.lock().unwrap();
421 if *done_writing {
422 break;
423 }
424 }
425 }
426 });
427 read_handles.push(h);
428 }
429
430 let mut write_handles = vec![];
432 for tid in 0..8 {
433 let dq = disk_queue.clone();
434 let write_ready_mutex = write_ready_mutex.clone();
435 let write_ready_cond = write_ready_cond.clone();
436 let h = std::thread::spawn(move || {
437 let mut records = vec![];
439 for i in 0..1000 {
440 let s = format!("record_t{}_{}", tid, i);
441 records.push(s.as_bytes().to_vec());
442 }
443
444 {
446 let mut write_ready = write_ready_mutex.lock().unwrap();
447 *write_ready += 1;
448 if *write_ready >= 8 {
449 println!("All threads started");
450 write_ready_cond.notify_all();
451 } else {
452 while *write_ready < 8 {
453 write_ready = write_ready_cond
454 .wait(write_ready).unwrap();
455 }
456 }
457 }
458
459 println!("Write thread {} start enqueue-ing items", tid);
460
461 for record in records {
463 dq.enqueue(record);
464 }
465
466 println!("Write thread {} done", tid);
467 });
468 write_handles.push(h);
469 }
470
471 {
473 let mut write_ready = write_ready_mutex.lock().unwrap();
474 while *write_ready < 8 {
475 write_ready = write_ready_cond.wait(write_ready).unwrap();
476 }
477 }
478
479 for h in write_handles {
481 h.join().unwrap();
482 }
483 {
484 let mut done_writing = done_writing_mut.lock().unwrap();
485 *done_writing = true;
486 }
487
488 for h in read_handles {
490 h.join().unwrap();
491 }
492
493 assert_eq!(read_count.fetch_add(0, Ordering::Relaxed), 8000);
494 }
495
496 cleanup_test_db();
497 }
498}
499