Skip to main content

sqlite_provider/
vtab.rs

1use core::ffi::{c_char, c_void};
2use core::mem::MaybeUninit;
3use core::ptr::NonNull;
4use std::collections::HashMap;
5use std::sync::{Mutex, OnceLock};
6
7use crate::connection::Connection;
8use crate::error::{Error, Result};
9use crate::function::Context;
10use crate::provider::{FeatureSet, Sqlite3Api, ValueType, sqlite3_module};
11use crate::value::ValueRef;
12
13/// Wrapper for a backend-specific best-index info pointer.
14pub struct BestIndexInfo {
15    /// Raw backend pointer to `sqlite3_index_info`.
16    pub raw: *mut c_void,
17}
18
19/// Virtual table implementation.
20pub trait VirtualTable<P: Sqlite3Api>: Sized + Send {
21    /// Cursor type opened by this table.
22    type Cursor: VTabCursor<P>;
23    /// Error type mapped into crate `Error`.
24    type Error: Into<Error>;
25
26    /// Create/connect a table instance from SQLite module arguments.
27    fn connect(args: &[&str]) -> core::result::Result<(Self, String), Self::Error>;
28    /// Disconnect and release table resources.
29    fn disconnect(self) -> core::result::Result<(), Self::Error>;
30
31    /// Populate best-index constraints/order information.
32    fn best_index(&self, _info: &mut BestIndexInfo) -> core::result::Result<(), Self::Error> {
33        Ok(())
34    }
35
36    /// Open a new cursor over this table.
37    fn open(&self) -> core::result::Result<Self::Cursor, Self::Error>;
38}
39
40/// Virtual table cursor implementation.
41pub trait VTabCursor<P: Sqlite3Api>: Sized + Send {
42    /// Error type mapped into crate `Error`.
43    type Error: Into<Error>;
44
45    /// Apply constraints and initialize iteration.
46    fn filter(
47        &mut self,
48        idx_num: i32,
49        idx_str: Option<&str>,
50        args: &[ValueRef<'_>],
51    ) -> core::result::Result<(), Self::Error>;
52    /// Advance to the next row.
53    fn next(&mut self) -> core::result::Result<(), Self::Error>;
54    /// Whether iteration reached end-of-input.
55    fn eof(&self) -> bool;
56    /// Emit one output column into SQLite context.
57    fn column(&self, ctx: &Context<'_, P>, col: i32) -> core::result::Result<(), Self::Error>;
58    /// Return current rowid.
59    fn rowid(&self) -> core::result::Result<i64, Self::Error>;
60}
61
62/// Glue wrapper stored in the backend's `sqlite3_vtab` pointer.
63#[repr(C)]
64pub struct VTab<P: Sqlite3Api, T: VirtualTable<P>> {
65    api: *const P,
66    table: T,
67}
68
69/// Glue wrapper stored in the backend's `sqlite3_vtab_cursor` pointer.
70#[repr(C)]
71pub struct Cursor<P: Sqlite3Api, C: VTabCursor<P>> {
72    api: *const P,
73    cursor: C,
74}
75
76const INLINE_ARGS: usize = 8;
77type ModuleCache = HashMap<usize, usize>;
78
79static MODULE_CACHE: OnceLock<Mutex<ModuleCache>> = OnceLock::new();
80
81struct ArgBuffer<'a> {
82    inline: [MaybeUninit<ValueRef<'a>>; INLINE_ARGS],
83    len: usize,
84    heap: Option<Vec<ValueRef<'a>>>,
85}
86
87impl<'a> ArgBuffer<'a> {
88    fn new(argc: usize) -> Self {
89        // SAFETY: MaybeUninit allows an uninitialized array.
90        let inline = unsafe {
91            MaybeUninit::<[MaybeUninit<ValueRef<'a>>; INLINE_ARGS]>::uninit().assume_init()
92        };
93        let heap = if argc > INLINE_ARGS {
94            Some(Vec::with_capacity(argc))
95        } else {
96            None
97        };
98        Self {
99            inline,
100            len: 0,
101            heap,
102        }
103    }
104
105    fn push(&mut self, value: ValueRef<'a>) {
106        if let Some(heap) = &mut self.heap {
107            heap.push(value);
108            return;
109        }
110        let slot = &mut self.inline[self.len];
111        slot.write(value);
112        self.len += 1;
113    }
114
115    fn as_slice(&self) -> &[ValueRef<'a>] {
116        if let Some(heap) = &self.heap {
117            return heap.as_slice();
118        }
119        // SAFETY: We only read the initialized prefix `len`.
120        unsafe {
121            core::slice::from_raw_parts(self.inline.as_ptr() as *const ValueRef<'a>, self.len)
122        }
123    }
124}
125
126fn module_cache() -> &'static Mutex<ModuleCache> {
127    MODULE_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
128}
129
130fn module_key<P, T>() -> usize
131where
132    P: Sqlite3Api,
133    T: VirtualTable<P>,
134{
135    x_create::<P, T> as *const () as usize
136}
137
138fn cached_module<P, T>() -> &'static sqlite3_module<P>
139where
140    P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
141    T: VirtualTable<P>,
142{
143    let key = module_key::<P, T>();
144    let mut cache = module_cache()
145        .lock()
146        .expect("sqlite-provider vtab module cache poisoned");
147    if let Some(raw) = cache.get(&key) {
148        // SAFETY: `raw` was inserted from a leaked `sqlite3_module<P>` built
149        // for this exact monomorphized callback key.
150        return unsafe { &*(*raw as *const sqlite3_module<P>) };
151    }
152    let module = Box::leak(Box::new(module::<P, T>()));
153    cache.insert(key, module as *const sqlite3_module<P> as usize);
154    module
155}
156
157fn err_code(err: &Error) -> i32 {
158    err.code.code().unwrap_or(1)
159}
160
161fn set_out_err_message<P: Sqlite3Api>(api: &P, out_err: *mut *mut u8, message: &str) {
162    if out_err.is_null() {
163        return;
164    }
165    let bytes = message.as_bytes();
166    let payload_len = bytes
167        .iter()
168        .position(|byte| *byte == 0)
169        .unwrap_or(bytes.len());
170    let alloc_len = payload_len.saturating_add(1);
171    let out = unsafe { api.malloc(alloc_len) } as *mut u8;
172    if out.is_null() {
173        return;
174    }
175    unsafe {
176        if payload_len > 0 {
177            core::ptr::copy_nonoverlapping(bytes.as_ptr(), out, payload_len);
178        }
179        *out.add(payload_len) = 0;
180        *out_err = out;
181    }
182}
183
184fn set_out_err_from_error<P: Sqlite3Api>(api: &P, out_err: *mut *mut u8, err: &Error) {
185    if let Some(message) = err.message.as_deref() {
186        set_out_err_message(api, out_err, message);
187    } else {
188        set_out_err_message(api, out_err, &err.to_string());
189    }
190}
191
192unsafe fn cstr_to_str<'a>(ptr: *const u8) -> Option<&'a str> {
193    if ptr.is_null() {
194        return None;
195    }
196    unsafe { core::ffi::CStr::from_ptr(ptr as *const c_char) }
197        .to_str()
198        .ok()
199}
200
201unsafe fn parse_args<'a>(argc: i32, argv: *const *const u8) -> Vec<&'a str> {
202    let argc = if argc < 0 { 0 } else { argc as usize };
203    let mut out = Vec::with_capacity(argc);
204    if argc == 0 || argv.is_null() {
205        return out;
206    }
207    let values = unsafe { core::slice::from_raw_parts(argv, argc) };
208    for value in values {
209        let value = unsafe { cstr_to_str(*value) }.unwrap_or("");
210        out.push(value);
211    }
212    out
213}
214
215unsafe fn value_ref_from_raw<'a, P: Sqlite3Api>(api: &P, value: NonNull<P::Value>) -> ValueRef<'a> {
216    match unsafe { api.value_type(value) } {
217        ValueType::Null => ValueRef::Null,
218        ValueType::Integer => ValueRef::Integer(unsafe { api.value_int64(value) }),
219        ValueType::Float => ValueRef::Float(unsafe { api.value_double(value) }),
220        ValueType::Text => unsafe { ValueRef::from_raw_text(api.value_text(value)) },
221        ValueType::Blob => unsafe { ValueRef::from_raw_blob(api.value_blob(value)) },
222    }
223}
224
225unsafe fn args_from_values<'a, P: Sqlite3Api>(
226    api: &P,
227    argc: i32,
228    argv: *mut *mut P::Value,
229) -> ArgBuffer<'a> {
230    let argc = if argc < 0 { 0 } else { argc as usize };
231    let mut out = ArgBuffer::new(argc);
232    if argc == 0 || argv.is_null() {
233        return out;
234    }
235    let values = unsafe { core::slice::from_raw_parts(argv, argc) };
236    for value in values {
237        if let Some(ptr) = NonNull::new(*value) {
238            out.push(unsafe { value_ref_from_raw(api, ptr) });
239        } else {
240            out.push(ValueRef::Null);
241        }
242    }
243    out
244}
245
246extern "C" fn x_create<P, T>(
247    db: *mut P::Db,
248    aux: *mut c_void,
249    argc: i32,
250    argv: *const *const u8,
251    out_vtab: *mut *mut P::VTab,
252    out_err: *mut *mut u8,
253) -> i32
254where
255    P: Sqlite3Api,
256    T: VirtualTable<P>,
257{
258    if out_vtab.is_null() {
259        return 1;
260    }
261    unsafe {
262        if !out_err.is_null() {
263            *out_err = core::ptr::null_mut();
264        }
265    }
266    if aux.is_null() || db.is_null() {
267        return 1;
268    }
269    let api = unsafe { &*(aux as *const P) };
270    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
271        let args = unsafe { parse_args(argc, argv) };
272        match T::connect(&args) {
273            Ok((table, schema)) => {
274                if let Err(err) = unsafe { api.declare_vtab(NonNull::new_unchecked(db), &schema) } {
275                    set_out_err_from_error(api, out_err, &err);
276                    return err_code(&err);
277                }
278                let vtab = Box::new(VTab {
279                    api: api as *const P,
280                    table,
281                });
282                unsafe {
283                    *out_vtab = Box::into_raw(vtab) as *mut P::VTab;
284                }
285                0
286            }
287            Err(err) => {
288                let err: Error = err.into();
289                set_out_err_from_error(api, out_err, &err);
290                err_code(&err)
291            }
292        }
293    }));
294    match out {
295        Ok(code) => code,
296        Err(_) => {
297            set_out_err_message(api, out_err, "panic in virtual table create/connect");
298            1
299        }
300    }
301}
302
303extern "C" fn x_connect<P, T>(
304    db: *mut P::Db,
305    aux: *mut c_void,
306    argc: i32,
307    argv: *const *const u8,
308    out_vtab: *mut *mut P::VTab,
309    out_err: *mut *mut u8,
310) -> i32
311where
312    P: Sqlite3Api,
313    T: VirtualTable<P>,
314{
315    x_create::<P, T>(db, aux, argc, argv, out_vtab, out_err)
316}
317
318extern "C" fn x_best_index<P, T>(vtab: *mut P::VTab, info: *mut c_void) -> i32
319where
320    P: Sqlite3Api<VTab = VTab<P, T>>,
321    T: VirtualTable<P>,
322{
323    if vtab.is_null() {
324        return 1;
325    }
326    let vtab: &mut VTab<P, T> = unsafe { &mut *vtab };
327    let mut info = BestIndexInfo { raw: info };
328    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
329        vtab.table.best_index(&mut info)
330    }));
331    match out {
332        Ok(Ok(())) => 0,
333        Ok(Err(err)) => err_code(&err.into()),
334        Err(_) => 1,
335    }
336}
337
338extern "C" fn x_disconnect<P, T>(vtab: *mut P::VTab) -> i32
339where
340    P: Sqlite3Api<VTab = VTab<P, T>>,
341    T: VirtualTable<P>,
342{
343    if vtab.is_null() {
344        return 0;
345    }
346    let vtab: Box<VTab<P, T>> = unsafe { Box::from_raw(vtab) };
347    let VTab { table, .. } = *vtab;
348    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| table.disconnect()));
349    match out {
350        Ok(Ok(())) => 0,
351        Ok(Err(err)) => err_code(&err.into()),
352        Err(_) => 1,
353    }
354}
355
356extern "C" fn x_destroy<P, T>(vtab: *mut P::VTab) -> i32
357where
358    P: Sqlite3Api<VTab = VTab<P, T>>,
359    T: VirtualTable<P>,
360{
361    x_disconnect::<P, T>(vtab)
362}
363
364extern "C" fn x_open<P, T>(vtab: *mut P::VTab, out_cursor: *mut *mut P::VTabCursor) -> i32
365where
366    P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
367    T: VirtualTable<P>,
368{
369    if vtab.is_null() || out_cursor.is_null() {
370        return 1;
371    }
372    let vtab: &mut VTab<P, T> = unsafe { &mut *vtab };
373    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| vtab.table.open()));
374    match out {
375        Ok(Ok(cursor)) => {
376            let handle = Box::new(Cursor {
377                api: vtab.api,
378                cursor,
379            });
380            unsafe { *out_cursor = Box::into_raw(handle) };
381            0
382        }
383        Ok(Err(err)) => err_code(&err.into()),
384        Err(_) => 1,
385    }
386}
387
388extern "C" fn x_close<P, T>(cursor: *mut P::VTabCursor) -> i32
389where
390    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
391    T: VirtualTable<P>,
392{
393    if cursor.is_null() {
394        return 0;
395    }
396    unsafe { drop(Box::from_raw(cursor)) };
397    0
398}
399
400extern "C" fn x_filter<P, T>(
401    cursor: *mut P::VTabCursor,
402    idx_num: i32,
403    idx_str: *const u8,
404    argc: i32,
405    argv: *mut *mut P::Value,
406) -> i32
407where
408    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
409    T: VirtualTable<P>,
410{
411    if cursor.is_null() {
412        return 1;
413    }
414    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
415    let api = unsafe { &*cursor.api };
416    let idx_str = unsafe { cstr_to_str(idx_str) };
417    let args = unsafe { args_from_values(api, argc, argv) };
418    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
419        cursor.cursor.filter(idx_num, idx_str, args.as_slice())
420    }));
421    match out {
422        Ok(Ok(())) => 0,
423        Ok(Err(err)) => err_code(&err.into()),
424        Err(_) => 1,
425    }
426}
427
428extern "C" fn x_next<P, T>(cursor: *mut P::VTabCursor) -> i32
429where
430    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
431    T: VirtualTable<P>,
432{
433    if cursor.is_null() {
434        return 1;
435    }
436    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
437    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| cursor.cursor.next()));
438    match out {
439        Ok(Ok(())) => 0,
440        Ok(Err(err)) => err_code(&err.into()),
441        Err(_) => 1,
442    }
443}
444
445extern "C" fn x_eof<P, T>(cursor: *mut P::VTabCursor) -> i32
446where
447    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
448    T: VirtualTable<P>,
449{
450    if cursor.is_null() {
451        return 1;
452    }
453    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
454    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| cursor.cursor.eof()));
455    match out {
456        Ok(true) => 1,
457        Ok(false) => 0,
458        Err(_) => 1,
459    }
460}
461
462extern "C" fn x_column<P, T>(cursor: *mut P::VTabCursor, ctx: *mut P::Context, col: i32) -> i32
463where
464    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
465    T: VirtualTable<P>,
466{
467    if cursor.is_null() || ctx.is_null() {
468        return 1;
469    }
470    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
471    let api = unsafe { &*cursor.api };
472    let context = Context::new(api, unsafe { NonNull::new_unchecked(ctx) });
473    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
474        cursor.cursor.column(&context, col)
475    }));
476    match out {
477        Ok(Ok(())) => 0,
478        Ok(Err(err)) => {
479            let err: Error = err.into();
480            context.result_error(err.message.as_deref().unwrap_or("virtual table error"));
481            err_code(&err)
482        }
483        Err(_) => {
484            context.result_error("panic in virtual table column");
485            1
486        }
487    }
488}
489
490extern "C" fn x_rowid<P, T>(cursor: *mut P::VTabCursor, rowid: *mut i64) -> i32
491where
492    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
493    T: VirtualTable<P>,
494{
495    if cursor.is_null() || rowid.is_null() {
496        return 1;
497    }
498    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
499    let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| cursor.cursor.rowid()));
500    match out {
501        Ok(Ok(id)) => {
502            unsafe { *rowid = id };
503            0
504        }
505        Ok(Err(err)) => err_code(&err.into()),
506        Err(_) => 1,
507    }
508}
509
510/// Build a `sqlite3_module` backed by `VirtualTable` and `VTabCursor`.
511pub fn module<P, T>() -> sqlite3_module<P>
512where
513    P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
514    T: VirtualTable<P>,
515{
516    sqlite3_module {
517        i_version: 1,
518        x_create: Some(x_create::<P, T>),
519        x_connect: Some(x_connect::<P, T>),
520        x_best_index: Some(x_best_index::<P, T>),
521        x_disconnect: Some(x_disconnect::<P, T>),
522        x_destroy: Some(x_destroy::<P, T>),
523        x_open: Some(x_open::<P, T>),
524        x_close: Some(x_close::<P, T>),
525        x_filter: Some(x_filter::<P, T>),
526        x_next: Some(x_next::<P, T>),
527        x_eof: Some(x_eof::<P, T>),
528        x_column: Some(x_column::<P, T>),
529        x_rowid: Some(x_rowid::<P, T>),
530    }
531}
532
533impl<'p, P: Sqlite3Api> Connection<'p, P> {
534    /// Register a virtual table module using the default glue.
535    pub fn create_module<T>(&self, name: &str) -> Result<()>
536    where
537        P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
538        T: VirtualTable<P>,
539    {
540        if !self.api.feature_set().contains(FeatureSet::VIRTUAL_TABLES) {
541            return Err(Error::feature_unavailable("virtual tables unsupported"));
542        }
543        let module = cached_module::<P, T>();
544        unsafe {
545            self.api.create_module_v2(
546                self.db,
547                name,
548                module,
549                self.api as *const P as *mut c_void,
550                None,
551            )
552        }
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::{ArgBuffer, BestIndexInfo};
559    use crate::value::ValueRef;
560
561    #[test]
562    fn best_index_info_is_pointer() {
563        let info = BestIndexInfo {
564            raw: core::ptr::null_mut(),
565        };
566        assert!(info.raw.is_null());
567    }
568
569    #[test]
570    fn arg_buffer_inline() {
571        let mut buf = ArgBuffer::new(2);
572        buf.push(ValueRef::Integer(1));
573        buf.push(ValueRef::Integer(2));
574        assert_eq!(
575            buf.as_slice(),
576            &[ValueRef::Integer(1), ValueRef::Integer(2)]
577        );
578    }
579
580    #[test]
581    fn arg_buffer_heap() {
582        let mut buf = ArgBuffer::new(9);
583        for i in 0..9 {
584            buf.push(ValueRef::Integer(i));
585        }
586        assert_eq!(buf.as_slice().len(), 9);
587        assert_eq!(buf.as_slice()[0], ValueRef::Integer(0));
588    }
589}