use std::ffi::CString;
use std::os::raw::c_void;
use libduckdb_sys::{
duckdb_cast_function_get_cast_mode, duckdb_cast_function_get_extra_info,
duckdb_cast_function_set_error, duckdb_cast_function_set_extra_info,
duckdb_cast_function_set_function, duckdb_cast_function_set_implicit_cast_cost,
duckdb_cast_function_set_row_error, duckdb_cast_function_set_source_type,
duckdb_cast_function_set_target_type, duckdb_cast_mode_DUCKDB_CAST_TRY, duckdb_connection,
duckdb_create_cast_function, duckdb_delete_callback_t, duckdb_destroy_cast_function,
duckdb_function_info, duckdb_register_cast_function, duckdb_vector, idx_t, DuckDBSuccess,
};
use crate::error::ExtensionError;
use crate::types::{LogicalType, TypeId};
#[mutants::skip] fn str_to_cstring(s: &str) -> CString {
CString::new(s).unwrap_or_else(|_| {
let pos = s.bytes().position(|b| b == 0).unwrap_or(s.len());
CString::new(&s.as_bytes()[..pos]).unwrap_or_default()
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CastMode {
Normal,
Try,
}
impl CastMode {
const fn from_raw(raw: libduckdb_sys::duckdb_cast_mode) -> Self {
if raw == duckdb_cast_mode_DUCKDB_CAST_TRY {
Self::Try
} else {
Self::Normal
}
}
}
pub struct CastFunctionInfo {
info: duckdb_function_info,
}
impl CastFunctionInfo {
#[inline]
#[must_use]
pub const unsafe fn new(info: duckdb_function_info) -> Self {
Self { info }
}
#[must_use]
pub fn cast_mode(&self) -> CastMode {
let raw = unsafe { duckdb_cast_function_get_cast_mode(self.info) };
CastMode::from_raw(raw)
}
#[must_use]
pub unsafe fn get_extra_info(&self) -> *mut c_void {
unsafe { duckdb_cast_function_get_extra_info(self.info) }
}
#[mutants::skip]
pub fn set_error(&self, message: &str) {
let c_msg = str_to_cstring(message);
unsafe {
duckdb_cast_function_set_error(self.info, c_msg.as_ptr());
}
}
pub unsafe fn set_row_error(&self, message: &str, row: idx_t, output: duckdb_vector) {
let c_msg = str_to_cstring(message);
unsafe {
duckdb_cast_function_set_row_error(self.info, c_msg.as_ptr(), row, output);
}
}
}
pub type CastFn = unsafe extern "C" fn(
info: duckdb_function_info,
count: idx_t,
input: duckdb_vector,
output: duckdb_vector,
) -> bool;
#[must_use]
pub struct CastFunctionBuilder {
source: Option<TypeId>,
source_logical: Option<LogicalType>,
target: Option<TypeId>,
target_logical: Option<LogicalType>,
function: Option<CastFn>,
implicit_cost: Option<i64>,
extra_info: Option<(*mut c_void, duckdb_delete_callback_t)>,
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl Send for CastFunctionBuilder {}
impl CastFunctionBuilder {
pub const fn new(source: TypeId, target: TypeId) -> Self {
Self {
source: Some(source),
source_logical: None,
target: Some(target),
target_logical: None,
function: None,
implicit_cost: None,
extra_info: None,
}
}
pub fn new_logical(source: LogicalType, target: LogicalType) -> Self {
Self {
source: None,
source_logical: Some(source),
target: None,
target_logical: Some(target),
function: None,
implicit_cost: None,
extra_info: None,
}
}
pub const fn source(&self) -> Option<TypeId> {
self.source
}
pub const fn target(&self) -> Option<TypeId> {
self.target
}
pub fn function(mut self, f: CastFn) -> Self {
self.function = Some(f);
self
}
pub const fn implicit_cost(mut self, cost: i64) -> Self {
self.implicit_cost = Some(cost);
self
}
pub unsafe fn extra_info(
mut self,
ptr: *mut c_void,
destroy: duckdb_delete_callback_t,
) -> Self {
self.extra_info = Some((ptr, destroy));
self
}
pub unsafe fn register(self, con: duckdb_connection) -> Result<(), ExtensionError> {
let function = self
.function
.ok_or_else(|| ExtensionError::new("cast function callback not set"))?;
let mut cast = unsafe { duckdb_create_cast_function() };
let src_lt = if let Some(lt) = self.source_logical {
lt
} else if let Some(id) = self.source {
LogicalType::new(id)
} else {
return Err(ExtensionError::new("cast source type not set"));
};
unsafe {
duckdb_cast_function_set_source_type(cast, src_lt.as_raw());
}
let tgt_lt = if let Some(lt) = self.target_logical {
lt
} else if let Some(id) = self.target {
LogicalType::new(id)
} else {
return Err(ExtensionError::new("cast target type not set"));
};
unsafe {
duckdb_cast_function_set_target_type(cast, tgt_lt.as_raw());
}
unsafe {
duckdb_cast_function_set_function(cast, Some(function));
}
if let Some(cost) = self.implicit_cost {
unsafe {
duckdb_cast_function_set_implicit_cast_cost(cast, cost);
}
}
if let Some((ptr, destroy)) = self.extra_info {
unsafe {
duckdb_cast_function_set_extra_info(cast, ptr, destroy);
}
}
let result = unsafe { duckdb_register_cast_function(con, cast) };
unsafe {
duckdb_destroy_cast_function(&raw mut cast);
}
if result == DuckDBSuccess {
Ok(())
} else {
Err(ExtensionError::new("duckdb_register_cast_function failed"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use libduckdb_sys::{duckdb_function_info, duckdb_vector, idx_t};
unsafe extern "C" fn noop_cast(
_: duckdb_function_info,
_: idx_t,
_: duckdb_vector,
_: duckdb_vector,
) -> bool {
true
}
#[test]
fn builder_stores_source_and_target() {
let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer);
assert_eq!(b.source(), Some(TypeId::Varchar));
assert_eq!(b.target(), Some(TypeId::Integer));
}
#[test]
fn builder_stores_function() {
let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer).function(noop_cast);
assert!(b.function.is_some());
}
#[test]
fn builder_stores_implicit_cost() {
let b = CastFunctionBuilder::new(TypeId::Varchar, TypeId::Integer).implicit_cost(10);
assert_eq!(b.implicit_cost, Some(10));
}
#[test]
fn builder_no_function_is_error() {
let b = CastFunctionBuilder::new(TypeId::BigInt, TypeId::Double);
assert!(b.function.is_none());
}
#[test]
fn cast_mode_from_raw_normal() {
use libduckdb_sys::duckdb_cast_mode_DUCKDB_CAST_NORMAL;
assert_eq!(
CastMode::from_raw(duckdb_cast_mode_DUCKDB_CAST_NORMAL),
CastMode::Normal
);
}
#[test]
fn cast_mode_from_raw_try() {
assert_eq!(
CastMode::from_raw(duckdb_cast_mode_DUCKDB_CAST_TRY),
CastMode::Try
);
}
#[test]
fn cast_function_info_wraps_null() {
let _info = unsafe { CastFunctionInfo::new(std::ptr::null_mut()) };
}
}