quack_rs/cast/builder.rs
1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6//! Builder for registering custom `DuckDB` cast functions.
7
8use std::ffi::CString;
9use std::os::raw::c_void;
10
11use libduckdb_sys::{
12 duckdb_cast_function_get_cast_mode, duckdb_cast_function_get_extra_info,
13 duckdb_cast_function_set_error, duckdb_cast_function_set_extra_info,
14 duckdb_cast_function_set_function, duckdb_cast_function_set_implicit_cast_cost,
15 duckdb_cast_function_set_row_error, duckdb_cast_function_set_source_type,
16 duckdb_cast_function_set_target_type, duckdb_cast_mode_DUCKDB_CAST_TRY, duckdb_connection,
17 duckdb_create_cast_function, duckdb_delete_callback_t, duckdb_destroy_cast_function,
18 duckdb_function_info, duckdb_register_cast_function, duckdb_vector, idx_t, DuckDBSuccess,
19};
20
21use crate::error::ExtensionError;
22use crate::types::{LogicalType, TypeId};
23
24/// Converts a `&str` to `CString` without panicking.
25#[mutants::skip] // private FFI helper — tested in replacement_scan::tests
26fn str_to_cstring(s: &str) -> CString {
27 CString::new(s).unwrap_or_else(|_| {
28 let pos = s.bytes().position(|b| b == 0).unwrap_or(s.len());
29 CString::new(&s.as_bytes()[..pos]).unwrap_or_default()
30 })
31}
32
33// ── Cast mode ─────────────────────────────────────────────────────────────────
34
35/// Whether the cast is called as a regular `CAST` or a `TRY_CAST`.
36///
37/// In [`Try`][CastMode::Try] mode, conversion failures should write `NULL` for
38/// the failed row and call [`CastFunctionInfo::set_row_error`] rather than
39/// aborting the whole query.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum CastMode {
42 /// Regular `CAST` — any failure aborts the query.
43 Normal,
44 /// `TRY_CAST` — failures produce `NULL`; use per-row error reporting.
45 Try,
46}
47
48impl CastMode {
49 const fn from_raw(raw: libduckdb_sys::duckdb_cast_mode) -> Self {
50 if raw == duckdb_cast_mode_DUCKDB_CAST_TRY {
51 Self::Try
52 } else {
53 Self::Normal
54 }
55 }
56}
57
58// ── Callback info wrapper ──────────────────────────────────────────────────────
59
60/// Ergonomic wrapper around the `duckdb_function_info` handle provided to a
61/// cast callback.
62///
63/// Exposes the cast-specific methods that are only meaningful inside a cast
64/// function callback.
65pub struct CastFunctionInfo {
66 info: duckdb_function_info,
67}
68
69impl CastFunctionInfo {
70 /// Wraps a raw `duckdb_function_info` provided by `DuckDB` inside a cast
71 /// callback.
72 ///
73 /// # Safety
74 ///
75 /// `info` must be a valid `duckdb_function_info` passed by `DuckDB` to a
76 /// cast callback.
77 #[inline]
78 #[must_use]
79 pub const unsafe fn new(info: duckdb_function_info) -> Self {
80 Self { info }
81 }
82
83 /// Returns whether this invocation is a `TRY_CAST` or a regular `CAST`.
84 ///
85 /// Check this inside your callback to decide between aborting on error
86 /// ([`CastMode::Normal`]) and producing `NULL` with a per-row error
87 /// ([`CastMode::Try`]).
88 #[must_use]
89 pub fn cast_mode(&self) -> CastMode {
90 // SAFETY: self.info is valid per constructor contract.
91 let raw = unsafe { duckdb_cast_function_get_cast_mode(self.info) };
92 CastMode::from_raw(raw)
93 }
94
95 /// Retrieves the extra-info pointer previously set via
96 /// [`CastFunctionBuilder::extra_info`].
97 ///
98 /// Returns a raw `*mut c_void`. Cast it back to your concrete type.
99 ///
100 /// # Safety
101 ///
102 /// The returned pointer is only valid as long as the cast function is
103 /// registered and `DuckDB` has not yet called the destructor.
104 #[must_use]
105 pub unsafe fn get_extra_info(&self) -> *mut c_void {
106 // SAFETY: self.info is valid per constructor contract.
107 unsafe { duckdb_cast_function_get_extra_info(self.info) }
108 }
109
110 /// Reports a fatal error, causing `DuckDB` to abort the current query.
111 ///
112 /// Use this only in [`CastMode::Normal`]; in [`CastMode::Try`] prefer
113 /// [`set_row_error`][Self::set_row_error] so that failed rows become `NULL`.
114 ///
115 /// If `message` contains an interior null byte it is truncated at that point.
116 #[mutants::skip]
117 pub fn set_error(&self, message: &str) {
118 let c_msg = str_to_cstring(message);
119 // SAFETY: self.info is valid per constructor contract.
120 unsafe {
121 duckdb_cast_function_set_error(self.info, c_msg.as_ptr());
122 }
123 }
124
125 /// Reports a per-row error for `TRY_CAST`.
126 ///
127 /// Records `message` for `row` in the output error vector. The row's
128 /// output value should be set to `NULL` by the caller.
129 ///
130 /// If `message` contains an interior null byte it is truncated at that point.
131 ///
132 /// # Safety
133 ///
134 /// `output` must be the same `duckdb_vector` passed to the cast callback.
135 pub unsafe fn set_row_error(&self, message: &str, row: idx_t, output: duckdb_vector) {
136 let c_msg = str_to_cstring(message);
137 // SAFETY: self.info is valid; output and row are caller-supplied.
138 unsafe {
139 duckdb_cast_function_set_row_error(self.info, c_msg.as_ptr(), row, output);
140 }
141 }
142}
143
144// ── Callback type alias ────────────────────────────────────────────────────────
145
146/// The cast function callback signature.
147///
148/// - `info` — cast function info; use [`CastFunctionInfo`] to wrap it.
149/// - `count` — number of rows in this chunk.
150/// - `input` — source vector (read from this).
151/// - `output` — destination vector (write results here).
152///
153/// Return `true` on success, `false` to signal a fatal cast error.
154pub type CastFn = unsafe extern "C" fn(
155 info: duckdb_function_info,
156 count: idx_t,
157 input: duckdb_vector,
158 output: duckdb_vector,
159) -> bool;
160
161// ── Builder ────────────────────────────────────────────────────────────────────
162
163/// Builder for registering a custom `DuckDB` cast function.
164///
165/// A cast function converts values from a **source** type to a **target** type.
166/// Registering a cast lets `DuckDB` use it both for explicit
167/// `CAST(x AS Target)` syntax and (if an implicit cost is set) for automatic
168/// coercions.
169///
170/// # Example
171///
172/// ```rust,no_run
173/// use quack_rs::cast::{CastFunctionBuilder, CastFunctionInfo, CastMode};
174/// use quack_rs::types::TypeId;
175/// use libduckdb_sys::{duckdb_function_info, duckdb_vector, idx_t};
176///
177/// unsafe extern "C" fn my_cast(
178/// _info: duckdb_function_info,
179/// _count: idx_t,
180/// _input: duckdb_vector,
181/// _output: duckdb_vector,
182/// ) -> bool {
183/// true // implement real conversion here
184/// }
185///
186/// // fn register(con: libduckdb_sys::duckdb_connection)
187/// // -> Result<(), quack_rs::error::ExtensionError>
188/// // {
189/// // unsafe {
190/// // CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer)
191/// // .function(my_cast)
192/// // .register(con)
193/// // }
194/// // }
195/// ```
196#[must_use]
197pub struct CastFunctionBuilder {
198 source: Option<TypeId>,
199 source_logical: Option<LogicalType>,
200 target: Option<TypeId>,
201 target_logical: Option<LogicalType>,
202 function: Option<CastFn>,
203 implicit_cost: Option<i64>,
204 extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
205}
206
207// SAFETY: CastFunctionBuilder owns the extra_info pointer and LogicalType handles
208// until registration. The raw pointers are only sent across threads as part of the
209// builder, which extension authors typically use on a single thread.
210#[allow(clippy::non_send_fields_in_send_ty)]
211unsafe impl Send for CastFunctionBuilder {}
212
213impl CastFunctionBuilder {
214 /// Creates a new builder that will cast `source` values into `target` values.
215 pub const fn new(source: TypeId, target: TypeId) -> Self {
216 Self {
217 source: Some(source),
218 source_logical: None,
219 target: Some(target),
220 target_logical: None,
221 function: None,
222 implicit_cost: None,
223 extra_info: None,
224 }
225 }
226
227 /// Creates a new builder using [`LogicalType`]s for source and target.
228 ///
229 /// Use this when the source or target types are complex (e.g.
230 /// `DECIMAL(18, 3)`, `LIST(VARCHAR)`, etc.) and cannot be expressed as
231 /// simple [`TypeId`] values.
232 pub fn new_logical(source: LogicalType, target: LogicalType) -> Self {
233 Self {
234 source: None,
235 source_logical: Some(source),
236 target: None,
237 target_logical: Some(target),
238 function: None,
239 implicit_cost: None,
240 extra_info: None,
241 }
242 }
243
244 /// Returns the source type this cast converts from (if set via [`new`][Self::new]).
245 ///
246 /// Returns `None` if the source was set via [`new_logical`][Self::new_logical].
247 ///
248 /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
249 pub const fn source(&self) -> Option<TypeId> {
250 self.source
251 }
252
253 /// Returns the target type this cast converts to (if set via [`new`][Self::new]).
254 ///
255 /// Returns `None` if the target was set via [`new_logical`][Self::new_logical].
256 ///
257 /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
258 pub const fn target(&self) -> Option<TypeId> {
259 self.target
260 }
261
262 /// Sets the cast callback.
263 pub fn function(mut self, f: CastFn) -> Self {
264 self.function = Some(f);
265 self
266 }
267
268 /// Sets the implicit cast cost.
269 ///
270 /// When a non-negative cost is provided, `DuckDB` may use this cast
271 /// automatically in expressions where an implicit coercion is needed.
272 /// Lower cost means higher priority. A negative cost or omitting this
273 /// method makes the cast explicit-only.
274 pub const fn implicit_cost(mut self, cost: i64) -> Self {
275 self.implicit_cost = Some(cost);
276 self
277 }
278
279 /// Attaches extra data to the cast function.
280 ///
281 /// The pointer is available inside the callback via
282 /// [`CastFunctionInfo::get_extra_info`].
283 ///
284 /// # Safety
285 ///
286 /// `ptr` must remain valid until `DuckDB` calls `destroy`, or for the
287 /// lifetime of the database if `destroy` is `None`.
288 pub unsafe fn extra_info(
289 mut self,
290 ptr: *mut c_void,
291 destroy: duckdb_delete_callback_t,
292 ) -> Self {
293 self.extra_info = Some((ptr, destroy));
294 self
295 }
296
297 /// Registers the cast function on the given connection.
298 ///
299 /// # Errors
300 ///
301 /// Returns `ExtensionError` if:
302 /// - The function callback was not set.
303 /// - `DuckDB` reports a registration failure.
304 ///
305 /// # Safety
306 ///
307 /// `con` must be a valid, open `duckdb_connection`.
308 pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
309 let function = self
310 .function
311 .ok_or_else(|| ExtensionError::new("cast function callback not set"))?;
312
313 // SAFETY: allocates a new cast function handle.
314 let mut cast = unsafe { duckdb_create_cast_function() };
315
316 // Resolve source type: prefer explicit LogicalType over TypeId.
317 let src_lt = if let Some(lt) = self.source_logical {
318 lt
319 } else if let Some(id) = self.source {
320 LogicalType::new(id)
321 } else {
322 return Err(ExtensionError::new("cast source type not set"));
323 };
324 // SAFETY: cast and src_lt.as_raw() are valid.
325 unsafe {
326 duckdb_cast_function_set_source_type(cast, src_lt.as_raw());
327 }
328
329 // Resolve target type: prefer explicit LogicalType over TypeId.
330 let tgt_lt = if let Some(lt) = self.target_logical {
331 lt
332 } else if let Some(id) = self.target {
333 LogicalType::new(id)
334 } else {
335 return Err(ExtensionError::new("cast target type not set"));
336 };
337 // SAFETY: cast and tgt_lt.as_raw() are valid.
338 unsafe {
339 duckdb_cast_function_set_target_type(cast, tgt_lt.as_raw());
340 }
341
342 // Set callback
343 // SAFETY: function is a valid extern "C" fn pointer.
344 unsafe {
345 duckdb_cast_function_set_function(cast, Some(function));
346 }
347
348 // Set implicit cost if requested
349 if let Some(cost) = self.implicit_cost {
350 // SAFETY: cast is a valid handle.
351 unsafe {
352 duckdb_cast_function_set_implicit_cast_cost(cast, cost);
353 }
354 }
355
356 // Attach extra info if provided
357 if let Some((ptr, destroy)) = self.extra_info {
358 // SAFETY: ptr validity is the caller's responsibility per the safety
359 // contract on extra_info().
360 unsafe {
361 duckdb_cast_function_set_extra_info(cast, ptr, destroy);
362 }
363 }
364
365 // Register
366 // SAFETY: con is a valid open connection, cast is fully configured.
367 let result = unsafe { duckdb_register_cast_function(con, cast) };
368
369 // SAFETY: cast was created above and must be destroyed after use.
370 unsafe {
371 duckdb_destroy_cast_function(&raw mut cast);
372 }
373
374 if result == DuckDBSuccess {
375 Ok(())
376 } else {
377 Err(ExtensionError::new("duckdb_register_cast_function failed"))
378 }
379 }
380}
381
382// ── Tests ──────────────────────────────────────────────────────────────────────
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use libduckdb_sys::{duckdb_function_info, duckdb_vector, idx_t};
388
389 unsafe extern "C" fn noop_cast(
390 _: duckdb_function_info,
391 _: idx_t,
392 _: duckdb_vector,
393 _: duckdb_vector,
394 ) -> bool {
395 true
396 }
397
398 #[test]
399 fn builder_stores_source_and_target() {
400 let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer);
401 assert_eq!(b.source(), Some(TypeId::Varchar));
402 assert_eq!(b.target(), Some(TypeId::Integer));
403 }
404
405 #[test]
406 fn builder_stores_function() {
407 let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer).function(noop_cast);
408 assert!(b.function.is_some());
409 }
410
411 #[test]
412 fn builder_stores_implicit_cost() {
413 let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer).implicit_cost(10);
414 assert_eq!(b.implicit_cost, Some(10));
415 }
416
417 #[test]
418 fn builder_no_function_is_error() {
419 // We cannot call register without a live DuckDB, but we can assert the
420 // function field starts as None.
421 let b = CastFunctionBuilder::new(TypeId::BigInt, TypeId::Double);
422 assert!(b.function.is_none());
423 }
424
425 #[test]
426 fn cast_mode_from_raw_normal() {
427 use libduckdb_sys::duckdb_cast_mode_DUCKDB_CAST_NORMAL;
428 assert_eq!(
429 CastMode::from_raw(duckdb_cast_mode_DUCKDB_CAST_NORMAL),
430 CastMode::Normal
431 );
432 }
433
434 #[test]
435 fn cast_mode_from_raw_try() {
436 assert_eq!(
437 CastMode::from_raw(duckdb_cast_mode_DUCKDB_CAST_TRY),
438 CastMode::Try
439 );
440 }
441
442 #[test]
443 fn cast_function_info_wraps_null() {
444 // Constructing with null must not crash (no DuckDB calls made).
445 let _info = unsafe { CastFunctionInfo::new(std::ptr::null_mut()) };
446 }
447}