Skip to main content

quack_rs/aggregate/builder/
single.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6use std::ffi::CString;
7use std::os::raw::c_void;
8
9use libduckdb_sys::{
10    duckdb_aggregate_function_set_destructor, duckdb_aggregate_function_set_extra_info,
11    duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
12    duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
13    duckdb_connection, duckdb_create_aggregate_function, duckdb_delete_callback_t,
14    duckdb_destroy_aggregate_function, duckdb_register_aggregate_function, DuckDBSuccess,
15};
16
17use crate::aggregate::callbacks::{
18    CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
19};
20use crate::error::ExtensionError;
21use crate::types::{LogicalType, NullHandling, TypeId};
22use crate::validate::validate_function_name;
23
24/// Builder for registering a single-signature `DuckDB` aggregate function.
25///
26/// # Pitfall L6
27///
28/// Unlike `duckdb_register_aggregate_function`, this builder also handles
29/// the case where you later want to convert to a function set — it sets the
30/// function name correctly.
31///
32/// # Example
33///
34/// ```rust,no_run
35/// use quack_rs::aggregate::AggregateFunctionBuilder;
36/// use quack_rs::types::TypeId;
37/// use libduckdb_sys::{duckdb_connection, duckdb_function_info, duckdb_aggregate_state,
38///                     duckdb_data_chunk, duckdb_vector, idx_t};
39///
40/// unsafe extern "C" fn state_size(_: duckdb_function_info) -> idx_t { 8 }
41/// unsafe extern "C" fn state_init(_: duckdb_function_info, _: duckdb_aggregate_state) {}
42/// unsafe extern "C" fn update(_: duckdb_function_info, _: duckdb_data_chunk, _: duckdb_aggregate_state) {}
43/// unsafe extern "C" fn combine(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_aggregate_state, _: idx_t) {}
44/// unsafe extern "C" fn finalize(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_vector, _: idx_t, _: idx_t) {}
45///
46/// // fn register(con: duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
47/// //     AggregateFunctionBuilder::new("word_count")
48/// //         .param(TypeId::Varchar)
49/// //         .returns(TypeId::BigInt)
50/// //         .state_size(state_size)
51/// //         .init(state_init)
52/// //         .update(update)
53/// //         .combine(combine)
54/// //         .finalize(finalize)
55/// //         .register(con)
56/// // }
57/// ```
58#[must_use]
59pub struct AggregateFunctionBuilder {
60    pub(super) name: CString,
61    pub(super) params: Vec<TypeId>,
62    pub(super) logical_params: Vec<(usize, LogicalType)>,
63    pub(super) return_type: Option<TypeId>,
64    pub(super) return_logical: Option<LogicalType>,
65    pub(super) state_size: Option<StateSizeFn>,
66    pub(super) init: Option<StateInitFn>,
67    pub(super) update: Option<UpdateFn>,
68    pub(super) combine: Option<CombineFn>,
69    pub(super) finalize: Option<FinalizeFn>,
70    pub(super) destructor: Option<DestroyFn>,
71    pub(super) null_handling: NullHandling,
72    pub(super) extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
73}
74
75impl AggregateFunctionBuilder {
76    /// Creates a new builder for an aggregate function with the given name.
77    ///
78    /// # Panics
79    ///
80    /// Panics if `name` contains an interior null byte.
81    pub fn new(name: &str) -> Self {
82        Self {
83            name: CString::new(name).expect("function name must not contain null bytes"),
84            params: Vec::new(),
85            logical_params: Vec::new(),
86            return_type: None,
87            return_logical: None,
88            state_size: None,
89            init: None,
90            update: None,
91            combine: None,
92            finalize: None,
93            destructor: None,
94            null_handling: NullHandling::DefaultNullHandling,
95            extra_info: None,
96        }
97    }
98
99    /// Creates a new builder with function name validation.
100    ///
101    /// Unlike [`new`][Self::new], this method validates the function name against
102    /// `DuckDB` naming conventions and returns an error instead of panicking.
103    ///
104    /// # Errors
105    ///
106    /// Returns `ExtensionError` if the name is empty, too long, contains invalid
107    /// characters, or does not start with a lowercase letter or underscore.
108    pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
109        validate_function_name(name)?;
110        let c_name = CString::new(name)
111            .map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
112        Ok(Self {
113            name: c_name,
114            params: Vec::new(),
115            logical_params: Vec::new(),
116            return_type: None,
117            return_logical: None,
118            state_size: None,
119            init: None,
120            update: None,
121            combine: None,
122            finalize: None,
123            destructor: None,
124            null_handling: NullHandling::DefaultNullHandling,
125            extra_info: None,
126        })
127    }
128
129    /// Returns the function name.
130    ///
131    /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
132    pub fn name(&self) -> &str {
133        self.name.to_str().unwrap_or("")
134    }
135
136    /// Adds a positional parameter with the given type.
137    ///
138    /// Call this once per parameter in order. For complex types like
139    /// `LIST(BIGINT)` or `MAP(VARCHAR, INTEGER)`, use [`param_logical`][Self::param_logical].
140    pub fn param(mut self, type_id: TypeId) -> Self {
141        self.params.push(type_id);
142        self
143    }
144
145    /// Adds a positional parameter with a complex [`LogicalType`].
146    ///
147    /// Use this for parameterized types that [`TypeId`] cannot express, such as
148    /// `LIST(BIGINT)`, `MAP(VARCHAR, INTEGER)`, or `STRUCT(...)`.
149    ///
150    /// The parameter position is determined by the total number of `param` and
151    /// `param_logical` calls made so far.
152    ///
153    /// # Example
154    ///
155    /// ```rust,no_run
156    /// use quack_rs::aggregate::AggregateFunctionBuilder;
157    /// use quack_rs::types::{LogicalType, TypeId};
158    ///
159    /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
160    /// //     AggregateFunctionBuilder::new("my_func")
161    /// //         .param(TypeId::Varchar)
162    /// //         .param_logical(LogicalType::list(TypeId::BigInt))
163    /// //         .returns(TypeId::BigInt)
164    /// //         // ... callbacks ...
165    /// //         ;
166    /// //     Ok(())
167    /// // }
168    /// ```
169    pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
170        let position = self.params.len() + self.logical_params.len();
171        self.logical_params.push((position, logical_type));
172        self
173    }
174
175    /// Sets the return type for this function.
176    ///
177    /// For complex return types like `LIST(BIGINT)`, use
178    /// [`returns_logical`][Self::returns_logical] instead.
179    pub const fn returns(mut self, type_id: TypeId) -> Self {
180        self.return_type = Some(type_id);
181        self
182    }
183
184    /// Sets the return type to a complex [`LogicalType`].
185    ///
186    /// Use this for parameterized return types that [`TypeId`] cannot express,
187    /// such as `LIST(BOOLEAN)`, `LIST(TIMESTAMP)`, `MAP(VARCHAR, INTEGER)`, etc.
188    ///
189    /// If both `returns` and `returns_logical` are called, the logical type takes
190    /// precedence.
191    ///
192    /// # Example
193    ///
194    /// ```rust,no_run
195    /// use quack_rs::aggregate::AggregateFunctionBuilder;
196    /// use quack_rs::types::{LogicalType, TypeId};
197    ///
198    /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
199    /// //     AggregateFunctionBuilder::new("retention")
200    /// //         .param(TypeId::Boolean)
201    /// //         .param(TypeId::Boolean)
202    /// //         .returns_logical(LogicalType::list(TypeId::Boolean))
203    /// //         // ... callbacks ...
204    /// //         ;
205    /// //     Ok(())
206    /// // }
207    /// ```
208    pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
209        self.return_logical = Some(logical_type);
210        self
211    }
212
213    /// Sets the `state_size` callback.
214    pub fn state_size(mut self, f: StateSizeFn) -> Self {
215        self.state_size = Some(f);
216        self
217    }
218
219    /// Sets the `state_init` callback.
220    pub fn init(mut self, f: StateInitFn) -> Self {
221        self.init = Some(f);
222        self
223    }
224
225    /// Sets the `update` callback.
226    pub fn update(mut self, f: UpdateFn) -> Self {
227        self.update = Some(f);
228        self
229    }
230
231    /// Sets the `combine` callback.
232    pub fn combine(mut self, f: CombineFn) -> Self {
233        self.combine = Some(f);
234        self
235    }
236
237    /// Sets the `finalize` callback.
238    pub fn finalize(mut self, f: FinalizeFn) -> Self {
239        self.finalize = Some(f);
240        self
241    }
242
243    /// Sets the optional `destructor` callback.
244    ///
245    /// Required if your state allocates heap memory (e.g., when using
246    /// [`FfiState<T>`][crate::aggregate::FfiState]).
247    pub fn destructor(mut self, f: DestroyFn) -> Self {
248        self.destructor = Some(f);
249        self
250    }
251
252    /// Sets the NULL handling behaviour for this aggregate function.
253    ///
254    /// By default, `DuckDB` skips NULL rows in aggregate functions
255    /// ([`DefaultNullHandling`][NullHandling::DefaultNullHandling]).
256    /// Set to [`SpecialNullHandling`][NullHandling::SpecialNullHandling] to receive
257    /// NULL values in your `update` callback.
258    pub const fn null_handling(mut self, handling: NullHandling) -> Self {
259        self.null_handling = handling;
260        self
261    }
262
263    /// Attaches arbitrary data to this aggregate function.
264    ///
265    /// The data pointer is available inside callbacks via
266    /// `duckdb_aggregate_function_get_extra_info`. The `destroy` callback is
267    /// called by `DuckDB` when the function is dropped to free the data.
268    ///
269    /// # Safety
270    ///
271    /// `data` must point to valid memory that outlives the function registration,
272    /// or will be freed by `destroy`. The typical pattern
273    /// is to box your data: `Box::into_raw(Box::new(my_data)).cast()`.
274    pub unsafe fn extra_info(
275        mut self,
276        data: *mut c_void,
277        destroy: duckdb_delete_callback_t,
278    ) -> Self {
279        self.extra_info = Some((data, destroy));
280        self
281    }
282
283    /// Registers the aggregate function on the given connection.
284    ///
285    /// # Errors
286    ///
287    /// Returns `ExtensionError` if:
288    /// - The return type was not set.
289    /// - Any required callback was not set.
290    /// - `DuckDB` reports a registration failure.
291    ///
292    /// # Safety
293    ///
294    /// `con` must be a valid, open `duckdb_connection`.
295    pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
296        // Resolve return type: prefer explicit LogicalType over TypeId.
297        let ret_lt = if let Some(lt) = self.return_logical {
298            lt
299        } else if let Some(id) = self.return_type {
300            LogicalType::new(id)
301        } else {
302            return Err(ExtensionError::new("return type not set"));
303        };
304
305        let state_size = self
306            .state_size
307            .ok_or_else(|| ExtensionError::new("state_size callback not set"))?;
308        let init = self
309            .init
310            .ok_or_else(|| ExtensionError::new("init callback not set"))?;
311        let update = self
312            .update
313            .ok_or_else(|| ExtensionError::new("update callback not set"))?;
314        let combine = self
315            .combine
316            .ok_or_else(|| ExtensionError::new("combine callback not set"))?;
317        let finalize = self
318            .finalize
319            .ok_or_else(|| ExtensionError::new("finalize callback not set"))?;
320
321        // SAFETY: duckdb_create_aggregate_function allocates a new function handle.
322        let func = unsafe { duckdb_create_aggregate_function() };
323
324        // SAFETY: func is a valid newly created function handle.
325        unsafe {
326            duckdb_aggregate_function_set_name(func, self.name.as_ptr());
327        }
328
329        // Add parameters: merge simple TypeId params and complex LogicalType params
330        // in the order they were added (tracked by position).
331        {
332            let mut simple_idx = 0;
333            let mut logical_idx = 0;
334            let total = self.params.len() + self.logical_params.len();
335            for pos in 0..total {
336                if logical_idx < self.logical_params.len()
337                    && self.logical_params[logical_idx].0 == pos
338                {
339                    // SAFETY: func and logical type handle are valid.
340                    unsafe {
341                        libduckdb_sys::duckdb_aggregate_function_add_parameter(
342                            func,
343                            self.logical_params[logical_idx].1.as_raw(),
344                        );
345                    }
346                    logical_idx += 1;
347                } else if simple_idx < self.params.len() {
348                    let lt = LogicalType::new(self.params[simple_idx]);
349                    // SAFETY: func and lt.as_raw() are valid.
350                    unsafe {
351                        libduckdb_sys::duckdb_aggregate_function_add_parameter(func, lt.as_raw());
352                    }
353                    simple_idx += 1;
354                }
355            }
356        }
357
358        // Set return type
359        // SAFETY: func and ret_lt.as_raw() are valid.
360        unsafe {
361            duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
362        }
363
364        // Set callbacks
365        // SAFETY: All function pointers are valid extern "C" fn pointers.
366        unsafe {
367            duckdb_aggregate_function_set_functions(
368                func,
369                Some(state_size),
370                Some(init),
371                Some(update),
372                Some(combine),
373                Some(finalize),
374            );
375        }
376
377        if let Some(dtor) = self.destructor {
378            // SAFETY: dtor is a valid extern "C" fn pointer.
379            unsafe {
380                duckdb_aggregate_function_set_destructor(func, Some(dtor));
381            }
382        }
383
384        // Set special NULL handling if requested
385        if self.null_handling == NullHandling::SpecialNullHandling {
386            // SAFETY: func is a valid aggregate function handle.
387            unsafe {
388                duckdb_aggregate_function_set_special_handling(func);
389            }
390        }
391
392        // Set extra info if provided
393        if let Some((data, destroy)) = self.extra_info {
394            // SAFETY: func is valid; data and destroy are provided by caller.
395            unsafe {
396                duckdb_aggregate_function_set_extra_info(func, data, destroy);
397            }
398        }
399
400        // Register
401        // SAFETY: con is a valid open connection, func is fully configured.
402        let result = unsafe { duckdb_register_aggregate_function(con, func) };
403
404        // SAFETY: func was created above and must be destroyed after use.
405        unsafe {
406            duckdb_destroy_aggregate_function(&mut { func });
407        }
408
409        if result == DuckDBSuccess {
410            Ok(())
411        } else {
412            Err(ExtensionError::new(format!(
413                "duckdb_register_aggregate_function failed for '{}'",
414                self.name.to_string_lossy()
415            )))
416        }
417    }
418}