use crate::{ffi, sqlite3_match_version, types::*, value::*, Connection};
use sealed::sealed;
use std::{any::TypeId, ffi::CString, mem::size_of};
#[repr(transparent)]
pub(crate) struct AggregateContext<U, F> {
base: ffi::sqlite3_context,
phantom: std::marker::PhantomData<(U, F)>,
}
#[repr(u8)]
#[derive(Default)]
enum SqliteManagedBox<T> {
#[default]
Uninitialized,
Initialized(T),
}
#[repr(transparent)]
pub struct Context {
base: ffi::sqlite3_context,
}
#[repr(C)]
struct AuxData<T> {
type_id: TypeId,
val: T,
}
#[derive(Debug, thiserror::Error)]
pub enum AuxDataError {
#[error("not set")]
Unset,
#[error("wrong type")]
WrongType,
}
impl<U, F> AggregateContext<U, F> {
pub unsafe fn from_ptr<'a>(base: *mut ffi::sqlite3_context) -> &'a mut Self {
&mut *(base as *mut Self)
}
pub unsafe fn user_data(&self) -> &U {
&mut *(ffi::sqlite3_user_data(&raw const self.base as _) as *mut U)
}
pub unsafe fn get_or_insert_with(&mut self, f: impl Fn(&U) -> F) -> Result<&mut F> {
let ptr = ffi::sqlite3_aggregate_context(
&raw const self.base as _,
size_of::<SqliteManagedBox<F>>() as _,
) as *mut SqliteManagedBox<F>;
if ptr.is_null() {
return Err(SQLITE_NOMEM);
}
let context = &mut *ptr;
if let SqliteManagedBox::Uninitialized = context {
*context = SqliteManagedBox::Initialized(f(self.user_data()));
}
let SqliteManagedBox::Initialized(ref mut val) = context else {
unreachable!()
};
Ok(val)
}
pub unsafe fn take(&mut self) -> Option<F> {
let ptr = ffi::sqlite3_aggregate_context(&raw const self.base as _, 0 as _)
as *mut SqliteManagedBox<F>;
if ptr.is_null() {
return None;
}
let context = std::mem::take(&mut *ptr);
match context {
SqliteManagedBox::Uninitialized => None,
SqliteManagedBox::Initialized(val) => Some(val),
}
}
}
impl Context {
pub(crate) unsafe fn from_ptr<'a>(base: *mut ffi::sqlite3_context) -> &'a mut Self {
&mut *(base as *mut Self)
}
pub(crate) fn as_ptr(&self) -> *mut ffi::sqlite3_context {
&raw const self.base as _
}
pub fn db(&self) -> &Connection {
unsafe { Connection::from_ptr(ffi::sqlite3_context_db_handle(self.as_ptr())) }
}
unsafe fn aux_data_inner<T: 'static>(
&self,
idx: usize,
) -> std::result::Result<*mut AuxData<T>, AuxDataError> {
let data = ffi::sqlite3_get_auxdata(self.as_ptr(), idx as _) as *mut AuxData<T>;
if data.is_null() {
Err(AuxDataError::Unset)
} else {
let data = &mut *data;
if data.type_id == TypeId::of::<T>() {
Ok(data)
} else {
Err(AuxDataError::WrongType)
}
}
}
pub fn aux_data<T: 'static>(&self, idx: usize) -> std::result::Result<&T, AuxDataError> {
unsafe { self.aux_data_inner::<T>(idx).map(|data| &(*data).val) }
}
pub fn aux_data_mut<T: 'static>(
&mut self,
idx: usize,
) -> std::result::Result<&mut T, AuxDataError> {
unsafe { self.aux_data_inner::<T>(idx).map(|data| &mut (*data).val) }
}
pub fn set_aux_data<T: 'static>(&self, idx: usize, val: T) {
let data = Box::new(AuxData {
type_id: TypeId::of::<T>(),
val,
});
unsafe {
ffi::sqlite3_set_auxdata(
self.as_ptr(),
idx as _,
Box::into_raw(data) as _,
Some(ffi::drop_boxed::<AuxData<T>>),
)
};
}
pub fn set_result(&self, val: impl ToContextResult) -> Result<()> {
unsafe { val.assign_to(self.as_ptr()) };
Ok(())
}
}
#[sealed]
pub trait ToContextResult {
#[doc(hidden)]
unsafe fn assign_to(self, context: *mut ffi::sqlite3_context);
}
macro_rules! to_context_result {
($($(#[$attr:meta])* match $ty:ty as ($ctx:ident, $val:ident) => $impl:expr),*) => {
$(
$(#[$attr])*
#[sealed]
impl ToContextResult for $ty {
unsafe fn assign_to(self, $ctx: *mut ffi::sqlite3_context) {
#[allow(clippy::let_unit_value)]
let $val = self;
$impl
}
}
)*
};
}
to_context_result! {
match () as (ctx, _val) => ffi::sqlite3_result_null(ctx),
match bool as (ctx, val) => ffi::sqlite3_result_int(ctx, val as i32),
match i32 as (ctx, val) => ffi::sqlite3_result_int(ctx, val),
match i64 as (ctx, val) => ffi::sqlite3_result_int64(ctx, val),
match f64 as (ctx, val) => ffi::sqlite3_result_double(ctx, val),
match &'static str as (ctx, val) => {
let val = val.as_bytes();
let len = val.len();
sqlite3_match_version! {
3_008_007 => ffi::sqlite3_result_text64(ctx, val.as_ptr() as _, len as _, None, ffi::SQLITE_UTF8 as _),
_ => ffi::sqlite3_result_text(ctx, val.as_ptr() as _, len as _, None),
}
},
match String as (ctx, val) => {
let val = val.as_bytes();
let len = val.len();
let cstring = CString::new(val).unwrap().into_raw();
sqlite3_match_version! {
3_008_007 => ffi::sqlite3_result_text64(ctx, cstring, len as _, Some(ffi::drop_cstring), ffi::SQLITE_UTF8 as _),
_ => ffi::sqlite3_result_text(ctx, cstring, len as _, Some(ffi::drop_cstring)),
}
},
match Blob as (ctx, val) => {
let len = val.len();
sqlite3_match_version! {
3_008_007 => ffi::sqlite3_result_blob64(ctx, val.into_raw(), len as _, Some(ffi::drop_blob),),
_ => ffi::sqlite3_result_blob(ctx, val.into_raw(), len as _, Some(ffi::drop_blob)),
}
},
match Error as (ctx, err) => {
match err {
Error::Sqlite(_, Some(desc)) => {
let bytes = desc.as_bytes();
ffi::sqlite3_result_error(ctx, bytes.as_ptr() as _, bytes.len() as _)
},
Error::Sqlite(code, None) => ffi::sqlite3_result_error_code(ctx, code),
Error::NoChange => (),
_ => {
let msg = format!("{err}");
let msg = msg.as_bytes();
let len = msg.len();
ffi::sqlite3_result_error(ctx, msg.as_ptr() as _, len as _);
}
}
}
}
#[sealed]
impl ToContextResult for &ValueRef {
unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
ffi::sqlite3_result_value(ctx, self.as_ptr())
}
}
#[sealed]
impl ToContextResult for &mut ValueRef {
unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
ffi::sqlite3_result_value(ctx, self.as_ptr())
}
}
#[sealed]
impl ToContextResult for &[u8] {
unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
let len = self.len();
sqlite3_match_version! {
3_008_007 => ffi::sqlite3_result_blob64(
ctx,
self.as_ptr() as _,
len as _,
ffi::sqlite_transient(),
),
_ => ffi::sqlite3_result_blob(
ctx,
self.as_ptr() as _,
len as _,
ffi::sqlite_transient(),
),
}
}
}
#[sealed]
impl<const N: usize> ToContextResult for &[u8; N] {
unsafe fn assign_to(self, ctx: *mut ffi::sqlite3_context) {
self.as_slice().assign_to(ctx);
}
}
#[sealed]
impl<T: ToContextResult> ToContextResult for Option<T> {
unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
match self {
Some(x) => x.assign_to(context),
None => ().assign_to(context),
}
}
}
#[sealed]
impl<T: ToContextResult> ToContextResult for Result<T> {
unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
match self {
Ok(x) => x.assign_to(context),
Err(x) => x.assign_to(context),
}
}
}
#[sealed]
impl ToContextResult for Value {
unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
match self {
Value::Integer(x) => x.assign_to(context),
Value::Float(x) => x.assign_to(context),
Value::Text(x) => x.assign_to(context),
Value::Blob(x) => x.assign_to(context),
Value::Null => ().assign_to(context),
}
}
}
#[sealed]
impl<T: 'static + ?Sized> ToContextResult for UnsafePtr<T> {
unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
sqlite3_match_version! {
3_009_000 => {
let subtype = self.subtype;
self.into_bytes().assign_to(context);
ffi::sqlite3_result_subtype(context, subtype as _);
},
_ => self.into_bytes().assign_to(context),
}
}
}
#[sealed]
impl<T: 'static> ToContextResult for PassedRef<T> {
unsafe fn assign_to(self, context: *mut ffi::sqlite3_context) {
let _ = (POINTER_TAG, context);
sqlite3_match_version! {
3_020_000 => ffi::sqlite3_result_pointer(
context,
Box::into_raw(Box::new(self)) as _,
POINTER_TAG,
Some(ffi::drop_boxed::<PassedRef<T>>),
),
_ => (),
}
}
}