libsql_wal/
registry.rs

1use std::io;
2use std::num::NonZeroU64;
3use std::path::Path;
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use dashmap::DashMap;
9use libsql_sys::ffi::Sqlite3DbHeader;
10use parking_lot::{Condvar, Mutex};
11use rand::Rng;
12use roaring::RoaringBitmap;
13use tokio::sync::{mpsc, Notify, Semaphore};
14use tokio::task::JoinSet;
15use tokio_stream::StreamExt;
16use uuid::Uuid;
17use zerocopy::{AsBytes, FromZeroes};
18
19use crate::checkpointer::CheckpointMessage;
20use crate::error::Result;
21use crate::io::file::FileExt;
22use crate::io::{Io, StdIO};
23use crate::replication::injector::Injector;
24use crate::replication::storage::{ReplicateFromStorage as _, StorageReplicator};
25use crate::segment::list::SegmentList;
26use crate::segment::Segment;
27use crate::segment::{current::CurrentSegment, sealed::SealedSegment};
28use crate::segment_swap_strategy::duration::DurationSwapStrategy;
29use crate::segment_swap_strategy::frame_count::FrameCountSwapStrategy;
30use crate::segment_swap_strategy::SegmentSwapStrategy;
31use crate::shared_wal::{SharedWal, SwapLog};
32use crate::storage::{OnStoreCallback, Storage};
33use crate::transaction::TxGuard;
34use crate::{LibsqlFooter, LIBSQL_PAGE_SIZE};
35use libsql_sys::name::NamespaceName;
36
37enum Slot<IO: Io> {
38    Wal(Arc<SharedWal<IO>>),
39    /// Only a single thread is allowed to instantiate the wal. The first thread to acquire an
40    /// entry in the registry map puts a building slot. Other connections will wait for the mutex
41    /// to turn to true, after the slot has been updated to contain the wal
42    Building(Arc<(Condvar, Mutex<bool>)>, Arc<Notify>),
43    /// The namespace was removed
44    Tombstone,
45}
46
47/// Wal Registry maintains a set of shared Wal, and their respective set of files.
48pub struct WalRegistry<IO: Io, S> {
49    io: Arc<IO>,
50    shutdown: AtomicBool,
51    opened: DashMap<NamespaceName, Slot<IO>>,
52    storage: Arc<S>,
53    checkpoint_notifier: mpsc::Sender<CheckpointMessage>,
54}
55
56impl<S> WalRegistry<StdIO, S> {
57    pub fn new(
58        storage: Arc<S>,
59        checkpoint_notifier: mpsc::Sender<CheckpointMessage>,
60    ) -> Result<Self> {
61        Self::new_with_io(StdIO(()), storage, checkpoint_notifier)
62    }
63}
64
65impl<IO: Io, S> WalRegistry<IO, S> {
66    pub fn new_with_io(
67        io: IO,
68        storage: Arc<S>,
69        checkpoint_notifier: mpsc::Sender<CheckpointMessage>,
70    ) -> Result<Self> {
71        let registry = Self {
72            io: io.into(),
73            opened: Default::default(),
74            shutdown: Default::default(),
75            storage,
76            checkpoint_notifier,
77        };
78
79        Ok(registry)
80    }
81
82    pub async fn get_async(&self, namespace: &NamespaceName) -> Option<Arc<SharedWal<IO>>> {
83        loop {
84            let notify = {
85                match self.opened.get(namespace).as_deref() {
86                    Some(Slot::Wal(wal)) => return Some(wal.clone()),
87                    Some(Slot::Building(_, notify)) => notify.clone(),
88                    Some(Slot::Tombstone) => return None,
89                    None => return None,
90                }
91            };
92
93            notify.notified().await
94        }
95    }
96}
97
98impl<IO, S> SwapLog<IO> for WalRegistry<IO, S>
99where
100    IO: Io,
101    S: Storage<Segment = SealedSegment<IO::File>>,
102{
103    #[tracing::instrument(skip_all)]
104    fn swap_current(
105        &self,
106        shared: &SharedWal<IO>,
107        tx: &dyn TxGuard<<IO as Io>::File>,
108    ) -> Result<()> {
109        assert!(tx.is_commited());
110        self.swap_current_inner(shared)
111    }
112}
113
114#[tracing::instrument(skip_all, fields(namespace = namespace.as_str(), start_frame_no = seg.start_frame_no()))]
115fn maybe_store_segment<S: Storage>(
116    storage: &S,
117    notifier: &tokio::sync::mpsc::Sender<CheckpointMessage>,
118    namespace: &NamespaceName,
119    durable_frame_no: &Arc<Mutex<u64>>,
120    seg: S::Segment,
121) {
122    if seg.last_committed() > *durable_frame_no.lock() {
123        let cb: OnStoreCallback = Box::new({
124            let notifier = notifier.clone();
125            let durable_frame_no = durable_frame_no.clone();
126            let namespace = namespace.clone();
127            move |fno| {
128                Box::pin(async move {
129                    update_durable(fno, notifier, durable_frame_no, namespace).await;
130                })
131            }
132        });
133        storage.store(namespace, seg, None, cb);
134    } else {
135        // segment can be checkpointed right away.
136        // FIXME: this is only necessary because some tests call this method in an async context.
137        #[cfg(debug_assertions)]
138        {
139            let namespace = namespace.clone();
140            let notifier = notifier.clone();
141            tokio::spawn(async move {
142                let _ = notifier.send(CheckpointMessage::Namespace(namespace)).await;
143            });
144        }
145
146        #[cfg(not(debug_assertions))]
147        {
148            let _ = notifier.blocking_send(CheckpointMessage::Namespace(namespace.clone()));
149        }
150
151        tracing::debug!(
152            segment_end = seg.last_committed(),
153            durable_frame_no = *durable_frame_no.lock(),
154            "segment doesn't contain any new data"
155        );
156    }
157}
158
159async fn update_durable(
160    new_durable: u64,
161    notifier: mpsc::Sender<CheckpointMessage>,
162    durable_frame_no_slot: Arc<Mutex<u64>>,
163    namespace: NamespaceName,
164) {
165    {
166        let mut g = durable_frame_no_slot.lock();
167        if *g < new_durable {
168            *g = new_durable;
169        }
170    }
171    let _ = notifier.send(CheckpointMessage::Namespace(namespace)).await;
172}
173
174impl<IO, S> WalRegistry<IO, S>
175where
176    IO: Io,
177    S: Storage<Segment = SealedSegment<IO::File>>,
178{
179    #[tracing::instrument(skip(self))]
180    pub fn open(
181        self: Arc<Self>,
182        db_path: &Path,
183        namespace: &NamespaceName,
184    ) -> Result<Arc<SharedWal<IO>>> {
185        if self.shutdown.load(Ordering::SeqCst) {
186            return Err(crate::error::Error::ShuttingDown);
187        }
188
189        loop {
190            if let Some(entry) = self.opened.get(namespace) {
191                match &*entry {
192                    Slot::Wal(wal) => return Ok(wal.clone()),
193                    Slot::Building(cond, _) => {
194                        let cond = cond.clone();
195                        cond.0
196                            .wait_while(&mut cond.1.lock(), |ready: &mut bool| !*ready);
197                        // the slot was updated: try again
198                        continue;
199                    }
200                    Slot::Tombstone => return Err(crate::error::Error::DeletingWal),
201                }
202            }
203
204            let action = match self.opened.entry(namespace.clone()) {
205                dashmap::Entry::Occupied(e) => match e.get() {
206                    Slot::Wal(shared) => return Ok(shared.clone()),
207                    Slot::Building(wait, _) => Err(wait.clone()),
208                    Slot::Tombstone => return Err(crate::error::Error::DeletingWal),
209                },
210                dashmap::Entry::Vacant(e) => {
211                    let notifier = Arc::new((Condvar::new(), Mutex::new(false)));
212                    let async_notifier = Arc::new(Notify::new());
213                    e.insert(Slot::Building(notifier.clone(), async_notifier.clone()));
214                    Ok((notifier, async_notifier))
215                }
216            };
217
218            match action {
219                Ok((notifier, async_notifier)) => {
220                    // if try_open succedded, then the slot was updated and contains the shared wal, if it
221                    // failed we need to remove the slot. Either way, notify all waiters
222                    let ret = self.clone().try_open(&namespace, db_path);
223                    if ret.is_err() {
224                        self.opened.remove(namespace);
225                    }
226
227                    *notifier.1.lock() = true;
228                    notifier.0.notify_all();
229                    async_notifier.notify_waiters();
230
231                    return ret;
232                }
233                Err(cond) => {
234                    cond.0
235                        .wait_while(&mut cond.1.lock(), |ready: &mut bool| !*ready);
236                    // the slot was updated: try again
237                    continue;
238                }
239            }
240        }
241    }
242
243    fn try_open(
244        self: Arc<Self>,
245        namespace: &NamespaceName,
246        db_path: &Path,
247    ) -> Result<Arc<SharedWal<IO>>> {
248        let db_file = self.io.open(false, true, true, db_path)?;
249        let db_file_len = db_file.len()?;
250        let header = if db_file_len > 0 {
251            let mut header: Sqlite3DbHeader = Sqlite3DbHeader::new_zeroed();
252            db_file.read_exact_at(header.as_bytes_mut(), 0)?;
253            Some(header)
254        } else {
255            None
256        };
257
258        let footer = self.try_read_footer(&db_file)?;
259
260        let mut checkpointed_frame_no = footer.map(|f| f.replication_index.get()).unwrap_or(0);
261
262        // the trick here to prevent sqlite to open our db is to create a dir <db-name>-wal. Sqlite
263        // will think that this is a wal file, but it's in fact a directory and it will not like
264        // it.
265        let mut wals_path = db_path.to_owned();
266        wals_path.set_file_name(format!(
267            "{}-wal",
268            db_path.file_name().unwrap().to_str().unwrap()
269        ));
270        self.io.create_dir_all(&wals_path)?;
271        // TODO: handle that with abstract io
272        let dir = walkdir::WalkDir::new(&wals_path)
273            .sort_by_file_name()
274            .into_iter();
275
276        // we only checkpoint durable frame_no so this is a good first estimate without an actual
277        // network call.
278        let durable_frame_no = Arc::new(Mutex::new(checkpointed_frame_no));
279
280        let list = SegmentList::default();
281        for entry in dir {
282            let entry = entry.map_err(|e| e.into_io_error().unwrap())?;
283            if entry
284                .path()
285                .extension()
286                .map(|e| e.to_str().unwrap() != "seg")
287                .unwrap_or(true)
288            {
289                continue;
290            }
291
292            let file = self.io.open(false, true, true, entry.path())?;
293
294            if let Some(sealed) = SealedSegment::open(
295                file.into(),
296                entry.path().to_path_buf(),
297                Default::default(),
298                self.io.now(),
299            )? {
300                list.push(sealed.clone());
301                maybe_store_segment(
302                    self.storage.as_ref(),
303                    &self.checkpoint_notifier,
304                    &namespace,
305                    &durable_frame_no,
306                    sealed,
307                );
308            }
309        }
310
311        let log_id = match footer {
312            Some(footer) if list.is_empty() => footer.log_id(),
313            None if list.is_empty() => self.io.uuid(),
314            Some(footer) => {
315                let log_id = list
316                    .with_head(|h| h.header().log_id.get())
317                    .expect("non-empty list should have a head");
318                let log_id = Uuid::from_u128(log_id);
319                assert_eq!(log_id, footer.log_id());
320                log_id
321            }
322            None => {
323                let log_id = list
324                    .with_head(|h| h.header().log_id.get())
325                    .expect("non-empty list should have a head");
326                Uuid::from_u128(log_id)
327            }
328        };
329
330        // if there is a tail, then the latest checkpointed frame_no is one before the the
331        // start frame_no of the tail. We must read it from the tail, because a partial
332        // checkpoint may have occured before a crash.
333        if let Some(last) = list.last() {
334            checkpointed_frame_no = (last.start_frame_no() - 1).max(1)
335        }
336
337        let (db_size, next_frame_no) = list
338            .with_head(|segment| {
339                let header = segment.header();
340                (header.size_after(), header.next_frame_no())
341            })
342            .unwrap_or_else(|| match header {
343                Some(header) => (
344                    header.db_size.get(),
345                    NonZeroU64::new(checkpointed_frame_no + 1)
346                        .unwrap_or(NonZeroU64::new(1).unwrap()),
347                ),
348                None => (0, NonZeroU64::new(1).unwrap()),
349            });
350
351        let current_segment_path = wals_path.join(format!("{next_frame_no:020}.seg"));
352
353        let segment_file = self.io.open(true, true, true, &current_segment_path)?;
354        let salt = self.io.with_rng(|rng| rng.gen());
355
356        let current = arc_swap::ArcSwap::new(Arc::new(CurrentSegment::create(
357            segment_file,
358            current_segment_path,
359            next_frame_no,
360            db_size,
361            list.into(),
362            salt,
363            log_id,
364        )?));
365
366        let (new_frame_notifier, _) = tokio::sync::watch::channel(next_frame_no.get() - 1);
367
368        // FIXME: make swap strategy configurable
369        // This strategy will perform a swap if either the wal is bigger than 20k frames, or older
370        // than 10 minutes, or if the frame count is greater than a 1000 and the wal was last
371        // swapped more than 30 secs ago
372        let swap_strategy = Box::new(
373            DurationSwapStrategy::new(Duration::from_secs(5 * 60))
374                .or(FrameCountSwapStrategy::new(20_000))
375                .or(FrameCountSwapStrategy::new(1000)
376                    .and(DurationSwapStrategy::new(Duration::from_secs(30)))),
377        );
378
379        let shared = Arc::new(SharedWal {
380            current,
381            wal_lock: Default::default(),
382            db_file,
383            registry: self.clone(),
384            namespace: namespace.clone(),
385            checkpointed_frame_no: checkpointed_frame_no.into(),
386            new_frame_notifier,
387            durable_frame_no,
388            stored_segments: Box::new(StorageReplicator::new(
389                self.storage.clone(),
390                namespace.clone(),
391            )),
392            shutdown: false.into(),
393            checkpoint_notifier: self.checkpoint_notifier.clone(),
394            io: self.io.clone(),
395            swap_strategy,
396            wals_path: wals_path.to_owned(),
397        });
398
399        self.opened
400            .insert(namespace.clone(), Slot::Wal(shared.clone()));
401
402        return Ok(shared);
403    }
404
405    fn try_read_footer(&self, db_file: &impl FileExt) -> Result<Option<LibsqlFooter>> {
406        let len = db_file.len()?;
407        if len as usize % LIBSQL_PAGE_SIZE as usize == size_of::<LibsqlFooter>() {
408            let mut footer: LibsqlFooter = LibsqlFooter::new_zeroed();
409            let footer_offset = (len / LIBSQL_PAGE_SIZE as u64) * LIBSQL_PAGE_SIZE as u64;
410            db_file.read_exact_at(footer.as_bytes_mut(), footer_offset)?;
411            footer.validate()?;
412            Ok(Some(footer))
413        } else {
414            Ok(None)
415        }
416    }
417
418    pub async fn tombstone(&self, namespace: &NamespaceName) -> Option<Arc<SharedWal<IO>>> {
419        // if a wal is currently being openned, let it
420        {
421            let v = self.opened.get(namespace)?;
422            if let Slot::Building(_, ref notify) = *v {
423                notify.clone().notified().await;
424            }
425        }
426
427        match self.opened.insert(namespace.clone(), Slot::Tombstone) {
428            Some(Slot::Tombstone) => None,
429            Some(Slot::Building(_, _)) => {
430                // FIXME: that could happen is someone removed it and immediately reopenned the
431                // wal. fix by retrying in a loop
432                unreachable!("already waited for ns to open")
433            }
434            Some(Slot::Wal(wal)) => Some(wal),
435            None => None,
436        }
437    }
438
439    pub async fn remove(&self, namespace: &NamespaceName) {
440        // if a wal is currently being openned, let it
441        {
442            let v = self.opened.get(namespace);
443            if let Some(Slot::Building(_, ref notify)) = v.as_deref() {
444                notify.clone().notified().await;
445            }
446        }
447
448        self.opened.remove(namespace);
449    }
450
451    /// Attempts to sync all loaded dbs with durable storage
452    pub async fn sync_all(&self, conccurency: usize) -> Result<()>
453    where
454        S: Storage,
455    {
456        let mut join_set = JoinSet::new();
457        tracing::info!("syncing {} namespaces", self.opened.len());
458        // FIXME: arbitrary value, maybe use something like numcpu * 2?
459        let before_sync = Instant::now();
460        let sem = Arc::new(Semaphore::new(conccurency));
461        for entry in self.opened.iter() {
462            let Slot::Wal(shared) = entry.value() else {
463                panic!("all wals should already be opened")
464            };
465            let storage = self.storage.clone();
466            let shared = shared.clone();
467            let sem = sem.clone();
468            let permit = sem.acquire_owned().await.unwrap();
469
470            join_set.spawn(async move {
471                let _permit = permit;
472                sync_one(shared, storage).await
473            });
474
475            if let Some(ret) = join_set.try_join_next() {
476                ret.unwrap()?;
477            }
478        }
479
480        while let Some(ret) = join_set.join_next().await {
481            ret.unwrap()?;
482        }
483
484        tracing::info!("synced in {:?}", before_sync.elapsed());
485
486        Ok(())
487    }
488
489    // On shutdown, we checkpoint all the WALs. This require sealing the current segment, and when
490    // checkpointing all the segments
491    pub async fn shutdown(self: Arc<Self>) -> Result<()> {
492        tracing::info!("shutting down registry");
493        self.shutdown.store(true, Ordering::SeqCst);
494
495        let mut join_set = JoinSet::<Result<()>>::new();
496        let semaphore = Arc::new(Semaphore::new(8));
497        for item in self.opened.iter() {
498            let (name, slot) = item.pair();
499            loop {
500                match slot {
501                    Slot::Wal(shared) => {
502                        // acquire a permit or drain the join set
503                        let permit = loop {
504                            tokio::select! {
505                                permit = semaphore.clone().acquire_owned() => break permit,
506                                _ = join_set.join_next() => (),
507                            }
508                        };
509                        let shared = shared.clone();
510                        let name = name.clone();
511
512                        join_set.spawn_blocking(move || {
513                            let _permit = permit;
514                            if let Err(e) = shared.shutdown() {
515                                tracing::error!("error shutting down `{name}`: {e}");
516                            }
517
518                            Ok(())
519                        });
520                        break;
521                    }
522                    Slot::Building(_, notify) => {
523                        // wait for shared to finish building
524                        notify.notified().await;
525                    }
526                    Slot::Tombstone => continue,
527                }
528            }
529        }
530
531        while join_set.join_next().await.is_some() {}
532
533        // we process any pending storage job, then checkpoint everything
534        self.storage.shutdown().await;
535
536        // wait for checkpointer to exit
537        let _ = self
538            .checkpoint_notifier
539            .send(CheckpointMessage::Shutdown)
540            .await;
541        self.checkpoint_notifier.closed().await;
542
543        tracing::info!("registry shutdown gracefully");
544
545        Ok(())
546    }
547
548    #[tracing::instrument(skip_all)]
549    fn swap_current_inner(&self, shared: &SharedWal<IO>) -> Result<()> {
550        let current = shared.current.load();
551        if current.is_empty() {
552            return Ok(());
553        }
554        let start_frame_no = current.next_frame_no();
555        let path = shared.wals_path.join(format!("{start_frame_no:020}.seg"));
556
557        let segment_file = self.io.open(true, true, true, &path)?;
558        let salt = self.io.with_rng(|rng| rng.gen());
559        let new = CurrentSegment::create(
560            segment_file,
561            path,
562            start_frame_no,
563            current.db_size(),
564            current.tail().clone(),
565            salt,
566            current.log_id(),
567        )?;
568        // sealing must the last fallible operation, because we don't want to end up in a situation
569        // where the current log is sealed and it wasn't swapped.
570        if let Some(sealed) = current.seal(self.io.now())? {
571            new.tail().push(sealed.clone());
572            maybe_store_segment(
573                self.storage.as_ref(),
574                &self.checkpoint_notifier,
575                &shared.namespace,
576                &shared.durable_frame_no,
577                sealed,
578            );
579        }
580
581        shared.current.swap(Arc::new(new));
582        tracing::debug!("current segment swapped");
583
584        Ok(())
585    }
586
587    pub fn storage(&self) -> Arc<S> {
588        self.storage.clone()
589    }
590}
591
592#[tracing::instrument(skip_all, fields(namespace = shared.namespace().as_str()))]
593async fn sync_one<IO, S>(shared: Arc<SharedWal<IO>>, storage: Arc<S>) -> Result<()>
594where
595    IO: Io,
596    S: Storage,
597{
598    let remote_durable_frame_no = storage
599        .durable_frame_no(shared.namespace(), None)
600        .await
601        .map_err(Box::new)?;
602    let local_current_frame_no = shared.current.load().next_frame_no().get() - 1;
603
604    if remote_durable_frame_no > local_current_frame_no {
605        tracing::info!(
606            remote_durable_frame_no,
607            local_current_frame_no,
608            "remote storage has newer segments"
609        );
610        let mut seen = RoaringBitmap::new();
611        let replicator = StorageReplicator::new(storage, shared.namespace().clone());
612        let stream = replicator
613            .stream(&mut seen, remote_durable_frame_no, 1)
614            .peekable();
615        let mut injector = Injector::new(shared.clone(), 10)?;
616        // we set the durable frame_no before we start injecting, because the wal may want to
617        // checkpoint on commit.
618        injector.set_durable(remote_durable_frame_no);
619        // use pin to the heap so that we can drop the stream in the loop, and count `seen`.
620        let mut stream = Box::pin(stream);
621        loop {
622            match stream.next().await {
623                Some(Ok(mut frame)) => {
624                    if stream.peek().await.is_none() {
625                        drop(stream);
626                        frame.header_mut().set_size_after(seen.len() as _);
627                        injector.insert_frame(frame).await?;
628                        break;
629                    } else {
630                        injector.insert_frame(frame).await?;
631                    }
632                }
633                Some(Err(e)) => todo!("handle error: {e}, {}", shared.namespace()),
634                None => break,
635            }
636        }
637    }
638
639    tracing::info!("local database is up to date");
640
641    Ok(())
642}
643
644fn read_log_id_from_footer<F: FileExt>(db_file: &F, db_size: u64) -> io::Result<Uuid> {
645    let mut footer: LibsqlFooter = LibsqlFooter::new_zeroed();
646    let footer_offset = LIBSQL_PAGE_SIZE as u64 * db_size;
647    // FIXME: failing to read the footer here is a sign of corrupted database: either we
648    // have a tail to the segment list, or we have fully checkpointed the database. Can we
649    // recover from that?
650    db_file.read_exact_at(footer.as_bytes_mut(), footer_offset)?;
651    Ok(footer.log_id())
652}