Skip to main content

quack_rs/cast/
builder.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tom F. <https://github.com/tomtom215/>
3// My way of giving something small back to the open source community
4// and encouraging more Rust development!
5
6//! Builder for registering custom `DuckDB` cast functions.
7
8use std::ffi::CString;
9use std::os::raw::c_void;
10
11use libduckdb_sys::{
12    duckdb_cast_function_get_cast_mode, duckdb_cast_function_get_extra_info,
13    duckdb_cast_function_set_error, duckdb_cast_function_set_extra_info,
14    duckdb_cast_function_set_function, duckdb_cast_function_set_implicit_cast_cost,
15    duckdb_cast_function_set_row_error, duckdb_cast_function_set_source_type,
16    duckdb_cast_function_set_target_type, duckdb_cast_mode_DUCKDB_CAST_TRY, duckdb_connection,
17    duckdb_create_cast_function, duckdb_delete_callback_t, duckdb_destroy_cast_function,
18    duckdb_function_info, duckdb_register_cast_function, duckdb_vector, idx_t, DuckDBSuccess,
19};
20
21use crate::error::ExtensionError;
22use crate::types::{LogicalType, TypeId};
23
24/// Converts a `&str` to `CString` without panicking.
25#[mutants::skip] // private FFI helper — tested in replacement_scan::tests
26fn str_to_cstring(s: &str) -> CString {
27    CString::new(s).unwrap_or_else(|_| {
28        let pos = s.bytes().position(|b| b == 0).unwrap_or(s.len());
29        CString::new(&s.as_bytes()[..pos]).unwrap_or_default()
30    })
31}
32
33// ── Cast mode ─────────────────────────────────────────────────────────────────
34
35/// Whether the cast is called as a regular `CAST` or a `TRY_CAST`.
36///
37/// In [`Try`][CastMode::Try] mode, conversion failures should write `NULL` for
38/// the failed row and call [`CastFunctionInfo::set_row_error`] rather than
39/// aborting the whole query.
40#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum CastMode {
42    /// Regular `CAST` — any failure aborts the query.
43    Normal,
44    /// `TRY_CAST` — failures produce `NULL`; use per-row error reporting.
45    Try,
46}
47
48impl CastMode {
49    const fn from_raw(raw: libduckdb_sys::duckdb_cast_mode) -> Self {
50        if raw == duckdb_cast_mode_DUCKDB_CAST_TRY {
51            Self::Try
52        } else {
53            Self::Normal
54        }
55    }
56}
57
58// ── Callback info wrapper ──────────────────────────────────────────────────────
59
60/// Ergonomic wrapper around the `duckdb_function_info` handle provided to a
61/// cast callback.
62///
63/// Exposes the cast-specific methods that are only meaningful inside a cast
64/// function callback.
65pub struct CastFunctionInfo {
66    info: duckdb_function_info,
67}
68
69impl CastFunctionInfo {
70    /// Wraps a raw `duckdb_function_info` provided by `DuckDB` inside a cast
71    /// callback.
72    ///
73    /// # Safety
74    ///
75    /// `info` must be a valid `duckdb_function_info` passed by `DuckDB` to a
76    /// cast callback.
77    #[inline]
78    #[must_use]
79    pub const unsafe fn new(info: duckdb_function_info) -> Self {
80        Self { info }
81    }
82
83    /// Returns whether this invocation is a `TRY_CAST` or a regular `CAST`.
84    ///
85    /// Check this inside your callback to decide between aborting on error
86    /// ([`CastMode::Normal`]) and producing `NULL` with a per-row error
87    /// ([`CastMode::Try`]).
88    #[must_use]
89    pub fn cast_mode(&self) -> CastMode {
90        // SAFETY: self.info is valid per constructor contract.
91        let raw = unsafe { duckdb_cast_function_get_cast_mode(self.info) };
92        CastMode::from_raw(raw)
93    }
94
95    /// Retrieves the extra-info pointer previously set via
96    /// [`CastFunctionBuilder::extra_info`].
97    ///
98    /// Returns a raw `*mut c_void`.  Cast it back to your concrete type.
99    ///
100    /// # Safety
101    ///
102    /// The returned pointer is only valid as long as the cast function is
103    /// registered and `DuckDB` has not yet called the destructor.
104    #[must_use]
105    pub unsafe fn get_extra_info(&self) -> *mut c_void {
106        // SAFETY: self.info is valid per constructor contract.
107        unsafe { duckdb_cast_function_get_extra_info(self.info) }
108    }
109
110    /// Reports a fatal error, causing `DuckDB` to abort the current query.
111    ///
112    /// Use this only in [`CastMode::Normal`]; in [`CastMode::Try`] prefer
113    /// [`set_row_error`][Self::set_row_error] so that failed rows become `NULL`.
114    ///
115    /// If `message` contains an interior null byte it is truncated at that point.
116    #[mutants::skip]
117    pub fn set_error(&self, message: &str) {
118        let c_msg = str_to_cstring(message);
119        // SAFETY: self.info is valid per constructor contract.
120        unsafe {
121            duckdb_cast_function_set_error(self.info, c_msg.as_ptr());
122        }
123    }
124
125    /// Reports a per-row error for `TRY_CAST`.
126    ///
127    /// Records `message` for `row` in the output error vector.  The row's
128    /// output value should be set to `NULL` by the caller.
129    ///
130    /// If `message` contains an interior null byte it is truncated at that point.
131    ///
132    /// # Safety
133    ///
134    /// `output` must be the same `duckdb_vector` passed to the cast callback.
135    pub unsafe fn set_row_error(&self, message: &str, row: idx_t, output: duckdb_vector) {
136        let c_msg = str_to_cstring(message);
137        // SAFETY: self.info is valid; output and row are caller-supplied.
138        unsafe {
139            duckdb_cast_function_set_row_error(self.info, c_msg.as_ptr(), row, output);
140        }
141    }
142}
143
144// ── Callback type alias ────────────────────────────────────────────────────────
145
146/// The cast function callback signature.
147///
148/// - `info`   — cast function info; use [`CastFunctionInfo`] to wrap it.
149/// - `count`  — number of rows in this chunk.
150/// - `input`  — source vector (read from this).
151/// - `output` — destination vector (write results here).
152///
153/// Return `true` on success, `false` to signal a fatal cast error.
154pub type CastFn = unsafe extern "C" fn(
155    info: duckdb_function_info,
156    count: idx_t,
157    input: duckdb_vector,
158    output: duckdb_vector,
159) -> bool;
160
161// ── Builder ────────────────────────────────────────────────────────────────────
162
163/// Builder for registering a custom `DuckDB` cast function.
164///
165/// A cast function converts values from a **source** type to a **target** type.
166/// Registering a cast lets `DuckDB` use it both for explicit
167/// `CAST(x AS Target)` syntax and (if an implicit cost is set) for automatic
168/// coercions.
169///
170/// # Example
171///
172/// ```rust,no_run
173/// use quack_rs::cast::{CastFunctionBuilder, CastFunctionInfo, CastMode};
174/// use quack_rs::types::TypeId;
175/// use libduckdb_sys::{duckdb_function_info, duckdb_vector, idx_t};
176///
177/// unsafe extern "C" fn my_cast(
178///     _info: duckdb_function_info,
179///     _count: idx_t,
180///     _input: duckdb_vector,
181///     _output: duckdb_vector,
182/// ) -> bool {
183///     true // implement real conversion here
184/// }
185///
186/// // fn register(con: libduckdb_sys::duckdb_connection)
187/// //     -> Result<(), quack_rs::error::ExtensionError>
188/// // {
189/// //     unsafe {
190/// //         CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer)
191/// //             .function(my_cast)
192/// //             .register(con)
193/// //     }
194/// // }
195/// ```
196#[must_use]
197pub struct CastFunctionBuilder {
198    source: Option<TypeId>,
199    source_logical: Option<LogicalType>,
200    target: Option<TypeId>,
201    target_logical: Option<LogicalType>,
202    function: Option<CastFn>,
203    implicit_cost: Option<i64>,
204    extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
205}
206
207// SAFETY: CastFunctionBuilder owns the extra_info pointer and LogicalType handles
208// until registration. The raw pointers are only sent across threads as part of the
209// builder, which extension authors typically use on a single thread.
210#[allow(clippy::non_send_fields_in_send_ty)]
211unsafe impl Send for CastFunctionBuilder {}
212
213impl CastFunctionBuilder {
214    /// Creates a new builder that will cast `source` values into `target` values.
215    pub const fn new(source: TypeId, target: TypeId) -> Self {
216        Self {
217            source: Some(source),
218            source_logical: None,
219            target: Some(target),
220            target_logical: None,
221            function: None,
222            implicit_cost: None,
223            extra_info: None,
224        }
225    }
226
227    /// Creates a new builder using [`LogicalType`]s for source and target.
228    ///
229    /// Use this when the source or target types are complex (e.g.
230    /// `DECIMAL(18, 3)`, `LIST(VARCHAR)`, etc.) and cannot be expressed as
231    /// simple [`TypeId`] values.
232    pub fn new_logical(source: LogicalType, target: LogicalType) -> Self {
233        Self {
234            source: None,
235            source_logical: Some(source),
236            target: None,
237            target_logical: Some(target),
238            function: None,
239            implicit_cost: None,
240            extra_info: None,
241        }
242    }
243
244    /// Returns the source type this cast converts from (if set via [`new`][Self::new]).
245    ///
246    /// Returns `None` if the source was set via [`new_logical`][Self::new_logical].
247    ///
248    /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
249    pub const fn source(&self) -> Option<TypeId> {
250        self.source
251    }
252
253    /// Returns the target type this cast converts to (if set via [`new`][Self::new]).
254    ///
255    /// Returns `None` if the target was set via [`new_logical`][Self::new_logical].
256    ///
257    /// Useful for introspection and for [`MockRegistrar`][crate::testing::MockRegistrar].
258    pub const fn target(&self) -> Option<TypeId> {
259        self.target
260    }
261
262    /// Sets the cast callback.
263    pub fn function(mut self, f: CastFn) -> Self {
264        self.function = Some(f);
265        self
266    }
267
268    /// Sets the implicit cast cost.
269    ///
270    /// When a non-negative cost is provided, `DuckDB` may use this cast
271    /// automatically in expressions where an implicit coercion is needed.
272    /// Lower cost means higher priority. A negative cost or omitting this
273    /// method makes the cast explicit-only.
274    pub const fn implicit_cost(mut self, cost: i64) -> Self {
275        self.implicit_cost = Some(cost);
276        self
277    }
278
279    /// Attaches extra data to the cast function.
280    ///
281    /// The pointer is available inside the callback via
282    /// [`CastFunctionInfo::get_extra_info`].
283    ///
284    /// # Safety
285    ///
286    /// `ptr` must remain valid until `DuckDB` calls `destroy`, or for the
287    /// lifetime of the database if `destroy` is `None`.
288    pub unsafe fn extra_info(
289        mut self,
290        ptr: *mut c_void,
291        destroy: duckdb_delete_callback_t,
292    ) -> Self {
293        self.extra_info = Some((ptr, destroy));
294        self
295    }
296
297    /// Registers the cast function on the given connection.
298    ///
299    /// # Errors
300    ///
301    /// Returns `ExtensionError` if:
302    /// - The function callback was not set.
303    /// - `DuckDB` reports a registration failure.
304    ///
305    /// # Safety
306    ///
307    /// `con` must be a valid, open `duckdb_connection`.
308    pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
309        let function = self
310            .function
311            .ok_or_else(|| ExtensionError::new("cast function callback not set"))?;
312
313        // SAFETY: allocates a new cast function handle.
314        let mut cast = unsafe { duckdb_create_cast_function() };
315
316        // Resolve source type: prefer explicit LogicalType over TypeId.
317        let src_lt = if let Some(lt) = self.source_logical {
318            lt
319        } else if let Some(id) = self.source {
320            LogicalType::new(id)
321        } else {
322            return Err(ExtensionError::new("cast source type not set"));
323        };
324        // SAFETY: cast and src_lt.as_raw() are valid.
325        unsafe {
326            duckdb_cast_function_set_source_type(cast, src_lt.as_raw());
327        }
328
329        // Resolve target type: prefer explicit LogicalType over TypeId.
330        let tgt_lt = if let Some(lt) = self.target_logical {
331            lt
332        } else if let Some(id) = self.target {
333            LogicalType::new(id)
334        } else {
335            return Err(ExtensionError::new("cast target type not set"));
336        };
337        // SAFETY: cast and tgt_lt.as_raw() are valid.
338        unsafe {
339            duckdb_cast_function_set_target_type(cast, tgt_lt.as_raw());
340        }
341
342        // Set callback
343        // SAFETY: function is a valid extern "C" fn pointer.
344        unsafe {
345            duckdb_cast_function_set_function(cast, Some(function));
346        }
347
348        // Set implicit cost if requested
349        if let Some(cost) = self.implicit_cost {
350            // SAFETY: cast is a valid handle.
351            unsafe {
352                duckdb_cast_function_set_implicit_cast_cost(cast, cost);
353            }
354        }
355
356        // Attach extra info if provided
357        if let Some((ptr, destroy)) = self.extra_info {
358            // SAFETY: ptr validity is the caller's responsibility per the safety
359            // contract on extra_info().
360            unsafe {
361                duckdb_cast_function_set_extra_info(cast, ptr, destroy);
362            }
363        }
364
365        // Register
366        // SAFETY: con is a valid open connection, cast is fully configured.
367        let result = unsafe { duckdb_register_cast_function(con, cast) };
368
369        // SAFETY: cast was created above and must be destroyed after use.
370        unsafe {
371            duckdb_destroy_cast_function(&raw mut cast);
372        }
373
374        if result == DuckDBSuccess {
375            Ok(())
376        } else {
377            Err(ExtensionError::new("duckdb_register_cast_function failed"))
378        }
379    }
380}
381
382// ── Tests ──────────────────────────────────────────────────────────────────────
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use libduckdb_sys::{duckdb_function_info, duckdb_vector, idx_t};
388
389    unsafe extern "C" fn noop_cast(
390        _: duckdb_function_info,
391        _: idx_t,
392        _: duckdb_vector,
393        _: duckdb_vector,
394    ) -> bool {
395        true
396    }
397
398    #[test]
399    fn builder_stores_source_and_target() {
400        let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer);
401        assert_eq!(b.source(), Some(TypeId::Varchar));
402        assert_eq!(b.target(), Some(TypeId::Integer));
403    }
404
405    #[test]
406    fn builder_stores_function() {
407        let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer).function(noop_cast);
408        assert!(b.function.is_some());
409    }
410
411    #[test]
412    fn builder_stores_implicit_cost() {
413        let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer).implicit_cost(10);
414        assert_eq!(b.implicit_cost, Some(10));
415    }
416
417    #[test]
418    fn builder_no_function_is_error() {
419        // We cannot call register without a live DuckDB, but we can assert the
420        // function field starts as None.
421        let b = CastFunctionBuilder::new(TypeId::BigInt, TypeId::Double);
422        assert!(b.function.is_none());
423    }
424
425    #[test]
426    fn cast_mode_from_raw_normal() {
427        use libduckdb_sys::duckdb_cast_mode_DUCKDB_CAST_NORMAL;
428        assert_eq!(
429            CastMode::from_raw(duckdb_cast_mode_DUCKDB_CAST_NORMAL),
430            CastMode::Normal
431        );
432    }
433
434    #[test]
435    fn cast_mode_from_raw_try() {
436        assert_eq!(
437            CastMode::from_raw(duckdb_cast_mode_DUCKDB_CAST_TRY),
438            CastMode::Try
439        );
440    }
441
442    #[test]
443    fn cast_function_info_wraps_null() {
444        // Constructing with null must not crash (no DuckDB calls made).
445        let _info = unsafe { CastFunctionInfo::new(std::ptr::null_mut()) };
446    }
447}