libsql_wal/segment/
list.rs

1use core::fmt;
2use std::ops::Deref;
3use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
4use std::sync::Arc;
5
6use arc_swap::ArcSwapOption;
7use fst::raw::IndexedValue;
8use fst::Streamer;
9use roaring::RoaringBitmap;
10use tokio_stream::Stream;
11use uuid::Uuid;
12use zerocopy::FromZeroes;
13
14use crate::error::Result;
15use crate::io::buf::{ZeroCopyBoxIoBuf, ZeroCopyBuf};
16use crate::io::{FileExt, Io};
17use crate::segment::Frame;
18use crate::{LibsqlFooter, LIBSQL_MAGIC, LIBSQL_PAGE_SIZE, LIBSQL_WAL_VERSION};
19
20use super::Segment;
21
22#[derive(Debug)]
23pub struct SegmentList<Seg> {
24    list: List<Seg>,
25    checkpointing: AtomicBool,
26}
27
28impl<Seg> Default for SegmentList<Seg> {
29    fn default() -> Self {
30        Self {
31            list: Default::default(),
32            checkpointing: Default::default(),
33        }
34    }
35}
36
37impl<Seg> Deref for SegmentList<Seg> {
38    type Target = List<Seg>;
39
40    fn deref(&self) -> &Self::Target {
41        &self.list
42    }
43}
44
45impl<Seg> SegmentList<Seg>
46where
47    Seg: Segment,
48{
49    pub(crate) fn push(&self, segment: Seg) {
50        self.list.prepend(segment);
51    }
52    /// attempt to read page_no with frame_no less than max_frame_no. Returns whether such a page
53    /// was found
54    pub(crate) fn read_page(
55        &self,
56        page_no: u32,
57        max_frame_no: u64,
58        buf: &mut [u8],
59    ) -> Result<bool> {
60        let mut prev_seg = u64::MAX;
61        let mut current = self.list.head.load();
62        let mut i = 0;
63        while let Some(link) = &*current {
64            let last = link.item.last_committed();
65            assert!(prev_seg > last);
66            prev_seg = last;
67            if link.item.read_page(page_no, max_frame_no, buf)? {
68                tracing::trace!("found {page_no} in segment {i}");
69                return Ok(true);
70            }
71
72            i += 1;
73            current = link.next.load();
74        }
75
76        Ok(false)
77    }
78
79    /// Checkpoints as many segments as possible to the main db file, and return the checkpointed
80    /// frame_no, if anything was checkpointed
81    #[tracing::instrument(skip_all)]
82    pub async fn checkpoint<IO: Io>(
83        &self,
84        db_file: &IO::File,
85        until_frame_no: u64,
86        log_id: Uuid,
87        io: &IO,
88    ) -> Result<Option<u64>> {
89        struct Guard<'a>(&'a AtomicBool);
90        impl<'a> Drop for Guard<'a> {
91            fn drop(&mut self) {
92                self.0.store(false, Ordering::SeqCst);
93            }
94        }
95
96        if self
97            .checkpointing
98            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
99            .is_err()
100        {
101            return Ok(None);
102        }
103
104        let _g = Guard(&self.checkpointing);
105
106        let mut segs = Vec::new();
107        let mut current = self.head.load();
108        // find the longest chain of segments that can be checkpointed, iow, segments that do not have
109        // readers pointing to them
110        while let Some(segment) = &*current {
111            // skip any segment more recent than until_frame_no
112            tracing::debug!(
113                last_committed = segment.last_committed(),
114                until = until_frame_no
115            );
116            if segment.last_committed() <= until_frame_no {
117                if !segment.is_checkpointable() {
118                    segs.clear();
119                } else {
120                    segs.push(segment.clone());
121                }
122            }
123            current = segment.next.load();
124        }
125
126        // nothing to checkpoint rn
127        if segs.is_empty() {
128            tracing::debug!("nothing to checkpoint");
129            return Ok(None);
130        }
131
132        let size_after = segs.first().unwrap().size_after();
133
134        let index_iter = segs.iter().map(|s| s.index());
135
136        let mut union = send_fst_ops::SendUnion::from_index_iter(index_iter);
137
138        let mut buf = ZeroCopyBuf::<Frame>::new_uninit();
139        let mut last_replication_index = 0;
140        while let Some((k, v)) = union.next() {
141            let page_no = u32::from_be_bytes(k.try_into().unwrap());
142            tracing::trace!(page_no);
143            let v = v.iter().min_by_key(|i| i.index).unwrap();
144            let offset = v.value as u32;
145
146            let seg = &segs[v.index];
147            let (frame, ret) = seg.item.read_frame_offset_async(offset, buf).await;
148            ret?;
149            assert_eq!(frame.get_ref().header().page_no(), page_no);
150            last_replication_index =
151                last_replication_index.max(frame.get_ref().header().frame_no());
152            let read_buf = frame.map_slice(|f| f.get_ref().data());
153            let (read_buf, ret) = db_file
154                .write_all_at_async(read_buf, (page_no as u64 - 1) * 4096)
155                .await;
156            ret?;
157            buf = read_buf.into_inner();
158        }
159
160        // update the footer at the end of the db file.
161        let footer = LibsqlFooter {
162            magic: LIBSQL_MAGIC.into(),
163            version: LIBSQL_WAL_VERSION.into(),
164            replication_index: last_replication_index.into(),
165            log_id: log_id.as_u128().into(),
166        };
167
168        db_file.set_len(size_after as u64 * LIBSQL_PAGE_SIZE as u64)?;
169
170        let footer_offset = size_after as usize * LIBSQL_PAGE_SIZE as usize;
171        let (_, ret) = db_file
172            .write_all_at_async(ZeroCopyBuf::new_init(footer), footer_offset as u64)
173            .await;
174        ret?;
175
176        // todo: truncate if necessary
177        //// TODO: make async
178        db_file.sync_all()?;
179
180        for seg in segs.iter() {
181            seg.destroy(io).await;
182        }
183
184        let mut current = self.head.compare_and_swap(&segs[0], None);
185        if Arc::ptr_eq(&segs[0], current.as_ref().unwrap()) {
186            // nothing to do
187        } else {
188            loop {
189                let next = current
190                    .as_ref()
191                    .unwrap()
192                    .next
193                    .compare_and_swap(&segs[0], None);
194                if Arc::ptr_eq(&segs[0], next.as_ref().unwrap()) {
195                    break;
196                } else {
197                    current = next;
198                }
199            }
200        }
201
202        self.len.fetch_sub(segs.len(), Ordering::Relaxed);
203
204        tracing::debug!(until = last_replication_index, "checkpointed");
205
206        Ok(Some(last_replication_index))
207    }
208
209    /// returns a stream of pages from the sealed segment list, and what's the lowest replication index
210    /// that was covered. If the returned index is less than start frame_no, the missing frames
211    /// must be read somewhere else.
212    pub async fn stream_pages_from<'a>(
213        &self,
214        current_fno: u64,
215        until_fno: u64,
216        seen: &'a mut RoaringBitmap,
217    ) -> (
218        impl Stream<Item = crate::error::Result<Box<Frame>>> + 'a,
219        u64,
220    ) {
221        // collect all the segments we need to read from to be up to date.
222        // We keep a reference to them so that they are not discarded while we read them.
223        let mut segments = Vec::new();
224        let mut current = self.list.head.load();
225        while current.is_some() {
226            let current_ref = current.as_ref().unwrap();
227            if current_ref.item.last_committed() >= until_fno {
228                segments.push(current_ref.clone());
229                current = current_ref.next.load();
230            } else {
231                break;
232            }
233        }
234
235        if segments.is_empty() {
236            return (
237                tokio_util::either::Either::Left(tokio_stream::empty()),
238                current_fno,
239            );
240        }
241
242        let new_current = segments
243            .last()
244            .map(|s| s.start_frame_no())
245            .unwrap()
246            .max(until_fno);
247
248        let stream = async_stream::try_stream! {
249            let index_iter = segments.iter().map(|s| s.index());
250            let mut union = send_fst_ops::SendUnion::from_index_iter(index_iter);
251            while let Some((key_bytes, indexes)) = union.next() {
252                let page_no = u32::from_be_bytes(key_bytes.try_into().unwrap());
253                // we already have a more recent version of this page.
254                if seen.contains(page_no) {
255                    continue;
256                }
257                let IndexedValue { index: segment_offset, value: frame_offset } = indexes.iter().min_by_key(|i| i.index).unwrap();
258                let segment = &segments[*segment_offset];
259
260                // we can ignore any frame with a replication index less than start_frame_no
261                if segment.start_frame_no() + frame_offset < until_fno {
262                    continue
263                }
264
265                let buf = ZeroCopyBoxIoBuf::new(Frame::new_box_zeroed());
266                let (buf, ret) = segment.read_frame_offset_async(*frame_offset as u32, buf).await;
267                ret?;
268                let mut frame = buf.into_inner();
269                frame.header_mut().size_after = 0.into();
270                seen.insert(page_no);
271                yield frame;
272            }
273        };
274
275        (tokio_util::either::Either::Right(stream), new_current)
276    }
277
278    pub(crate) fn last(&self) -> Option<Seg>
279    where
280        Seg: Clone,
281    {
282        let mut current = self.list.head.load().clone();
283        loop {
284            match current.as_ref() {
285                Some(c) => {
286                    if c.next.load().is_none() {
287                        return Some(c.item.clone());
288                    }
289                    current = c.next.load().clone();
290                }
291                None => return None,
292            }
293        }
294    }
295}
296
297struct Node<T> {
298    item: T,
299    next: ArcSwapOption<Node<T>>,
300}
301
302impl<T> Deref for Node<T> {
303    type Target = T;
304
305    fn deref(&self) -> &Self::Target {
306        &self.item
307    }
308}
309
310pub struct List<T> {
311    head: ArcSwapOption<Node<T>>,
312    len: AtomicUsize,
313}
314
315impl<T: fmt::Debug> fmt::Debug for List<T> {
316    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
317        let mut list = f.debug_list();
318        let mut current = self.head.load();
319        while current.is_some() {
320            list.entry(&current.as_ref().unwrap().item);
321            current = current.as_ref().unwrap().next.load();
322        }
323        list.finish()
324    }
325}
326
327impl<F> Default for List<F> {
328    fn default() -> Self {
329        Self {
330            head: Default::default(),
331            len: Default::default(),
332        }
333    }
334}
335
336impl<T> List<T> {
337    /// Prepend the list with the passed sealed segment
338    pub fn prepend(&self, item: T) {
339        let node = Arc::new(Node {
340            item,
341            next: self.head.load().clone().into(),
342        });
343
344        self.head.swap(Some(node));
345        self.len.fetch_add(1, Ordering::Relaxed);
346    }
347
348    /// Call f on the head of the segments list, if it exists. The head of the list is the most
349    /// recent segment.
350    pub fn with_head<R>(&self, f: impl FnOnce(&T) -> R) -> Option<R> {
351        let head = self.head.load();
352        head.as_ref().map(|link| f(&link.item))
353    }
354
355    pub fn len(&self) -> usize {
356        self.len.load(Ordering::Relaxed)
357    }
358
359    pub fn is_empty(&self) -> bool {
360        self.len() == 0
361    }
362}
363
364mod send_fst_ops {
365    use std::ops::{Deref, DerefMut};
366    use std::sync::Arc;
367
368    use fst::map::{OpBuilder, Union};
369
370    /// Safety: Union contains a Box<dyn trait> that doesn't require Send, to it's not send.
371    /// That's an issue for us, but all the indexes we have are safe to send, so we're good.
372    /// FIXME: we could implement union ourselves.
373    unsafe impl Send for SendUnion<'_> {}
374    unsafe impl Sync for SendUnion<'_> {}
375
376    #[repr(transparent)]
377    pub(super) struct SendUnion<'a>(Union<'a>);
378
379    impl<'a> SendUnion<'a> {
380        pub fn from_index_iter<I>(iter: I) -> Self
381        where
382            I: Iterator<Item = &'a fst::map::Map<Arc<[u8]>>>,
383        {
384            let op = iter.collect::<OpBuilder>().union();
385            Self(op)
386        }
387    }
388
389    impl<'a> Deref for SendUnion<'a> {
390        type Target = Union<'a>;
391
392        fn deref(&self) -> &Self::Target {
393            &self.0
394        }
395    }
396
397    impl<'a> DerefMut for SendUnion<'a> {
398        fn deref_mut(&mut self) -> &mut Self::Target {
399            &mut self.0
400        }
401    }
402}
403
404#[cfg(test)]
405mod test {
406    use std::io::{Read, Seek, Write};
407    use tempfile::{tempfile, NamedTempFile};
408    use tokio_stream::StreamExt as _;
409
410    use crate::test::{seal_current_segment, TestEnv};
411
412    use super::*;
413
414    #[tokio::test]
415    async fn stream_pages() {
416        let env = TestEnv::new();
417        let conn = env.open_conn("test");
418        let shared = env.shared("test");
419
420        conn.execute("CREATE TABLE t1(a INTEGER PRIMARY KEY, b BLOB(16));", ())
421            .unwrap();
422        conn.execute("CREATE INDEX i1 ON t1(b);", ()).unwrap();
423
424        for _ in 0..100 {
425            for _ in 0..10 {
426                conn.execute(
427                    "REPLACE INTO t1 VALUES(abs(random() % 500), randomblob(16));",
428                    (),
429                )
430                .unwrap();
431            }
432            seal_current_segment(&shared);
433        }
434
435        seal_current_segment(&shared);
436
437        let current = shared.current.load();
438        let segment_list = current.tail();
439        let mut seen = RoaringBitmap::new();
440        let (stream, _) = segment_list.stream_pages_from(0, 0, &mut seen).await;
441        tokio::pin!(stream);
442
443        let mut file = NamedTempFile::new().unwrap();
444        let mut tx = shared.begin_read(999999).into();
445        while let Some(frame) = stream.next().await {
446            let frame = frame.unwrap();
447            let mut buffer = [0; 4096];
448            shared
449                .read_page(&mut tx, frame.header.page_no(), &mut buffer)
450                .unwrap();
451            assert_eq!(buffer, frame.data());
452            file.write_all(frame.data()).unwrap();
453        }
454
455        drop(tx);
456
457        *shared.durable_frame_no.lock() = 999999;
458        shared.checkpoint().await.unwrap();
459        file.seek(std::io::SeekFrom::Start(0)).unwrap();
460        let mut copy_bytes = Vec::new();
461        file.read_to_end(&mut copy_bytes).unwrap();
462
463        let mut orig_bytes = Vec::new();
464        shared
465            .db_file
466            .try_clone()
467            .unwrap()
468            .read_to_end(&mut orig_bytes)
469            .unwrap();
470
471        assert_eq!(db_payload(&orig_bytes), db_payload(&copy_bytes));
472    }
473
474    #[tokio::test]
475    async fn stream_pages_skip_before_start_fno() {
476        let env = TestEnv::new();
477        let conn = env.open_conn("test");
478        let shared = env.shared("test");
479
480        conn.execute("CREATE TABLE test(x);", ()).unwrap();
481
482        for _ in 0..10 {
483            conn.execute("INSERT INTO test VALUES(42)", ()).unwrap();
484        }
485
486        seal_current_segment(&shared);
487
488        let current = shared.current.load();
489        let segment_list = current.tail();
490        let mut seen = RoaringBitmap::new();
491        let (stream, replicated_until) = segment_list.stream_pages_from(0, 10, &mut seen).await;
492        tokio::pin!(stream);
493
494        assert_eq!(replicated_until, 10);
495
496        while let Some(frame) = stream.next().await {
497            let frame = frame.unwrap();
498            assert!(frame.header().frame_no() >= 10);
499        }
500    }
501
502    #[tokio::test]
503    async fn stream_pages_ignore_already_seen_pages() {
504        let env = TestEnv::new();
505        let conn = env.open_conn("test");
506        let shared = env.shared("test");
507
508        conn.execute("CREATE TABLE test(x);", ()).unwrap();
509
510        for _ in 0..10 {
511            conn.execute("INSERT INTO test VALUES(42)", ()).unwrap();
512        }
513
514        seal_current_segment(&shared);
515
516        let current = shared.current.load();
517        let segment_list = current.tail();
518        let mut seen = RoaringBitmap::from_sorted_iter([1]).unwrap();
519        let (stream, replicated_until) = segment_list.stream_pages_from(0, 1, &mut seen).await;
520        tokio::pin!(stream);
521
522        assert_eq!(replicated_until, 1);
523
524        while let Some(frame) = stream.next().await {
525            let frame = frame.unwrap();
526            assert_ne!(!frame.header().page_no(), 1);
527        }
528    }
529
530    #[tokio::test]
531    async fn stream_pages_resume_replication() {
532        let env = TestEnv::new();
533        let conn = env.open_conn("test");
534        let shared = env.shared("test");
535
536        conn.execute("CREATE TABLE test(x);", ()).unwrap();
537
538        for _ in 0..10 {
539            conn.execute("INSERT INTO test VALUES(42)", ()).unwrap();
540        }
541
542        seal_current_segment(&shared);
543
544        let current = shared.current.load();
545        let segment_list = current.tail();
546        let mut seen = RoaringBitmap::new();
547        let (stream, replicated_until) = segment_list.stream_pages_from(0, 1, &mut seen).await;
548        tokio::pin!(stream);
549
550        assert_eq!(replicated_until, 1);
551
552        let mut tmp = tempfile().unwrap();
553
554        let mut last_offset = 0;
555        while let Some(frame) = stream.next().await {
556            let frame = frame.unwrap();
557            let offset = (frame.header().page_no() - 1) * 4096;
558            tmp.write_all_at(frame.data(), offset as u64).unwrap();
559            last_offset = last_offset.max(frame.header().frame_no());
560        }
561
562        for _ in 0..10 {
563            conn.execute("INSERT INTO test VALUES(42)", ()).unwrap();
564        }
565
566        seal_current_segment(&shared);
567
568        let mut seen = RoaringBitmap::new();
569        let (stream, replicated_until) = segment_list
570            .stream_pages_from(0, last_offset, &mut seen)
571            .await;
572        tokio::pin!(stream);
573
574        assert_eq!(replicated_until, last_offset);
575
576        while let Some(frame) = stream.next().await {
577            let frame = frame.unwrap();
578            let offset = (frame.header().page_no() - 1) * 4096;
579            tmp.write_all_at(frame.data(), offset as u64).unwrap();
580        }
581
582        *shared.durable_frame_no.lock() = 999999;
583
584        shared.checkpoint().await.unwrap();
585        tmp.seek(std::io::SeekFrom::Start(0)).unwrap();
586        let mut copy_bytes = Vec::new();
587        tmp.read_to_end(&mut copy_bytes).unwrap();
588
589        let mut orig_bytes = Vec::new();
590        shared
591            .db_file
592            .try_clone()
593            .unwrap()
594            .read_to_end(&mut orig_bytes)
595            .unwrap();
596
597        assert_eq!(db_payload(&copy_bytes), db_payload(&orig_bytes));
598    }
599
600    #[tokio::test]
601    async fn stream_start_frame_no_before_sealed_segments() {
602        let env = TestEnv::new();
603        let conn = env.open_conn("test");
604        let shared = env.shared("test");
605
606        conn.execute("CREATE TABLE test(x);", ()).unwrap();
607
608        for _ in 0..10 {
609            conn.execute("INSERT INTO test VALUES(42)", ()).unwrap();
610        }
611
612        seal_current_segment(&shared);
613        *shared.durable_frame_no.lock() = 999999;
614        shared.checkpoint().await.unwrap();
615
616        for _ in 0..10 {
617            conn.execute("INSERT INTO test VALUES(42)", ()).unwrap();
618        }
619        seal_current_segment(&shared);
620
621        let current = shared.current.load();
622        let segment_list = current.tail();
623        let mut seen = RoaringBitmap::new();
624        let (stream, replicated_from) = segment_list.stream_pages_from(0, 0, &mut seen).await;
625        tokio::pin!(stream);
626
627        let mut count = 0;
628        while let Some(_) = stream.next().await {
629            count += 1;
630        }
631
632        assert_eq!(count, 1);
633        assert_eq!(replicated_from, 13);
634    }
635
636    fn db_payload(db: &[u8]) -> &[u8] {
637        let size = (db.len() / 4096) * 4096;
638        &db[..size]
639    }
640}