use crate::error_code::ErrorCode;
use abi_stable::StableAbi;
use std::io::Cursor;
use std::ptr::NonNull;
use tarantool::error::BoxError;
use tarantool::error::TarantoolErrorCode;
use tarantool::ffi::tarantool as ffi;
#[repr(C)]
#[derive(StableAbi, Clone, Copy, Debug)]
pub struct FfiSafeBytes {
pointer: NonNull<u8>,
len: usize,
}
impl FfiSafeBytes {
#[inline(always)]
pub fn len(self) -> usize {
self.len
}
#[inline(always)]
pub fn is_empty(self) -> bool {
self.len == 0
}
#[inline(always)]
pub unsafe fn from_raw_parts(pointer: NonNull<u8>, len: usize) -> Self {
Self { pointer, len }
}
#[inline(always)]
pub fn into_raw_parts(self) -> (*mut u8, usize) {
(self.pointer.as_ptr(), self.len)
}
pub unsafe fn as_bytes<'a>(self) -> &'a [u8] {
std::slice::from_raw_parts(self.pointer.as_ptr(), self.len)
}
}
impl Default for FfiSafeBytes {
#[inline(always)]
fn default() -> Self {
Self {
pointer: NonNull::dangling(),
len: 0,
}
}
}
impl<'a> From<&'a [u8]> for FfiSafeBytes {
#[inline(always)]
fn from(value: &'a [u8]) -> Self {
Self {
pointer: as_non_null_ptr(value),
len: value.len(),
}
}
}
impl<'a> From<&'a str> for FfiSafeBytes {
#[inline(always)]
fn from(value: &'a str) -> Self {
Self {
pointer: as_non_null_ptr(value.as_bytes()),
len: value.len(),
}
}
}
#[repr(C)]
#[derive(StableAbi, Clone, Copy, Debug)]
pub struct FfiSafeStr {
pointer: NonNull<u8>,
len: usize,
}
impl FfiSafeStr {
#[inline(always)]
pub fn len(self) -> usize {
self.len
}
#[inline(always)]
pub fn is_empty(self) -> bool {
self.len == 0
}
#[inline(always)]
pub unsafe fn from_raw_parts(pointer: NonNull<u8>, len: usize) -> Self {
Self { pointer, len }
}
pub unsafe fn from_utf8_unchecked(bytes: &[u8]) -> Self {
let pointer = as_non_null_ptr(bytes);
let len = bytes.len();
Self { pointer, len }
}
#[inline(always)]
pub fn into_raw_parts(self) -> (*mut u8, usize) {
(self.pointer.as_ptr(), self.len)
}
#[inline]
pub unsafe fn as_str<'a>(self) -> &'a str {
if cfg!(debug_assertions) {
std::str::from_utf8(self.as_bytes()).expect("should only be used with valid utf8")
} else {
std::str::from_utf8_unchecked(self.as_bytes())
}
}
#[inline(always)]
pub unsafe fn as_bytes<'a>(self) -> &'a [u8] {
std::slice::from_raw_parts(self.pointer.as_ptr(), self.len)
}
}
impl Default for FfiSafeStr {
#[inline(always)]
fn default() -> Self {
Self {
pointer: NonNull::dangling(),
len: 0,
}
}
}
impl<'a> From<&'a str> for FfiSafeStr {
#[inline(always)]
fn from(value: &'a str) -> Self {
Self {
pointer: as_non_null_ptr(value.as_bytes()),
len: value.len(),
}
}
}
#[derive(Debug)]
pub struct RegionGuard {
save_point: usize,
}
impl RegionGuard {
#[inline(always)]
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let save_point = unsafe { ffi::box_region_used() };
Self { save_point }
}
#[inline(always)]
pub fn used_at_creation(&self) -> usize {
self.save_point
}
}
impl Drop for RegionGuard {
fn drop(&mut self) {
unsafe { ffi::box_region_truncate(self.save_point) }
}
}
#[inline]
fn allocate_on_region(size: usize) -> Result<&'static mut [u8], BoxError> {
let pointer = unsafe { ffi::box_region_alloc(size).cast::<u8>() };
if pointer.is_null() {
return Err(BoxError::last());
}
let region_slice = unsafe { std::slice::from_raw_parts_mut(pointer, size) };
Ok(region_slice)
}
#[inline]
pub fn copy_to_region(data: &[u8]) -> Result<&'static [u8], BoxError> {
let region_slice = allocate_on_region(data.len())?;
region_slice.copy_from_slice(data);
Ok(region_slice)
}
#[derive(Debug)]
pub struct RegionBuffer {
guard: RegionGuard,
start: *mut u8,
count: usize,
}
impl RegionBuffer {
#[inline(always)]
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self {
guard: RegionGuard::new(),
start: std::ptr::null_mut(),
count: 0,
}
}
#[track_caller]
pub fn push(&mut self, data: &[u8]) -> Result<(), BoxError> {
let added_count = data.len();
unsafe {
let pointer: *mut u8 = ffi::box_region_alloc(added_count) as _;
if pointer.is_null() {
#[rustfmt::skip]
return Err(BoxError::new(TarantoolErrorCode::MemoryIssue, format!("failed to allocate {added_count} bytes on the region allocator")));
}
memcpy(pointer, data.as_ptr(), added_count);
self.count += added_count;
if self.start.is_null() {
self.start = pointer;
}
}
Ok(())
}
#[deprecated = "no longer supported, consider using RegionBuffer::into_raw_parts instead"]
#[inline(always)]
pub fn get(&self) -> &[u8] {
unimplemented!("RegionBuffer::get is no longer supported")
}
#[inline]
pub fn into_raw_parts(self) -> (&'static [u8], usize) {
self.try_into_raw_parts().unwrap()
}
pub fn try_into_raw_parts(self) -> Result<(&'static [u8], usize), (BoxError, Self)> {
let res = unsafe { self.join() };
let slice = match res {
Ok(v) => v,
Err(e) => {
return Err((e, self));
}
};
let save_point = self.guard.used_at_creation();
std::mem::forget(self.guard);
Ok((slice, save_point))
}
#[inline]
unsafe fn join(&self) -> Result<&'static [u8], BoxError> {
use crate::internal::ffi;
if self.count == 0 {
return Ok(&[]);
}
if !ffi::has_box_region_join() {
return Err(BoxError::new(
TarantoolErrorCode::Unsupported,
"box_region_join is not supported in this version of picodata",
));
}
let start = unsafe { ffi::box_region_join(self.count) };
if start.is_null() {
return Err(BoxError::last());
}
let slice = unsafe { std::slice::from_raw_parts(start.cast(), self.count) };
Ok(slice)
}
pub fn try_into_vec(self) -> Result<Vec<u8>, BoxError> {
let res = unsafe { self.join() };
let slice = match res {
Ok(v) => v,
Err(e) => {
return Err(e);
}
};
let res = Vec::from(slice);
drop(self);
Ok(res)
}
}
impl std::io::Write for RegionBuffer {
#[inline(always)]
fn write(&mut self, data: &[u8]) -> std::io::Result<usize> {
if let Err(e) = self.push(data) {
#[rustfmt::skip]
return Err(std::io::Error::new(std::io::ErrorKind::OutOfMemory, e.message()));
}
Ok(data.len())
}
#[inline(always)]
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[inline(always)]
unsafe fn memcpy(destination: *mut u8, source: *const u8, count: usize) {
let to = std::slice::from_raw_parts_mut(destination, count);
let from = std::slice::from_raw_parts(source, count);
to.copy_from_slice(from)
}
pub struct DisplayErrorLocation<'a>(pub &'a BoxError);
impl std::fmt::Display for DisplayErrorLocation<'_> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if let Some((file, line)) = self.0.file().zip(self.0.line()) {
write!(f, "{file}:{line}: ")?;
}
Ok(())
}
}
pub struct DisplayAsHexBytesLimitted<'a>(pub &'a [u8]);
impl std::fmt::Display for DisplayAsHexBytesLimitted<'_> {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.0.len() > 512 {
f.write_str("<too-big-to-display>")
} else {
tarantool::util::DisplayAsHexBytes(self.0).fmt(f)
}
}
}
#[track_caller]
#[inline]
pub fn msgpack_decode_str(data: &[u8]) -> Result<&str, BoxError> {
let mut cursor = Cursor::new(data);
let length = rmp::decode::read_str_len(&mut cursor).map_err(invalid_msgpack)? as usize;
let res = str_from_cursor(length, &mut cursor)?;
let (_, tail) = cursor_split(&cursor);
if !tail.is_empty() {
return Err(invalid_msgpack(format!(
"unexpected data after msgpack value: {}",
DisplayAsHexBytesLimitted(tail)
)));
}
Ok(res)
}
#[track_caller]
pub fn msgpack_read_str<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a str, BoxError> {
let length = rmp::decode::read_str_len(cursor).map_err(invalid_msgpack)? as usize;
str_from_cursor(length, cursor)
}
#[track_caller]
pub fn msgpack_read_rest_of_str<'a>(
marker: rmp::Marker,
cursor: &mut Cursor<&'a [u8]>,
) -> Result<Option<&'a str>, BoxError> {
use rmp::decode::RmpRead as _;
let length = match marker {
rmp::Marker::FixStr(v) => v as usize,
rmp::Marker::Str8 => cursor.read_data_u8().map_err(invalid_msgpack)? as usize,
rmp::Marker::Str16 => cursor.read_data_u16().map_err(invalid_msgpack)? as usize,
rmp::Marker::Str32 => cursor.read_data_u32().map_err(invalid_msgpack)? as usize,
_ => return Ok(None),
};
str_from_cursor(length, cursor).map(Some)
}
#[inline]
#[track_caller]
fn str_from_cursor<'a>(length: usize, cursor: &mut Cursor<&'a [u8]>) -> Result<&'a str, BoxError> {
let start_index = cursor.position() as usize;
let data = *cursor.get_ref();
let remaining_length = data.len() - start_index;
if remaining_length < length {
return Err(invalid_msgpack(format!(
"expected a string of length {length}, got {remaining_length}"
)));
}
let end_index = start_index + length;
let res = std::str::from_utf8(&data[start_index..end_index]).map_err(invalid_msgpack)?;
cursor.set_position(end_index as _);
Ok(res)
}
#[track_caller]
pub fn msgpack_decode_bin(data: &[u8]) -> Result<&[u8], BoxError> {
let mut cursor = Cursor::new(data);
let length = rmp::decode::read_bin_len(&mut cursor).map_err(invalid_msgpack)? as usize;
let res = bin_from_cursor(length, &mut cursor)?;
let (_, tail) = cursor_split(&cursor);
if !tail.is_empty() {
return Err(invalid_msgpack(format!(
"unexpected data after msgpack value: {}",
DisplayAsHexBytesLimitted(tail)
)));
}
Ok(res)
}
#[track_caller]
pub fn msgpack_read_bin<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], BoxError> {
let length = rmp::decode::read_bin_len(cursor).map_err(invalid_msgpack)? as usize;
bin_from_cursor(length, cursor)
}
#[track_caller]
pub fn msgpack_read_rest_of_bin<'a>(
marker: rmp::Marker,
cursor: &mut Cursor<&'a [u8]>,
) -> Result<Option<&'a [u8]>, BoxError> {
use rmp::decode::RmpRead as _;
let length = match marker {
rmp::Marker::Bin8 => cursor.read_data_u8().map_err(invalid_msgpack)? as usize,
rmp::Marker::Bin16 => cursor.read_data_u16().map_err(invalid_msgpack)? as usize,
rmp::Marker::Bin32 => cursor.read_data_u32().map_err(invalid_msgpack)? as usize,
_ => return Ok(None),
};
bin_from_cursor(length, cursor).map(Some)
}
#[inline]
#[track_caller]
fn bin_from_cursor<'a>(length: usize, cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], BoxError> {
let start_index = cursor.position() as usize;
let data = *cursor.get_ref();
let remaining_length = data.len() - start_index;
if remaining_length < length {
return Err(invalid_msgpack(format!(
"expected binary data of length {length}, got {remaining_length}"
)));
}
let end_index = start_index + length;
let res = &data[start_index..end_index];
cursor.set_position(end_index as _);
Ok(res)
}
fn cursor_split<'a>(cursor: &Cursor<&'a [u8]>) -> (&'a [u8], &'a [u8]) {
let slice = cursor.get_ref();
let pos = cursor.position().min(slice.len() as u64);
slice.split_at(pos as usize)
}
#[inline(always)]
#[track_caller]
fn invalid_msgpack(error: impl ToString) -> BoxError {
BoxError::new(TarantoolErrorCode::InvalidMsgpack, error.to_string())
}
#[inline(always)]
fn as_non_null_ptr<T>(data: &[T]) -> NonNull<T> {
let pointer = data.as_ptr();
unsafe { NonNull::new_unchecked(pointer as *mut _) }
}
pub fn tarantool_error_to_box_error(e: tarantool::error::Error) -> BoxError {
match e {
tarantool::error::Error::Tarantool(e) => e,
other => BoxError::new(ErrorCode::Other, other.to_string()),
}
}
#[cfg(all(feature = "internal_test", not(test)))]
mod test {
use super::*;
#[tarantool::test]
fn region_buffer() {
#[derive(serde::Serialize, Debug)]
struct S {
name: String,
x: f32,
y: f32,
array: Vec<(i32, i32, bool)>,
}
let s = S {
name: "foo".into(),
x: 4.2,
y: 6.9,
array: vec![(1, 2, true), (3, 4, false)],
};
let vec = rmp_serde::to_vec(&s).unwrap();
let mut buffer = RegionBuffer::new();
rmp_serde::encode::write(&mut buffer, &s).unwrap();
let data = buffer.try_into_vec().unwrap();
assert_eq!(vec, data);
}
#[tarantool::test]
fn region_buffer_tiny_allocation() {
let _guard = RegionGuard::new();
let mut buffer = RegionBuffer::new();
buffer.push(&[1, 2, 3]).unwrap();
let data = unsafe { buffer.join().unwrap() };
assert_eq!(data, &[1, 2, 3]);
assert_eq!(data.as_ptr(), buffer.start);
let data2 = unsafe { buffer.join().unwrap() };
assert_eq!(data2, data);
assert_eq!(data2.as_ptr(), buffer.start);
let (data3, _) = buffer.into_raw_parts();
assert_eq!(data3, data);
assert_eq!(data3.as_ptr(), data.as_ptr());
}
#[tarantool::test]
fn region_buffer_big_allocation() {
const N: usize = 4923;
const M: usize = 85;
const K: usize = 10;
const {
const SLAB_SIZE: usize = u16::MAX as usize + 1;
assert!(N * M * K > SLAB_SIZE);
};
let t0 = std::time::Instant::now();
let mut input = Vec::with_capacity(N);
for i in 0..N {
let mut row = Vec::with_capacity(M);
for j in 0..M {
let mut col = Vec::with_capacity(K);
for k in 0..K {
let v = (1 + i + j) * (1 + k);
col.push(v as u8);
}
row.push(col);
}
input.push(row);
}
tarantool::say_info!("generating data took {:?}", t0.elapsed());
let t0 = std::time::Instant::now();
let mut buffer = RegionBuffer::new();
rmp_serde::encode::write(&mut buffer, &input).unwrap();
let data = unsafe { buffer.join().unwrap() };
tarantool::say_info!(
"serializing data to region allocator took {:?}",
t0.elapsed()
);
assert_ne!(data.as_ptr(), buffer.start);
let t0 = std::time::Instant::now();
let control = rmp_serde::to_vec(&input).unwrap();
tarantool::say_info!("serializing data to rust allocator took {:?}", t0.elapsed());
assert_eq!(control, data);
}
}