Skip to main content

limbo_core/ext/
mod.rs

1#[cfg(feature = "fs")]
2mod dynamic;
3mod vtab_xconnect;
4#[cfg(all(target_os = "linux", feature = "io_uring"))]
5use crate::UringIO;
6use crate::{function::ExternalFunc, Connection, Database, LimboError, IO};
7#[cfg(feature = "fs")]
8pub use dynamic::{add_builtin_vfs_extensions, add_vfs_module, list_vfs_modules, VfsMod};
9use limbo_ext::{
10    ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabKind, VTabModuleImpl,
11};
12pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType};
13use std::{
14    ffi::{c_char, c_void, CStr, CString},
15    rc::Rc,
16    sync::Arc,
17};
18pub use vtab_xconnect::{close, execute, prepare_stmt};
19type ExternAggFunc = (InitAggFunction, StepFunction, FinalizeFunction);
20
21#[derive(Clone)]
22pub struct VTabImpl {
23    pub module_kind: VTabKind,
24    pub implementation: Rc<VTabModuleImpl>,
25}
26
27pub(crate) unsafe extern "C" fn register_scalar_function(
28    ctx: *mut c_void,
29    name: *const c_char,
30    func: ScalarFunction,
31) -> ResultCode {
32    let c_str = unsafe { CStr::from_ptr(name) };
33    let name_str = match c_str.to_str() {
34        Ok(s) => s.to_string(),
35        Err(_) => return ResultCode::InvalidArgs,
36    };
37    if ctx.is_null() {
38        return ResultCode::Error;
39    }
40    let conn = unsafe { &*(ctx as *const Connection) };
41    conn.register_scalar_function_impl(&name_str, func)
42}
43
44pub(crate) unsafe extern "C" fn register_aggregate_function(
45    ctx: *mut c_void,
46    name: *const c_char,
47    args: i32,
48    init_func: InitAggFunction,
49    step_func: StepFunction,
50    finalize_func: FinalizeFunction,
51) -> ResultCode {
52    let c_str = unsafe { CStr::from_ptr(name) };
53    let name_str = match c_str.to_str() {
54        Ok(s) => s.to_string(),
55        Err(_) => return ResultCode::InvalidArgs,
56    };
57    if ctx.is_null() {
58        return ResultCode::Error;
59    }
60    let conn = unsafe { &*(ctx as *const Connection) };
61    conn.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func))
62}
63
64pub(crate) unsafe extern "C" fn register_vtab_module(
65    ctx: *mut c_void,
66    name: *const c_char,
67    module: VTabModuleImpl,
68    kind: VTabKind,
69) -> ResultCode {
70    if name.is_null() || ctx.is_null() {
71        return ResultCode::Error;
72    }
73    let c_str = unsafe { CString::from_raw(name as *mut _) };
74    let name_str = match c_str.to_str() {
75        Ok(s) => s.to_string(),
76        Err(_) => return ResultCode::Error,
77    };
78    if ctx.is_null() {
79        return ResultCode::Error;
80    }
81    let conn = unsafe { &mut *(ctx as *mut Connection) };
82
83    conn.register_vtab_module_impl(&name_str, module, kind)
84}
85
86impl Database {
87    #[cfg(feature = "fs")]
88    #[allow(clippy::arc_with_non_send_sync, dead_code)]
89    pub fn open_with_vfs(
90        &self,
91        path: &str,
92        vfs: &str,
93    ) -> crate::Result<(Arc<dyn IO>, Arc<Database>)> {
94        use crate::{MemoryIO, SyscallIO};
95        use dynamic::get_vfs_modules;
96
97        let io: Arc<dyn IO> = match vfs {
98            "memory" => Arc::new(MemoryIO::new()),
99            "syscall" => Arc::new(SyscallIO::new()?),
100            #[cfg(all(target_os = "linux", feature = "io_uring"))]
101            "io_uring" => Arc::new(UringIO::new()?),
102            other => match get_vfs_modules().iter().find(|v| v.0 == vfs) {
103                Some((_, vfs)) => vfs.clone(),
104                None => {
105                    return Err(LimboError::InvalidArgument(format!(
106                        "no such VFS: {}",
107                        other
108                    )));
109                }
110            },
111        };
112        let db = Self::open_file(io.clone(), path, false)?;
113        Ok((io, db))
114    }
115}
116
117impl Connection {
118    fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode {
119        self.syms.borrow_mut().functions.insert(
120            name.to_string(),
121            Rc::new(ExternalFunc::new_scalar(name.to_string(), func)),
122        );
123        ResultCode::OK
124    }
125
126    fn register_aggregate_function_impl(
127        &self,
128        name: &str,
129        args: i32,
130        func: ExternAggFunc,
131    ) -> ResultCode {
132        self.syms.borrow_mut().functions.insert(
133            name.to_string(),
134            Rc::new(ExternalFunc::new_aggregate(name.to_string(), args, func)),
135        );
136        ResultCode::OK
137    }
138
139    fn register_vtab_module_impl(
140        &mut self,
141        name: &str,
142        module: VTabModuleImpl,
143        kind: VTabKind,
144    ) -> ResultCode {
145        let module = Rc::new(module);
146        let vmodule = VTabImpl {
147            module_kind: kind,
148            implementation: module,
149        };
150        self.syms
151            .borrow_mut()
152            .vtab_modules
153            .insert(name.to_string(), vmodule.into());
154        ResultCode::OK
155    }
156
157    pub fn build_limbo_ext(&self) -> ExtensionApi {
158        ExtensionApi {
159            ctx: self as *const _ as *mut c_void,
160            register_scalar_function,
161            register_aggregate_function,
162            register_vtab_module,
163            #[cfg(feature = "fs")]
164            vfs_interface: limbo_ext::VfsInterface {
165                register_vfs: dynamic::register_vfs,
166                builtin_vfs: std::ptr::null_mut(),
167                builtin_vfs_count: 0,
168            },
169        }
170    }
171
172    pub fn register_builtins(&self) -> Result<(), String> {
173        #[allow(unused_variables)]
174        let mut ext_api = self.build_limbo_ext();
175        #[cfg(feature = "uuid")]
176        if unsafe { !limbo_uuid::register_extension_static(&mut ext_api).is_ok() } {
177            return Err("Failed to register uuid extension".to_string());
178        }
179        #[cfg(feature = "time")]
180        if unsafe { !limbo_time::register_extension_static(&mut ext_api).is_ok() } {
181            return Err("Failed to register time extension".to_string());
182        }
183        #[cfg(feature = "fs")]
184        {
185            let vfslist = add_builtin_vfs_extensions(Some(ext_api)).map_err(|e| e.to_string())?;
186            for (name, vfs) in vfslist {
187                add_vfs_module(name, vfs);
188            }
189        }
190        Ok(())
191    }
192}