libsql_wal/segment/
current.rs

1use std::hash::Hasher;
2use std::io::{BufWriter, IoSlice, Write};
3use std::num::NonZeroU64;
4use std::ops::DerefMut;
5use std::path::PathBuf;
6use std::sync::atomic::AtomicU32;
7use std::sync::{
8    atomic::{AtomicBool, AtomicU64, Ordering},
9    Arc,
10};
11
12use chrono::{DateTime, Utc};
13use crossbeam_skiplist::SkipMap;
14use fst::MapBuilder;
15use parking_lot::{Mutex, RwLock};
16use roaring::RoaringBitmap;
17use tokio_stream::Stream;
18use uuid::Uuid;
19use zerocopy::little_endian::U32;
20use zerocopy::{AsBytes, FromZeroes};
21
22use crate::io::buf::{IoBufMut, ZeroCopyBoxIoBuf, ZeroCopyBuf};
23use crate::io::file::FileExt;
24use crate::io::Inspect;
25use crate::segment::{checked_frame_offset, SegmentFlags};
26use crate::segment::{frame_offset, page_offset, sealed::SealedSegment};
27use crate::transaction::{Transaction, TxGuardOwned, TxGuardShared};
28use crate::{LIBSQL_MAGIC, LIBSQL_PAGE_SIZE, LIBSQL_WAL_VERSION};
29
30use super::list::SegmentList;
31use super::{CheckedFrame, Frame, FrameHeader, SegmentHeader};
32
33use crate::error::Result;
34
35pub struct CurrentSegment<F> {
36    path: PathBuf,
37    index: SegmentIndex,
38    header: Mutex<SegmentHeader>,
39    file: Arc<F>,
40    /// Read lock count on this segment. Each begin_read increments the count of readers on the current
41    /// lock
42    read_locks: Arc<AtomicU64>,
43    sealed: AtomicBool,
44    /// current running checksum
45    current_checksum: AtomicU32,
46    tail: Arc<SegmentList<SealedSegment<F>>>,
47}
48
49impl<F> CurrentSegment<F> {
50    /// Create a new segment from the given path and metadata. The file pointed to by path must not
51    /// exist.
52    pub fn create(
53        segment_file: F,
54        path: PathBuf,
55        start_frame_no: NonZeroU64,
56        db_size: u32,
57        tail: Arc<SegmentList<SealedSegment<F>>>,
58        salt: u32,
59        log_id: Uuid,
60    ) -> Result<Self>
61    where
62        F: FileExt,
63    {
64        let mut header = SegmentHeader {
65            start_frame_no: start_frame_no.get().into(),
66            last_commited_frame_no: 0.into(),
67            size_after: db_size.into(),
68            index_offset: 0.into(),
69            index_size: 0.into(),
70            header_cheksum: 0.into(),
71            flags: 0.into(),
72            magic: LIBSQL_MAGIC.into(),
73            version: LIBSQL_WAL_VERSION.into(),
74            salt: salt.into(),
75            page_size: LIBSQL_PAGE_SIZE.into(),
76            log_id: log_id.as_u128().into(),
77            frame_count: 0.into(),
78            sealed_at_timestamp: 0.into(),
79        };
80
81        header.recompute_checksum();
82
83        segment_file.write_all_at(header.as_bytes(), 0)?;
84
85        Ok(Self {
86            path: path.to_path_buf(),
87            index: SegmentIndex::new(start_frame_no.get()),
88            header: Mutex::new(header),
89            file: segment_file.into(),
90            read_locks: Arc::new(AtomicU64::new(0)),
91            sealed: AtomicBool::default(),
92            tail,
93            current_checksum: salt.into(),
94        })
95    }
96
97    pub fn log_id(&self) -> Uuid {
98        Uuid::from_u128(self.header.lock().log_id.get())
99    }
100
101    pub fn is_empty(&self) -> bool {
102        self.header.lock().is_empty()
103    }
104
105    pub fn with_header<R>(&self, f: impl FnOnce(&SegmentHeader) -> R) -> R {
106        let header = self.header.lock();
107        f(&header)
108    }
109
110    pub fn last_committed(&self) -> u64 {
111        self.header.lock().last_committed()
112    }
113
114    pub fn next_frame_no(&self) -> NonZeroU64 {
115        self.header.lock().next_frame_no()
116    }
117
118    pub fn count_committed(&self) -> usize {
119        self.header.lock().frame_count()
120    }
121
122    pub fn db_size(&self) -> u32 {
123        self.header.lock().size_after.get()
124    }
125
126    pub fn current_checksum(&self) -> u32 {
127        self.current_checksum.load(Ordering::Relaxed)
128    }
129
130    /// insert a bunch of frames in the Wal. The frames needn't be ordered, therefore, on commit
131    /// the last frame no needs to be passed alongside the new size_after.
132    #[tracing::instrument(skip_all)]
133    pub async fn inject_frames(
134        &self,
135        frames: Vec<Box<Frame>>,
136        // (size_after, last_frame_no)
137        commit_data: Option<(u32, u64)>,
138        tx: &mut TxGuardOwned<F>,
139    ) -> Result<Vec<Box<Frame>>>
140    where
141        F: FileExt,
142    {
143        assert!(!self.sealed.load(Ordering::SeqCst));
144        assert_eq!(
145            tx.savepoints.len(),
146            1,
147            "injecting wal should not use savepoints"
148        );
149        {
150            let tx = tx.deref_mut();
151            // let mut commit_frame_written = false;
152            let current_savepoint = tx.savepoints.last_mut().expect("no savepoints initialized");
153            let mut frames = frame_list_to_option(frames);
154            // For each frame, we compute and write the frame checksum, followed by the frame
155            // itself as an array of CheckedFrame
156            for i in 0..frames.len() {
157                let offset = tx.next_offset;
158                let current_checksum = current_savepoint.current_checksum;
159                let mut digest = crc32fast::Hasher::new_with_initial(current_checksum);
160                digest.write(frames[i].as_ref().unwrap().as_bytes());
161                let new_checksum = digest.finalize();
162                let (_buf, ret) = self
163                    .file
164                    .write_all_at_async(
165                        ZeroCopyBuf::new_init(zerocopy::byteorder::little_endian::U32::new(
166                            new_checksum,
167                        )),
168                        checked_frame_offset(offset),
169                    )
170                    .await;
171                ret?;
172
173                let buf = ZeroCopyBoxIoBuf::new(frames[i].take().unwrap());
174                let (buf, ret) = self
175                    .file
176                    .write_all_at_async(buf, frame_offset(offset))
177                    .await;
178                ret?;
179
180                let frame = buf.into_inner();
181
182                current_savepoint
183                    .index
184                    .insert(frame.header().page_no(), offset);
185                current_savepoint.current_checksum = new_checksum;
186                tx.next_offset += 1;
187                frames[i] = Some(frame);
188            }
189
190            if let Some((size_after, last_frame_no)) = commit_data {
191                if tx.not_empty() {
192                    let mut header = { *self.header.lock() };
193                    header.last_commited_frame_no = last_frame_no.into();
194                    header.size_after = size_after.into();
195                    // set frames unordered because there are no guarantees that we received frames
196                    // in order.
197                    header.set_flags(header.flags().union(SegmentFlags::FRAME_UNORDERED));
198                    {
199                        let savepoint = tx.savepoints.first().unwrap();
200                        header.frame_count = (header.frame_count.get()
201                            + (tx.next_offset - savepoint.next_offset) as u64)
202                            .into();
203                    }
204                    header.recompute_checksum();
205
206                    let (header, ret) = self
207                        .file
208                        .write_all_at_async(ZeroCopyBuf::new_init(header), 0)
209                        .await;
210
211                    ret?;
212
213                    // self.file.sync_data().unwrap();
214                    tx.merge_savepoints(&self.index);
215                    // set the header last, so that a transaction does not witness a write before
216                    // it's actually committed.
217                    self.current_checksum
218                        .store(tx.current_checksum(), Ordering::Relaxed);
219                    *self.header.lock() = header.into_inner();
220
221                    tx.is_commited = true;
222                }
223            }
224
225            let frames = options_to_frame_list(frames);
226
227            Ok(frames)
228        }
229    }
230
231    #[tracing::instrument(skip(self, pages, tx))]
232    pub fn insert_pages<'a>(
233        &self,
234        pages: impl Iterator<Item = (u32, &'a [u8])>,
235        size_after: Option<u32>,
236        tx: &mut TxGuardShared<F>,
237    ) -> Result<Option<u64>>
238    where
239        F: FileExt,
240    {
241        assert!(!self.sealed.load(Ordering::SeqCst));
242        {
243            let tx = tx.deref_mut();
244            let mut pages = pages.peekable();
245            // let mut commit_frame_written = false;
246            let current_savepoint = tx.savepoints.last_mut().expect("no savepoints initialized");
247            while let Some((page_no, page)) = pages.next() {
248                // optim: if the page is already present, overwrite its content
249                if let Some(offset) = current_savepoint.index.get(&page_no) {
250                    tracing::trace!(page_no, "recycling frame");
251                    self.file.write_all_at(page, page_offset(*offset))?;
252                    // we overwrote a frame, record that for later rewrite
253                    tx.recompute_checksum = Some(
254                        tx.recompute_checksum
255                            .map(|old| old.min(*offset))
256                            .unwrap_or(*offset),
257                    );
258                    continue;
259                }
260
261                tracing::trace!(page_no, "inserting new frame");
262                let size_after = if let Some(size) = size_after {
263                    pages.peek().is_none().then_some(size).unwrap_or(0)
264                } else {
265                    0
266                };
267
268                let frame_no = tx.next_frame_no;
269                let header = FrameHeader {
270                    page_no: page_no.into(),
271                    size_after: size_after.into(),
272                    frame_no: frame_no.into(),
273                };
274
275                // only compute checksum if we don't need to recompute it later
276                let checksum = if tx.recompute_checksum.is_none() {
277                    let mut digest =
278                        crc32fast::Hasher::new_with_initial(current_savepoint.current_checksum);
279                    digest.write(header.as_bytes());
280                    digest.write(page);
281                    digest.finalize()
282                } else {
283                    0
284                };
285
286                let checksum_bytes = checksum.to_le_bytes();
287                // We write a instance of a ChecksummedFrame
288                let slices = &[
289                    IoSlice::new(&checksum_bytes),
290                    IoSlice::new(header.as_bytes()),
291                    IoSlice::new(&page),
292                ];
293                let offset = tx.next_offset;
294                debug_assert_eq!(
295                    self.header.lock().start_frame_no.get() + offset as u64,
296                    frame_no
297                );
298                self.file
299                    .write_at_vectored(slices, checked_frame_offset(offset))?;
300                assert!(
301                    current_savepoint.index.insert(page_no, offset).is_none(),
302                    "existing frames should be recycled"
303                );
304                current_savepoint.current_checksum = checksum;
305                tx.next_frame_no += 1;
306                tx.next_offset += 1;
307            }
308        }
309
310        // commit
311        if let Some(size_after) = size_after {
312            if tx.not_empty() {
313                let new_checksum = if let Some(offset) = tx.recompute_checksum {
314                    self.recompute_checksum(offset, tx.next_offset - 1)?
315                } else {
316                    tx.current_checksum()
317                };
318
319                #[cfg(debug_assertions)]
320                {
321                    // ensure that file checksum for that transaction is valid
322                    let from = {
323                        let header = self.header.lock();
324                        if header.last_commited_frame_no() == 0 {
325                            0
326                        } else {
327                            (header.last_commited_frame_no() - header.start_frame_no.get()) as u32
328                        }
329                    };
330
331                    self.assert_valid_checksum(from, tx.next_offset - 1)?;
332                }
333
334                let last_frame_no = tx.next_frame_no - 1;
335                let mut header = { *self.header.lock() };
336                header.last_commited_frame_no = last_frame_no.into();
337                header.size_after = size_after.into();
338                // count how many frames were appeneded: basically last appeneded offset - initial
339                // offset
340                let tx = tx.deref_mut();
341                let savepoint = tx.savepoints.first().unwrap();
342                header.frame_count = (header.frame_count.get()
343                    + (tx.next_offset - savepoint.next_offset) as u64)
344                    .into();
345                header.recompute_checksum();
346
347                self.file.write_all_at(header.as_bytes(), 0)?;
348                // todo: sync if sync mode is EXTRA
349                // self.file.sync_data().unwrap();
350                tx.merge_savepoints(&self.index);
351                // set the header last, so that a transaction does not witness a write before
352                // it's actually committed.
353                *self.header.lock() = header;
354                self.current_checksum.store(new_checksum, Ordering::Relaxed);
355
356                tx.is_commited = true;
357
358                return Ok(Some(last_frame_no));
359            }
360        }
361        Ok(None)
362    }
363
364    /// return the offset of the frame for page_no, with frame_no no larger that max_frame_no, if
365    /// it exists
366    pub fn find_frame(&self, page_no: u32, tx: &Transaction<F>) -> Option<u32> {
367        // if it's a write transaction, check its transient index first
368        if let Transaction::Write(ref tx) = tx {
369            if let Some(offset) = tx.find_frame_offset(page_no) {
370                return Some(offset);
371            }
372        }
373
374        // not a write tx, or page is not in write tx, look into the segment
375        self.index.locate(page_no, tx.max_offset)
376    }
377
378    /// reads the page conainted in frame at offset into buf
379    #[tracing::instrument(skip(self, buf))]
380    pub fn read_page_offset(&self, offset: u32, buf: &mut [u8]) -> Result<()>
381    where
382        F: FileExt,
383    {
384        tracing::trace!("read page");
385        debug_assert_eq!(buf.len(), 4096);
386        self.file.read_exact_at(buf, page_offset(offset))?;
387
388        Ok(())
389    }
390
391    async fn read_frame_offset_async<B>(&self, offset: u32, buf: B) -> (B, std::io::Result<()>)
392    where
393        F: FileExt,
394        B: IoBufMut + Send + 'static,
395    {
396        let byte_offset = frame_offset(offset);
397        self.file.read_exact_at_async(buf, byte_offset).await
398    }
399
400    #[allow(dead_code)]
401    pub fn frame_header_at(&self, offset: u32) -> Result<FrameHeader>
402    where
403        F: FileExt,
404    {
405        let mut header = FrameHeader::new_zeroed();
406        self.file
407            .read_exact_at(header.as_bytes_mut(), frame_offset(offset))?;
408        Ok(header)
409    }
410
411    /// It is expected that sealing is performed under a write lock
412    #[tracing::instrument(skip_all)]
413    pub fn seal(&self, now: DateTime<Utc>) -> Result<Option<SealedSegment<F>>>
414    where
415        F: FileExt,
416    {
417        let mut header = self.header.lock();
418        let index_offset = header.frame_count() as u32;
419        let index_byte_offset = checked_frame_offset(index_offset);
420        let mut cursor = self.file.cursor(index_byte_offset);
421        let writer = BufWriter::new(&mut cursor);
422
423        let current = self.current_checksum();
424        let mut digest = crc32fast::Hasher::new_with_initial(current);
425        let mut writer = Inspect::new(writer, |data: &[u8]| {
426            digest.write(data);
427        });
428        self.index.merge_all(&mut writer)?;
429        let mut writer = writer.into_inner();
430        let index_checksum = digest.finalize();
431        let index_size = writer.get_ref().count();
432        writer.write_all(&index_checksum.to_le_bytes())?;
433
434        writer.into_inner().map_err(|e| e.into_parts().0)?;
435        // we perform a first sync to ensure that all the segment has been flushed to disk. We then
436        // write the header and flush again. We want to guarantee that if we find a segement marked
437        // as "SEALED", then there was no partial flush.
438        //
439        // If a segment is found that doesn't have the SEALED flag, then we enter crash recovery,
440        // and we need to check the segment.
441        self.file.sync_all()?;
442
443        header.index_offset = index_byte_offset.into();
444        header.index_size = index_size.into();
445        let flags = header.flags();
446        header.set_flags(flags | SegmentFlags::SEALED);
447        header.sealed_at_timestamp = (now.timestamp_millis() as u64).into();
448        header.recompute_checksum();
449        self.file.write_all_at(header.as_bytes(), 0)?;
450
451        // flush the header.
452        self.file.sync_all()?;
453
454        let sealed = SealedSegment::open(
455            self.file.clone(),
456            self.path.clone(),
457            self.read_locks.clone(),
458            now,
459        )?;
460
461        // we only flip the sealed mark when no more error can occur, or we risk to deadlock a read
462        // transaction waiting for a more recent version of the segment that is never going to arrive
463        assert!(
464            self.sealed
465                .compare_exchange(false, true, Ordering::SeqCst, Ordering::Relaxed)
466                .is_ok(),
467            "attempt to seal an already sealed segment"
468        );
469
470        tracing::debug!("segment sealed");
471
472        Ok(sealed)
473    }
474
475    pub fn last_committed_frame_no(&self) -> u64 {
476        let header = self.header.lock();
477        if header.last_commited_frame_no.get() == 0 {
478            header.start_frame_no.get()
479        } else {
480            header.last_commited_frame_no.get()
481        }
482    }
483
484    pub fn inc_reader_count(&self) {
485        self.read_locks().fetch_add(1, Ordering::SeqCst);
486    }
487
488    /// return true if the reader count is 0
489    pub fn dec_reader_count(&self) -> bool {
490        self.read_locks().fetch_sub(1, Ordering::SeqCst) - 1 == 0
491    }
492
493    pub fn read_locks(&self) -> &AtomicU64 {
494        self.read_locks.as_ref()
495    }
496
497    pub fn is_sealed(&self) -> bool {
498        self.sealed.load(Ordering::SeqCst)
499    }
500
501    pub fn tail(&self) -> &Arc<SegmentList<SealedSegment<F>>> {
502        &self.tail
503    }
504
505    /// returns all the frames that changed between start_frame_no and the current commit index
506    pub(crate) fn frame_stream_from<'a>(
507        &'a self,
508        start_frame_no: u64,
509        seen: &'a mut RoaringBitmap,
510    ) -> (impl Stream<Item = Result<Box<Frame>>> + 'a, u64, u32)
511    where
512        F: FileExt,
513    {
514        let (seg_start_frame_no, last_committed, db_size) =
515            self.with_header(|h| (h.start_frame_no.get(), h.last_committed(), h.size_after()));
516        let replicated_until = seg_start_frame_no
517            // if current is empty, start_frame_no doesn't exist
518            .min(last_committed)
519            .max(start_frame_no);
520
521        // TODO: optim, we could read less frames if we had a mapping from frame_no to page_no in
522        // the index
523        let stream = async_stream::try_stream! {
524            if !self.is_empty() {
525                let mut frame_offset = (last_committed - seg_start_frame_no) as u32;
526                loop {
527                    let buf = ZeroCopyBoxIoBuf::new(Frame::new_box_zeroed());
528                    let (buf, res) = self.read_frame_offset_async(frame_offset, buf).await;
529                    res?;
530
531                    let mut frame = buf.into_inner();
532                    frame.header_mut().size_after = 0.into();
533                    let page_no = frame.header().page_no();
534
535                    let frame_no = frame.header().frame_no();
536                    if frame_no < start_frame_no {
537                        break
538                    }
539
540                    if !seen.contains(page_no) {
541                        seen.insert(page_no);
542                        yield frame;
543                    }
544
545                    if frame_offset == 0 {
546                        break
547                    }
548
549                    frame_offset -= 1;
550                }
551            }
552        };
553
554        (stream, replicated_until, db_size)
555    }
556
557    fn recompute_checksum(&self, start_offset: u32, until_offset: u32) -> Result<u32>
558    where
559        F: FileExt,
560    {
561        let mut current_checksum = if start_offset == 0 {
562            self.header.lock().salt.get()
563        } else {
564            // we get the checksum from the frame just before the the start offset
565            let frame_offset = checked_frame_offset(start_offset - 1);
566            let mut out = U32::new(0);
567            self.file.read_exact_at(out.as_bytes_mut(), frame_offset)?;
568            out.get()
569        };
570
571        let mut checked_frame: Box<CheckedFrame> = CheckedFrame::new_box_zeroed();
572        for offset in start_offset..=until_offset {
573            let frame_offset = checked_frame_offset(offset);
574            self.file
575                .read_exact_at(checked_frame.as_bytes_mut(), frame_offset)?;
576            current_checksum = checked_frame.frame.checksum(current_checksum);
577            self.file
578                .write_all_at(&current_checksum.to_le_bytes(), frame_offset)?;
579        }
580
581        Ok(current_checksum)
582    }
583
584    /// test fuction to ensure checksum integrity
585    #[cfg(debug_assertions)]
586    #[track_caller]
587    fn assert_valid_checksum(&self, from: u32, until: u32) -> Result<()>
588    where
589        F: FileExt,
590    {
591        let mut frame: Box<CheckedFrame> = CheckedFrame::new_box_zeroed();
592        let mut current_checksum = if from != 0 {
593            let offset = checked_frame_offset(from - 1);
594            self.file.read_exact_at(frame.as_bytes_mut(), offset)?;
595            frame.checksum.get()
596        } else {
597            self.header.lock().salt.get()
598        };
599
600        for i in from..=until {
601            let offset = checked_frame_offset(i);
602            self.file.read_exact_at(frame.as_bytes_mut(), offset)?;
603            current_checksum = frame.frame.checksum(current_checksum);
604            assert_eq!(
605                current_checksum,
606                frame.checksum.get(),
607                "invalid checksum at offset {i}"
608            );
609        }
610
611        Ok(())
612    }
613}
614
615fn frame_list_to_option(frames: Vec<Box<Frame>>) -> Vec<Option<Box<Frame>>> {
616    // this is safe because Option<Box<T>> and Box<T> are the same size and Frame is sized:
617    // https://doc.rust-lang.org/std/option/index.html#representation
618    unsafe { std::mem::transmute(frames) }
619}
620
621fn options_to_frame_list(frames: Vec<Option<Box<Frame>>>) -> Vec<Box<Frame>> {
622    debug_assert!(frames.iter().all(|f| f.is_some()));
623    // this is safe because Option<Box<T>> and Box<T> are the same size and Frame is sized:
624    // https://doc.rust-lang.org/std/option/index.html#representation
625    unsafe { std::mem::transmute(frames) }
626}
627
628impl<F> Drop for CurrentSegment<F> {
629    fn drop(&mut self) {
630        // todo: if reader is 0 and segment is sealed, register for compaction.
631    }
632}
633
634/// TODO: implement spill-to-disk when txn is too large
635/// TODO: optimize that data structure with something more custom. I can't find a wholy satisfying
636/// structure in the wild.
637pub(crate) struct SegmentIndex {
638    start_frame_no: u64,
639    // TODO: measure perf, and consider using https://docs.rs/bplustree/latest/bplustree/
640    index: SkipMap<u32, RwLock<Vec<u32>>>,
641}
642
643impl SegmentIndex {
644    pub fn new(start_frame_no: u64) -> Self {
645        Self {
646            start_frame_no,
647            index: Default::default(),
648        }
649    }
650
651    fn locate(&self, page_no: u32, max_offset: u64) -> Option<u32> {
652        let offsets = self.index.get(&page_no)?;
653        let offsets = offsets.value().read();
654        offsets
655            .iter()
656            .rev()
657            .find(|fno| **fno as u64 <= max_offset)
658            .copied()
659    }
660
661    #[tracing::instrument(skip_all)]
662    fn merge_all<W: Write>(&self, writer: W) -> Result<()> {
663        let mut builder = MapBuilder::new(writer)?;
664        let Some(mut entry) = self.index.front() else {
665            return Ok(());
666        };
667        loop {
668            let offset = *entry.value().read().last().unwrap();
669            builder.insert(entry.key().to_be_bytes(), offset as u64)?;
670            if !entry.move_next() {
671                break;
672            }
673        }
674
675        builder.finish()?;
676        Ok(())
677    }
678
679    pub(crate) fn insert(&self, page_no: u32, offset: u32) {
680        let entry = self.index.get_or_insert(page_no, Default::default());
681        let mut offsets = entry.value().write();
682        if offsets.is_empty() || *offsets.last().unwrap() < offset {
683            offsets.push(offset);
684        }
685    }
686}
687
688#[cfg(test)]
689mod test {
690    use std::io::{self, Read};
691
692    use chrono::{DateTime, Utc};
693    use hashbrown::HashMap;
694    use insta::assert_debug_snapshot;
695    use rand::rngs::ThreadRng;
696    use tempfile::{tempdir, tempfile};
697    use tokio_stream::StreamExt;
698    use uuid::Uuid;
699
700    use crate::io::{FileExt, Io};
701    use crate::test::{seal_current_segment, TestEnv};
702
703    use super::*;
704
705    #[tokio::test]
706    async fn current_stream_frames() {
707        let env = TestEnv::new();
708        let conn = env.open_conn("test");
709        let shared = env.shared("test");
710
711        conn.execute("create table test (x)", ()).unwrap();
712        for _ in 0..50 {
713            conn.execute("insert into test values (randomblob(256))", ())
714                .unwrap();
715        }
716
717        let mut seen = RoaringBitmap::new();
718        let current = shared.current.load();
719        let (stream, replicated_until, size_after) = current.frame_stream_from(1, &mut seen);
720        tokio::pin!(stream);
721        assert_eq!(replicated_until, 1);
722        assert_eq!(size_after, 6);
723
724        let mut tmp = tempfile().unwrap();
725        while let Some(frame) = stream.next().await {
726            let frame = frame.unwrap();
727            let offset = (frame.header().page_no() - 1) * 4096;
728            tmp.write_all_at(frame.data(), offset as _).unwrap();
729        }
730
731        seal_current_segment(&shared);
732        *shared.durable_frame_no.lock() = 999999;
733        shared.checkpoint().await.unwrap();
734
735        let mut orig = Vec::new();
736        shared
737            .db_file
738            .try_clone()
739            .unwrap()
740            .read_to_end(&mut orig)
741            .unwrap();
742
743        let mut copy = Vec::new();
744        tmp.read_to_end(&mut copy).unwrap();
745
746        assert_eq!(db_payload(&copy), db_payload(&orig));
747    }
748
749    #[tokio::test]
750    async fn current_stream_frames_incomplete() {
751        let env = TestEnv::new();
752        let conn = env.open_conn("test");
753        let shared = env.shared("test");
754
755        conn.execute("create table test (x)", ()).unwrap();
756
757        for _ in 0..50 {
758            conn.execute("insert into test values (randomblob(256))", ())
759                .unwrap();
760        }
761
762        seal_current_segment(&shared);
763
764        for _ in 0..50 {
765            conn.execute("insert into test values (randomblob(256))", ())
766                .unwrap();
767        }
768
769        let mut seen = RoaringBitmap::new();
770        {
771            let current = shared.current.load();
772            let (stream, replicated_until, size_after) = current.frame_stream_from(1, &mut seen);
773            tokio::pin!(stream);
774            assert_eq!(replicated_until, 60);
775            assert_eq!(size_after, 9);
776            assert_eq!(stream.fold(0, |count, _| count + 1).await, 6);
777        }
778        assert_debug_snapshot!(seen);
779    }
780
781    #[tokio::test]
782    async fn current_stream_too_recent_frame_no() {
783        let env = TestEnv::new();
784        let conn = env.open_conn("test");
785        let shared = env.shared("test");
786
787        conn.execute("create table test (x)", ()).unwrap();
788
789        let mut seen = RoaringBitmap::new();
790        let current = shared.current.load();
791        let (stream, replicated_until, size_after) = current.frame_stream_from(100, &mut seen);
792        tokio::pin!(stream);
793        assert_eq!(replicated_until, 100);
794        assert_eq!(stream.fold(0, |count, _| count + 1).await, 0);
795        assert_eq!(size_after, 2);
796    }
797
798    #[tokio::test]
799    async fn current_stream_empty_segment() {
800        let env = TestEnv::new();
801        let conn = env.open_conn("test");
802        let shared = env.shared("test");
803
804        conn.execute("create table test (x)", ()).unwrap();
805        seal_current_segment(&shared);
806
807        let mut seen = RoaringBitmap::new();
808        let current = shared.current.load();
809        let (stream, replicated_until, size_after) = current.frame_stream_from(1, &mut seen);
810        tokio::pin!(stream);
811        assert_eq!(replicated_until, 2);
812        assert_eq!(size_after, 2);
813        assert_eq!(stream.fold(0, |count, _| count + 1).await, 0);
814    }
815
816    #[tokio::test]
817    async fn crash_on_flush() {
818        #[derive(Clone, Default)]
819        struct SyncFailBufferIo {
820            inner: Arc<Mutex<HashMap<PathBuf, Arc<Mutex<Vec<u8>>>>>>,
821        }
822
823        struct File {
824            path: PathBuf,
825            io: SyncFailBufferIo,
826        }
827
828        impl File {
829            fn inner(&self) -> Arc<Mutex<Vec<u8>>> {
830                self.io.inner.lock().get(&self.path).cloned().unwrap()
831            }
832        }
833
834        impl FileExt for File {
835            fn len(&self) -> std::io::Result<u64> {
836                Ok(self.inner().lock().len() as u64)
837            }
838
839            fn write_at_vectored(&self, bufs: &[IoSlice], offset: u64) -> std::io::Result<usize> {
840                let mut written = 0;
841                for buf in bufs {
842                    self.write_at(buf.as_bytes(), written + offset)?;
843                    written += buf.len() as u64;
844                }
845                Ok(written as _)
846            }
847
848            fn write_at(&self, buf: &[u8], offset: u64) -> std::io::Result<usize> {
849                let data = self.inner();
850                let mut data = data.lock();
851                let new_len = offset as usize + buf.len();
852                let old_len = data.len();
853                if old_len < new_len {
854                    data.extend(std::iter::repeat(0).take(new_len - old_len));
855                }
856                data[offset as usize..offset as usize + buf.len()].copy_from_slice(buf);
857                Ok(buf.len())
858            }
859
860            fn read_at(&self, buf: &mut [u8], offset: u64) -> std::io::Result<usize> {
861                let inner = self.inner();
862                let inner = inner.lock();
863                if offset >= inner.len() as u64 {
864                    return Ok(0);
865                }
866
867                let read_len = buf.len().min(inner.len() - offset as usize);
868                buf[..read_len]
869                    .copy_from_slice(&inner[offset as usize..offset as usize + read_len]);
870                Ok(read_len)
871            }
872
873            fn sync_all(&self) -> std::io::Result<()> {
874                // simulate a flush that only flushes half the pages and then fail
875                let inner = self.inner();
876                let inner = inner.lock();
877                // just keep 5 pages from the log. The log will be incomplete and frames will be
878                // broken.
879                std::fs::write(&self.path, &inner[..4096 * 5])?;
880                Err(io::Error::new(io::ErrorKind::BrokenPipe, ""))
881            }
882
883            fn set_len(&self, _len: u64) -> std::io::Result<()> {
884                todo!()
885            }
886
887            async fn read_exact_at_async<B: IoBufMut + Send + 'static>(
888                &self,
889                mut buf: B,
890                offset: u64,
891            ) -> (B, std::io::Result<()>) {
892                let slice = unsafe {
893                    std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total())
894                };
895                let ret = self.read_at(slice, offset);
896                (buf, ret.map(|_| ()))
897            }
898
899            async fn read_at_async<B: IoBufMut + Send + 'static>(
900                &self,
901                _buf: B,
902                _offset: u64,
903            ) -> (B, std::io::Result<usize>) {
904                todo!()
905            }
906
907            async fn write_all_at_async<B: crate::io::buf::IoBuf + Send + 'static>(
908                &self,
909                _buf: B,
910                _offset: u64,
911            ) -> (B, std::io::Result<()>) {
912                todo!()
913            }
914        }
915
916        impl Io for SyncFailBufferIo {
917            type File = File;
918            type Rng = ThreadRng;
919            type TempFile = File;
920
921            fn create_dir_all(&self, path: &std::path::Path) -> std::io::Result<()> {
922                std::fs::create_dir_all(path)
923            }
924
925            fn open(
926                &self,
927                _create_new: bool,
928                _read: bool,
929                _write: bool,
930                path: &std::path::Path,
931            ) -> std::io::Result<Self::File> {
932                let mut inner = self.inner.lock();
933                if !inner.contains_key(path) {
934                    let data = if path.exists() {
935                        std::fs::read(path)?
936                    } else {
937                        vec![]
938                    };
939                    inner.insert(path.to_owned(), Arc::new(Mutex::new(data)));
940                }
941
942                Ok(File {
943                    path: path.into(),
944                    io: self.clone(),
945                })
946            }
947
948            fn tempfile(&self) -> std::io::Result<Self::TempFile> {
949                todo!()
950            }
951
952            fn now(&self) -> DateTime<Utc> {
953                Utc::now()
954            }
955
956            fn uuid(&self) -> uuid::Uuid {
957                Uuid::new_v4()
958            }
959
960            fn hard_link(
961                &self,
962                _src: &std::path::Path,
963                _dst: &std::path::Path,
964            ) -> std::io::Result<()> {
965                todo!()
966            }
967
968            fn with_rng<F, R>(&self, f: F) -> R
969            where
970                F: FnOnce(&mut Self::Rng) -> R,
971            {
972                f(&mut rand::thread_rng())
973            }
974
975            fn remove_file_async(
976                &self,
977                path: &std::path::Path,
978            ) -> impl std::future::Future<Output = io::Result<()>> + Send {
979                async move { std::fs::remove_file(path) }
980            }
981        }
982
983        let tmp = Arc::new(tempdir().unwrap());
984        {
985            let env = TestEnv::new_io_and_tmp(SyncFailBufferIo::default(), tmp.clone(), false);
986            let conn = env.open_conn("test");
987            let shared = env.shared("test");
988
989            conn.execute("create table test (x)", ()).unwrap();
990            for _ in 0..6 {
991                conn.execute("insert into test values (1234)", ()).unwrap();
992            }
993
994            // trigger a flush, that will fail. When we reopen the db, the log should need recovery
995            // this simulates a crash before flush
996            {
997                let mut tx = shared.begin_read(99999).into();
998                shared.upgrade(&mut tx).unwrap();
999                let mut guard = tx.as_write_mut().unwrap().lock();
1000                guard.commit();
1001                let _ = shared.swap_current(&mut guard);
1002            }
1003        }
1004
1005        {
1006            let env = TestEnv::new_io_and_tmp(SyncFailBufferIo::default(), tmp.clone(), false);
1007            let conn = env.open_conn("test");
1008            // the db was recovered: we lost some rows, but it still works
1009            conn.query_row("select count(*) from test", (), |row| {
1010                assert_eq!(row.get::<_, u32>(0).unwrap(), 2);
1011                Ok(())
1012            })
1013            .unwrap();
1014        }
1015    }
1016
1017    fn db_payload(db: &[u8]) -> &[u8] {
1018        let size = (db.len() / 4096) * 4096;
1019        &db[..size]
1020    }
1021}