#![allow(clippy::not_unsafe_ptr_arg_deref)]
use std::{
ffi::CString,
os::raw::{c_int, c_void},
slice,
};
use crate::{
api,
constants::{SQLITE_INTERNAL, SQLITE_OKAY},
errors::{Error, ErrorKind, Result},
ext::sqlite3ext_create_function_v2,
};
use sqlite3ext_sys::{sqlite3, sqlite3_context, sqlite3_user_data, sqlite3_value};
use bitflags::bitflags;
use sqlite3ext_sys::{
SQLITE_DETERMINISTIC, SQLITE_DIRECTONLY, SQLITE_INNOCUOUS, SQLITE_SUBTYPE, SQLITE_UTF16,
SQLITE_UTF16BE, SQLITE_UTF16LE, SQLITE_UTF8,
};
bitflags! {
pub struct FunctionFlags: i32 {
const UTF8 = SQLITE_UTF8 as i32;
const UTF16LE = SQLITE_UTF16LE as i32;
const UTF16BE = SQLITE_UTF16BE as i32;
const UTF16 = SQLITE_UTF16 as i32;
const DETERMINISTIC = SQLITE_DETERMINISTIC as i32;
const DIRECTONLY = SQLITE_DIRECTONLY as i32;
const SUBTYPE = SQLITE_SUBTYPE as i32;
const INNOCUOUS = SQLITE_INNOCUOUS as i32;
}
}
pub fn define_scalar_function<F>(
db: *mut sqlite3,
name: &str,
num_args: c_int,
x_func: F,
func_flags: FunctionFlags,
) -> Result<()>
where
F: Fn(*mut sqlite3_context, &[*mut sqlite3_value]) -> Result<()>,
{
let function_pointer: *mut F = Box::into_raw(Box::new(x_func));
unsafe extern "C" fn x_func_wrapper<F>(
context: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
F: Fn(*mut sqlite3_context, &[*mut sqlite3_value]) -> Result<()>,
{
let boxed_function: *mut F = sqlite3_user_data(context).cast::<F>();
let args = slice::from_raw_parts(argv, argc as usize);
match (*boxed_function)(context, args) {
Ok(()) => (),
Err(e) => {
if api::result_error(context, &e.result_error_message()).is_err() {
api::result_error_code(context, SQLITE_INTERNAL);
}
}
}
}
let cname = CString::new(name)?;
let result = unsafe {
sqlite3ext_create_function_v2(
db,
cname.as_ptr(),
num_args,
func_flags.bits,
function_pointer.cast::<c_void>(),
Some(x_func_wrapper::<F>),
None,
None,
None,
)
};
if result != SQLITE_OKAY {
Err(Error::new(ErrorKind::DefineScalarFunction(result)))
} else {
Ok(())
}
}
pub fn delete_scalar_function(
db: *mut sqlite3,
name: &str,
num_args: c_int,
func_flags: FunctionFlags,
) -> Result<()> {
let cname = CString::new(name)?;
let result = unsafe {
sqlite3ext_create_function_v2(
db,
cname.as_ptr(),
num_args,
func_flags.bits,
std::ptr::null_mut(),
None,
None,
None,
None,
)
};
if result != SQLITE_OKAY {
println!("failed with {result}");
Err(Error::new(ErrorKind::DefineScalarFunction(result)))
} else {
Ok(())
}
}
pub fn define_scalar_function_with_aux<F, T>(
db: *mut sqlite3,
name: &str,
num_args: c_int,
x_func: F,
func_flags: FunctionFlags,
aux: T,
) -> Result<()>
where
F: Fn(*mut sqlite3_context, &[*mut sqlite3_value], &T) -> Result<()>,
{
let function_pointer: *mut F = Box::into_raw(Box::new(x_func));
let aux_pointer: *mut T = Box::into_raw(Box::new(aux));
let app_pointer = Box::into_raw(Box::new((function_pointer, aux_pointer)));
unsafe extern "C" fn x_func_wrapper<F, T>(
context: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
F: Fn(*mut sqlite3_context, &[*mut sqlite3_value], &T) -> Result<()>,
{
let x = sqlite3_user_data(context).cast::<(*mut F, *mut T)>();
let boxed_function = (*x).0;
let aux = (*x).1;
let args = slice::from_raw_parts(argv, argc as usize);
let b = Box::from_raw(aux);
match (*boxed_function)(context, args, &*b) {
Ok(()) => (),
Err(e) => {
if api::result_error(context, &e.result_error_message()).is_err() {
api::result_error_code(context, SQLITE_INTERNAL);
}
}
}
Box::into_raw(b);
}
let cname = CString::new(name)?;
let result = unsafe {
sqlite3ext_create_function_v2(
db,
cname.as_ptr(),
num_args,
func_flags.bits,
app_pointer.cast::<c_void>(),
Some(x_func_wrapper::<F, T>),
None,
None,
None,
)
};
if result != SQLITE_OKAY {
Err(Error::new(ErrorKind::DefineScalarFunction(result)))
} else {
Ok(())
}
}