Skip to main content

quack_rs/aggregate/builder/
single.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6use std::ffi::CString;
7use std::os::raw::c_void;
8
9use libduckdb_sys::{
10    duckdb_aggregate_function_set_destructor, duckdb_aggregate_function_set_extra_info,
11    duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
12    duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
13    duckdb_connection, duckdb_create_aggregate_function, duckdb_delete_callback_t,
14    duckdb_destroy_aggregate_function, duckdb_register_aggregate_function, DuckDBSuccess,
15};
16
17use crate::aggregate::callbacks::{
18    CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
19};
20use crate::error::ExtensionError;
21use crate::types::{LogicalType, NullHandling, TypeId};
22use crate::validate::validate_function_name;
23
24/// Builder for registering a single-signature `DuckDB` aggregate function.
25///
26/// # Pitfall L6
27///
28/// Unlike `duckdb_register_aggregate_function`, this builder also handles
29/// the case where you later want to convert to a function set — it sets the
30/// function name correctly.
31///
32/// # Example
33///
34/// ```rust,no_run
35/// use quack_rs::aggregate::AggregateFunctionBuilder;
36/// use quack_rs::types::TypeId;
37/// use libduckdb_sys::{duckdb_connection, duckdb_function_info, duckdb_aggregate_state,
38///                     duckdb_data_chunk, duckdb_vector, idx_t};
39///
40/// unsafe extern "C" fn state_size(_: duckdb_function_info) -> idx_t { 8 }
41/// unsafe extern "C" fn state_init(_: duckdb_function_info, _: duckdb_aggregate_state) {}
42/// unsafe extern "C" fn update(_: duckdb_function_info, _: duckdb_data_chunk, _: duckdb_aggregate_state) {}
43/// unsafe extern "C" fn combine(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_aggregate_state, _: idx_t) {}
44/// unsafe extern "C" fn finalize(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_vector, _: idx_t, _: idx_t) {}
45///
46/// // fn register(con: duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
47/// //     AggregateFunctionBuilder::new("word_count")
48/// //         .param(TypeId::Varchar)
49/// //         .returns(TypeId::BigInt)
50/// //         .state_size(state_size)
51/// //         .init(state_init)
52/// //         .update(update)
53/// //         .combine(combine)
54/// //         .finalize(finalize)
55/// //         .register(con)
56/// // }
57/// ```
58#[must_use]
59pub struct AggregateFunctionBuilder {
60    pub(super) name: CString,
61    pub(super) params: Vec<TypeId>,
62    pub(super) logical_params: Vec<(usize, LogicalType)>,
63    pub(super) return_type: Option<TypeId>,
64    pub(super) return_logical: Option<LogicalType>,
65    pub(super) state_size: Option<StateSizeFn>,
66    pub(super) init: Option<StateInitFn>,
67    pub(super) update: Option<UpdateFn>,
68    pub(super) combine: Option<CombineFn>,
69    pub(super) finalize: Option<FinalizeFn>,
70    pub(super) destructor: Option<DestroyFn>,
71    pub(super) null_handling: NullHandling,
72    pub(super) extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
73}
74
75impl AggregateFunctionBuilder {
76    /// Creates a new builder for an aggregate function with the given name.
77    ///
78    /// # Panics
79    ///
80    /// Panics if `name` contains an interior null byte.
81    pub fn new(name: &str) -> Self {
82        Self {
83            name: CString::new(name).expect("function name must not contain null bytes"),
84            params: Vec::new(),
85            logical_params: Vec::new(),
86            return_type: None,
87            return_logical: None,
88            state_size: None,
89            init: None,
90            update: None,
91            combine: None,
92            finalize: None,
93            destructor: None,
94            null_handling: NullHandling::DefaultNullHandling,
95            extra_info: None,
96        }
97    }
98
99    /// Creates a new builder with function name validation.
100    ///
101    /// Unlike [`new`][Self::new], this method validates the function name against
102    /// `DuckDB` naming conventions and returns an error instead of panicking.
103    ///
104    /// # Errors
105    ///
106    /// Returns `ExtensionError` if the name is empty, too long, contains invalid
107    /// characters, or does not start with a lowercase letter or underscore.
108    pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
109        validate_function_name(name)?;
110        let c_name = CString::new(name)
111            .map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
112        Ok(Self {
113            name: c_name,
114            params: Vec::new(),
115            logical_params: Vec::new(),
116            return_type: None,
117            return_logical: None,
118            state_size: None,
119            init: None,
120            update: None,
121            combine: None,
122            finalize: None,
123            destructor: None,
124            null_handling: NullHandling::DefaultNullHandling,
125            extra_info: None,
126        })
127    }
128
129    /// Returns the function name.
130    ///
131    /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
132    pub fn name(&self) -> &str {
133        self.name.to_str().unwrap_or("")
134    }
135
136    /// Adds a positional parameter with the given type.
137    ///
138    /// Call this once per parameter in order. For complex types like
139    /// `LIST(BIGINT)` or `MAP(VARCHAR, INTEGER)`, use [`param_logical`][Self::param_logical].
140    pub fn param(mut self, type_id: TypeId) -> Self {
141        self.params.push(type_id);
142        self
143    }
144
145    /// Adds a positional parameter with a complex [`LogicalType`].
146    ///
147    /// Use this for parameterized types that [`TypeId`] cannot express, such as
148    /// `LIST(BIGINT)`, `MAP(VARCHAR, INTEGER)`, or `STRUCT(...)`.
149    ///
150    /// The parameter position is determined by the total number of `param` and
151    /// `param_logical` calls made so far.
152    ///
153    /// # Example
154    ///
155    /// ```rust,no_run
156    /// use quack_rs::aggregate::AggregateFunctionBuilder;
157    /// use quack_rs::types::{LogicalType, TypeId};
158    ///
159    /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
160    /// //     AggregateFunctionBuilder::new("my_func")
161    /// //         .param(TypeId::Varchar)
162    /// //         .param_logical(LogicalType::list(TypeId::BigInt))
163    /// //         .returns(TypeId::BigInt)
164    /// //         // ... callbacks ...
165    /// //         ;
166    /// //     Ok(())
167    /// // }
168    /// ```
169    #[mutants::skip] // position arithmetic tested via E2E
170    pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
171        let position = self.params.len() + self.logical_params.len();
172        self.logical_params.push((position, logical_type));
173        self
174    }
175
176    /// Sets the return type for this function.
177    ///
178    /// For complex return types like `LIST(BIGINT)`, use
179    /// [`returns_logical`][Self::returns_logical] instead.
180    pub const fn returns(mut self, type_id: TypeId) -> Self {
181        self.return_type = Some(type_id);
182        self
183    }
184
185    /// Sets the return type to a complex [`LogicalType`].
186    ///
187    /// Use this for parameterized return types that [`TypeId`] cannot express,
188    /// such as `LIST(BOOLEAN)`, `LIST(TIMESTAMP)`, `MAP(VARCHAR, INTEGER)`, etc.
189    ///
190    /// If both `returns` and `returns_logical` are called, the logical type takes
191    /// precedence.
192    ///
193    /// # Example
194    ///
195    /// ```rust,no_run
196    /// use quack_rs::aggregate::AggregateFunctionBuilder;
197    /// use quack_rs::types::{LogicalType, TypeId};
198    ///
199    /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
200    /// //     AggregateFunctionBuilder::new("retention")
201    /// //         .param(TypeId::Boolean)
202    /// //         .param(TypeId::Boolean)
203    /// //         .returns_logical(LogicalType::list(TypeId::Boolean))
204    /// //         // ... callbacks ...
205    /// //         ;
206    /// //     Ok(())
207    /// // }
208    /// ```
209    pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
210        self.return_logical = Some(logical_type);
211        self
212    }
213
214    /// Sets the `state_size` callback.
215    pub fn state_size(mut self, f: StateSizeFn) -> Self {
216        self.state_size = Some(f);
217        self
218    }
219
220    /// Sets the `state_init` callback.
221    pub fn init(mut self, f: StateInitFn) -> Self {
222        self.init = Some(f);
223        self
224    }
225
226    /// Sets the `update` callback.
227    pub fn update(mut self, f: UpdateFn) -> Self {
228        self.update = Some(f);
229        self
230    }
231
232    /// Sets the `combine` callback.
233    pub fn combine(mut self, f: CombineFn) -> Self {
234        self.combine = Some(f);
235        self
236    }
237
238    /// Sets the `finalize` callback.
239    pub fn finalize(mut self, f: FinalizeFn) -> Self {
240        self.finalize = Some(f);
241        self
242    }
243
244    /// Sets the optional `destructor` callback.
245    ///
246    /// Required if your state allocates heap memory (e.g., when using
247    /// [`FfiState<T>`][crate::aggregate::FfiState]).
248    pub fn destructor(mut self, f: DestroyFn) -> Self {
249        self.destructor = Some(f);
250        self
251    }
252
253    /// Sets the NULL handling behaviour for this aggregate function.
254    ///
255    /// By default, `DuckDB` skips NULL rows in aggregate functions
256    /// ([`DefaultNullHandling`][NullHandling::DefaultNullHandling]).
257    /// Set to [`SpecialNullHandling`][NullHandling::SpecialNullHandling] to receive
258    /// NULL values in your `update` callback.
259    pub const fn null_handling(mut self, handling: NullHandling) -> Self {
260        self.null_handling = handling;
261        self
262    }
263
264    /// Attaches arbitrary data to this aggregate function.
265    ///
266    /// The data pointer is available inside callbacks via
267    /// `duckdb_aggregate_function_get_extra_info`. The `destroy` callback is
268    /// called by `DuckDB` when the function is dropped to free the data.
269    ///
270    /// # Safety
271    ///
272    /// `data` must point to valid memory that outlives the function registration,
273    /// or will be freed by `destroy`. The typical pattern
274    /// is to box your data: `Box::into_raw(Box::new(my_data)).cast()`.
275    pub unsafe fn extra_info(
276        mut self,
277        data: *mut c_void,
278        destroy: duckdb_delete_callback_t,
279    ) -> Self {
280        self.extra_info = Some((data, destroy));
281        self
282    }
283
284    /// Registers the aggregate function on the given connection.
285    ///
286    /// # Errors
287    ///
288    /// Returns `ExtensionError` if:
289    /// - The return type was not set.
290    /// - Any required callback was not set.
291    /// - `DuckDB` reports a registration failure.
292    ///
293    /// # Safety
294    ///
295    /// `con` must be a valid, open `duckdb_connection`.
296    pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
297        // Resolve return type: prefer explicit LogicalType over TypeId.
298        let ret_lt = if let Some(lt) = self.return_logical {
299            lt
300        } else if let Some(id) = self.return_type {
301            LogicalType::new(id)
302        } else {
303            return Err(ExtensionError::new("return type not set"));
304        };
305
306        let state_size = self
307            .state_size
308            .ok_or_else(|| ExtensionError::new("state_size callback not set"))?;
309        let init = self
310            .init
311            .ok_or_else(|| ExtensionError::new("init callback not set"))?;
312        let update = self
313            .update
314            .ok_or_else(|| ExtensionError::new("update callback not set"))?;
315        let combine = self
316            .combine
317            .ok_or_else(|| ExtensionError::new("combine callback not set"))?;
318        let finalize = self
319            .finalize
320            .ok_or_else(|| ExtensionError::new("finalize callback not set"))?;
321
322        // SAFETY: duckdb_create_aggregate_function allocates a new function handle.
323        let mut func = unsafe { duckdb_create_aggregate_function() };
324
325        // SAFETY: func is a valid newly created function handle.
326        unsafe {
327            duckdb_aggregate_function_set_name(func, self.name.as_ptr());
328        }
329
330        // Add parameters: merge simple TypeId params and complex LogicalType params
331        // in the order they were added (tracked by position).
332        {
333            let mut simple_idx = 0;
334            let mut logical_idx = 0;
335            let total = self.params.len() + self.logical_params.len();
336            for pos in 0..total {
337                if logical_idx < self.logical_params.len()
338                    && self.logical_params[logical_idx].0 == pos
339                {
340                    // SAFETY: func and logical type handle are valid.
341                    unsafe {
342                        libduckdb_sys::duckdb_aggregate_function_add_parameter(
343                            func,
344                            self.logical_params[logical_idx].1.as_raw(),
345                        );
346                    }
347                    logical_idx += 1;
348                } else if simple_idx < self.params.len() {
349                    let lt = LogicalType::new(self.params[simple_idx]);
350                    // SAFETY: func and lt.as_raw() are valid.
351                    unsafe {
352                        libduckdb_sys::duckdb_aggregate_function_add_parameter(func, lt.as_raw());
353                    }
354                    simple_idx += 1;
355                }
356            }
357        }
358
359        // Set return type
360        // SAFETY: func and ret_lt.as_raw() are valid.
361        unsafe {
362            duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
363        }
364
365        // Set callbacks
366        // SAFETY: All function pointers are valid extern "C" fn pointers.
367        unsafe {
368            duckdb_aggregate_function_set_functions(
369                func,
370                Some(state_size),
371                Some(init),
372                Some(update),
373                Some(combine),
374                Some(finalize),
375            );
376        }
377
378        if let Some(dtor) = self.destructor {
379            // SAFETY: dtor is a valid extern "C" fn pointer.
380            unsafe {
381                duckdb_aggregate_function_set_destructor(func, Some(dtor));
382            }
383        }
384
385        // Set special NULL handling if requested
386        if self.null_handling == NullHandling::SpecialNullHandling {
387            // SAFETY: func is a valid aggregate function handle.
388            unsafe {
389                duckdb_aggregate_function_set_special_handling(func);
390            }
391        }
392
393        // Set extra info if provided
394        if let Some((data, destroy)) = self.extra_info {
395            // SAFETY: func is valid; data and destroy are provided by caller.
396            unsafe {
397                duckdb_aggregate_function_set_extra_info(func, data, destroy);
398            }
399        }
400
401        // Register
402        // SAFETY: con is a valid open connection, func is fully configured.
403        let result = unsafe { duckdb_register_aggregate_function(con, func) };
404
405        // SAFETY: func was created above and must be destroyed after use.
406        unsafe {
407            duckdb_destroy_aggregate_function(&raw mut func);
408        }
409
410        if result == DuckDBSuccess {
411            Ok(())
412        } else {
413            Err(ExtensionError::new(format!(
414                "duckdb_register_aggregate_function failed for '{}'",
415                self.name.to_string_lossy()
416            )))
417        }
418    }
419}