arrow_udf/ffi.rs
1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! FFI interfaces.
16
17use crate::{Error, ScalarFunction, TableFunction};
18use arrow_array::RecordBatchReader;
19use arrow_ipc::{reader::FileReader, writer::FileWriter};
20
21/// A symbol indicating the ABI version.
22///
23/// The version follows semantic versioning `MAJOR.MINOR`.
24/// - The major version is incremented when incompatible API changes are made.
25/// - The minor version is incremented when new functionality are added in a backward compatible manner.
26///
27/// # Changelog
28///
29/// - 3.0: Change type names in signatures.
30/// - 2.0: Add user defined struct type.
31/// - 1.0: Initial version.
32#[unsafe(no_mangle)]
33#[used]
34pub static ARROWUDF_VERSION_3_0: () = ();
35
36/// Allocate memory.
37///
38/// # Safety
39///
40/// See [`std::alloc::GlobalAlloc::alloc`].
41#[unsafe(no_mangle)]
42pub unsafe extern "C" fn alloc(len: usize, align: usize) -> *mut u8 {
43 unsafe { std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(len, align)) }
44}
45
46/// Deallocate memory.
47///
48/// # Safety
49///
50/// See [`std::alloc::GlobalAlloc::dealloc`].
51#[unsafe(no_mangle)]
52pub unsafe extern "C" fn dealloc(ptr: *mut u8, len: usize, align: usize) {
53 unsafe {
54 std::alloc::dealloc(
55 ptr,
56 std::alloc::Layout::from_size_align_unchecked(len, align),
57 );
58 }
59}
60
61/// A FFI-safe slice.
62#[repr(C)]
63#[derive(Debug)]
64pub struct CSlice {
65 pub ptr: *const u8,
66 pub len: usize,
67}
68
69/// A wrapper for calling scalar functions from C.
70///
71/// The input record batch is read from the IPC buffer pointed to by `ptr` and `len`.
72///
73/// The output data is written to the buffer pointed to by `out_slice`.
74/// The caller is responsible for deallocating the output buffer.
75///
76/// The return value is 0 on success, -1 on error.
77/// If successful, the record batch is written to the buffer.
78/// If failed, the error message is written to the buffer.
79///
80/// # Safety
81///
82/// `ptr`, `len`, `out_slice` must point to a valid buffer.
83pub unsafe fn scalar_wrapper(
84 function: ScalarFunction,
85 ptr: *const u8,
86 len: usize,
87 out_slice: *mut CSlice,
88) -> i32 {
89 unsafe {
90 let input = std::slice::from_raw_parts(ptr, len);
91 match call_scalar(function, input) {
92 Ok(data) => {
93 out_slice.write(CSlice {
94 ptr: data.as_ptr(),
95 len: data.len(),
96 });
97 std::mem::forget(data);
98 0
99 }
100 Err(err) => {
101 let msg = err.to_string().into_boxed_str();
102 out_slice.write(CSlice {
103 ptr: msg.as_ptr(),
104 len: msg.len(),
105 });
106 std::mem::forget(msg);
107 -1
108 }
109 }
110 }
111}
112
113/// The internal wrapper that returns a Result.
114fn call_scalar(function: ScalarFunction, input_bytes: &[u8]) -> Result<Box<[u8]>, Error> {
115 let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
116 let input_batch = reader
117 .next()
118 .ok_or_else(|| Error::IpcError("no record batch".into()))??;
119
120 let output_batch = function(&input_batch)?;
121
122 // Write data to IPC buffer
123 let mut buf = vec![];
124 let mut writer = FileWriter::try_new(&mut buf, &output_batch.schema())?;
125 writer.write(&output_batch)?;
126 writer.finish()?;
127 drop(writer);
128
129 Ok(buf.into())
130}
131
132/// An opaque type for iterating over record batches.
133pub struct RecordBatchIter {
134 iter: Box<dyn RecordBatchReader + Send>,
135}
136
137/// A wrapper for calling table functions from C.
138///
139/// The input record batch is read from the IPC buffer pointed to by `ptr` and `len`.
140///
141/// The output iterator is written to `out_slice`.
142///
143/// The return value is 0 on success, -1 on error.
144/// If successful, the record batch is written to the buffer.
145/// If failed, the error message is written to the buffer.
146///
147/// # Safety
148///
149/// `ptr`, `len`, `out_slice` must point to a valid buffer.
150pub unsafe fn table_wrapper(
151 function: TableFunction,
152 ptr: *const u8,
153 len: usize,
154 out_slice: *mut CSlice,
155) -> i32 {
156 unsafe {
157 let input = std::slice::from_raw_parts(ptr, len);
158 match call_table(function, input) {
159 Ok(iter) => {
160 out_slice.write(CSlice {
161 ptr: Box::into_raw(iter) as *const u8,
162 len: std::mem::size_of::<RecordBatchIter>(),
163 });
164 0
165 }
166 Err(err) => {
167 let msg = err.to_string().into_boxed_str();
168 out_slice.write(CSlice {
169 ptr: msg.as_ptr(),
170 len: msg.len(),
171 });
172 std::mem::forget(msg);
173 -1
174 }
175 }
176 }
177}
178
179fn call_table(function: TableFunction, input_bytes: &[u8]) -> Result<Box<RecordBatchIter>, Error> {
180 let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
181 let input_batch = reader
182 .next()
183 .ok_or_else(|| Error::IpcError("no record batch".into()))??;
184
185 let iter = function(&input_batch)?;
186 Ok(Box::new(RecordBatchIter { iter }))
187}
188
189/// Get the next record batch from the iterator.
190///
191/// The output record batch is written to the buffer pointed to by `out`.
192/// The caller is responsible for deallocating the output buffer.
193///
194/// # Safety
195///
196/// `iter` and `out` must be valid pointers.
197#[unsafe(no_mangle)]
198pub unsafe extern "C" fn record_batch_iterator_next(iter: *mut RecordBatchIter, out: *mut CSlice) {
199 unsafe {
200 let iter = iter.as_mut().expect("null pointer");
201 if let Some(Ok(batch)) = iter.iter.next() {
202 let mut buf = vec![];
203 let mut writer = FileWriter::try_new(&mut buf, &batch.schema()).unwrap();
204 writer.write(&batch).unwrap();
205 writer.finish().unwrap();
206 drop(writer);
207 let buf = buf.into_boxed_slice();
208
209 out.write(CSlice {
210 ptr: buf.as_ptr(),
211 len: buf.len(),
212 });
213 std::mem::forget(buf);
214 } else {
215 // TODO: return error message
216 out.write(CSlice {
217 ptr: std::ptr::null(),
218 len: 0,
219 });
220 }
221 }
222}
223
224/// Drop the iterator.
225///
226/// # Safety
227///
228/// `iter` must be valid pointers.
229#[unsafe(no_mangle)]
230pub unsafe extern "C" fn record_batch_iterator_drop(iter: *mut RecordBatchIter) {
231 unsafe {
232 drop(Box::from_raw(iter));
233 }
234}