use std::ffi::CString;
use std::os::raw::c_void;
use libduckdb_sys::{
duckdb_bind_info, duckdb_connection, duckdb_create_table_function, duckdb_data_chunk,
duckdb_destroy_table_function, duckdb_function_info, duckdb_init_info,
duckdb_register_table_function, duckdb_table_function_add_named_parameter,
duckdb_table_function_add_parameter, duckdb_table_function_set_bind,
duckdb_table_function_set_extra_info, duckdb_table_function_set_function,
duckdb_table_function_set_init, duckdb_table_function_set_local_init,
duckdb_table_function_set_name, duckdb_table_function_supports_projection_pushdown,
DuckDBSuccess,
};
use crate::error::ExtensionError;
use crate::types::{LogicalType, TypeId};
use crate::validate::validate_function_name;
pub type BindFn = unsafe extern "C" fn(info: duckdb_bind_info);
pub type InitFn = unsafe extern "C" fn(info: duckdb_init_info);
pub type ScanFn = unsafe extern "C" fn(info: duckdb_function_info, output: duckdb_data_chunk);
pub type ExtraDestroyFn = unsafe extern "C" fn(data: *mut c_void);
enum NamedParam {
Simple {
name: CString,
type_id: TypeId,
},
Logical {
name: CString,
logical_type: LogicalType,
},
}
#[must_use]
pub struct TableFunctionBuilder {
name: CString,
params: Vec<TypeId>,
logical_params: Vec<(usize, LogicalType)>,
named_params: Vec<NamedParam>,
bind: Option<BindFn>,
init: Option<InitFn>,
local_init: Option<InitFn>,
scan: Option<ScanFn>,
projection_pushdown: bool,
extra_info: Option<(*mut c_void, ExtraDestroyFn)>,
}
impl TableFunctionBuilder {
pub fn new(name: &str) -> Self {
Self {
name: CString::new(name).expect("function name must not contain null bytes"),
params: Vec::new(),
logical_params: Vec::new(),
named_params: Vec::new(),
bind: None,
init: None,
local_init: None,
scan: None,
projection_pushdown: false,
extra_info: None,
}
}
pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
validate_function_name(name)?;
let c_name = CString::new(name)
.map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
Ok(Self {
name: c_name,
params: Vec::new(),
logical_params: Vec::new(),
named_params: Vec::new(),
bind: None,
init: None,
local_init: None,
scan: None,
projection_pushdown: false,
extra_info: None,
})
}
pub fn name(&self) -> &str {
self.name.to_str().unwrap_or("")
}
pub fn param(mut self, type_id: TypeId) -> Self {
self.params.push(type_id);
self
}
#[mutants::skip] pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
let position = self.params.len() + self.logical_params.len();
self.logical_params.push((position, logical_type));
self
}
pub fn named_param(mut self, name: &str, type_id: TypeId) -> Self {
self.named_params.push(NamedParam::Simple {
name: CString::new(name).expect("parameter name must not contain null bytes"),
type_id,
});
self
}
pub fn named_param_logical(mut self, name: &str, logical_type: LogicalType) -> Self {
self.named_params.push(NamedParam::Logical {
name: CString::new(name).expect("parameter name must not contain null bytes"),
logical_type,
});
self
}
pub fn bind(mut self, f: BindFn) -> Self {
self.bind = Some(f);
self
}
pub fn init(mut self, f: InitFn) -> Self {
self.init = Some(f);
self
}
pub fn local_init(mut self, f: InitFn) -> Self {
self.local_init = Some(f);
self
}
pub fn scan(mut self, f: ScanFn) -> Self {
self.scan = Some(f);
self
}
pub const fn projection_pushdown(mut self, enable: bool) -> Self {
self.projection_pushdown = enable;
self
}
pub unsafe fn extra_info(mut self, data: *mut c_void, destroy: ExtraDestroyFn) -> Self {
self.extra_info = Some((data, destroy));
self
}
pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
let bind = self
.bind
.ok_or_else(|| ExtensionError::new("bind callback not set"))?;
let init = self
.init
.ok_or_else(|| ExtensionError::new("init callback not set"))?;
let scan = self
.scan
.ok_or_else(|| ExtensionError::new("scan callback not set"))?;
let mut func = unsafe { duckdb_create_table_function() };
unsafe {
duckdb_table_function_set_name(func, self.name.as_ptr());
}
{
let mut simple_idx = 0;
let mut logical_idx = 0;
let total = self.params.len() + self.logical_params.len();
for pos in 0..total {
if logical_idx < self.logical_params.len()
&& self.logical_params[logical_idx].0 == pos
{
unsafe {
duckdb_table_function_add_parameter(
func,
self.logical_params[logical_idx].1.as_raw(),
);
}
logical_idx += 1;
} else if simple_idx < self.params.len() {
let lt = LogicalType::new(self.params[simple_idx]);
unsafe {
duckdb_table_function_add_parameter(func, lt.as_raw());
}
simple_idx += 1;
}
}
}
for np in &self.named_params {
match np {
NamedParam::Simple { name, type_id } => {
let lt = LogicalType::new(*type_id);
unsafe {
duckdb_table_function_add_named_parameter(func, name.as_ptr(), lt.as_raw());
}
}
NamedParam::Logical { name, logical_type } => unsafe {
duckdb_table_function_add_named_parameter(
func,
name.as_ptr(),
logical_type.as_raw(),
);
},
}
}
unsafe {
duckdb_table_function_set_bind(func, Some(bind));
duckdb_table_function_set_init(func, Some(init));
duckdb_table_function_set_function(func, Some(scan));
}
if let Some(local_init) = self.local_init {
unsafe {
duckdb_table_function_set_local_init(func, Some(local_init));
}
}
unsafe {
duckdb_table_function_supports_projection_pushdown(func, self.projection_pushdown);
}
if let Some((data, destroy)) = self.extra_info {
unsafe {
duckdb_table_function_set_extra_info(func, data, Some(destroy));
}
}
let result = unsafe { duckdb_register_table_function(con, func) };
unsafe {
duckdb_destroy_table_function(&raw mut func);
}
if result == DuckDBSuccess {
Ok(())
} else {
Err(ExtensionError::new(format!(
"duckdb_register_table_function failed for '{}'",
self.name.to_string_lossy()
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_stores_name() {
let b = TableFunctionBuilder::new("my_table_fn");
assert_eq!(b.name.to_str().unwrap(), "my_table_fn");
}
#[test]
fn builder_stores_params() {
let b = TableFunctionBuilder::new("f")
.param(TypeId::Varchar)
.param(TypeId::BigInt);
assert_eq!(b.params.len(), 2);
assert_eq!(b.params[0], TypeId::Varchar);
assert_eq!(b.params[1], TypeId::BigInt);
}
#[test]
fn builder_stores_named_params() {
let b = TableFunctionBuilder::new("f")
.named_param("path", TypeId::Varchar)
.named_param("limit", TypeId::BigInt);
assert_eq!(b.named_params.len(), 2);
match &b.named_params[0] {
NamedParam::Simple { name, .. } => assert_eq!(name.to_str().unwrap(), "path"),
NamedParam::Logical { .. } => panic!("expected Simple"),
}
match &b.named_params[1] {
NamedParam::Simple { name, .. } => assert_eq!(name.to_str().unwrap(), "limit"),
NamedParam::Logical { .. } => panic!("expected Simple"),
}
}
#[test]
fn builder_stores_callbacks() {
unsafe extern "C" fn my_bind(_: duckdb_bind_info) {}
unsafe extern "C" fn my_init(_: duckdb_init_info) {}
unsafe extern "C" fn my_scan(_: duckdb_function_info, _: duckdb_data_chunk) {}
let b = TableFunctionBuilder::new("f")
.bind(my_bind)
.init(my_init)
.scan(my_scan);
assert!(b.bind.is_some());
assert!(b.init.is_some());
assert!(b.scan.is_some());
}
#[test]
fn builder_projection_pushdown() {
let b = TableFunctionBuilder::new("f").projection_pushdown(true);
assert!(b.projection_pushdown);
}
#[test]
fn try_new_valid_name() {
assert!(TableFunctionBuilder::try_new("read_csv_ext").is_ok());
}
#[test]
fn try_new_invalid_name() {
assert!(TableFunctionBuilder::try_new("").is_err());
assert!(TableFunctionBuilder::try_new("MyFunc").is_err());
}
#[test]
fn try_new_null_byte_rejected() {
assert!(TableFunctionBuilder::try_new("func\0name").is_err());
}
#[test]
fn param_logical_position_tracking() {
let fake_lt = unsafe { LogicalType::from_raw(std::ptr::NonNull::dangling().as_ptr()) };
let b = TableFunctionBuilder::new("f")
.param(TypeId::Integer)
.param_logical(fake_lt);
assert_eq!(b.params.len(), 1);
assert_eq!(b.logical_params.len(), 1);
assert_eq!(b.logical_params[0].0, 1);
std::mem::forget(b);
}
}