arrow_udf/
ffi.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! FFI interfaces.
16
17use crate::{Error, ScalarFunction, TableFunction};
18use arrow_array::RecordBatchReader;
19use arrow_ipc::{reader::FileReader, writer::FileWriter};
20
21/// A symbol indicating the ABI version.
22///
23/// The version follows semantic versioning `MAJOR.MINOR`.
24/// - The major version is incremented when incompatible API changes are made.
25/// - The minor version is incremented when new functionality are added in a backward compatible manner.
26///
27/// # Changelog
28///
29/// - 3.0: Change type names in signatures.
30/// - 2.0: Add user defined struct type.
31/// - 1.0: Initial version.
32#[unsafe(no_mangle)]
33#[used]
34pub static ARROWUDF_VERSION_3_0: () = ();
35
36/// Allocate memory.
37///
38/// # Safety
39///
40/// See [`std::alloc::GlobalAlloc::alloc`].
41#[unsafe(no_mangle)]
42pub unsafe extern "C" fn alloc(len: usize, align: usize) -> *mut u8 {
43    unsafe { std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(len, align)) }
44}
45
46/// Deallocate memory.
47///
48/// # Safety
49///
50/// See [`std::alloc::GlobalAlloc::dealloc`].
51#[unsafe(no_mangle)]
52pub unsafe extern "C" fn dealloc(ptr: *mut u8, len: usize, align: usize) {
53    unsafe {
54        std::alloc::dealloc(
55            ptr,
56            std::alloc::Layout::from_size_align_unchecked(len, align),
57        );
58    }
59}
60
61/// A FFI-safe slice.
62#[repr(C)]
63#[derive(Debug)]
64pub struct CSlice {
65    pub ptr: *const u8,
66    pub len: usize,
67}
68
69/// A wrapper for calling scalar functions from C.
70///
71/// The input record batch is read from the IPC buffer pointed to by `ptr` and `len`.
72///
73/// The output data is written to the buffer pointed to by `out_slice`.
74/// The caller is responsible for deallocating the output buffer.
75///
76/// The return value is 0 on success, -1 on error.
77/// If successful, the record batch is written to the buffer.
78/// If failed, the error message is written to the buffer.
79///
80/// # Safety
81///
82/// `ptr`, `len`, `out_slice` must point to a valid buffer.
83pub unsafe fn scalar_wrapper(
84    function: ScalarFunction,
85    ptr: *const u8,
86    len: usize,
87    out_slice: *mut CSlice,
88) -> i32 {
89    unsafe {
90        let input = std::slice::from_raw_parts(ptr, len);
91        match call_scalar(function, input) {
92            Ok(data) => {
93                out_slice.write(CSlice {
94                    ptr: data.as_ptr(),
95                    len: data.len(),
96                });
97                std::mem::forget(data);
98                0
99            }
100            Err(err) => {
101                let msg = err.to_string().into_boxed_str();
102                out_slice.write(CSlice {
103                    ptr: msg.as_ptr(),
104                    len: msg.len(),
105                });
106                std::mem::forget(msg);
107                -1
108            }
109        }
110    }
111}
112
113/// The internal wrapper that returns a Result.
114fn call_scalar(function: ScalarFunction, input_bytes: &[u8]) -> Result<Box<[u8]>, Error> {
115    let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
116    let input_batch = reader
117        .next()
118        .ok_or_else(|| Error::IpcError("no record batch".into()))??;
119
120    let output_batch = function(&input_batch)?;
121
122    // Write data to IPC buffer
123    let mut buf = vec![];
124    let mut writer = FileWriter::try_new(&mut buf, &output_batch.schema())?;
125    writer.write(&output_batch)?;
126    writer.finish()?;
127    drop(writer);
128
129    Ok(buf.into())
130}
131
132/// An opaque type for iterating over record batches.
133pub struct RecordBatchIter {
134    iter: Box<dyn RecordBatchReader + Send>,
135}
136
137/// A wrapper for calling table functions from C.
138///
139/// The input record batch is read from the IPC buffer pointed to by `ptr` and `len`.
140///
141/// The output iterator is written to `out_slice`.
142///
143/// The return value is 0 on success, -1 on error.
144/// If successful, the record batch is written to the buffer.
145/// If failed, the error message is written to the buffer.
146///
147/// # Safety
148///
149/// `ptr`, `len`, `out_slice` must point to a valid buffer.
150pub unsafe fn table_wrapper(
151    function: TableFunction,
152    ptr: *const u8,
153    len: usize,
154    out_slice: *mut CSlice,
155) -> i32 {
156    unsafe {
157        let input = std::slice::from_raw_parts(ptr, len);
158        match call_table(function, input) {
159            Ok(iter) => {
160                out_slice.write(CSlice {
161                    ptr: Box::into_raw(iter) as *const u8,
162                    len: std::mem::size_of::<RecordBatchIter>(),
163                });
164                0
165            }
166            Err(err) => {
167                let msg = err.to_string().into_boxed_str();
168                out_slice.write(CSlice {
169                    ptr: msg.as_ptr(),
170                    len: msg.len(),
171                });
172                std::mem::forget(msg);
173                -1
174            }
175        }
176    }
177}
178
179fn call_table(function: TableFunction, input_bytes: &[u8]) -> Result<Box<RecordBatchIter>, Error> {
180    let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
181    let input_batch = reader
182        .next()
183        .ok_or_else(|| Error::IpcError("no record batch".into()))??;
184
185    let iter = function(&input_batch)?;
186    Ok(Box::new(RecordBatchIter { iter }))
187}
188
189/// Get the next record batch from the iterator.
190///
191/// The output record batch is written to the buffer pointed to by `out`.
192/// The caller is responsible for deallocating the output buffer.
193///
194/// # Safety
195///
196/// `iter` and `out` must be valid pointers.
197#[unsafe(no_mangle)]
198pub unsafe extern "C" fn record_batch_iterator_next(iter: *mut RecordBatchIter, out: *mut CSlice) {
199    unsafe {
200        let iter = iter.as_mut().expect("null pointer");
201        if let Some(Ok(batch)) = iter.iter.next() {
202            let mut buf = vec![];
203            let mut writer = FileWriter::try_new(&mut buf, &batch.schema()).unwrap();
204            writer.write(&batch).unwrap();
205            writer.finish().unwrap();
206            drop(writer);
207            let buf = buf.into_boxed_slice();
208
209            out.write(CSlice {
210                ptr: buf.as_ptr(),
211                len: buf.len(),
212            });
213            std::mem::forget(buf);
214        } else {
215            // TODO: return error message
216            out.write(CSlice {
217                ptr: std::ptr::null(),
218                len: 0,
219            });
220        }
221    }
222}
223
224/// Drop the iterator.
225///
226/// # Safety
227///
228/// `iter` must be valid pointers.
229#[unsafe(no_mangle)]
230pub unsafe extern "C" fn record_batch_iterator_drop(iter: *mut RecordBatchIter) {
231    unsafe {
232        drop(Box::from_raw(iter));
233    }
234}