use std::ffi::CString;
use std::os::raw::c_void;
use libduckdb_sys::{
duckdb_add_scalar_function_to_set, duckdb_connection, duckdb_create_scalar_function,
duckdb_create_scalar_function_set, duckdb_delete_callback_t, duckdb_destroy_scalar_function,
duckdb_destroy_scalar_function_set, duckdb_register_scalar_function_set,
duckdb_scalar_function_add_parameter, duckdb_scalar_function_set_extra_info,
duckdb_scalar_function_set_function, duckdb_scalar_function_set_name,
duckdb_scalar_function_set_return_type, duckdb_scalar_function_set_special_handling,
DuckDBSuccess,
};
use crate::error::ExtensionError;
use crate::types::{LogicalType, NullHandling, TypeId};
use crate::validate::validate_function_name;
use super::single::ScalarFn;
#[must_use]
pub struct ScalarFunctionSetBuilder {
pub(super) name: CString,
pub(super) overloads: Vec<ScalarOverloadSpec>,
}
pub(super) struct ScalarOverloadSpec {
pub(super) params: Vec<TypeId>,
pub(super) logical_params: Vec<(usize, LogicalType)>,
pub(super) return_type: Option<TypeId>,
pub(super) return_logical: Option<LogicalType>,
pub(super) function: Option<ScalarFn>,
pub(super) null_handling: NullHandling,
pub(super) extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
}
impl ScalarFunctionSetBuilder {
pub fn new(name: &str) -> Self {
Self {
name: CString::new(name).expect("function name must not contain null bytes"),
overloads: Vec::new(),
}
}
pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
validate_function_name(name)?;
let c_name = CString::new(name)
.map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
Ok(Self {
name: c_name,
overloads: Vec::new(),
})
}
pub fn name(&self) -> &str {
self.name.to_str().unwrap_or("")
}
pub fn overload(mut self, builder: ScalarOverloadBuilder) -> Self {
self.overloads.push(ScalarOverloadSpec {
params: builder.params,
logical_params: builder.logical_params,
return_type: builder.return_type,
return_logical: builder.return_logical,
function: builder.function,
null_handling: builder.null_handling,
extra_info: builder.extra_info,
});
self
}
pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
if self.overloads.is_empty() {
return Err(ExtensionError::new(
"no overloads added to scalar function set",
));
}
let mut set = unsafe { duckdb_create_scalar_function_set(self.name.as_ptr()) };
let mut register_error: Option<ExtensionError> = None;
for overload in &self.overloads {
let (_ret_lt_owner, ret_raw) = if let Some(ref lt) = overload.return_logical {
(None, lt.as_raw())
} else if let Some(id) = overload.return_type {
let lt = LogicalType::new(id);
let raw = lt.as_raw();
(Some(lt), raw)
} else {
register_error = Some(ExtensionError::new("overload missing return type"));
break;
};
let Some(function) = overload.function else {
register_error = Some(ExtensionError::new("overload missing function callback"));
break;
};
let mut func = unsafe { duckdb_create_scalar_function() };
unsafe {
duckdb_scalar_function_set_name(func, self.name.as_ptr());
}
{
let mut simple_idx = 0;
let mut logical_idx = 0;
let total = overload.params.len() + overload.logical_params.len();
for pos in 0..total {
if logical_idx < overload.logical_params.len()
&& overload.logical_params[logical_idx].0 == pos
{
unsafe {
duckdb_scalar_function_add_parameter(
func,
overload.logical_params[logical_idx].1.as_raw(),
);
}
logical_idx += 1;
} else if simple_idx < overload.params.len() {
let lt = LogicalType::new(overload.params[simple_idx]);
unsafe {
duckdb_scalar_function_add_parameter(func, lt.as_raw());
}
simple_idx += 1;
}
}
}
unsafe {
duckdb_scalar_function_set_return_type(func, ret_raw);
}
unsafe {
duckdb_scalar_function_set_function(func, Some(function));
}
if overload.null_handling == NullHandling::SpecialNullHandling {
unsafe {
duckdb_scalar_function_set_special_handling(func);
}
}
if let Some((data, destroy)) = overload.extra_info {
unsafe {
duckdb_scalar_function_set_extra_info(func, data, destroy);
}
}
unsafe {
duckdb_add_scalar_function_to_set(set, func);
}
unsafe {
duckdb_destroy_scalar_function(&raw mut func);
}
}
if register_error.is_none() {
let result = unsafe { duckdb_register_scalar_function_set(con, set) };
if result != DuckDBSuccess {
register_error = Some(ExtensionError::new(format!(
"duckdb_register_scalar_function_set failed for '{}'",
self.name.to_string_lossy()
)));
}
}
unsafe {
duckdb_destroy_scalar_function_set(&raw mut set);
}
register_error.map_or(Ok(()), Err)
}
}
#[must_use]
pub struct ScalarOverloadBuilder {
pub(super) params: Vec<TypeId>,
pub(super) logical_params: Vec<(usize, LogicalType)>,
pub(super) return_type: Option<TypeId>,
pub(super) return_logical: Option<LogicalType>,
pub(super) function: Option<ScalarFn>,
pub(super) null_handling: NullHandling,
pub(super) extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
}
impl ScalarOverloadBuilder {
pub fn new() -> Self {
Self {
params: Vec::new(),
logical_params: Vec::new(),
return_type: None,
return_logical: None,
function: None,
null_handling: NullHandling::DefaultNullHandling,
extra_info: None,
}
}
pub fn param(mut self, type_id: TypeId) -> Self {
self.params.push(type_id);
self
}
#[mutants::skip] pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
let position = self.params.len() + self.logical_params.len();
self.logical_params.push((position, logical_type));
self
}
pub const fn returns(mut self, type_id: TypeId) -> Self {
self.return_type = Some(type_id);
self
}
#[mutants::skip] pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
self.return_logical = Some(logical_type);
self
}
pub fn function(mut self, f: ScalarFn) -> Self {
self.function = Some(f);
self
}
pub const fn null_handling(mut self, handling: NullHandling) -> Self {
self.null_handling = handling;
self
}
pub unsafe fn extra_info(
mut self,
data: *mut c_void,
destroy: duckdb_delete_callback_t,
) -> Self {
self.extra_info = Some((data, destroy));
self
}
}
impl Default for ScalarOverloadBuilder {
fn default() -> Self {
Self::new()
}
}