picodata_plugin/
util.rs

1use crate::error_code::ErrorCode;
2use abi_stable::StableAbi;
3use std::io::Cursor;
4use std::ptr::NonNull;
5use tarantool::error::BoxError;
6use tarantool::error::TarantoolErrorCode;
7use tarantool::ffi::tarantool as ffi;
8
9////////////////////////////////////////////////////////////////////////////////
10// FfiSafeBytes
11////////////////////////////////////////////////////////////////////////////////
12
13/// A helper struct for passing byte slices over the ABI boundary.
14#[repr(C)]
15#[derive(StableAbi, Clone, Copy, Debug)]
16pub struct FfiSafeBytes {
17    pointer: NonNull<u8>,
18    len: usize,
19}
20
21impl FfiSafeBytes {
22    #[inline(always)]
23    pub fn len(self) -> usize {
24        self.len
25    }
26
27    #[inline(always)]
28    pub fn is_empty(self) -> bool {
29        self.len == 0
30    }
31
32    /// # Safety
33    ///
34    /// `pointer` and `len` must be correct pointer and length
35    #[inline(always)]
36    pub unsafe fn from_raw_parts(pointer: NonNull<u8>, len: usize) -> Self {
37        Self { pointer, len }
38    }
39
40    #[inline(always)]
41    pub fn into_raw_parts(self) -> (*mut u8, usize) {
42        (self.pointer.as_ptr(), self.len)
43    }
44
45    /// Converts `self` back to a borrowed string `&[u8]`.
46    ///
47    /// # Safety
48    /// `FfiSafeBytes` can only be constructed from a valid rust byte slice,
49    /// so you only need to make sure that the origial `&[u8]` outlives the lifetime `'a`.
50    ///
51    /// This should generally be true when borrowing strings owned by the current
52    /// function and calling a function via FFI, but borrowing global data or
53    /// data stored within a `Rc` for example is probably unsafe.
54    pub unsafe fn as_bytes<'a>(self) -> &'a [u8] {
55        std::slice::from_raw_parts(self.pointer.as_ptr(), self.len)
56    }
57}
58
59impl Default for FfiSafeBytes {
60    #[inline(always)]
61    fn default() -> Self {
62        Self {
63            pointer: NonNull::dangling(),
64            len: 0,
65        }
66    }
67}
68
69impl<'a> From<&'a [u8]> for FfiSafeBytes {
70    #[inline(always)]
71    fn from(value: &'a [u8]) -> Self {
72        Self {
73            pointer: as_non_null_ptr(value),
74            len: value.len(),
75        }
76    }
77}
78
79impl<'a> From<&'a str> for FfiSafeBytes {
80    #[inline(always)]
81    fn from(value: &'a str) -> Self {
82        Self {
83            pointer: as_non_null_ptr(value.as_bytes()),
84            len: value.len(),
85        }
86    }
87}
88
89////////////////////////////////////////////////////////////////////////////////
90// FfiSafeStr
91////////////////////////////////////////////////////////////////////////////////
92
93/// A helper struct for passing rust strings over the ABI boundary.
94///
95/// This type can only be constructed from a valid rust string, so it's not
96/// necessary to validate the utf8 encoding when converting back to `&str`.
97#[repr(C)]
98#[derive(StableAbi, Clone, Copy, Debug)]
99pub struct FfiSafeStr {
100    pointer: NonNull<u8>,
101    len: usize,
102}
103
104impl FfiSafeStr {
105    #[inline(always)]
106    pub fn len(self) -> usize {
107        self.len
108    }
109
110    #[inline(always)]
111    pub fn is_empty(self) -> bool {
112        self.len == 0
113    }
114
115    /// # Safety
116    ///
117    /// `pointer` and `len` must be correct pointer and length
118    #[inline(always)]
119    pub unsafe fn from_raw_parts(pointer: NonNull<u8>, len: usize) -> Self {
120        Self { pointer, len }
121    }
122
123    /// # Safety
124    /// `bytes` must represent a valid utf8 string.
125    pub unsafe fn from_utf8_unchecked(bytes: &[u8]) -> Self {
126        let pointer = as_non_null_ptr(bytes);
127        let len = bytes.len();
128        Self { pointer, len }
129    }
130
131    #[inline(always)]
132    pub fn into_raw_parts(self) -> (*mut u8, usize) {
133        (self.pointer.as_ptr(), self.len)
134    }
135
136    /// Converts `self` back to a borrowed string `&str`.
137    ///
138    /// # Safety
139    /// `FfiSafeStr` can only be constructed from a valid rust `str`,
140    /// so you only need to make sure that the origial `str` outlives the lifetime `'a`.
141    ///
142    /// This should generally be true when borrowing strings owned by the current
143    /// function and calling a function via FFI, but borrowing global data or
144    /// data stored within a `Rc` for example is probably unsafe.
145    #[inline]
146    pub unsafe fn as_str<'a>(self) -> &'a str {
147        if cfg!(debug_assertions) {
148            std::str::from_utf8(self.as_bytes()).expect("should only be used with valid utf8")
149        } else {
150            std::str::from_utf8_unchecked(self.as_bytes())
151        }
152    }
153
154    /// Converts `self` back to a borrowed string `&[u8]`.
155    ///
156    /// # Safety
157    /// `FfiSafeStr` can only be constructed from a valid rust byte slice,
158    /// so you only need to make sure that the original `&[u8]` outlives the lifetime `'a`.
159    ///
160    /// This should generally be true when borrowing strings owned by the current
161    /// function and calling a function via FFI, but borrowing global data or
162    /// data stored within a `Rc` for example is probably unsafe.
163    #[inline(always)]
164    pub unsafe fn as_bytes<'a>(self) -> &'a [u8] {
165        std::slice::from_raw_parts(self.pointer.as_ptr(), self.len)
166    }
167}
168
169impl Default for FfiSafeStr {
170    #[inline(always)]
171    fn default() -> Self {
172        Self {
173            pointer: NonNull::dangling(),
174            len: 0,
175        }
176    }
177}
178
179impl<'a> From<&'a str> for FfiSafeStr {
180    #[inline(always)]
181    fn from(value: &'a str) -> Self {
182        Self {
183            pointer: as_non_null_ptr(value.as_bytes()),
184            len: value.len(),
185        }
186    }
187}
188
189////////////////////////////////////////////////////////////////////////////////
190// RegionGuard
191////////////////////////////////////////////////////////////////////////////////
192
193// TODO: move to tarantool-module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
194pub struct RegionGuard {
195    save_point: usize,
196}
197
198impl RegionGuard {
199    /// TODO
200    #[inline(always)]
201    #[allow(clippy::new_without_default)]
202    pub fn new() -> Self {
203        // This is safe as long as the function is called within an initialized
204        // fiber runtime
205        let save_point = unsafe { ffi::box_region_used() };
206        Self { save_point }
207    }
208
209    /// TODO
210    #[inline(always)]
211    pub fn used_at_creation(&self) -> usize {
212        self.save_point
213    }
214}
215
216impl Drop for RegionGuard {
217    fn drop(&mut self) {
218        // This is safe as long as the function is called within an initialized
219        // fiber runtime
220        unsafe { ffi::box_region_truncate(self.save_point) }
221    }
222}
223
224////////////////////////////////////////////////////////////////////////////////
225// region allocation
226////////////////////////////////////////////////////////////////////////////////
227
228// TODO: move to tarantool module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
229/// TODO: doc
230#[inline]
231fn allocate_on_region(size: usize) -> Result<&'static mut [u8], BoxError> {
232    // SAFETY: requires initialized fiber runtime
233    let pointer = unsafe { ffi::box_region_alloc(size).cast::<u8>() };
234    if pointer.is_null() {
235        return Err(BoxError::last());
236    }
237    // SAFETY: safe because pointer is not null
238    let region_slice = unsafe { std::slice::from_raw_parts_mut(pointer, size) };
239    Ok(region_slice)
240}
241
242// TODO: move to tarantool module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
243/// Copies the provided `data` to the current fiber's region allocator returning
244/// a reference to the new allocation.
245///
246/// Use this to return dynamically sized values over the ABI boundary, for
247/// example in RPC handlers.
248///
249/// Note that the returned slice's lifetime is not really `'static`, but is
250/// determined by the following call to `box_region_truncate`.
251#[inline]
252pub fn copy_to_region(data: &[u8]) -> Result<&'static [u8], BoxError> {
253    let region_slice = allocate_on_region(data.len())?;
254    region_slice.copy_from_slice(data);
255    Ok(region_slice)
256}
257
258////////////////////////////////////////////////////////////////////////////////
259// RegionBuffer
260////////////////////////////////////////////////////////////////////////////////
261
262// TODO: move to tarantool module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/210
263/// TODO
264pub struct RegionBuffer {
265    guard: RegionGuard,
266
267    start: *mut u8,
268    count: usize,
269}
270
271impl RegionBuffer {
272    #[inline(always)]
273    #[allow(clippy::new_without_default)]
274    pub fn new() -> Self {
275        Self {
276            guard: RegionGuard::new(),
277            start: NonNull::dangling().as_ptr(),
278            count: 0,
279        }
280    }
281
282    #[track_caller]
283    pub fn push(&mut self, data: &[u8]) -> Result<(), BoxError> {
284        let added_count = data.len();
285        let new_count = self.count + added_count;
286        unsafe {
287            let save_point = ffi::box_region_used();
288            let pointer: *mut u8 = ffi::box_region_alloc(added_count) as _;
289
290            if pointer.is_null() {
291                #[rustfmt::skip]
292                return Err(BoxError::new(TarantoolErrorCode::MemoryIssue, format!("failed to allocate {added_count} bytes on the region allocator")));
293            }
294
295            if self.start.is_null() || pointer == self.start.add(self.count) {
296                // New allocation is contiguous with the previous one
297                memcpy(pointer, data.as_ptr(), added_count);
298                self.count = new_count;
299                if self.start.is_null() {
300                    self.start = pointer;
301                }
302            } else {
303                // New allocation is in a different slab, need to reallocate
304                ffi::box_region_truncate(save_point);
305
306                let new_count = self.count + added_count;
307                let pointer: *mut u8 = ffi::box_region_alloc(new_count) as _;
308                memcpy(pointer, self.start, self.count);
309                memcpy(pointer.add(self.count), data.as_ptr(), added_count);
310                self.start = pointer;
311                self.count = new_count;
312            }
313        }
314
315        Ok(())
316    }
317
318    #[inline(always)]
319    pub fn get(&self) -> &[u8] {
320        if self.start.is_null() {
321            // Cannot construct a slice from a null pointer even if len is 0
322            &[]
323        } else {
324            unsafe { std::slice::from_raw_parts(self.start, self.count) }
325        }
326    }
327
328    #[inline]
329    pub fn into_raw_parts(self) -> (&'static [u8], usize) {
330        let save_point = self.guard.used_at_creation();
331        std::mem::forget(self.guard);
332        if self.start.is_null() {
333            // Cannot construct a slice from a null pointer even if len is 0
334            return (&[], save_point);
335        }
336        let slice = unsafe { std::slice::from_raw_parts(self.start, self.count) };
337        (slice, save_point)
338    }
339}
340
341impl std::io::Write for RegionBuffer {
342    #[inline(always)]
343    fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
344        if let Err(e) = self.push(data) {
345            #[rustfmt::skip]
346            return Err(std::io::Error::new(std::io::ErrorKind::OutOfMemory, e.message()));
347        }
348
349        Ok(data.len())
350    }
351
352    #[inline(always)]
353    fn flush(&mut self) -> std::io::Result<()> {
354        Ok(())
355    }
356}
357
358#[inline(always)]
359unsafe fn memcpy(destination: *mut u8, source: *const u8, count: usize) {
360    let to = std::slice::from_raw_parts_mut(destination, count);
361    let from = std::slice::from_raw_parts(source, count);
362    to.copy_from_slice(from)
363}
364
365////////////////////////////////////////////////////////////////////////////////
366// DisplayErrorLocation
367////////////////////////////////////////////////////////////////////////////////
368
369// TODO: move to taratool-module https://git.picodata.io/picodata/picodata/tarantool-module/-/issues/211
370pub struct DisplayErrorLocation<'a>(pub &'a BoxError);
371
372impl std::fmt::Display for DisplayErrorLocation<'_> {
373    #[inline]
374    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
375        if let Some((file, line)) = self.0.file().zip(self.0.line()) {
376            write!(f, "{file}:{line}: ")?;
377        }
378        Ok(())
379    }
380}
381
382////////////////////////////////////////////////////////////////////////////////
383// DisplayAsHexBytesLimitted
384////////////////////////////////////////////////////////////////////////////////
385
386// TODO: move to taratool-module https://git.picodata.io/picodata/picodata/tarantool-module/-/merge_requests/523
387pub struct DisplayAsHexBytesLimitted<'a>(pub &'a [u8]);
388
389impl std::fmt::Display for DisplayAsHexBytesLimitted<'_> {
390    #[inline]
391    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
392        if self.0.len() > 512 {
393            f.write_str("<too-big-to-display>")
394        } else {
395            tarantool::util::DisplayAsHexBytes(self.0).fmt(f)
396        }
397    }
398}
399
400////////////////////////////////////////////////////////////////////////////////
401// msgpack
402////////////////////////////////////////////////////////////////////////////////
403
404/// Decode a utf-8 string from the provided msgpack.
405/// Advances the cursor to the first byte after the encoded string.
406#[track_caller]
407#[inline]
408pub fn msgpack_decode_str(data: &[u8]) -> Result<&str, BoxError> {
409    let mut cursor = Cursor::new(data);
410    let length = rmp::decode::read_str_len(&mut cursor).map_err(invalid_msgpack)? as usize;
411
412    let res = str_from_cursor(length, &mut cursor)?;
413    let (_, tail) = cursor_split(&cursor);
414    if !tail.is_empty() {
415        return Err(invalid_msgpack(format!(
416            "unexpected data after msgpack value: {}",
417            DisplayAsHexBytesLimitted(tail)
418        )));
419    }
420
421    Ok(res)
422}
423
424/// Decode a utf-8 string from the provided msgpack.
425/// Advances the cursor to the first byte after the encoded string.
426#[track_caller]
427pub fn msgpack_read_str<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a str, BoxError> {
428    let length = rmp::decode::read_str_len(cursor).map_err(invalid_msgpack)? as usize;
429
430    str_from_cursor(length, cursor)
431}
432
433/// Continues decoding a utf-8 string from the provided msgpack after `marker`
434/// which must have been decode from the same `buffer`. The `buffer` cursor
435/// must be set to the first byte after the decoded `marker`.
436/// Advances the cursor to the first byte after the encoded string.
437///
438/// Returns `Ok(None)` if `marker` doesn't correspond to a msgpack string.
439/// Returns errors in other failure cases:
440/// - if there's not enough data in stream
441/// - if string is not valid utf-8
442#[track_caller]
443pub fn msgpack_read_rest_of_str<'a>(
444    marker: rmp::Marker,
445    cursor: &mut Cursor<&'a [u8]>,
446) -> Result<Option<&'a str>, BoxError> {
447    use rmp::decode::RmpRead as _;
448
449    let length = match marker {
450        rmp::Marker::FixStr(v) => v as usize,
451        rmp::Marker::Str8 => cursor.read_data_u8().map_err(invalid_msgpack)? as usize,
452        rmp::Marker::Str16 => cursor.read_data_u16().map_err(invalid_msgpack)? as usize,
453        rmp::Marker::Str32 => cursor.read_data_u32().map_err(invalid_msgpack)? as usize,
454        _ => return Ok(None),
455    };
456
457    str_from_cursor(length, cursor).map(Some)
458}
459
460#[inline]
461#[track_caller]
462fn str_from_cursor<'a>(length: usize, cursor: &mut Cursor<&'a [u8]>) -> Result<&'a str, BoxError> {
463    let start_index = cursor.position() as usize;
464    let data = *cursor.get_ref();
465    let remaining_length = data.len() - start_index;
466    if remaining_length < length {
467        return Err(invalid_msgpack(format!(
468            "expected a string of length {length}, got {remaining_length}"
469        )));
470    }
471
472    let end_index = start_index + length;
473    let res = std::str::from_utf8(&data[start_index..end_index]).map_err(invalid_msgpack)?;
474    cursor.set_position(end_index as _);
475    Ok(res)
476}
477
478/// Decode binary data from the provided msgpack.
479#[track_caller]
480pub fn msgpack_decode_bin(data: &[u8]) -> Result<&[u8], BoxError> {
481    let mut cursor = Cursor::new(data);
482    let length = rmp::decode::read_bin_len(&mut cursor).map_err(invalid_msgpack)? as usize;
483
484    let res = bin_from_cursor(length, &mut cursor)?;
485    let (_, tail) = cursor_split(&cursor);
486    if !tail.is_empty() {
487        return Err(invalid_msgpack(format!(
488            "unexpected data after msgpack value: {}",
489            DisplayAsHexBytesLimitted(tail)
490        )));
491    }
492
493    Ok(res)
494}
495
496/// Decode binary data from the provided msgpack.
497/// Advances the cursor to the first byte after the encoded binary data.
498#[track_caller]
499pub fn msgpack_read_bin<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], BoxError> {
500    let length = rmp::decode::read_bin_len(cursor).map_err(invalid_msgpack)? as usize;
501
502    bin_from_cursor(length, cursor)
503}
504
505/// Continues decoding a binary data from the provided msgpack after `marker`
506/// which must have been decode from the same `cursor`. The `cursor` cursor
507/// must be set to the first byte after the decoded `marker`.
508/// Advances the cursor to the first byte after the encoded binary data.
509///
510/// Returns `Ok(None)` if `marker` doesn't correspond to msgpack binary data.
511/// Returns errors in other failure cases:
512/// - if there's not enough data in stream
513#[track_caller]
514pub fn msgpack_read_rest_of_bin<'a>(
515    marker: rmp::Marker,
516    cursor: &mut Cursor<&'a [u8]>,
517) -> Result<Option<&'a [u8]>, BoxError> {
518    use rmp::decode::RmpRead as _;
519
520    let length = match marker {
521        rmp::Marker::Bin8 => cursor.read_data_u8().map_err(invalid_msgpack)? as usize,
522        rmp::Marker::Bin16 => cursor.read_data_u16().map_err(invalid_msgpack)? as usize,
523        rmp::Marker::Bin32 => cursor.read_data_u32().map_err(invalid_msgpack)? as usize,
524        _ => return Ok(None),
525    };
526
527    bin_from_cursor(length, cursor).map(Some)
528}
529
530#[inline]
531#[track_caller]
532fn bin_from_cursor<'a>(length: usize, cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], BoxError> {
533    let start_index = cursor.position() as usize;
534    let data = *cursor.get_ref();
535    let remaining_length = data.len() - start_index;
536    if remaining_length < length {
537        return Err(invalid_msgpack(format!(
538            "expected binary data of length {length}, got {remaining_length}"
539        )));
540    }
541
542    let end_index = start_index + length;
543    let res = &data[start_index..end_index];
544    cursor.set_position(end_index as _);
545    Ok(res)
546}
547
548// TODO Remove when [`std::io::Cursor::split`] is stabilized.
549fn cursor_split<'a>(cursor: &Cursor<&'a [u8]>) -> (&'a [u8], &'a [u8]) {
550    let slice = cursor.get_ref();
551    let pos = cursor.position().min(slice.len() as u64);
552    slice.split_at(pos as usize)
553}
554
555#[inline(always)]
556#[track_caller]
557fn invalid_msgpack(error: impl ToString) -> BoxError {
558    BoxError::new(TarantoolErrorCode::InvalidMsgpack, error.to_string())
559}
560
561////////////////////////////////////////////////////////////////////////////////
562// miscellaneous
563////////////////////////////////////////////////////////////////////////////////
564
565#[inline(always)]
566fn as_non_null_ptr<T>(data: &[T]) -> NonNull<T> {
567    let pointer = data.as_ptr();
568    // SAFETY: slice::as_ptr never returns `null`
569    // Also I have to cast to `* mut` here even though we're not going to
570    // mutate it, because there's no constructor that takes `* const`....
571    unsafe { NonNull::new_unchecked(pointer as *mut _) }
572}
573
574// TODO: this should be in tarantool module
575pub fn tarantool_error_to_box_error(e: tarantool::error::Error) -> BoxError {
576    match e {
577        tarantool::error::Error::Tarantool(e) => e,
578        other => BoxError::new(ErrorCode::Other, other.to_string()),
579    }
580}
581
582////////////////////////////////////////////////////////////////////////////////
583// test
584////////////////////////////////////////////////////////////////////////////////
585
586#[cfg(feature = "internal_test")]
587mod test {
588    use super::*;
589
590    #[tarantool::test]
591    fn region_buffer() {
592        #[derive(serde::Serialize, Debug)]
593        struct S {
594            name: String,
595            x: f32,
596            y: f32,
597            array: Vec<(i32, i32, bool)>,
598        }
599
600        let s = S {
601            name: "foo".into(),
602            x: 4.2,
603            y: 6.9,
604            array: vec![(1, 2, true), (3, 4, false)],
605        };
606
607        let vec = rmp_serde::to_vec(&s).unwrap();
608        let mut buffer = RegionBuffer::new();
609        rmp_serde::encode::write(&mut buffer, &s).unwrap();
610        assert_eq!(vec, buffer.get());
611    }
612}