use super::{
blob_store::BlobStore,
indexes::{Bytes, PageOffset, RowPointer, SquashedOffset},
page::{GranuleOffsetIter, Page, VarView},
page_pool::PagePool,
pages::Pages,
table::BlobNumBytes,
util::range_move,
var_len::{VarLenGranule, VarLenMembers, VarLenRef},
};
use spacetimedb_sats::{
bsatn::{self, to_writer, DecodeError},
buffer::BufWriter,
de::DeserializeSeed as _,
i256,
layout::{
align_to, AlgebraicTypeLayout, HasLayout, ProductTypeLayoutView, RowTypeLayout, SumTypeLayout, VarLenType,
},
u256, AlgebraicType, AlgebraicValue, ProductValue, SumValue,
};
use thiserror::Error;
#[derive(Error, Debug, PartialEq, Eq)]
pub enum Error {
#[error(transparent)]
Decode(#[from] DecodeError),
#[error("Expected a value of type {0:?}, but found {1:?}")]
WrongType(AlgebraicType, AlgebraicValue),
#[error(transparent)]
PageError(#[from] super::page::Error),
#[error(transparent)]
PagesError(#[from] super::pages::Error),
}
pub unsafe fn write_row_to_pages_bsatn(
pool: &PagePool,
pages: &mut Pages,
visitor: &impl VarLenMembers,
blob_store: &mut dyn BlobStore,
ty: &RowTypeLayout,
mut bytes: &[u8],
squashed_offset: SquashedOffset,
) -> Result<(RowPointer, BlobNumBytes), Error> {
let val = ty.product().deserialize(bsatn::Deserializer::new(&mut bytes))?;
unsafe { write_row_to_pages(pool, pages, visitor, blob_store, ty, &val, squashed_offset) }
}
pub unsafe fn write_row_to_pages(
pool: &PagePool,
pages: &mut Pages,
visitor: &impl VarLenMembers,
blob_store: &mut dyn BlobStore,
ty: &RowTypeLayout,
val: &ProductValue,
squashed_offset: SquashedOffset,
) -> Result<(RowPointer, BlobNumBytes), Error> {
let num_granules = required_var_len_granules_for_row(val);
match pages.with_page_to_insert_row(pool, ty.size(), num_granules, |page| {
unsafe { write_row_to_page(page, blob_store, visitor, ty, val) }
})? {
(page, Ok((offset, blob_inserted))) => {
Ok((RowPointer::new(false, page, offset, squashed_offset), blob_inserted))
}
(_, Err(e)) => Err(e),
}
}
pub unsafe fn write_row_to_page(
page: &mut Page,
blob_store: &mut dyn BlobStore,
visitor: &impl VarLenMembers,
ty: &RowTypeLayout,
val: &ProductValue,
) -> Result<(PageOffset, BlobNumBytes), Error> {
let fixed_row_size = ty.size();
let fixed_offset = unsafe { page.alloc_fixed_len(fixed_row_size)? };
let (mut fixed, var_view) = page.split_fixed_var_mut();
let mut serialized = BflatnSerializedRowBuffer {
fixed_buf: fixed.get_row_mut(fixed_offset, fixed_row_size),
curr_offset: 0,
var_view,
last_allocated_var_len_index: 0,
large_blob_insertions: Vec::new(),
};
if let Err(e) = serialized.write_product(ty.product(), val) {
unsafe { serialized.roll_back_var_len_allocations(visitor) };
unsafe { fixed.free(fixed_offset, fixed_row_size) };
return Err(e);
}
let blob_store_inserted_bytes = serialized.write_large_blobs(blob_store);
Ok((fixed_offset, blob_store_inserted_bytes))
}
struct BflatnSerializedRowBuffer<'page> {
fixed_buf: &'page mut Bytes,
curr_offset: usize,
last_allocated_var_len_index: usize,
large_blob_insertions: Vec<(VarLenRef, Vec<u8>)>,
var_view: VarView<'page>,
}
impl BflatnSerializedRowBuffer<'_> {
unsafe fn roll_back_var_len_allocations(&mut self, visitor: &impl VarLenMembers) {
let visitor_iter = unsafe { visitor.visit_var_len(self.fixed_buf) };
for vlr in visitor_iter.take(self.last_allocated_var_len_index) {
unsafe { self.var_view.free_object_ignore_blob(*vlr) };
}
}
fn write_large_blobs(mut self, blob_store: &mut dyn BlobStore) -> BlobNumBytes {
let mut blob_store_inserted_bytes = BlobNumBytes::default();
for (vlr, value) in self.large_blob_insertions {
unsafe {
blob_store_inserted_bytes += self.var_view.write_large_blob_hash_to_granule(blob_store, &value, vlr);
}
}
blob_store_inserted_bytes
}
fn write_value(&mut self, ty: &AlgebraicTypeLayout, val: &AlgebraicValue) -> Result<(), Error> {
debug_assert_eq!(
self.curr_offset,
align_to(self.curr_offset, ty.align()),
"curr_offset {} insufficiently aligned for type {:#?}",
self.curr_offset,
val,
);
match (ty, val) {
(AlgebraicTypeLayout::Sum(ty), AlgebraicValue::Sum(val)) => self.write_sum(ty, val)?,
(AlgebraicTypeLayout::Product(ty), AlgebraicValue::Product(val)) => self.write_product(ty.view(), val)?,
(&AlgebraicTypeLayout::Bool, AlgebraicValue::Bool(val)) => self.write_bool(*val),
(&AlgebraicTypeLayout::I8, AlgebraicValue::I8(val)) => self.write_i8(*val),
(&AlgebraicTypeLayout::U8, AlgebraicValue::U8(val)) => self.write_u8(*val),
(&AlgebraicTypeLayout::I16, AlgebraicValue::I16(val)) => self.write_i16(*val),
(&AlgebraicTypeLayout::U16, AlgebraicValue::U16(val)) => self.write_u16(*val),
(&AlgebraicTypeLayout::I32, AlgebraicValue::I32(val)) => self.write_i32(*val),
(&AlgebraicTypeLayout::U32, AlgebraicValue::U32(val)) => self.write_u32(*val),
(&AlgebraicTypeLayout::I64, AlgebraicValue::I64(val)) => self.write_i64(*val),
(&AlgebraicTypeLayout::U64, AlgebraicValue::U64(val)) => self.write_u64(*val),
(&AlgebraicTypeLayout::I128, AlgebraicValue::I128(val)) => self.write_i128(val.0),
(&AlgebraicTypeLayout::U128, AlgebraicValue::U128(val)) => self.write_u128(val.0),
(&AlgebraicTypeLayout::I256, AlgebraicValue::I256(val)) => self.write_i256(**val),
(&AlgebraicTypeLayout::U256, AlgebraicValue::U256(val)) => self.write_u256(**val),
(&AlgebraicTypeLayout::F32, AlgebraicValue::F32(val)) => self.write_f32((*val).into()),
(&AlgebraicTypeLayout::F64, AlgebraicValue::F64(val)) => self.write_f64((*val).into()),
(&AlgebraicTypeLayout::String, AlgebraicValue::String(val)) => self.write_string(val)?,
(AlgebraicTypeLayout::VarLen(VarLenType::Array(_)), val @ AlgebraicValue::Array(_)) => {
self.write_av_bsatn(val)?
}
(ty, val) => Err(Error::WrongType(ty.algebraic_type(), val.clone()))?,
}
Ok(())
}
fn write_sum(&mut self, ty: &SumTypeLayout, val: &SumValue) -> Result<(), Error> {
let SumValue { tag, ref value } = *val;
let variant_ty = &ty.variants[tag as usize];
let variant_offset = self.curr_offset + ty.offset_of_variant_data(tag);
let tag_offset = self.curr_offset + ty.offset_of_tag();
self.curr_offset = variant_offset;
self.write_value(&variant_ty.ty, value)?;
self.curr_offset = tag_offset;
self.write_u8(tag);
Ok(())
}
fn write_product(&mut self, ty: ProductTypeLayoutView<'_>, val: &ProductValue) -> Result<(), Error> {
if ty.elements.len() != val.elements.len() {
return Err(Error::WrongType(
ty.algebraic_type(),
AlgebraicValue::Product(val.clone()),
));
}
let base_offset = self.curr_offset;
for (elt_ty, elt) in ty.elements.iter().zip(val.elements.iter()) {
self.curr_offset = base_offset + elt_ty.offset as usize;
self.write_value(&elt_ty.ty, elt)?;
}
Ok(())
}
fn write_string(&mut self, val: &str) -> Result<(), Error> {
let val = val.as_bytes();
let (vlr, in_blob) = self.var_view.alloc_for_slice(val)?;
if in_blob {
self.defer_insert_large_blob(vlr, val.to_vec());
}
self.write_var_len_ref(vlr);
Ok(())
}
fn write_av_bsatn(&mut self, val: &AlgebraicValue) -> Result<(), Error> {
let len_in_bytes = bsatn_len(val);
let (vlr, in_blob) = self.var_view.alloc_for_len(len_in_bytes)?;
self.write_var_len_ref(vlr);
if in_blob {
let mut bytes = Vec::with_capacity(len_in_bytes);
val.encode(&mut bytes);
self.defer_insert_large_blob(vlr, bytes);
} else {
let iter = unsafe { self.var_view.granule_offset_iter(vlr.first_granule) };
let mut writer = GranuleBufWriter { buf: None, iter };
to_writer(&mut writer, val).unwrap();
}
struct GranuleBufWriter<'vv, 'page> {
buf: Option<(PageOffset, usize)>,
iter: GranuleOffsetIter<'page, 'vv>,
}
impl BufWriter for GranuleBufWriter<'_, '_> {
fn put_slice(&mut self, mut slice: &[u8]) {
while !slice.is_empty() {
let (offset, start) = match self.buf.take() {
Some(buf @ (_, start)) if start < VarLenGranule::DATA_SIZE => buf,
_ => {
let next = self.iter.next();
debug_assert!(next.is_some());
let next = unsafe { next.unwrap_unchecked() };
(next, 0)
}
};
let capacity_remains = VarLenGranule::DATA_SIZE - start;
debug_assert!(capacity_remains > 0);
let extend_len = capacity_remains.min(slice.len());
let (extend_with, rest) = slice.split_at(extend_len);
let write_to = unsafe { self.iter.get_mut_data(offset, start) };
for (to, byte) in write_to.iter_mut().zip(extend_with) {
*to = *byte;
}
slice = rest;
self.buf = Some((offset, start + extend_len));
}
}
}
Ok(())
}
fn write_var_len_ref(&mut self, val: VarLenRef) {
self.write_u16(val.length_in_bytes);
self.write_u16(val.first_granule.0);
self.last_allocated_var_len_index += 1;
}
fn defer_insert_large_blob(&mut self, vlr: VarLenRef, obj_bytes: Vec<u8>) {
self.large_blob_insertions.push((vlr, obj_bytes));
}
fn write_bytes<const N: usize>(&mut self, bytes: &[u8; N]) {
self.fixed_buf[range_move(0..N, self.curr_offset)].copy_from_slice(bytes);
self.curr_offset += N;
}
fn write_u8(&mut self, val: u8) {
self.write_bytes(&[val]);
}
fn write_i8(&mut self, val: i8) {
self.write_u8(val as u8);
}
fn write_bool(&mut self, val: bool) {
self.write_u8(val as u8);
}
fn write_u16(&mut self, val: u16) {
self.write_bytes(&val.to_le_bytes());
}
fn write_i16(&mut self, val: i16) {
self.write_bytes(&val.to_le_bytes());
}
fn write_u32(&mut self, val: u32) {
self.write_bytes(&val.to_le_bytes());
}
fn write_i32(&mut self, val: i32) {
self.write_bytes(&val.to_le_bytes());
}
fn write_u64(&mut self, val: u64) {
self.write_bytes(&val.to_le_bytes());
}
fn write_i64(&mut self, val: i64) {
self.write_bytes(&val.to_le_bytes());
}
fn write_u128(&mut self, val: u128) {
self.write_bytes(&val.to_le_bytes());
}
fn write_i128(&mut self, val: i128) {
self.write_bytes(&val.to_le_bytes());
}
fn write_u256(&mut self, val: u256) {
self.write_bytes(&val.to_le_bytes());
}
fn write_i256(&mut self, val: i256) {
self.write_bytes(&val.to_le_bytes());
}
fn write_f32(&mut self, val: f32) {
self.write_bytes(&val.to_le_bytes());
}
fn write_f64(&mut self, val: f64) {
self.write_bytes(&val.to_le_bytes());
}
}
fn required_var_len_granules_for_row(val: &ProductValue) -> usize {
fn traverse_av(val: &AlgebraicValue, count: &mut usize) {
match val {
AlgebraicValue::Product(val) => traverse_product(val, count),
AlgebraicValue::Sum(val) => traverse_av(&val.value, count),
AlgebraicValue::Array(_) => add_for_bytestring(bsatn_len(val), count),
AlgebraicValue::String(val) => add_for_bytestring(val.len(), count),
_ => (),
}
}
fn traverse_product(val: &ProductValue, count: &mut usize) {
for elt in val {
traverse_av(elt, count);
}
}
fn add_for_bytestring(len_in_bytes: usize, count: &mut usize) {
*count += VarLenGranule::bytes_to_granules(len_in_bytes).0;
}
let mut required_granules: usize = 0;
traverse_product(val, &mut required_granules);
required_granules
}
fn bsatn_len(val: &AlgebraicValue) -> usize {
bsatn::to_len(val).unwrap()
}
#[cfg(test)]
pub mod test {
use super::*;
use crate::{
bflatn_from::serialize_row_from_page, blob_store::HashMapBlobStore, page::tests::hash_unmodified_save_get,
row_type_visitor::row_type_visitor,
};
use proptest::{prelude::*, prop_assert_eq, proptest};
use spacetimedb_sats::algebraic_value::ser::ValueSerializer;
use spacetimedb_sats::proptest::generate_typed_row;
proptest! {
#![proptest_config(ProptestConfig::with_cases(if cfg!(miri) { 8 } else { 2048 }))]
#[test]
fn av_serde_round_trip_through_page((ty, val) in generate_typed_row()) {
let ty: RowTypeLayout = ty.into();
let mut page = Page::new(ty.size());
let visitor = row_type_visitor(&ty);
let blob_store = &mut HashMapBlobStore::default();
let hash_pre_ins = hash_unmodified_save_get(&mut page);
let (offset, _) = unsafe { write_row_to_page(&mut page, blob_store, &visitor, &ty, &val).unwrap() };
let hash_pre_ser = hash_unmodified_save_get(&mut page);
assert_ne!(hash_pre_ins, hash_pre_ser);
let read_val = unsafe { serialize_row_from_page(ValueSerializer, &page, blob_store, offset, &ty) }
.unwrap().into_product().unwrap();
prop_assert_eq!(val, read_val);
assert_eq!(hash_pre_ser, *page.unmodified_hash().unwrap());
}
}
}