use std::ffi::CString;
use std::os::raw::c_void;
use libduckdb_sys::{
duckdb_aggregate_function_set_destructor, duckdb_aggregate_function_set_extra_info,
duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
duckdb_connection, duckdb_create_aggregate_function, duckdb_delete_callback_t,
duckdb_destroy_aggregate_function, duckdb_register_aggregate_function, DuckDBSuccess,
};
use crate::aggregate::callbacks::{
CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
};
use crate::error::ExtensionError;
use crate::types::{LogicalType, NullHandling, TypeId};
use crate::validate::validate_function_name;
#[must_use]
pub struct AggregateFunctionBuilder {
pub(super) name: CString,
pub(super) params: Vec<TypeId>,
pub(super) logical_params: Vec<(usize, LogicalType)>,
pub(super) return_type: Option<TypeId>,
pub(super) return_logical: Option<LogicalType>,
pub(super) state_size: Option<StateSizeFn>,
pub(super) init: Option<StateInitFn>,
pub(super) update: Option<UpdateFn>,
pub(super) combine: Option<CombineFn>,
pub(super) finalize: Option<FinalizeFn>,
pub(super) destructor: Option<DestroyFn>,
pub(super) null_handling: NullHandling,
pub(super) extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
}
impl AggregateFunctionBuilder {
pub fn new(name: &str) -> Self {
Self {
name: CString::new(name).expect("function name must not contain null bytes"),
params: Vec::new(),
logical_params: Vec::new(),
return_type: None,
return_logical: None,
state_size: None,
init: None,
update: None,
combine: None,
finalize: None,
destructor: None,
null_handling: NullHandling::DefaultNullHandling,
extra_info: None,
}
}
pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
validate_function_name(name)?;
let c_name = CString::new(name)
.map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
Ok(Self {
name: c_name,
params: Vec::new(),
logical_params: Vec::new(),
return_type: None,
return_logical: None,
state_size: None,
init: None,
update: None,
combine: None,
finalize: None,
destructor: None,
null_handling: NullHandling::DefaultNullHandling,
extra_info: None,
})
}
pub fn name(&self) -> &str {
self.name.to_str().unwrap_or("")
}
pub fn param(mut self, type_id: TypeId) -> Self {
self.params.push(type_id);
self
}
#[mutants::skip] pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
let position = self.params.len() + self.logical_params.len();
self.logical_params.push((position, logical_type));
self
}
pub const fn returns(mut self, type_id: TypeId) -> Self {
self.return_type = Some(type_id);
self
}
pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
self.return_logical = Some(logical_type);
self
}
pub fn state_size(mut self, f: StateSizeFn) -> Self {
self.state_size = Some(f);
self
}
pub fn init(mut self, f: StateInitFn) -> Self {
self.init = Some(f);
self
}
pub fn update(mut self, f: UpdateFn) -> Self {
self.update = Some(f);
self
}
pub fn combine(mut self, f: CombineFn) -> Self {
self.combine = Some(f);
self
}
pub fn finalize(mut self, f: FinalizeFn) -> Self {
self.finalize = Some(f);
self
}
pub fn destructor(mut self, f: DestroyFn) -> Self {
self.destructor = Some(f);
self
}
pub const fn null_handling(mut self, handling: NullHandling) -> Self {
self.null_handling = handling;
self
}
pub unsafe fn extra_info(
mut self,
data: *mut c_void,
destroy: duckdb_delete_callback_t,
) -> Self {
self.extra_info = Some((data, destroy));
self
}
pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
let ret_lt = if let Some(lt) = self.return_logical {
lt
} else if let Some(id) = self.return_type {
LogicalType::new(id)
} else {
return Err(ExtensionError::new("return type not set"));
};
let state_size = self
.state_size
.ok_or_else(|| ExtensionError::new("state_size callback not set"))?;
let init = self
.init
.ok_or_else(|| ExtensionError::new("init callback not set"))?;
let update = self
.update
.ok_or_else(|| ExtensionError::new("update callback not set"))?;
let combine = self
.combine
.ok_or_else(|| ExtensionError::new("combine callback not set"))?;
let finalize = self
.finalize
.ok_or_else(|| ExtensionError::new("finalize callback not set"))?;
let mut func = unsafe { duckdb_create_aggregate_function() };
unsafe {
duckdb_aggregate_function_set_name(func, self.name.as_ptr());
}
{
let mut simple_idx = 0;
let mut logical_idx = 0;
let total = self.params.len() + self.logical_params.len();
for pos in 0..total {
if logical_idx < self.logical_params.len()
&& self.logical_params[logical_idx].0 == pos
{
unsafe {
libduckdb_sys::duckdb_aggregate_function_add_parameter(
func,
self.logical_params[logical_idx].1.as_raw(),
);
}
logical_idx += 1;
} else if simple_idx < self.params.len() {
let lt = LogicalType::new(self.params[simple_idx]);
unsafe {
libduckdb_sys::duckdb_aggregate_function_add_parameter(func, lt.as_raw());
}
simple_idx += 1;
}
}
}
unsafe {
duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
}
unsafe {
duckdb_aggregate_function_set_functions(
func,
Some(state_size),
Some(init),
Some(update),
Some(combine),
Some(finalize),
);
}
if let Some(dtor) = self.destructor {
unsafe {
duckdb_aggregate_function_set_destructor(func, Some(dtor));
}
}
if self.null_handling == NullHandling::SpecialNullHandling {
unsafe {
duckdb_aggregate_function_set_special_handling(func);
}
}
if let Some((data, destroy)) = self.extra_info {
unsafe {
duckdb_aggregate_function_set_extra_info(func, data, destroy);
}
}
let result = unsafe { duckdb_register_aggregate_function(con, func) };
unsafe {
duckdb_destroy_aggregate_function(&raw mut func);
}
if result == DuckDBSuccess {
Ok(())
} else {
Err(ExtensionError::new(format!(
"duckdb_register_aggregate_function failed for '{}'",
self.name.to_string_lossy()
)))
}
}
}