arrow_udf/
ffi.rs

1// Copyright 2024 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! FFI interfaces.
16
17use crate::{Error, ScalarFunction, TableFunction};
18use arrow_array::RecordBatchReader;
19use arrow_ipc::{reader::FileReader, writer::FileWriter};
20
21/// A symbol indicating the ABI version.
22///
23/// The version follows semantic versioning `MAJOR.MINOR`.
24/// - The major version is incremented when incompatible API changes are made.
25/// - The minor version is incremented when new functionality are added in a backward compatible manner.
26///
27/// # Changelog
28///
29/// - 3.0: Change type names in signatures.
30/// - 2.0: Add user defined struct type.
31/// - 1.0: Initial version.
32#[no_mangle]
33#[used]
34pub static ARROWUDF_VERSION_3_0: () = ();
35
36/// Allocate memory.
37///
38/// # Safety
39///
40/// See [`std::alloc::GlobalAlloc::alloc`].
41#[no_mangle]
42pub unsafe extern "C" fn alloc(len: usize, align: usize) -> *mut u8 {
43    std::alloc::alloc(std::alloc::Layout::from_size_align_unchecked(len, align))
44}
45
46/// Deallocate memory.
47///
48/// # Safety
49///
50/// See [`std::alloc::GlobalAlloc::dealloc`].
51#[no_mangle]
52pub unsafe extern "C" fn dealloc(ptr: *mut u8, len: usize, align: usize) {
53    std::alloc::dealloc(
54        ptr,
55        std::alloc::Layout::from_size_align_unchecked(len, align),
56    );
57}
58
59/// A FFI-safe slice.
60#[repr(C)]
61#[derive(Debug)]
62pub struct CSlice {
63    pub ptr: *const u8,
64    pub len: usize,
65}
66
67/// A wrapper for calling scalar functions from C.
68///
69/// The input record batch is read from the IPC buffer pointed to by `ptr` and `len`.
70///
71/// The output data is written to the buffer pointed to by `out_slice`.
72/// The caller is responsible for deallocating the output buffer.
73///
74/// The return value is 0 on success, -1 on error.
75/// If successful, the record batch is written to the buffer.
76/// If failed, the error message is written to the buffer.
77///
78/// # Safety
79///
80/// `ptr`, `len`, `out_slice` must point to a valid buffer.
81pub unsafe fn scalar_wrapper(
82    function: ScalarFunction,
83    ptr: *const u8,
84    len: usize,
85    out_slice: *mut CSlice,
86) -> i32 {
87    let input = std::slice::from_raw_parts(ptr, len);
88    match call_scalar(function, input) {
89        Ok(data) => {
90            out_slice.write(CSlice {
91                ptr: data.as_ptr(),
92                len: data.len(),
93            });
94            std::mem::forget(data);
95            0
96        }
97        Err(err) => {
98            let msg = err.to_string().into_boxed_str();
99            out_slice.write(CSlice {
100                ptr: msg.as_ptr(),
101                len: msg.len(),
102            });
103            std::mem::forget(msg);
104            -1
105        }
106    }
107}
108
109/// The internal wrapper that returns a Result.
110fn call_scalar(function: ScalarFunction, input_bytes: &[u8]) -> Result<Box<[u8]>, Error> {
111    let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
112    let input_batch = reader
113        .next()
114        .ok_or_else(|| Error::IpcError("no record batch".into()))??;
115
116    let output_batch = function(&input_batch)?;
117
118    // Write data to IPC buffer
119    let mut buf = vec![];
120    let mut writer = FileWriter::try_new(&mut buf, &output_batch.schema())?;
121    writer.write(&output_batch)?;
122    writer.finish()?;
123    drop(writer);
124
125    Ok(buf.into())
126}
127
128/// An opaque type for iterating over record batches.
129pub struct RecordBatchIter {
130    iter: Box<dyn RecordBatchReader + Send>,
131}
132
133/// A wrapper for calling table functions from C.
134///
135/// The input record batch is read from the IPC buffer pointed to by `ptr` and `len`.
136///
137/// The output iterator is written to `out_slice`.
138///
139/// The return value is 0 on success, -1 on error.
140/// If successful, the record batch is written to the buffer.
141/// If failed, the error message is written to the buffer.
142///
143/// # Safety
144///
145/// `ptr`, `len`, `out_slice` must point to a valid buffer.
146pub unsafe fn table_wrapper(
147    function: TableFunction,
148    ptr: *const u8,
149    len: usize,
150    out_slice: *mut CSlice,
151) -> i32 {
152    let input = std::slice::from_raw_parts(ptr, len);
153    match call_table(function, input) {
154        Ok(iter) => {
155            out_slice.write(CSlice {
156                ptr: Box::into_raw(iter) as *const u8,
157                len: std::mem::size_of::<RecordBatchIter>(),
158            });
159            0
160        }
161        Err(err) => {
162            let msg = err.to_string().into_boxed_str();
163            out_slice.write(CSlice {
164                ptr: msg.as_ptr(),
165                len: msg.len(),
166            });
167            std::mem::forget(msg);
168            -1
169        }
170    }
171}
172
173fn call_table(function: TableFunction, input_bytes: &[u8]) -> Result<Box<RecordBatchIter>, Error> {
174    let mut reader = FileReader::try_new(std::io::Cursor::new(input_bytes), None)?;
175    let input_batch = reader
176        .next()
177        .ok_or_else(|| Error::IpcError("no record batch".into()))??;
178
179    let iter = function(&input_batch)?;
180    Ok(Box::new(RecordBatchIter { iter }))
181}
182
183/// Get the next record batch from the iterator.
184///
185/// The output record batch is written to the buffer pointed to by `out`.
186/// The caller is responsible for deallocating the output buffer.
187///
188/// # Safety
189///
190/// `iter` and `out` must be valid pointers.
191#[no_mangle]
192pub unsafe extern "C" fn record_batch_iterator_next(iter: *mut RecordBatchIter, out: *mut CSlice) {
193    let iter = iter.as_mut().expect("null pointer");
194    if let Some(Ok(batch)) = iter.iter.next() {
195        let mut buf = vec![];
196        let mut writer = FileWriter::try_new(&mut buf, &batch.schema()).unwrap();
197        writer.write(&batch).unwrap();
198        writer.finish().unwrap();
199        drop(writer);
200        let buf = buf.into_boxed_slice();
201
202        out.write(CSlice {
203            ptr: buf.as_ptr(),
204            len: buf.len(),
205        });
206        std::mem::forget(buf);
207    } else {
208        // TODO: return error message
209        out.write(CSlice {
210            ptr: std::ptr::null(),
211            len: 0,
212        });
213    }
214}
215
216/// Drop the iterator.
217///
218/// # Safety
219///
220/// `iter` must be valid pointers.
221#[no_mangle]
222pub unsafe extern "C" fn record_batch_iterator_drop(iter: *mut RecordBatchIter) {
223    drop(Box::from_raw(iter));
224}