quack_rs/aggregate/builder/single.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6use std::ffi::CString;
7use std::os::raw::c_void;
8
9use libduckdb_sys::{
10 duckdb_aggregate_function_set_destructor, duckdb_aggregate_function_set_extra_info,
11 duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
12 duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
13 duckdb_connection, duckdb_create_aggregate_function, duckdb_delete_callback_t,
14 duckdb_destroy_aggregate_function, duckdb_register_aggregate_function, DuckDBSuccess,
15};
16
17use crate::aggregate::callbacks::{
18 CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
19};
20use crate::error::ExtensionError;
21use crate::types::{LogicalType, NullHandling, TypeId};
22use crate::validate::validate_function_name;
23
24/// Builder for registering a single-signature `DuckDB` aggregate function.
25///
26/// # Pitfall L6
27///
28/// Unlike `duckdb_register_aggregate_function`, this builder also handles
29/// the case where you later want to convert to a function set — it sets the
30/// function name correctly.
31///
32/// # Example
33///
34/// ```rust,no_run
35/// use quack_rs::aggregate::AggregateFunctionBuilder;
36/// use quack_rs::types::TypeId;
37/// use libduckdb_sys::{duckdb_connection, duckdb_function_info, duckdb_aggregate_state,
38/// duckdb_data_chunk, duckdb_vector, idx_t};
39///
40/// unsafe extern "C" fn state_size(_: duckdb_function_info) -> idx_t { 8 }
41/// unsafe extern "C" fn state_init(_: duckdb_function_info, _: duckdb_aggregate_state) {}
42/// unsafe extern "C" fn update(_: duckdb_function_info, _: duckdb_data_chunk, _: duckdb_aggregate_state) {}
43/// unsafe extern "C" fn combine(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_aggregate_state, _: idx_t) {}
44/// unsafe extern "C" fn finalize(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_vector, _: idx_t, _: idx_t) {}
45///
46/// // fn register(con: duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
47/// // AggregateFunctionBuilder::new("word_count")
48/// // .param(TypeId::Varchar)
49/// // .returns(TypeId::BigInt)
50/// // .state_size(state_size)
51/// // .init(state_init)
52/// // .update(update)
53/// // .combine(combine)
54/// // .finalize(finalize)
55/// // .register(con)
56/// // }
57/// ```
58#[must_use]
59pub struct AggregateFunctionBuilder {
60 pub(super) name: CString,
61 pub(super) params: Vec<TypeId>,
62 pub(super) logical_params: Vec<(usize, LogicalType)>,
63 pub(super) return_type: Option<TypeId>,
64 pub(super) return_logical: Option<LogicalType>,
65 pub(super) state_size: Option<StateSizeFn>,
66 pub(super) init: Option<StateInitFn>,
67 pub(super) update: Option<UpdateFn>,
68 pub(super) combine: Option<CombineFn>,
69 pub(super) finalize: Option<FinalizeFn>,
70 pub(super) destructor: Option<DestroyFn>,
71 pub(super) null_handling: NullHandling,
72 pub(super) extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
73}
74
75impl AggregateFunctionBuilder {
76 /// Creates a new builder for an aggregate function with the given name.
77 ///
78 /// # Panics
79 ///
80 /// Panics if `name` contains an interior null byte.
81 pub fn new(name: &str) -> Self {
82 Self {
83 name: CString::new(name).expect("function name must not contain null bytes"),
84 params: Vec::new(),
85 logical_params: Vec::new(),
86 return_type: None,
87 return_logical: None,
88 state_size: None,
89 init: None,
90 update: None,
91 combine: None,
92 finalize: None,
93 destructor: None,
94 null_handling: NullHandling::DefaultNullHandling,
95 extra_info: None,
96 }
97 }
98
99 /// Creates a new builder with function name validation.
100 ///
101 /// Unlike [`new`][Self::new], this method validates the function name against
102 /// `DuckDB` naming conventions and returns an error instead of panicking.
103 ///
104 /// # Errors
105 ///
106 /// Returns `ExtensionError` if the name is empty, too long, contains invalid
107 /// characters, or does not start with a lowercase letter or underscore.
108 pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
109 validate_function_name(name)?;
110 let c_name = CString::new(name)
111 .map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
112 Ok(Self {
113 name: c_name,
114 params: Vec::new(),
115 logical_params: Vec::new(),
116 return_type: None,
117 return_logical: None,
118 state_size: None,
119 init: None,
120 update: None,
121 combine: None,
122 finalize: None,
123 destructor: None,
124 null_handling: NullHandling::DefaultNullHandling,
125 extra_info: None,
126 })
127 }
128
129 /// Returns the function name.
130 ///
131 /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
132 pub fn name(&self) -> &str {
133 self.name.to_str().unwrap_or("")
134 }
135
136 /// Adds a positional parameter with the given type.
137 ///
138 /// Call this once per parameter in order. For complex types like
139 /// `LIST(BIGINT)` or `MAP(VARCHAR, INTEGER)`, use [`param_logical`][Self::param_logical].
140 pub fn param(mut self, type_id: TypeId) -> Self {
141 self.params.push(type_id);
142 self
143 }
144
145 /// Adds a positional parameter with a complex [`LogicalType`].
146 ///
147 /// Use this for parameterized types that [`TypeId`] cannot express, such as
148 /// `LIST(BIGINT)`, `MAP(VARCHAR, INTEGER)`, or `STRUCT(...)`.
149 ///
150 /// The parameter position is determined by the total number of `param` and
151 /// `param_logical` calls made so far.
152 ///
153 /// # Example
154 ///
155 /// ```rust,no_run
156 /// use quack_rs::aggregate::AggregateFunctionBuilder;
157 /// use quack_rs::types::{LogicalType, TypeId};
158 ///
159 /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
160 /// // AggregateFunctionBuilder::new("my_func")
161 /// // .param(TypeId::Varchar)
162 /// // .param_logical(LogicalType::list(TypeId::BigInt))
163 /// // .returns(TypeId::BigInt)
164 /// // // ... callbacks ...
165 /// // ;
166 /// // Ok(())
167 /// // }
168 /// ```
169 #[mutants::skip] // position arithmetic tested via E2E
170 pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
171 let position = self.params.len() + self.logical_params.len();
172 self.logical_params.push((position, logical_type));
173 self
174 }
175
176 /// Sets the return type for this function.
177 ///
178 /// For complex return types like `LIST(BIGINT)`, use
179 /// [`returns_logical`][Self::returns_logical] instead.
180 pub const fn returns(mut self, type_id: TypeId) -> Self {
181 self.return_type = Some(type_id);
182 self
183 }
184
185 /// Sets the return type to a complex [`LogicalType`].
186 ///
187 /// Use this for parameterized return types that [`TypeId`] cannot express,
188 /// such as `LIST(BOOLEAN)`, `LIST(TIMESTAMP)`, `MAP(VARCHAR, INTEGER)`, etc.
189 ///
190 /// If both `returns` and `returns_logical` are called, the logical type takes
191 /// precedence.
192 ///
193 /// # Example
194 ///
195 /// ```rust,no_run
196 /// use quack_rs::aggregate::AggregateFunctionBuilder;
197 /// use quack_rs::types::{LogicalType, TypeId};
198 ///
199 /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
200 /// // AggregateFunctionBuilder::new("retention")
201 /// // .param(TypeId::Boolean)
202 /// // .param(TypeId::Boolean)
203 /// // .returns_logical(LogicalType::list(TypeId::Boolean))
204 /// // // ... callbacks ...
205 /// // ;
206 /// // Ok(())
207 /// // }
208 /// ```
209 pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
210 self.return_logical = Some(logical_type);
211 self
212 }
213
214 /// Sets the `state_size` callback.
215 pub fn state_size(mut self, f: StateSizeFn) -> Self {
216 self.state_size = Some(f);
217 self
218 }
219
220 /// Sets the `state_init` callback.
221 pub fn init(mut self, f: StateInitFn) -> Self {
222 self.init = Some(f);
223 self
224 }
225
226 /// Sets the `update` callback.
227 pub fn update(mut self, f: UpdateFn) -> Self {
228 self.update = Some(f);
229 self
230 }
231
232 /// Sets the `combine` callback.
233 pub fn combine(mut self, f: CombineFn) -> Self {
234 self.combine = Some(f);
235 self
236 }
237
238 /// Sets the `finalize` callback.
239 pub fn finalize(mut self, f: FinalizeFn) -> Self {
240 self.finalize = Some(f);
241 self
242 }
243
244 /// Sets the optional `destructor` callback.
245 ///
246 /// Required if your state allocates heap memory (e.g., when using
247 /// [`FfiState<T>`][crate::aggregate::FfiState]).
248 pub fn destructor(mut self, f: DestroyFn) -> Self {
249 self.destructor = Some(f);
250 self
251 }
252
253 /// Sets the NULL handling behaviour for this aggregate function.
254 ///
255 /// By default, `DuckDB` skips NULL rows in aggregate functions
256 /// ([`DefaultNullHandling`][NullHandling::DefaultNullHandling]).
257 /// Set to [`SpecialNullHandling`][NullHandling::SpecialNullHandling] to receive
258 /// NULL values in your `update` callback.
259 pub const fn null_handling(mut self, handling: NullHandling) -> Self {
260 self.null_handling = handling;
261 self
262 }
263
264 /// Attaches arbitrary data to this aggregate function.
265 ///
266 /// The data pointer is available inside callbacks via
267 /// `duckdb_aggregate_function_get_extra_info`. The `destroy` callback is
268 /// called by `DuckDB` when the function is dropped to free the data.
269 ///
270 /// # Safety
271 ///
272 /// `data` must point to valid memory that outlives the function registration,
273 /// or will be freed by `destroy`. The typical pattern
274 /// is to box your data: `Box::into_raw(Box::new(my_data)).cast()`.
275 pub unsafe fn extra_info(
276 mut self,
277 data: *mut c_void,
278 destroy: duckdb_delete_callback_t,
279 ) -> Self {
280 self.extra_info = Some((data, destroy));
281 self
282 }
283
284 /// Registers the aggregate function on the given connection.
285 ///
286 /// # Errors
287 ///
288 /// Returns `ExtensionError` if:
289 /// - The return type was not set.
290 /// - Any required callback was not set.
291 /// - `DuckDB` reports a registration failure.
292 ///
293 /// # Safety
294 ///
295 /// `con` must be a valid, open `duckdb_connection`.
296 pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
297 // Resolve return type: prefer explicit LogicalType over TypeId.
298 let ret_lt = if let Some(lt) = self.return_logical {
299 lt
300 } else if let Some(id) = self.return_type {
301 LogicalType::new(id)
302 } else {
303 return Err(ExtensionError::new("return type not set"));
304 };
305
306 let state_size = self
307 .state_size
308 .ok_or_else(|| ExtensionError::new("state_size callback not set"))?;
309 let init = self
310 .init
311 .ok_or_else(|| ExtensionError::new("init callback not set"))?;
312 let update = self
313 .update
314 .ok_or_else(|| ExtensionError::new("update callback not set"))?;
315 let combine = self
316 .combine
317 .ok_or_else(|| ExtensionError::new("combine callback not set"))?;
318 let finalize = self
319 .finalize
320 .ok_or_else(|| ExtensionError::new("finalize callback not set"))?;
321
322 // SAFETY: duckdb_create_aggregate_function allocates a new function handle.
323 let mut func = unsafe { duckdb_create_aggregate_function() };
324
325 // SAFETY: func is a valid newly created function handle.
326 unsafe {
327 duckdb_aggregate_function_set_name(func, self.name.as_ptr());
328 }
329
330 // Add parameters: merge simple TypeId params and complex LogicalType params
331 // in the order they were added (tracked by position).
332 {
333 let mut simple_idx = 0;
334 let mut logical_idx = 0;
335 let total = self.params.len() + self.logical_params.len();
336 for pos in 0..total {
337 if logical_idx < self.logical_params.len()
338 && self.logical_params[logical_idx].0 == pos
339 {
340 // SAFETY: func and logical type handle are valid.
341 unsafe {
342 libduckdb_sys::duckdb_aggregate_function_add_parameter(
343 func,
344 self.logical_params[logical_idx].1.as_raw(),
345 );
346 }
347 logical_idx += 1;
348 } else if simple_idx < self.params.len() {
349 let lt = LogicalType::new(self.params[simple_idx]);
350 // SAFETY: func and lt.as_raw() are valid.
351 unsafe {
352 libduckdb_sys::duckdb_aggregate_function_add_parameter(func, lt.as_raw());
353 }
354 simple_idx += 1;
355 }
356 }
357 }
358
359 // Set return type
360 // SAFETY: func and ret_lt.as_raw() are valid.
361 unsafe {
362 duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
363 }
364
365 // Set callbacks
366 // SAFETY: All function pointers are valid extern "C" fn pointers.
367 unsafe {
368 duckdb_aggregate_function_set_functions(
369 func,
370 Some(state_size),
371 Some(init),
372 Some(update),
373 Some(combine),
374 Some(finalize),
375 );
376 }
377
378 if let Some(dtor) = self.destructor {
379 // SAFETY: dtor is a valid extern "C" fn pointer.
380 unsafe {
381 duckdb_aggregate_function_set_destructor(func, Some(dtor));
382 }
383 }
384
385 // Set special NULL handling if requested
386 if self.null_handling == NullHandling::SpecialNullHandling {
387 // SAFETY: func is a valid aggregate function handle.
388 unsafe {
389 duckdb_aggregate_function_set_special_handling(func);
390 }
391 }
392
393 // Set extra info if provided
394 if let Some((data, destroy)) = self.extra_info {
395 // SAFETY: func is valid; data and destroy are provided by caller.
396 unsafe {
397 duckdb_aggregate_function_set_extra_info(func, data, destroy);
398 }
399 }
400
401 // Register
402 // SAFETY: con is a valid open connection, func is fully configured.
403 let result = unsafe { duckdb_register_aggregate_function(con, func) };
404
405 // SAFETY: func was created above and must be destroyed after use.
406 unsafe {
407 duckdb_destroy_aggregate_function(&raw mut func);
408 }
409
410 if result == DuckDBSuccess {
411 Ok(())
412 } else {
413 Err(ExtensionError::new(format!(
414 "duckdb_register_aggregate_function failed for '{}'",
415 self.name.to_string_lossy()
416 )))
417 }
418 }
419}