use libduckdb_sys::{duckdb_aggregate_state, duckdb_function_info, idx_t};
pub trait AggregateState: Default + Send + 'static {}
#[repr(C)]
pub struct FfiState<T: AggregateState> {
pub inner: *mut T,
}
impl<T: AggregateState> FfiState<T> {
#[inline]
#[must_use]
pub const fn size() -> usize {
core::mem::size_of::<Self>()
}
pub const unsafe extern "C" fn size_callback(_info: duckdb_function_info) -> idx_t {
core::mem::size_of::<Self>() as idx_t
}
pub unsafe extern "C" fn init_callback(
_info: duckdb_function_info,
state: duckdb_aggregate_state,
) {
let ffi = unsafe { &mut *(state.cast::<Self>()) };
ffi.inner = Box::into_raw(Box::<T>::default());
}
pub unsafe extern "C" fn destroy_callback(states: *mut duckdb_aggregate_state, count: idx_t) {
for i in 0..usize::try_from(count).unwrap_or(0) {
let state_ptr = unsafe { *states.add(i) };
let ffi = unsafe { &mut *(state_ptr.cast::<Self>()) };
if !ffi.inner.is_null() {
unsafe { drop(Box::from_raw(ffi.inner)) };
ffi.inner = core::ptr::null_mut();
}
}
}
pub unsafe fn with_state_mut<'a>(state: duckdb_aggregate_state) -> Option<&'a mut T> {
let ffi = unsafe { &mut *state.cast::<Self>() };
if ffi.inner.is_null() {
return None;
}
Some(unsafe { &mut *ffi.inner })
}
pub unsafe fn with_state<'a>(state: duckdb_aggregate_state) -> Option<&'a T> {
let ffi = unsafe { &*state.cast::<Self>() };
if ffi.inner.is_null() {
return None;
}
Some(unsafe { &*ffi.inner })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Default, Debug, PartialEq)]
struct Counter {
value: u64,
}
impl AggregateState for Counter {}
#[test]
fn ffi_state_is_pointer_sized() {
assert_eq!(
core::mem::size_of::<FfiState<Counter>>(),
core::mem::size_of::<*mut Counter>()
);
}
#[test]
fn size_returns_pointer_size() {
assert_eq!(FfiState::<Counter>::size(), core::mem::size_of::<usize>());
}
#[test]
fn init_and_destroy_lifecycle() {
let mut raw: FfiState<Counter> = FfiState {
inner: core::ptr::null_mut(),
};
let state_ptr = std::ptr::addr_of_mut!(raw) as duckdb_aggregate_state;
unsafe { FfiState::<Counter>::init_callback(core::ptr::null_mut(), state_ptr) };
assert!(!raw.inner.is_null());
let s = unsafe { FfiState::<Counter>::with_state_mut(state_ptr) };
assert!(s.is_some());
if let Some(counter) = s {
counter.value = 42;
}
let s2 = unsafe { FfiState::<Counter>::with_state(state_ptr) };
assert_eq!(s2.map(|c| c.value), Some(42));
let mut state_arr: [duckdb_aggregate_state; 1] = [state_ptr];
unsafe {
FfiState::<Counter>::destroy_callback(state_arr.as_mut_ptr(), 1);
}
assert!(raw.inner.is_null());
}
#[test]
fn destroy_null_inner_is_noop() {
let mut raw: FfiState<Counter> = FfiState {
inner: core::ptr::null_mut(),
};
let state_ptr = std::ptr::addr_of_mut!(raw) as duckdb_aggregate_state;
let mut state_arr: [duckdb_aggregate_state; 1] = [state_ptr];
unsafe {
FfiState::<Counter>::destroy_callback(state_arr.as_mut_ptr(), 1);
}
assert!(raw.inner.is_null());
}
#[test]
fn with_state_mut_null_inner_returns_none() {
let mut raw: FfiState<Counter> = FfiState {
inner: core::ptr::null_mut(),
};
let state_ptr = std::ptr::addr_of_mut!(raw) as duckdb_aggregate_state;
let result = unsafe { FfiState::<Counter>::with_state_mut(state_ptr) };
assert!(result.is_none());
}
#[test]
fn with_state_null_inner_returns_none() {
let raw: FfiState<Counter> = FfiState {
inner: core::ptr::null_mut(),
};
let state_ptr = std::ptr::addr_of!(raw) as duckdb_aggregate_state;
let result = unsafe { FfiState::<Counter>::with_state(state_ptr) };
assert!(result.is_none());
}
#[test]
fn size_callback_returns_pointer_size() {
let size = unsafe { FfiState::<Counter>::size_callback(core::ptr::null_mut()) };
assert_eq!(
usize::try_from(size).unwrap(),
core::mem::size_of::<usize>()
);
}
#[test]
fn multiple_state_destroy() {
let mut states: Vec<FfiState<Counter>> = (0..4)
.map(|_| FfiState {
inner: core::ptr::null_mut(),
})
.collect();
let mut ptrs: Vec<duckdb_aggregate_state> = states
.iter_mut()
.map(|s| std::ptr::from_mut::<FfiState<Counter>>(s) as duckdb_aggregate_state)
.collect();
for &ptr in &ptrs {
unsafe { FfiState::<Counter>::init_callback(core::ptr::null_mut(), ptr) };
}
for s in &states {
assert!(!s.inner.is_null());
}
unsafe {
FfiState::<Counter>::destroy_callback(ptrs.as_mut_ptr(), 4);
}
for s in &states {
assert!(s.inner.is_null());
}
}
}