Skip to main content

bf_tree/wal/
mod.rs

1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT license.
3
4use std::path::Path;
5use std::sync::Arc;
6
7#[cfg(unix)]
8use std::os::unix::fs::FileExt;
9#[cfg(windows)]
10use std::os::windows::fs::FileExt;
11
12mod operations;
13
14use crate::config::WalConfig;
15use crate::fs::VfsImpl;
16use crate::storage::make_vfs;
17use crate::sync::{atomic::AtomicBool, Condvar, Mutex};
18
19pub(crate) use operations::{LogEntry, WriteOp};
20
21const BLOCK_SIZE: usize = 512;
22
23pub(crate) trait LogEntryImpl<'a> {
24    fn log_size(&self) -> usize;
25    fn write_to_buffer(&self, buffer: &mut [u8]);
26    fn read_from_buffer(buffer: &'a [u8]) -> Self;
27}
28
29/// Ptr aligned to block size, so that it can be directly write to storage device
30struct RawBuffer {
31    buffer_size: usize,
32    ptr: *mut u8,
33}
34
35impl RawBuffer {
36    fn new(buffer_size: usize) -> RawBuffer {
37        let layout = std::alloc::Layout::from_size_align(buffer_size, BLOCK_SIZE).unwrap();
38        let ptr = unsafe { std::alloc::alloc(layout) };
39        RawBuffer { ptr, buffer_size }
40    }
41
42    fn as_slice(&self) -> &[u8] {
43        unsafe { std::slice::from_raw_parts(self.ptr, self.buffer_size) }
44    }
45
46    unsafe fn as_mut_slice_at_exact(&mut self, offset: usize, size: usize) -> &mut [u8] {
47        unsafe { std::slice::from_raw_parts_mut(self.ptr.add(offset), size) }
48    }
49}
50
51unsafe impl Send for RawBuffer {}
52unsafe impl Sync for RawBuffer {}
53
54impl Drop for RawBuffer {
55    fn drop(&mut self) {
56        let layout = std::alloc::Layout::from_size_align(self.buffer_size, BLOCK_SIZE).unwrap();
57        unsafe { std::alloc::dealloc(self.ptr, layout) };
58    }
59}
60
61struct WriteAheadLogInner {
62    buffer: RawBuffer,
63    file_handle: Arc<dyn VfsImpl>,
64    buffer_cursor: usize,
65    file_offset: usize,
66    next_lsn: u64,
67    flushed_lsn: u64,
68    need_flush: bool,
69}
70
71impl WriteAheadLogInner {
72    fn flush(&mut self) {
73        if self.buffer_cursor == 0 {
74            // nothing to flush
75            return;
76        }
77
78        self.clear_next_header();
79        self.file_handle
80            .write(self.file_offset, self.buffer.as_slice());
81
82        if !self.should_inplace_flush() {
83            self.file_offset += self.buffer.buffer_size;
84            self.buffer_cursor = 0;
85        }
86
87        self.flushed_lsn = self.next_lsn - 1;
88        self.need_flush = false;
89    }
90
91    fn clear_next_header(&mut self) {
92        if self.buffer_cursor + 8 <= self.buffer.buffer_size {
93            let slice = unsafe { self.buffer.as_mut_slice_at_exact(self.buffer_cursor, 8) };
94            slice.copy_from_slice(&[0u8; 8]);
95        }
96    }
97
98    unsafe fn alloc_buffer(&mut self, size: usize) -> &mut [u8] {
99        debug_assert!(
100            self.buffer_cursor + size <= self.buffer.buffer_size,
101            "buffer overflow"
102        );
103        let cursor = self.buffer_cursor;
104        self.buffer_cursor += size;
105        unsafe { self.buffer.as_mut_slice_at_exact(cursor, size) }
106    }
107
108    /// if buffer is less than half full, we should not create a new buffer,
109    /// instead inplace flush the buffer
110    fn should_inplace_flush(&self) -> bool {
111        self.buffer_cursor < (self.buffer.buffer_size / 2)
112    }
113
114    fn alloc_lsn(&mut self) -> u64 {
115        let lsn = self.next_lsn;
116        self.next_lsn += 1;
117        lsn
118    }
119}
120
121pub(crate) struct WriteAheadLog {
122    inner: Mutex<WriteAheadLogInner>,
123    flushed_cond: Condvar,    // for workers that waiting for flush
124    need_flush_cond: Condvar, // for background job
125    background_job_running: AtomicBool,
126    config: Arc<WalConfig>,
127}
128
129impl WriteAheadLog {
130    /// Create a new wal instance, and start a background thread to flush wal buffer.
131    pub(crate) fn new(config: Arc<WalConfig>) -> Arc<Self> {
132        let vfs = make_vfs(&config.storage_backend, &config.file_path);
133        let wal = WriteAheadLog {
134            inner: Mutex::new(WriteAheadLogInner {
135                buffer: RawBuffer::new(config.segment_size),
136                file_handle: vfs,
137                buffer_cursor: 0,
138                file_offset: 0,
139                next_lsn: 0,
140                flushed_lsn: 0,
141                need_flush: false,
142            }),
143            flushed_cond: Condvar::new(),
144            need_flush_cond: Condvar::new(),
145            background_job_running: AtomicBool::new(true),
146            config,
147        };
148
149        let wal = Arc::new(wal);
150        WriteAheadLog::start_flush_job(wal.clone());
151        wal
152    }
153
154    fn start_flush_job(wal: Arc<Self>) {
155        let h = crate::sync::thread::spawn(move || wal.background_flush_job());
156        drop(h); // detach the thread
157    }
158
159    pub(crate) fn stop_background_job(&self) {
160        self.background_job_running
161            .store(false, std::sync::atomic::Ordering::Relaxed);
162        self.need_flush_cond.notify_all();
163    }
164
165    pub(crate) fn background_flush_job(&self) {
166        let mut inner = self.inner.lock().unwrap();
167
168        let flush_interval = self.config.flush_interval;
169        let mut last_flush = std::time::Instant::now();
170        loop {
171            let v = self
172                .need_flush_cond
173                .wait_timeout(inner, flush_interval)
174                // wait for a notification or a interval, whichever happens first.
175                .unwrap();
176
177            inner = v.0;
178
179            if !self
180                .background_job_running
181                .load(std::sync::atomic::Ordering::Relaxed)
182            {
183                // stop the background job, gracefully shutdown.
184                break;
185            }
186
187            if inner.need_flush || last_flush.elapsed() > flush_interval {
188                inner.flush();
189                last_flush = std::time::Instant::now();
190                self.flushed_cond.notify_all();
191            }
192        }
193    }
194
195    #[must_use = "The returned flushed lsn must be write to page meta"]
196    pub(crate) fn append_and_wait<'a>(
197        &self,
198        log_entry: &impl LogEntryImpl<'a>,
199        page_offset: u64,
200    ) -> u64 {
201        let mut inner = self.inner.lock().unwrap();
202
203        // log header + wal size
204        let required_bytes = std::mem::size_of::<LogHeader>() + log_entry.log_size();
205        let remaining = inner.buffer.buffer_size - inner.buffer_cursor;
206        if required_bytes > remaining {
207            // need to flush buffer
208            inner.need_flush = true;
209            self.need_flush_cond.notify_all();
210            inner = self
211                .flushed_cond
212                .wait_while(inner, |inner| !inner.need_flush)
213                .unwrap();
214            // we need to retry here because by the time we wake up, the buffer maybe already full again.
215            drop(inner);
216            return self.append_and_wait(log_entry, page_offset);
217        }
218
219        let lsn = inner.alloc_lsn();
220        let header = LogHeader::new(lsn, page_offset, required_bytes);
221        let buffer = unsafe { inner.alloc_buffer(required_bytes) };
222        buffer[0..LogHeader::size()].copy_from_slice(header.as_slice());
223        log_entry.write_to_buffer(&mut buffer[LogHeader::size()..]);
224
225        while inner.flushed_lsn < lsn {
226            inner = self.flushed_cond.wait(inner).unwrap();
227        }
228        lsn
229    }
230}
231
232/// Read the write-ahead-log file produced by Bf-Tree.
233///
234/// Allows users to iterate over the log entries in the file and decide what to do with them.
235///
236///
237/// Example
238/// ```ignore
239/// let reader = WalReader::new(&file, 4096);
240/// for segment in reader.segment_iter() {
241///     let seg_iter = segment.iter();
242///     for (header, buffer) in seg_iter {
243///         ...
244///     }
245/// }
246/// ```
247pub struct WalReader {
248    log_file: std::fs::File,
249    segment_size: usize,
250    file_size: usize,
251}
252
253impl WalReader {
254    /// Create a new WalReader instance.
255    ///
256    /// The `segment_size`` should be the same as the one used to create the WriteAheadLog instance.
257    ///
258    /// Todo: we should include segment_size as a field in the wal file, so that we don't need to pass it in.
259    pub fn new(path: impl AsRef<Path>, segment_size: usize) -> Self {
260        let log_file = std::fs::OpenOptions::new().read(true).open(path).unwrap();
261        let file_size = log_file.metadata().unwrap().len() as usize;
262        WalReader {
263            log_file,
264            segment_size,
265            file_size,
266        }
267    }
268
269    /// Iterate through all the segments in the wal file.
270    ///
271    /// Each segment contains multiple log entries,
272    /// you can iterate through the log entries in each segment using the `iter` method on `WalSegment`.
273    pub fn segment_iter(&self) -> WalSegmentIter<'_> {
274        WalSegmentIter {
275            reader: self,
276            cursor: 0,
277        }
278    }
279}
280
281pub struct WalSegmentIter<'a> {
282    reader: &'a WalReader,
283    cursor: u64,
284}
285
286impl Iterator for WalSegmentIter<'_> {
287    type Item = WalSegment;
288    fn next(&mut self) -> Option<Self::Item> {
289        if self.cursor as usize >= self.reader.file_size {
290            return None;
291        }
292
293        let mut buffer = vec![0u8; self.reader.segment_size];
294        let page_offset = self.cursor;
295
296        #[cfg(unix)]
297        {
298            self.reader
299                .log_file
300                .read_exact_at(&mut buffer, page_offset)
301                .unwrap();
302        }
303        #[cfg(windows)]
304        {
305            let bytes_to_read = buffer.len();
306            let bytes_read = self
307                .reader
308                .log_file
309                .seek_read(&mut buffer, page_offset)
310                .unwrap();
311            assert_eq!(bytes_to_read, bytes_read);
312        }
313
314        self.cursor += self.reader.segment_size as u64;
315
316        Some(WalSegment { data: buffer })
317    }
318}
319
320pub struct WalSegment {
321    data: Vec<u8>,
322}
323
324impl WalSegment {
325    /// Iterate through all the log entries in the segment.
326    pub fn entry_iter(&self) -> WalEntryIter<'_> {
327        WalEntryIter {
328            segment: self,
329            cur_offset: 0,
330        }
331    }
332}
333
334pub struct WalEntryIter<'a> {
335    segment: &'a WalSegment,
336    cur_offset: u64,
337}
338
339impl<'a> Iterator for WalEntryIter<'a> {
340    type Item = (LogHeader, &'a [u8]);
341    fn next(&mut self) -> Option<Self::Item> {
342        if (self.cur_offset as usize + LogHeader::size()) >= self.segment.data.len() {
343            return None;
344        }
345
346        let header = LogHeader::from_slice(&self.segment.data[self.cur_offset as usize..]);
347
348        if header.log_len == 0 {
349            return None;
350        }
351
352        let data_start = self.cur_offset as usize + LogHeader::size();
353        let data_end = data_start + header.log_len - LogHeader::size();
354        let data = &self.segment.data[data_start..data_end];
355        self.cur_offset += header.log_len as u64;
356        Some((header, data))
357    }
358}
359
360/// The header of a log entry in the wal file.
361#[repr(C)]
362#[derive(Debug, Clone)]
363pub struct LogHeader {
364    pub log_len: usize,
365    pub lsn: u64,
366    pub page_offset: u64,
367}
368
369impl LogHeader {
370    fn new(lsn: u64, page_offset: u64, log_len: usize) -> Self {
371        LogHeader {
372            log_len,
373            lsn,
374            page_offset,
375        }
376    }
377
378    fn as_slice(&self) -> &[u8] {
379        unsafe {
380            std::slice::from_raw_parts(self as *const _ as *const u8, std::mem::size_of::<Self>())
381        }
382    }
383
384    fn from_slice(buffer: &[u8]) -> Self {
385        let log_len = usize::from_le_bytes(buffer[0..8].try_into().unwrap());
386        let lsn = u64::from_le_bytes(buffer[8..16].try_into().unwrap());
387        let page_offset = u64::from_le_bytes(buffer[16..24].try_into().unwrap());
388        Self::new(lsn, page_offset, log_len)
389    }
390
391    const fn size() -> usize {
392        std::mem::size_of::<Self>()
393    }
394}
395
396const _: () = assert!(LogHeader::size() == 24);
397
398#[cfg(test)]
399mod tests {
400    use std::time::Duration;
401
402    use crate::utils;
403
404    use super::*;
405
406    struct TestLogEntry {
407        val: usize,
408    }
409
410    impl TestLogEntry {
411        fn new(val: usize) -> Self {
412            TestLogEntry { val }
413        }
414    }
415
416    impl LogEntryImpl<'_> for TestLogEntry {
417        fn log_size(&self) -> usize {
418            8
419        }
420
421        fn write_to_buffer(&self, buffer: &mut [u8]) {
422            buffer.copy_from_slice(&self.val.to_le_bytes());
423        }
424
425        fn read_from_buffer(buffer: &[u8]) -> Self {
426            let val = usize::from_le_bytes(buffer.try_into().unwrap());
427            TestLogEntry { val }
428        }
429    }
430
431    fn make_test_wal(name: &str, segment_size: usize) -> Arc<WriteAheadLog> {
432        let tmp_dir = std::env::temp_dir();
433        let tmp_file = tmp_dir.join(name);
434        let mut wal_config = WalConfig::new(&tmp_file);
435        wal_config.segment_size(segment_size);
436        wal_config.flush_interval(Duration::from_micros(1));
437        WriteAheadLog::new(Arc::new(wal_config))
438    }
439
440    #[test]
441    fn simple_wal() {
442        const TEST_SEGMENT_SIZE: usize = 4096;
443        let wal = make_test_wal("wal_simple_test.log", TEST_SEGMENT_SIZE);
444        let tmp_file = wal.config.file_path.clone();
445
446        let log_entry_cnt = 4096;
447
448        for i in 0..log_entry_cnt {
449            let log = TestLogEntry::new(i);
450            let lsn = wal.append_and_wait(&log, log.val as u64);
451            assert_eq!(lsn, i as u64);
452        }
453
454        wal.stop_background_job();
455        drop(wal);
456
457        let reader = WalReader::new(&tmp_file, TEST_SEGMENT_SIZE);
458        let mut cnt = 0;
459        for segment in reader.segment_iter() {
460            let seg_iter = segment.entry_iter();
461            for (header, data) in seg_iter {
462                let val = TestLogEntry::read_from_buffer(data);
463                assert_eq!(
464                    header.log_len,
465                    TestLogEntry::new(0).log_size() + LogHeader::size()
466                );
467                assert_eq!(header.lsn, cnt as u64);
468                assert_eq!(header.page_offset, cnt as u64);
469                assert_eq!(val.val, cnt);
470                cnt += 1;
471            }
472        }
473        assert_eq!(cnt, log_entry_cnt);
474        std::fs::remove_file(tmp_file).unwrap();
475    }
476
477    #[test]
478    fn multi_thread_wal() {
479        const TEST_SEGMENT_SIZE: usize = 4096;
480        let pid = std::process::id();
481        let tid = utils::thread_id_to_u64(std::thread::current().id());
482        let wal = make_test_wal(
483            &format!("wal_multi_thread_test_{}_{}.log", pid, tid),
484            TEST_SEGMENT_SIZE,
485        );
486        let tmp_file = wal.config.file_path.clone();
487
488        let log_entry_cnt = 4096;
489        let thread_cnt = 4;
490
491        let join_handles = (0..thread_cnt)
492            .map(|_| {
493                let wal_t = wal.clone();
494                crate::sync::thread::spawn(move || {
495                    for i in 0..log_entry_cnt {
496                        let log = TestLogEntry::new(i);
497                        let _lsn = wal_t.append_and_wait(&log, log.val as u64);
498                    }
499                })
500            })
501            .collect::<Vec<_>>();
502
503        for h in join_handles.into_iter() {
504            h.join().unwrap();
505        }
506
507        wal.stop_background_job();
508        drop(wal);
509
510        let reader = WalReader::new(&tmp_file, TEST_SEGMENT_SIZE);
511        let mut cnt = 0;
512        for segment in reader.segment_iter() {
513            let seg_iter = segment.entry_iter();
514            for (header, data) in seg_iter {
515                let val = TestLogEntry::read_from_buffer(data);
516                assert_eq!(
517                    header.log_len,
518                    TestLogEntry::new(0).log_size() + LogHeader::size()
519                );
520                assert_eq!(val.val, header.page_offset as usize);
521                cnt += 1;
522            }
523        }
524        assert_eq!(cnt, log_entry_cnt * thread_cnt);
525        std::fs::remove_file(tmp_file).unwrap();
526    }
527
528    /// As of https://github.com/awslabs/shuttle/issues/74
529    /// Shuttle can not properly handle wait_timeout, so we can't really test this with shuttle.
530    #[cfg(feature = "shuttle")]
531    #[test]
532    fn shuttle_wal_concurrent_op() {
533        use std::{path::PathBuf, str::FromStr};
534
535        tracing_subscriber::fmt()
536            .with_ansi(true)
537            .with_thread_names(false)
538            .with_target(false)
539            .init();
540        let mut config = shuttle::Config::default();
541        config.max_steps = shuttle::MaxSteps::None;
542        config.failure_persistence =
543            shuttle::FailurePersistence::File(Some(PathBuf::from_str("target").unwrap()));
544
545        let mut runner = shuttle::PortfolioRunner::new(true, config);
546
547        let available_cores = std::thread::available_parallelism().unwrap().get().min(4);
548
549        for _i in 0..available_cores {
550            runner.add(shuttle::scheduler::PctScheduler::new(10, 4_000));
551        }
552
553        runner.run(multi_thread_wal);
554    }
555
556    #[cfg(feature = "shuttle")]
557    #[test]
558    fn shuttle_wal_replay() {
559        tracing_subscriber::fmt()
560            .with_ansi(true)
561            .with_thread_names(false)
562            .with_target(false)
563            .init();
564
565        shuttle::replay_from_file(multi_thread_wal, "target/schedule003.txt");
566    }
567}