#![allow(unused)]
use super::static_layout::StaticLayout;
use itertools::{repeat_n, Itertools as _};
use spacetimedb_sats::bsatn::DecodeError;
use spacetimedb_sats::layout::{
AlgebraicTypeLayout, HasLayout as _, PrimitiveType, ProductTypeLayout, ProductTypeLayoutView, RowTypeLayout,
};
use spacetimedb_sats::memory_usage::MemoryUsage;
use std::sync::Arc;
pub(crate) fn static_bsatn_validator(ty: &RowTypeLayout) -> StaticBsatnValidator {
let tree = row_type_to_tree(ty.product());
let insns = tree_to_insns(&tree).into();
StaticBsatnValidator { insns }
}
fn row_type_to_tree(ty: ProductTypeLayoutView<'_>) -> Tree {
let mut sub_trees = Vec::new();
extend_trees_for_product_type(ty, &mut 0, &mut sub_trees);
sub_trees_to_tree(sub_trees)
}
fn sub_trees_to_tree(mut sub_trees: Vec<Tree>) -> Tree {
match sub_trees.len() {
0 => Tree::Empty,
1 => sub_trees.pop().unwrap(),
_ => Tree::Sequence { sub_trees },
}
}
fn extend_trees_for_product_type(ty: ProductTypeLayoutView<'_>, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
for elem in ty.elements {
extend_trees_for_algebraic_type(&elem.ty, current_offset, sub_trees);
}
}
fn extend_trees_for_algebraic_type(ty: &AlgebraicTypeLayout, current_offset: &mut usize, sub_trees: &mut Vec<Tree>) {
match ty {
AlgebraicTypeLayout::Primitive(PrimitiveType::Bool) => {
let offset = *current_offset as u16;
*current_offset += 1;
sub_trees.push(Tree::CheckBool { offset });
}
AlgebraicTypeLayout::Primitive(prim_ty) => {
*current_offset += prim_ty.size();
}
AlgebraicTypeLayout::Product(prod_ty) => {
extend_trees_for_product_type(prod_ty.view(), current_offset, sub_trees)
}
AlgebraicTypeLayout::Sum(sum_ty) => {
let num_variants = sum_ty.variants.len() as u8;
let tag_offset = *current_offset as u16;
*current_offset += 1;
let mut child_offset = *current_offset;
let mut variants = sum_ty
.variants
.iter()
.map(|variant| {
let var_ty = &variant.ty;
let mut sub_trees = Vec::new();
child_offset = *current_offset;
extend_trees_for_algebraic_type(var_ty, &mut child_offset, &mut sub_trees);
sub_trees_to_tree(sub_trees)
})
.collect::<Vec<_>>();
*current_offset = child_offset;
if variants.iter().all_equal() {
sub_trees.push(Tree::CheckTag {
tag_offset,
num_variants,
});
if let Some(tree) = variants.pop() {
sub_trees.push(tree);
}
} else {
sub_trees.push(Tree::Sum {
tag_offset,
tag_data_processors: variants,
});
}
}
AlgebraicTypeLayout::VarLen(_) => unreachable!(),
}
}
#[derive(Debug, PartialEq, Eq)]
enum Tree {
Empty,
Sequence { sub_trees: Vec<Tree> },
CheckBool { offset: u16 },
CheckTag {
tag_offset: u16,
num_variants: u8,
},
Sum {
tag_offset: u16,
tag_data_processors: Vec<Tree>,
},
}
fn tree_to_insns(tree: &Tree) -> Vec<Insn> {
let mut program = Vec::new();
fn compile_tree(tree: &Tree, into: &mut Vec<Insn>) {
match tree {
Tree::Empty => {}
&Tree::CheckBool { offset } => into.push(Insn::CheckBool(offset)),
Tree::Sequence { sub_trees } => {
for tree in &**sub_trees {
compile_tree(tree, into);
}
}
&Tree::CheckTag {
tag_offset,
num_variants,
} => into.push(Insn::CheckTag(CheckTag {
tag_offset,
num_variants,
})),
Tree::Sum {
tag_offset,
tag_data_processors,
} => {
let num_variants = tag_data_processors.len();
into.push(Insn::CheckReadTagRelBranch(CheckTag {
tag_offset: *tag_offset,
num_variants: num_variants as u8,
}));
let to_branches = into.len();
into.extend(repeat_n(Insn::FIXUP, num_variants));
let mut from_variant_gotos = Vec::with_capacity(num_variants);
for (tag, branch) in tag_data_processors.iter().enumerate() {
into[to_branches + tag] = Insn::Goto(into.len() as u16);
compile_tree(branch, into);
from_variant_gotos.push(into.len());
into.push(Insn::FIXUP);
}
let goto_addr = into.len();
for idx in from_variant_gotos {
into[idx] = Insn::Goto(goto_addr as u16);
}
}
}
}
compile_tree(tree, &mut program);
remove_trailing_gotos(&mut program);
program
}
fn remove_trailing_gotos(program: &mut Vec<Insn>) {
for idx in (0..program.len()).rev() {
match program[idx] {
Insn::Goto(_) => program.pop(),
_ => break,
};
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CheckTag {
tag_offset: u16,
num_variants: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Insn {
CheckBool(u16),
CheckTag(CheckTag),
CheckReadTagRelBranch(CheckTag),
Goto(u16),
}
impl Insn {
const FIXUP: Self = Self::Goto(u16::MAX);
}
impl MemoryUsage for Insn {}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct StaticBsatnValidator {
insns: Arc<[Insn]>,
}
impl MemoryUsage for StaticBsatnValidator {
fn heap_usage(&self) -> usize {
let Self { insns } = self;
insns.heap_usage()
}
}
unsafe fn check_tag(bytes: &[u8], check: CheckTag) -> Result<u8, DecodeError> {
let tag = *unsafe { bytes.get_unchecked(check.tag_offset as usize) };
if tag < check.num_variants {
Ok(tag)
} else {
Err(DecodeError::InvalidTag { tag, sum_name: None })
}
}
pub(crate) unsafe fn validate_bsatn(
program: &StaticBsatnValidator,
static_layout: &StaticLayout,
bytes: &[u8],
) -> Result<(), DecodeError> {
let expected = static_layout.bsatn_length as usize;
let given = bytes.len();
if expected != given {
return Err(DecodeError::InvalidLen { expected, given });
}
let program = &*program.insns;
let mut instr_ptr = 0;
loop {
match program.get(instr_ptr as usize).copied() {
None => break,
Some(Insn::CheckBool(offset)) => {
instr_ptr += 1;
let byte = *unsafe { bytes.get_unchecked(offset as usize) };
if byte > 1 {
return Err(DecodeError::InvalidBool(byte));
}
}
Some(Insn::Goto(new_insn)) => instr_ptr = new_insn,
Some(Insn::CheckTag(check)) => {
unsafe { check_tag(bytes, check) }?;
instr_ptr += 1;
}
Some(Insn::CheckReadTagRelBranch(check)) => {
let tag = unsafe { check_tag(bytes, check) }?;
instr_ptr += tag as u16 + 1;
}
}
}
Ok(())
}
#[cfg(test)]
pub mod test {
use super::*;
use crate::{
bflatn_to::write_row_to_page, blob_store::HashMapBlobStore, page::Page, row_type_visitor::row_type_visitor,
};
use proptest::{prelude::*, prop_assert_eq, proptest};
use spacetimedb_sats::bsatn::to_vec;
use spacetimedb_sats::proptest::generate_typed_row;
use spacetimedb_sats::{AlgebraicType, ProductType};
proptest! {
#![proptest_config(ProptestConfig {
max_global_rejects: 65536,
cases: if cfg!(miri) { 8 } else { 2048 },
..<_>::default()
})]
#[test]
fn validation_same_as_write_row_to_pages((ty, val) in generate_typed_row()) {
let ty: RowTypeLayout = ty.into();
let Some(static_layout) = StaticLayout::for_row_type(&ty) else {
return Err(TestCaseError::reject("Var-length type"));
};
let validator = static_bsatn_validator(&ty);
let bsatn = to_vec(&val).unwrap();
let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
let mut page = Page::new(ty.size());
let visitor = row_type_visitor(&ty);
let blob_store = &mut HashMapBlobStore::default();
let res_write = unsafe { write_row_to_page(&mut page, blob_store, &visitor, &ty, &val) };
prop_assert_eq!(res_validate.is_ok(), res_write.is_ok());
}
#[test]
fn bad_bool_validates_to_error(byte in 2u8..) {
let ty: RowTypeLayout = ProductType::from([AlgebraicType::Bool]).into();
let static_layout = StaticLayout::for_row_type(&ty).unwrap();
let validator = static_bsatn_validator(&ty);
let bsatn = [byte];
let res_validate = unsafe { validate_bsatn(&validator, &static_layout, &bsatn) };
prop_assert_eq!(res_validate, Err(DecodeError::InvalidBool(byte)));
}
}
}