disk_queue/
lib.rs

1#![warn(
2    missing_docs,
3)]
4
5//! # disk-queue
6//!
7//! FIFO queue backed by disk.
8//!
9//! ## Usage
10//! 
11//! ```rust
12//! use disk_queue::DiskQueue;
13//! 
14//! let mut queue = DiskQueue::open("test.db");
15//! queue.enqueue("https://sahamee.com".as_bytes().to_vec());
16//! let item = queue.dequeue().unwrap();
17//! let s = std::str::from_utf8(&item).unwrap();
18//! println!("{}", s); // print "https://sahamee.com"
19//! ```
20//!
21
22use std::fs::{File, OpenOptions};
23use std::io::Write;
24use std::path::Path;
25use std::sync::{Arc, Mutex, RwLock};
26
27use memmap::{MmapMut, MmapOptions};
28
29mod constant;
30mod page;
31
32use constant::PAGE_SIZE;
33use page::{
34    buf_write_metadata_page, empty_page_buf,
35    Cursor, Metadata, MetaPage, RecordPage
36};
37
38/// FIFO queue backed by disk.
39pub struct DiskQueue {
40    file: Arc<Mutex<File>>,
41    meta_page: Arc<RwLock<MetaPage>>,
42    read_page: Arc<RwLock<RecordPage>>,
43    write_page: Arc<RwLock<RecordPage>>,
44    // This protects reading/writing from write_page_mem
45    rwlatch: Arc<RwLock<()>>,
46    // Cast to usize so it can be sent safely between threads
47    // Real type is `*const u8`
48    write_page_mem: usize,
49    _mmap: MmapMut,
50}
51
52impl DiskQueue {
53    /// Open existing queue from file or create one if not exist.
54    pub fn open(path: &str) -> Self {
55        // Check if file exists, if it doesn't initialize file and close it
56        if !Path::new(path).exists() {
57            let mut file = File::create(path).unwrap();
58            let meta = Metadata {
59                num_pages: 1,
60                num_items: 0,
61                read_cursor: Cursor { pageid: 1, slotid: 0 },
62                write_cursor: Cursor { pageid: 1, slotid: 0 },
63            };
64            let meta_page_buf = buf_write_metadata_page(&meta);
65            let write_page_buf = empty_page_buf();
66            assert_eq!(meta_page_buf.len(), PAGE_SIZE);
67            assert_eq!(write_page_buf.len(), PAGE_SIZE);
68            file.write(&meta_page_buf).unwrap();
69            file.write(&write_page_buf).unwrap();
70            file.sync_all().unwrap();
71        }
72        
73        // Open the file, mmap-ing the first two pages
74        let file = OpenOptions::new()
75            .read(true)
76            .write(true)
77            .open(path)
78            .unwrap();
79        let mmap = unsafe {
80            MmapOptions::new()
81                .len(2 * PAGE_SIZE)
82                .map_mut(&file)
83                .unwrap()
84        };
85        
86        let meta_page_mem = mmap.as_ptr() as usize;
87        let write_page_mem = unsafe {
88            mmap.as_ptr().add(PAGE_SIZE)
89        } as usize;
90        let file = Arc::new(Mutex::new(file));
91
92        let rwlatch = Arc::new(RwLock::new(()));
93        
94        let meta_page = Arc::new(RwLock::new(MetaPage::from_mmap_ptr(meta_page_mem)));
95        let write_page = Arc::new(RwLock::new(
96            RecordPage::from_mmap_ptr(rwlatch.clone(), write_page_mem))
97        );
98        
99        let read_page;
100        {
101            let meta_page = meta_page.write().unwrap();
102            let num_pages = meta_page.get_num_pages();
103            let read_cursor = meta_page.get_read_cursor();
104            if read_cursor.pageid == num_pages {
105                read_page = Arc::new(RwLock::new(
106                    RecordPage::from_mmap_ptr(rwlatch.clone(), write_page_mem)
107                ));
108            } else {
109                read_page = Arc::new(RwLock::new(
110                    RecordPage::from_file(file.clone(), read_cursor.pageid)
111                ));
112            }
113        }
114
115        Self {
116            file,
117            meta_page,
118            read_page,
119            write_page,
120            write_page_mem,
121            rwlatch,
122            _mmap: mmap,
123        }
124    }
125    
126    /// Get number of items stored in the queue.
127    pub fn num_items(&self) -> u64 {
128        let meta_page = self.meta_page.read().unwrap();
129        meta_page.get_num_items()
130    }
131
132    /// Enqueue item
133    pub fn enqueue(&self, record: Vec<u8>) {
134        let mut meta_page = self.meta_page.write().unwrap();
135        let mut write_page = self.write_page.write().unwrap();
136        if write_page.can_insert(&record) {
137            // Case 1: the write page can still hold the record
138
139            write_page.insert(record);
140            
141            meta_page.incr_num_items();
142            let mut write_cursor = meta_page.get_write_cursor();
143            write_cursor.slotid += 1;
144            meta_page.set_write_cursor(write_cursor);
145        } else {
146            // Case 2: the write page cannot hold the new record
147            //
148            // This should write the page to disk and reset the write page
149            
150            // Copy write page to a new page and reset write page
151            let pageid = meta_page.get_num_pages();
152            write_page.save(self.file.clone(), pageid);
153            write_page.reset();
154            
155            write_page.insert(record);
156
157            let mut write_cursor = meta_page.get_write_cursor();
158            let mut read_cursor = meta_page.get_read_cursor();
159            
160            // There are two cases here:
161            // 1. Read page points to the write page (i.e. it shares the same 
162            //    underlying memory)
163            // 2. Read page points to a read-only page from disk.
164            //
165            // In case 2, we don't have to do anything.
166            //
167            // In case 1, we further need to determine if read_cursor is 
168            // equal to write_cursor. 
169            // If it is, then the read page should still point to write page. 
170            // Nothing should be done.
171            // If it is not, we need to load read_page from a recently written 
172            // page.
173            {
174                let mut read_page = self.read_page.write().unwrap();
175                if read_page.is_shared_mem() && 
176                   read_cursor != write_cursor 
177                {
178                    let rp = RecordPage::from_file(
179                        self.file.clone(),
180                        pageid,
181                    );
182                    *read_page = rp;
183                }
184            }
185
186            // Note that slotid is 1 since we just inserted a new record on 
187            // the newly inserted page
188            //
189            // Also, we need to fix read cursor to point to a new page if it 
190            // points to the same cursor as write cursor
191            if read_cursor == write_cursor {
192                read_cursor.pageid += 1;
193                read_cursor.slotid = 0;
194                meta_page.set_read_cursor(read_cursor);
195            }
196            write_cursor.pageid += 1;
197            write_cursor.slotid = 1;
198
199            meta_page.incr_num_items();
200            meta_page.incr_num_pages();
201            meta_page.set_write_cursor(write_cursor);
202        }
203    }
204
205    /// Dequeue item
206    pub fn dequeue(&self) -> Option<Vec<u8>> {
207        let mut meta_page = self.meta_page.write().unwrap();
208        
209        let mut assign_write_to_read_page = false;
210        let mut read_next_page = false;
211        let mut read_cursor;
212        let record;
213        {
214            let read_page = self.read_page.read().unwrap();
215            
216            let num_pages = meta_page.get_num_pages();
217            let num_records = read_page.num_records();
218            read_cursor = meta_page.get_read_cursor();
219            let write_cursor = meta_page.get_write_cursor();
220            
221            if read_cursor == write_cursor {
222                return None;
223            }
224            
225            match read_page.get_record(read_cursor.slotid as usize) {
226                Some(r) => record = r,
227                None => panic!("Invariant violated"),
228            }
229            if read_cursor.slotid + 1 < num_records as u16 || 
230               read_cursor.pageid == write_cursor.pageid {
231                read_cursor.slotid += 1;
232                meta_page.set_read_cursor(read_cursor.clone());
233            } else {
234                read_cursor.pageid += 1;
235                read_cursor.slotid = 0;
236                meta_page.set_read_cursor(read_cursor.clone());
237                
238                assert!(read_cursor.pageid <= num_pages);
239                if read_cursor.pageid == num_pages {
240                    assign_write_to_read_page = true;
241                } else {
242                    read_next_page = true;
243                }
244            }
245        }
246        
247        if assign_write_to_read_page {
248            let mut read_page = self.read_page.write().unwrap();
249            *read_page = RecordPage::from_mmap_ptr(
250                self.rwlatch.clone(), 
251                self.write_page_mem,
252            );
253        } else if read_next_page {
254            let mut read_page = self.read_page.write().unwrap();
255            *read_page = RecordPage::from_file(
256                self.file.clone(),
257                read_cursor.pageid,
258            );
259        }
260        
261        Some(record)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use rand::RngCore;
268    use std::sync::Condvar;
269    use std::sync::atomic::{AtomicU32, Ordering};
270    use super::*;
271
272    const TEST_DB_PATH: &str = "test.db";
273
274    fn cleanup_test_db() {
275        loop {
276            std::fs::remove_file(TEST_DB_PATH).unwrap();
277            if !Path::new(TEST_DB_PATH).exists() {
278                break;
279            }
280            std::thread::sleep(std::time::Duration::from_millis(1));
281        }
282    }
283
284    #[test]
285    fn basic() {
286        {
287            let records = vec![
288                "https://www.google.com".as_bytes().to_vec(),
289                "https://www.dexcode.com".as_bytes().to_vec(),
290                "https://sahamee.com".as_bytes().to_vec(),
291            ];
292        
293            let queue = DiskQueue::open(TEST_DB_PATH);
294            for record in records.iter() {
295                queue.enqueue(record.clone());
296            }
297
298            let mut popped_records = vec![];
299            loop {
300                match queue.dequeue() {
301                    Some(record) => popped_records.push(record),
302                    None => break,
303                }
304            }
305            
306            assert_eq!(records, popped_records);
307        }
308
309        cleanup_test_db();
310    }
311
312    fn test_read_write_single_threaded(
313        num_records: usize, 
314        read_ratio: u32, 
315        write_ratio: u32
316    ) {
317        assert_eq!((read_ratio + write_ratio) & 1, 0);
318
319        {
320            let mut records = vec![];
321            for i in 0..num_records {
322                let s = format!("record_{}", i);
323                records.push(s.as_bytes().to_vec());
324            }
325
326            let mut popped_records = vec![];
327
328            let queue = DiskQueue::open(TEST_DB_PATH);
329
330            let mut enqueue_finished = false;
331            let mut dequeue_finished = false;
332            let mut rng = rand::thread_rng();
333            let mut records_iter = records.iter();
334            loop {
335                let num = rng.next_u32() % (read_ratio + write_ratio);
336                if num < read_ratio {
337                    // Dequeue
338                    match queue.dequeue() {
339                        Some(r) => {
340                            popped_records.push(r)
341                        }
342                        None => {
343                            if enqueue_finished {
344                                dequeue_finished = true;
345                            }
346                        }
347                    }
348                } else {
349                    // Enqueue
350                    match records_iter.next() {
351                        Some(r) => queue.enqueue(r.clone()),
352                        None => enqueue_finished = true,
353                    }
354                }
355
356                if enqueue_finished && dequeue_finished {
357                    break;
358                }
359            }
360
361            for (idx, record) in records.iter().enumerate() {
362                let empty_vec = vec![];
363                let popped_record = popped_records.get(idx).unwrap_or(&empty_vec);
364                assert_eq!(
365                    String::from_utf8_lossy(record), 
366                    String::from_utf8_lossy(popped_record)
367                );
368            }
369        }
370
371        cleanup_test_db();
372    }
373    
374    #[test]
375    // Test reading & writing a lot of pages with read-write ratio of 1:3
376    fn multiple_pages() {
377        test_read_write_single_threaded(10000, 1, 3);
378    }
379    
380    #[test]
381    // Test reading & writing a lot of pages with read-write ratio of 3:1
382    fn read_plenty() {
383        test_read_write_single_threaded(10000, 3, 1);
384    }
385
386    #[test]
387    fn multithreaded() {
388        {
389            let done_writing_mut = Arc::new(Mutex::new(false));
390
391            let read_count = Arc::new(AtomicU32::new(0));
392
393            let write_ready_mutex = Arc::new(Mutex::new(0));
394            let write_ready_cond = Arc::new(Condvar::new());
395
396            let disk_queue = Arc::new(DiskQueue::open(TEST_DB_PATH));
397
398            // Spawn 8 read threads
399            let mut read_handles = vec![];
400            for _ in 0..8 {
401                let dq = disk_queue.clone();
402                let write_ready_mutex = write_ready_mutex.clone();
403                let write_ready_cond = write_ready_cond.clone();
404                let done_writing_mut = done_writing_mut.clone();
405                let read_count = read_count.clone();
406                let h = std::thread::spawn(move || {
407                    {
408                        let mut write_ready = write_ready_mutex.lock().unwrap();
409                        while *write_ready < 8 {
410                            write_ready = write_ready_cond.wait(write_ready).unwrap();
411                        }
412                    }
413
414                    // TODO: Read threads is busy looping when there are no items
415                    // Perhaps use condition variable to wake up read thread?
416                    loop {
417                        if let Some(_) = dq.dequeue() {
418                            read_count.fetch_add(1, Ordering::Relaxed);
419                        } else {
420                            let done_writing = done_writing_mut.lock().unwrap();
421                            if *done_writing {
422                                break;
423                            }
424                        }
425                    }
426                });
427                read_handles.push(h);
428            }
429
430            // Spawn 8 write threads
431            let mut write_handles = vec![];
432            for tid in 0..8 {
433                let dq = disk_queue.clone();
434                let write_ready_mutex = write_ready_mutex.clone();
435                let write_ready_cond = write_ready_cond.clone();
436                let h = std::thread::spawn(move || {
437                    // Generate records
438                    let mut records = vec![];
439                    for i in 0..1000 {
440                        let s = format!("record_t{}_{}", tid, i);
441                        records.push(s.as_bytes().to_vec());
442                    }
443
444                    // Increment write ready
445                    {
446                        let mut write_ready = write_ready_mutex.lock().unwrap();
447                        *write_ready += 1;
448                        if *write_ready >= 8 {
449                            println!("All threads started");
450                            write_ready_cond.notify_all();
451                        } else {
452                            while *write_ready < 8 {
453                                write_ready = write_ready_cond
454                                    .wait(write_ready).unwrap();
455                            }
456                        }
457                    }
458
459                    println!("Write thread {} start enqueue-ing items", tid);
460
461                    // Start enqueue-ing items
462                    for record in records {
463                        dq.enqueue(record);
464                    }
465
466                    println!("Write thread {} done", tid);
467                });
468                write_handles.push(h);
469            }
470
471            // Wait for all write threads to be ready
472            {
473                let mut write_ready = write_ready_mutex.lock().unwrap();
474                while *write_ready < 8 {
475                    write_ready = write_ready_cond.wait(write_ready).unwrap();
476                }
477            }
478
479            // Wait for write threads to finish and set done_writing
480            for h in write_handles {
481                h.join().unwrap();
482            }
483            {
484                let mut done_writing = done_writing_mut.lock().unwrap();
485                *done_writing = true;
486            }
487
488            // Wait for all read threads to finish
489            for h in read_handles {
490                h.join().unwrap();
491            }
492
493            assert_eq!(read_count.fetch_add(0, Ordering::Relaxed), 8000);
494        }
495
496        cleanup_test_db();
497    }
498}
499