Skip to main content

limbo_core/ext/
dynamic.rs

1use crate::{
2    ext::{register_aggregate_function, register_scalar_function, register_vtab_module},
3    Connection, LimboError,
4};
5use libloading::{Library, Symbol};
6use limbo_ext::{ExtensionApi, ExtensionApiRef, ExtensionEntryPoint, ResultCode, VfsImpl};
7use std::{
8    ffi::{c_char, CString},
9    sync::{Arc, Mutex, OnceLock},
10};
11
12type ExtensionStore = Vec<(Arc<Library>, ExtensionApiRef)>;
13static EXTENSIONS: OnceLock<Arc<Mutex<ExtensionStore>>> = OnceLock::new();
14pub fn get_extension_libraries() -> Arc<Mutex<ExtensionStore>> {
15    EXTENSIONS
16        .get_or_init(|| Arc::new(Mutex::new(Vec::new())))
17        .clone()
18}
19
20type Vfs = (String, Arc<VfsMod>);
21static VFS_MODULES: OnceLock<Mutex<Vec<Vfs>>> = OnceLock::new();
22
23#[derive(Clone, Debug)]
24pub struct VfsMod {
25    pub ctx: *const VfsImpl,
26}
27
28unsafe impl Send for VfsMod {}
29unsafe impl Sync for VfsMod {}
30
31impl Connection {
32    pub fn load_extension<P: AsRef<std::ffi::OsStr>>(
33        self: &Arc<Connection>,
34        path: P,
35    ) -> crate::Result<()> {
36        use limbo_ext::ExtensionApiRef;
37
38        let api = Box::new(self.build_limbo_ext());
39        let lib =
40            unsafe { Library::new(path).map_err(|e| LimboError::ExtensionError(e.to_string()))? };
41        let entry: Symbol<ExtensionEntryPoint> = unsafe {
42            lib.get(b"register_extension")
43                .map_err(|e| LimboError::ExtensionError(e.to_string()))?
44        };
45        let api_ptr: *const ExtensionApi = Box::into_raw(api);
46        let api_ref = ExtensionApiRef { api: api_ptr };
47        let result_code = unsafe { entry(api_ptr) };
48        if result_code.is_ok() {
49            let extensions = get_extension_libraries();
50            extensions
51                .lock()
52                .map_err(|_| {
53                    LimboError::ExtensionError("Error locking extension libraries".to_string())
54                })?
55                .push((Arc::new(lib), api_ref));
56            {
57                self.parse_schema_rows()?;
58            }
59            Ok(())
60        } else {
61            if !api_ptr.is_null() {
62                let _ = unsafe { Box::from_raw(api_ptr.cast_mut()) };
63            }
64            Err(LimboError::ExtensionError(
65                "Extension registration failed".to_string(),
66            ))
67        }
68    }
69}
70
71#[allow(clippy::arc_with_non_send_sync)]
72pub(crate) unsafe extern "C" fn register_vfs(
73    name: *const c_char,
74    vfs: *const VfsImpl,
75) -> ResultCode {
76    if name.is_null() || vfs.is_null() {
77        return ResultCode::Error;
78    }
79    let c_str = unsafe { CString::from_raw(name as *mut _) };
80    let name_str = match c_str.to_str() {
81        Ok(s) => s.to_string(),
82        Err(_) => return ResultCode::Error,
83    };
84    add_vfs_module(name_str, Arc::new(VfsMod { ctx: vfs }));
85    ResultCode::OK
86}
87
88/// Get pointers to all the vfs extensions that need to be built in at compile time.
89/// any other types that are defined in the same extension will not be registered
90/// until the database file is opened and `register_builtins` is called.
91#[cfg(feature = "fs")]
92#[allow(clippy::arc_with_non_send_sync)]
93pub fn add_builtin_vfs_extensions(
94    api: Option<ExtensionApi>,
95) -> crate::Result<Vec<(String, Arc<VfsMod>)>> {
96    use limbo_ext::VfsInterface;
97
98    let mut vfslist: Vec<*const VfsImpl> = Vec::new();
99    let mut api = match api {
100        None => ExtensionApi {
101            ctx: std::ptr::null_mut(),
102            register_scalar_function,
103            register_aggregate_function,
104            register_vtab_module,
105            vfs_interface: VfsInterface {
106                register_vfs,
107                builtin_vfs: vfslist.as_mut_ptr(),
108                builtin_vfs_count: 0,
109            },
110        },
111        Some(mut api) => {
112            api.vfs_interface.builtin_vfs = vfslist.as_mut_ptr();
113            api
114        }
115    };
116    register_static_vfs_modules(&mut api);
117    let mut vfslist = Vec::with_capacity(api.vfs_interface.builtin_vfs_count as usize);
118    let slice = unsafe {
119        std::slice::from_raw_parts_mut(
120            api.vfs_interface.builtin_vfs,
121            api.vfs_interface.builtin_vfs_count as usize,
122        )
123    };
124    for vfs in slice {
125        if vfs.is_null() {
126            continue;
127        }
128        let vfsimpl = unsafe { &**vfs };
129        let name = unsafe {
130            CString::from_raw(vfsimpl.name as *mut _)
131                .to_str()
132                .map_err(|_| {
133                    LimboError::ExtensionError("unable to register vfs extension".to_string())
134                })?
135                .to_string()
136        };
137        vfslist.push((
138            name,
139            Arc::new(VfsMod {
140                ctx: vfsimpl as *const _,
141            }),
142        ));
143    }
144    Ok(vfslist)
145}
146
147#[cfg(feature = "fs")]
148fn register_static_vfs_modules(_api: &mut ExtensionApi) {
149    // testvfs extension removed (limbo_ext_tests was a crates.io dep incompatible
150    // with our local oxisqlite-ext fork); no static VFS modules needed here.
151    let _ = _api;
152}
153
154pub fn add_vfs_module(name: String, vfs: Arc<VfsMod>) {
155    let mut modules = VFS_MODULES
156        .get_or_init(|| Mutex::new(Vec::new()))
157        .lock()
158        .expect("VFS_MODULES mutex poisoned");
159    if !modules.iter().any(|v| v.0 == name) {
160        modules.push((name, vfs));
161    }
162}
163
164pub fn list_vfs_modules() -> Vec<String> {
165    VFS_MODULES
166        .get_or_init(|| Mutex::new(Vec::new()))
167        .lock()
168        .expect("VFS_MODULES mutex poisoned")
169        .iter()
170        .map(|v| v.0.clone())
171        .collect()
172}
173
174pub fn get_vfs_modules() -> Vec<Vfs> {
175    VFS_MODULES
176        .get_or_init(|| Mutex::new(Vec::new()))
177        .lock()
178        .expect("VFS_MODULES mutex poisoned")
179        .clone()
180}