use core::fmt::Debug;
use std::marker::PhantomData;
use std::{slice, str};
use coerce::{get_coercion, get_current_type, get_desired_or_current, set_coercion};
use udf_sys::Item_result;
use crate::types::{SqlResult, SqlType};
use crate::wrapper::UDF_ARGSx;
use crate::{ArgList, Init, UdfState};
#[derive(Debug)]
#[allow(clippy::module_name_repetitions)]
pub struct SqlArg<'a, S: UdfState> {
pub(super) base: &'a ArgList<'a, S>,
pub(super) index: usize,
pub(super) marker: PhantomData<S>,
}
impl<'a, T: UdfState> SqlArg<'a, T> {
#[inline]
#[allow(clippy::missing_panics_doc)]
pub fn value(&self) -> SqlResult<'a> {
unsafe {
let base = self.get_base();
let arg_buf_ptr: *const u8 = (*base.args.add(self.index)).cast();
let arg_type = *base.arg_types.add(self.index);
let arg_len = *base.lengths.add(self.index);
SqlResult::from_ptr(arg_buf_ptr, arg_type.try_into().unwrap(), arg_len as usize)
.unwrap()
}
}
#[inline]
#[allow(clippy::missing_panics_doc)]
pub fn attribute(&'a self) -> &'a str {
let attr_slice;
unsafe {
let base = self.get_base();
let attr_buf_ptr: *const u8 = *base.attributes.add(self.index).cast();
let attr_len = *base.attribute_lengths.add(self.index) as usize;
attr_slice = slice::from_raw_parts(attr_buf_ptr, attr_len);
}
str::from_utf8(attr_slice)
.map_err(|e| format!("unexpected: attribute is not valid utf8. Error: {e:?}"))
.unwrap()
}
unsafe fn get_base(&'a self) -> &'a UDF_ARGSx {
&(*self.base.0.get())
}
unsafe fn arg_type_ptr(&self) -> *mut i32 {
self.get_base().arg_types.add(self.index)
}
}
impl<'a> SqlArg<'a, Init> {
#[inline]
pub fn is_const(&self) -> bool {
match self.value() {
SqlResult::String(v) => v.is_some(),
SqlResult::Decimal(v) => v.is_some(),
SqlResult::Real(v) => v.is_some(),
SqlResult::Int(v) => v.is_some(),
}
}
#[inline]
pub fn maybe_null(&self) -> bool {
unsafe { *self.get_base().maybe_null.add(self.index) != 0 }
}
#[inline]
#[allow(clippy::missing_panics_doc)] pub fn set_type_coercion(&mut self, newtype: SqlType) {
unsafe {
let arg_ptr = self.arg_type_ptr();
*arg_ptr = set_coercion(*arg_ptr, newtype as i32);
}
}
#[inline]
#[allow(clippy::missing_panics_doc)] pub fn get_type_coercion(&self) -> SqlType {
unsafe {
let arg_type = *self.arg_type_ptr();
let coerced_type = get_coercion(arg_type).unwrap_or_else(|| get_current_type(arg_type));
SqlType::try_from(coerced_type as i8).expect("critical: invalid sql type")
}
}
#[inline]
pub(crate) fn flush_coercion(&mut self) {
unsafe {
let to_set = get_desired_or_current(*self.arg_type_ptr());
let _ = Item_result::try_from(to_set).unwrap();
*self.arg_type_ptr() = to_set;
}
}
}
mod coerce {
const COERCION_SET: i32 = 0b1010_1010 << (3 * 8);
const COERCION_SET_MASK: i32 = 0b1111_1111 << (3 * 8);
const DESIRED_MASK: i32 = 0b1111_1111 << 8;
const BYTE_MASK: i32 = 0b1111_1111;
const RESET_COERCION_DESIRED_MASK: i32 = !(COERCION_SET_MASK | DESIRED_MASK);
fn coercion_is_set(value: i32) -> bool {
value & COERCION_SET_MASK == COERCION_SET
}
pub fn set_coercion(current: i32, desired: i32) -> i32 {
RESET_COERCION_DESIRED_MASK & current | COERCION_SET | ((desired & BYTE_MASK) << 8)
}
#[allow(clippy::cast_lossless)]
pub fn get_coercion(value: i32) -> Option<i32> {
if coercion_is_set(value) {
Some(((value & DESIRED_MASK) >> 8) as i8 as i32)
} else {
None
}
}
#[allow(clippy::cast_lossless)]
pub fn get_current_type(value: i32) -> i32 {
(value & BYTE_MASK) as i8 as i32
}
pub fn get_desired_or_current(value: i32) -> i32 {
get_coercion(value).unwrap_or_else(|| get_current_type(value))
}
#[cfg(test)]
mod tests {
use super::*;
const TESTVALS: [i32; 8] = [-10, -5, -1, 0, 1, 5, 10, 20];
#[test]
fn test_unset_coercion() {
for val in TESTVALS.iter().copied() {
assert!(!coercion_is_set(val));
assert_eq!(get_coercion(val), None);
assert_eq!(get_current_type(val), val);
assert_eq!(get_desired_or_current(val), val);
}
}
#[test]
fn test_coercion() {
for current in TESTVALS.iter().copied() {
for desired in TESTVALS.iter().copied() {
let res = set_coercion(current, desired);
assert!(coercion_is_set(res));
assert_eq!(get_coercion(res), Some(desired));
assert_eq!(get_current_type(res), current);
assert_eq!(get_desired_or_current(res), desired);
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::mem;
use super::*;
#[test]
fn verify_item_result_layout() {
assert_eq!(mem::size_of::<Item_result>(), mem::size_of::<i32>());
assert_eq!(mem::align_of::<Item_result>(), mem::align_of::<i32>());
}
}