use crate::{Error, ScalarFunction, TableFunction};
use arrow_array::RecordBatchReader;
use arrow_ipc::{reader::FileReader, writer::FileWriter};
#[unsafe(no_mangle)]
#[used]
pub static ARROWUDF_VERSION_3_0: () = ();
#[unsafe(no_mangle)]
pub unsafe extern "C" fn alloc(len: usize, align: usize) -> *mut u8 {
unsafe { std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(len, align)) }
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn dealloc(ptr: *mut u8, len: usize, align: usize) {
unsafe {
std::alloc::dealloc(
ptr,
std::alloc::Layout::from_size_align_unchecked(len, align),
);
}
}
#[repr(C)]
#[derive(Debug)]
pub struct CSlice {
pub ptr: *const u8,
pub len: usize,
}
pub unsafe fn scalar_wrapper(
function: ScalarFunction,
ptr: *const u8,
len: usize,
out_slice: *mut CSlice,
) -> i32 {
unsafe {
let input = std::slice::from_raw_parts(ptr, len);
match call_scalar(function, input) {
Ok(data) => {
out_slice.write(CSlice {
ptr: data.as_ptr(),
len: data.len(),
});
std::mem::forget(data);
0
}
Err(err) => {
let msg = err.to_string().into_boxed_str();
out_slice.write(CSlice {
ptr: msg.as_ptr(),
len: msg.len(),
});
std::mem::forget(msg);
-1
}
}
}
}
fn call_scalar(function: ScalarFunction, input_bytes: &[u8]) -> Result<Box<[u8]>, Error> {
let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
let input_batch = reader
.next()
.ok_or_else(|| Error::IpcError("no record batch".into()))??;
let output_batch = function(&input_batch)?;
let mut buf = vec![];
let mut writer = FileWriter::try_new(&mut buf, &output_batch.schema())?;
writer.write(&output_batch)?;
writer.finish()?;
drop(writer);
Ok(buf.into())
}
pub struct RecordBatchIter {
iter: Box<dyn RecordBatchReader + Send>,
}
pub unsafe fn table_wrapper(
function: TableFunction,
ptr: *const u8,
len: usize,
out_slice: *mut CSlice,
) -> i32 {
unsafe {
let input = std::slice::from_raw_parts(ptr, len);
match call_table(function, input) {
Ok(iter) => {
out_slice.write(CSlice {
ptr: Box::into_raw(iter) as *const u8,
len: std::mem::size_of::<RecordBatchIter>(),
});
0
}
Err(err) => {
let msg = err.to_string().into_boxed_str();
out_slice.write(CSlice {
ptr: msg.as_ptr(),
len: msg.len(),
});
std::mem::forget(msg);
-1
}
}
}
}
fn call_table(function: TableFunction, input_bytes: &[u8]) -> Result<Box<RecordBatchIter>, Error> {
let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
let input_batch = reader
.next()
.ok_or_else(|| Error::IpcError("no record batch".into()))??;
let iter = function(&input_batch)?;
Ok(Box::new(RecordBatchIter { iter }))
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn record_batch_iterator_next(iter: *mut RecordBatchIter, out: *mut CSlice) {
unsafe {
let iter = iter.as_mut().expect("null pointer");
if let Some(Ok(batch)) = iter.iter.next() {
let mut buf = vec![];
let mut writer = FileWriter::try_new(&mut buf, &batch.schema()).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
drop(writer);
let buf = buf.into_boxed_slice();
out.write(CSlice {
ptr: buf.as_ptr(),
len: buf.len(),
});
std::mem::forget(buf);
} else {
out.write(CSlice {
ptr: std::ptr::null(),
len: 0,
});
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn record_batch_iterator_drop(iter: *mut RecordBatchIter) {
unsafe {
drop(Box::from_raw(iter));
}
}