Skip to main content

sqlite_provider/
vtab.rs

1use core::ffi::{c_char, c_void};
2use core::mem::MaybeUninit;
3use core::ptr::NonNull;
4
5use crate::connection::Connection;
6use crate::error::{Error, Result};
7use crate::function::Context;
8use crate::provider::{FeatureSet, Sqlite3Api, ValueType, sqlite3_module};
9use crate::value::ValueRef;
10
11/// Wrapper for a backend-specific best-index info pointer.
12pub struct BestIndexInfo {
13    pub raw: *mut c_void,
14}
15
16/// Virtual table implementation.
17pub trait VirtualTable<P: Sqlite3Api>: Sized + Send {
18    type Cursor: VTabCursor<P>;
19    type Error: Into<Error>;
20
21    fn connect(args: &[&str]) -> core::result::Result<(Self, String), Self::Error>;
22    fn disconnect(self) -> core::result::Result<(), Self::Error>;
23
24    fn best_index(&self, _info: &mut BestIndexInfo) -> core::result::Result<(), Self::Error> {
25        Ok(())
26    }
27
28    fn open(&self) -> core::result::Result<Self::Cursor, Self::Error>;
29}
30
31/// Virtual table cursor implementation.
32pub trait VTabCursor<P: Sqlite3Api>: Sized + Send {
33    type Error: Into<Error>;
34
35    fn filter(
36        &mut self,
37        idx_num: i32,
38        idx_str: Option<&str>,
39        args: &[ValueRef<'_>],
40    ) -> core::result::Result<(), Self::Error>;
41    fn next(&mut self) -> core::result::Result<(), Self::Error>;
42    fn eof(&self) -> bool;
43    fn column(&self, ctx: &Context<'_, P>, col: i32) -> core::result::Result<(), Self::Error>;
44    fn rowid(&self) -> core::result::Result<i64, Self::Error>;
45}
46
47/// Glue wrapper stored in the backend's `sqlite3_vtab` pointer.
48#[repr(C)]
49pub struct VTab<P: Sqlite3Api, T: VirtualTable<P>> {
50    api: *const P,
51    table: T,
52}
53
54/// Glue wrapper stored in the backend's `sqlite3_vtab_cursor` pointer.
55#[repr(C)]
56pub struct Cursor<P: Sqlite3Api, C: VTabCursor<P>> {
57    api: *const P,
58    cursor: C,
59}
60
61const INLINE_ARGS: usize = 8;
62
63struct ArgBuffer<'a> {
64    inline: [MaybeUninit<ValueRef<'a>>; INLINE_ARGS],
65    len: usize,
66    heap: Option<Vec<ValueRef<'a>>>,
67}
68
69impl<'a> ArgBuffer<'a> {
70    fn new(argc: usize) -> Self {
71        // SAFETY: MaybeUninit allows an uninitialized array.
72        let inline = unsafe {
73            MaybeUninit::<[MaybeUninit<ValueRef<'a>>; INLINE_ARGS]>::uninit().assume_init()
74        };
75        let heap = if argc > INLINE_ARGS { Some(Vec::with_capacity(argc)) } else { None };
76        Self { inline, len: 0, heap }
77    }
78
79    fn push(&mut self, value: ValueRef<'a>) {
80        if let Some(heap) = &mut self.heap {
81            heap.push(value);
82            return;
83        }
84        let slot = &mut self.inline[self.len];
85        slot.write(value);
86        self.len += 1;
87    }
88
89    fn as_slice(&self) -> &[ValueRef<'a>] {
90        if let Some(heap) = &self.heap {
91            return heap.as_slice();
92        }
93        // SAFETY: We only read the initialized prefix `len`.
94        unsafe { core::slice::from_raw_parts(self.inline.as_ptr() as *const ValueRef<'a>, self.len) }
95    }
96}
97
98fn err_code(err: &Error) -> i32 {
99    err.code.code().unwrap_or(1)
100}
101
102unsafe fn cstr_to_str<'a>(ptr: *const u8) -> Option<&'a str> {
103    if ptr.is_null() {
104        return None;
105    }
106    unsafe { core::ffi::CStr::from_ptr(ptr as *const c_char) }.to_str().ok()
107}
108
109unsafe fn parse_args<'a>(argc: i32, argv: *const *const u8) -> Vec<&'a str> {
110    let argc = if argc < 0 { 0 } else { argc as usize };
111    let mut out = Vec::with_capacity(argc);
112    if argc == 0 || argv.is_null() {
113        return out;
114    }
115    let values = unsafe { core::slice::from_raw_parts(argv, argc) };
116    for value in values {
117        let value = unsafe { cstr_to_str(*value) }.unwrap_or("");
118        out.push(value);
119    }
120    out
121}
122
123unsafe fn value_ref_from_raw<'a, P: Sqlite3Api>(api: &P, value: NonNull<P::Value>) -> ValueRef<'a> {
124    match unsafe { api.value_type(value) } {
125        ValueType::Null => ValueRef::Null,
126        ValueType::Integer => ValueRef::Integer(unsafe { api.value_int64(value) }),
127        ValueType::Float => ValueRef::Float(unsafe { api.value_double(value) }),
128        ValueType::Text => unsafe { ValueRef::from_raw_text(api.value_text(value)) },
129        ValueType::Blob => unsafe { ValueRef::from_raw_blob(api.value_blob(value)) },
130    }
131}
132
133unsafe fn args_from_values<'a, P: Sqlite3Api>(
134    api: &P,
135    argc: i32,
136    argv: *mut *mut P::Value,
137) -> ArgBuffer<'a> {
138    let argc = if argc < 0 { 0 } else { argc as usize };
139    let mut out = ArgBuffer::new(argc);
140    if argc == 0 || argv.is_null() {
141        return out;
142    }
143    let values = unsafe { core::slice::from_raw_parts(argv, argc) };
144    for value in values {
145        if let Some(ptr) = NonNull::new(*value) {
146            out.push(unsafe { value_ref_from_raw(api, ptr) });
147        } else {
148            out.push(ValueRef::Null);
149        }
150    }
151    out
152}
153
154extern "C" fn x_create<P, T>(
155    db: *mut P::Db,
156    aux: *mut c_void,
157    argc: i32,
158    argv: *const *const u8,
159    out_vtab: *mut *mut P::VTab,
160    out_err: *mut *mut u8,
161) -> i32
162where
163    P: Sqlite3Api,
164    T: VirtualTable<P>,
165{
166    if out_vtab.is_null() {
167        return 1;
168    }
169    unsafe {
170        if !out_err.is_null() {
171            *out_err = core::ptr::null_mut();
172        }
173    }
174    if aux.is_null() || db.is_null() {
175        return 1;
176    }
177    let api = unsafe { &*(aux as *const P) };
178    let args = unsafe { parse_args(argc, argv) };
179    match T::connect(&args) {
180        Ok((table, schema)) => {
181            if let Err(err) = unsafe { api.declare_vtab(NonNull::new_unchecked(db), &schema) } {
182                return err_code(&err);
183            }
184            let vtab = Box::new(VTab { api: api as *const P, table });
185            unsafe {
186                *out_vtab = Box::into_raw(vtab) as *mut P::VTab;
187            }
188            0
189        }
190        Err(err) => {
191            let err: Error = err.into();
192            err_code(&err)
193        }
194    }
195}
196
197extern "C" fn x_connect<P, T>(
198    db: *mut P::Db,
199    aux: *mut c_void,
200    argc: i32,
201    argv: *const *const u8,
202    out_vtab: *mut *mut P::VTab,
203    out_err: *mut *mut u8,
204) -> i32
205where
206    P: Sqlite3Api,
207    T: VirtualTable<P>,
208{
209    x_create::<P, T>(db, aux, argc, argv, out_vtab, out_err)
210}
211
212extern "C" fn x_best_index<P, T>(vtab: *mut P::VTab, info: *mut c_void) -> i32
213where
214    P: Sqlite3Api<VTab = VTab<P, T>>,
215    T: VirtualTable<P>,
216{
217    if vtab.is_null() {
218        return 1;
219    }
220    let vtab: &mut VTab<P, T> = unsafe { &mut *vtab };
221    let mut info = BestIndexInfo { raw: info };
222    match vtab.table.best_index(&mut info) {
223        Ok(()) => 0,
224        Err(err) => err_code(&err.into()),
225    }
226}
227
228extern "C" fn x_disconnect<P, T>(vtab: *mut P::VTab) -> i32
229where
230    P: Sqlite3Api<VTab = VTab<P, T>>,
231    T: VirtualTable<P>,
232{
233    if vtab.is_null() {
234        return 0;
235    }
236    let vtab: Box<VTab<P, T>> = unsafe { Box::from_raw(vtab) };
237    match vtab.table.disconnect() {
238        Ok(()) => 0,
239        Err(err) => err_code(&err.into()),
240    }
241}
242
243extern "C" fn x_destroy<P, T>(vtab: *mut P::VTab) -> i32
244where
245    P: Sqlite3Api<VTab = VTab<P, T>>,
246    T: VirtualTable<P>,
247{
248    x_disconnect::<P, T>(vtab)
249}
250
251extern "C" fn x_open<P, T>(vtab: *mut P::VTab, out_cursor: *mut *mut P::VTabCursor) -> i32
252where
253    P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
254    T: VirtualTable<P>,
255{
256    if vtab.is_null() || out_cursor.is_null() {
257        return 1;
258    }
259    let vtab: &mut VTab<P, T> = unsafe { &mut *vtab };
260    match vtab.table.open() {
261        Ok(cursor) => {
262            let handle = Box::new(Cursor { api: vtab.api, cursor });
263            unsafe { *out_cursor = Box::into_raw(handle) };
264            0
265        }
266        Err(err) => err_code(&err.into()),
267    }
268}
269
270extern "C" fn x_close<P, T>(cursor: *mut P::VTabCursor) -> i32
271where
272    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
273    T: VirtualTable<P>,
274{
275    if cursor.is_null() {
276        return 0;
277    }
278    unsafe { drop(Box::from_raw(cursor)) };
279    0
280}
281
282extern "C" fn x_filter<P, T>(
283    cursor: *mut P::VTabCursor,
284    idx_num: i32,
285    idx_str: *const u8,
286    argc: i32,
287    argv: *mut *mut P::Value,
288) -> i32
289where
290    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
291    T: VirtualTable<P>,
292{
293    if cursor.is_null() {
294        return 1;
295    }
296    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
297    let api = unsafe { &*cursor.api };
298    let idx_str = unsafe { cstr_to_str(idx_str) };
299    let args = unsafe { args_from_values(api, argc, argv) };
300    match cursor.cursor.filter(idx_num, idx_str, args.as_slice()) {
301        Ok(()) => 0,
302        Err(err) => err_code(&err.into()),
303    }
304}
305
306extern "C" fn x_next<P, T>(cursor: *mut P::VTabCursor) -> i32
307where
308    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
309    T: VirtualTable<P>,
310{
311    if cursor.is_null() {
312        return 1;
313    }
314    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
315    match cursor.cursor.next() {
316        Ok(()) => 0,
317        Err(err) => err_code(&err.into()),
318    }
319}
320
321extern "C" fn x_eof<P, T>(cursor: *mut P::VTabCursor) -> i32
322where
323    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
324    T: VirtualTable<P>,
325{
326    if cursor.is_null() {
327        return 1;
328    }
329    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
330    if cursor.cursor.eof() {
331        1
332    } else {
333        0
334    }
335}
336
337extern "C" fn x_column<P, T>(
338    cursor: *mut P::VTabCursor,
339    ctx: *mut P::Context,
340    col: i32,
341) -> i32
342where
343    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
344    T: VirtualTable<P>,
345{
346    if cursor.is_null() || ctx.is_null() {
347        return 1;
348    }
349    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
350    let api = unsafe { &*cursor.api };
351    let context = Context::new(api, unsafe { NonNull::new_unchecked(ctx) });
352    match cursor.cursor.column(&context, col) {
353        Ok(()) => 0,
354        Err(err) => {
355            let err: Error = err.into();
356            context.result_error(err.message.as_deref().unwrap_or("virtual table error"));
357            err_code(&err)
358        }
359    }
360}
361
362extern "C" fn x_rowid<P, T>(cursor: *mut P::VTabCursor, rowid: *mut i64) -> i32
363where
364    P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
365    T: VirtualTable<P>,
366{
367    if cursor.is_null() || rowid.is_null() {
368        return 1;
369    }
370    let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
371    match cursor.cursor.rowid() {
372        Ok(id) => {
373            unsafe { *rowid = id };
374            0
375        }
376        Err(err) => err_code(&err.into()),
377    }
378}
379
380/// Build a `sqlite3_module` backed by `VirtualTable` and `VTabCursor`.
381pub fn module<P, T>() -> sqlite3_module<P>
382where
383    P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
384    T: VirtualTable<P>,
385{
386    sqlite3_module {
387        i_version: 1,
388        x_create: Some(x_create::<P, T>),
389        x_connect: Some(x_connect::<P, T>),
390        x_best_index: Some(x_best_index::<P, T>),
391        x_disconnect: Some(x_disconnect::<P, T>),
392        x_destroy: Some(x_destroy::<P, T>),
393        x_open: Some(x_open::<P, T>),
394        x_close: Some(x_close::<P, T>),
395        x_filter: Some(x_filter::<P, T>),
396        x_next: Some(x_next::<P, T>),
397        x_eof: Some(x_eof::<P, T>),
398        x_column: Some(x_column::<P, T>),
399        x_rowid: Some(x_rowid::<P, T>),
400    }
401}
402
403impl<'p, P: Sqlite3Api> Connection<'p, P> {
404    /// Register a virtual table module using the default glue.
405    pub fn create_module<T>(&self, name: &str) -> Result<()>
406    where
407        P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
408        T: VirtualTable<P>,
409    {
410        if !self.api.feature_set().contains(FeatureSet::VIRTUAL_TABLES) {
411            return Err(Error::feature_unavailable("virtual tables unsupported"));
412        }
413        let module = Box::leak(Box::new(module::<P, T>()));
414        unsafe {
415            self.api.create_module_v2(
416                self.db,
417                name,
418                module,
419                self.api as *const P as *mut c_void,
420                None,
421            )
422        }
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::{ArgBuffer, BestIndexInfo};
429    use crate::value::ValueRef;
430
431    #[test]
432    fn best_index_info_is_pointer() {
433        let info = BestIndexInfo { raw: core::ptr::null_mut() };
434        assert!(info.raw.is_null());
435    }
436
437    #[test]
438    fn arg_buffer_inline() {
439        let mut buf = ArgBuffer::new(2);
440        buf.push(ValueRef::Integer(1));
441        buf.push(ValueRef::Integer(2));
442        assert_eq!(buf.as_slice(), &[ValueRef::Integer(1), ValueRef::Integer(2)]);
443    }
444
445    #[test]
446    fn arg_buffer_heap() {
447        let mut buf = ArgBuffer::new(9);
448        for i in 0..9 {
449            buf.push(ValueRef::Integer(i));
450        }
451        assert_eq!(buf.as_slice().len(), 9);
452        assert_eq!(buf.as_slice()[0], ValueRef::Integer(0));
453    }
454}