picodata_plugin/
util.rs

1use crate::error_code::ErrorCode;
2use abi_stable::StableAbi;
3use std::ptr::NonNull;
4use tarantool::error::BoxError;
5use tarantool::error::TarantoolErrorCode;
6use tarantool::ffi::tarantool as ffi;
7
8////////////////////////////////////////////////////////////////////////////////
9// FfiSafeBytes
10////////////////////////////////////////////////////////////////////////////////
11
12/// A helper struct for passing byte slices over the ABI boundary.
13#[repr(C)]
14#[derive(StableAbi, Clone, Copy, Debug)]
15pub struct FfiSafeBytes {
16    pointer: NonNull<u8>,
17    len: usize,
18}
19
20impl FfiSafeBytes {
21    #[inline(always)]
22    pub fn len(self) -> usize {
23        self.len
24    }
25
26    #[inline(always)]
27    pub unsafe fn from_raw_parts(pointer: NonNull<u8>, len: usize) -> Self {
28        Self { pointer, len }
29    }
30
31    #[inline(always)]
32    pub fn into_raw_parts(self) -> (*mut u8, usize) {
33        (self.pointer.as_ptr(), self.len)
34    }
35
36    /// Converts `self` back to a borrowed string `&[u8]`.
37    ///
38    /// # Safety
39    /// `FfiSafeBytes` can only be constructed from a valid rust byte slice,
40    /// so you only need to make sure that the origial `&[u8]` outlives the lifetime `'a`.
41    ///
42    /// This should generally be true when borrowing strings owned by the current
43    /// function and calling a function via FFI, but borrowing global data or
44    /// data stored within a `Rc` for example is probably unsafe.
45    pub unsafe fn as_bytes<'a>(self) -> &'a [u8] {
46        std::slice::from_raw_parts(self.pointer.as_ptr(), self.len)
47    }
48}
49
50impl Default for FfiSafeBytes {
51    #[inline(always)]
52    fn default() -> Self {
53        Self {
54            pointer: NonNull::dangling(),
55            len: 0,
56        }
57    }
58}
59
60impl<'a> From<&'a [u8]> for FfiSafeBytes {
61    #[inline(always)]
62    fn from(value: &'a [u8]) -> Self {
63        Self {
64            pointer: as_non_null_ptr(value),
65            len: value.len(),
66        }
67    }
68}
69
70impl<'a> From<&'a str> for FfiSafeBytes {
71    #[inline(always)]
72    fn from(value: &'a str) -> Self {
73        Self {
74            pointer: as_non_null_ptr(value.as_bytes()),
75            len: value.len(),
76        }
77    }
78}
79
80////////////////////////////////////////////////////////////////////////////////
81// FfiSafeStr
82////////////////////////////////////////////////////////////////////////////////
83
84/// A helper struct for passing rust strings over the ABI boundary.
85///
86/// This type can only be constructed from a valid rust string, so it's not
87/// necessary to validate the utf8 encoding when converting back to `&str`.
88#[repr(C)]
89#[derive(StableAbi, Clone, Copy, Debug)]
90pub struct FfiSafeStr {
91    pointer: NonNull<u8>,
92    len: usize,
93}
94
95impl FfiSafeStr {
96    #[inline(always)]
97    pub fn len(self) -> usize {
98        self.len
99    }
100
101    #[inline(always)]
102    pub unsafe fn from_raw_parts(pointer: NonNull<u8>, len: usize) -> Self {
103        Self { pointer, len }
104    }
105
106    /// # Safety
107    /// `bytes` must represent a valid utf8 string.
108    pub unsafe fn from_utf8_unchecked(bytes: &[u8]) -> Self {
109        let pointer = as_non_null_ptr(bytes);
110        let len = bytes.len();
111        Self { pointer, len }
112    }
113
114    #[inline(always)]
115    pub fn into_raw_parts(self) -> (*mut u8, usize) {
116        (self.pointer.as_ptr(), self.len)
117    }
118
119    /// Converts `self` back to a borrowed string `&str`.
120    ///
121    /// # Safety
122    /// `FfiSafeStr` can only be constructed from a valid rust `str`,
123    /// so you only need to make sure that the origial `str` outlives the lifetime `'a`.
124    ///
125    /// This should generally be true when borrowing strings owned by the current
126    /// function and calling a function via FFI, but borrowing global data or
127    /// data stored within a `Rc` for example is probably unsafe.
128    #[inline]
129    pub unsafe fn as_str<'a>(self) -> &'a str {
130        if cfg!(debug_assertions) {
131            std::str::from_utf8(self.as_bytes()).expect("should only be used with valid utf8")
132        } else {
133            std::str::from_utf8_unchecked(self.as_bytes())
134        }
135    }
136
137    #[inline(always)]
138    pub unsafe fn as_bytes<'a>(self) -> &'a [u8] {
139        std::slice::from_raw_parts(self.pointer.as_ptr(), self.len)
140    }
141}
142
143impl Default for FfiSafeStr {
144    #[inline(always)]
145    fn default() -> Self {
146        Self {
147            pointer: NonNull::dangling(),
148            len: 0,
149        }
150    }
151}
152
153impl<'a> From<&'a str> for FfiSafeStr {
154    #[inline(always)]
155    fn from(value: &'a str) -> Self {
156        Self {
157            pointer: as_non_null_ptr(value.as_bytes()),
158            len: value.len(),
159        }
160    }
161}
162
163////////////////////////////////////////////////////////////////////////////////
164// RegionGuard
165////////////////////////////////////////////////////////////////////////////////
166
167// TODO: move to tarantool-module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
168pub struct RegionGuard {
169    save_point: usize,
170}
171
172impl RegionGuard {
173    /// TODO
174    #[inline(always)]
175    pub fn new() -> Self {
176        // This is safe as long as the function is called within an initialized
177        // fiber runtime
178        let save_point = unsafe { ffi::box_region_used() };
179        Self { save_point }
180    }
181
182    /// TODO
183    #[inline(always)]
184    pub fn used_at_creation(&self) -> usize {
185        self.save_point
186    }
187}
188
189impl Drop for RegionGuard {
190    fn drop(&mut self) {
191        // This is safe as long as the function is called within an initialized
192        // fiber runtime
193        unsafe { ffi::box_region_truncate(self.save_point) }
194    }
195}
196
197////////////////////////////////////////////////////////////////////////////////
198// region allocation
199////////////////////////////////////////////////////////////////////////////////
200
201// TODO: move to tarantool module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
202/// TODO: doc
203#[inline]
204fn allocate_on_region(size: usize) -> Result<&'static mut [u8], BoxError> {
205    // SAFETY: requires initialized fiber runtime
206    let pointer = unsafe { ffi::box_region_alloc(size).cast::<u8>() };
207    if pointer.is_null() {
208        return Err(BoxError::last());
209    }
210    // SAFETY: safe because pointer is not null
211    let region_slice = unsafe { std::slice::from_raw_parts_mut(pointer, size) };
212    Ok(region_slice)
213}
214
215// TODO: move to tarantool module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
216/// Copies the provided `data` to the current fiber's region allocator returning
217/// a reference to the new allocation.
218///
219/// Use this to return dynamically sized values over the ABI boundary, for
220/// example in RPC handlers.
221///
222/// Note that the returned slice's lifetime is not really `'static`, but is
223/// determined by the following call to `box_region_truncate`.
224#[inline]
225pub fn copy_to_region(data: &[u8]) -> Result<&'static [u8], BoxError> {
226    let region_slice = allocate_on_region(data.len())?;
227    region_slice.copy_from_slice(data);
228    Ok(region_slice)
229}
230
231////////////////////////////////////////////////////////////////////////////////
232// RegionBuffer
233////////////////////////////////////////////////////////////////////////////////
234
235// TODO: move to tarantool module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
236/// TODO
237pub struct RegionBuffer {
238    guard: RegionGuard,
239
240    start: *mut u8,
241    count: usize,
242}
243
244impl RegionBuffer {
245    #[inline(always)]
246    pub fn new() -> Self {
247        Self {
248            guard: RegionGuard::new(),
249            start: NonNull::dangling().as_ptr(),
250            count: 0,
251        }
252    }
253
254    #[track_caller]
255    pub fn push(&mut self, data: &[u8]) -> Result<(), BoxError> {
256        let added_count = data.len();
257        let new_count = self.count + added_count;
258        unsafe {
259            let save_point = ffi::box_region_used();
260            let pointer: *mut u8 = ffi::box_region_alloc(added_count) as _;
261
262            if pointer.is_null() {
263                #[rustfmt::skip]
264                return Err(BoxError::new(TarantoolErrorCode::MemoryIssue, format!("failed to allocate {added_count} bytes on the region allocator")));
265            }
266
267            if self.start.is_null() || pointer == self.start.add(self.count) {
268                // New allocation is contiguous with the previous one
269                memcpy(pointer, data.as_ptr(), added_count);
270                self.count = new_count;
271                if self.start.is_null() {
272                    self.start = pointer;
273                }
274            } else {
275                // New allocation is in a different slab, need to reallocate
276                ffi::box_region_truncate(save_point);
277
278                let new_count = self.count + added_count;
279                let pointer: *mut u8 = ffi::box_region_alloc(new_count) as _;
280                memcpy(pointer, self.start, self.count);
281                memcpy(pointer.add(self.count), data.as_ptr(), added_count);
282                self.start = pointer;
283                self.count = new_count;
284            }
285        }
286
287        Ok(())
288    }
289
290    #[inline(always)]
291    pub fn get(&self) -> &[u8] {
292        if self.start.is_null() {
293            // Cannot construct a slice from a null pointer even if len is 0
294            &[]
295        } else {
296            unsafe { std::slice::from_raw_parts(self.start, self.count) }
297        }
298    }
299
300    #[inline]
301    pub fn into_raw_parts(self) -> (&'static [u8], usize) {
302        let save_point = self.guard.used_at_creation();
303        std::mem::forget(self.guard);
304        if self.start.is_null() {
305            // Cannot construct a slice from a null pointer even if len is 0
306            return (&[], save_point);
307        }
308        let slice = unsafe { std::slice::from_raw_parts(self.start, self.count) };
309        (slice, save_point)
310    }
311}
312
313impl std::io::Write for RegionBuffer {
314    #[inline(always)]
315    fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
316        if let Err(e) = self.push(data) {
317            #[rustfmt::skip]
318            return Err(std::io::Error::new(std::io::ErrorKind::OutOfMemory, e.message()));
319        }
320
321        Ok(data.len())
322    }
323
324    #[inline(always)]
325    fn flush(&mut self) -> std::io::Result<()> {
326        Ok(())
327    }
328}
329
330#[inline(always)]
331unsafe fn memcpy(destination: *mut u8, source: *const u8, count: usize) {
332    let to = std::slice::from_raw_parts_mut(destination, count);
333    let from = std::slice::from_raw_parts(source, count);
334    to.copy_from_slice(from)
335}
336
337////////////////////////////////////////////////////////////////////////////////
338// DisplayErrorLocation
339////////////////////////////////////////////////////////////////////////////////
340
341// TODO: move to taratool-module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/211
342pub struct DisplayErrorLocation<'a>(pub &'a BoxError);
343
344impl std::fmt::Display for DisplayErrorLocation<'_> {
345    #[inline]
346    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
347        if let Some((file, line)) = self.0.file().zip(self.0.line()) {
348            write!(f, "{file}:{line}: ")?;
349        }
350        Ok(())
351    }
352}
353
354////////////////////////////////////////////////////////////////////////////////
355// DisplayAsHexBytesLimitted
356////////////////////////////////////////////////////////////////////////////////
357
358// TODO: move to taratool-module https://git.picodata.io/picodata/picodata/tarantool-module/-/merge_requests/523
359pub struct DisplayAsHexBytesLimitted<'a>(pub &'a [u8]);
360
361impl std::fmt::Display for DisplayAsHexBytesLimitted<'_> {
362    #[inline]
363    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
364        if self.0.len() > 512 {
365            f.write_str("<too-big-to-display>")
366        } else {
367            tarantool::util::DisplayAsHexBytes(self.0).fmt(f)
368        }
369    }
370}
371
372////////////////////////////////////////////////////////////////////////////////
373// miscellaneous
374////////////////////////////////////////////////////////////////////////////////
375
376#[inline(always)]
377fn as_non_null_ptr<T>(data: &[T]) -> NonNull<T> {
378    let pointer = data.as_ptr();
379    // SAFETY: slice::as_ptr never returns `null`
380    // Also I have to cast to `* mut` here even though we're not going to
381    // mutate it, because there's no constructor that takes `* const`....
382    unsafe { NonNull::new_unchecked(pointer as *mut _) }
383}
384
385// TODO: this should be in tarantool module
386pub fn tarantool_error_to_box_error(e: tarantool::error::Error) -> BoxError {
387    match e {
388        tarantool::error::Error::Tarantool(e) => e,
389        other => BoxError::new(ErrorCode::Other, other.to_string()),
390    }
391}
392
393////////////////////////////////////////////////////////////////////////////////
394// test
395////////////////////////////////////////////////////////////////////////////////
396
397#[cfg(feature = "internal_test")]
398mod test {
399    use super::*;
400
401    #[tarantool::test]
402    fn region_buffer() {
403        #[derive(serde::Serialize, Debug)]
404        struct S {
405            name: String,
406            x: f32,
407            y: f32,
408            array: Vec<(i32, i32, bool)>,
409        }
410
411        let s = S {
412            name: "foo".into(),
413            x: 4.2,
414            y: 6.9,
415            array: vec![(1, 2, true), (3, 4, false)],
416        };
417
418        let vec = rmp_serde::to_vec(&s).unwrap();
419        let mut buffer = RegionBuffer::new();
420        rmp_serde::encode::write(&mut buffer, &s).unwrap();
421        assert_eq!(vec, buffer.get());
422    }
423}