sqlite3_ext 0.2.0

Build loadable extensions for SQLite using Rust
Documentation
use crate::{ffi, sqlite3_match_version, types::*, value::*, Connection};
use sealed::sealed;
use std::{any::TypeId, ffi::CString, mem::size_of};

/// Access to sqlite3_aggregate_context.
///
/// U is the type of user data and F is the function context type. Neither are checked at
/// runtime.
#[repr(transparent)]
pub(crate) struct AggregateContext<U, F> {
    base: ffi::sqlite3_context,
    phantom: std::marker::PhantomData<(U, F)>,
}

#[repr(u8)]
#[derive(Default)]
enum SqliteManagedBox<T> {
    #[default]
    Uninitialized,
    Initialized(T),
}

/// Describes the run-time environment of an application-defined function.
#[repr(transparent)]
pub struct Context {
    base: ffi::sqlite3_context,
}

#[repr(C)]
struct AuxData<T> {
    type_id: TypeId,
    val: T,
}

#[derive(Debug, thiserror::Error)]
pub enum AuxDataError {
    /// No previous call to set_aux_data
    #[error("not set")]
    Unset,
    /// Previous call to set_aux_data used a different type
    #[error("wrong type")]
    WrongType,
}

impl<U, F> AggregateContext<U, F> {
    pub unsafe fn from_ptr<'a>(base: *mut ffi::sqlite3_context) -> &'a mut Self {
        &mut *(base as *mut Self)
    }

    pub unsafe fn user_data(&self) -> &U {
        &mut *(ffi::sqlite3_user_data(&raw const self.base as _) as *mut U)
    }

    pub unsafe fn get_or_insert_with(&mut self, f: impl Fn(&U) -> F) -> Result<&mut F> {
        let ptr = ffi::sqlite3_aggregate_context(
            &raw const self.base as _,
            size_of::<SqliteManagedBox<F>>() as _,
        ) as *mut SqliteManagedBox<F>;
        if ptr.is_null() {
            return Err(SQLITE_NOMEM);
        }
        let context = &mut *ptr;
        if let SqliteManagedBox::Uninitialized = context {
            *context = SqliteManagedBox::Initialized(f(self.user_data()));
        }
        let SqliteManagedBox::Initialized(ref mut val) = context else {
            unreachable!()
        };
        Ok(val)
    }

    /// Try to get the aggregate context, consuming it if it is found.
    pub unsafe fn take(&mut self) -> Option<F> {
        let ptr = ffi::sqlite3_aggregate_context(&raw const self.base as _, 0 as _)
            as *mut SqliteManagedBox<F>;
        if ptr.is_null() {
            return None;
        }
        let context = std::mem::take(&mut *ptr);
        match context {
            SqliteManagedBox::Uninitialized => None,
            SqliteManagedBox::Initialized(val) => Some(val),
        }
    }
}

impl Context {
    pub(crate) unsafe fn from_ptr<'a>(base: *mut ffi::sqlite3_context) -> &'a mut Self {
        &mut *(base as *mut Self)
    }

    pub(crate) fn as_ptr(&self) -> *mut ffi::sqlite3_context {
        &raw const self.base as _
    }

    /// Return a handle to the current database.
    pub fn db(&self) -> &Connection {
        unsafe { Connection::from_ptr(ffi::sqlite3_context_db_handle(self.as_ptr())) }
    }

    unsafe fn aux_data_inner<T: 'static>(
        &self,
        idx: usize,
    ) -> std::result::Result<*mut AuxData<T>, AuxDataError> {
        let data = ffi::sqlite3_get_auxdata(self.as_ptr(), idx as _) as *mut AuxData<T>;
        if data.is_null() {
            Err(AuxDataError::Unset)
        } else {
            let data = &mut *data;
            if data.type_id == TypeId::of::<T>() {
                Ok(data)
            } else {
                Err(AuxDataError::WrongType)
            }
        }
    }

    /// Retrieve data about a function parameter that was previously set with
    /// [set_aux_data](Context::set_aux_data).
    pub fn aux_data<T: 'static>(&self, idx: usize) -> std::result::Result<&T, AuxDataError> {
        unsafe { self.aux_data_inner::<T>(idx).map(|data| &(*data).val) }
    }

    /// Mutable version of [aux_data](Context::aux_data).
    pub fn aux_data_mut<T: 'static>(
        &mut self,
        idx: usize,
    ) -> std::result::Result<&mut T, AuxDataError> {
        unsafe { self.aux_data_inner::<T>(idx).map(|data| &mut (*data).val) }
    }

    /// Set the auxiliary data associated with the corresponding function parameter.
    ///
    /// If some processing is necessary in order for a function parameter to be useful (for
    /// example, compiling a regular expression), this method can be used to cache the
    /// processed value in case it is later reused in the same query. The cached value can
    /// be retrieved with the [aux_data](Context::aux_data) method.
    pub fn set_aux_data<T: 'static>(&self, idx: usize, val: T) {
        let data = Box::new(AuxData {
            type_id: TypeId::of::<T>(),
            val,
        });
        unsafe {
            ffi::sqlite3_set_auxdata(
                self.as_ptr(),
                idx as _,
                Box::into_raw(data) as _,
                Some(ffi::drop_boxed::<AuxData<T>>),
            )
        };
    }

    /// Assign the given value to the result of the function. This function always returns Ok.
    pub fn set_result(&self, val: impl ToContextResult) -> Result<()> {
        unsafe { val.assign_to(self.as_ptr()) };
        Ok(())
    }
}

/// A value that can be returned from an SQL function.
///
/// There are several useful implementations available:
///
/// - For nullable values, Option\<ToContextResult\> provides an implementation.
/// - For fallible functions, [Result]\<ToContextResult\> provides an implementation.
/// - For arbitrary Rust objects, [PassedRef] provides an implementation.
/// - For borrowed SQLite values, &[ValueRef] provides an implementation. Note that you have to
///   reborrow as immutable in most cases: `&*value_ref`.
/// - For owned types known only at run-time, [Value] provides an implementation.
#[sealed]
pub trait ToContextResult {
    #[doc(hidden)]
    unsafe fn assign_to(self, context: *mut ffi::sqlite3_context);
}

macro_rules! to_context_result {
    ($($(#[$attr:meta])* match $ty:ty as ($ctx:ident, $val:ident) => $impl:expr),*) => {
        $(
        $(#[$attr])*
        #[sealed]
        impl ToContextResult for $ty {
            unsafe fn assign_to(self, $ctx: *mut ffi::sqlite3_context) {
                #[allow(clippy::let_unit_value)]
                let $val = self;
                $impl
            }
        }
        )*
    };
}

to_context_result! {
    /// Assign NULL to the context result.
    match () as (ctx, _val) => ffi::sqlite3_result_null(ctx),
    match bool as (ctx, val) => ffi::sqlite3_result_int(ctx, val as i32),
    match i32 as (ctx, val) => ffi::sqlite3_result_int(ctx, val),
    match i64 as (ctx, val) => ffi::sqlite3_result_int64(ctx, val),
    match f64 as (ctx, val) => ffi::sqlite3_result_double(ctx, val),
    /// Assign a static string to the context result.
    match &'static str as (ctx, val) => {
        let val = val.as_bytes();
        let len = val.len();
        sqlite3_match_version! {
            3_008_007 => ffi::sqlite3_result_text64(ctx, val.as_ptr() as _, len as _, None, ffi::SQLITE_UTF8 as _),
            _ => ffi::sqlite3_result_text(ctx, val.as_ptr() as _, len as _, None),
        }
    },
    /// Assign an owned string to the context result.
    match String as (ctx, val) => {
        let val = val.as_bytes();
        let len = val.len();
        let cstring = CString::new(val).unwrap().into_raw();
        sqlite3_match_version! {
            3_008_007 => ffi::sqlite3_result_text64(ctx, cstring, len as _, Some(ffi::drop_cstring), ffi::SQLITE_UTF8 as _),
            _ => ffi::sqlite3_result_text(ctx, cstring, len as _, Some(ffi::drop_cstring)),
        }
    },
    match Blob as (ctx, val) => {
        let len = val.len();
        sqlite3_match_version! {
            3_008_007 => ffi::sqlite3_result_blob64(ctx, val.into_raw(), len as _, Some(ffi::drop_blob),),
            _ => ffi::sqlite3_result_blob(ctx, val.into_raw(), len as _, Some(ffi::drop_blob)),
        }
    },
    /// Sets the context error to this error.
    match Error as (ctx, err) => {
        match err {
            Error::Sqlite(_, Some(desc)) => {
                let bytes = desc.as_bytes();
                ffi::sqlite3_result_error(ctx, bytes.as_ptr() as _, bytes.len() as _)
            },
            Error::Sqlite(code, None) => ffi::sqlite3_result_error_code(ctx, code),
            Error::NoChange => (),
            _ => {
                let msg = format!("{err}");
                let msg = msg.as_bytes();
                let len = msg.len();
                ffi::sqlite3_result_error(ctx, msg.as_ptr() as _, len as _);
            }
        }
    }
}

/// Sets the context result to the contained value.
#[sealed]
impl ToContextResult for &ValueRef {
    unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
        ffi::sqlite3_result_value(ctx, self.as_ptr())
    }
}

/// Sets the context result to the contained value.
#[sealed]
impl ToContextResult for &mut ValueRef {
    unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
        ffi::sqlite3_result_value(ctx, self.as_ptr())
    }
}

/// Sets the context result to the given BLOB.
#[sealed]
impl ToContextResult for &[u8] {
    unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
        let len = self.len();
        sqlite3_match_version! {
            3_008_007 => ffi::sqlite3_result_blob64(
                ctx,
                self.as_ptr() as _,
                len as _,
                ffi::sqlite_transient(),
            ),
            _ => ffi::sqlite3_result_blob(
                ctx,
                self.as_ptr() as _,
                len as _,
                ffi::sqlite_transient(),
            ),
        }
    }
}

/// Sets the context result to the given BLOB.
#[sealed]
impl<const N: usize> ToContextResult for &[u8; N] {
    unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
        self.as_slice().assign_to(ctx);
    }
}

/// Sets the context result to the contained value or NULL.
#[sealed]
impl<T: ToContextResult> ToContextResult for Option<T> {
    unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
        match self {
            Some(x) => x.assign_to(context),
            None => ().assign_to(context),
        }
    }
}

/// Sets either the context result or error, depending on the result.
#[sealed]
impl<T: ToContextResult> ToContextResult for Result<T> {
    unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
        match self {
            Ok(x) => x.assign_to(context),
            Err(x) => x.assign_to(context),
        }
    }
}

/// Sets a dynamically typed [Value] to the context result.
#[sealed]
impl ToContextResult for Value {
    unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
        match self {
            Value::Integer(x) => x.assign_to(context),
            Value::Float(x) => x.assign_to(context),
            Value::Text(x) => x.assign_to(context),
            Value::Blob(x) => x.assign_to(context),
            Value::Null => ().assign_to(context),
        }
    }
}

/// Sets an arbitrary pointer to the context result.
#[sealed]
impl<T: 'static + ?Sized> ToContextResult for UnsafePtr<T> {
    unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
        sqlite3_match_version! {
        3_009_000 => {
            let subtype = self.subtype;
            self.into_bytes().assign_to(context);
            ffi::sqlite3_result_subtype(context, subtype as _);
        },
        _ => self.into_bytes().assign_to(context),
        }
    }
}

/// Sets the context result to NULL with this value as an associated pointer.
#[sealed]
impl<T: 'static> ToContextResult for PassedRef<T> {
    unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
        let _ = (POINTER_TAG, context);
        sqlite3_match_version! {
            3_020_000 => ffi::sqlite3_result_pointer(
                context,
                Box::into_raw(Box::new(self)) as _,
                POINTER_TAG,
                Some(ffi::drop_boxed::<PassedRef<T>>),
            ),
            _ => (),
        }
    }
}