use std::any::type_name;
use std::backtrace::Backtrace;
use std::cmp::min;
use std::ffi::{c_char, c_ulong};
use std::ptr;
use crate::udf_log;
pub unsafe fn write_msg_to_buf<const N: usize>(msg: &[u8], buf: *mut c_char) {
let bytes_to_write = min(msg.len(), N - 1);
unsafe {
ptr::copy_nonoverlapping(msg.as_ptr().cast::<c_char>(), buf, bytes_to_write);
*buf.add(bytes_to_write) = 0;
}
}
pub struct BufOptions {
res_buf: *mut c_char,
length: *mut c_ulong,
can_return_ref: bool,
}
impl BufOptions {
pub fn new(res_buf: *mut c_char, length: *mut c_ulong, can_return_ref: bool) -> Self {
Self {
res_buf,
length,
can_return_ref,
}
}
}
pub unsafe fn buf_result_callback<U, T: AsRef<[u8]>>(
input: T,
opts: &BufOptions,
) -> Option<*const c_char> {
let slice_ref = input.as_ref();
let slice_len = slice_ref.len();
let slice_len_ulong: c_ulong = slice_len.try_into().unwrap_or_else(|_| {
udf_log!(Error: "Buffer size {}, platform limitation of {}. Truncating", slice_len, c_ulong::MAX);
c_ulong::MAX});
let slice_ptr: *const c_char = slice_ref.as_ptr().cast();
let buf_len = *opts.length as usize;
if slice_len <= buf_len {
ptr::copy(slice_ptr, opts.res_buf, slice_len);
*opts.length = slice_len_ulong;
return Some(opts.res_buf);
}
if !opts.can_return_ref {
*opts.length = 0;
udf_log!(
Critical: "output overflow, returning NULL. Buffer size: {}, data length: {} at `{}::process`",
buf_len, slice_len, type_name::<U>()
);
udf_log!(Critical: "contact your UDF vendor as this is a serious bug");
udf_log!(Critical: "run with `RUST_LIB_BACKTRACE=1` for a full backtrace");
eprintln!("{:#?}", Backtrace::capture());
return None;
}
*opts.length = slice_len_ulong;
Some(slice_ptr)
}
#[cfg(test)]
mod tests {
#![allow(clippy::similar_names)]
use std::ffi::{c_ulong, c_void, CStr};
use std::ptr;
use udf_sys::{Item_result, UDF_ARGS};
use super::*;
use crate::prelude::*;
const MSG: &str = "message";
const BUF_SIZE: usize = MSG.len() + 1;
#[test]
fn write_msg_ok() {
let mut mbuf = [1 as c_char; BUF_SIZE];
unsafe {
write_msg_to_buf::<BUF_SIZE>(MSG.as_bytes(), mbuf.as_mut_ptr());
let s = CStr::from_ptr(mbuf.as_ptr()).to_str().unwrap();
assert_eq!(s, MSG);
}
}
#[test]
fn write_message_too_long() {
const NEW_BUF_SIZE: usize = BUF_SIZE - 1;
let mut mbuf = [1 as c_char; NEW_BUF_SIZE];
unsafe {
write_msg_to_buf::<NEW_BUF_SIZE>(MSG.as_bytes(), mbuf.as_mut_ptr());
let s = CStr::from_ptr(mbuf.as_ptr()).to_str().unwrap();
assert_eq!(*s, MSG[..MSG.len() - 1]);
};
}
#[test]
fn argtype_from_ptr_null() {
unsafe {
assert_eq!(
SqlResult::from_ptr(ptr::null(), Item_result::INT_RESULT, 0),
Ok(SqlResult::Int(None))
);
assert_eq!(
SqlResult::from_ptr(ptr::null(), Item_result::REAL_RESULT, 0),
Ok(SqlResult::Real(None))
);
assert_eq!(
SqlResult::from_ptr(ptr::null(), Item_result::STRING_RESULT, 0),
Ok(SqlResult::String(None))
);
assert_eq!(
SqlResult::from_ptr(ptr::null(), Item_result::DECIMAL_RESULT, 0),
Ok(SqlResult::Decimal(None))
);
assert!(SqlResult::from_ptr(ptr::null(), Item_result::INVALID_RESULT, 0).is_err());
}
}
#[test]
fn argtype_from_ptr_notnull() {
unsafe {
let ival = -1000i64;
assert_eq!(
SqlResult::from_ptr(ptr::addr_of!(ival).cast(), Item_result::INT_RESULT, 0),
Ok(SqlResult::Int(Some(ival)))
);
let rval = -1000.0f64;
assert_eq!(
SqlResult::from_ptr(ptr::addr_of!(rval).cast(), Item_result::REAL_RESULT, 0),
Ok(SqlResult::Real(Some(rval)))
);
let sval = "this is a string";
assert_eq!(
SqlResult::from_ptr(sval.as_ptr(), Item_result::STRING_RESULT, sval.len()),
Ok(SqlResult::String(Some(sval.as_bytes())))
);
let dval = "123.456";
assert_eq!(
SqlResult::from_ptr(dval.as_ptr(), Item_result::DECIMAL_RESULT, dval.len()),
Ok(SqlResult::Decimal(Some(dval)))
);
assert!(
SqlResult::from_ptr(dval.as_ptr(), Item_result::INVALID_RESULT, dval.len())
.is_err()
);
}
}
const ARG_COUNT: usize = 4;
static IVAL: i64 = -1000i64;
static RVAL: f64 = -1234.5678f64;
static SVAL: &str = "this is a string";
static DVAL: &str = "123.456";
#[test]
fn process_args_ok() {
let mut arg_types = [
Item_result::INT_RESULT,
Item_result::REAL_RESULT,
Item_result::STRING_RESULT,
Item_result::DECIMAL_RESULT,
];
let mut arg_ptrs: [*const u8; ARG_COUNT] = [
ptr::addr_of!(IVAL).cast(),
ptr::addr_of!(RVAL).cast(),
SVAL.as_ptr(),
DVAL.as_ptr(),
];
let mut arg_lens: [c_ulong; 4] = [0, 0, SVAL.len() as c_ulong, DVAL.len() as c_ulong];
let mut maybe_null = [true, true, false, false];
let attrs = ["ival", "rval", "sval", "dval"];
let mut attr_ptrs = [
attrs[0].as_ptr(),
attrs[1].as_ptr(),
attrs[2].as_ptr(),
attrs[3].as_ptr(),
];
let mut attr_lens: [c_ulong; 4] = [
attrs[0].len() as c_ulong,
attrs[1].len() as c_ulong,
attrs[2].len() as c_ulong,
attrs[3].len() as c_ulong,
];
let mut udf_args = UDF_ARGS {
arg_count: ARG_COUNT as u32,
arg_types: arg_types.as_mut_ptr(),
args: arg_ptrs.as_mut_ptr() as *const *const c_char,
lengths: arg_lens.as_mut_ptr(),
maybe_null: maybe_null.as_mut_ptr() as *const c_char,
attributes: attr_ptrs.as_mut_ptr() as *const *const c_char,
attribute_lengths: attr_lens.as_mut_ptr() as *const c_ulong,
extension: ptr::null_mut::<c_void>(),
};
let arglist: &ArgList<Init> = unsafe { ArgList::from_raw_ptr(&mut udf_args) };
let res: Vec<_> = arglist.into_iter().collect();
let expected_args = [
SqlResult::Int(Some(IVAL)),
SqlResult::Real(Some(RVAL)),
SqlResult::String(Some(SVAL.as_bytes())),
SqlResult::Decimal(Some(DVAL)),
];
for i in 0..ARG_COUNT {
assert_eq!(res[i].value(), expected_args[i]);
assert_eq!(res[i].maybe_null(), maybe_null[i]);
assert_eq!(res[i].attribute(), attrs[i]);
}
}
}
#[cfg(test)]
mod buffer_tests {
use core::slice;
use super::*;
const BUF_LEN: usize = 10;
#[test]
fn test_buf_fits() {
let input = b"1234";
let mut res_buf = [0u8; BUF_LEN];
let zeroes = [0u8; BUF_LEN];
let mut len = res_buf.len() as c_ulong;
let buf_opts = BufOptions::new(res_buf.as_mut_ptr().cast(), &mut len, false);
let res_ptr: *const u8 = unsafe { buf_result_callback::<u8, _>(input, &buf_opts) }
.unwrap()
.cast();
let res_slice = unsafe { slice::from_raw_parts(res_ptr, len as usize) };
assert_eq!(len as usize, input.len());
assert_eq!(res_slice, input);
assert_eq!(res_ptr.cast(), res_buf.as_ptr());
assert_eq!(
res_buf[input.len()..res_buf.len()],
zeroes[input.len()..res_buf.len()]
);
}
#[test]
fn test_buf_no_fit_ref() {
let input = b"123456789012345";
let mut res_buf = [0u8; BUF_LEN];
let mut len = res_buf.len() as c_ulong;
let buf_opts = BufOptions::new(res_buf.as_mut_ptr().cast(), &mut len, true);
let res_ptr: *const u8 = unsafe { buf_result_callback::<u8, _>(input, &buf_opts) }
.unwrap()
.cast();
let res_slice = unsafe { slice::from_raw_parts(res_ptr, len as usize) };
assert_eq!(len as usize, input.len());
assert_eq!(res_slice, input);
assert_eq!(res_ptr.cast(), input.as_ptr());
}
#[test]
fn test_buf_no_fit_no_ref() {
let input = b"123456789012345";
let mut res_buf = [0u8; BUF_LEN];
let mut len = res_buf.len() as c_ulong;
let buf_opts = BufOptions::new(res_buf.as_mut_ptr().cast(), &mut len, false);
let res = unsafe { buf_result_callback::<u8, _>(input, &buf_opts) };
assert_eq!(len, 0);
assert_eq!(res, None);
}
}