use super::{ffi, sqlite3_match_version, types::*, value::*, Connection, RiskLevel};
pub use context::*;
use std::{cmp::Ordering, ffi::CString, ptr::null_mut};
mod context;
mod stubs;
mod test;
pub trait FromUserData<T> {
fn from_user_data(data: &T) -> Self;
}
pub trait ScalarFunction<'db> {
fn call(&self, context: &mut Context, args: &mut [&mut ValueRef]) -> Result<()>;
}
struct ScalarClosure<F>(F)
where
F: Fn(&mut Context, &mut [&mut ValueRef]) -> Result<()> + 'static;
impl<F> ScalarFunction<'_> for ScalarClosure<F>
where
F: Fn(&mut Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
{
fn call(&self, ctx: &mut Context, args: &mut [&mut ValueRef]) -> Result<()> {
self.0(ctx, args)
}
}
pub trait LegacyAggregateFunction<UserData>: FromUserData<UserData> {
fn default_value(user_data: &UserData, context: &Context) -> Result<()>
where
Self: Sized,
{
Self::from_user_data(user_data).value(context)
}
fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
fn value(&self, context: &Context) -> Result<()>;
}
pub trait AggregateFunction<UserData>: FromUserData<UserData> {
fn default_value(user_data: &UserData, context: &Context) -> Result<()>
where
Self: Sized,
{
Self::from_user_data(user_data).value(context)
}
fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
fn value(&self, context: &Context) -> Result<()>;
fn inverse(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()>;
}
impl<U, F: Default> FromUserData<U> for F {
fn from_user_data(_: &U) -> F {
F::default()
}
}
impl<U, T: AggregateFunction<U>> LegacyAggregateFunction<U> for T {
fn default_value(user_data: &U, context: &Context) -> Result<()> {
<T as AggregateFunction<U>>::default_value(user_data, context)
}
fn step(&mut self, context: &Context, args: &mut [&mut ValueRef]) -> Result<()> {
<T as AggregateFunction<U>>::step(self, context, args)
}
fn value(&self, context: &Context) -> Result<()> {
<T as AggregateFunction<U>>::value(self, context)
}
}
#[derive(Debug, Clone)]
pub struct FunctionOptions {
n_args: i32,
flags: i32,
}
impl Default for FunctionOptions {
fn default() -> Self {
FunctionOptions::default()
}
}
impl FunctionOptions {
pub const fn default() -> Self {
FunctionOptions {
n_args: -1,
flags: 0,
}
}
pub const fn set_n_args(mut self, n_args: i32) -> Self {
assert!(n_args >= -1 && n_args < 128, "n_args invalid");
self.n_args = n_args;
self
}
pub const fn set_deterministic(mut self, val: bool) -> Self {
if val {
self.flags |= ffi::SQLITE_DETERMINISTIC;
} else {
self.flags &= !ffi::SQLITE_DETERMINISTIC;
}
self
}
pub const fn set_risk_level(
#[cfg_attr(not(modern_sqlite), allow(unused_mut))] mut self,
level: RiskLevel,
) -> Self {
let _ = level;
#[cfg(modern_sqlite)]
{
self.flags |= match level {
RiskLevel::Innocuous => ffi::SQLITE_INNOCUOUS,
RiskLevel::DirectOnly => ffi::SQLITE_DIRECTONLY,
};
self.flags &= match level {
RiskLevel::Innocuous => !ffi::SQLITE_DIRECTONLY,
RiskLevel::DirectOnly => !ffi::SQLITE_INNOCUOUS,
};
}
self
}
}
impl Connection {
pub fn create_overloaded_function(&self, name: &str, opts: &FunctionOptions) -> Result<()> {
let guard = self.lock();
let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
unsafe {
Error::from_sqlite_desc(
ffi::sqlite3_overload_function(self.as_mut_ptr(), name.as_ptr() as _, opts.n_args),
guard,
)
}
}
pub fn create_scalar_function<F>(
&self,
name: &str,
opts: &FunctionOptions,
func: F,
) -> Result<()>
where
F: Fn(&mut Context, &mut [&mut ValueRef]) -> Result<()> + 'static,
{
self.create_scalar_function_object(name, opts, ScalarClosure(func))
}
pub fn create_scalar_function_object<'db, F>(
&'db self,
name: &str,
opts: &FunctionOptions,
func: F,
) -> Result<()>
where
F: ScalarFunction<'db>,
{
let guard = self.lock();
let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
let func = Box::new(func);
unsafe {
Error::from_sqlite_desc(
sqlite3_match_version! {
3_007_003 => ffi::sqlite3_create_function_v2(
self.as_mut_ptr(),
name.as_ptr() as _,
opts.n_args,
opts.flags,
Box::into_raw(func) as _,
Some(stubs::call_scalar::<F>),
None,
None,
Some(ffi::drop_boxed::<F>),
),
_ => ffi::sqlite3_create_function(
self.as_mut_ptr(),
name.as_ptr() as _,
opts.n_args,
opts.flags,
Box::into_raw(func) as _,
Some(stubs::call_scalar::<F>),
None,
None,
),
},
guard,
)
}
}
pub fn create_legacy_aggregate_function<U, F: LegacyAggregateFunction<U>>(
&self,
name: &str,
opts: &FunctionOptions,
user_data: U,
) -> Result<()> {
let guard = self.lock();
let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
let user_data = Box::new(user_data);
unsafe {
Error::from_sqlite_desc(
sqlite3_match_version! {
3_007_003 => ffi::sqlite3_create_function_v2(
self.as_mut_ptr(),
name.as_ptr() as _,
opts.n_args,
opts.flags,
Box::into_raw(user_data) as _,
None,
Some(stubs::aggregate_step::<U, F>),
Some(stubs::aggregate_final::<U, F>),
Some(ffi::drop_boxed::<U>),
),
_ => ffi::sqlite3_create_function(
self.as_mut_ptr(),
name.as_ptr() as _,
opts.n_args,
opts.flags,
Box::into_raw(user_data) as _,
None,
Some(stubs::aggregate_step::<U, F>),
Some(stubs::aggregate_final::<U, F>),
),
},
guard,
)
}
}
pub fn create_aggregate_function<U, F: AggregateFunction<U>>(
&self,
name: &str,
opts: &FunctionOptions,
user_data: U,
) -> Result<()> {
sqlite3_match_version! {
3_025_000 => {
let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
let user_data = Box::new(user_data);
let guard = self.lock();
unsafe {
Error::from_sqlite_desc(ffi::sqlite3_create_window_function(
self.as_mut_ptr(),
name.as_ptr() as _,
opts.n_args,
opts.flags,
Box::into_raw(user_data) as _,
Some(stubs::aggregate_step::<U, F>),
Some(stubs::aggregate_final::<U, F>),
Some(stubs::aggregate_value::<U, F>),
Some(stubs::aggregate_inverse::<U, F>),
Some(ffi::drop_boxed::<U>),
), guard)
}
},
_ => self.create_legacy_aggregate_function::<U, F>(name, opts, user_data),
}
}
pub fn remove_function(&self, name: &str, n_args: i32) -> Result<()> {
let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
let guard = self.lock();
unsafe {
Error::from_sqlite_desc(
ffi::sqlite3_create_function(
self.as_mut_ptr(),
name.as_ptr() as _,
n_args,
0,
null_mut(),
None,
None,
None,
),
guard,
)
}
}
pub fn create_collation<F: Fn(&str, &str) -> Ordering>(
&self,
name: &str,
func: F,
) -> Result<()> {
let name = unsafe { CString::from_vec_unchecked(name.as_bytes().into()) };
let func = Box::into_raw(Box::new(func));
let guard = self.lock();
unsafe {
let rc = ffi::sqlite3_create_collation_v2(
self.as_mut_ptr(),
name.as_ptr() as _,
ffi::SQLITE_UTF8,
func as _,
Some(stubs::compare::<F>),
Some(ffi::drop_boxed::<F>),
);
if rc != ffi::SQLITE_OK {
drop(Box::from_raw(func));
}
Error::from_sqlite_desc(rc, guard)
}
}
pub fn set_collation_needed_func<F: Fn(&str)>(&self, func: F) -> Result<()> {
let func = Box::new(func);
let guard = self.lock();
unsafe {
Error::from_sqlite_desc(
ffi::sqlite3_collation_needed(
self.as_mut_ptr(),
Box::into_raw(func) as _,
Some(stubs::collation_needed::<F>),
),
guard,
)
}
}
}