Skip to main content

quack_rs/aggregate/builder/
set.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6use std::ffi::CString;
7
8use libduckdb_sys::{
9    duckdb_add_aggregate_function_to_set, duckdb_aggregate_function_set_destructor,
10    duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
11    duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
12    duckdb_connection, duckdb_create_aggregate_function, duckdb_create_aggregate_function_set,
13    duckdb_destroy_aggregate_function, duckdb_destroy_aggregate_function_set,
14    duckdb_register_aggregate_function_set, DuckDBSuccess,
15};
16
17use crate::aggregate::callbacks::{
18    CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
19};
20use crate::error::ExtensionError;
21use crate::types::{LogicalType, NullHandling, TypeId};
22use crate::validate::validate_function_name;
23
24/// Builder for registering a `DuckDB` aggregate function set (multiple overloads).
25///
26/// Use this when your function accepts a variable number of arguments by
27/// registering N overloads (one per arity) under a single name.
28///
29/// # ADR-2: Function sets for variadic signatures
30///
31/// `DuckDB` does not support true varargs for aggregate functions. For functions
32/// that accept 2–32 boolean conditions, register 31 overloads.
33///
34/// # Pitfall L6: Name must be set on each member
35///
36/// This builder calls `duckdb_aggregate_function_set_name` on EVERY individual
37/// function before adding it to the set. If you forget this call, `DuckDB`
38/// silently rejects the registration. Discovery of this bug required reading
39/// `DuckDB`'s own C++ test code at `test/api/capi/test_capi_aggregate_functions.cpp`.
40///
41/// # Example
42///
43/// ```rust,no_run
44/// use quack_rs::aggregate::AggregateFunctionSetBuilder;
45/// use quack_rs::types::{LogicalType, TypeId};
46/// use libduckdb_sys::duckdb_connection;
47///
48/// // fn register_retention(con: duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
49/// //     AggregateFunctionSetBuilder::new("retention")
50/// //         .returns_logical(LogicalType::list(TypeId::Boolean))
51/// //         .overloads(2..=32, |_n, builder| {
52/// //             builder
53/// //                 .state_size(state_size)
54/// //                 .init(state_init)
55/// //                 .update(update)
56/// //                 .combine(combine)
57/// //                 .finalize(finalize)
58/// //                 .destructor(destroy)
59/// //         })
60/// //         .register(con)
61/// // }
62/// ```
63#[must_use]
64pub struct AggregateFunctionSetBuilder {
65    pub(super) name: CString,
66    pub(super) return_type: Option<TypeId>,
67    pub(super) return_logical: Option<LogicalType>,
68    pub(super) overloads: Vec<OverloadSpec>,
69}
70
71/// Specification for one overload within a function set.
72pub(super) struct OverloadSpec {
73    pub(super) params: Vec<TypeId>,
74    pub(super) logical_params: Vec<(usize, LogicalType)>,
75    pub(super) state_size: Option<StateSizeFn>,
76    pub(super) init: Option<StateInitFn>,
77    pub(super) update: Option<UpdateFn>,
78    pub(super) combine: Option<CombineFn>,
79    pub(super) finalize: Option<FinalizeFn>,
80    pub(super) destructor: Option<DestroyFn>,
81    pub(super) null_handling: NullHandling,
82}
83
84impl AggregateFunctionSetBuilder {
85    /// Creates a new builder for a function set with the given name.
86    ///
87    /// # Panics
88    ///
89    /// Panics if `name` contains an interior null byte.
90    pub fn new(name: &str) -> Self {
91        Self {
92            name: CString::new(name).expect("function name must not contain null bytes"),
93            return_type: None,
94            return_logical: None,
95            overloads: Vec::new(),
96        }
97    }
98
99    /// Creates a new builder with function name validation.
100    ///
101    /// # Errors
102    ///
103    /// Returns `ExtensionError` if the name is invalid.
104    /// See [`validate_function_name`] for the full set of rules.
105    pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
106        validate_function_name(name)?;
107        let c_name = CString::new(name)
108            .map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
109        Ok(Self {
110            name: c_name,
111            return_type: None,
112            return_logical: None,
113            overloads: Vec::new(),
114        })
115    }
116
117    /// Returns the function set name.
118    ///
119    /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
120    pub fn name(&self) -> &str {
121        self.name.to_str().unwrap_or("")
122    }
123
124    /// Sets the return type for all overloads in this function set.
125    ///
126    /// For complex return types like `LIST(BIGINT)`, use
127    /// [`returns_logical`][Self::returns_logical] instead.
128    pub const fn returns(mut self, type_id: TypeId) -> Self {
129        self.return_type = Some(type_id);
130        self
131    }
132
133    /// Sets the return type to a complex [`LogicalType`] for all overloads.
134    ///
135    /// Use this for parameterized return types that [`TypeId`] cannot express,
136    /// such as `LIST(BOOLEAN)`, `LIST(TIMESTAMP)`, `MAP(VARCHAR, INTEGER)`, etc.
137    ///
138    /// If both `returns` and `returns_logical` are called, the logical type takes
139    /// precedence.
140    ///
141    /// # Example
142    ///
143    /// ```rust,no_run
144    /// use quack_rs::aggregate::AggregateFunctionSetBuilder;
145    /// use quack_rs::types::{LogicalType, TypeId};
146    ///
147    /// // AggregateFunctionSetBuilder::new("retention")
148    /// //     .returns_logical(LogicalType::list(TypeId::Boolean))
149    /// //     .overloads(2..=32, |n, builder| {
150    /// //         (0..n).fold(builder, |b, _| b.param(TypeId::Boolean))
151    /// //             .state_size(my_state_size)
152    /// //             .init(my_init)
153    /// //             .update(my_update)
154    /// //             .combine(my_combine)
155    /// //             .finalize(my_finalize)
156    /// //     });
157    /// ```
158    pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
159        self.return_logical = Some(logical_type);
160        self
161    }
162
163    /// Adds overloads for each arity in `range`, using the given builder closure.
164    ///
165    /// The closure receives:
166    /// - `n`: the number of parameters for this overload
167    /// - A fresh [`OverloadBuilder`] for configuring callbacks
168    ///
169    /// # Example
170    ///
171    /// ```rust,no_run
172    /// use quack_rs::aggregate::AggregateFunctionSetBuilder;
173    /// use quack_rs::types::TypeId;
174    ///
175    /// // AggregateFunctionSetBuilder::new("retention")
176    /// //     .returns(TypeId::BigInt)
177    /// //     .overloads(2..=32, |n, builder| {
178    /// //         let builder = builder
179    /// //             .state_size(my_state_size)
180    /// //             .init(my_init)
181    /// //             .update(my_update)
182    /// //             .combine(my_combine)
183    /// //             .finalize(my_finalize);
184    /// //         // `n` booleans as params
185    /// //         (0..n).fold(builder, |b, _| b.param(TypeId::Boolean))
186    /// //     });
187    /// ```
188    pub fn overloads<F>(mut self, range: std::ops::RangeInclusive<usize>, f: F) -> Self
189    where
190        F: Fn(usize, OverloadBuilder) -> OverloadBuilder,
191    {
192        for n in range {
193            let builder = f(n, OverloadBuilder::new());
194            self.overloads.push(OverloadSpec {
195                params: builder.params,
196                logical_params: builder.logical_params,
197                state_size: builder.state_size,
198                init: builder.init,
199                update: builder.update,
200                combine: builder.combine,
201                finalize: builder.finalize,
202                destructor: builder.destructor,
203                null_handling: builder.null_handling,
204            });
205        }
206        self
207    }
208
209    /// Registers the function set on the given connection.
210    ///
211    /// # Pitfall L6
212    ///
213    /// This method calls `duckdb_aggregate_function_set_name` on EVERY individual
214    /// function in the set. Omitting this call causes silent registration failure.
215    ///
216    /// # Errors
217    ///
218    /// Returns `ExtensionError` if:
219    /// - Return type was not set.
220    /// - Any overload is missing required callbacks.
221    /// - `DuckDB` reports registration failure.
222    ///
223    /// # Safety
224    ///
225    /// `con` must be a valid, open `duckdb_connection`.
226    #[allow(clippy::too_many_lines)]
227    pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
228        // Resolve return type: prefer explicit LogicalType over TypeId.
229        let ret_lt = if let Some(lt) = self.return_logical {
230            lt
231        } else if let Some(id) = self.return_type {
232            LogicalType::new(id)
233        } else {
234            return Err(ExtensionError::new("return type not set for function set"));
235        };
236
237        if self.overloads.is_empty() {
238            return Err(ExtensionError::new("no overloads added to function set"));
239        }
240
241        // SAFETY: Creates a new aggregate function set handle.
242        let mut set = unsafe { duckdb_create_aggregate_function_set(self.name.as_ptr()) };
243
244        let mut register_error: Option<ExtensionError> = None;
245
246        for overload in &self.overloads {
247            let Some(state_size) = overload.state_size else {
248                register_error = Some(ExtensionError::new("overload missing state_size"));
249                break;
250            };
251            let Some(init) = overload.init else {
252                register_error = Some(ExtensionError::new("overload missing init"));
253                break;
254            };
255            let Some(update) = overload.update else {
256                register_error = Some(ExtensionError::new("overload missing update"));
257                break;
258            };
259            let Some(combine) = overload.combine else {
260                register_error = Some(ExtensionError::new("overload missing combine"));
261                break;
262            };
263            let Some(finalize) = overload.finalize else {
264                register_error = Some(ExtensionError::new("overload missing finalize"));
265                break;
266            };
267
268            // SAFETY: Creates a new aggregate function handle for this overload.
269            let mut func = unsafe { duckdb_create_aggregate_function() };
270
271            // PITFALL L6: CRITICAL — must call this on EACH function, not just the set.
272            // Without this, duckdb_register_aggregate_function_set silently returns DuckDBError.
273            // Discovered by reading DuckDB's test/api/capi/test_capi_aggregate_functions.cpp.
274            unsafe {
275                duckdb_aggregate_function_set_name(func, self.name.as_ptr());
276            }
277
278            // Add parameters: merge simple TypeId params and complex LogicalType params
279            // in the order they were added (tracked by position).
280            {
281                let mut simple_idx = 0;
282                let mut logical_idx = 0;
283                let total = overload.params.len() + overload.logical_params.len();
284                for pos in 0..total {
285                    if logical_idx < overload.logical_params.len()
286                        && overload.logical_params[logical_idx].0 == pos
287                    {
288                        // SAFETY: func and logical type handle are valid.
289                        unsafe {
290                            libduckdb_sys::duckdb_aggregate_function_add_parameter(
291                                func,
292                                overload.logical_params[logical_idx].1.as_raw(),
293                            );
294                        }
295                        logical_idx += 1;
296                    } else if simple_idx < overload.params.len() {
297                        let lt = LogicalType::new(overload.params[simple_idx]);
298                        // SAFETY: func and lt.as_raw() are valid handles.
299                        unsafe {
300                            libduckdb_sys::duckdb_aggregate_function_add_parameter(
301                                func,
302                                lt.as_raw(),
303                            );
304                        }
305                        simple_idx += 1;
306                    }
307                }
308            }
309
310            // Set return type (shared across all overloads)
311            // SAFETY: func and ret_lt.as_raw() are valid.
312            unsafe {
313                duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
314            }
315
316            // Set callbacks
317            unsafe {
318                duckdb_aggregate_function_set_functions(
319                    func,
320                    Some(state_size),
321                    Some(init),
322                    Some(update),
323                    Some(combine),
324                    Some(finalize),
325                );
326            }
327
328            if let Some(dtor) = overload.destructor {
329                unsafe {
330                    duckdb_aggregate_function_set_destructor(func, Some(dtor));
331                }
332            }
333
334            // Set special NULL handling if requested
335            if overload.null_handling == NullHandling::SpecialNullHandling {
336                // SAFETY: func is a valid aggregate function handle.
337                unsafe {
338                    duckdb_aggregate_function_set_special_handling(func);
339                }
340            }
341
342            // Add this function to the set
343            // SAFETY: set and func are valid handles.
344            unsafe {
345                duckdb_add_aggregate_function_to_set(set, func);
346            }
347
348            // SAFETY: func was created above and ownership transferred to the set.
349            unsafe {
350                duckdb_destroy_aggregate_function(&raw mut func);
351            }
352        }
353
354        if register_error.is_none() {
355            // SAFETY: con is valid and set is fully configured.
356            let result = unsafe { duckdb_register_aggregate_function_set(con, set) };
357            if result != DuckDBSuccess {
358                register_error = Some(ExtensionError::new(format!(
359                    "duckdb_register_aggregate_function_set failed for '{}'",
360                    self.name.to_string_lossy()
361                )));
362            }
363        }
364
365        // SAFETY: set was created above and must be destroyed.
366        unsafe {
367            duckdb_destroy_aggregate_function_set(&raw mut set);
368        }
369
370        register_error.map_or(Ok(()), Err)
371    }
372}
373
374/// A builder for one overload within a [`AggregateFunctionSetBuilder`].
375///
376/// Returned by the closure passed to [`AggregateFunctionSetBuilder::overloads`].
377#[must_use]
378pub struct OverloadBuilder {
379    pub(super) params: Vec<TypeId>,
380    pub(super) logical_params: Vec<(usize, LogicalType)>,
381    pub(super) state_size: Option<StateSizeFn>,
382    pub(super) init: Option<StateInitFn>,
383    pub(super) update: Option<UpdateFn>,
384    pub(super) combine: Option<CombineFn>,
385    pub(super) finalize: Option<FinalizeFn>,
386    pub(super) destructor: Option<DestroyFn>,
387    pub(super) null_handling: NullHandling,
388}
389
390impl OverloadBuilder {
391    /// Creates a new `OverloadBuilder`.
392    pub(super) fn new() -> Self {
393        Self {
394            params: Vec::new(),
395            logical_params: Vec::new(),
396            state_size: None,
397            init: None,
398            update: None,
399            combine: None,
400            finalize: None,
401            destructor: None,
402            null_handling: NullHandling::DefaultNullHandling,
403        }
404    }
405
406    /// Adds a positional parameter to this overload.
407    ///
408    /// For complex types like `LIST(BIGINT)`, use
409    /// [`param_logical`][Self::param_logical].
410    pub fn param(mut self, type_id: TypeId) -> Self {
411        self.params.push(type_id);
412        self
413    }
414
415    /// Adds a positional parameter with a complex [`LogicalType`].
416    ///
417    /// Use this for parameterized types that [`TypeId`] cannot express, such as
418    /// `LIST(BIGINT)`, `MAP(VARCHAR, INTEGER)`, or `STRUCT(...)`.
419    #[mutants::skip] // position arithmetic tested via E2E
420    pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
421        let position = self.params.len() + self.logical_params.len();
422        self.logical_params.push((position, logical_type));
423        self
424    }
425
426    /// Sets the `state_size` callback for this overload.
427    pub fn state_size(mut self, f: StateSizeFn) -> Self {
428        self.state_size = Some(f);
429        self
430    }
431
432    /// Sets the `init` callback for this overload.
433    pub fn init(mut self, f: StateInitFn) -> Self {
434        self.init = Some(f);
435        self
436    }
437
438    /// Sets the `update` callback for this overload.
439    pub fn update(mut self, f: UpdateFn) -> Self {
440        self.update = Some(f);
441        self
442    }
443
444    /// Sets the `combine` callback for this overload.
445    pub fn combine(mut self, f: CombineFn) -> Self {
446        self.combine = Some(f);
447        self
448    }
449
450    /// Sets the `finalize` callback for this overload.
451    pub fn finalize(mut self, f: FinalizeFn) -> Self {
452        self.finalize = Some(f);
453        self
454    }
455
456    /// Sets the optional destructor callback for this overload.
457    pub fn destructor(mut self, f: DestroyFn) -> Self {
458        self.destructor = Some(f);
459        self
460    }
461
462    /// Sets the NULL handling behaviour for this overload.
463    ///
464    /// By default, `DuckDB` skips NULL rows in aggregate functions
465    /// ([`DefaultNullHandling`][NullHandling::DefaultNullHandling]).
466    /// Set to [`SpecialNullHandling`][NullHandling::SpecialNullHandling] to receive
467    /// NULL values in your `update` callback.
468    pub const fn null_handling(mut self, handling: NullHandling) -> Self {
469        self.null_handling = handling;
470        self
471    }
472}