quack_rs/aggregate/builder/set.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6use std::ffi::CString;
7
8use libduckdb_sys::{
9 duckdb_add_aggregate_function_to_set, duckdb_aggregate_function_set_destructor,
10 duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
11 duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
12 duckdb_connection, duckdb_create_aggregate_function, duckdb_create_aggregate_function_set,
13 duckdb_destroy_aggregate_function, duckdb_destroy_aggregate_function_set,
14 duckdb_register_aggregate_function_set, DuckDBSuccess,
15};
16
17use crate::aggregate::callbacks::{
18 CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
19};
20use crate::error::ExtensionError;
21use crate::types::{LogicalType, NullHandling, TypeId};
22use crate::validate::validate_function_name;
23
24/// Builder for registering a `DuckDB` aggregate function set (multiple overloads).
25///
26/// Use this when your function accepts a variable number of arguments by
27/// registering N overloads (one per arity) under a single name.
28///
29/// # ADR-2: Function sets for variadic signatures
30///
31/// `DuckDB` does not support true varargs for aggregate functions. For functions
32/// that accept 2–32 boolean conditions, register 31 overloads.
33///
34/// # Pitfall L6: Name must be set on each member
35///
36/// This builder calls `duckdb_aggregate_function_set_name` on EVERY individual
37/// function before adding it to the set. If you forget this call, `DuckDB`
38/// silently rejects the registration. Discovery of this bug required reading
39/// `DuckDB`'s own C++ test code at `test/api/capi/test_capi_aggregate_functions.cpp`.
40///
41/// # Example
42///
43/// ```rust,no_run
44/// use quack_rs::aggregate::AggregateFunctionSetBuilder;
45/// use quack_rs::types::{LogicalType, TypeId};
46/// use libduckdb_sys::duckdb_connection;
47///
48/// // fn register_retention(con: duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
49/// // AggregateFunctionSetBuilder::new("retention")
50/// // .returns_logical(LogicalType::list(TypeId::Boolean))
51/// // .overloads(2..=32, |_n, builder| {
52/// // builder
53/// // .state_size(state_size)
54/// // .init(state_init)
55/// // .update(update)
56/// // .combine(combine)
57/// // .finalize(finalize)
58/// // .destructor(destroy)
59/// // })
60/// // .register(con)
61/// // }
62/// ```
63#[must_use]
64pub struct AggregateFunctionSetBuilder {
65 pub(super) name: CString,
66 pub(super) return_type: Option<TypeId>,
67 pub(super) return_logical: Option<LogicalType>,
68 pub(super) overloads: Vec<OverloadSpec>,
69}
70
71/// Specification for one overload within a function set.
72pub(super) struct OverloadSpec {
73 pub(super) params: Vec<TypeId>,
74 pub(super) logical_params: Vec<(usize, LogicalType)>,
75 pub(super) state_size: Option<StateSizeFn>,
76 pub(super) init: Option<StateInitFn>,
77 pub(super) update: Option<UpdateFn>,
78 pub(super) combine: Option<CombineFn>,
79 pub(super) finalize: Option<FinalizeFn>,
80 pub(super) destructor: Option<DestroyFn>,
81 pub(super) null_handling: NullHandling,
82}
83
84impl AggregateFunctionSetBuilder {
85 /// Creates a new builder for a function set with the given name.
86 ///
87 /// # Panics
88 ///
89 /// Panics if `name` contains an interior null byte.
90 pub fn new(name: &str) -> Self {
91 Self {
92 name: CString::new(name).expect("function name must not contain null bytes"),
93 return_type: None,
94 return_logical: None,
95 overloads: Vec::new(),
96 }
97 }
98
99 /// Creates a new builder with function name validation.
100 ///
101 /// # Errors
102 ///
103 /// Returns `ExtensionError` if the name is invalid.
104 /// See [`validate_function_name`] for the full set of rules.
105 pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
106 validate_function_name(name)?;
107 let c_name = CString::new(name)
108 .map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
109 Ok(Self {
110 name: c_name,
111 return_type: None,
112 return_logical: None,
113 overloads: Vec::new(),
114 })
115 }
116
117 /// Returns the function set name.
118 ///
119 /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
120 pub fn name(&self) -> &str {
121 self.name.to_str().unwrap_or("")
122 }
123
124 /// Sets the return type for all overloads in this function set.
125 ///
126 /// For complex return types like `LIST(BIGINT)`, use
127 /// [`returns_logical`][Self::returns_logical] instead.
128 pub const fn returns(mut self, type_id: TypeId) -> Self {
129 self.return_type = Some(type_id);
130 self
131 }
132
133 /// Sets the return type to a complex [`LogicalType`] for all overloads.
134 ///
135 /// Use this for parameterized return types that [`TypeId`] cannot express,
136 /// such as `LIST(BOOLEAN)`, `LIST(TIMESTAMP)`, `MAP(VARCHAR, INTEGER)`, etc.
137 ///
138 /// If both `returns` and `returns_logical` are called, the logical type takes
139 /// precedence.
140 ///
141 /// # Example
142 ///
143 /// ```rust,no_run
144 /// use quack_rs::aggregate::AggregateFunctionSetBuilder;
145 /// use quack_rs::types::{LogicalType, TypeId};
146 ///
147 /// // AggregateFunctionSetBuilder::new("retention")
148 /// // .returns_logical(LogicalType::list(TypeId::Boolean))
149 /// // .overloads(2..=32, |n, builder| {
150 /// // (0..n).fold(builder, |b, _| b.param(TypeId::Boolean))
151 /// // .state_size(my_state_size)
152 /// // .init(my_init)
153 /// // .update(my_update)
154 /// // .combine(my_combine)
155 /// // .finalize(my_finalize)
156 /// // });
157 /// ```
158 pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
159 self.return_logical = Some(logical_type);
160 self
161 }
162
163 /// Adds overloads for each arity in `range`, using the given builder closure.
164 ///
165 /// The closure receives:
166 /// - `n`: the number of parameters for this overload
167 /// - A fresh [`OverloadBuilder`] for configuring callbacks
168 ///
169 /// # Example
170 ///
171 /// ```rust,no_run
172 /// use quack_rs::aggregate::AggregateFunctionSetBuilder;
173 /// use quack_rs::types::TypeId;
174 ///
175 /// // AggregateFunctionSetBuilder::new("retention")
176 /// // .returns(TypeId::BigInt)
177 /// // .overloads(2..=32, |n, builder| {
178 /// // let builder = builder
179 /// // .state_size(my_state_size)
180 /// // .init(my_init)
181 /// // .update(my_update)
182 /// // .combine(my_combine)
183 /// // .finalize(my_finalize);
184 /// // // `n` booleans as params
185 /// // (0..n).fold(builder, |b, _| b.param(TypeId::Boolean))
186 /// // });
187 /// ```
188 pub fn overloads<F>(mut self, range: std::ops::RangeInclusive<usize>, f: F) -> Self
189 where
190 F: Fn(usize, OverloadBuilder) -> OverloadBuilder,
191 {
192 for n in range {
193 let builder = f(n, OverloadBuilder::new());
194 self.overloads.push(OverloadSpec {
195 params: builder.params,
196 logical_params: builder.logical_params,
197 state_size: builder.state_size,
198 init: builder.init,
199 update: builder.update,
200 combine: builder.combine,
201 finalize: builder.finalize,
202 destructor: builder.destructor,
203 null_handling: builder.null_handling,
204 });
205 }
206 self
207 }
208
209 /// Registers the function set on the given connection.
210 ///
211 /// # Pitfall L6
212 ///
213 /// This method calls `duckdb_aggregate_function_set_name` on EVERY individual
214 /// function in the set. Omitting this call causes silent registration failure.
215 ///
216 /// # Errors
217 ///
218 /// Returns `ExtensionError` if:
219 /// - Return type was not set.
220 /// - Any overload is missing required callbacks.
221 /// - `DuckDB` reports registration failure.
222 ///
223 /// # Safety
224 ///
225 /// `con` must be a valid, open `duckdb_connection`.
226 #[allow(clippy::too_many_lines)]
227 pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
228 // Resolve return type: prefer explicit LogicalType over TypeId.
229 let ret_lt = if let Some(lt) = self.return_logical {
230 lt
231 } else if let Some(id) = self.return_type {
232 LogicalType::new(id)
233 } else {
234 return Err(ExtensionError::new("return type not set for function set"));
235 };
236
237 if self.overloads.is_empty() {
238 return Err(ExtensionError::new("no overloads added to function set"));
239 }
240
241 // SAFETY: Creates a new aggregate function set handle.
242 let mut set = unsafe { duckdb_create_aggregate_function_set(self.name.as_ptr()) };
243
244 let mut register_error: Option<ExtensionError> = None;
245
246 for overload in &self.overloads {
247 let Some(state_size) = overload.state_size else {
248 register_error = Some(ExtensionError::new("overload missing state_size"));
249 break;
250 };
251 let Some(init) = overload.init else {
252 register_error = Some(ExtensionError::new("overload missing init"));
253 break;
254 };
255 let Some(update) = overload.update else {
256 register_error = Some(ExtensionError::new("overload missing update"));
257 break;
258 };
259 let Some(combine) = overload.combine else {
260 register_error = Some(ExtensionError::new("overload missing combine"));
261 break;
262 };
263 let Some(finalize) = overload.finalize else {
264 register_error = Some(ExtensionError::new("overload missing finalize"));
265 break;
266 };
267
268 // SAFETY: Creates a new aggregate function handle for this overload.
269 let mut func = unsafe { duckdb_create_aggregate_function() };
270
271 // PITFALL L6: CRITICAL — must call this on EACH function, not just the set.
272 // Without this, duckdb_register_aggregate_function_set silently returns DuckDBError.
273 // Discovered by reading DuckDB's test/api/capi/test_capi_aggregate_functions.cpp.
274 unsafe {
275 duckdb_aggregate_function_set_name(func, self.name.as_ptr());
276 }
277
278 // Add parameters: merge simple TypeId params and complex LogicalType params
279 // in the order they were added (tracked by position).
280 {
281 let mut simple_idx = 0;
282 let mut logical_idx = 0;
283 let total = overload.params.len() + overload.logical_params.len();
284 for pos in 0..total {
285 if logical_idx < overload.logical_params.len()
286 && overload.logical_params[logical_idx].0 == pos
287 {
288 // SAFETY: func and logical type handle are valid.
289 unsafe {
290 libduckdb_sys::duckdb_aggregate_function_add_parameter(
291 func,
292 overload.logical_params[logical_idx].1.as_raw(),
293 );
294 }
295 logical_idx += 1;
296 } else if simple_idx < overload.params.len() {
297 let lt = LogicalType::new(overload.params[simple_idx]);
298 // SAFETY: func and lt.as_raw() are valid handles.
299 unsafe {
300 libduckdb_sys::duckdb_aggregate_function_add_parameter(
301 func,
302 lt.as_raw(),
303 );
304 }
305 simple_idx += 1;
306 }
307 }
308 }
309
310 // Set return type (shared across all overloads)
311 // SAFETY: func and ret_lt.as_raw() are valid.
312 unsafe {
313 duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
314 }
315
316 // Set callbacks
317 unsafe {
318 duckdb_aggregate_function_set_functions(
319 func,
320 Some(state_size),
321 Some(init),
322 Some(update),
323 Some(combine),
324 Some(finalize),
325 );
326 }
327
328 if let Some(dtor) = overload.destructor {
329 unsafe {
330 duckdb_aggregate_function_set_destructor(func, Some(dtor));
331 }
332 }
333
334 // Set special NULL handling if requested
335 if overload.null_handling == NullHandling::SpecialNullHandling {
336 // SAFETY: func is a valid aggregate function handle.
337 unsafe {
338 duckdb_aggregate_function_set_special_handling(func);
339 }
340 }
341
342 // Add this function to the set
343 // SAFETY: set and func are valid handles.
344 unsafe {
345 duckdb_add_aggregate_function_to_set(set, func);
346 }
347
348 // SAFETY: func was created above and ownership transferred to the set.
349 unsafe {
350 duckdb_destroy_aggregate_function(&raw mut func);
351 }
352 }
353
354 if register_error.is_none() {
355 // SAFETY: con is valid and set is fully configured.
356 let result = unsafe { duckdb_register_aggregate_function_set(con, set) };
357 if result != DuckDBSuccess {
358 register_error = Some(ExtensionError::new(format!(
359 "duckdb_register_aggregate_function_set failed for '{}'",
360 self.name.to_string_lossy()
361 )));
362 }
363 }
364
365 // SAFETY: set was created above and must be destroyed.
366 unsafe {
367 duckdb_destroy_aggregate_function_set(&raw mut set);
368 }
369
370 register_error.map_or(Ok(()), Err)
371 }
372}
373
374/// A builder for one overload within a [`AggregateFunctionSetBuilder`].
375///
376/// Returned by the closure passed to [`AggregateFunctionSetBuilder::overloads`].
377#[must_use]
378pub struct OverloadBuilder {
379 pub(super) params: Vec<TypeId>,
380 pub(super) logical_params: Vec<(usize, LogicalType)>,
381 pub(super) state_size: Option<StateSizeFn>,
382 pub(super) init: Option<StateInitFn>,
383 pub(super) update: Option<UpdateFn>,
384 pub(super) combine: Option<CombineFn>,
385 pub(super) finalize: Option<FinalizeFn>,
386 pub(super) destructor: Option<DestroyFn>,
387 pub(super) null_handling: NullHandling,
388}
389
390impl OverloadBuilder {
391 /// Creates a new `OverloadBuilder`.
392 pub(super) fn new() -> Self {
393 Self {
394 params: Vec::new(),
395 logical_params: Vec::new(),
396 state_size: None,
397 init: None,
398 update: None,
399 combine: None,
400 finalize: None,
401 destructor: None,
402 null_handling: NullHandling::DefaultNullHandling,
403 }
404 }
405
406 /// Adds a positional parameter to this overload.
407 ///
408 /// For complex types like `LIST(BIGINT)`, use
409 /// [`param_logical`][Self::param_logical].
410 pub fn param(mut self, type_id: TypeId) -> Self {
411 self.params.push(type_id);
412 self
413 }
414
415 /// Adds a positional parameter with a complex [`LogicalType`].
416 ///
417 /// Use this for parameterized types that [`TypeId`] cannot express, such as
418 /// `LIST(BIGINT)`, `MAP(VARCHAR, INTEGER)`, or `STRUCT(...)`.
419 #[mutants::skip] // position arithmetic tested via E2E
420 pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
421 let position = self.params.len() + self.logical_params.len();
422 self.logical_params.push((position, logical_type));
423 self
424 }
425
426 /// Sets the `state_size` callback for this overload.
427 pub fn state_size(mut self, f: StateSizeFn) -> Self {
428 self.state_size = Some(f);
429 self
430 }
431
432 /// Sets the `init` callback for this overload.
433 pub fn init(mut self, f: StateInitFn) -> Self {
434 self.init = Some(f);
435 self
436 }
437
438 /// Sets the `update` callback for this overload.
439 pub fn update(mut self, f: UpdateFn) -> Self {
440 self.update = Some(f);
441 self
442 }
443
444 /// Sets the `combine` callback for this overload.
445 pub fn combine(mut self, f: CombineFn) -> Self {
446 self.combine = Some(f);
447 self
448 }
449
450 /// Sets the `finalize` callback for this overload.
451 pub fn finalize(mut self, f: FinalizeFn) -> Self {
452 self.finalize = Some(f);
453 self
454 }
455
456 /// Sets the optional destructor callback for this overload.
457 pub fn destructor(mut self, f: DestroyFn) -> Self {
458 self.destructor = Some(f);
459 self
460 }
461
462 /// Sets the NULL handling behaviour for this overload.
463 ///
464 /// By default, `DuckDB` skips NULL rows in aggregate functions
465 /// ([`DefaultNullHandling`][NullHandling::DefaultNullHandling]).
466 /// Set to [`SpecialNullHandling`][NullHandling::SpecialNullHandling] to receive
467 /// NULL values in your `update` callback.
468 pub const fn null_handling(mut self, handling: NullHandling) -> Self {
469 self.null_handling = handling;
470 self
471 }
472}