1use std::future::Future;
2use std::sync::Arc;
3
4use hashbrown::HashSet;
5use libsql_sys::name::NamespaceName;
6use tokio::sync::mpsc;
7use tokio::task::JoinSet;
8
9use crate::io::Io;
10use crate::registry::WalRegistry;
11
12pub(crate) type NotifyCheckpointer = mpsc::Sender<NamespaceName>;
13
14pub enum CheckpointMessage {
15 Namespace(NamespaceName),
17 Shutdown,
19}
20
21impl From<NamespaceName> for CheckpointMessage {
22 fn from(value: NamespaceName) -> Self {
23 Self::Namespace(value)
24 }
25}
26
27pub type LibsqlCheckpointer<IO, S> = Checkpointer<WalRegistry<IO, S>>;
28
29impl<IO, S> LibsqlCheckpointer<IO, S>
30where
31 IO: Io,
32 S: Sync + Send + 'static,
33{
34 pub fn new(
35 registry: Arc<WalRegistry<IO, S>>,
36 notifier: mpsc::Receiver<CheckpointMessage>,
37 max_checkpointing_conccurency: usize,
38 ) -> Self {
39 Self::new_with_performer(registry, notifier, max_checkpointing_conccurency)
40 }
41}
42
43trait PerformCheckpoint {
44 fn checkpoint(
45 &self,
46 namespace: &NamespaceName,
47 ) -> impl Future<Output = crate::error::Result<()>> + Send;
48}
49
50impl<IO, S> PerformCheckpoint for WalRegistry<IO, S>
51where
52 IO: Io,
53 S: Sync + Send + 'static,
54{
55 #[tracing::instrument(skip(self))]
56 fn checkpoint(
57 &self,
58 namespace: &NamespaceName,
59 ) -> impl Future<Output = crate::error::Result<()>> + Send {
60 let namespace = namespace.clone();
61 async move {
62 if let Some(registry) = self.get_async(&namespace).await {
63 registry.checkpoint().await?;
64 }
65 Ok(())
66 }
67 }
68}
69
70const CHECKPOINTER_ERROR_THRES: usize = 16;
71
72#[derive(Debug)]
77pub struct Checkpointer<P> {
78 perform_checkpoint: Arc<P>,
79 scheduled: HashSet<NamespaceName>,
81 checkpointing: HashSet<NamespaceName>,
83 recv: mpsc::Receiver<CheckpointMessage>,
86 max_checkpointing_conccurency: usize,
87 shutting_down: bool,
88 join_set: JoinSet<(NamespaceName, crate::error::Result<()>)>,
89 processing: Vec<NamespaceName>,
90 errors: usize,
91 no_work: bool,
93}
94
95#[allow(private_bounds)]
96impl<P> Checkpointer<P>
97where
98 P: PerformCheckpoint + Send + Sync + 'static,
99{
100 fn new_with_performer(
101 perform_checkpoint: Arc<P>,
102 notifier: mpsc::Receiver<CheckpointMessage>,
103 max_checkpointing_conccurency: usize,
104 ) -> Self {
105 Self {
106 perform_checkpoint,
107 scheduled: Default::default(),
108 checkpointing: Default::default(),
109 recv: notifier,
110 max_checkpointing_conccurency,
111 shutting_down: false,
112 join_set: JoinSet::new(),
113 processing: Vec::new(),
114 errors: 0,
115 no_work: false,
116 }
117 }
118
119 #[tracing::instrument(skip(self))]
120 pub async fn run(mut self) {
121 loop {
122 if self.should_exit() {
123 tracing::info!("checkpointer exited cleanly.");
124 return;
125 }
126
127 if self.errors > CHECKPOINTER_ERROR_THRES {
128 todo!("handle too many consecutive errors");
129 }
130
131 self.step().await;
132 }
133 }
134
135 fn should_exit(&self) -> bool {
136 self.shutting_down
137 && self.recv.is_empty()
138 && self.scheduled.is_empty()
139 && self.checkpointing.is_empty()
140 && self.join_set.is_empty()
141 }
142
143 async fn step(&mut self) {
144 tokio::select! {
145 biased;
146 result = self.join_set.join_next(), if !self.join_set.is_empty() => {
147 match result {
148 Some(Ok((namespace, result))) => {
149 self.checkpointing.remove(&namespace);
150 if let Err(e) = result {
151 self.errors += 1;
152 tracing::error!("error checkpointing ns {namespace}: {e}, rescheduling");
153 self.scheduled.insert(namespace);
155 } else {
156 self.errors = 0;
157 }
158 }
159 Some(Err(e)) => panic!("checkpoint task panicked: {e}"),
160 None => unreachable!("got None, but join set is not empty")
161 }
162 }
163 notified = self.recv.recv(), if !self.shutting_down => {
164 match notified {
165 Some(CheckpointMessage::Namespace(namespace)) => {
166 tracing::info!(namespace = namespace.as_str(), "notified for checkpoint");
167 self.scheduled.insert(namespace);
168 }
169 None | Some(CheckpointMessage::Shutdown) => {
170 tracing::info!("checkpointed is shutting down. {} namespaces to checkpoint", self.checkpointing.len());
171 self.shutting_down = true;
172 }
173 }
174 }
175 _ = std::future::ready(()), if !self.scheduled.is_empty()
177 && self.join_set.len() < self.max_checkpointing_conccurency && !self.no_work => (),
178 }
179
180 let n_available = self.max_checkpointing_conccurency - self.join_set.len();
181 if n_available > 0 {
182 self.no_work = true;
183 for namespace in self
184 .scheduled
185 .difference(&self.checkpointing)
186 .take(n_available)
187 .cloned()
188 {
189 self.no_work = false;
190 self.processing.push(namespace.clone());
191 let perform_checkpoint = self.perform_checkpoint.clone();
192 self.join_set.spawn(async move {
193 let ret = perform_checkpoint.checkpoint(&namespace).await;
194 (namespace, ret)
195 });
196 }
197
198 for namespace in self.processing.drain(..) {
199 self.scheduled.remove(&namespace);
200 self.checkpointing.insert(namespace);
201 }
202 }
203 }
204}
205
206#[cfg(test)]
207mod test {
208 use std::sync::atomic::{AtomicBool, Ordering::Relaxed};
209
210 use tokio::time::Duration;
211
212 use super::*;
213
214 #[tokio::test]
215 async fn process_checkpoint() {
216 static CALLED: AtomicBool = AtomicBool::new(false);
217
218 #[derive(Debug)]
219 struct TestPerformCheckoint;
220
221 impl PerformCheckpoint for TestPerformCheckoint {
222 async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
223 CALLED.store(true, Relaxed);
224 Ok(())
225 }
226 }
227
228 let (sender, receiver) = mpsc::channel(8);
229 let mut checkpointer =
230 Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
231 let ns = NamespaceName::from("test");
232
233 sender.send(ns.clone().into()).await.unwrap();
234
235 checkpointer.step().await;
236
237 assert!(checkpointer.checkpointing.contains(&ns));
238
239 checkpointer.step().await;
240
241 assert!(checkpointer.checkpointing.is_empty());
242 assert!(checkpointer.scheduled.is_empty());
243 assert!(CALLED.load(std::sync::atomic::Ordering::Relaxed));
244 }
245
246 #[tokio::test]
247 async fn checkpoint_error() {
248 static CALLED: AtomicBool = AtomicBool::new(false);
249
250 #[derive(Debug)]
251 struct TestPerformCheckoint;
252
253 impl PerformCheckpoint for TestPerformCheckoint {
254 async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
255 CALLED.store(true, Relaxed);
256 Err(crate::error::Error::BusySnapshot)
258 }
259 }
260
261 let (sender, receiver) = mpsc::channel(8);
262 let mut checkpointer =
263 Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
264 let ns = NamespaceName::from("test");
265
266 sender.send(ns.clone().into()).await.unwrap();
267
268 checkpointer.step().await;
269 assert_eq!(checkpointer.errors, 0);
270
271 assert!(checkpointer.checkpointing.contains(&ns));
272
273 checkpointer.step().await;
274
275 assert!(CALLED.load(std::sync::atomic::Ordering::Relaxed));
277 assert!(checkpointer.checkpointing.contains(&ns));
278 assert!(checkpointer.scheduled.is_empty());
279 assert_eq!(checkpointer.errors, 1);
280 }
281
282 #[tokio::test]
283 async fn checkpointer_shutdown() {
284 #[derive(Debug)]
285 struct TestPerformCheckoint;
286
287 impl PerformCheckpoint for TestPerformCheckoint {
288 async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
289 Ok(())
290 }
291 }
292
293 let (sender, receiver) = mpsc::channel(8);
294 let mut checkpointer =
295 Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
296
297 drop(sender);
298
299 assert!(!checkpointer.should_exit());
300
301 checkpointer.step().await;
302
303 assert!(checkpointer.should_exit());
304
305 checkpointer.run().await;
307 }
308
309 #[tokio::test]
310 async fn cant_exit_until_all_processed() {
311 #[derive(Debug)]
312 struct TestPerformCheckoint;
313
314 impl PerformCheckpoint for TestPerformCheckoint {
315 async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
316 Ok(())
317 }
318 }
319
320 let (sender, receiver) = mpsc::channel(8);
321 let mut checkpointer =
322 Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
323
324 drop(sender);
325
326 checkpointer.step().await;
327
328 let ns: NamespaceName = "test".into();
329 checkpointer.scheduled.insert(ns.clone());
330 assert!(!checkpointer.should_exit());
331 checkpointer.scheduled.remove(&ns);
332
333 checkpointer.checkpointing.insert(ns.clone());
334 assert!(!checkpointer.should_exit());
335 checkpointer.checkpointing.remove(&ns);
336
337 assert!(checkpointer.should_exit());
338 checkpointer.run().await;
340 }
341
342 #[tokio::test]
343 async fn dont_schedule_already_scheduled() {
344 #[derive(Debug)]
345 struct TestPerformCheckoint;
346
347 impl PerformCheckpoint for TestPerformCheckoint {
348 async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
349 tokio::time::sleep(Duration::from_secs(1000)).await;
350 Ok(())
351 }
352 }
353
354 let (sender, receiver) = mpsc::channel(8);
355 let mut checkpointer =
356 Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
357
358 let ns: NamespaceName = "test".into();
359
360 sender.send(ns.clone().into()).await.unwrap();
361 sender.send(ns.clone().into()).await.unwrap();
362
363 checkpointer.step().await;
364
365 assert!(checkpointer.scheduled.is_empty());
366 assert!(checkpointer.checkpointing.contains(&ns));
367
368 checkpointer.step().await;
369
370 assert!(checkpointer.scheduled.contains(&ns));
371 assert!(checkpointer.checkpointing.contains(&ns));
372 assert_eq!(checkpointer.join_set.len(), 1);
373 }
374
375 #[tokio::test]
376 async fn schedule_conccurently_for_different_namespaces() {
377 #[derive(Debug)]
378 struct TestPerformCheckoint;
379
380 impl PerformCheckpoint for TestPerformCheckoint {
381 async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
382 tokio::time::sleep(Duration::from_secs(1000)).await;
383 Ok(())
384 }
385 }
386
387 let (sender, receiver) = mpsc::channel(8);
388 let mut checkpointer =
389 Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 5);
390
391 let ns1: NamespaceName = "test1".into();
392 let ns2: NamespaceName = "test2".into();
393
394 sender.send(ns1.clone().into()).await.unwrap();
395 sender.send(ns2.clone().into()).await.unwrap();
396
397 checkpointer.step().await;
398
399 assert!(checkpointer.scheduled.is_empty());
400 assert!(checkpointer.checkpointing.contains(&ns1));
401 assert_eq!(checkpointer.checkpointing.len(), 1);
402
403 checkpointer.step().await;
404
405 assert!(checkpointer.scheduled.is_empty());
406 assert!(checkpointer.checkpointing.contains(&ns2));
407 assert_eq!(checkpointer.checkpointing.len(), 2);
408 assert_eq!(checkpointer.join_set.len(), 2);
409 }
410
411 #[tokio::test]
412 async fn checkpointer_limited_conccurency() {
413 #[derive(Debug)]
414 struct TestPerformCheckoint;
415
416 impl PerformCheckpoint for TestPerformCheckoint {
417 async fn checkpoint(&self, _namespace: &NamespaceName) -> crate::error::Result<()> {
418 tokio::time::sleep(Duration::from_secs(1000)).await;
419 Ok(())
420 }
421 }
422
423 let (sender, receiver) = mpsc::channel(8);
424 let mut checkpointer =
425 Checkpointer::new_with_performer(TestPerformCheckoint.into(), receiver, 2);
426
427 let ns1: NamespaceName = "test1".into();
428 let ns2: NamespaceName = "test2".into();
429 let ns3: NamespaceName = "test3".into();
430
431 sender.send(ns1.clone().into()).await.unwrap();
432 sender.send(ns2.clone().into()).await.unwrap();
433 sender.send(ns3.clone().into()).await.unwrap();
434
435 checkpointer.step().await;
436 checkpointer.step().await;
437 checkpointer.step().await;
438
439 assert_eq!(checkpointer.scheduled.len(), 1);
440 assert!(checkpointer.scheduled.contains(&ns3));
441
442 assert!(checkpointer.checkpointing.contains(&ns1));
443 assert!(checkpointer.checkpointing.contains(&ns2));
444 assert_eq!(checkpointer.checkpointing.len(), 2);
445 assert_eq!(checkpointer.join_set.len(), 2);
446
447 tokio::time::pause();
448 tokio::time::advance(Duration::from_secs(2000)).await;
449
450 checkpointer.step().await;
451 checkpointer.step().await;
452
453 assert!(checkpointer.scheduled.is_empty());
454 assert!(checkpointer.checkpointing.contains(&ns3));
455 assert_eq!(checkpointer.checkpointing.len(), 1);
456 }
457}