//! Rust representation of `UDF_INIT`
#![allow(clippy::useless_conversion, clippy::unnecessary_cast)]
use std::cell::UnsafeCell;
use std::ffi::c_ulong;
use std::fmt::Debug;
use std::marker::PhantomData;
#[cfg(feature = "logging-debug")]
use std::{any::type_name, mem::size_of};
use udf_sys::UDF_INIT;
#[cfg(feature = "logging-debug")]
use crate::udf_log;
use crate::{Init, UdfState};
/// Helpful constants related to the `max_length` parameter
///
/// These can be helpful when calling [`UdfCfg::set_max_len()`]
#[repr(u32)]
#[non_exhaustive]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum MaxLenOptions {
/// The default max length for integers is 21
IntDefault = 21,
/// The default max length of a real value is 13 plus the result of
/// [`UdfCfg::get_decimals()`]
RealBase = 13,
/// A `blob` can be up to 65 KiB.
Blob = 1 << 16,
/// A `mediumblob` can be up to 16 MiB.
MediumBlob = 1 << 24,
}
/// A collection of SQL arguments
///
/// This is rusty wrapper around SQL's `UDF_INIT` struct, providing methods to
/// easily and safely work with arguments.
#[repr(transparent)]
pub struct UdfCfg<S: UdfState>(pub(crate) UnsafeCell<UDF_INIT>, PhantomData<S>);
impl<S: UdfState> UdfCfg<S> {
/// Create an `ArgList` type on a `UDF_ARGS` struct
///
/// # Safety
///
/// The caller must guarantee that `ptr` is valid and remains valid for the
/// lifetime of the returned value
#[inline]
pub(crate) unsafe fn from_raw_ptr<'p>(ptr: *const UDF_INIT) -> &'p Self {
&*ptr.cast()
}
/// Consume a box and store its pointer in this `UDF_INIT`
///
/// This takes a boxed object, turns it into a pointer, and stores that
/// pointer in this struct. After calling this function, [`retrieve_box`]
/// _must_ be called to free the memory!
pub(crate) fn store_box<T>(&self, b: Box<T>) {
let box_ptr = Box::into_raw(b);
// Note: if T is zero-sized, this will print `0x1` for the address
#[cfg(feature = "logging-debug")]
udf_log!(
Debug: "{box_ptr:p} {} bytes udf->server control transfer ({})",
size_of::<T>(),type_name::<T>()
);
// SAFETY: unsafe when called from different threads, but we are `!Sync`
// here
unsafe { (*self.0.get()).ptr = box_ptr.cast() };
}
/// Given this struct's `ptr` field is a boxed object, turn that pointer
/// back into a box
///
/// # Safety
///
/// T _must_ be the type of this struct's pointer, likely created with
/// [`store_box`]
pub(crate) unsafe fn retrieve_box<T>(&self) -> Box<T> {
let box_ptr = (*self.0.get()).ptr.cast::<T>();
#[cfg(feature = "logging-debug")]
udf_log!(
Debug: "{box_ptr:p} {} bytes server->udf control transfer ({})",
size_of::<T>(),type_name::<T>()
);
Box::from_raw(box_ptr)
}
/// Retrieve the setting for whether this UDF may return `null`
///
/// This defaults to true if any argument is nullable, false otherwise
#[inline]
pub fn get_maybe_null(&self) -> bool {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).maybe_null }
}
/// Retrieve the setting for number of decimal places
///
/// This defaults to the longest number of digits of any argument, or 31 if
/// there is no fixed number
#[inline]
pub fn get_decimals(&self) -> u32 {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).decimals as u32 }
}
/// Set the number of decimals this function returns
///
/// This can be changed at any point in the UDF (init or process)
#[inline]
pub fn set_decimals(&self, v: u32) {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).decimals = v.into() }
}
/// Retrieve the current maximum length setting for this in-progress UDF
#[inline]
pub fn get_max_len(&self) -> u64 {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).max_length as u64 }
}
/// Get the current `const_item` value
#[inline]
pub fn get_is_const(&self) -> bool {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).const_item }
}
}
/// Implementations of actions on a `UdfCfg` that are only possible during
/// initialization
impl UdfCfg<Init> {
/// Set whether or not this function may return null
#[inline]
pub fn set_maybe_null(&self, v: bool) {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).maybe_null = v };
}
/// Set the maximum possible length of this UDF's result
///
/// This is mostly relevant for String and Decimal return types. See
/// [`MaxLenOptions`] for possible defaults, including `BLOB` sizes.
#[inline]
pub fn set_max_len(&self, v: u64) {
// Need to try_into because ulong is 64 bits in GNU but 32 bits MSVC
let set: c_ulong = v.try_into().unwrap_or(c_ulong::MAX);
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).max_length = set };
}
/// Set a new `const_item` value
///
/// Set this to true if your function always returns the same values with
/// the same arguments
#[inline]
pub fn set_is_const(&self, v: bool) {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
unsafe { (*self.0.get()).const_item = v };
}
}
impl<T: UdfState> Debug for UdfCfg<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// SAFETY: unsafe when called from different threads, but we are `!Sync`
// here
let base = unsafe { &*self.0.get() };
f.debug_struct("UdfCfg")
.field("maybe_null", &base.maybe_null)
.field("decimals", &base.decimals)
.field("max_len", &base.max_length)
.field("is_const", &base.const_item)
.field("ptr", &base.ptr)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::mem::{align_of, size_of};
use super::*;
use crate::mock::MockUdfCfg;
use crate::{Init, Process};
// Verify no size issues
#[test]
fn cfg_init_size() {
assert_eq!(
size_of::<UDF_INIT>(),
size_of::<UdfCfg<Init>>(),
concat!("Size of: ", stringify!(UDF_INIT))
);
assert_eq!(
align_of::<UDF_INIT>(),
align_of::<UdfCfg<Init>>(),
concat!("Alignment of ", stringify!(UDF_INIT))
);
}
#[test]
fn cfg_proc_size() {
assert_eq!(
size_of::<UDF_INIT>(),
size_of::<UdfCfg<Process>>(),
concat!("Size of: ", stringify!(UDF_INIT))
);
assert_eq!(
align_of::<UDF_INIT>(),
align_of::<UdfCfg<Process>>(),
concat!("Alignment of ", stringify!(UDF_INIT))
);
}
#[test]
fn test_box_load_store() {
// Verify store & retrieve on a box works
#[derive(PartialEq, Debug, Clone)]
struct X {
s: String,
map: HashMap<i64, f64>,
}
let mut map = HashMap::new();
map.insert(930_984_098, 4_525_435_435.900_981);
map.insert(12_341_234, -234.090_909_092);
map.insert(-23_412_343_453, 838_383.6);
let stored = X {
s: "This is a string".to_owned(),
map,
};
let mut m = MockUdfCfg::new();
let cfg = m.as_init();
cfg.store_box(Box::new(stored.clone()));
let loaded: X = unsafe { *cfg.retrieve_box() };
assert_eq!(stored, loaded);
}
#[test]
fn maybe_null() {
let mut m = MockUdfCfg::new();
*m.maybe_null() = false;
assert!(!m.as_init().get_maybe_null());
*m.maybe_null() = true;
assert!(m.as_init().get_maybe_null());
}
#[test]
fn decimals() {
let mut m = MockUdfCfg::new();
*m.decimals() = 1234;
assert_eq!(m.as_init().get_decimals(), 1234);
*m.decimals() = 0;
assert_eq!(m.as_init().get_decimals(), 0);
*m.decimals() = 1;
assert_eq!(m.as_init().get_decimals(), 1);
m.as_init().set_decimals(4);
assert_eq!(*m.decimals(), 4);
}
#[test]
fn max_len() {
let mut m = MockUdfCfg::new();
*m.max_len() = 1234;
assert_eq!(m.as_init().get_max_len(), 1234);
*m.max_len() = 0;
assert_eq!(m.as_init().get_max_len(), 0);
*m.max_len() = 1;
assert_eq!(m.as_init().get_max_len(), 1);
}
#[test]
fn test_const() {
let mut m = MockUdfCfg::new();
*m.is_const() = false;
assert!(!m.as_init().get_is_const());
*m.is_const() = true;
assert!(m.as_init().get_is_const());
}
}