1#![allow(non_snake_case)]
2#![allow(clippy::not_unsafe_ptr_arg_deref)]
3
4mod ffi;
5
6pub mod replicator;
7
8use crate::ffi::{libsql_wal_methods, sqlite3_file, sqlite3_vfs, PgHdr, Wal};
9use std::ffi::c_void;
10
11fn is_regular(vfs: *const sqlite3_vfs) -> bool {
13    let vfs = unsafe { std::ffi::CStr::from_ptr((*vfs).zName) }
14        .to_str()
15        .unwrap_or("[error]");
16    tracing::trace!("VFS: {}", vfs);
17    vfs.starts_with("unix") || vfs.starts_with("win32")
18}
19
20macro_rules! block_on {
21    ($runtime:expr, $e:expr) => {
22        $runtime.block_on(async { $e.await })
23    };
24}
25
26fn is_local() -> bool {
27    std::env::var("LIBSQL_BOTTOMLESS_LOCAL").map_or(false, |local| {
28        local.eq_ignore_ascii_case("true")
29            || local.eq_ignore_ascii_case("t")
30            || local.eq_ignore_ascii_case("yes")
31            || local.eq_ignore_ascii_case("y")
32            || local == "1"
33    })
34}
35
36pub extern "C" fn xOpen(
37    vfs: *const sqlite3_vfs,
38    db_file: *mut sqlite3_file,
39    wal_name: *const i8,
40    no_shm_mode: i32,
41    max_size: i64,
42    methods: *mut libsql_wal_methods,
43    wal: *mut *mut Wal,
44) -> i32 {
45    tracing::debug!("Opening WAL {}", unsafe {
46        std::ffi::CStr::from_ptr(wal_name).to_str().unwrap()
47    });
48
49    let orig_methods = unsafe { &*(*methods).underlying_methods };
50    let rc = (orig_methods.xOpen)(vfs, db_file, wal_name, no_shm_mode, max_size, methods, wal);
51    if rc != ffi::SQLITE_OK {
52        return rc;
53    }
54
55    if !is_regular(vfs) {
56        tracing::error!("Bottomless WAL is currently only supported for regular VFS");
57        return ffi::SQLITE_CANTOPEN;
58    }
59
60    if is_local() {
61        tracing::info!("Running in local-mode only, without any replication");
62        return ffi::SQLITE_OK;
63    }
64
65    let runtime = match tokio::runtime::Builder::new_current_thread()
66        .enable_all()
67        .build()
68    {
69        Ok(runtime) => runtime,
70        Err(e) => {
71            tracing::error!("Failed to initialize async runtime: {}", e);
72            return ffi::SQLITE_CANTOPEN;
73        }
74    };
75
76    let replicator = block_on!(runtime, replicator::Replicator::new());
77    let mut replicator = match replicator {
78        Ok(repl) => repl,
79        Err(e) => {
80            tracing::error!("Failed to initialize replicator: {}", e);
81            return ffi::SQLITE_CANTOPEN;
82        }
83    };
84
85    let path = unsafe {
86        match std::ffi::CStr::from_ptr(wal_name).to_str() {
87            Ok(path) if path.len() >= 4 => &path[..path.len() - 4],
88            Ok(path) => path,
89            Err(e) => {
90                tracing::error!("Failed to parse the main database path: {}", e);
91                return ffi::SQLITE_CANTOPEN;
92            }
93        }
94    };
95
96    replicator.register_db(path);
97    let rc = block_on!(runtime, try_restore(&mut replicator));
98    if rc != ffi::SQLITE_OK {
99        return rc;
100    }
101
102    let context = replicator::Context {
103        replicator,
104        runtime,
105    };
106    unsafe { (*(*wal)).replicator_context = Box::leak(Box::new(context)) };
107
108    ffi::SQLITE_OK
109}
110
111fn get_orig_methods(wal: *mut Wal) -> &'static libsql_wal_methods {
112    unsafe { &*((*(*wal).wal_methods).underlying_methods) }
113}
114
115fn get_replicator_context(wal: *mut Wal) -> &'static mut replicator::Context {
116    unsafe { &mut *((*wal).replicator_context) }
117}
118
119pub extern "C" fn xClose(
120    wal: *mut Wal,
121    db: *mut c_void,
122    sync_flags: i32,
123    n_buf: i32,
124    z_buf: *mut u8,
125) -> i32 {
126    tracing::debug!("Closing wal");
127    let orig_methods = get_orig_methods(wal);
128    if !is_local() {
129        let _replicator_box = unsafe { Box::from_raw((*wal).replicator_context) };
130    }
131
132    (orig_methods.xClose)(wal, db, sync_flags, n_buf, z_buf)
133}
134
135pub extern "C" fn xLimit(wal: *mut Wal, limit: i64) {
136    let orig_methods = get_orig_methods(wal);
137    (orig_methods.xLimit)(wal, limit)
138}
139
140pub extern "C" fn xBeginReadTransaction(wal: *mut Wal, changed: *mut i32) -> i32 {
141    let orig_methods = get_orig_methods(wal);
142    (orig_methods.xBeginReadTransaction)(wal, changed)
143}
144
145pub extern "C" fn xEndReadTransaction(wal: *mut Wal) -> i32 {
146    let orig_methods = get_orig_methods(wal);
147    (orig_methods.xEndReadTransaction)(wal)
148}
149
150pub extern "C" fn xFindFrame(wal: *mut Wal, pgno: i32, frame: *mut i32) -> i32 {
151    let orig_methods = get_orig_methods(wal);
152    (orig_methods.xFindFrame)(wal, pgno, frame)
153}
154
155pub extern "C" fn xReadFrame(wal: *mut Wal, frame: u32, n_out: i32, p_out: *mut u8) -> i32 {
156    let orig_methods = get_orig_methods(wal);
157    (orig_methods.xReadFrame)(wal, frame, n_out, p_out)
158}
159
160pub extern "C" fn xDbSize(wal: *mut Wal) -> i32 {
161    let orig_methods = get_orig_methods(wal);
162    (orig_methods.xDbSize)(wal)
163}
164
165pub extern "C" fn xBeginWriteTransaction(wal: *mut Wal) -> i32 {
166    let orig_methods = get_orig_methods(wal);
167    (orig_methods.xBeginWriteTransaction)(wal)
168}
169
170pub extern "C" fn xEndWriteTransaction(wal: *mut Wal) -> i32 {
171    let orig_methods = get_orig_methods(wal);
172    (orig_methods.xEndWriteTransaction)(wal)
173}
174
175pub extern "C" fn xUndo(
176    wal: *mut Wal,
177    func: extern "C" fn(*mut c_void, i32) -> i32,
178    ctx: *mut c_void,
179) -> i32 {
180    let orig_methods = get_orig_methods(wal);
181    let rc = (orig_methods.xUndo)(wal, func, ctx);
182    if is_local() || rc != ffi::SQLITE_OK {
183        return rc;
184    }
185
186    let last_valid_frame = unsafe { (*wal).hdr.last_valid_frame };
187    let ctx = get_replicator_context(wal);
188    tracing::trace!(
189        "Undo: rolling back from frame {} to {}",
190        ctx.replicator.peek_last_valid_frame(),
191        last_valid_frame
192    );
193    ctx.replicator.rollback_to_frame(last_valid_frame);
194
195    ffi::SQLITE_OK
196}
197
198pub extern "C" fn xSavepoint(wal: *mut Wal, wal_data: *mut u32) {
199    let orig_methods = get_orig_methods(wal);
200    (orig_methods.xSavepoint)(wal, wal_data)
201}
202
203pub extern "C" fn xSavepointUndo(wal: *mut Wal, wal_data: *mut u32) -> i32 {
204    let orig_methods = get_orig_methods(wal);
205    let rc = (orig_methods.xSavepointUndo)(wal, wal_data);
206    if is_local() || rc != ffi::SQLITE_OK {
207        return rc;
208    }
209
210    let last_valid_frame = unsafe { *wal_data };
211    let ctx = get_replicator_context(wal);
212    tracing::trace!(
213        "Savepoint: rolling back from frame {} to {}",
214        ctx.replicator.peek_last_valid_frame(),
215        last_valid_frame
216    );
217    ctx.replicator.rollback_to_frame(last_valid_frame);
218
219    ffi::SQLITE_OK
220}
221
222pub extern "C" fn xFrames(
223    wal: *mut Wal,
224    page_size: u32,
225    page_headers: *const PgHdr,
226    size_after: u32,
227    is_commit: i32,
228    sync_flags: i32,
229) -> i32 {
230    let mut last_consistent_frame = 0;
231    if !is_local() {
232        let ctx = get_replicator_context(wal);
233        let last_valid_frame = unsafe { (*wal).hdr.last_valid_frame };
234        ctx.replicator.register_last_valid_frame(last_valid_frame);
235        if let Err(e) = ctx.replicator.set_page_size(page_size as usize) {
241            tracing::error!("{}", e);
242            return ffi::SQLITE_IOERR_WRITE;
243        }
244        for (pgno, data) in ffi::PageHdrIter::new(page_headers, page_size as usize) {
245            ctx.replicator.write(pgno, data);
246        }
247
248        if is_commit != 0 {
254            last_consistent_frame = match block_on!(ctx.runtime, ctx.replicator.flush()) {
255                Ok(frame) => frame,
256                Err(e) => {
257                    tracing::error!("Failed to replicate: {}", e);
258                    return ffi::SQLITE_IOERR_WRITE;
259                }
260            };
261        }
262    }
263
264    let orig_methods = get_orig_methods(wal);
265    let rc = (orig_methods.xFrames)(
266        wal,
267        page_size,
268        page_headers,
269        size_after,
270        is_commit,
271        sync_flags,
272    );
273    if is_local() || rc != ffi::SQLITE_OK {
274        return rc;
275    }
276
277    let ctx = get_replicator_context(wal);
278    if is_commit != 0 {
279        let frame_checksum = unsafe { (*wal).hdr.frame_checksum };
280
281        if let Err(e) = block_on!(
282            ctx.runtime,
283            ctx.replicator
284                .finalize_commit(last_consistent_frame, frame_checksum)
285        ) {
286            tracing::error!("Failed to finalize replication: {}", e);
287            return ffi::SQLITE_IOERR_WRITE;
288        }
289    }
290
291    ffi::SQLITE_OK
292}
293
294extern "C" fn always_wait(_busy_param: *mut c_void) -> i32 {
295    std::thread::sleep(std::time::Duration::from_millis(10));
296    1
297}
298
299#[tracing::instrument(skip(wal, db, busy_handler, busy_arg))]
300pub extern "C" fn xCheckpoint(
301    wal: *mut Wal,
302    db: *mut c_void,
303    emode: i32,
304    busy_handler: extern "C" fn(busy_param: *mut c_void) -> i32,
305    busy_arg: *const c_void,
306    sync_flags: i32,
307    n_buf: i32,
308    z_buf: *mut u8,
309    frames_in_wal: *mut i32,
310    backfilled_frames: *mut i32,
311) -> i32 {
312    tracing::trace!("Checkpoint");
313
314    let emode = if emode < ffi::SQLITE_CHECKPOINT_TRUNCATE {
324        tracing::trace!("Upgrading checkpoint to TRUNCATE mode");
325        ffi::SQLITE_CHECKPOINT_TRUNCATE
326    } else {
327        emode
328    };
329    let busy_handler = if (busy_handler as *const c_void).is_null() {
333        tracing::trace!("Falling back to the default busy handler - always wait");
334        always_wait
335    } else {
336        busy_handler
337    };
338
339    let orig_methods = get_orig_methods(wal);
340    let rc = (orig_methods.xCheckpoint)(
341        wal,
342        db,
343        emode,
344        busy_handler,
345        busy_arg,
346        sync_flags,
347        n_buf,
348        z_buf,
349        frames_in_wal,
350        backfilled_frames,
351    );
352
353    if is_local() || rc != ffi::SQLITE_OK {
354        return rc;
355    }
356
357    let ctx = get_replicator_context(wal);
358    if ctx.replicator.commits_in_current_generation == 0 {
359        tracing::debug!("No commits happened in this generation, not snapshotting");
360        return ffi::SQLITE_OK;
361    }
362
363    ctx.replicator.new_generation();
364    tracing::debug!("Snapshotting after checkpoint");
365    let result = block_on!(ctx.runtime, ctx.replicator.snapshot_main_db_file());
366    if let Err(e) = result {
367        tracing::error!(
368            "Failed to snapshot the main db file during checkpoint: {}",
369            e
370        );
371        return ffi::SQLITE_IOERR_WRITE;
372    }
373
374    ffi::SQLITE_OK
375}
376
377pub extern "C" fn xCallback(wal: *mut Wal) -> i32 {
378    let orig_methods = get_orig_methods(wal);
379    (orig_methods.xCallback)(wal)
380}
381
382pub extern "C" fn xExclusiveMode(wal: *mut Wal) -> i32 {
383    let orig_methods = get_orig_methods(wal);
384    (orig_methods.xExclusiveMode)(wal)
385}
386
387pub extern "C" fn xHeapMemory(wal: *mut Wal) -> i32 {
388    let orig_methods = get_orig_methods(wal);
389    (orig_methods.xHeapMemory)(wal)
390}
391
392pub extern "C" fn xFile(wal: *mut Wal) -> *const c_void {
393    let orig_methods = get_orig_methods(wal);
394    (orig_methods.xFile)(wal)
395}
396
397pub extern "C" fn xDb(wal: *mut Wal, db: *const c_void) {
398    let orig_methods = get_orig_methods(wal);
399    (orig_methods.xDb)(wal, db)
400}
401
402pub extern "C" fn xPathnameLen(orig_len: i32) -> i32 {
403    orig_len + 4
404}
405
406pub extern "C" fn xGetPathname(buf: *mut u8, orig: *const u8, orig_len: i32) {
407    unsafe { std::ptr::copy(orig, buf, orig_len as usize) }
408    unsafe { std::ptr::copy("-wal".as_ptr(), buf.offset(orig_len as isize), 4) }
409}
410
411async fn try_restore(replicator: &mut replicator::Replicator) -> i32 {
412    match replicator.restore().await {
413        Ok(replicator::RestoreAction::None) => (),
414        Ok(replicator::RestoreAction::SnapshotMainDbFile) => {
415            replicator.new_generation();
416            if let Err(e) = replicator.snapshot_main_db_file().await {
417                tracing::error!("Failed to snapshot the main db file: {}", e);
418                return ffi::SQLITE_CANTOPEN;
419            }
420            if let Err(e) = replicator.maybe_replicate_wal().await {
423                tracing::error!("Failed to replicate local WAL: {}", e);
424                return ffi::SQLITE_CANTOPEN;
425            }
426        }
427        Ok(replicator::RestoreAction::ReuseGeneration(gen)) => {
428            replicator.set_generation(gen);
429        }
430        Err(e) => {
431            tracing::error!("Failed to restore the database: {}", e);
432            return ffi::SQLITE_CANTOPEN;
433        }
434    }
435
436    ffi::SQLITE_OK
437}
438
439pub extern "C" fn xPreMainDbOpen(_methods: *mut libsql_wal_methods, path: *const i8) -> i32 {
440    if is_local() {
441        tracing::info!("Running in local-mode only, without any replication");
442        return ffi::SQLITE_OK;
443    }
444
445    if path.is_null() {
446        return ffi::SQLITE_OK;
447    }
448    let path = unsafe {
449        match std::ffi::CStr::from_ptr(path).to_str() {
450            Ok(path) => path,
451            Err(e) => {
452                tracing::error!("Failed to parse the main database path: {}", e);
453                return ffi::SQLITE_CANTOPEN;
454            }
455        }
456    };
457    tracing::debug!("Main database file {} will be open soon", path);
458
459    let runtime = match tokio::runtime::Builder::new_current_thread()
460        .enable_all()
461        .build()
462    {
463        Ok(runtime) => runtime,
464        Err(e) => {
465            tracing::error!("Failed to initialize async runtime: {}", e);
466            return ffi::SQLITE_CANTOPEN;
467        }
468    };
469
470    let replicator = block_on!(
471        runtime,
472        replicator::Replicator::create(replicator::Options {
473            create_bucket_if_not_exists: true,
474            verify_crc: true,
475        })
476    );
477    let mut replicator = match replicator {
478        Ok(repl) => repl,
479        Err(e) => {
480            tracing::error!("Failed to initialize replicator: {}", e);
481            return ffi::SQLITE_CANTOPEN;
482        }
483    };
484
485    replicator.register_db(path);
486    block_on!(runtime, try_restore(&mut replicator))
487}
488
489#[no_mangle]
490pub extern "C" fn bottomless_init() {
491    tracing::debug!("bottomless module initialized");
492}
493
494#[tracing::instrument]
495#[no_mangle]
496pub extern "C" fn bottomless_methods(
497    underlying_methods: *const libsql_wal_methods,
498) -> *const libsql_wal_methods {
499    let vwal_name: *const u8 = "bottomless\0".as_ptr();
500
501    Box::into_raw(Box::new(libsql_wal_methods {
502        iVersion: 1,
503        xOpen,
504        xClose,
505        xLimit,
506        xBeginReadTransaction,
507        xEndReadTransaction,
508        xFindFrame,
509        xReadFrame,
510        xDbSize,
511        xBeginWriteTransaction,
512        xEndWriteTransaction,
513        xUndo,
514        xSavepoint,
515        xSavepointUndo,
516        xFrames,
517        xCheckpoint,
518        xCallback,
519        xExclusiveMode,
520        xHeapMemory,
521        snapshot_get_stub: std::ptr::null(),
522        snapshot_open_stub: std::ptr::null(),
523        snapshot_recover_stub: std::ptr::null(),
524        snapshot_check_stub: std::ptr::null(),
525        snapshot_unlock_stub: std::ptr::null(),
526        framesize_stub: std::ptr::null(),
527        xFile,
528        write_lock_stub: std::ptr::null(),
529        xDb,
530        xPathnameLen,
531        xGetPathname,
532        xPreMainDbOpen,
533        name: vwal_name,
534        b_uses_shm: 0,
535        p_next: std::ptr::null(),
536        underlying_methods,
537    }))
538}
539
540#[cfg(feature = "libsql_linked_statically")]
541pub mod static_init {
542    use crate::libsql_wal_methods;
543
544    extern "C" {
545        fn libsql_wal_methods_find(name: *const std::ffi::c_char) -> *const libsql_wal_methods;
546        fn libsql_wal_methods_register(methods: *const libsql_wal_methods) -> i32;
547    }
548
549    pub fn register_bottomless_methods() {
550        static INIT: std::sync::Once = std::sync::Once::new();
551        INIT.call_once(|| {
552            crate::bottomless_init();
553            let orig_methods = unsafe { libsql_wal_methods_find(std::ptr::null()) };
554            if orig_methods.is_null() {}
555            let methods = crate::bottomless_methods(orig_methods);
556            let rc = unsafe { libsql_wal_methods_register(methods) };
557            if rc != crate::ffi::SQLITE_OK {
558                let _box = unsafe { Box::from_raw(methods as *mut libsql_wal_methods) };
559                tracing::warn!("Failed to instantiate bottomless WAL methods");
560            }
561        })
562    }
563}