use std::os::raw::c_void;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::sync::Mutex;
use libduckdb_sys::{
duckdb_bind_info, duckdb_data_chunk, duckdb_data_chunk_set_size, duckdb_function_info,
duckdb_init_info,
};
use crate::data_chunk::DataChunk;
use crate::error::ExtensionError;
use crate::table::bind_data::FfiBindData;
use crate::table::builder::TableFunctionBuilder;
use crate::table::info::{BindInfo, FunctionInfo, InitInfo};
use crate::table::init_data::FfiInitData;
use crate::types::{LogicalType, TypeId};
type BindClosure<S> = dyn Fn(&BindInfo) -> Result<S, ExtensionError> + Send + Sync + 'static;
type ScanClosure<S> =
dyn Fn(&mut S, &DataChunk) -> Result<(), ExtensionError> + Send + Sync + 'static;
struct TypedCallbacks<S: Send + 'static> {
bind: Box<BindClosure<S>>,
scan: Box<ScanClosure<S>>,
}
impl<S: Send + 'static> TypedCallbacks<S> {
unsafe extern "C" fn destroy_extra(ptr: *mut c_void) {
if ptr.is_null() {
return;
}
unsafe {
drop(Box::from_raw(ptr.cast::<Self>()));
}
}
}
#[must_use]
pub struct TypedTableFunctionBuilder<S: Send + 'static> {
inner: TableFunctionBuilder,
bind: Option<Box<BindClosure<S>>>,
scan: Option<Box<ScanClosure<S>>>,
}
impl TableFunctionBuilder {
pub fn with_state<S, F>(self, bind: F) -> TypedTableFunctionBuilder<S>
where
S: Send + 'static,
F: Fn(&BindInfo) -> Result<S, ExtensionError> + Send + Sync + 'static,
{
TypedTableFunctionBuilder {
inner: self,
bind: Some(Box::new(bind)),
scan: None,
}
}
}
impl<S: Send + 'static> TypedTableFunctionBuilder<S> {
pub fn scan<F>(mut self, f: F) -> Self
where
F: Fn(&mut S, &DataChunk) -> Result<(), ExtensionError> + Send + Sync + 'static,
{
self.scan = Some(Box::new(f));
self
}
pub fn name(&self) -> &str {
self.inner.name()
}
pub fn param(mut self, type_id: TypeId) -> Self {
self.inner = self.inner.param(type_id);
self
}
pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
self.inner = self.inner.param_logical(logical_type);
self
}
pub fn named_param(mut self, name: &str, type_id: TypeId) -> Self {
self.inner = self.inner.named_param(name, type_id);
self
}
pub fn named_param_logical(mut self, name: &str, logical_type: LogicalType) -> Self {
self.inner = self.inner.named_param_logical(name, logical_type);
self
}
pub fn projection_pushdown(mut self, enable: bool) -> Self {
self.inner = self.inner.projection_pushdown(enable);
self
}
pub fn build(self) -> Result<TableFunctionBuilder, ExtensionError> {
let bind = self
.bind
.ok_or_else(|| ExtensionError::new("typed table function: bind closure not set"))?;
let scan = self
.scan
.ok_or_else(|| ExtensionError::new("typed table function: scan closure not set"))?;
let cbs = Box::new(TypedCallbacks::<S> { bind, scan });
let raw = Box::into_raw(cbs).cast::<c_void>();
let builder = unsafe {
self.inner
.bind(typed_bind_trampoline::<S>)
.init(typed_init_trampoline::<S>)
.scan(typed_scan_trampoline::<S>)
.extra_info(raw, TypedCallbacks::<S>::destroy_extra)
};
Ok(builder)
}
}
fn panic_message(payload: &(dyn std::any::Any + Send)) -> &'static str {
if payload.downcast_ref::<&'static str>().is_some()
|| payload.downcast_ref::<String>().is_some()
{
"quack-rs: typed table function closure panicked"
} else {
"quack-rs: typed table function closure panicked (unknown payload)"
}
}
unsafe extern "C" fn typed_bind_trampoline<S: Send + 'static>(info: duckdb_bind_info) {
let outcome = catch_unwind(AssertUnwindSafe(|| {
let bind_info = unsafe { BindInfo::new(info) };
let raw = unsafe { bind_info.get_extra_info() };
if raw.is_null() {
bind_info.set_error("quack-rs: typed table function missing extra_info");
return;
}
let cbs = unsafe { &*raw.cast::<TypedCallbacks<S>>() };
match (cbs.bind)(&bind_info) {
Ok(state) => {
unsafe {
FfiBindData::<Mutex<Option<S>>>::set(info, Mutex::new(Some(state)));
}
}
Err(e) => bind_info.set_error(e.as_str()),
}
}));
if let Err(payload) = outcome {
let bind_info = unsafe { BindInfo::new(info) };
bind_info.set_error(panic_message(&*payload));
}
}
unsafe extern "C" fn typed_init_trampoline<S: Send + 'static>(info: duckdb_init_info) {
let outcome = catch_unwind(AssertUnwindSafe(|| {
let init_info = unsafe { InitInfo::new(info) };
let bind_state = unsafe { FfiBindData::<Mutex<Option<S>>>::get_from_init(info) };
let Some(cell) = bind_state else {
init_info.set_error("quack-rs: typed table function missing bind state");
return;
};
let taken = if let Ok(mut guard) = cell.lock() {
guard.take()
} else {
init_info.set_error("quack-rs: typed table function bind-state mutex poisoned");
return;
};
let Some(state) = taken else {
init_info.set_error("quack-rs: typed table function bind state already consumed");
return;
};
unsafe {
FfiInitData::<S>::set(info, state);
}
init_info.set_max_threads(1);
}));
if let Err(payload) = outcome {
let init_info = unsafe { InitInfo::new(info) };
init_info.set_error(panic_message(&*payload));
}
}
unsafe extern "C" fn typed_scan_trampoline<S: Send + 'static>(
info: duckdb_function_info,
output: duckdb_data_chunk,
) {
let outcome = catch_unwind(AssertUnwindSafe(|| {
let fninfo = unsafe { FunctionInfo::new(info) };
let raw = unsafe { fninfo.get_extra_info() };
if raw.is_null() {
fninfo.set_error("quack-rs: typed table function missing extra_info");
unsafe { duckdb_data_chunk_set_size(output, 0) };
return;
}
let cbs = unsafe { &*raw.cast::<TypedCallbacks<S>>() };
let state = unsafe { FfiInitData::<S>::get_mut(info) };
let Some(state) = state else {
fninfo.set_error("quack-rs: typed table function missing scan state");
unsafe { duckdb_data_chunk_set_size(output, 0) };
return;
};
let chunk = unsafe { DataChunk::from_raw(output) };
if let Err(e) = (cbs.scan)(state, &chunk) {
fninfo.set_error(e.as_str());
unsafe { duckdb_data_chunk_set_size(output, 0) };
}
}));
if let Err(payload) = outcome {
let fninfo = unsafe { FunctionInfo::new(info) };
fninfo.set_error(panic_message(&*payload));
unsafe { duckdb_data_chunk_set_size(output, 0) };
}
}
#[cfg(test)]
mod tests {
use super::*;
struct DummyState {
_rows: u64,
}
#[test]
fn with_state_produces_typed_builder() {
let typed = TableFunctionBuilder::new("demo")
.with_state::<DummyState, _>(|_bind| Ok(DummyState { _rows: 10 }));
assert_eq!(typed.name(), "demo");
assert!(typed.bind.is_some());
assert!(typed.scan.is_none());
}
#[test]
fn build_without_scan_errors() {
let typed = TableFunctionBuilder::new("demo")
.with_state::<DummyState, _>(|_bind| Ok(DummyState { _rows: 10 }));
match typed.build() {
Err(e) => assert!(e.as_str().contains("scan closure not set")),
Ok(_) => panic!("expected error"),
}
}
#[test]
fn build_with_bind_and_scan_succeeds() {
let typed = TableFunctionBuilder::new("demo")
.param(TypeId::BigInt)
.with_state::<DummyState, _>(|_bind| Ok(DummyState { _rows: 10 }))
.scan(|_state, _chunk| Ok(()));
let builder = typed.build().expect("build should succeed");
assert_eq!(builder.name(), "demo");
}
#[test]
fn passthroughs_mutate_inner_builder() {
let typed = TableFunctionBuilder::new("demo")
.with_state::<DummyState, _>(|_| Ok(DummyState { _rows: 0 }))
.param(TypeId::Varchar)
.named_param("path", TypeId::Varchar)
.projection_pushdown(true);
assert_eq!(typed.name(), "demo");
}
#[test]
fn destroy_extra_null_is_noop() {
unsafe {
TypedCallbacks::<DummyState>::destroy_extra(std::ptr::null_mut());
}
}
#[test]
fn destroy_extra_drops_box() {
let cbs: Box<TypedCallbacks<DummyState>> = Box::new(TypedCallbacks {
bind: Box::new(|_| Ok(DummyState { _rows: 0 })),
scan: Box::new(|_, _| Ok(())),
});
let raw = Box::into_raw(cbs).cast::<c_void>();
unsafe { TypedCallbacks::<DummyState>::destroy_extra(raw) };
}
#[test]
fn panic_message_classifies_known_payloads() {
let s: Box<dyn std::any::Any + Send> = Box::new("boom");
assert!(panic_message(&*s).contains("panicked"));
let s: Box<dyn std::any::Any + Send> = Box::new(String::from("boom"));
assert!(panic_message(&*s).contains("panicked"));
let s: Box<dyn std::any::Any + Send> = Box::new(42_i32);
assert!(panic_message(&*s).contains("unknown payload"));
}
}