libsql_wal/
checkpointer.rs

1use std::future::Future;
2use std::sync::Arc;
3
4use hashbrown::HashSet;
5use libsql_sys::name::NamespaceName;
6use tokio::sync::mpsc;
7use tokio::task::JoinSet;
8
9use crate::io::Io;
10use crate::registry::WalRegistry;
11
12pub(crate) type NotifyCheckpointer = mpsc::Sender<NamespaceName>;
13
14pub enum CheckpointMessage {
15    /// notify that a namespace may be checkpointable
16    Namespace(NamespaceName),
17    /// shutdown initiated
18    Shutdown,
19}
20
21impl From<NamespaceName> for CheckpointMessage {
22    fn from(value: NamespaceName) -> Self {
23        Self::Namespace(value)
24    }
25}
26
27pub type LibsqlCheckpointer<IO, S> = Checkpointer<WalRegistry<IO, S>>;
28
29impl<IO, S> LibsqlCheckpointer<IO, S>
30where
31    IO: Io,
32    S: Sync + Send + 'static,
33{
34    pub fn new(
35        registry: Arc<WalRegistry<IO, S>>,
36        notifier: mpsc::Receiver<CheckpointMessage>,
37        max_checkpointing_conccurency: usize,
38    ) -> Self {
39        Self::new_with_performer(registry, notifier, max_checkpointing_conccurency)
40    }
41}
42
43trait PerformCheckpoint {
44    fn checkpoint(
45        &self,
46        namespace: &NamespaceName,
47    ) -> impl Future<Output = crate::error::Result<()>> + Send;
48}
49
50impl<IO, S> PerformCheckpoint for WalRegistry<IO, S>
51where
52    IO: Io,
53    S: Sync + Send + 'static,
54{
55    #[tracing::instrument(skip(self))]
56    fn checkpoint(
57        &self,
58        namespace: &NamespaceName,
59    ) -> impl Future<Output = crate::error::Result<()>> + Send {
60        let namespace = namespace.clone();
61        async move {
62            if let Some(registry) = self.get_async(&namespace).await {
63                registry.checkpoint().await?;
64            }
65            Ok(())
66        }
67    }
68}
69
70const CHECKPOINTER_ERROR_THRES: usize = 16;
71
72/// The checkpointer checkpoint wal segments in the main db file, and deletes checkpointed
73/// segments.
74/// For simplicity of implementation, we only delete segments when they are checkpointed, and only checkpoint when
75/// they are reported as durable.
76#[derive(Debug)]
77pub struct Checkpointer<P> {
78    perform_checkpoint: Arc<P>,
79    /// Namespaces scheduled for checkpointing, but not currently checkpointing
80    scheduled: HashSet<NamespaceName>,
81    /// currently checkpointing databases
82    checkpointing: HashSet<NamespaceName>,
83    /// the checkpointer is notifier whenever there is a change to a namespage that could trigger a
84    /// checkpoint
85    recv: mpsc::Receiver<CheckpointMessage>,
86    max_checkpointing_conccurency: usize,
87    shutting_down: bool,
88    join_set: JoinSet<(NamespaceName, crate::error::Result<()>)>,
89    processing: Vec<NamespaceName>,
90    errors: usize,
91    /// previous iteration of the loop resulted in no work being enqueued
92    no_work: bool,
93}
94
95#[allow(private_bounds)]
96impl<P> Checkpointer<P>
97where
98    P: PerformCheckpoint + Send + Sync + 'static,
99{
100    fn new_with_performer(
101        perform_checkpoint: Arc<P>,
102        notifier: mpsc::Receiver<CheckpointMessage>,
103        max_checkpointing_conccurency: usize,
104    ) -> Self {
105        Self {
106            perform_checkpoint,
107            scheduled: Default::default(),
108            checkpointing: Default::default(),
109            recv: notifier,
110            max_checkpointing_conccurency,
111            shutting_down: false,
112            join_set: JoinSet::new(),
113            processing: Vec::new(),
114            errors: 0,
115            no_work: false,
116        }
117    }
118
119    #[tracing::instrument(skip(self))]
120    pub async fn run(mut self) {
121        loop {
122            if self.should_exit() {
123                tracing::info!("checkpointer exited cleanly.");
124                return;
125            }
126
127            if self.errors > CHECKPOINTER_ERROR_THRES {
128                todo!("handle too many consecutive errors");
129            }
130
131            self.step().await;
132        }
133    }
134
135    fn should_exit(&self) -> bool {
136        self.shutting_down
137            && self.recv.is_empty()
138            && self.scheduled.is_empty()
139            && self.checkpointing.is_empty()
140            && self.join_set.is_empty()
141    }
142
143    async fn step(&mut self) {
144        tokio::select! {
145            biased;
146            result = self.join_set.join_next(), if !self.join_set.is_empty() => {
147                match result {
148                    Some(Ok((namespace, result))) => {
149                        self.checkpointing.remove(&namespace);
150                        if let Err(e) = result {
151                            self.errors += 1;
152                            tracing::error!("error checkpointing ns {namespace}: {e}, rescheduling");
153                            // reschedule
154                            self.scheduled.insert(namespace);
155                        } else {
156                            self.errors = 0;
157                        }
158                    }
159                    Some(Err(e)) => panic!("checkpoint task panicked: {e}"),
160                    None => unreachable!("got None, but join set is not empty")
161                }
162            }
163            notified = self.recv.recv(), if !self.shutting_down => {
164                match notified {
165                    Some(CheckpointMessage::Namespace(namespace)) => {
166                        tracing::info!(namespace = namespace.as_str(), "notified for checkpoint");
167                        self.scheduled.insert(namespace);
168                    }
169                    None | Some(CheckpointMessage::Shutdown) => {
170                        tracing::info!("checkpointed is shutting down. {} namespaces to checkpoint", self.checkpointing.len());
171                        self.shutting_down = true;
172                    }
173                }
174            }
175            // don't wait if there is stuff to enqueue
176            _ = std::future::ready(()), if !self.scheduled.is_empty()
177                && self.join_set.len() < self.max_checkpointing_conccurency && !self.no_work => (),
178        }
179
180        let n_available = self.max_checkpointing_conccurency - self.join_set.len();
181        if n_available > 0 {
182            self.no_work = true;
183            for namespace in self
184                .scheduled
185                .difference(&self.checkpointing)
186                .take(n_available)
187                .cloned()
188            {
189                self.no_work = false;
190                self.processing.push(namespace.clone());
191                let perform_checkpoint = self.perform_checkpoint.clone();
192                self.join_set.spawn(async move {
193                    let ret = perform_checkpoint.checkpoint(&namespace).await;
194                    (namespace, ret)
195                });
196            }
197
198            for namespace in self.processing.drain(..) {
199                self.scheduled.remove(&namespace);
200                self.checkpointing.insert(namespace);
201            }
202        }
203    }
204}
205
206#[cfg(test)]
207mod test {
208    use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
209
210    use tokio::time::Duration;
211
212    use super::*;
213
214    #[tokio::test]
215    async fn process_checkpoint() {
216        static CALLED: AtomicBool = AtomicBool::new(false);
217
218        #[derive(Debug)]
219        struct TestPerformCheckoint;
220
221        impl PerformCheckpoint for TestPerformCheckoint {
222            async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
223                CALLED.store(true, Relaxed);
224                Ok(())
225            }
226        }
227
228        let (sender, receiver) = mpsc::channel(8);
229        let mut checkpointer =
230            Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
231        let ns = NamespaceName::from("test");
232
233        sender.send(ns.clone().into()).await.unwrap();
234
235        checkpointer.step().await;
236
237        assert!(checkpointer.checkpointing.contains(&ns));
238
239        checkpointer.step().await;
240
241        assert!(checkpointer.checkpointing.is_empty());
242        assert!(checkpointer.scheduled.is_empty());
243        assert!(CALLED.load(std::sync::atomic::Ordering::Relaxed));
244    }
245
246    #[tokio::test]
247    async fn checkpoint_error() {
248        static CALLED: AtomicBool = AtomicBool::new(false);
249
250        #[derive(Debug)]
251        struct TestPerformCheckoint;
252
253        impl PerformCheckpoint for TestPerformCheckoint {
254            async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
255                CALLED.store(true, Relaxed);
256                // random error
257                Err(crate::error::Error::BusySnapshot)
258            }
259        }
260
261        let (sender, receiver) = mpsc::channel(8);
262        let mut checkpointer =
263            Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
264        let ns = NamespaceName::from("test");
265
266        sender.send(ns.clone().into()).await.unwrap();
267
268        checkpointer.step().await;
269        assert_eq!(checkpointer.errors, 0);
270
271        assert!(checkpointer.checkpointing.contains(&ns));
272
273        checkpointer.step().await;
274
275        // job is re-enqueued
276        assert!(CALLED.load(std::sync::atomic::Ordering::Relaxed));
277        assert!(checkpointer.checkpointing.contains(&ns));
278        assert!(checkpointer.scheduled.is_empty());
279        assert_eq!(checkpointer.errors, 1);
280    }
281
282    #[tokio::test]
283    async fn checkpointer_shutdown() {
284        #[derive(Debug)]
285        struct TestPerformCheckoint;
286
287        impl PerformCheckpoint for TestPerformCheckoint {
288            async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
289                Ok(())
290            }
291        }
292
293        let (sender, receiver) = mpsc::channel(8);
294        let mut checkpointer =
295            Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
296
297        drop(sender);
298
299        assert!(!checkpointer.should_exit());
300
301        checkpointer.step().await;
302
303        assert!(checkpointer.should_exit());
304
305        // should return immediately.
306        checkpointer.run().await;
307    }
308
309    #[tokio::test]
310    async fn cant_exit_until_all_processed() {
311        #[derive(Debug)]
312        struct TestPerformCheckoint;
313
314        impl PerformCheckpoint for TestPerformCheckoint {
315            async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
316                Ok(())
317            }
318        }
319
320        let (sender, receiver) = mpsc::channel(8);
321        let mut checkpointer =
322            Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
323
324        drop(sender);
325
326        checkpointer.step().await;
327
328        let ns: NamespaceName = "test".into();
329        checkpointer.scheduled.insert(ns.clone());
330        assert!(!checkpointer.should_exit());
331        checkpointer.scheduled.remove(&ns);
332
333        checkpointer.checkpointing.insert(ns.clone());
334        assert!(!checkpointer.should_exit());
335        checkpointer.checkpointing.remove(&ns);
336
337        assert!(checkpointer.should_exit());
338        // should return immediately.
339        checkpointer.run().await;
340    }
341
342    #[tokio::test]
343    async fn dont_schedule_already_scheduled() {
344        #[derive(Debug)]
345        struct TestPerformCheckoint;
346
347        impl PerformCheckpoint for TestPerformCheckoint {
348            async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
349                tokio::time::sleep(Duration::from_secs(1000)).await;
350                Ok(())
351            }
352        }
353
354        let (sender, receiver) = mpsc::channel(8);
355        let mut checkpointer =
356            Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
357
358        let ns: NamespaceName = "test".into();
359
360        sender.send(ns.clone().into()).await.unwrap();
361        sender.send(ns.clone().into()).await.unwrap();
362
363        checkpointer.step().await;
364
365        assert!(checkpointer.scheduled.is_empty());
366        assert!(checkpointer.checkpointing.contains(&ns));
367
368        checkpointer.step().await;
369
370        assert!(checkpointer.scheduled.contains(&ns));
371        assert!(checkpointer.checkpointing.contains(&ns));
372        assert_eq!(checkpointer.join_set.len(), 1);
373    }
374
375    #[tokio::test]
376    async fn schedule_conccurently_for_different_namespaces() {
377        #[derive(Debug)]
378        struct TestPerformCheckoint;
379
380        impl PerformCheckpoint for TestPerformCheckoint {
381            async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
382                tokio::time::sleep(Duration::from_secs(1000)).await;
383                Ok(())
384            }
385        }
386
387        let (sender, receiver) = mpsc::channel(8);
388        let mut checkpointer =
389            Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
390
391        let ns1: NamespaceName = "test1".into();
392        let ns2: NamespaceName = "test2".into();
393
394        sender.send(ns1.clone().into()).await.unwrap();
395        sender.send(ns2.clone().into()).await.unwrap();
396
397        checkpointer.step().await;
398
399        assert!(checkpointer.scheduled.is_empty());
400        assert!(checkpointer.checkpointing.contains(&ns1));
401        assert_eq!(checkpointer.checkpointing.len(), 1);
402
403        checkpointer.step().await;
404
405        assert!(checkpointer.scheduled.is_empty());
406        assert!(checkpointer.checkpointing.contains(&ns2));
407        assert_eq!(checkpointer.checkpointing.len(), 2);
408        assert_eq!(checkpointer.join_set.len(), 2);
409    }
410
411    #[tokio::test]
412    async fn checkpointer_limited_conccurency() {
413        #[derive(Debug)]
414        struct TestPerformCheckoint;
415
416        impl PerformCheckpoint for TestPerformCheckoint {
417            async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
418                tokio::time::sleep(Duration::from_secs(1000)).await;
419                Ok(())
420            }
421        }
422
423        let (sender, receiver) = mpsc::channel(8);
424        let mut checkpointer =
425            Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 2);
426
427        let ns1: NamespaceName = "test1".into();
428        let ns2: NamespaceName = "test2".into();
429        let ns3: NamespaceName = "test3".into();
430
431        sender.send(ns1.clone().into()).await.unwrap();
432        sender.send(ns2.clone().into()).await.unwrap();
433        sender.send(ns3.clone().into()).await.unwrap();
434
435        checkpointer.step().await;
436        checkpointer.step().await;
437        checkpointer.step().await;
438
439        assert_eq!(checkpointer.scheduled.len(), 1);
440        assert!(checkpointer.scheduled.contains(&ns3));
441
442        assert!(checkpointer.checkpointing.contains(&ns1));
443        assert!(checkpointer.checkpointing.contains(&ns2));
444        assert_eq!(checkpointer.checkpointing.len(), 2);
445        assert_eq!(checkpointer.join_set.len(), 2);
446
447        tokio::time::pause();
448        tokio::time::advance(Duration::from_secs(2000)).await;
449
450        checkpointer.step().await;
451        checkpointer.step().await;
452
453        assert!(checkpointer.scheduled.is_empty());
454        assert!(checkpointer.checkpointing.contains(&ns3));
455        assert_eq!(checkpointer.checkpointing.len(), 1);
456    }
457}