1use crate::snapshot::write_last_frame_no;
2
3use super::*;
4use std::{fmt::Debug, future::Future, ops::Range, sync::Arc};
5
6use anyhow::{anyhow, bail, Result};
7use futures::future::BoxFuture;
8use libsql::replication::{Frame, FrameNo};
9use libsql_replication::injector::{Injector, SqliteInjector};
10use libsql_sys::{
11 connection::NO_AUTOCHECKPOINT,
12 rusqlite::OpenFlags,
13 wal::{wrapper::WrappedWal, Sqlite3Wal, Sqlite3WalManager},
14 Connection,
15};
16use log::Log;
17use parking_lot::Mutex;
18use rusqlite::{DatabaseName, TransactionState};
19use snapshot::Snapshot;
20use tempfile::NamedTempFile;
21use wal::ShadowWal;
22
23pub struct MylibsqlDB {
24 db: NamedTempFile,
25 shadow_wal: ShadowWal,
26 injector: SqliteInjector,
27 rw_conn: Arc<Mutex<Connection<WrappedWal<ShadowWal, Sqlite3Wal>>>>,
28 ro_conn: Arc<Mutex<Connection<Sqlite3Wal>>>,
29}
30
31impl Debug for MylibsqlDB {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("MylibsqlDB")
34 .field("path", &self.db.path())
35 .finish()
36 }
37}
38
39impl MylibsqlDB {
40 pub async fn init(
41 init: impl FnOnce(&rusqlite::Connection) -> Result<()> + Send + 'static,
42 ) -> Result<(Snapshot, Log)> {
43 let (wal, wal_manager) = ShadowWal::new(0).await?;
44 let (path, log) = tokio::task::spawn_blocking(move || {
45 let db = NamedTempFile::new()?;
46 let conn = Connection::open(
47 db.path(),
48 OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
49 wal_manager,
50 NO_AUTOCHECKPOINT,
51 None,
52 )?;
53 init(&conn)?;
55 conn.query_row_and_then("PRAGMA wal_checkpoint(TRUNCATE)", (), |row| {
57 let status: i32 = row.get(0)?;
58 if status != 0 {
59 Err(anyhow!("WAL checkpoint failed with status {}", status))
60 } else {
61 Ok(())
62 }
63 })?;
64 drop(conn);
66 write_last_frame_no(&db, wal.log().last_commited_frame_no())?;
67 anyhow::Ok((db, wal.into_log()?))
68 })
69 .await??;
70 Ok((Snapshot::open(path).await?, log))
71 }
72
73 pub async fn open(snapshot: &Snapshot) -> Result<Self> {
74 let start_frame_no = snapshot.last_frame_no().map(|f| f + 1).unwrap_or_default();
75 let db = NamedTempFile::new()?;
76 tokio::fs::copy(snapshot.path(), db.path()).await?;
77 let injector =
78 SqliteInjector::new(db.path().to_path_buf(), 4096, NO_AUTOCHECKPOINT, None).await?;
79 let (shadow_wal, wal_manager) = ShadowWal::new(start_frame_no).await?;
80 let db = tokio::task::spawn_blocking(move || {
81 let rw_conn = Connection::open(
82 db.path(),
83 OpenFlags::SQLITE_OPEN_READ_WRITE,
84 wal_manager,
85 NO_AUTOCHECKPOINT,
86 None,
87 )?;
88 let ro_conn = Connection::open(
89 db.path(),
90 OpenFlags::SQLITE_OPEN_READ_ONLY,
91 Sqlite3WalManager::new(),
92 NO_AUTOCHECKPOINT,
93 None,
94 )?;
95 anyhow::Ok(MylibsqlDB {
96 db,
97 rw_conn: Arc::new(Mutex::new(rw_conn)),
98 ro_conn: Arc::new(Mutex::new(ro_conn)),
99 shadow_wal,
100 injector,
101 })
102 })
103 .await??;
104 Ok(db)
105 }
106
107 pub async fn inject_log(&mut self, additional_log: &Log) -> Result<()> {
108 let additional_log_start_frame_no = additional_log.start_frame_no();
109 {
110 let log = self.shadow_wal.log();
111 let expected_frame_no = log.next_frame_no();
112 if !log.is_empty() || additional_log_start_frame_no != expected_frame_no {
113 bail!("log does not start at the expected frame number (expected {expected_frame_no}, got {additional_log_start_frame_no})");
114 }
115 }
116 for frame in additional_log.frames_iter() {
117 Self::inject_frame(&mut self.injector, frame?).await?;
118 }
119 let shadow_wal = self.shadow_wal.clone();
120 let next_frame_no = additional_log.next_frame_no();
121 tokio::task::spawn_blocking(move || {
122 shadow_wal.swap_log(move |_| Log::new_from(next_frame_no))
123 })
124 .await??;
125 let rw_conn = self.rw_conn.clone();
126 tokio::task::spawn_blocking(move || {
127 rw_conn
129 .lock()
130 .query_row_and_then("PRAGMA wal_checkpoint(TRUNCATE)", (), |row| {
131 let status: i32 = row.get(0)?;
132 if status != 0 {
133 Err(anyhow!("WAL checkpoint failed with status {}", status))
134 } else {
135 Ok(())
136 }
137 })
138 })
139 .await??;
140 Ok(())
141 }
142
143 async fn inject_frame(injector: &mut SqliteInjector, frame: Frame) -> Result<()> {
144 injector
145 .inject_frame(libsql_replication::rpc::replication::Frame {
146 data: frame.bytes(),
147 ..Default::default()
148 })
149 .await?;
150 Ok(())
151 }
152
153 pub async fn with_rw_connection<A>(
154 &self,
155 f: impl FnOnce(&mut rusqlite::Connection) -> Result<A> + Send + 'static,
156 ) -> Result<(A, Log, Range<FrameNo>)>
157 where
158 A: Send + 'static,
159 {
160 let conn = self.rw_conn.clone();
161 let shadow_wal = self.shadow_wal.clone();
162 shadow_wal.check_poisoned()?;
163
164 Ok(tokio::task::spawn_blocking(move || {
165 let mut conn = conn.lock();
166
167 let start_frame_no = shadow_wal.log().next_frame_no();
169 let a = f(&mut conn)?;
170
171 let result = if conn.transaction_state(None) != Ok(TransactionState::None) {
173 conn.execute_batch("ROLLBACK;")?;
174 Err(anyhow!("a transaction was still pending"))
175 } else {
176 let end_frame_no = shadow_wal.log().next_frame_no();
178 let frames = start_frame_no..end_frame_no;
179
180 Ok((a, shadow_wal.log().clone(), frames))
181 };
182
183 if shadow_wal.log().has_uncommitted_frames() {
185 shadow_wal.poison();
186 bail!("fatal error: uncommitted frames in the log");
187 }
188
189 result
190 })
191 .await??)
192 }
193
194 pub async fn with_ro_connection<A>(
195 &self,
196 f: impl FnOnce(&mut rusqlite::Connection) -> Result<A> + Send + 'static,
197 ) -> Result<A>
198 where
199 A: Send + 'static,
200 {
201 self.shadow_wal.check_poisoned()?;
202 let conn = self.ro_conn.clone();
203 Ok(tokio::task::spawn_blocking(move || {
204 let mut conn = conn.lock();
205 debug_assert!(conn.is_readonly(DatabaseName::Main)?);
206 f(&mut conn)
207 })
208 .await??)
209 }
210
211 pub async fn checkpoint(
212 &self,
213 f: impl FnOnce(Log) -> BoxFuture<'static, ()>,
214 ) -> Result<impl Future<Output = Option<FrameNo>>> {
215 let shadow_wal = self.shadow_wal.clone();
216 shadow_wal.check_poisoned()?;
217 let rw_conn = self.rw_conn.clone();
218 let old_log = tokio::task::spawn_blocking(move || {
219 let rw_conn = rw_conn.lock();
220 let old_log = shadow_wal.swap_log(|log| Log::new_from(log.next_frame_no()))?;
221 rw_conn.query_row_and_then("PRAGMA wal_checkpoint(PASSIVE)", (), |row| {
223 let status: i32 = row.get(0)?;
224 if status != 0 {
225 shadow_wal.poison();
226 Err(anyhow!(
227 "fatal error: WAL checkpoint failed with status {}",
228 status
229 ))
230 } else {
231 Ok(())
232 }
233 })?;
234 anyhow::Ok(old_log)
235 })
236 .await??;
237 let last_frame_no = old_log.last_commited_frame_no();
238 Ok(async move {
239 f(old_log).await;
240 last_frame_no
241 })
242 }
243
244 pub fn into_inner(self) -> (NamedTempFile, Log) {
245 drop(self.rw_conn);
246 drop(self.ro_conn);
247 (self.db, self.shadow_wal.into_log().unwrap())
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use std::collections::VecDeque;
254
255 use futures::FutureExt;
256
257 use super::*;
258
259 async fn blank_db() -> Result<Snapshot> {
260 init(|_| Ok(())).await.map(|(snapshot, _)| snapshot)
261 }
262
263 #[tokio::test]
264 async fn create_blank_database() -> Result<()> {
265 let (snapshot, log) = MylibsqlDB::init(|_| Ok(())).await?;
266 assert!(snapshot.last_frame_no().is_none());
267 assert!(log.last_commited_frame_no().is_none());
268 assert_eq!(log.next_frame_no(), 0);
269
270 let _ = MylibsqlDB::open(&snapshot).await?;
272
273 let mut db = MylibsqlDB::open(&blank_db().await?).await?;
275 db.inject_log(&log).await?;
276
277 Ok(())
278 }
279
280 #[tokio::test]
281 async fn create_initial_database() -> Result<()> {
282 let (snapshot, log) = MylibsqlDB::init(|conn| {
283 conn.execute("create table lol(x integer)", ())?;
284 conn.execute("insert into lol values (1)", ())?;
285 Ok(())
286 })
287 .await?;
288 assert!(snapshot.last_frame_no().is_some());
289 assert!(log.last_commited_frame_no().is_some());
290 assert_eq!(log.next_frame_no(), 3);
291
292 let db = MylibsqlDB::open(&snapshot).await?;
294 let count: usize = db
295 .with_rw_connection(|conn| {
296 Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
297 })
298 .await?
299 .0;
300 assert_eq!(count, 1);
301 let count: usize = db
302 .with_ro_connection(|conn| {
303 Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
304 })
305 .await?;
306 assert_eq!(count, 1);
307
308 let mut db = MylibsqlDB::open(&blank_db().await?).await?;
310 db.inject_log(&log).await?;
311 let count: usize = db
312 .with_rw_connection(|conn| {
313 Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
314 })
315 .await?
316 .0;
317 assert_eq!(count, 1);
318 let count: usize = db
319 .with_ro_connection(|conn| {
320 Ok(conn.query_row("select count(*) from lol", (), |row| row.get(0))?)
321 })
322 .await?;
323 assert_eq!(count, 1);
324
325 Ok(())
326 }
327
328 #[tokio::test]
329 async fn with_ro_connection() -> Result<()> {
330 let db = MylibsqlDB::open(&blank_db().await?).await?;
331 assert!(db
332 .with_ro_connection(|conn| Ok(conn.execute("create table lol(x integer)", ())?))
333 .await
334 .is_err());
335 Ok(())
336 }
337
338 #[tokio::test]
339 async fn checkpoints() -> Result<()> {
340 let logs_store = Arc::new(Mutex::new(VecDeque::new()));
341 let save_log = {
342 let logs_store = logs_store.clone();
343 move |log| {
344 let logs_store = logs_store.clone();
345 async move {
346 logs_store.lock().push_back(log);
347 }
348 .boxed()
349 }
350 };
351
352 let db = MylibsqlDB::open(&blank_db().await?).await?;
353 assert_eq!(None, db.checkpoint(&save_log).await?.await);
354
355 db.with_rw_connection(|conn| Ok(conn.execute("create table boo(x string)", ())?))
357 .await?;
358 assert_eq!(Some(1), db.checkpoint(&save_log).await?.await);
359
360 db.with_rw_connection(|conn| Ok(conn.execute("insert into boo values ('YO')", ())?))
362 .await?;
363 assert_eq!(Some(2), db.checkpoint(&save_log).await?.await);
364
365 db.with_rw_connection(|conn| Ok(conn.execute("update boo set x = 'YOO'", ())?))
367 .await?;
368 assert_eq!(Some(3), db.checkpoint(&save_log).await?.await);
369
370 db.with_rw_connection(|conn| Ok(conn.execute("delete from boo", ())?))
372 .await?;
373 assert_eq!(Some(4), db.checkpoint(&save_log).await?.await);
374
375 drop(save_log);
376 let mut logs_store = Arc::into_inner(logs_store).unwrap().into_inner();
377 assert_eq!(logs_store.len(), 5);
378
379 let mut db = MylibsqlDB::open(&blank_db().await?).await?;
381 db.inject_log(&logs_store.pop_front().unwrap()).await?; db.inject_log(&logs_store.pop_front().unwrap()).await?;
385 let count: usize = db
386 .with_rw_connection(|conn| {
387 Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
388 })
389 .await?
390 .0;
391 assert_eq!(count, 0);
392 let count: usize = db
393 .with_ro_connection(|conn| {
394 Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
395 })
396 .await?;
397 assert_eq!(count, 0);
398
399 db.inject_log(&logs_store.pop_front().unwrap()).await?;
401 let yo: String = db
402 .with_rw_connection(|conn| {
403 Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
404 })
405 .await?
406 .0;
407 assert_eq!(yo, "YO");
408 let yo: String = db
409 .with_ro_connection(|conn| {
410 Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
411 })
412 .await?;
413 assert_eq!(yo, "YO");
414
415 db.inject_log(&logs_store.pop_front().unwrap()).await?;
417 let yo: String = db
418 .with_rw_connection(|conn| {
419 Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
420 })
421 .await?
422 .0;
423 assert_eq!(yo, "YOO");
424 let yo: String = db
425 .with_ro_connection(|conn| {
426 Ok(conn.query_row("select x from boo", (), |row| row.get(0))?)
427 })
428 .await?;
429 assert_eq!(yo, "YOO");
430
431 db.inject_log(&logs_store.pop_front().unwrap()).await?;
433 let count: usize = db
434 .with_rw_connection(|conn| {
435 Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
436 })
437 .await?
438 .0;
439 assert_eq!(count, 0);
440 let count: usize = db
441 .with_ro_connection(|conn| {
442 Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
443 })
444 .await?;
445 assert_eq!(count, 0);
446
447 db.with_rw_connection(|conn| Ok(conn.execute("insert into boo values ('LOL')", ())?))
449 .await?;
450 let count: usize = db
451 .with_ro_connection(|conn| {
452 Ok(conn.query_row("select count(*) from boo", (), |row| row.get(0))?)
453 })
454 .await?;
455 assert_eq!(count, 1);
456
457 db.checkpoint(|log| {
459 async move {
460 assert_eq!(log.next_frame_no(), 6);
461 }
462 .boxed()
463 })
464 .await?
465 .await;
466
467 Ok(())
468 }
469
470 #[tokio::test]
471 async fn bad_transaction() -> Result<()> {
472 let (snapshot, log) = MylibsqlDB::init(|conn| {
473 conn.execute("create table lol(x integer)", ())?;
474 Ok(())
475 })
476 .await?;
477 assert_eq!(2, log.next_frame_no());
478
479 let db = MylibsqlDB::open(&snapshot).await?;
480
481 assert!(db
483 .with_rw_connection(|conn| {
484 conn.execute_batch("begin;")?;
485 conn.execute_batch("insert into lol values (1)")?;
486 Ok(())
487 })
488 .await
489 .is_err());
490
491 assert!(db
493 .with_rw_connection(|conn| {
494 let txn = conn.transaction()?;
495 txn.execute_batch("insert into lol values (1)")?;
496 Ok(())
497 })
498 .await
499 .is_ok());
500
501 let (_, _, frames) = db
502 .with_rw_connection(|conn| {
503 let txn = conn.transaction()?;
504 txn.execute_batch("insert into lol values (1)")?;
505 txn.commit()?;
506 Ok(())
507 })
508 .await?;
509 assert_eq!(frames, 2..3); Ok(())
512 }
513}