quack_rs/aggregate/builder/single.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6use std::ffi::CString;
7use std::os::raw::c_void;
8
9use libduckdb_sys::{
10 duckdb_aggregate_function_set_destructor, duckdb_aggregate_function_set_extra_info,
11 duckdb_aggregate_function_set_functions, duckdb_aggregate_function_set_name,
12 duckdb_aggregate_function_set_return_type, duckdb_aggregate_function_set_special_handling,
13 duckdb_connection, duckdb_create_aggregate_function, duckdb_delete_callback_t,
14 duckdb_destroy_aggregate_function, duckdb_register_aggregate_function, DuckDBSuccess,
15};
16
17use crate::aggregate::callbacks::{
18 CombineFn, DestroyFn, FinalizeFn, StateInitFn, StateSizeFn, UpdateFn,
19};
20use crate::error::ExtensionError;
21use crate::types::{LogicalType, NullHandling, TypeId};
22use crate::validate::validate_function_name;
23
24/// Builder for registering a single-signature `DuckDB` aggregate function.
25///
26/// # Pitfall L6
27///
28/// Unlike `duckdb_register_aggregate_function`, this builder also handles
29/// the case where you later want to convert to a function set — it sets the
30/// function name correctly.
31///
32/// # Example
33///
34/// ```rust,no_run
35/// use quack_rs::aggregate::AggregateFunctionBuilder;
36/// use quack_rs::types::TypeId;
37/// use libduckdb_sys::{duckdb_connection, duckdb_function_info, duckdb_aggregate_state,
38/// duckdb_data_chunk, duckdb_vector, idx_t};
39///
40/// unsafe extern "C" fn state_size(_: duckdb_function_info) -> idx_t { 8 }
41/// unsafe extern "C" fn state_init(_: duckdb_function_info, _: duckdb_aggregate_state) {}
42/// unsafe extern "C" fn update(_: duckdb_function_info, _: duckdb_data_chunk, _: duckdb_aggregate_state) {}
43/// unsafe extern "C" fn combine(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_aggregate_state, _: idx_t) {}
44/// unsafe extern "C" fn finalize(_: duckdb_function_info, _: duckdb_aggregate_state, _: duckdb_vector, _: idx_t, _: idx_t) {}
45///
46/// // fn register(con: duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
47/// // AggregateFunctionBuilder::new("word_count")
48/// // .param(TypeId::Varchar)
49/// // .returns(TypeId::BigInt)
50/// // .state_size(state_size)
51/// // .init(state_init)
52/// // .update(update)
53/// // .combine(combine)
54/// // .finalize(finalize)
55/// // .register(con)
56/// // }
57/// ```
58#[must_use]
59pub struct AggregateFunctionBuilder {
60 pub(super) name: CString,
61 pub(super) params: Vec<TypeId>,
62 pub(super) logical_params: Vec<(usize, LogicalType)>,
63 pub(super) return_type: Option<TypeId>,
64 pub(super) return_logical: Option<LogicalType>,
65 pub(super) state_size: Option<StateSizeFn>,
66 pub(super) init: Option<StateInitFn>,
67 pub(super) update: Option<UpdateFn>,
68 pub(super) combine: Option<CombineFn>,
69 pub(super) finalize: Option<FinalizeFn>,
70 pub(super) destructor: Option<DestroyFn>,
71 pub(super) null_handling: NullHandling,
72 pub(super) extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
73}
74
75impl AggregateFunctionBuilder {
76 /// Creates a new builder for an aggregate function with the given name.
77 ///
78 /// # Panics
79 ///
80 /// Panics if `name` contains an interior null byte.
81 pub fn new(name: &str) -> Self {
82 Self {
83 name: CString::new(name).expect("function name must not contain null bytes"),
84 params: Vec::new(),
85 logical_params: Vec::new(),
86 return_type: None,
87 return_logical: None,
88 state_size: None,
89 init: None,
90 update: None,
91 combine: None,
92 finalize: None,
93 destructor: None,
94 null_handling: NullHandling::DefaultNullHandling,
95 extra_info: None,
96 }
97 }
98
99 /// Creates a new builder with function name validation.
100 ///
101 /// Unlike [`new`][Self::new], this method validates the function name against
102 /// `DuckDB` naming conventions and returns an error instead of panicking.
103 ///
104 /// # Errors
105 ///
106 /// Returns `ExtensionError` if the name is empty, too long, contains invalid
107 /// characters, or does not start with a lowercase letter or underscore.
108 pub fn try_new(name: &str) -> Result<Self, ExtensionError> {
109 validate_function_name(name)?;
110 let c_name = CString::new(name)
111 .map_err(|_| ExtensionError::new("function name contains interior null byte"))?;
112 Ok(Self {
113 name: c_name,
114 params: Vec::new(),
115 logical_params: Vec::new(),
116 return_type: None,
117 return_logical: None,
118 state_size: None,
119 init: None,
120 update: None,
121 combine: None,
122 finalize: None,
123 destructor: None,
124 null_handling: NullHandling::DefaultNullHandling,
125 extra_info: None,
126 })
127 }
128
129 /// Returns the function name.
130 ///
131 /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
132 pub fn name(&self) -> &str {
133 self.name.to_str().unwrap_or("")
134 }
135
136 /// Adds a positional parameter with the given type.
137 ///
138 /// Call this once per parameter in order. For complex types like
139 /// `LIST(BIGINT)` or `MAP(VARCHAR, INTEGER)`, use [`param_logical`][Self::param_logical].
140 pub fn param(mut self, type_id: TypeId) -> Self {
141 self.params.push(type_id);
142 self
143 }
144
145 /// Adds a positional parameter with a complex [`LogicalType`].
146 ///
147 /// Use this for parameterized types that [`TypeId`] cannot express, such as
148 /// `LIST(BIGINT)`, `MAP(VARCHAR, INTEGER)`, or `STRUCT(...)`.
149 ///
150 /// The parameter position is determined by the total number of `param` and
151 /// `param_logical` calls made so far.
152 ///
153 /// # Example
154 ///
155 /// ```rust,no_run
156 /// use quack_rs::aggregate::AggregateFunctionBuilder;
157 /// use quack_rs::types::{LogicalType, TypeId};
158 ///
159 /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
160 /// // AggregateFunctionBuilder::new("my_func")
161 /// // .param(TypeId::Varchar)
162 /// // .param_logical(LogicalType::list(TypeId::BigInt))
163 /// // .returns(TypeId::BigInt)
164 /// // // ... callbacks ...
165 /// // ;
166 /// // Ok(())
167 /// // }
168 /// ```
169 pub fn param_logical(mut self, logical_type: LogicalType) -> Self {
170 let position = self.params.len() + self.logical_params.len();
171 self.logical_params.push((position, logical_type));
172 self
173 }
174
175 /// Sets the return type for this function.
176 ///
177 /// For complex return types like `LIST(BIGINT)`, use
178 /// [`returns_logical`][Self::returns_logical] instead.
179 pub const fn returns(mut self, type_id: TypeId) -> Self {
180 self.return_type = Some(type_id);
181 self
182 }
183
184 /// Sets the return type to a complex [`LogicalType`].
185 ///
186 /// Use this for parameterized return types that [`TypeId`] cannot express,
187 /// such as `LIST(BOOLEAN)`, `LIST(TIMESTAMP)`, `MAP(VARCHAR, INTEGER)`, etc.
188 ///
189 /// If both `returns` and `returns_logical` are called, the logical type takes
190 /// precedence.
191 ///
192 /// # Example
193 ///
194 /// ```rust,no_run
195 /// use quack_rs::aggregate::AggregateFunctionBuilder;
196 /// use quack_rs::types::{LogicalType, TypeId};
197 ///
198 /// // fn register(con: libduckdb_sys::duckdb_connection) -> Result<(), quack_rs::error::ExtensionError> {
199 /// // AggregateFunctionBuilder::new("retention")
200 /// // .param(TypeId::Boolean)
201 /// // .param(TypeId::Boolean)
202 /// // .returns_logical(LogicalType::list(TypeId::Boolean))
203 /// // // ... callbacks ...
204 /// // ;
205 /// // Ok(())
206 /// // }
207 /// ```
208 pub fn returns_logical(mut self, logical_type: LogicalType) -> Self {
209 self.return_logical = Some(logical_type);
210 self
211 }
212
213 /// Sets the `state_size` callback.
214 pub fn state_size(mut self, f: StateSizeFn) -> Self {
215 self.state_size = Some(f);
216 self
217 }
218
219 /// Sets the `state_init` callback.
220 pub fn init(mut self, f: StateInitFn) -> Self {
221 self.init = Some(f);
222 self
223 }
224
225 /// Sets the `update` callback.
226 pub fn update(mut self, f: UpdateFn) -> Self {
227 self.update = Some(f);
228 self
229 }
230
231 /// Sets the `combine` callback.
232 pub fn combine(mut self, f: CombineFn) -> Self {
233 self.combine = Some(f);
234 self
235 }
236
237 /// Sets the `finalize` callback.
238 pub fn finalize(mut self, f: FinalizeFn) -> Self {
239 self.finalize = Some(f);
240 self
241 }
242
243 /// Sets the optional `destructor` callback.
244 ///
245 /// Required if your state allocates heap memory (e.g., when using
246 /// [`FfiState<T>`][crate::aggregate::FfiState]).
247 pub fn destructor(mut self, f: DestroyFn) -> Self {
248 self.destructor = Some(f);
249 self
250 }
251
252 /// Sets the NULL handling behaviour for this aggregate function.
253 ///
254 /// By default, `DuckDB` skips NULL rows in aggregate functions
255 /// ([`DefaultNullHandling`][NullHandling::DefaultNullHandling]).
256 /// Set to [`SpecialNullHandling`][NullHandling::SpecialNullHandling] to receive
257 /// NULL values in your `update` callback.
258 pub const fn null_handling(mut self, handling: NullHandling) -> Self {
259 self.null_handling = handling;
260 self
261 }
262
263 /// Attaches arbitrary data to this aggregate function.
264 ///
265 /// The data pointer is available inside callbacks via
266 /// `duckdb_aggregate_function_get_extra_info`. The `destroy` callback is
267 /// called by `DuckDB` when the function is dropped to free the data.
268 ///
269 /// # Safety
270 ///
271 /// `data` must point to valid memory that outlives the function registration,
272 /// or will be freed by `destroy`. The typical pattern
273 /// is to box your data: `Box::into_raw(Box::new(my_data)).cast()`.
274 pub unsafe fn extra_info(
275 mut self,
276 data: *mut c_void,
277 destroy: duckdb_delete_callback_t,
278 ) -> Self {
279 self.extra_info = Some((data, destroy));
280 self
281 }
282
283 /// Registers the aggregate function on the given connection.
284 ///
285 /// # Errors
286 ///
287 /// Returns `ExtensionError` if:
288 /// - The return type was not set.
289 /// - Any required callback was not set.
290 /// - `DuckDB` reports a registration failure.
291 ///
292 /// # Safety
293 ///
294 /// `con` must be a valid, open `duckdb_connection`.
295 pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
296 // Resolve return type: prefer explicit LogicalType over TypeId.
297 let ret_lt = if let Some(lt) = self.return_logical {
298 lt
299 } else if let Some(id) = self.return_type {
300 LogicalType::new(id)
301 } else {
302 return Err(ExtensionError::new("return type not set"));
303 };
304
305 let state_size = self
306 .state_size
307 .ok_or_else(|| ExtensionError::new("state_size callback not set"))?;
308 let init = self
309 .init
310 .ok_or_else(|| ExtensionError::new("init callback not set"))?;
311 let update = self
312 .update
313 .ok_or_else(|| ExtensionError::new("update callback not set"))?;
314 let combine = self
315 .combine
316 .ok_or_else(|| ExtensionError::new("combine callback not set"))?;
317 let finalize = self
318 .finalize
319 .ok_or_else(|| ExtensionError::new("finalize callback not set"))?;
320
321 // SAFETY: duckdb_create_aggregate_function allocates a new function handle.
322 let func = unsafe { duckdb_create_aggregate_function() };
323
324 // SAFETY: func is a valid newly created function handle.
325 unsafe {
326 duckdb_aggregate_function_set_name(func, self.name.as_ptr());
327 }
328
329 // Add parameters: merge simple TypeId params and complex LogicalType params
330 // in the order they were added (tracked by position).
331 {
332 let mut simple_idx = 0;
333 let mut logical_idx = 0;
334 let total = self.params.len() + self.logical_params.len();
335 for pos in 0..total {
336 if logical_idx < self.logical_params.len()
337 && self.logical_params[logical_idx].0 == pos
338 {
339 // SAFETY: func and logical type handle are valid.
340 unsafe {
341 libduckdb_sys::duckdb_aggregate_function_add_parameter(
342 func,
343 self.logical_params[logical_idx].1.as_raw(),
344 );
345 }
346 logical_idx += 1;
347 } else if simple_idx < self.params.len() {
348 let lt = LogicalType::new(self.params[simple_idx]);
349 // SAFETY: func and lt.as_raw() are valid.
350 unsafe {
351 libduckdb_sys::duckdb_aggregate_function_add_parameter(func, lt.as_raw());
352 }
353 simple_idx += 1;
354 }
355 }
356 }
357
358 // Set return type
359 // SAFETY: func and ret_lt.as_raw() are valid.
360 unsafe {
361 duckdb_aggregate_function_set_return_type(func, ret_lt.as_raw());
362 }
363
364 // Set callbacks
365 // SAFETY: All function pointers are valid extern "C" fn pointers.
366 unsafe {
367 duckdb_aggregate_function_set_functions(
368 func,
369 Some(state_size),
370 Some(init),
371 Some(update),
372 Some(combine),
373 Some(finalize),
374 );
375 }
376
377 if let Some(dtor) = self.destructor {
378 // SAFETY: dtor is a valid extern "C" fn pointer.
379 unsafe {
380 duckdb_aggregate_function_set_destructor(func, Some(dtor));
381 }
382 }
383
384 // Set special NULL handling if requested
385 if self.null_handling == NullHandling::SpecialNullHandling {
386 // SAFETY: func is a valid aggregate function handle.
387 unsafe {
388 duckdb_aggregate_function_set_special_handling(func);
389 }
390 }
391
392 // Set extra info if provided
393 if let Some((data, destroy)) = self.extra_info {
394 // SAFETY: func is valid; data and destroy are provided by caller.
395 unsafe {
396 duckdb_aggregate_function_set_extra_info(func, data, destroy);
397 }
398 }
399
400 // Register
401 // SAFETY: con is a valid open connection, func is fully configured.
402 let result = unsafe { duckdb_register_aggregate_function(con, func) };
403
404 // SAFETY: func was created above and must be destroyed after use.
405 unsafe {
406 duckdb_destroy_aggregate_function(&mut { func });
407 }
408
409 if result == DuckDBSuccess {
410 Ok(())
411 } else {
412 Err(ExtensionError::new(format!(
413 "duckdb_register_aggregate_function failed for '{}'",
414 self.name.to_string_lossy()
415 )))
416 }
417 }
418}