libsql_wal/
wal.rs

1use std::ffi::OsStr;
2use std::os::unix::prelude::OsStrExt;
3use std::sync::atomic::AtomicU64;
4use std::sync::Arc;
5
6use libsql_sys::name::NamespaceResolver;
7use libsql_sys::wal::{Wal, WalManager};
8
9use crate::io::Io;
10use crate::registry::WalRegistry;
11use crate::segment::sealed::SealedSegment;
12use crate::shared_wal::SharedWal;
13use crate::storage::Storage;
14use crate::transaction::Transaction;
15
16pub struct LibsqlWalManager<IO: Io, S> {
17    registry: Arc<WalRegistry<IO, S>>,
18    next_conn_id: Arc<AtomicU64>,
19    namespace_resolver: Arc<dyn NamespaceResolver>,
20}
21
22impl<IO: Io, S> Clone for LibsqlWalManager<IO, S> {
23    fn clone(&self) -> Self {
24        Self {
25            registry: self.registry.clone(),
26            next_conn_id: self.next_conn_id.clone(),
27            namespace_resolver: self.namespace_resolver.clone(),
28        }
29    }
30}
31
32impl<FS: Io, S> LibsqlWalManager<FS, S> {
33    pub fn new(
34        registry: Arc<WalRegistry<FS, S>>,
35        namespace_resolver: Arc<dyn NamespaceResolver>,
36    ) -> Self {
37        Self {
38            registry,
39            next_conn_id: Default::default(),
40            namespace_resolver,
41        }
42    }
43}
44
45pub struct LibsqlWal<FS: Io> {
46    last_read_frame_no: Option<u64>,
47    tx: Option<Transaction<FS::File>>,
48    shared: Arc<SharedWal<FS>>,
49    conn_id: u64,
50}
51
52impl<IO: Io, S: Storage<Segment = SealedSegment<IO::File>>> WalManager for LibsqlWalManager<IO, S> {
53    type Wal = LibsqlWal<IO>;
54
55    fn use_shared_memory(&self) -> bool {
56        false
57    }
58
59    fn open(
60        &self,
61        _vfs: &mut libsql_sys::wal::Vfs,
62        _file: &mut libsql_sys::wal::Sqlite3File,
63        _no_shm_mode: std::ffi::c_int,
64        _max_log_size: i64,
65        db_path: &std::ffi::CStr,
66    ) -> libsql_sys::wal::Result<Self::Wal> {
67        let db_path = OsStr::from_bytes(&db_path.to_bytes());
68        let namespace = self.namespace_resolver.resolve(db_path.as_ref());
69        let shared = self
70            .registry
71            .clone()
72            .open(db_path.as_ref(), &namespace)
73            .map_err(|e| e.into())?;
74        let conn_id = self
75            .next_conn_id
76            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
77        Ok(LibsqlWal {
78            last_read_frame_no: None,
79            tx: None,
80            shared,
81            conn_id,
82        })
83    }
84
85    fn close(
86        &self,
87        wal: &mut Self::Wal,
88        _db: &mut libsql_sys::wal::Sqlite3Db,
89        _sync_flags: std::ffi::c_int,
90        _scratch: Option<&mut [u8]>,
91    ) -> libsql_sys::wal::Result<()> {
92        wal.end_read_txn();
93        Ok(())
94    }
95
96    fn destroy_log(
97        &self,
98        _vfs: &mut libsql_sys::wal::Vfs,
99        _db_path: &std::ffi::CStr,
100    ) -> libsql_sys::wal::Result<()> {
101        Ok(())
102    }
103
104    fn log_exists(
105        &self,
106        _vfs: &mut libsql_sys::wal::Vfs,
107        _db_path: &std::ffi::CStr,
108    ) -> libsql_sys::wal::Result<bool> {
109        Ok(true)
110    }
111
112    fn destroy(self)
113    where
114        Self: Sized,
115    {
116    }
117}
118
119impl<FS: Io> Wal for LibsqlWal<FS> {
120    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
121    fn limit(&mut self, _size: i64) {}
122
123    #[tracing::instrument(skip_all, fields(id = self.conn_id, ns = self.shared.namespace().as_str()))]
124    fn begin_read_txn(&mut self) -> libsql_sys::wal::Result<bool> {
125        tracing::trace!("begin read");
126        let tx = self.shared.begin_read(self.conn_id);
127        let invalidate_cache = self
128            .last_read_frame_no
129            .map(|idx| tx.max_frame_no != idx)
130            .unwrap_or(true);
131        self.last_read_frame_no = Some(tx.max_frame_no);
132        self.tx = Some(Transaction::Read(tx));
133
134        tracing::trace!(invalidate_cache, "read started");
135        Ok(invalidate_cache)
136    }
137
138    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
139    fn end_read_txn(&mut self) {
140        self.tx.take().map(|tx| tx.end());
141        tracing::trace!("end read tx");
142    }
143
144    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
145    fn find_frame(
146        &mut self,
147        page_no: std::num::NonZeroU32,
148    ) -> libsql_sys::wal::Result<Option<std::num::NonZeroU32>> {
149        tracing::trace!(page_no, "find frame");
150        // this is a trick: we defer the frame read to the `read_frame` method. The read_frame
151        // method will read from the journal if the page exist, or from the db_file if it doesn't
152        Ok(Some(page_no))
153    }
154
155    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
156    fn read_frame(
157        &mut self,
158        page_no: std::num::NonZeroU32,
159        buffer: &mut [u8],
160    ) -> libsql_sys::wal::Result<()> {
161        tracing::trace!(page_no, "reading frame");
162        let tx = self.tx.as_mut().unwrap();
163        self.shared
164            .read_page(tx, page_no.get(), buffer)
165            .map_err(Into::into)?;
166        Ok(())
167    }
168
169    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
170    fn db_size(&self) -> u32 {
171        let db_size = match self.tx.as_ref() {
172            Some(tx) => tx.db_size,
173            None => 0,
174        };
175        tracing::trace!(db_size, "db_size");
176        db_size
177    }
178
179    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
180    fn begin_write_txn(&mut self) -> libsql_sys::wal::Result<()> {
181        tracing::trace!("begin write");
182        match self.tx.as_mut() {
183            Some(tx) => {
184                self.shared.upgrade(tx).map_err(Into::into)?;
185                tracing::trace!("write lock acquired");
186            }
187            None => panic!("should acquire read txn first"),
188        }
189
190        Ok(())
191    }
192
193    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
194    fn end_write_txn(&mut self) -> libsql_sys::wal::Result<()> {
195        tracing::trace!("end write");
196        match self.tx.take() {
197            Some(Transaction::Write(tx)) => {
198                self.last_read_frame_no = Some(tx.next_frame_no - 1);
199                self.tx = Some(Transaction::Read(tx.downgrade()));
200            }
201            other => {
202                self.tx = other;
203            }
204        }
205
206        Ok(())
207    }
208
209    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
210    fn undo<U: libsql_sys::wal::UndoHandler>(
211        &mut self,
212        handler: Option<&mut U>,
213    ) -> libsql_sys::wal::Result<()> {
214        match self.tx {
215            Some(Transaction::Write(ref mut tx)) => {
216                if tx.is_commited() {
217                    return Ok(());
218                }
219                if let Some(handler) = handler {
220                    for page_no in tx.index_page_iter() {
221                        // FIXME: maybe it's not OK to call that callback with duplicated pages_no,
222                        // need to test that
223                        if let Err(e) = handler.handle_undo(page_no) {
224                            tracing::debug!("undo handler error: {e}");
225                            break;
226                        }
227                    }
228                }
229
230                tx.reset(0);
231
232                tracing::debug!("rolled back tx");
233
234                Ok(())
235            }
236            _ => Ok(()),
237        }
238    }
239
240    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
241    fn savepoint(&mut self, rollback_data: &mut [u32]) {
242        match self.tx {
243            Some(Transaction::Write(ref mut tx)) => {
244                let id = tx.savepoint() as u32;
245                rollback_data[0] = id;
246            }
247            _ => {
248                // if we don't have a write tx, we always point to the beginning of the tx
249                rollback_data[0] = 0;
250            }
251        }
252    }
253
254    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
255    fn savepoint_undo(&mut self, rollback_data: &mut [u32]) -> libsql_sys::wal::Result<()> {
256        match self.tx {
257            Some(Transaction::Write(ref mut tx)) => {
258                tx.reset(rollback_data[0] as usize);
259                Ok(())
260            }
261            _ => Ok(()),
262        }
263    }
264
265    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
266    fn insert_frames(
267        &mut self,
268        page_size: std::ffi::c_int,
269        page_headers: &mut libsql_sys::wal::PageHeaders,
270        size_after: u32,
271        _is_commit: bool,
272        _sync_flags: std::ffi::c_int,
273    ) -> libsql_sys::wal::Result<usize> {
274        assert_eq!(page_size, 4096);
275        match self.tx.as_mut() {
276            Some(Transaction::Write(ref mut tx)) => {
277                self.shared
278                    .insert_frames(
279                        tx,
280                        page_headers.iter(),
281                        (size_after != 0).then_some(size_after),
282                    )
283                    .map_err(Into::into)?;
284            }
285            _ => todo!("no write transaction"),
286        }
287        Ok(0)
288    }
289
290    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
291    fn checkpoint(
292        &mut self,
293        _db: &mut libsql_sys::wal::Sqlite3Db,
294        _mode: libsql_sys::wal::CheckpointMode,
295        _busy_handler: Option<&mut dyn libsql_sys::wal::BusyHandler>,
296        _sync_flags: u32,
297        _buf: &mut [u8],
298        _checkpoint_cb: Option<&mut dyn libsql_sys::wal::CheckpointCallback>,
299        _in_wal: Option<&mut i32>,
300        _backfilled: Option<&mut i32>,
301    ) -> libsql_sys::wal::Result<()> {
302        // self.shared.segments.checkpoint();
303        Ok(())
304    }
305
306    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
307    fn exclusive_mode(&mut self, op: std::ffi::c_int) -> libsql_sys::wal::Result<()> {
308        tracing::trace!(op, "trying to acquire exclusive mode");
309        Ok(())
310    }
311
312    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
313    fn uses_heap_memory(&self) -> bool {
314        true
315    }
316
317    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
318    fn set_db(&mut self, _db: &mut libsql_sys::wal::Sqlite3Db) {}
319
320    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
321    fn callback(&self) -> i32 {
322        0
323    }
324
325    #[tracing::instrument(skip_all, fields(id = self.conn_id))]
326    fn frames_in_wal(&self) -> u32 {
327        0
328    }
329}