use std::ffi::CString;
use libduckdb_sys::{
duckdb_add_aggregate_function_to_set, duckdb_aggregate_function_set_destructor,
duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
duckdb_connection, duckdb_create_aggregate_function, duckdb_create_aggregate_function_set,
duckdb_destroy_aggregate_function, duckdb_destroy_aggregate_function_set,
duckdb_register_aggregate_function_set, DuckDBSuccess,
};
use crate::aggregate::callbacks::{
CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
};
use crate::error::ExtensionError;
use crate::types::{LogicalType, NullHandling, TypeId};
use crate::validate::validate_function_name;
#[must_use]
pub struct AggregateFunctionSetBuilder {
pub(super) name: CString,
pub(super) return_type: Option<TypeId>,
pub(super) return_logical: Option<LogicalType>,
pub(super) overloads: Vec<OverloadSpec>,
}
pub(super) struct OverloadSpec {
pub(super) params: Vec<TypeId>,
pub(super) logical_params: Vec<(usize, LogicalType)>,
pub(super) state_size: Option<StateSizeFn>,
pub(super) init: Option<StateInitFn>,
pub(super) update: Option<UpdateFn>,
pub(super) combine: Option<CombineFn>,
pub(super) finalize: Option<FinalizeFn>,
pub(super) destructor: Option<DestroyFn>,
pub(super) null_handling: NullHandling,
}
impl AggregateFunctionSetBuilder {
pub fn new(name: &str) -> Self {
Self {
name: CString::new(name).expect("function name must not contain null bytes"),
return_type: None,
return_logical: None,
overloads: Vec::new(),
}
}
pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
validate_function_name(name)?;
let c_name = CString::new(name)
.map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
Ok(Self {
name: c_name,
return_type: None,
return_logical: None,
overloads: Vec::new(),
})
}
pub fn name(&self) -> &str {
self.name.to_str().unwrap_or("")
}
pub const fn returns(mut self, type_id: TypeId) -> Self {
self.return_type = Some(type_id);
self
}
pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
self.return_logical = Some(logical_type);
self
}
pub fn overloads<F>(mut self, range: std::ops::RangeInclusive<usize>, f: F) -> Self
where
F: Fn(usize, OverloadBuilder) -> OverloadBuilder,
{
for n in range {
let builder = f(n, OverloadBuilder::new());
self.overloads.push(OverloadSpec {
params: builder.params,
logical_params: builder.logical_params,
state_size: builder.state_size,
init: builder.init,
update: builder.update,
combine: builder.combine,
finalize: builder.finalize,
destructor: builder.destructor,
null_handling: builder.null_handling,
});
}
self
}
#[allow(clippy::too_many_lines)]
pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
let ret_lt = if let Some(lt) = self.return_logical {
lt
} else if let Some(id) = self.return_type {
LogicalType::new(id)
} else {
return Err(ExtensionError::new("return type not set for function set"));
};
if self.overloads.is_empty() {
return Err(ExtensionError::new("no overloads added to function set"));
}
let mut set = unsafe { duckdb_create_aggregate_function_set(self.name.as_ptr()) };
let mut register_error: Option<ExtensionError> = None;
for overload in &self.overloads {
let Some(state_size) = overload.state_size else {
register_error = Some(ExtensionError::new("overload missing state_size"));
break;
};
let Some(init) = overload.init else {
register_error = Some(ExtensionError::new("overload missing init"));
break;
};
let Some(update) = overload.update else {
register_error = Some(ExtensionError::new("overload missing update"));
break;
};
let Some(combine) = overload.combine else {
register_error = Some(ExtensionError::new("overload missing combine"));
break;
};
let Some(finalize) = overload.finalize else {
register_error = Some(ExtensionError::new("overload missing finalize"));
break;
};
let mut func = unsafe { duckdb_create_aggregate_function() };
unsafe {
duckdb_aggregate_function_set_name(func, self.name.as_ptr());
}
{
let mut simple_idx = 0;
let mut logical_idx = 0;
let total = overload.params.len() + overload.logical_params.len();
for pos in 0..total {
if logical_idx < overload.logical_params.len()
&& overload.logical_params[logical_idx].0 == pos
{
unsafe {
libduckdb_sys::duckdb_aggregate_function_add_parameter(
func,
overload.logical_params[logical_idx].1.as_raw(),
);
}
logical_idx += 1;
} else if simple_idx < overload.params.len() {
let lt = LogicalType::new(overload.params[simple_idx]);
unsafe {
libduckdb_sys::duckdb_aggregate_function_add_parameter(
func,
lt.as_raw(),
);
}
simple_idx += 1;
}
}
}
unsafe {
duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
}
unsafe {
duckdb_aggregate_function_set_functions(
func,
Some(state_size),
Some(init),
Some(update),
Some(combine),
Some(finalize),
);
}
if let Some(dtor) = overload.destructor {
unsafe {
duckdb_aggregate_function_set_destructor(func, Some(dtor));
}
}
if overload.null_handling == NullHandling::SpecialNullHandling {
unsafe {
duckdb_aggregate_function_set_special_handling(func);
}
}
unsafe {
duckdb_add_aggregate_function_to_set(set, func);
}
unsafe {
duckdb_destroy_aggregate_function(&raw mut func);
}
}
if register_error.is_none() {
let result = unsafe { duckdb_register_aggregate_function_set(con, set) };
if result != DuckDBSuccess {
register_error = Some(ExtensionError::new(format!(
"duckdb_register_aggregate_function_set failed for '{}'",
self.name.to_string_lossy()
)));
}
}
unsafe {
duckdb_destroy_aggregate_function_set(&raw mut set);
}
register_error.map_or(Ok(()), Err)
}
}
#[must_use]
pub struct OverloadBuilder {
pub(super) params: Vec<TypeId>,
pub(super) logical_params: Vec<(usize, LogicalType)>,
pub(super) state_size: Option<StateSizeFn>,
pub(super) init: Option<StateInitFn>,
pub(super) update: Option<UpdateFn>,
pub(super) combine: Option<CombineFn>,
pub(super) finalize: Option<FinalizeFn>,
pub(super) destructor: Option<DestroyFn>,
pub(super) null_handling: NullHandling,
}
impl OverloadBuilder {
pub(super) fn new() -> Self {
Self {
params: Vec::new(),
logical_params: Vec::new(),
state_size: None,
init: None,
update: None,
combine: None,
finalize: None,
destructor: None,
null_handling: NullHandling::DefaultNullHandling,
}
}
pub fn param(mut self, type_id: TypeId) -> Self {
self.params.push(type_id);
self
}
#[mutants::skip] pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
let position = self.params.len() + self.logical_params.len();
self.logical_params.push((position, logical_type));
self
}
pub fn state_size(mut self, f: StateSizeFn) -> Self {
self.state_size = Some(f);
self
}
pub fn init(mut self, f: StateInitFn) -> Self {
self.init = Some(f);
self
}
pub fn update(mut self, f: UpdateFn) -> Self {
self.update = Some(f);
self
}
pub fn combine(mut self, f: CombineFn) -> Self {
self.combine = Some(f);
self
}
pub fn finalize(mut self, f: FinalizeFn) -> Self {
self.finalize = Some(f);
self
}
pub fn destructor(mut self, f: DestroyFn) -> Self {
self.destructor = Some(f);
self
}
pub const fn null_handling(mut self, handling: NullHandling) -> Self {
self.null_handling = handling;
self
}
}