libsql_sys/wal/
sqlite3_wal.rs

1use std::ffi::{c_int, c_void, CStr};
2use std::mem::MaybeUninit;
3use std::num::NonZeroU32;
4use std::ptr::null_mut;
5
6use libsql_ffi::{
7    libsql_wal, libsql_wal_manager, sqlite3_wal, sqlite3_wal_manager, Error, SQLITE_OK,
8    WAL_SAVEPOINT_NDATA,
9};
10
11use super::{
12    BusyHandler, CheckpointCallback, CheckpointMode, PageHeaders, Result, Sqlite3Db, Sqlite3File,
13    UndoHandler, Vfs, Wal, WalManager,
14};
15
16/// SQLite3 default wal_manager implementation.
17#[derive(Clone, Copy)]
18pub struct Sqlite3WalManager {
19    inner: libsql_wal_manager,
20}
21
22/// Safety: the create pointer is an immutable global pointer
23unsafe impl Send for Sqlite3WalManager {}
24unsafe impl Sync for Sqlite3WalManager {}
25
26impl Sqlite3WalManager {
27    pub fn new() -> Self {
28        Self {
29            inner: unsafe { sqlite3_wal_manager },
30        }
31    }
32}
33
34impl Default for Sqlite3WalManager {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl WalManager for Sqlite3WalManager {
41    type Wal = Sqlite3Wal;
42
43    fn use_shared_memory(&self) -> bool {
44        self.inner.bUsesShm != 0
45    }
46
47    fn open(
48        &self,
49        vfs: &mut Vfs,
50        file: &mut Sqlite3File,
51        no_shm_mode: c_int,
52        max_log_size: i64,
53        db_path: &CStr,
54    ) -> Result<Self::Wal> {
55        let mut wal: MaybeUninit<libsql_wal> = MaybeUninit::uninit();
56        let rc = unsafe {
57            (self.inner.xOpen.unwrap())(
58                self.inner.pData,
59                vfs.as_ptr(),
60                file.as_ptr(),
61                no_shm_mode,
62                max_log_size,
63                db_path.as_ptr(),
64                wal.as_mut_ptr(),
65            )
66        };
67
68        if rc != 0 {
69            Err(Error::new(rc))?
70        }
71
72        let inner = unsafe { wal.assume_init() };
73
74        Ok(Sqlite3Wal { inner })
75    }
76
77    fn close(
78        &self,
79        wal: &mut Self::Wal,
80        db: &mut Sqlite3Db,
81        sync_flags: c_int,
82        scratch: Option<&mut [u8]>,
83    ) -> Result<()> {
84        let scratch_len = scratch.as_ref().map(|s| s.len()).unwrap_or(0);
85        let scratch_ptr = scratch.map(|s| s.as_mut_ptr()).unwrap_or(null_mut());
86        let rc = unsafe {
87            (self.inner.xClose.unwrap())(
88                self.inner.pData,
89                wal.inner.pData,
90                db.as_ptr(),
91                sync_flags,
92                scratch_len as _,
93                scratch_ptr as _,
94            )
95        };
96
97        if rc != 0 {
98            Err(Error::new(rc))?
99        } else {
100            Ok(())
101        }
102    }
103
104    fn destroy_log(&self, vfs: &mut Vfs, db_path: &CStr) -> Result<()> {
105        let rc = unsafe {
106            (self.inner.xLogDestroy.unwrap())(self.inner.pData, vfs.as_ptr(), db_path.as_ptr())
107        };
108
109        if rc != 0 {
110            Err(Error::new(rc))?
111        } else {
112            Ok(())
113        }
114    }
115
116    fn log_exists(&self, vfs: &mut Vfs, db_path: &CStr) -> Result<bool> {
117        let mut out: c_int = 0;
118        let rc = unsafe {
119            (self.inner.xLogExists.unwrap())(
120                self.inner.pData,
121                vfs.as_ptr(),
122                db_path.as_ptr(),
123                &mut out,
124            )
125        };
126
127        if rc != 0 {
128            Err(Error::new(rc))?
129        } else {
130            Ok(out != 0)
131        }
132    }
133
134    fn destroy(self)
135    where
136        Self: Sized,
137    {
138        unsafe { (self.inner.xDestroy.unwrap())(self.inner.pData) }
139    }
140}
141
142unsafe impl Send for Sqlite3Wal {}
143
144/// SQLite3 wal implementation
145pub struct Sqlite3Wal {
146    inner: libsql_wal,
147}
148
149impl Wal for Sqlite3Wal {
150    fn limit(&mut self, size: i64) {
151        unsafe {
152            (self.inner.methods.xLimit.unwrap())(self.inner.pData, size);
153        }
154    }
155
156    fn begin_read_txn(&mut self) -> Result<bool> {
157        let mut out: c_int = 0;
158        let rc = unsafe {
159            (self.inner.methods.xBeginReadTransaction.unwrap())(
160                self.inner.pData,
161                &mut out as *mut _,
162            )
163        };
164        if rc != 0 {
165            Err(Error::new(rc))
166        } else {
167            Ok(out != 0)
168        }
169    }
170
171    fn end_read_txn(&mut self) {
172        unsafe {
173            (self.inner.methods.xEndReadTransaction.unwrap())(self.inner.pData);
174        }
175    }
176
177    fn find_frame(&mut self, page_no: NonZeroU32) -> Result<Option<NonZeroU32>> {
178        let mut out: u32 = 0;
179        let rc = unsafe {
180            (self.inner.methods.xFindFrame.unwrap())(self.inner.pData, page_no.into(), &mut out)
181        };
182
183        if rc != 0 {
184            Err(Error::new(rc))
185        } else {
186            Ok(NonZeroU32::new(out))
187        }
188    }
189
190    fn read_frame(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> Result<()> {
191        let rc = unsafe {
192            (self.inner.methods.xReadFrame.unwrap())(
193                self.inner.pData,
194                frame_no.into(),
195                buffer.len() as _,
196                buffer.as_mut_ptr(),
197            )
198        };
199        if rc != 0 {
200            Err(Error::new(rc))
201        } else {
202            Ok(())
203        }
204    }
205
206    fn read_frame_raw(&mut self, frame_no: NonZeroU32, buffer: &mut [u8]) -> Result<()> {
207        let rc = unsafe {
208            (self.inner.methods.xReadFrameRaw.unwrap())(
209                self.inner.pData,
210                frame_no.into(),
211                buffer.len() as _,
212                buffer.as_mut_ptr(),
213            )
214        };
215        if rc != 0 {
216            Err(Error::new(rc))
217        } else {
218            Ok(())
219        }
220    }
221
222    fn db_size(&self) -> u32 {
223        unsafe { (self.inner.methods.xDbsize.unwrap())(self.inner.pData) }
224    }
225
226    fn begin_write_txn(&mut self) -> Result<()> {
227        let rc = unsafe { (self.inner.methods.xBeginWriteTransaction.unwrap())(self.inner.pData) };
228        if rc != 0 {
229            Err(Error::new(rc))
230        } else {
231            Ok(())
232        }
233    }
234
235    fn end_write_txn(&mut self) -> Result<()> {
236        let rc = unsafe { (self.inner.methods.xEndWriteTransaction.unwrap())(self.inner.pData) };
237        if rc != 0 {
238            Err(Error::new(rc))
239        } else {
240            Ok(())
241        }
242    }
243
244    fn undo<U: UndoHandler>(&mut self, undo_handler: Option<&mut U>) -> Result<()> {
245        unsafe extern "C" fn call_handler<U: UndoHandler>(p: *mut c_void, page_no: u32) -> c_int {
246            let this = &mut *(p as *mut U);
247            match this.handle_undo(page_no) {
248                Ok(_) => SQLITE_OK,
249                Err(e) => e.extended_code,
250            }
251        }
252
253        let handler = undo_handler
254            .is_some()
255            .then_some(call_handler::<U> as unsafe extern "C" fn(*mut c_void, u32) -> i32);
256        let handler_data = undo_handler
257            .map(|d| d as *mut _ as *mut _)
258            .unwrap_or(std::ptr::null_mut());
259
260        let rc =
261            unsafe { (self.inner.methods.xUndo.unwrap())(self.inner.pData, handler, handler_data) };
262        if rc != 0 {
263            Err(Error::new(rc))
264        } else {
265            Ok(())
266        }
267    }
268
269    fn savepoint(&mut self, rollback_data: &mut [u32]) {
270        assert_eq!(rollback_data.len(), WAL_SAVEPOINT_NDATA as usize);
271        unsafe {
272            (self.inner.methods.xSavepoint.unwrap())(self.inner.pData, rollback_data.as_mut_ptr());
273        }
274    }
275
276    fn savepoint_undo(&mut self, rollback_data: &mut [u32]) -> Result<()> {
277        assert_eq!(rollback_data.len(), WAL_SAVEPOINT_NDATA as usize);
278        let rc = unsafe {
279            (self.inner.methods.xSavepointUndo.unwrap())(
280                self.inner.pData,
281                rollback_data.as_mut_ptr(),
282            )
283        };
284        if rc != 0 {
285            Err(Error::new(rc))
286        } else {
287            Ok(())
288        }
289    }
290
291    fn frame_count(&self, locked: i32) -> Result<u32> {
292        let mut out: u32 = 0;
293        let rc = unsafe {
294            (self.inner.methods.xFrameCount.unwrap())(self.inner.pData, locked, &mut out)
295        };
296        if rc != 0 {
297            Err(Error::new(rc))
298        } else {
299            Ok(out)
300        }
301    }
302
303    fn insert_frames(
304        &mut self,
305        page_size: c_int,
306        page_headers: &mut PageHeaders,
307        size_after: u32,
308        is_commit: bool,
309        sync_flags: c_int,
310    ) -> Result<usize> {
311        let mut frames = 0;
312        let rc = unsafe {
313            (self.inner.methods.xFrames.unwrap())(
314                self.inner.pData,
315                page_size,
316                page_headers.as_mut_ptr(),
317                size_after,
318                is_commit as _,
319                sync_flags,
320                &mut frames,
321            )
322        };
323        if rc != 0 {
324            Err(Error::new(rc))
325        } else {
326            Ok(frames as _)
327        }
328    }
329
330    fn checkpoint(
331        &mut self,
332        db: &mut Sqlite3Db,
333        mode: CheckpointMode,
334        mut busy_handler: Option<&mut dyn BusyHandler>,
335        sync_flags: u32,
336        // temporary scratch buffer
337        buf: &mut [u8],
338        mut checkpoint_cb: Option<&mut dyn CheckpointCallback>,
339        in_wal: Option<&mut i32>,
340        backfilled: Option<&mut i32>,
341    ) -> Result<()> {
342        unsafe extern "C" fn call_handler(p: *mut c_void) -> c_int {
343            let this = &mut *(p as *mut &mut dyn BusyHandler);
344            this.handle_busy() as _
345        }
346
347        unsafe extern "C" fn call_cb(
348            data: *mut c_void,
349            max_safe_frame_no: c_int,
350            page: *const u8,
351            page_len: c_int,
352            page_no: c_int,
353            frame_no: c_int,
354        ) -> c_int {
355            let this = &mut *(data as *mut &mut dyn CheckpointCallback);
356            let ret = if page.is_null() {
357                this.finish()
358            } else {
359                this.frame(
360                    max_safe_frame_no as _,
361                    std::slice::from_raw_parts(page, page_len as _),
362                    NonZeroU32::new(page_no as _).unwrap(),
363                    NonZeroU32::new(frame_no as _).unwrap(),
364                )
365            };
366
367            match ret {
368                Ok(()) => 0,
369                Err(e) => e.extended_code,
370            }
371        }
372
373        let handler = busy_handler
374            .is_some()
375            .then_some(call_handler as unsafe extern "C" fn(*mut c_void) -> i32);
376        let handler_data = busy_handler
377            .as_mut()
378            .map(|d| d as *mut _ as *mut _)
379            .unwrap_or(std::ptr::null_mut());
380
381        let checkpoint_cb_fn = checkpoint_cb.is_some().then_some(call_cb as _);
382        let checkpoint_cb_data = checkpoint_cb
383            .as_mut()
384            .map(|d| d as *mut &mut dyn CheckpointCallback as *mut _)
385            .unwrap_or(std::ptr::null_mut());
386
387        let out_log_num_frames = in_wal.map(|ptr| ptr as _).unwrap_or(std::ptr::null_mut());
388        let out_backfilled = backfilled
389            .map(|ptr| ptr as _)
390            .unwrap_or(std::ptr::null_mut());
391
392        let rc = unsafe {
393            (self.inner.methods.xCheckpoint.unwrap())(
394                self.inner.pData,
395                db.as_ptr(),
396                mode as _,
397                handler,
398                handler_data,
399                sync_flags as _,
400                buf.len() as _,
401                buf.as_mut_ptr(),
402                out_log_num_frames,
403                out_backfilled,
404                checkpoint_cb_fn,
405                checkpoint_cb_data,
406            )
407        };
408
409        if rc != 0 {
410            Err(Error::new(rc))
411        } else {
412            Ok(())
413        }
414    }
415
416    fn exclusive_mode(&mut self, op: c_int) -> Result<()> {
417        let rc = unsafe { (self.inner.methods.xExclusiveMode.unwrap())(self.inner.pData, op) };
418
419        if rc != 0 {
420            Err(Error::new(rc))
421        } else {
422            Ok(())
423        }
424    }
425
426    fn uses_heap_memory(&self) -> bool {
427        unsafe { (self.inner.methods.xHeapMemory.unwrap())(self.inner.pData) != 0 }
428    }
429
430    fn set_db(&mut self, db: &mut Sqlite3Db) {
431        unsafe {
432            (self.inner.methods.xDb.unwrap())(self.inner.pData, db.as_ptr());
433        }
434    }
435
436    fn callback(&self) -> i32 {
437        unsafe { (self.inner.methods.xCallback.unwrap())(self.inner.pData) }
438    }
439
440    fn frames_in_wal(&self) -> u32 {
441        unsafe {
442            let wal = &*(self.inner.pData as *const sqlite3_wal);
443            wal.hdr.mxFrame
444        }
445    }
446}