1use core::ffi::{c_char, c_void};
2use core::mem::MaybeUninit;
3use core::ptr::NonNull;
4use std::collections::HashMap;
5use std::sync::{Mutex, OnceLock};
6
7use crate::connection::Connection;
8use crate::error::{Error, Result};
9use crate::function::Context;
10use crate::provider::{FeatureSet, Sqlite3Api, ValueType, sqlite3_module};
11use crate::value::ValueRef;
12
13pub struct BestIndexInfo {
15 pub raw: *mut c_void,
17}
18
19pub trait VirtualTable<P: Sqlite3Api>: Sized + Send {
21 type Cursor: VTabCursor<P>;
23 type Error: Into<Error>;
25
26 fn connect(args: &[&str]) -> core::result::Result<(Self, String), Self::Error>;
28 fn disconnect(self) -> core::result::Result<(), Self::Error>;
30
31 fn best_index(&self, _info: &mut BestIndexInfo) -> core::result::Result<(), Self::Error> {
33 Ok(())
34 }
35
36 fn open(&self) -> core::result::Result<Self::Cursor, Self::Error>;
38}
39
40pub trait VTabCursor<P: Sqlite3Api>: Sized + Send {
42 type Error: Into<Error>;
44
45 fn filter(
47 &mut self,
48 idx_num: i32,
49 idx_str: Option<&str>,
50 args: &[ValueRef<'_>],
51 ) -> core::result::Result<(), Self::Error>;
52 fn next(&mut self) -> core::result::Result<(), Self::Error>;
54 fn eof(&self) -> bool;
56 fn column(&self, ctx: &Context<'_, P>, col: i32) -> core::result::Result<(), Self::Error>;
58 fn rowid(&self) -> core::result::Result<i64, Self::Error>;
60}
61
62#[repr(C)]
64pub struct VTab<P: Sqlite3Api, T: VirtualTable<P>> {
65 api: *const P,
66 table: T,
67}
68
69#[repr(C)]
71pub struct Cursor<P: Sqlite3Api, C: VTabCursor<P>> {
72 api: *const P,
73 cursor: C,
74}
75
76const INLINE_ARGS: usize = 8;
77type ModuleCache = HashMap<usize, usize>;
78
79static MODULE_CACHE: OnceLock<Mutex<ModuleCache>> = OnceLock::new();
80
81struct ArgBuffer<'a> {
82 inline: [MaybeUninit<ValueRef<'a>>; INLINE_ARGS],
83 len: usize,
84 heap: Option<Vec<ValueRef<'a>>>,
85}
86
87impl<'a> ArgBuffer<'a> {
88 fn new(argc: usize) -> Self {
89 let inline = unsafe {
91 MaybeUninit::<[MaybeUninit<ValueRef<'a>>; INLINE_ARGS]>::uninit().assume_init()
92 };
93 let heap = if argc > INLINE_ARGS {
94 Some(Vec::with_capacity(argc))
95 } else {
96 None
97 };
98 Self {
99 inline,
100 len: 0,
101 heap,
102 }
103 }
104
105 fn push(&mut self, value: ValueRef<'a>) {
106 if let Some(heap) = &mut self.heap {
107 heap.push(value);
108 return;
109 }
110 let slot = &mut self.inline[self.len];
111 slot.write(value);
112 self.len += 1;
113 }
114
115 fn as_slice(&self) -> &[ValueRef<'a>] {
116 if let Some(heap) = &self.heap {
117 return heap.as_slice();
118 }
119 unsafe {
121 core::slice::from_raw_parts(self.inline.as_ptr() as *const ValueRef<'a>, self.len)
122 }
123 }
124}
125
126fn module_cache() -> &'static Mutex<ModuleCache> {
127 MODULE_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
128}
129
130fn module_key<P, T>() -> usize
131where
132 P: Sqlite3Api,
133 T: VirtualTable<P>,
134{
135 x_create::<P, T> as *const () as usize
136}
137
138fn cached_module<P, T>() -> &'static sqlite3_module<P>
139where
140 P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
141 T: VirtualTable<P>,
142{
143 let key = module_key::<P, T>();
144 let mut cache = module_cache()
145 .lock()
146 .expect("sqlite-provider vtab module cache poisoned");
147 if let Some(raw) = cache.get(&key) {
148 return unsafe { &*(*raw as *const sqlite3_module<P>) };
151 }
152 let module = Box::leak(Box::new(module::<P, T>()));
153 cache.insert(key, module as *const sqlite3_module<P> as usize);
154 module
155}
156
157fn err_code(err: &Error) -> i32 {
158 err.code.code().unwrap_or(1)
159}
160
161fn set_out_err_message<P: Sqlite3Api>(api: &P, out_err: *mut *mut u8, message: &str) {
162 if out_err.is_null() {
163 return;
164 }
165 let bytes = message.as_bytes();
166 let payload_len = bytes
167 .iter()
168 .position(|byte| *byte == 0)
169 .unwrap_or(bytes.len());
170 let alloc_len = payload_len.saturating_add(1);
171 let out = unsafe { api.malloc(alloc_len) } as *mut u8;
172 if out.is_null() {
173 return;
174 }
175 unsafe {
176 if payload_len > 0 {
177 core::ptr::copy_nonoverlapping(bytes.as_ptr(), out, payload_len);
178 }
179 *out.add(payload_len) = 0;
180 *out_err = out;
181 }
182}
183
184fn set_out_err_from_error<P: Sqlite3Api>(api: &P, out_err: *mut *mut u8, err: &Error) {
185 if let Some(message) = err.message.as_deref() {
186 set_out_err_message(api, out_err, message);
187 } else {
188 set_out_err_message(api, out_err, &err.to_string());
189 }
190}
191
192unsafe fn cstr_to_str<'a>(ptr: *const u8) -> Option<&'a str> {
193 if ptr.is_null() {
194 return None;
195 }
196 unsafe { core::ffi::CStr::from_ptr(ptr as *const c_char) }
197 .to_str()
198 .ok()
199}
200
201unsafe fn parse_args<'a>(argc: i32, argv: *const *const u8) -> Vec<&'a str> {
202 let argc = if argc < 0 { 0 } else { argc as usize };
203 let mut out = Vec::with_capacity(argc);
204 if argc == 0 || argv.is_null() {
205 return out;
206 }
207 let values = unsafe { core::slice::from_raw_parts(argv, argc) };
208 for value in values {
209 let value = unsafe { cstr_to_str(*value) }.unwrap_or("");
210 out.push(value);
211 }
212 out
213}
214
215unsafe fn value_ref_from_raw<'a, P: Sqlite3Api>(api: &P, value: NonNull<P::Value>) -> ValueRef<'a> {
216 match unsafe { api.value_type(value) } {
217 ValueType::Null => ValueRef::Null,
218 ValueType::Integer => ValueRef::Integer(unsafe { api.value_int64(value) }),
219 ValueType::Float => ValueRef::Float(unsafe { api.value_double(value) }),
220 ValueType::Text => unsafe { ValueRef::from_raw_text(api.value_text(value)) },
221 ValueType::Blob => unsafe { ValueRef::from_raw_blob(api.value_blob(value)) },
222 }
223}
224
225unsafe fn args_from_values<'a, P: Sqlite3Api>(
226 api: &P,
227 argc: i32,
228 argv: *mut *mut P::Value,
229) -> ArgBuffer<'a> {
230 let argc = if argc < 0 { 0 } else { argc as usize };
231 let mut out = ArgBuffer::new(argc);
232 if argc == 0 || argv.is_null() {
233 return out;
234 }
235 let values = unsafe { core::slice::from_raw_parts(argv, argc) };
236 for value in values {
237 if let Some(ptr) = NonNull::new(*value) {
238 out.push(unsafe { value_ref_from_raw(api, ptr) });
239 } else {
240 out.push(ValueRef::Null);
241 }
242 }
243 out
244}
245
246extern "C" fn x_create<P, T>(
247 db: *mut P::Db,
248 aux: *mut c_void,
249 argc: i32,
250 argv: *const *const u8,
251 out_vtab: *mut *mut P::VTab,
252 out_err: *mut *mut u8,
253) -> i32
254where
255 P: Sqlite3Api,
256 T: VirtualTable<P>,
257{
258 if out_vtab.is_null() {
259 return 1;
260 }
261 unsafe {
262 if !out_err.is_null() {
263 *out_err = core::ptr::null_mut();
264 }
265 }
266 if aux.is_null() || db.is_null() {
267 return 1;
268 }
269 let api = unsafe { &*(aux as *const P) };
270 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
271 let args = unsafe { parse_args(argc, argv) };
272 match T::connect(&args) {
273 Ok((table, schema)) => {
274 if let Err(err) = unsafe { api.declare_vtab(NonNull::new_unchecked(db), &schema) } {
275 set_out_err_from_error(api, out_err, &err);
276 return err_code(&err);
277 }
278 let vtab = Box::new(VTab {
279 api: api as *const P,
280 table,
281 });
282 unsafe {
283 *out_vtab = Box::into_raw(vtab) as *mut P::VTab;
284 }
285 0
286 }
287 Err(err) => {
288 let err: Error = err.into();
289 set_out_err_from_error(api, out_err, &err);
290 err_code(&err)
291 }
292 }
293 }));
294 match out {
295 Ok(code) => code,
296 Err(_) => {
297 set_out_err_message(api, out_err, "panic in virtual table create/connect");
298 1
299 }
300 }
301}
302
303extern "C" fn x_connect<P, T>(
304 db: *mut P::Db,
305 aux: *mut c_void,
306 argc: i32,
307 argv: *const *const u8,
308 out_vtab: *mut *mut P::VTab,
309 out_err: *mut *mut u8,
310) -> i32
311where
312 P: Sqlite3Api,
313 T: VirtualTable<P>,
314{
315 x_create::<P, T>(db, aux, argc, argv, out_vtab, out_err)
316}
317
318extern "C" fn x_best_index<P, T>(vtab: *mut P::VTab, info: *mut c_void) -> i32
319where
320 P: Sqlite3Api<VTab = VTab<P, T>>,
321 T: VirtualTable<P>,
322{
323 if vtab.is_null() {
324 return 1;
325 }
326 let vtab: &mut VTab<P, T> = unsafe { &mut *vtab };
327 let mut info = BestIndexInfo { raw: info };
328 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
329 vtab.table.best_index(&mut info)
330 }));
331 match out {
332 Ok(Ok(())) => 0,
333 Ok(Err(err)) => err_code(&err.into()),
334 Err(_) => 1,
335 }
336}
337
338extern "C" fn x_disconnect<P, T>(vtab: *mut P::VTab) -> i32
339where
340 P: Sqlite3Api<VTab = VTab<P, T>>,
341 T: VirtualTable<P>,
342{
343 if vtab.is_null() {
344 return 0;
345 }
346 let vtab: Box<VTab<P, T>> = unsafe { Box::from_raw(vtab) };
347 let VTab { table, .. } = *vtab;
348 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| table.disconnect()));
349 match out {
350 Ok(Ok(())) => 0,
351 Ok(Err(err)) => err_code(&err.into()),
352 Err(_) => 1,
353 }
354}
355
356extern "C" fn x_destroy<P, T>(vtab: *mut P::VTab) -> i32
357where
358 P: Sqlite3Api<VTab = VTab<P, T>>,
359 T: VirtualTable<P>,
360{
361 x_disconnect::<P, T>(vtab)
362}
363
364extern "C" fn x_open<P, T>(vtab: *mut P::VTab, out_cursor: *mut *mut P::VTabCursor) -> i32
365where
366 P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
367 T: VirtualTable<P>,
368{
369 if vtab.is_null() || out_cursor.is_null() {
370 return 1;
371 }
372 let vtab: &mut VTab<P, T> = unsafe { &mut *vtab };
373 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| vtab.table.open()));
374 match out {
375 Ok(Ok(cursor)) => {
376 let handle = Box::new(Cursor {
377 api: vtab.api,
378 cursor,
379 });
380 unsafe { *out_cursor = Box::into_raw(handle) };
381 0
382 }
383 Ok(Err(err)) => err_code(&err.into()),
384 Err(_) => 1,
385 }
386}
387
388extern "C" fn x_close<P, T>(cursor: *mut P::VTabCursor) -> i32
389where
390 P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
391 T: VirtualTable<P>,
392{
393 if cursor.is_null() {
394 return 0;
395 }
396 unsafe { drop(Box::from_raw(cursor)) };
397 0
398}
399
400extern "C" fn x_filter<P, T>(
401 cursor: *mut P::VTabCursor,
402 idx_num: i32,
403 idx_str: *const u8,
404 argc: i32,
405 argv: *mut *mut P::Value,
406) -> i32
407where
408 P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
409 T: VirtualTable<P>,
410{
411 if cursor.is_null() {
412 return 1;
413 }
414 let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
415 let api = unsafe { &*cursor.api };
416 let idx_str = unsafe { cstr_to_str(idx_str) };
417 let args = unsafe { args_from_values(api, argc, argv) };
418 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
419 cursor.cursor.filter(idx_num, idx_str, args.as_slice())
420 }));
421 match out {
422 Ok(Ok(())) => 0,
423 Ok(Err(err)) => err_code(&err.into()),
424 Err(_) => 1,
425 }
426}
427
428extern "C" fn x_next<P, T>(cursor: *mut P::VTabCursor) -> i32
429where
430 P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
431 T: VirtualTable<P>,
432{
433 if cursor.is_null() {
434 return 1;
435 }
436 let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
437 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| cursor.cursor.next()));
438 match out {
439 Ok(Ok(())) => 0,
440 Ok(Err(err)) => err_code(&err.into()),
441 Err(_) => 1,
442 }
443}
444
445extern "C" fn x_eof<P, T>(cursor: *mut P::VTabCursor) -> i32
446where
447 P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
448 T: VirtualTable<P>,
449{
450 if cursor.is_null() {
451 return 1;
452 }
453 let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
454 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| cursor.cursor.eof()));
455 match out {
456 Ok(true) => 1,
457 Ok(false) => 0,
458 Err(_) => 1,
459 }
460}
461
462extern "C" fn x_column<P, T>(cursor: *mut P::VTabCursor, ctx: *mut P::Context, col: i32) -> i32
463where
464 P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
465 T: VirtualTable<P>,
466{
467 if cursor.is_null() || ctx.is_null() {
468 return 1;
469 }
470 let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
471 let api = unsafe { &*cursor.api };
472 let context = Context::new(api, unsafe { NonNull::new_unchecked(ctx) });
473 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
474 cursor.cursor.column(&context, col)
475 }));
476 match out {
477 Ok(Ok(())) => 0,
478 Ok(Err(err)) => {
479 let err: Error = err.into();
480 context.result_error(err.message.as_deref().unwrap_or("virtual table error"));
481 err_code(&err)
482 }
483 Err(_) => {
484 context.result_error("panic in virtual table column");
485 1
486 }
487 }
488}
489
490extern "C" fn x_rowid<P, T>(cursor: *mut P::VTabCursor, rowid: *mut i64) -> i32
491where
492 P: Sqlite3Api<VTabCursor = Cursor<P, T::Cursor>>,
493 T: VirtualTable<P>,
494{
495 if cursor.is_null() || rowid.is_null() {
496 return 1;
497 }
498 let cursor: &mut Cursor<P, T::Cursor> = unsafe { &mut *cursor };
499 let out = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| cursor.cursor.rowid()));
500 match out {
501 Ok(Ok(id)) => {
502 unsafe { *rowid = id };
503 0
504 }
505 Ok(Err(err)) => err_code(&err.into()),
506 Err(_) => 1,
507 }
508}
509
510pub fn module<P, T>() -> sqlite3_module<P>
512where
513 P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
514 T: VirtualTable<P>,
515{
516 sqlite3_module {
517 i_version: 1,
518 x_create: Some(x_create::<P, T>),
519 x_connect: Some(x_connect::<P, T>),
520 x_best_index: Some(x_best_index::<P, T>),
521 x_disconnect: Some(x_disconnect::<P, T>),
522 x_destroy: Some(x_destroy::<P, T>),
523 x_open: Some(x_open::<P, T>),
524 x_close: Some(x_close::<P, T>),
525 x_filter: Some(x_filter::<P, T>),
526 x_next: Some(x_next::<P, T>),
527 x_eof: Some(x_eof::<P, T>),
528 x_column: Some(x_column::<P, T>),
529 x_rowid: Some(x_rowid::<P, T>),
530 }
531}
532
533impl<'p, P: Sqlite3Api> Connection<'p, P> {
534 pub fn create_module<T>(&self, name: &str) -> Result<()>
536 where
537 P: Sqlite3Api<VTab = VTab<P, T>, VTabCursor = Cursor<P, T::Cursor>>,
538 T: VirtualTable<P>,
539 {
540 if !self.api.feature_set().contains(FeatureSet::VIRTUAL_TABLES) {
541 return Err(Error::feature_unavailable("virtual tables unsupported"));
542 }
543 let module = cached_module::<P, T>();
544 unsafe {
545 self.api.create_module_v2(
546 self.db,
547 name,
548 module,
549 self.api as *const P as *mut c_void,
550 None,
551 )
552 }
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::{ArgBuffer, BestIndexInfo};
559 use crate::value::ValueRef;
560
561 #[test]
562 fn best_index_info_is_pointer() {
563 let info = BestIndexInfo {
564 raw: core::ptr::null_mut(),
565 };
566 assert!(info.raw.is_null());
567 }
568
569 #[test]
570 fn arg_buffer_inline() {
571 let mut buf = ArgBuffer::new(2);
572 buf.push(ValueRef::Integer(1));
573 buf.push(ValueRef::Integer(2));
574 assert_eq!(
575 buf.as_slice(),
576 &[ValueRef::Integer(1), ValueRef::Integer(2)]
577 );
578 }
579
580 #[test]
581 fn arg_buffer_heap() {
582 let mut buf = ArgBuffer::new(9);
583 for i in 0..9 {
584 buf.push(ValueRef::Integer(i));
585 }
586 assert_eq!(buf.as_slice().len(), 9);
587 assert_eq!(buf.as_slice()[0], ValueRef::Integer(0));
588 }
589}