Skip to main content

mylibsql/
db.rs

1use crate::snapshot::write_last_frame_no;
2
3use super::*;
4use std::{fmt::Debug, future::Future, ops::Range, sync::Arc};
5
6use anyhow::{anyhow, bail, Result};
7use futures::future::BoxFuture;
8use libsql::replication::{Frame, FrameNo};
9use libsql_replication::injector::{Injector, SqliteInjector};
10use libsql_sys::{
11    connection::NO_AUTOCHECKPOINT,
12    rusqlite::OpenFlags,
13    wal::{wrapper::WrappedWal, Sqlite3Wal, Sqlite3WalManager},
14    Connection,
15};
16use log::Log;
17use parking_lot::Mutex;
18use rusqlite::{DatabaseName, TransactionState};
19use snapshot::Snapshot;
20use tempfile::NamedTempFile;
21use wal::ShadowWal;
22
23pub struct MylibsqlDB {
24    db: NamedTempFile,
25    shadow_wal: ShadowWal,
26    injector: SqliteInjector,
27    rw_conn: Arc<Mutex<Connection<WrappedWal<ShadowWal, Sqlite3Wal>>>>,
28    ro_conn: Arc<Mutex<Connection<Sqlite3Wal>>>,
29}
30
31impl Debug for MylibsqlDB {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        f.debug_struct("MylibsqlDB")
34            .field("path", &self.db.path())
35            .finish()
36    }
37}
38
39impl MylibsqlDB {
40    pub async fn init(
41        init: impl FnOnce(&rusqlite::Connection) -> Result<()> + Send + 'static,
42    ) -> Result<(Snapshot, Log)> {
43        let (wal, wal_manager) = ShadowWal::new(0).await?;
44        let (path, log) = tokio::task::spawn_blocking(move || {
45            let db = NamedTempFile::new()?;
46            let conn = Connection::open(
47                db.path(),
48                OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
49                wal_manager,
50                NO_AUTOCHECKPOINT,
51                None,
52            )?;
53            // run user code initialization script
54            init(&conn)?;
55            // truncate WAL
56            conn.query_row_and_then("PRAGMA wal_checkpoint(TRUNCATE)", (), |row| {
57                let status: i32 = row.get(0)?;
58                if status != 0 {
59                    Err(anyhow!("WAL checkpoint failed with status {}", status))
60                } else {
61                    Ok(())
62                }
63            })?;
64            // write the last frame number in the database header
65            drop(conn);
66            write_last_frame_no(&db, wal.log().last_commited_frame_no())?;
67            anyhow::Ok((db, wal.into_log()?))
68        })
69        .await??;
70        Ok((Snapshot::open(path).await?, log))
71    }
72
73    pub async fn open(snapshot: &Snapshot) -> Result<Self> {
74        let start_frame_no = snapshot.last_frame_no().map(|f| f + 1).unwrap_or_default();
75        let db = NamedTempFile::new()?;
76        tokio::fs::copy(snapshot.path(), db.path()).await?;
77        let injector =
78            SqliteInjector::new(db.path().to_path_buf(), 4096, NO_AUTOCHECKPOINT, None).await?;
79        let (shadow_wal, wal_manager) = ShadowWal::new(start_frame_no).await?;
80        let db = tokio::task::spawn_blocking(move || {
81            let rw_conn = Connection::open(
82                db.path(),
83                OpenFlags::SQLITE_OPEN_READ_WRITE,
84                wal_manager,
85                NO_AUTOCHECKPOINT,
86                None,
87            )?;
88            let ro_conn = Connection::open(
89                db.path(),
90                OpenFlags::SQLITE_OPEN_READ_ONLY,
91                Sqlite3WalManager::new(),
92                NO_AUTOCHECKPOINT,
93                None,
94            )?;
95            anyhow::Ok(MylibsqlDB {
96                db,
97                rw_conn: Arc::new(Mutex::new(rw_conn)),
98                ro_conn: Arc::new(Mutex::new(ro_conn)),
99                shadow_wal,
100                injector,
101            })
102        })
103        .await??;
104        Ok(db)
105    }
106
107    pub async fn inject_log(&mut self, additional_log: &Log) -> Result<()> {
108        let additional_log_start_frame_no = additional_log.start_frame_no();
109        {
110            let log = self.shadow_wal.log();
111            let expected_frame_no = log.next_frame_no();
112            if !log.is_empty() || additional_log_start_frame_no != expected_frame_no {
113                bail!("log does not start at the expected frame number (expected {expected_frame_no}, got {additional_log_start_frame_no})");
114            }
115        }
116        for frame in additional_log.frames_iter() {
117            Self::inject_frame(&mut self.injector, frame?).await?;
118        }
119        let shadow_wal = self.shadow_wal.clone();
120        let next_frame_no = additional_log.next_frame_no();
121        tokio::task::spawn_blocking(move || {
122            shadow_wal.swap_log(move |_| Log::new_from(next_frame_no))
123        })
124        .await??;
125        let rw_conn = self.rw_conn.clone();
126        tokio::task::spawn_blocking(move || {
127            // truncate WAL
128            rw_conn
129                .lock()
130                .query_row_and_then("PRAGMA wal_checkpoint(TRUNCATE)", (), |row| {
131                    let status: i32 = row.get(0)?;
132                    if status != 0 {
133                        Err(anyhow!("WAL checkpoint failed with status {}", status))
134                    } else {
135                        Ok(())
136                    }
137                })
138        })
139        .await??;
140        Ok(())
141    }
142
143    async fn inject_frame(injector: &mut SqliteInjector, frame: Frame) -> Result<()> {
144        injector
145            .inject_frame(libsql_replication::rpc::replication::Frame {
146                data: frame.bytes(),
147                ..Default::default()
148            })
149            .await?;
150        Ok(())
151    }
152
153    pub async fn with_rw_connection<A>(
154        &self,
155        f: impl FnOnce(&mut rusqlite::Connection) -> Result<A> + Send + 'static,
156    ) -> Result<(A, Log, Range<FrameNo>)>
157    where
158        A: Send + 'static,
159    {
160        let conn = self.rw_conn.clone();
161        let shadow_wal = self.shadow_wal.clone();
162        shadow_wal.check_poisoned()?;
163
164        Ok(tokio::task::spawn_blocking(move || {
165            let mut conn = conn.lock();
166
167            // record the start frame number
168            let start_frame_no = shadow_wal.log().next_frame_no();
169            let a = f(&mut conn)?;
170
171            // auto-rollback pending transaction
172            let result = if conn.transaction_state(None) != Ok(TransactionState::None) {
173                conn.execute_batch("ROLLBACK;")?;
174                Err(anyhow!("a transaction was still pending"))
175            } else {
176                // record the end frame number
177                let end_frame_no = shadow_wal.log().next_frame_no();
178                let frames = start_frame_no..end_frame_no;
179
180                Ok((a, shadow_wal.log().clone(), frames))
181            };
182
183            // sanity check
184            if shadow_wal.log().has_uncommitted_frames() {
185                shadow_wal.poison();
186                bail!("fatal error: uncommitted frames in the log");
187            }
188
189            result
190        })
191        .await??)
192    }
193
194    pub async fn with_ro_connection<A>(
195        &self,
196        f: impl FnOnce(&mut rusqlite::Connection) -> Result<A> + Send + 'static,
197    ) -> Result<A>
198    where
199        A: Send + 'static,
200    {
201        self.shadow_wal.check_poisoned()?;
202        let conn = self.ro_conn.clone();
203        Ok(tokio::task::spawn_blocking(move || {
204            let mut conn = conn.lock();
205            debug_assert!(conn.is_readonly(DatabaseName::Main)?);
206            f(&mut conn)
207        })
208        .await??)
209    }
210
211    pub async fn checkpoint(
212        &self,
213        f: impl FnOnce(Log) -> BoxFuture<'static, ()>,
214    ) -> Result<impl Future<Output = Option<FrameNo>>> {
215        let shadow_wal = self.shadow_wal.clone();
216        shadow_wal.check_poisoned()?;
217        let rw_conn = self.rw_conn.clone();
218        let old_log = tokio::task::spawn_blocking(move || {
219            let rw_conn = rw_conn.lock();
220            let old_log = shadow_wal.swap_log(|log| Log::new_from(log.next_frame_no()))?;
221            // checkpoint WAL as much as possible
222            rw_conn.query_row_and_then("PRAGMA wal_checkpoint(PASSIVE)", (), |row| {
223                let status: i32 = row.get(0)?;
224                if status != 0 {
225                    shadow_wal.poison();
226                    Err(anyhow!(
227                        "fatal error: WAL checkpoint failed with status {}",
228                        status
229                    ))
230                } else {
231                    Ok(())
232                }
233            })?;
234            anyhow::Ok(old_log)
235        })
236        .await??;
237        let last_frame_no = old_log.last_commited_frame_no();
238        Ok(async move {
239            f(old_log).await;
240            last_frame_no
241        })
242    }
243
244    pub fn into_inner(self) -> (NamedTempFile, Log) {
245        drop(self.rw_conn);
246        drop(self.ro_conn);
247        (self.db, self.shadow_wal.into_log().unwrap())
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use std::collections::VecDeque;
254
255    use futures::FutureExt;
256
257    use super::*;
258
259    async fn blank_db() -> Result<Snapshot> {
260        init(|_| Ok(())).await.map(|(snapshot, _)| snapshot)
261    }
262
263    #[tokio::test]
264    async fn create_blank_database() -> Result<()> {
265        let (snapshot, log) = MylibsqlDB::init(|_| Ok(())).await?;
266        assert!(snapshot.last_frame_no().is_none());
267        assert!(log.last_commited_frame_no().is_none());
268        assert_eq!(log.next_frame_no(), 0);
269
270        // reopen from snapshot
271        let _ = MylibsqlDB::open(&snapshot).await?;
272
273        // reopen from log
274        let mut db = MylibsqlDB::open(&blank_db().await?).await?;
275        db.inject_log(&log).await?;
276
277        Ok(())
278    }
279
280    #[tokio::test]
281    async fn create_initial_database() -> Result<()> {
282        let (snapshot, log) = MylibsqlDB::init(|conn| {
283            conn.execute("create table lol(x integer)", ())?;
284            conn.execute("insert into lol values (1)", ())?;
285            Ok(())
286        })
287        .await?;
288        assert!(snapshot.last_frame_no().is_some());
289        assert!(log.last_commited_frame_no().is_some());
290        assert_eq!(log.next_frame_no(), 3);
291
292        // reopen from snapshot
293        let db = MylibsqlDB::open(&snapshot).await?;
294        let count: usize = db
295            .with_rw_connection(|conn| {
296                Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
297            })
298            .await?
299            .0;
300        assert_eq!(count, 1);
301        let count: usize = db
302            .with_ro_connection(|conn| {
303                Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
304            })
305            .await?;
306        assert_eq!(count, 1);
307
308        // reopen from log
309        let mut db = MylibsqlDB::open(&blank_db().await?).await?;
310        db.inject_log(&log).await?;
311        let count: usize = db
312            .with_rw_connection(|conn| {
313                Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
314            })
315            .await?
316            .0;
317        assert_eq!(count, 1);
318        let count: usize = db
319            .with_ro_connection(|conn| {
320                Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
321            })
322            .await?;
323        assert_eq!(count, 1);
324
325        Ok(())
326    }
327
328    #[tokio::test]
329    async fn with_ro_connection() -> Result<()> {
330        let db = MylibsqlDB::open(&blank_db().await?).await?;
331        assert!(db
332            .with_ro_connection(|conn| Ok(conn.execute("create table lol(x integer)", ())?))
333            .await
334            .is_err());
335        Ok(())
336    }
337
338    #[tokio::test]
339    async fn checkpoints() -> Result<()> {
340        let logs_store = Arc::new(Mutex::new(VecDeque::new()));
341        let save_log = {
342            let logs_store = logs_store.clone();
343            move |log| {
344                let logs_store = logs_store.clone();
345                async move {
346                    logs_store.lock().push_back(log);
347                }
348                .boxed()
349            }
350        };
351
352        let db = MylibsqlDB::open(&blank_db().await?).await?;
353        assert_eq!(None, db.checkpoint(&save_log).await?.await);
354
355        // first checkpoint (create table)
356        db.with_rw_connection(|conn| Ok(conn.execute("create table boo(x string)", ())?))
357            .await?;
358        assert_eq!(Some(1), db.checkpoint(&save_log).await?.await);
359
360        // second checkpoint (insert data)
361        db.with_rw_connection(|conn| Ok(conn.execute("insert into boo values ('YO')", ())?))
362            .await?;
363        assert_eq!(Some(2), db.checkpoint(&save_log).await?.await);
364
365        // third checkpoint (update data)
366        db.with_rw_connection(|conn| Ok(conn.execute("update boo set x = 'YOO'", ())?))
367            .await?;
368        assert_eq!(Some(3), db.checkpoint(&save_log).await?.await);
369
370        // fourth checkpoint (delete data)
371        db.with_rw_connection(|conn| Ok(conn.execute("delete from boo", ())?))
372            .await?;
373        assert_eq!(Some(4), db.checkpoint(&save_log).await?.await);
374
375        drop(save_log);
376        let mut logs_store = Arc::into_inner(logs_store).unwrap().into_inner();
377        assert_eq!(logs_store.len(), 5);
378
379        // restart with a new db
380        let mut db = MylibsqlDB::open(&blank_db().await?).await?;
381        db.inject_log(&logs_store.pop_front().unwrap()).await?; // this one is blank
382
383        // apply first checkpoint
384        db.inject_log(&logs_store.pop_front().unwrap()).await?;
385        let count: usize = db
386            .with_rw_connection(|conn| {
387                Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
388            })
389            .await?
390            .0;
391        assert_eq!(count, 0);
392        let count: usize = db
393            .with_ro_connection(|conn| {
394                Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
395            })
396            .await?;
397        assert_eq!(count, 0);
398
399        // apply second checkpoint
400        db.inject_log(&logs_store.pop_front().unwrap()).await?;
401        let yo: String = db
402            .with_rw_connection(|conn| {
403                Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
404            })
405            .await?
406            .0;
407        assert_eq!(yo, "YO");
408        let yo: String = db
409            .with_ro_connection(|conn| {
410                Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
411            })
412            .await?;
413        assert_eq!(yo, "YO");
414
415        // apply third checkpoint
416        db.inject_log(&logs_store.pop_front().unwrap()).await?;
417        let yo: String = db
418            .with_rw_connection(|conn| {
419                Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
420            })
421            .await?
422            .0;
423        assert_eq!(yo, "YOO");
424        let yo: String = db
425            .with_ro_connection(|conn| {
426                Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
427            })
428            .await?;
429        assert_eq!(yo, "YOO");
430
431        // apply fourth checkpoint
432        db.inject_log(&logs_store.pop_front().unwrap()).await?;
433        let count: usize = db
434            .with_rw_connection(|conn| {
435                Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
436            })
437            .await?
438            .0;
439        assert_eq!(count, 0);
440        let count: usize = db
441            .with_ro_connection(|conn| {
442                Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
443            })
444            .await?;
445        assert_eq!(count, 0);
446
447        // now we can add additional data
448        db.with_rw_connection(|conn| Ok(conn.execute("insert into boo values ('LOL')", ())?))
449            .await?;
450        let count: usize = db
451            .with_ro_connection(|conn| {
452                Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
453            })
454            .await?;
455        assert_eq!(count, 1);
456
457        // and checkpoint it
458        db.checkpoint(|log| {
459            async move {
460                assert_eq!(log.next_frame_no(), 6);
461            }
462            .boxed()
463        })
464        .await?
465        .await;
466
467        Ok(())
468    }
469
470    #[tokio::test]
471    async fn bad_transaction() -> Result<()> {
472        let (snapshot, log) = MylibsqlDB::init(|conn| {
473            conn.execute("create table lol(x integer)", ())?;
474            Ok(())
475        })
476        .await?;
477        assert_eq!(2, log.next_frame_no());
478
479        let db = MylibsqlDB::open(&snapshot).await?;
480
481        // this is not valid to keep a transaction pending like this
482        assert!(db
483            .with_rw_connection(|conn| {
484                conn.execute_batch("begin;")?;
485                conn.execute_batch("insert into lol values (1)")?;
486                Ok(())
487            })
488            .await
489            .is_err());
490
491        // here it is ok because txn will be auto-rollback during the drop
492        assert!(db
493            .with_rw_connection(|conn| {
494                let txn = conn.transaction()?;
495                txn.execute_batch("insert into lol values (1)")?;
496                Ok(())
497            })
498            .await
499            .is_ok());
500
501        let (_, _, frames) = db
502            .with_rw_connection(|conn| {
503                let txn = conn.transaction()?;
504                txn.execute_batch("insert into lol values (1)")?;
505                txn.commit()?;
506                Ok(())
507            })
508            .await?;
509        assert_eq!(frames, 2..3); // we should have a single frame
510
511        Ok(())
512    }
513}