use super::{
indexes::{Byte, Bytes},
util::range_move,
};
use core::mem::MaybeUninit;
use core::ptr;
use smallvec::SmallVec;
use spacetimedb_data_structures::slim_slice::SlimSmallSliceBox;
use spacetimedb_sats::layout::{
AlgebraicTypeLayout, HasLayout, PrimitiveType, ProductTypeElementLayout, ProductTypeLayoutView, RowTypeLayout,
SumTypeLayout, SumTypeVariantLayout,
};
use spacetimedb_sats::memory_usage::MemoryUsage;
#[derive(PartialEq, Eq, Debug, Clone)]
#[repr(align(8))]
pub struct StaticLayout {
pub(crate) bsatn_length: u16,
fields: SlimSmallSliceBox<MemcpyField, 3>,
}
impl MemoryUsage for StaticLayout {
fn heap_usage(&self) -> usize {
let Self { bsatn_length, fields } = self;
bsatn_length.heap_usage() + fields.heap_usage()
}
}
impl StaticLayout {
unsafe fn serialize_row_into(&self, buf: &mut [MaybeUninit<Byte>], row: &Bytes) {
debug_assert!(buf.len() >= self.bsatn_length as usize);
for field in &*self.fields {
unsafe { field.copy_bflatn_to_bsatn(row, buf) };
}
}
pub(crate) unsafe fn serialize_row_into_vec(&self, row: &Bytes) -> Vec<u8> {
let bsatn_len = self.bsatn_length as usize;
let mut buf = Vec::with_capacity(bsatn_len);
let sink = buf.spare_capacity_mut();
unsafe {
self.serialize_row_into(sink, row);
}
unsafe { buf.set_len(bsatn_len) }
buf
}
pub(crate) unsafe fn serialize_row_extend(&self, buf: &mut Vec<u8>, row: &Bytes) {
let start = buf.len();
let len = self.bsatn_length as usize;
buf.reserve(len);
let sink = &mut buf.spare_capacity_mut()[..len];
unsafe {
self.serialize_row_into(sink, row);
}
unsafe { buf.set_len(start + len) }
}
#[allow(unused)]
pub(crate) unsafe fn deserialize_row_into(&self, buf: &mut Bytes, row: &[u8]) {
for field in &*self.fields {
unsafe { field.copy_bsatn_to_bflatn(row, buf) };
}
}
pub(crate) unsafe fn eq(&self, row_a: &Bytes, row_b: &Bytes) -> bool {
self.fields.iter().all(|field| {
unsafe { field.eq(row_a, row_b) }
})
}
pub fn for_row_type(row_type: &RowTypeLayout) -> Option<Self> {
if !row_type.layout().fixed {
return None;
}
let mut builder = LayoutBuilder::new_builder();
builder.visit_product(row_type.product())?;
Some(builder.build())
}
}
#[derive(PartialEq, Eq, Debug, Copy, Clone)]
struct MemcpyField {
bflatn_offset: u16,
bsatn_offset: u16,
length: u16,
}
impl MemoryUsage for MemcpyField {}
impl MemcpyField {
unsafe fn copy_bflatn_to_bsatn(&self, src: &Bytes, dst: &mut [MaybeUninit<Byte>]) {
let src_offset = self.bflatn_offset as usize;
let dst_offset = self.bsatn_offset as usize;
let len = self.length as usize;
let src = src.as_ptr();
let dst = dst.as_mut_ptr();
let src = unsafe { src.add(src_offset) };
let dst = unsafe { dst.add(dst_offset) };
let dst = dst.cast();
unsafe { ptr::copy_nonoverlapping(src, dst, len) }
}
unsafe fn copy_bsatn_to_bflatn(&self, src: &Bytes, dst: &mut Bytes) {
let src_offset = self.bsatn_offset as usize;
let dst_offset = self.bflatn_offset as usize;
let len = self.length as usize;
let src = src.as_ptr();
let dst = dst.as_mut_ptr();
let src = unsafe { src.add(src_offset) };
let dst = unsafe { dst.add(dst_offset) };
unsafe { ptr::copy_nonoverlapping(src, dst, len) }
}
unsafe fn eq(&self, row_a: &Bytes, row_b: &Bytes) -> bool {
let range = range_move(0..self.length as usize, self.bflatn_offset as usize);
let range2 = range.clone();
let row_a_field = unsafe { row_a.get_unchecked(range) };
let row_b_field = unsafe { row_b.get_unchecked(range2) };
row_a_field == row_b_field
}
fn is_empty(&self) -> bool {
self.length == 0
}
}
struct LayoutBuilder {
fields: Vec<MemcpyField>,
}
impl LayoutBuilder {
fn new_builder() -> Self {
Self {
fields: vec![MemcpyField {
bflatn_offset: 0,
bsatn_offset: 0,
length: 0,
}],
}
}
fn build(self) -> StaticLayout {
let LayoutBuilder { fields } = self;
let fields: SmallVec<[_; 3]> = fields.into_iter().filter(|field| !field.is_empty()).collect();
let fields: SlimSmallSliceBox<MemcpyField, 3> = fields.into();
let bsatn_length = fields.last().map(|last| last.bsatn_offset + last.length).unwrap_or(0);
StaticLayout { bsatn_length, fields }
}
fn current_field(&self) -> &MemcpyField {
self.fields.last().unwrap()
}
fn current_field_mut(&mut self) -> &mut MemcpyField {
self.fields.last_mut().unwrap()
}
fn next_bflatn_offset(&self) -> u16 {
let last = self.current_field();
last.bflatn_offset + last.length
}
fn next_bsatn_offset(&self) -> u16 {
let last = self.current_field();
last.bsatn_offset + last.length
}
fn visit_product(&mut self, product: ProductTypeLayoutView) -> Option<()> {
let base_bflatn_offset = self.next_bflatn_offset();
for elt in product.elements.iter() {
self.visit_product_element(elt, base_bflatn_offset)?;
}
Some(())
}
fn visit_product_element(&mut self, elt: &ProductTypeElementLayout, product_base_offset: u16) -> Option<()> {
let elt_offset = product_base_offset + elt.offset;
let next_bflatn_offset = self.next_bflatn_offset();
if next_bflatn_offset != elt_offset {
let bsatn_offset = self.next_bsatn_offset();
self.fields.push(MemcpyField {
bsatn_offset,
bflatn_offset: elt_offset,
length: 0,
});
}
self.visit_value(&elt.ty)
}
fn visit_value(&mut self, val: &AlgebraicTypeLayout) -> Option<()> {
match val {
AlgebraicTypeLayout::Sum(sum) => self.visit_sum(sum),
AlgebraicTypeLayout::Product(prod) => self.visit_product(prod.view()),
AlgebraicTypeLayout::Primitive(prim) => {
self.visit_primitive(prim);
Some(())
}
AlgebraicTypeLayout::VarLen(_) => None,
}
}
fn visit_sum(&mut self, sum: &SumTypeLayout) -> Option<()> {
let first_variant = sum.variants.first()?;
let variant_layout = |variant: &SumTypeVariantLayout| {
let mut builder = LayoutBuilder::new_builder();
builder.visit_value(&variant.ty)?;
Some(builder.build())
};
let first_variant_layout = variant_layout(first_variant)?;
for later_variant in &sum.variants[1..] {
let later_variant_layout = variant_layout(later_variant)?;
if later_variant_layout != first_variant_layout {
return None;
}
}
if first_variant_layout.bsatn_length == 0 {
self.current_field_mut().length += 1;
return Some(());
}
let tag_bflatn_offset = self.next_bflatn_offset();
let payload_bflatn_offset = tag_bflatn_offset + sum.payload_offset;
let tag_bsatn_offset = self.next_bsatn_offset();
let payload_bsatn_offset = tag_bsatn_offset + 1;
self.visit_primitive(&PrimitiveType::U8);
if sum.payload_offset > 1 {
self.fields.push(MemcpyField {
bflatn_offset: payload_bflatn_offset,
bsatn_offset: payload_bsatn_offset,
length: 0,
});
}
self.visit_value(&first_variant.ty)?;
Some(())
}
fn visit_primitive(&mut self, prim: &PrimitiveType) {
self.current_field_mut().length += prim.size() as u16
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{blob_store::HashMapBlobStore, page_pool::PagePool};
use proptest::prelude::*;
use spacetimedb_sats::{bsatn, proptest::generate_typed_row, AlgebraicType, ProductType};
fn assert_expected_layout(ty: ProductType, bsatn_length: u16, fields: &[(u16, u16, u16)]) {
let expected_layout = StaticLayout {
bsatn_length,
fields: fields
.iter()
.copied()
.map(|(bflatn_offset, bsatn_offset, length)| MemcpyField {
bflatn_offset,
bsatn_offset,
length,
})
.collect::<SmallVec<_>>()
.into(),
};
let row_type = RowTypeLayout::from(ty.clone());
let Some(computed_layout) = StaticLayout::for_row_type(&row_type) else {
panic!("assert_expected_layout: Computed `None` for row {row_type:#?}\nExpected:{expected_layout:#?}");
};
assert_eq!(
computed_layout, expected_layout,
"assert_expected_layout: Computed layout (left) doesn't match expected (right) for {ty:?}",
);
}
#[test]
fn known_types_expected_layout_plain() {
for prim in [
AlgebraicType::Bool,
AlgebraicType::U8,
AlgebraicType::I8,
AlgebraicType::U16,
AlgebraicType::I16,
AlgebraicType::U32,
AlgebraicType::I32,
AlgebraicType::U64,
AlgebraicType::I64,
AlgebraicType::U128,
AlgebraicType::I128,
AlgebraicType::U256,
AlgebraicType::I256,
] {
let size = AlgebraicTypeLayout::from(prim.clone()).size() as u16;
assert_expected_layout(ProductType::from([prim]), size, &[(0, 0, size)]);
}
}
#[test]
fn known_types_expected_layout_complex() {
for (ty, bsatn_length, fields) in [
(ProductType::new([].into()), 0, &[][..]),
(
ProductType::from([AlgebraicType::sum([
AlgebraicType::U8,
AlgebraicType::I8,
AlgebraicType::Bool,
])]),
2,
&[(0, 0, 2)][..],
),
(
ProductType::from([AlgebraicType::sum([
AlgebraicType::product([
AlgebraicType::U8,
AlgebraicType::U8,
AlgebraicType::U8,
AlgebraicType::U8,
]),
AlgebraicType::product([AlgebraicType::U16, AlgebraicType::U16]),
AlgebraicType::U32,
])]),
5,
&[(0, 0, 1), (4, 1, 4)][..],
),
(
ProductType::from([
AlgebraicType::sum([AlgebraicType::U128, AlgebraicType::I128]),
AlgebraicType::U32,
]),
21,
&[(0, 0, 1), (16, 1, 20)][..],
),
(
ProductType::from([
AlgebraicType::sum([AlgebraicType::U256, AlgebraicType::I256]),
AlgebraicType::U32,
]),
37,
&[(0, 0, 1), (32, 1, 36)][..],
),
(
ProductType::from([
AlgebraicType::U256,
AlgebraicType::U128,
AlgebraicType::U64,
AlgebraicType::U32,
AlgebraicType::U16,
AlgebraicType::U8,
]),
63,
&[(0, 0, 63)][..],
),
(
ProductType::from([
AlgebraicType::U8,
AlgebraicType::U16,
AlgebraicType::U32,
AlgebraicType::U64,
AlgebraicType::U128,
]),
31,
&[(0, 0, 1), (2, 1, 30)][..],
),
(
ProductType::from([AlgebraicType::sum([AlgebraicType::product::<[AlgebraicType; 0]>([])])]),
1,
&[(0, 0, 1)][..],
),
(
ProductType::from([AlgebraicType::sum([
AlgebraicType::product::<[AlgebraicType; 0]>([]),
AlgebraicType::product::<[AlgebraicType; 0]>([]),
])]),
1,
&[(0, 0, 1)][..],
),
(
ProductType::from([AlgebraicType::sum([
AlgebraicType::product([AlgebraicType::U8, AlgebraicType::U8]),
AlgebraicType::product([AlgebraicType::Bool, AlgebraicType::Bool]),
])]),
3,
&[(0, 0, 3)][..],
),
(
ProductType::from([
AlgebraicType::sum([AlgebraicType::Bool, AlgebraicType::U8]),
AlgebraicType::sum([AlgebraicType::U8, AlgebraicType::Bool]),
]),
4,
&[(0, 0, 4)][..],
),
(
ProductType::from([
AlgebraicType::U16,
AlgebraicType::sum([AlgebraicType::U8, AlgebraicType::Bool]),
AlgebraicType::U16,
]),
6,
&[(0, 0, 6)][..],
),
(
ProductType::from([
AlgebraicType::U32,
AlgebraicType::sum([AlgebraicType::U16, AlgebraicType::I16]),
AlgebraicType::U32,
]),
11,
&[(0, 0, 5), (6, 5, 6)][..],
),
] {
assert_expected_layout(ty, bsatn_length, fields);
}
}
#[test]
fn known_types_not_applicable() {
for ty in [
AlgebraicType::String,
AlgebraicType::bytes(),
AlgebraicType::never(),
AlgebraicType::array(AlgebraicType::U16),
AlgebraicType::sum([AlgebraicType::U8, AlgebraicType::U16]),
] {
let layout = RowTypeLayout::from(ProductType::from([ty]));
if let Some(computed) = StaticLayout::for_row_type(&layout) {
panic!("Expected row type not to have a constant BSATN layout!\nRow type: {layout:#?}\nBSATN layout: {computed:#?}");
}
}
}
proptest! {
#![proptest_config(ProptestConfig { max_global_rejects: 65536, ..Default::default()})]
#[test]
fn known_bsatn_same_as_bflatn_from((ty, val) in generate_typed_row()) {
let pool = PagePool::new_for_test();
let mut blob_store = HashMapBlobStore::default();
let mut table = crate::table::test::table(ty);
let Some(static_layout) = table.static_layout().cloned() else {
return Err(TestCaseError::reject("Var-length type"));
};
let (_, row_ref) = table.insert(&pool, &mut blob_store, &val).unwrap();
let bytes = row_ref.get_row_data();
let slow_path = bsatn::to_vec(&row_ref).unwrap();
let fast_path = unsafe {
static_layout.serialize_row_into_vec(bytes)
};
let mut fast_path2 = Vec::new();
unsafe {
static_layout.serialize_row_extend(&mut fast_path2, bytes)
};
assert_eq!(slow_path, fast_path);
assert_eq!(slow_path, fast_path2);
}
#[test]
fn known_bflatn_same_as_pv_from((ty, val) in generate_typed_row()) {
let pool = PagePool::new_for_test();
let mut blob_store = HashMapBlobStore::default();
let mut table = crate::table::test::table(ty);
let Some(static_layout) = table.static_layout().cloned() else {
return Err(TestCaseError::reject("Var-length type"));
};
let bsatn = bsatn::to_vec(&val).unwrap();
let (_, row_ref) = table.insert(&pool, &mut blob_store, &val).unwrap();
let slow_path = row_ref.get_row_data();
let mut fast_path = vec![0u8; slow_path.len()];
unsafe {
static_layout.deserialize_row_into(&mut fast_path, &bsatn);
};
assert_eq!(slow_path, fast_path);
}
}
}